mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-12 23:34:36 +08:00
Compare commits
100 Commits
Author | SHA1 | Date | |
---|---|---|---|
![]() |
bb303c45a5 | ||
![]() |
6f7986d592 | ||
![]() |
7cbb4aef17 | ||
![]() |
02bec0bb6d | ||
![]() |
c79f6a4a8c | ||
![]() |
0c5eea226b | ||
![]() |
dcca0d7477 | ||
![]() |
0d5e7716ad | ||
![]() |
d8c824c594 | ||
![]() |
cb431dfc9f | ||
![]() |
61d787726a | ||
![]() |
5e89aace9b | ||
![]() |
2af7e8a9a6 | ||
![]() |
2419edd5b2 | ||
![]() |
bf481e8e5d | ||
![]() |
9d7fa6b8e6 | ||
![]() |
073076ac7d | ||
![]() |
9bd03dd9b4 | ||
![]() |
6931f84412 | ||
![]() |
16ec0556a0 | ||
![]() |
610af352d4 | ||
![]() |
b35f1e3c9c | ||
![]() |
dfa0b9aab4 | ||
![]() |
a4c47b0276 | ||
![]() |
111fefd5e9 | ||
![]() |
c1fe1ef081 | ||
![]() |
8c34c9dac4 | ||
![]() |
91c0277356 | ||
![]() |
9f0d5c12fc | ||
![]() |
59247c2b62 | ||
![]() |
9a3842a2d9 | ||
![]() |
726dbd9267 | ||
![]() |
54f05e7195 | ||
![]() |
26be608470 | ||
![]() |
248431eb3c | ||
![]() |
76f275b4df | ||
![]() |
f1951d6cce | ||
![]() |
62f297b51d | ||
![]() |
09bc32f62f | ||
![]() |
46d8b16ab4 | ||
![]() |
42533931fa | ||
![]() |
9bd3a7102f | ||
![]() |
9e516b71ea | ||
![]() |
eac961ddb1 | ||
![]() |
57c6aa7188 | ||
![]() |
cde5b4ad80 | ||
![]() |
4f72c66911 | ||
![]() |
960e3f0f05 | ||
![]() |
884af42da2 | ||
![]() |
048fabdabd | ||
![]() |
917252a5a1 | ||
![]() |
1a992e31e8 | ||
![]() |
d2ff04a4f2 | ||
![]() |
015c247393 | ||
![]() |
d3cd26820e | ||
![]() |
91f6c499d7 | ||
![]() |
35e9c87ab9 | ||
![]() |
8e88e30d95 | ||
![]() |
0eb56d5be0 | ||
![]() |
f70764a162 | ||
![]() |
dad1b00b13 | ||
![]() |
430ffef58a | ||
![]() |
3d17077187 | ||
![]() |
c9b41d460f | ||
![]() |
32972a5924 | ||
![]() |
f6afb9c09b | ||
![]() |
3ddc07e936 | ||
![]() |
c26208f67d | ||
![]() |
d15fa13daf | ||
![]() |
58a855682c | ||
![]() |
92d7cb71f8 | ||
![]() |
50d8bed468 | ||
![]() |
9dd72cd421 | ||
![]() |
343aa46b78 | ||
![]() |
b8ab89b413 | ||
![]() |
f9f8c167d4 | ||
![]() |
3f86399922 | ||
![]() |
2b8ace6a03 | ||
![]() |
0ab8e099e8 | ||
![]() |
020f048cd0 | ||
![]() |
881615b072 | ||
![]() |
0eef4febfd | ||
![]() |
b54a70ec2d | ||
![]() |
bf6ec92216 | ||
![]() |
c21331d47f | ||
![]() |
e1c9600da3 | ||
![]() |
1fa0d20a30 | ||
![]() |
3274c6a087 | ||
![]() |
9b12093739 | ||
![]() |
f374b6ca4d | ||
![]() |
0070e1db40 | ||
![]() |
95d04805b3 | ||
![]() |
e4534dac17 | ||
![]() |
fef3c4ec1d | ||
![]() |
1bdc038bf9 | ||
![]() |
5523d9c426 | ||
![]() |
d878015228 | ||
![]() |
5900e3249f | ||
![]() |
bacced53d3 | ||
![]() |
4a64d4bff1 |
@@ -13,8 +13,62 @@ parameters:
|
||||
test_release:
|
||||
type: boolean
|
||||
default: false
|
||||
linux_release:
|
||||
type: boolean
|
||||
default: false
|
||||
|
||||
jobs:
|
||||
build_documentation:
|
||||
parameters:
|
||||
upload-docs:
|
||||
type: boolean
|
||||
default: false
|
||||
macos:
|
||||
xcode: "15.2.0"
|
||||
resource_class: macos.m1.medium.gen1
|
||||
steps:
|
||||
- checkout
|
||||
- run:
|
||||
name: Install
|
||||
command: |
|
||||
brew install python@3.9
|
||||
brew install doxygen
|
||||
python3.9 -m venv env
|
||||
source env/bin/activate
|
||||
pip install --upgrade pip
|
||||
pip install --upgrade cmake
|
||||
pip install -r docs/requirements.txt
|
||||
CMAKE_BUILD_PARALLEL_LEVEL=`sysctl -n hw.ncpu` pip install . -v
|
||||
- when:
|
||||
condition:
|
||||
not: << parameters.upload-docs >>
|
||||
steps:
|
||||
- run:
|
||||
name: Build documentation
|
||||
command: |
|
||||
source env/bin/activate
|
||||
cd docs && doxygen && make html O=-W
|
||||
- when:
|
||||
condition: << parameters.upload-docs >>
|
||||
steps:
|
||||
- add_ssh_keys:
|
||||
fingerprints:
|
||||
- "SHA256:OhcVVMovbT0pkgMeiVRyxMnjV9R2t+hKBsNcuxq9h+0"
|
||||
- run:
|
||||
name: Upload documentation
|
||||
command: |
|
||||
source env/bin/activate
|
||||
git config user.email "mlx@group.apple.com"
|
||||
git config user.name "CircleCI Docs"
|
||||
git checkout gh-pages
|
||||
git rebase main
|
||||
cd docs
|
||||
git rm -rf build/html
|
||||
doxygen && make html O=-W
|
||||
git add -f build/html
|
||||
git commit -m "rebase"
|
||||
git push -f origin gh-pages
|
||||
|
||||
linux_build_and_test:
|
||||
docker:
|
||||
- image: cimg/python:3.9
|
||||
@@ -31,7 +85,7 @@ jobs:
|
||||
name: Install dependencies
|
||||
command: |
|
||||
pip install --upgrade cmake
|
||||
pip install nanobind==2.1.0
|
||||
pip install nanobind==2.2.0
|
||||
pip install numpy
|
||||
sudo apt-get update
|
||||
sudo apt-get install libblas-dev liblapack-dev liblapacke-dev
|
||||
@@ -77,13 +131,13 @@ jobs:
|
||||
- run:
|
||||
name: Install dependencies
|
||||
command: |
|
||||
brew install python@3.8
|
||||
brew install python@3.9
|
||||
brew install openmpi
|
||||
python3.8 -m venv env
|
||||
python3.9 -m venv env
|
||||
source env/bin/activate
|
||||
pip install --upgrade pip
|
||||
pip install --upgrade cmake
|
||||
pip install nanobind==2.1.0
|
||||
pip install nanobind==2.2.0
|
||||
pip install numpy
|
||||
pip install torch
|
||||
pip install tensorflow
|
||||
@@ -105,7 +159,7 @@ jobs:
|
||||
source env/bin/activate
|
||||
LOW_MEMORY=1 DEVICE=cpu python -m xmlrunner discover -v python/tests -o test-results/cpu
|
||||
LOW_MEMORY=1 DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 METAL_DEBUG_ERROR_MODE=0 python -m xmlrunner discover -v python/tests -o test-results/gpu
|
||||
mpirun -host localhost:8 -np 8 -x DYLD_LIBRARY_PATH=/opt/homebrew/lib/ python python/tests/mpi_test_distributed.py
|
||||
mpirun --bind-to none -host localhost:8 -np 8 -x DYLD_LIBRARY_PATH=/opt/homebrew/lib/ python python/tests/mpi_test_distributed.py
|
||||
- run:
|
||||
name: Build example extension
|
||||
command: |
|
||||
@@ -172,7 +226,7 @@ jobs:
|
||||
source env/bin/activate
|
||||
pip install --upgrade pip
|
||||
pip install --upgrade cmake
|
||||
pip install nanobind==2.1.0
|
||||
pip install nanobind==2.2.0
|
||||
pip install --upgrade setuptools
|
||||
pip install numpy
|
||||
pip install twine
|
||||
@@ -208,7 +262,7 @@ jobs:
|
||||
- store_artifacts:
|
||||
path: dist/
|
||||
|
||||
build_linux_test_release:
|
||||
build_linux_release:
|
||||
parameters:
|
||||
python_version:
|
||||
type: string
|
||||
@@ -237,12 +291,13 @@ jobs:
|
||||
source env/bin/activate
|
||||
pip install --upgrade pip
|
||||
pip install --upgrade cmake
|
||||
pip install nanobind==2.1.0
|
||||
pip install nanobind==2.2.0
|
||||
pip install --upgrade setuptools
|
||||
pip install numpy
|
||||
pip install auditwheel
|
||||
pip install patchelf
|
||||
pip install build
|
||||
pip install twine
|
||||
<< parameters.extra_env >> \
|
||||
CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
|
||||
pip install . -v
|
||||
@@ -253,6 +308,11 @@ jobs:
|
||||
python -m build --wheel
|
||||
auditwheel show dist/*
|
||||
auditwheel repair dist/* --plat manylinux_2_31_x86_64
|
||||
- run:
|
||||
name: Upload package
|
||||
command: |
|
||||
source env/bin/activate
|
||||
twine upload wheelhouse/*
|
||||
- store_artifacts:
|
||||
path: wheelhouse/
|
||||
|
||||
@@ -272,6 +332,7 @@ workflows:
|
||||
parameters:
|
||||
xcode_version: ["15.0.0", "15.2.0", "16.0.0"]
|
||||
- linux_build_and_test
|
||||
- build_documentation
|
||||
|
||||
build_pypi_release:
|
||||
when:
|
||||
@@ -288,9 +349,17 @@ workflows:
|
||||
ignore: /.*/
|
||||
matrix:
|
||||
parameters:
|
||||
python_version: ["3.8", "3.9", "3.10", "3.11", "3.12"]
|
||||
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
|
||||
xcode_version: ["15.0.0", "15.2.0"]
|
||||
build_env: ["PYPI_RELEASE=1"]
|
||||
- build_documentation:
|
||||
filters:
|
||||
tags:
|
||||
only: /^v.*/
|
||||
branches:
|
||||
ignore: /.*/
|
||||
upload-docs: true
|
||||
|
||||
prb:
|
||||
when:
|
||||
matches:
|
||||
@@ -317,7 +386,7 @@ workflows:
|
||||
- build_release:
|
||||
matrix:
|
||||
parameters:
|
||||
python_version: ["3.8", "3.9", "3.10", "3.11", "3.12"]
|
||||
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
|
||||
xcode_version: ["15.0.0", "15.2.0"]
|
||||
weekly_build:
|
||||
when:
|
||||
@@ -328,17 +397,17 @@ workflows:
|
||||
- build_release:
|
||||
matrix:
|
||||
parameters:
|
||||
python_version: ["3.8", "3.9", "3.10", "3.11", "3.12"]
|
||||
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
|
||||
xcode_version: ["15.0.0", "15.2.0", "16.0.0"]
|
||||
build_env: ["DEV_RELEASE=1"]
|
||||
linux_test_release:
|
||||
when:
|
||||
and:
|
||||
- equal: [ main, << pipeline.git.branch >> ]
|
||||
- << pipeline.parameters.test_release >>
|
||||
- << pipeline.parameters.linux_release >>
|
||||
jobs:
|
||||
- build_linux_test_release:
|
||||
- build_linux_release:
|
||||
matrix:
|
||||
parameters:
|
||||
python_version: ["3.8", "3.9", "3.10", "3.11", "3.12"]
|
||||
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
|
||||
extra_env: ["PYPI_RELEASE=1"]
|
||||
|
@@ -24,7 +24,7 @@ option(MLX_METAL_JIT "Use JIT compilation for Metal kernels" OFF)
|
||||
option(BUILD_SHARED_LIBS "Build mlx as a shared library" OFF)
|
||||
|
||||
if(NOT MLX_VERSION)
|
||||
set(MLX_VERSION 0.18.0)
|
||||
set(MLX_VERSION 0.21.0)
|
||||
endif()
|
||||
|
||||
# --------------------- Processor tests -------------------------
|
||||
@@ -47,17 +47,20 @@ if(${CMAKE_SYSTEM_NAME} MATCHES "Darwin")
|
||||
"https://ml-explore.github.io/mlx/build/html/install.html#build-from-source"
|
||||
)
|
||||
else()
|
||||
set(MLX_BUILD_METAL OFF)
|
||||
message(WARNING "Building for x86_64 arch is not officially supported.")
|
||||
endif()
|
||||
set(MLX_BUILD_METAL OFF)
|
||||
elseif(${CMAKE_SYSTEM_PROCESSOR} MATCHES "arm64")
|
||||
set(MLX_BUILD_ARM ON)
|
||||
endif()
|
||||
|
||||
else()
|
||||
set(MLX_BUILD_METAL OFF)
|
||||
message(WARNING "MLX is prioritised for Apple silicon systems using macOS.")
|
||||
endif()
|
||||
|
||||
if(${CMAKE_SYSTEM_PROCESSOR} MATCHES "arm64")
|
||||
set(MLX_BUILD_ARM ON)
|
||||
endif()
|
||||
|
||||
# ----------------------------- Lib -----------------------------
|
||||
|
||||
include(FetchContent)
|
||||
@@ -86,25 +89,27 @@ elseif(MLX_BUILD_METAL)
|
||||
# Throw an error if xcrun not found
|
||||
execute_process(
|
||||
COMMAND zsh "-c" "/usr/bin/xcrun -sdk macosx --show-sdk-version"
|
||||
OUTPUT_VARIABLE MACOS_VERSION COMMAND_ERROR_IS_FATAL ANY)
|
||||
OUTPUT_VARIABLE MACOS_SDK_VERSION COMMAND_ERROR_IS_FATAL ANY)
|
||||
|
||||
if(${MACOS_VERSION} LESS 14.0)
|
||||
if(${MACOS_SDK_VERSION} LESS 14.0)
|
||||
message(
|
||||
FATAL_ERROR
|
||||
"MLX requires macOS SDK >= 14.0 to be built with MLX_BUILD_METAL=ON")
|
||||
endif()
|
||||
message(STATUS "Building with SDK for macOS version ${MACOS_VERSION}")
|
||||
message(STATUS "Building with macOS SDK version ${MACOS_SDK_VERSION}")
|
||||
|
||||
set(METAL_CPP_URL
|
||||
https://developer.apple.com/metal/cpp/files/metal-cpp_macOS15_iOS18-beta.zip
|
||||
)
|
||||
# Get the metal version
|
||||
|
||||
if(NOT CMAKE_OSX_DEPLOYMENT_TARGET STREQUAL "")
|
||||
set(XCRUN_FLAGS "-mmacosx-version-min=${CMAKE_OSX_DEPLOYMENT_TARGET}")
|
||||
endif()
|
||||
execute_process(
|
||||
COMMAND
|
||||
zsh "-c"
|
||||
"echo \"__METAL_VERSION__\" | xcrun -sdk macosx metal -E -x metal -P - | tail -1 | tr -d '\n'"
|
||||
"echo \"__METAL_VERSION__\" | xcrun -sdk macosx metal ${XCRUN_FLAGS} -E -x metal -P - | tail -1 | tr -d '\n'"
|
||||
OUTPUT_VARIABLE MLX_METAL_VERSION COMMAND_ERROR_IS_FATAL ANY)
|
||||
|
||||
FetchContent_Declare(metal_cpp URL ${METAL_CPP_URL})
|
||||
|
||||
FetchContent_MakeAvailable(metal_cpp)
|
||||
@@ -112,8 +117,6 @@ elseif(MLX_BUILD_METAL)
|
||||
mlx PUBLIC $<BUILD_INTERFACE:${metal_cpp_SOURCE_DIR}>
|
||||
$<INSTALL_INTERFACE:include/metal_cpp>)
|
||||
target_link_libraries(mlx PUBLIC ${METAL_LIB} ${FOUNDATION_LIB} ${QUARTZ_LIB})
|
||||
|
||||
add_compile_definitions("MLX_METAL_VERSION=${MLX_METAL_VERSION}")
|
||||
endif()
|
||||
|
||||
if(MLX_BUILD_CPU)
|
||||
|
@@ -6,7 +6,7 @@
|
||||
|
||||
[](https://circleci.com/gh/ml-explore/mlx)
|
||||
|
||||
MLX is an array framework for machine learning research on Apple silicon,
|
||||
MLX is an array framework for machine learning on Apple silicon,
|
||||
brought to you by Apple machine learning research.
|
||||
|
||||
Some key features of MLX include:
|
||||
|
@@ -144,6 +144,13 @@ def reduction(op, axis, x):
|
||||
mx.eval(ys)
|
||||
|
||||
|
||||
def sum_and_add(axis, x, y):
|
||||
z = x.sum(axis=axis, keepdims=True)
|
||||
for i in range(50):
|
||||
z = (z + y).sum(axis=axis, keepdims=True)
|
||||
mx.eval(z)
|
||||
|
||||
|
||||
def softmax(axis, x):
|
||||
ys = []
|
||||
for i in range(100):
|
||||
@@ -505,5 +512,8 @@ if __name__ == "__main__":
|
||||
elif args.benchmark == "selu":
|
||||
print(bench(selu, x))
|
||||
|
||||
elif args.benchmark == "sum_and_add":
|
||||
print(bench(sum_and_add, axis, *xs))
|
||||
|
||||
else:
|
||||
raise ValueError("Unknown benchmark")
|
||||
|
@@ -9,7 +9,7 @@ from time_utils import measure_runtime
|
||||
|
||||
def benchmark_scatter_mlx(dst_shape, x_shape, idx_shapes):
|
||||
def scatter(dst, x, idx):
|
||||
dst[*idx] = x
|
||||
dst[tuple(idx)] = x
|
||||
mx.eval(dst)
|
||||
|
||||
idx = []
|
||||
@@ -23,8 +23,8 @@ def benchmark_scatter_mlx(dst_shape, x_shape, idx_shapes):
|
||||
|
||||
|
||||
def benchmark_scatter_torch(dst_shape, x_shape, idx_shapes, device):
|
||||
def gather(dst, x, idx, device):
|
||||
dst[*idx] = x
|
||||
def scatter(dst, x, idx, device):
|
||||
dst[tuple(idx)] = x
|
||||
if device == torch.device("mps"):
|
||||
torch.mps.synchronize()
|
||||
|
||||
@@ -34,7 +34,7 @@ def benchmark_scatter_torch(dst_shape, x_shape, idx_shapes, device):
|
||||
x = torch.randn(x_shape, dtype=torch.float32).to(device)
|
||||
dst = torch.randn(dst_shape, dtype=torch.float32).to(device)
|
||||
|
||||
runtime = measure_runtime(gather, dst=dst, x=x, idx=idx, device=device)
|
||||
runtime = measure_runtime(scatter, dst=dst, x=x, idx=idx, device=device)
|
||||
print(f"PyTorch: {runtime:.3f}ms")
|
||||
|
||||
|
||||
@@ -54,7 +54,7 @@ if __name__ == "__main__":
|
||||
(100_000, 64),
|
||||
(1_000_000, 64),
|
||||
(100_000,),
|
||||
(2_000_00,),
|
||||
(200_000,),
|
||||
(20_000_000,),
|
||||
(10000, 64),
|
||||
(100, 64),
|
||||
@@ -91,6 +91,6 @@ if __name__ == "__main__":
|
||||
|
||||
for dst_shape, x_shape, idx_shape in zip(dst_shapes, x_shapes, idx_shapes):
|
||||
print("=" * 20)
|
||||
print(f"X {x_shape}, Indices {idx_shape}")
|
||||
print(f"Dst: {dst_shape}, X {x_shape}, Indices {idx_shape}")
|
||||
benchmark_scatter_mlx(dst_shape, x_shape, idx_shape)
|
||||
benchmark_scatter_torch(dst_shape, x_shape, idx_shape, device=device)
|
||||
|
@@ -1,62 +1,189 @@
|
||||
# Copyright © 2024 Apple Inc.
|
||||
|
||||
import argparse
|
||||
import math
|
||||
import os
|
||||
import subprocess
|
||||
import time
|
||||
|
||||
import mlx.core as mx
|
||||
from time_utils import time_fn
|
||||
import numpy as np
|
||||
|
||||
MAX_SEQ = 300
|
||||
START_SEQ = 100
|
||||
SEQ_INCREMENT = 50
|
||||
device_name = subprocess.check_output(["sysctl", "-n", "machdep.cpu.brand_string"])
|
||||
device_name = device_name.decode("utf-8").strip("\n")
|
||||
|
||||
N_warmup = 5
|
||||
N_iter_bench = 40
|
||||
N_iter_func = 8
|
||||
|
||||
|
||||
def time_self_attention_primitives():
|
||||
mx.random.seed(3)
|
||||
B = 2
|
||||
H = 38
|
||||
D = 64
|
||||
for R in range(START_SEQ, MAX_SEQ, SEQ_INCREMENT):
|
||||
q = mx.random.uniform(shape=(B, H, R, D))
|
||||
k = mx.random.uniform(shape=(B, H, R, D))
|
||||
v = mx.random.uniform(shape=(B, H, R, D))
|
||||
scale = 1.0 / math.sqrt(float(D))
|
||||
mx.eval(q, k, v)
|
||||
def bench(f, *args):
|
||||
for i in range(N_warmup):
|
||||
f(*args)
|
||||
|
||||
def sdpa_primitives(qs, ks, vs, alpha):
|
||||
s = (alpha * qs) @ ks.transpose(0, 1, 3, 2)
|
||||
p = mx.softmax(s.astype(mx.float32), axis=-1).astype(s.dtype)
|
||||
o = p @ vs
|
||||
return o
|
||||
|
||||
time_fn(sdpa_primitives, q, k, v, scale)
|
||||
s = time.perf_counter_ns()
|
||||
for i in range(N_iter_bench):
|
||||
f(*args)
|
||||
e = time.perf_counter_ns()
|
||||
return (e - s) * 1e-9
|
||||
|
||||
|
||||
def time_self_attention_sdpa():
|
||||
mx.random.seed(3)
|
||||
B = 2
|
||||
H = 38
|
||||
D = 64
|
||||
for R in range(START_SEQ, MAX_SEQ, SEQ_INCREMENT):
|
||||
q = mx.random.uniform(shape=(B, H, R, D))
|
||||
k = mx.random.uniform(shape=(B, H, R, D))
|
||||
v = mx.random.uniform(shape=(B, H, R, D))
|
||||
scale = 1.0 / math.sqrt(float(D))
|
||||
mx.eval(q, k, v)
|
||||
def mlx_sdpa_fused_inner(q, k, v, scale):
|
||||
return mx.fast.scaled_dot_product_attention(q, k, v, scale=scale, mask=None)
|
||||
|
||||
def sdpa_fused(qs, ks, vs, alpha):
|
||||
o = mx.fast.scaled_dot_product_attention(qs, ks, vs, scale=alpha)
|
||||
return o
|
||||
|
||||
time_fn(sdpa_fused, q, k, v, scale)
|
||||
def mlx_sdpa_unfused_inner(q, k, v, scale, f32softmax=False):
|
||||
q_dtype = q.dtype
|
||||
q = q * mx.array(scale, q_dtype)
|
||||
n_q_heads = q.shape[-3]
|
||||
n_kv_heads = k.shape[-3]
|
||||
n_repeats = n_q_heads // n_kv_heads
|
||||
|
||||
B = q.shape[0]
|
||||
L = q.shape[2]
|
||||
|
||||
if n_repeats > 1:
|
||||
q = mx.reshape(q, [B, n_kv_heads, n_repeats, L, -1])
|
||||
k = mx.expand_dims(k, 2)
|
||||
v = mx.expand_dims(v, 2)
|
||||
|
||||
scores = q @ mx.swapaxes(k, -1, -2)
|
||||
if f32softmax:
|
||||
scores = mx.softmax(scores.astype(mx.float32), axis=-1).astype(q_dtype)
|
||||
else:
|
||||
scores = mx.softmax(scores, axis=-1)
|
||||
|
||||
out = scores @ v
|
||||
if n_repeats > 1:
|
||||
out = mx.reshape(out, [B, n_q_heads, L, -1])
|
||||
|
||||
return out
|
||||
|
||||
|
||||
def mlx_spda_unfused(q, k, v, scale, transpose):
|
||||
q_out = q
|
||||
if transpose:
|
||||
k = mx.transpose(k, (0, 2, 1, 3))
|
||||
v = mx.transpose(v, (0, 2, 1, 3))
|
||||
|
||||
for i in range(N_iter_func):
|
||||
if transpose:
|
||||
q_out = mx.transpose(q_out, (0, 2, 1, 3))
|
||||
q_out = mlx_sdpa_unfused_inner(q_out, k, v, scale)
|
||||
if transpose:
|
||||
q_out = mx.transpose(q_out, (0, 2, 1, 3))
|
||||
|
||||
mx.eval(q_out)
|
||||
return q_out
|
||||
|
||||
|
||||
def mlx_spda_fused(q, k, v, scale, transpose):
|
||||
q_out = q
|
||||
if transpose:
|
||||
k = mx.transpose(k, (0, 2, 1, 3))
|
||||
v = mx.transpose(v, (0, 2, 1, 3))
|
||||
|
||||
for i in range(N_iter_func):
|
||||
if transpose:
|
||||
q_out = mx.transpose(q_out, (0, 2, 1, 3))
|
||||
q_out = mlx_sdpa_fused_inner(q_out, k, v, scale)
|
||||
if transpose:
|
||||
q_out = mx.transpose(q_out, (0, 2, 1, 3))
|
||||
|
||||
mx.eval(q_out)
|
||||
return q_out
|
||||
|
||||
|
||||
def bench_shape(B, qsl, ksl, head_dim, n_q_heads, n_kv_heads, np_dtype, transpose=True):
|
||||
shape_q = (
|
||||
(B, qsl, n_q_heads, head_dim) if transpose else (B, n_q_heads, qsl, head_dim)
|
||||
)
|
||||
shape_kv = (
|
||||
(B, ksl, n_kv_heads, head_dim) if transpose else (B, n_kv_heads, ksl, head_dim)
|
||||
)
|
||||
|
||||
q_np = np.random.normal(0.0, 1.0 / math.sqrt(head_dim), shape_q).astype(np_dtype)
|
||||
k_np = np.random.normal(0.0, 1.0 / math.sqrt(head_dim), shape_kv).astype(np_dtype)
|
||||
v_np = np.random.normal(0.0, 1.0 / math.sqrt(head_dim), shape_kv).astype(np_dtype)
|
||||
|
||||
scale = math.sqrt(1.0 / head_dim)
|
||||
|
||||
q_mx = mx.array(q_np)
|
||||
k_mx = mx.array(k_np)
|
||||
v_mx = mx.array(v_np)
|
||||
|
||||
time_mlx_unfused = bench(mlx_spda_unfused, q_mx, k_mx, v_mx, scale, transpose)
|
||||
time_mlx_fused = bench(mlx_spda_fused, q_mx, k_mx, v_mx, scale, transpose)
|
||||
|
||||
if transpose:
|
||||
q_mx = mx.transpose(q_mx, (0, 2, 1, 3))
|
||||
k_mx = mx.transpose(k_mx, (0, 2, 1, 3))
|
||||
v_mx = mx.transpose(v_mx, (0, 2, 1, 3))
|
||||
|
||||
o_mlx_fused = mlx_sdpa_fused_inner(q_mx, k_mx, v_mx, scale)
|
||||
o_mlx_unfused = mlx_sdpa_unfused_inner(q_mx, k_mx, v_mx, scale, f32softmax=True)
|
||||
|
||||
atol = 1e-5 if np_dtype == np.float32 else 1e-4
|
||||
|
||||
if not mx.allclose(o_mlx_fused, o_mlx_unfused, atol=atol):
|
||||
print(
|
||||
f"Failed at (B: {B}, qsl: {qsl}, ksl: {ksl}, head_dim: {head_dim}, n_qh: {n_q_heads}, n_kvh: {n_kv_heads}) [tpose = {transpose}] with max(|a - b|) = {mx.max(mx.abs(o_mlx_unfused - o_mlx_fused)):3.2e}"
|
||||
)
|
||||
|
||||
return time_mlx_fused, time_mlx_unfused
|
||||
|
||||
|
||||
def get_gflop_count(B, M, N, K):
|
||||
return float(2.0 * N_iter_bench * N_iter_func * B * M * N * K) / float(1024.0**3)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser("MLX benchmarks.")
|
||||
parser.add_argument("--gpu", action="store_true", help="Use the Metal back-end.")
|
||||
args = parser.parse_args()
|
||||
if args.gpu:
|
||||
mx.set_default_device(mx.gpu)
|
||||
else:
|
||||
mx.set_default_device(mx.cpu)
|
||||
parser = argparse.ArgumentParser(description="Run gemm benchmarks")
|
||||
|
||||
time_self_attention_sdpa()
|
||||
time_self_attention_primitives()
|
||||
dtypes = ("float16", "float32")[:1]
|
||||
transposes = (False,)
|
||||
|
||||
# fmt: off
|
||||
shapes_64 = (
|
||||
# ( B, qsl, ksl, head_dim, n_qh, n_kvh)
|
||||
( 1, 32, 32, 64, 32, 32),
|
||||
( 1, 64, 64, 64, 32, 32),
|
||||
( 1, 128, 128, 64, 32, 32),
|
||||
( 1, 256, 256, 64, 32, 32),
|
||||
( 1, 512, 512, 64, 32, 32),
|
||||
( 1, 1024, 1024, 64, 32, 32),
|
||||
( 1, 2048, 2048, 64, 32, 32),
|
||||
( 1, 4096, 4096, 64, 32, 32),
|
||||
)
|
||||
|
||||
shapes_80 = (
|
||||
# ( B, qsl, ksl, head_dim, n_qh, n_kvh)
|
||||
( 1, 1024, 1024, 80, 32, 32),
|
||||
( 1, 2048, 2048, 80, 32, 32),
|
||||
( 1, 4096, 4096, 80, 32, 32),
|
||||
)
|
||||
|
||||
shapes_128 = (
|
||||
# ( B, qsl, ksl, head_dim, n_qh, n_kvh)
|
||||
( 1, 1024, 1024, 128, 32, 32),
|
||||
( 1, 2048, 2048, 128, 32, 32),
|
||||
( 1, 4096, 4096, 128, 32, 32),
|
||||
)
|
||||
# fmt: on
|
||||
|
||||
shapes = shapes_64 + shapes_80 + shapes_128
|
||||
|
||||
print(" B, qsl, ksl, hdim, n_qh, n_kvh, tpose, dtype, t_unfs, t_fuse, diff%")
|
||||
|
||||
for dtype in dtypes:
|
||||
for transpose in transposes:
|
||||
for B, qsl, ksl, head_dim, n_q_heads, n_kv_heads in shapes:
|
||||
np_dtype = getattr(np, dtype)
|
||||
time_mlx_fused, time_mlx_unfused = bench_shape(
|
||||
B, qsl, ksl, head_dim, n_q_heads, n_kv_heads, np_dtype, transpose
|
||||
)
|
||||
diff = time_mlx_unfused / time_mlx_fused - 1.0
|
||||
t_str = 1 if transpose else 0
|
||||
print(
|
||||
f"{B:3d}, {qsl:5d}, {ksl:5d}, {head_dim:4d}, {n_q_heads:4d}, {n_kv_heads:5d}, {t_str:5d}, {dtype}, {time_mlx_unfused: 2.3f}, {time_mlx_fused: 2.3f}, {100. * diff:+5.2f}%"
|
||||
)
|
||||
|
58
benchmarks/python/sdpa_vector_bench.py
Normal file
58
benchmarks/python/sdpa_vector_bench.py
Normal file
@@ -0,0 +1,58 @@
|
||||
import argparse
|
||||
import math
|
||||
|
||||
import mlx.core as mx
|
||||
from time_utils import time_fn
|
||||
|
||||
L = 16384
|
||||
H = 32
|
||||
H_k = H // 4
|
||||
D = 128
|
||||
dtype = mx.float16
|
||||
loops = 10
|
||||
|
||||
|
||||
def attention(q, k, v):
|
||||
def _sdpa(q, k, v):
|
||||
B, Hq, L, D = q.shape
|
||||
_, Hk, S, _ = k.shape
|
||||
q = q.reshape(B, Hk, Hq // Hk, L, D)
|
||||
k = k[:, :, None, :, :]
|
||||
v = v[:, :, None, :, :]
|
||||
s = q @ k.transpose(0, 1, 2, 4, 3)
|
||||
p = mx.softmax(s.astype(mx.float32), axis=-1).astype(s.dtype)
|
||||
o = p @ v
|
||||
return o.reshape(B, Hq, L, D)
|
||||
|
||||
for i in range(loops):
|
||||
q = _sdpa(q, k, v)
|
||||
return q
|
||||
|
||||
|
||||
def sdpa(q, k, v):
|
||||
for i in range(loops):
|
||||
q = mx.fast.scaled_dot_product_attention(q, k, v, scale=1.0)
|
||||
return q
|
||||
|
||||
|
||||
def time_self_attention_primitives():
|
||||
mx.random.seed(3)
|
||||
q = mx.random.uniform(shape=(1, H, 1, D)).astype(dtype)
|
||||
k = mx.random.uniform(shape=(1, H_k, L, D)).astype(dtype)
|
||||
v = mx.random.uniform(shape=(1, H_k, L, D)).astype(dtype)
|
||||
mx.eval(q, k, v)
|
||||
time_fn(attention, q, k, v)
|
||||
|
||||
|
||||
def time_self_attention_sdpa():
|
||||
mx.random.seed(3)
|
||||
q = mx.random.uniform(shape=(1, H, 1, D)).astype(dtype)
|
||||
k = mx.random.uniform(shape=(1, H_k, L, D)).astype(dtype)
|
||||
v = mx.random.uniform(shape=(1, H_k, L, D)).astype(dtype)
|
||||
mx.eval(q, k, v)
|
||||
time_fn(sdpa, q, k, v)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
time_self_attention_sdpa()
|
||||
time_self_attention_primitives()
|
@@ -60,6 +60,7 @@ html_theme_options = {
|
||||
},
|
||||
}
|
||||
|
||||
html_favicon = html_theme_options["logo"]["image_light"]
|
||||
|
||||
# -- Options for HTMLHelp output ---------------------------------------------
|
||||
|
||||
|
@@ -1,3 +1,5 @@
|
||||
.. _custom_metal_kernels:
|
||||
|
||||
Custom Metal Kernels
|
||||
====================
|
||||
|
||||
@@ -76,6 +78,10 @@ Putting this all together, the generated function signature for ``myexp`` is as
|
||||
|
||||
template [[host_name("custom_kernel_myexp_float")]] [[kernel]] decltype(custom_kernel_myexp_float<float>) custom_kernel_myexp_float<float>;
|
||||
|
||||
Note: ``grid`` and ``threadgroup`` are parameters to the Metal `dispatchThreads <https://developer.apple.com/documentation/metal/mtlcomputecommandencoder/2866532-dispatchthreads>`_ function.
|
||||
This means we will launch ``mx.prod(grid)`` threads, subdivided into ``threadgroup`` size threadgroups.
|
||||
For optimal performance, each thread group dimension should be less than or equal to the corresponding grid dimension.
|
||||
|
||||
Passing ``verbose=True`` to ``mx.fast.metal_kernel.__call__`` will print the generated code for debugging purposes.
|
||||
|
||||
Using Shape/Strides
|
||||
|
@@ -494,7 +494,7 @@ below.
|
||||
|
||||
// Prepare to encode kernel
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
// Kernel parameters are registered with buffer indices corresponding to
|
||||
// those in the kernel declaration at axpby.metal
|
||||
@@ -509,14 +509,14 @@ below.
|
||||
compute_encoder.set_output_array(out, 2);
|
||||
|
||||
// Encode alpha and beta
|
||||
compute_encoder->setBytes(&alpha_, sizeof(float), 3);
|
||||
compute_encoder->setBytes(&beta_, sizeof(float), 4);
|
||||
compute_encoder.set_bytes(alpha_, 3);
|
||||
compute_encoder.set_bytes(beta_, 4);
|
||||
|
||||
// Encode shape, strides and ndim
|
||||
compute_encoder->setBytes(x.shape().data(), ndim * sizeof(int), 5);
|
||||
compute_encoder->setBytes(x.strides().data(), ndim * sizeof(size_t), 6);
|
||||
compute_encoder->setBytes(y.strides().data(), ndim * sizeof(size_t), 7);
|
||||
compute_encoder->setBytes(&ndim, sizeof(int), 8);
|
||||
compute_encoder.set_vector_bytes(x.shape(), 5);
|
||||
compute_encoder.set_vector_bytes(x.strides(), 6);
|
||||
compute_encoder.set_bytes(y.strides(), 7);
|
||||
compute_encoder.set_bytes(ndim, 8);
|
||||
|
||||
// We launch 1 thread for each input and make sure that the number of
|
||||
// threads in any given threadgroup is not higher than the max allowed
|
||||
@@ -530,7 +530,7 @@ below.
|
||||
|
||||
// Launch the grid with the given number of threads divided among
|
||||
// the given threadgroups
|
||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
We can now call the :meth:`axpby` operation on both the CPU and the GPU!
|
||||
|
@@ -14,7 +14,7 @@ silicon computer is
|
||||
To install from PyPI you must meet the following requirements:
|
||||
|
||||
- Using an M series chip (Apple silicon)
|
||||
- Using a native Python >= 3.8
|
||||
- Using a native Python >= 3.9
|
||||
- macOS >= 13.5
|
||||
|
||||
.. note::
|
||||
@@ -209,7 +209,7 @@ Metal library by run-time compiling kernels the first time they are used in MLX
|
||||
on a given machine. Note run-time compilation incurs a cold-start cost which can
|
||||
be anwywhere from a few hundred millisecond to a few seconds depending on the
|
||||
application. Once a kernel is compiled, it will be cached by the system. The
|
||||
Metal kernel cache persists accross reboots.
|
||||
Metal kernel cache persists across reboots.
|
||||
|
||||
Troubleshooting
|
||||
^^^^^^^^^^^^^^^
|
||||
@@ -240,7 +240,7 @@ x86 Shell
|
||||
|
||||
.. _build shell:
|
||||
|
||||
If the ouptut of ``uname -p`` is ``x86`` then your shell is running as x86 via
|
||||
If the output of ``uname -p`` is ``x86`` then your shell is running as x86 via
|
||||
Rosetta instead of natively.
|
||||
|
||||
To fix this, find the application in Finder (``/Applications`` for iTerm,
|
||||
@@ -264,4 +264,4 @@ Also check that cmake is using the correct architecture:
|
||||
|
||||
If you see ``"x86_64"``, try re-installing ``cmake``. If you see ``"arm64"``
|
||||
but the build errors out with "Building for x86_64 on macOS is not supported."
|
||||
wipe your build cahce with ``rm -rf build/`` and try again.
|
||||
wipe your build cache with ``rm -rf build/`` and try again.
|
||||
|
@@ -12,5 +12,4 @@ Fast
|
||||
layer_norm
|
||||
rope
|
||||
scaled_dot_product_attention
|
||||
affine_quantize
|
||||
metal_kernel
|
||||
|
@@ -16,3 +16,5 @@ Linear Algebra
|
||||
cross
|
||||
qr
|
||||
svd
|
||||
eigvalsh
|
||||
eigh
|
||||
|
@@ -14,6 +14,7 @@ Metal
|
||||
get_cache_memory
|
||||
set_memory_limit
|
||||
set_cache_limit
|
||||
set_wired_limit
|
||||
clear_cache
|
||||
start_capture
|
||||
stop_capture
|
||||
|
@@ -12,6 +12,7 @@ Layers
|
||||
ALiBi
|
||||
AvgPool1d
|
||||
AvgPool2d
|
||||
AvgPool3d
|
||||
BatchNorm
|
||||
CELU
|
||||
Conv1d
|
||||
@@ -41,6 +42,7 @@ Layers
|
||||
LSTM
|
||||
MaxPool1d
|
||||
MaxPool2d
|
||||
MaxPool3d
|
||||
Mish
|
||||
MultiHeadAttention
|
||||
PReLU
|
||||
|
@@ -80,6 +80,7 @@ Operations
|
||||
greater_equal
|
||||
hadamard_transform
|
||||
identity
|
||||
imag
|
||||
inner
|
||||
isfinite
|
||||
isclose
|
||||
@@ -125,11 +126,13 @@ Operations
|
||||
quantize
|
||||
quantized_matmul
|
||||
radians
|
||||
real
|
||||
reciprocal
|
||||
remainder
|
||||
repeat
|
||||
reshape
|
||||
right_shift
|
||||
roll
|
||||
round
|
||||
rsqrt
|
||||
save
|
||||
|
@@ -45,3 +45,4 @@ we use a splittable version of Threefry, which is a counter-based PRNG.
|
||||
truncated_normal
|
||||
uniform
|
||||
laplace
|
||||
permutation
|
||||
|
@@ -33,12 +33,12 @@ Let's start with a simple example:
|
||||
# Compile the function
|
||||
compiled_fun = mx.compile(fun)
|
||||
|
||||
# Prints: array(2.36788, dtype=float32)
|
||||
# Prints: array(2.36788, dtype=float32)
|
||||
print(compiled_fun(x, y))
|
||||
|
||||
The output of both the regular function and the compiled function is the same
|
||||
up to numerical precision.
|
||||
|
||||
|
||||
The first time you call a compiled function, MLX will build the compute
|
||||
graph, optimize it, and generate and compile code. This can be relatively
|
||||
slow. However, MLX will cache compiled functions, so calling a compiled
|
||||
@@ -96,7 +96,7 @@ element-wise operations:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
def gelu(x):
|
||||
def gelu(x):
|
||||
return x * (1 + mx.erf(x / math.sqrt(2))) / 2
|
||||
|
||||
If you use this function with small arrays, it will be overhead bound. If you
|
||||
@@ -136,13 +136,6 @@ Now make an array, and benchmark both functions:
|
||||
On an M1 Max the times are 15.5 and 3.1 milliseconds. The compiled ``gelu`` is
|
||||
five times faster.
|
||||
|
||||
.. note::
|
||||
|
||||
As of the latest MLX, CPU functions are not fully compiled. Compiling CPU
|
||||
functions can still be helpful, but won't typically result in as large a
|
||||
speedup as compiling operations that run on the GPU.
|
||||
|
||||
|
||||
Debugging
|
||||
---------
|
||||
|
||||
@@ -287,7 +280,7 @@ to the function. In some cases this can be pretty inconvenient. Hence,
|
||||
print(fun(mx.array(1.0)))
|
||||
|
||||
|
||||
Compiling Training Graphs
|
||||
Compiling Training Graphs
|
||||
-------------------------
|
||||
|
||||
This section will step through how to use :func:`compile` with a simple example
|
||||
@@ -297,7 +290,7 @@ full forward, backward, and update with :func:`compile`.
|
||||
|
||||
To start, here is the simple example without any compilation:
|
||||
|
||||
.. code-block:: python
|
||||
.. code-block:: python
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
@@ -330,7 +323,7 @@ To start, here is the simple example without any compilation:
|
||||
To compile the update we can put it all in a function and compile it with the
|
||||
appropriate input and output captures. Here's the same example but compiled:
|
||||
|
||||
.. code-block:: python
|
||||
.. code-block:: python
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
@@ -355,7 +348,7 @@ appropriate input and output captures. Here's the same example but compiled:
|
||||
|
||||
# The state that will be captured as input and output
|
||||
state = [model.state, optimizer.state]
|
||||
|
||||
|
||||
@partial(mx.compile, inputs=state, outputs=state)
|
||||
def step(x, y):
|
||||
loss_and_grad_fn = nn.value_and_grad(model, loss_fn)
|
||||
@@ -410,7 +403,7 @@ Compiling transformed functions works just as expected:
|
||||
|
||||
In order to compile as much as possible, a transformation of a compiled
|
||||
function will not by default be compiled. To compile the transformed
|
||||
function simply pass it through :func:`compile`.
|
||||
function simply pass it through :func:`compile`.
|
||||
|
||||
You can also compile functions which themselves call compiled functions. A
|
||||
good practice is to compile the outer most function to give :func:`compile`
|
||||
|
@@ -25,7 +25,7 @@ Here is a simple example:
|
||||
|
||||
The output of :func:`grad` on :func:`sin` is simply another function. In this
|
||||
case it is the gradient of the sine function which is exactly the cosine
|
||||
function. To get the second derivative you can do:
|
||||
function. To get the second derivative you can do:
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
@@ -50,7 +50,7 @@ Automatic Differentiation
|
||||
.. _auto diff:
|
||||
|
||||
Automatic differentiation in MLX works on functions rather than on implicit
|
||||
graphs.
|
||||
graphs.
|
||||
|
||||
.. note::
|
||||
|
||||
@@ -114,7 +114,7 @@ way to do that is the following:
|
||||
|
||||
def loss_fn(params, x, y):
|
||||
w, b = params["weight"], params["bias"]
|
||||
h = w * x + b
|
||||
h = w * x + b
|
||||
return mx.mean(mx.square(h - y))
|
||||
|
||||
params = {"weight": mx.array(1.0), "bias": mx.array(0.0)}
|
||||
@@ -132,7 +132,7 @@ way to do that is the following:
|
||||
|
||||
Notice the tree structure of the parameters is preserved in the gradients.
|
||||
|
||||
In some cases you may want to stop gradients from propagating through a
|
||||
In some cases you may want to stop gradients from propagating through a
|
||||
part of the function. You can use the :func:`stop_gradient` for that.
|
||||
|
||||
|
||||
@@ -161,19 +161,19 @@ A naive way to add the elements from two sets of vectors is with a loop:
|
||||
ys = mx.random.uniform(shape=(100, 4096))
|
||||
|
||||
def naive_add(xs, ys):
|
||||
return [xs[i] + ys[:, i] for i in range(xs.shape[1])]
|
||||
return [xs[i] + ys[:, i] for i in range(xs.shape[0])]
|
||||
|
||||
Instead you can use :func:`vmap` to automatically vectorize the addition:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
|
||||
# Vectorize over the second dimension of x and the
|
||||
# first dimension of y
|
||||
vmap_add = mx.vmap(lambda x, y: x + y, in_axes=(1, 0))
|
||||
vmap_add = mx.vmap(lambda x, y: x + y, in_axes=(0, 1))
|
||||
|
||||
The ``in_axes`` parameter can be used to specify which dimensions of the
|
||||
corresponding input to vectorize over. Similarly, use ``out_axes`` to specify
|
||||
where the vectorized axes should be in the outputs.
|
||||
where the vectorized axes should be in the outputs.
|
||||
|
||||
Let's time these two different versions:
|
||||
|
||||
@@ -184,8 +184,8 @@ Let's time these two different versions:
|
||||
print(timeit.timeit(lambda: mx.eval(naive_add(xs, ys)), number=100))
|
||||
print(timeit.timeit(lambda: mx.eval(vmap_add(xs, ys)), number=100))
|
||||
|
||||
On an M1 Max the naive version takes in total ``0.390`` seconds whereas the
|
||||
vectorized version takes only ``0.025`` seconds, more than ten times faster.
|
||||
On an M1 Max the naive version takes in total ``5.639`` seconds whereas the
|
||||
vectorized version takes only ``0.024`` seconds, more than 200 times faster.
|
||||
|
||||
Of course, this operation is quite contrived. A better approach is to simply do
|
||||
``xs + ys.T``, but for more complex functions :func:`vmap` can be quite handy.
|
||||
|
@@ -51,7 +51,7 @@ You can also use an :obj:`array` to index another :obj:`array`:
|
||||
.. code-block:: shell
|
||||
|
||||
>>> arr = mx.arange(10)
|
||||
>>> idx = mx.array([5, 7])
|
||||
>>> idx = mx.array([5, 7])
|
||||
>>> arr[idx]
|
||||
array([5, 7], dtype=int32)
|
||||
|
||||
@@ -77,12 +77,12 @@ from the GPU. Performing bounds checking for array indices before launching the
|
||||
kernel would be extremely inefficient.
|
||||
|
||||
Indexing with boolean masks is something that MLX may support in the future. In
|
||||
general, MLX has limited support for operations for which outputs
|
||||
general, MLX has limited support for operations for which output
|
||||
*shapes* are dependent on input *data*. Other examples of these types of
|
||||
operations which MLX does not yet support include :func:`numpy.nonzero` and the
|
||||
single input version of :func:`numpy.where`.
|
||||
|
||||
In Place Updates
|
||||
In Place Updates
|
||||
----------------
|
||||
|
||||
In place updates to indexed arrays are possible in MLX. For example:
|
||||
|
@@ -13,7 +13,7 @@ compute graph is recorded. The actual computation only happens if an
|
||||
:func:`eval` is performed.
|
||||
|
||||
MLX uses lazy evaluation because it has some nice features, some of which we
|
||||
describe below.
|
||||
describe below.
|
||||
|
||||
Transforming Compute Graphs
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
@@ -109,14 +109,14 @@ Here is a concrete example:
|
||||
|
||||
An important behavior to be aware of is when the graph will be implicitly
|
||||
evaluated. Anytime you ``print`` an array, convert it to an
|
||||
:obj:`numpy.ndarray`, or otherwise access it's memory via :obj:`memoryview`,
|
||||
:obj:`numpy.ndarray`, or otherwise access its memory via :obj:`memoryview`,
|
||||
the graph will be evaluated. Saving arrays via :func:`save` (or any other MLX
|
||||
saving functions) will also evaluate the array.
|
||||
|
||||
|
||||
Calling :func:`array.item` on a scalar array will also evaluate it. In the
|
||||
example above, printing the loss (``print(loss)``) or adding the loss scalar to
|
||||
a list (``losses.append(loss.item())``) would cause a graph evaluation. If
|
||||
a list (``losses.append(loss.item())``) would cause a graph evaluation. If
|
||||
these lines are before ``mx.eval(loss, model.parameters())`` then this
|
||||
will be a partial evaluation, computing only the forward pass.
|
||||
|
||||
|
@@ -3,10 +3,10 @@
|
||||
Conversion to NumPy and Other Frameworks
|
||||
========================================
|
||||
|
||||
MLX array supports conversion between other frameworks with either:
|
||||
MLX array supports conversion between other frameworks with either:
|
||||
|
||||
* The `Python Buffer Protocol <https://docs.python.org/3/c-api/buffer.html>`_.
|
||||
* `DLPack <https://dmlc.github.io/dlpack/latest/>`_.
|
||||
* The `Python Buffer Protocol <https://docs.python.org/3/c-api/buffer.html>`_.
|
||||
* `DLPack <https://dmlc.github.io/dlpack/latest/>`_.
|
||||
|
||||
Let's convert an array to NumPy and back.
|
||||
|
||||
@@ -66,7 +66,7 @@ even though no in-place operations on MLX memory are executed.
|
||||
PyTorch
|
||||
-------
|
||||
|
||||
.. warning::
|
||||
.. warning::
|
||||
|
||||
PyTorch Support for :obj:`memoryview` is experimental and can break for
|
||||
multi-dimensional arrays. Casting to NumPy first is advised for now.
|
||||
|
@@ -64,4 +64,4 @@ Other gradient transformations include :func:`vjp` for vector-Jacobian products
|
||||
and :func:`jvp` for Jacobian-vector products.
|
||||
|
||||
Use :func:`value_and_grad` to efficiently compute both a function's output and
|
||||
gradient with respect to the function's input.
|
||||
gradient with respect to the function's input.
|
||||
|
@@ -8,33 +8,33 @@ Saving and Loading Arrays
|
||||
MLX supports multiple array serialization formats.
|
||||
|
||||
.. list-table:: Serialization Formats
|
||||
:widths: 20 8 25 25
|
||||
:widths: 20 8 25 25
|
||||
:header-rows: 1
|
||||
|
||||
* - Format
|
||||
- Extension
|
||||
* - Format
|
||||
- Extension
|
||||
- Function
|
||||
- Notes
|
||||
* - NumPy
|
||||
- ``.npy``
|
||||
- Notes
|
||||
* - NumPy
|
||||
- ``.npy``
|
||||
- :func:`save`
|
||||
- Single arrays only
|
||||
* - NumPy archive
|
||||
- ``.npz``
|
||||
* - NumPy archive
|
||||
- ``.npz``
|
||||
- :func:`savez` and :func:`savez_compressed`
|
||||
- Multiple arrays
|
||||
- Multiple arrays
|
||||
* - Safetensors
|
||||
- ``.safetensors``
|
||||
- ``.safetensors``
|
||||
- :func:`save_safetensors`
|
||||
- Multiple arrays
|
||||
* - GGUF
|
||||
- ``.gguf``
|
||||
- Multiple arrays
|
||||
* - GGUF
|
||||
- ``.gguf``
|
||||
- :func:`save_gguf`
|
||||
- Multiple arrays
|
||||
|
||||
The :func:`load` function will load any of the supported serialization
|
||||
formats. It determines the format from the extensions. The output of
|
||||
:func:`load` depends on the format.
|
||||
:func:`load` depends on the format.
|
||||
|
||||
Here's an example of saving a single array to a file:
|
||||
|
||||
|
@@ -20,7 +20,7 @@ Both ``a`` and ``b`` live in unified memory.
|
||||
|
||||
In MLX, rather than moving arrays to devices, you specify the device when you
|
||||
run the operation. Any device can perform any operation on ``a`` and ``b``
|
||||
without needing to move them from one memory location to another. For example:
|
||||
without needing to move them from one memory location to another. For example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
|
@@ -257,7 +257,7 @@ void Axpby::eval_gpu(
|
||||
|
||||
// Prepare to encode kernel
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
// Kernel parameters are registered with buffer indices corresponding to
|
||||
// those in the kernel declaration at axpby.metal
|
||||
@@ -272,15 +272,15 @@ void Axpby::eval_gpu(
|
||||
compute_encoder.set_output_array(out, 2);
|
||||
|
||||
// Encode alpha and beta
|
||||
compute_encoder->setBytes(&alpha_, sizeof(float), 3);
|
||||
compute_encoder->setBytes(&beta_, sizeof(float), 4);
|
||||
compute_encoder.set_bytes(alpha_, 3);
|
||||
compute_encoder.set_bytes(beta_, 4);
|
||||
|
||||
// Encode shape, strides and ndim if needed
|
||||
if (!contiguous_kernel) {
|
||||
compute_encoder->setBytes(x.shape().data(), ndim * sizeof(int), 5);
|
||||
compute_encoder->setBytes(x.strides().data(), ndim * sizeof(size_t), 6);
|
||||
compute_encoder->setBytes(y.strides().data(), ndim * sizeof(size_t), 7);
|
||||
compute_encoder->setBytes(&ndim, sizeof(int), 8);
|
||||
compute_encoder.set_vector_bytes(x.shape(), 5);
|
||||
compute_encoder.set_vector_bytes(x.strides(), 6);
|
||||
compute_encoder.set_bytes(y.strides(), 7);
|
||||
compute_encoder.set_bytes(ndim, 8);
|
||||
}
|
||||
|
||||
// We launch 1 thread for each input and make sure that the number of
|
||||
@@ -295,7 +295,7 @@ void Axpby::eval_gpu(
|
||||
|
||||
// Launch the grid with the given number of threads divided among
|
||||
// the given threadgroups
|
||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
#else // Metal is not available
|
||||
|
@@ -2,7 +2,6 @@
|
||||
|
||||
#include <metal_stdlib>
|
||||
|
||||
#include "mlx/backend/metal/kernels/bf16.h"
|
||||
#include "mlx/backend/metal/kernels/utils.h"
|
||||
|
||||
template <typename T>
|
||||
@@ -60,4 +59,4 @@ template <typename T>
|
||||
instantiate_axpby(float32, float);
|
||||
instantiate_axpby(float16, half);
|
||||
instantiate_axpby(bfloat16, bfloat16_t);
|
||||
instantiate_axpby(complex64, complex64_t);
|
||||
instantiate_axpby(complex64, complex64_t);
|
||||
|
@@ -2,7 +2,7 @@
|
||||
requires = [
|
||||
"setuptools>=42",
|
||||
"cmake>=3.24",
|
||||
"mlx>=0.17.0",
|
||||
"nanobind==2.1.0",
|
||||
"mlx>=0.18.0",
|
||||
"nanobind==2.2.0",
|
||||
]
|
||||
build-backend = "setuptools.build_meta"
|
||||
|
@@ -1,4 +1,4 @@
|
||||
setuptools>=42
|
||||
cmake>=3.24
|
||||
mlx>=0.17.0
|
||||
nanobind==2.1.0
|
||||
mlx>=0.21.0
|
||||
nanobind==2.2.0
|
||||
|
15
mlx.pc.in
15
mlx.pc.in
@@ -28,10 +28,19 @@ endif()
|
||||
if (@MLX_BUILD_METAL@)
|
||||
set(MLX_BUILD_METAL @MLX_BUILD_METAL@)
|
||||
set(MLX_CXX_FLAGS ${MLX_CXX_FLAGS} -D_METAL_)
|
||||
set_and_check(MLX_INCLUDE_DIRS
|
||||
${MLX_INCLUDE_DIRS}
|
||||
set(MLX_INCLUDE_DIRS
|
||||
"${MLX_INCLUDE_DIRS};"
|
||||
@PACKAGE_CMAKE_INSTALL_INCLUDEDIR@/metal_cpp
|
||||
)
|
||||
if(@MLX_METAL_VERSION@ GREATER_EQUAL 310)
|
||||
set(MLX_INCLUDE_DIRS
|
||||
"${MLX_INCLUDE_DIRS};"
|
||||
@PACKAGE_CMAKE_INSTALL_INCLUDEDIR@/mlx/backend/metal/kernels/metal_3_1)
|
||||
else()
|
||||
set(MLX_INCLUDE_DIRS
|
||||
"${MLX_INCLUDE_DIRS};"
|
||||
@PACKAGE_CMAKE_INSTALL_INCLUDEDIR@/mlx/backend/metal/kernels/metal_3_0)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
set_target_properties(mlx PROPERTIES
|
||||
@@ -40,4 +49,4 @@ set_target_properties(mlx PROPERTIES
|
||||
)
|
||||
|
||||
include(FindPackageHandleStandardArgs)
|
||||
find_package_handle_standard_args(MLX DEFAULT_MSG MLX_LIBRARY MLX_INCLUDE_DIRS)
|
||||
find_package_handle_standard_args(MLX DEFAULT_MSG MLX_LIBRARY MLX_INCLUDE_DIRS)
|
||||
|
@@ -19,7 +19,7 @@ Buffer malloc(size_t size) {
|
||||
}
|
||||
|
||||
void free(Buffer buffer) {
|
||||
return allocator().free(buffer);
|
||||
allocator().free(buffer);
|
||||
}
|
||||
|
||||
Buffer CommonAllocator::malloc(size_t size, bool) {
|
||||
|
@@ -95,13 +95,29 @@ void array::detach() {
|
||||
array_desc_->primitive = nullptr;
|
||||
}
|
||||
|
||||
void array::eval() {
|
||||
// Ensure the array is ready to be read
|
||||
if (status() == Status::scheduled) {
|
||||
bool array::is_available() const {
|
||||
if (status() == Status::available) {
|
||||
return true;
|
||||
} else if (status() == Status::evaluated && event().is_signaled()) {
|
||||
set_status(Status::available);
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
void array::wait() {
|
||||
if (!is_available()) {
|
||||
event().wait();
|
||||
set_status(Status::available);
|
||||
} else if (status() == Status::unscheduled) {
|
||||
}
|
||||
}
|
||||
|
||||
void array::eval() {
|
||||
// Ensure the array is ready to be read
|
||||
if (status() == Status::unscheduled) {
|
||||
mlx::core::eval({*this});
|
||||
} else {
|
||||
wait();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -162,8 +178,10 @@ void array::move_shared_buffer(
|
||||
array_desc_->flags = flags;
|
||||
array_desc_->data_size = data_size;
|
||||
auto char_offset = sizeof(char) * itemsize() * offset;
|
||||
array_desc_->data_ptr = static_cast<void*>(
|
||||
static_cast<char*>(other.array_desc_->data_ptr) + char_offset);
|
||||
auto data_ptr = other.array_desc_->data_ptr;
|
||||
other.array_desc_->data_ptr = nullptr;
|
||||
array_desc_->data_ptr =
|
||||
static_cast<void*>(static_cast<char*>(data_ptr) + char_offset);
|
||||
}
|
||||
|
||||
void array::move_shared_buffer(array other) {
|
||||
@@ -196,6 +214,8 @@ array::~array() {
|
||||
if (do_detach) {
|
||||
for (auto& s : siblings()) {
|
||||
for (auto& ss : s.siblings()) {
|
||||
// Set to null here to avoid descending into array destructor
|
||||
// for siblings
|
||||
ss.array_desc_ = nullptr;
|
||||
}
|
||||
s.array_desc_->siblings.clear();
|
||||
@@ -242,25 +262,46 @@ array::ArrayDesc::~ArrayDesc() {
|
||||
// This calls recursively the destructor and can result in stack overflow, we
|
||||
// instead put them in a vector and destroy them one at a time resulting in a
|
||||
// max stack depth of 2.
|
||||
if (inputs.empty()) {
|
||||
return;
|
||||
}
|
||||
|
||||
std::vector<std::shared_ptr<ArrayDesc>> for_deletion;
|
||||
|
||||
for (array& a : inputs) {
|
||||
if (a.array_desc_.use_count() == 1) {
|
||||
for_deletion.push_back(std::move(a.array_desc_));
|
||||
auto append_deletable_inputs = [&for_deletion](ArrayDesc& ad) {
|
||||
std::unordered_map<std::uintptr_t, array> input_map;
|
||||
for (array& a : ad.inputs) {
|
||||
if (a.array_desc_) {
|
||||
input_map.insert({a.id(), a});
|
||||
for (auto& s : a.siblings()) {
|
||||
input_map.insert({s.id(), s});
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
ad.inputs.clear();
|
||||
for (auto& [_, a] : input_map) {
|
||||
if (a.array_desc_.use_count() <= a.siblings().size() + 1) {
|
||||
for_deletion.push_back(std::move(a.array_desc_));
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
append_deletable_inputs(*this);
|
||||
|
||||
while (!for_deletion.empty()) {
|
||||
// top is going to be deleted at the end of the block *after* the arrays
|
||||
// with inputs have been moved into the vector
|
||||
auto top = std::move(for_deletion.back());
|
||||
for_deletion.pop_back();
|
||||
append_deletable_inputs(*top);
|
||||
|
||||
for (array& a : top->inputs) {
|
||||
if (a.array_desc_.use_count() == 1) {
|
||||
for_deletion.push_back(std::move(a.array_desc_));
|
||||
}
|
||||
// Clear out possible siblings to break circular references
|
||||
for (auto& s : top->siblings) {
|
||||
// Set to null here to avoid descending into top-level
|
||||
// array destructor for siblings
|
||||
s.array_desc_ = nullptr;
|
||||
}
|
||||
top->siblings.clear();
|
||||
}
|
||||
}
|
||||
|
||||
|
30
mlx/array.h
30
mlx/array.h
@@ -344,11 +344,33 @@ class array {
|
||||
return static_cast<T*>(array_desc_->data_ptr);
|
||||
}
|
||||
|
||||
enum Status { unscheduled, scheduled, available };
|
||||
enum Status {
|
||||
// The ouptut of a computation which has not been scheduled.
|
||||
// For example, the status of `x` in `auto x = a + b`.
|
||||
unscheduled,
|
||||
|
||||
bool is_available() const {
|
||||
return status() == Status::available;
|
||||
}
|
||||
// The ouptut of a computation which has been scheduled but `eval_*` has
|
||||
// not yet been called on the array's primitive. A possible
|
||||
// status of `x` in `auto x = a + b; eval(x);`
|
||||
scheduled,
|
||||
|
||||
// The array's `eval_*` function has been run, but the computation is not
|
||||
// necessarily complete. The array will have memory allocated and if it is
|
||||
// not a tracer then it will be detached from the graph.
|
||||
evaluated,
|
||||
|
||||
// If the array is the output of a computation then the computation
|
||||
// is complete. Constant arrays are always available (e.g. `array({1, 2,
|
||||
// 3})`)
|
||||
available
|
||||
};
|
||||
|
||||
// Check if the array is safe to read.
|
||||
bool is_available() const;
|
||||
|
||||
// Wait on the array to be available. After this `is_available` returns
|
||||
// `true`.
|
||||
void wait();
|
||||
|
||||
Status status() const {
|
||||
return array_desc_->status;
|
||||
|
@@ -81,6 +81,7 @@ DEFAULT_MULTI(SVD)
|
||||
DEFAULT(Transpose)
|
||||
DEFAULT(Inverse)
|
||||
DEFAULT(Cholesky)
|
||||
DEFAULT_MULTI(Eigh)
|
||||
|
||||
void Abs::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
|
@@ -18,49 +18,61 @@ void _qmm_t_4_64(
|
||||
const float* biases,
|
||||
int M,
|
||||
int N,
|
||||
int K) {
|
||||
int K,
|
||||
int B,
|
||||
bool batched_w) {
|
||||
constexpr int bits = 4;
|
||||
constexpr int group_size = 64;
|
||||
constexpr int bitmask = (1 << bits) - 1;
|
||||
constexpr int pack_factor = 32 / bits;
|
||||
constexpr int packs_in_group = group_size / pack_factor;
|
||||
|
||||
for (int m = 0; m < M; m++) {
|
||||
const uint32_t* w_local = w;
|
||||
const float* scales_local = scales;
|
||||
const float* biases_local = biases;
|
||||
int w_els = N * K / pack_factor;
|
||||
int g_els = w_els * pack_factor / group_size;
|
||||
|
||||
for (int n = 0; n < N; n++) {
|
||||
const simd_float16* x_local = (simd_float16*)x;
|
||||
simd_float16 sum = 0;
|
||||
for (int k = 0; k < K; k += group_size) {
|
||||
float scale = *scales_local++;
|
||||
float bias = *biases_local++;
|
||||
for (int i = 0; i < B; i++) {
|
||||
for (int m = 0; m < M; m++) {
|
||||
const uint32_t* w_local = w;
|
||||
const float* scales_local = scales;
|
||||
const float* biases_local = biases;
|
||||
|
||||
for (int kw = 0; kw < packs_in_group; kw += 2) {
|
||||
// TODO: vectorize this properly
|
||||
simd_uint16 wi;
|
||||
for (int e = 0; e < 2; e++) {
|
||||
uint32_t wii = *w_local++;
|
||||
for (int p = 0; p < 8; p++) {
|
||||
wi[e * 8 + p] = wii & bitmask;
|
||||
wii >>= bits;
|
||||
for (int n = 0; n < N; n++) {
|
||||
const simd_float16* x_local = (simd_float16*)x;
|
||||
simd_float16 sum = 0;
|
||||
for (int k = 0; k < K; k += group_size) {
|
||||
float scale = *scales_local++;
|
||||
float bias = *biases_local++;
|
||||
|
||||
for (int kw = 0; kw < packs_in_group; kw += 2) {
|
||||
// TODO: vectorize this properly
|
||||
simd_uint16 wi;
|
||||
for (int e = 0; e < 2; e++) {
|
||||
uint32_t wii = *w_local++;
|
||||
for (int p = 0; p < 8; p++) {
|
||||
wi[e * 8 + p] = wii & bitmask;
|
||||
wii >>= bits;
|
||||
}
|
||||
}
|
||||
}
|
||||
simd_float16 wf = simd_float(wi);
|
||||
wf *= scale;
|
||||
wf += bias;
|
||||
simd_float16 wf = simd_float(wi);
|
||||
wf *= scale;
|
||||
wf += bias;
|
||||
|
||||
sum += (*x_local) * wf;
|
||||
x_local++;
|
||||
sum += (*x_local) * wf;
|
||||
x_local++;
|
||||
}
|
||||
}
|
||||
|
||||
*result = simd_reduce_add(sum);
|
||||
result++;
|
||||
}
|
||||
|
||||
*result = simd_reduce_add(sum);
|
||||
result++;
|
||||
x += K;
|
||||
}
|
||||
if (batched_w) {
|
||||
w += w_els;
|
||||
scales += g_els;
|
||||
biases += g_els;
|
||||
}
|
||||
|
||||
x += K;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -82,8 +94,10 @@ void QuantizedMatmul::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
if (condition) {
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
int K = x.shape(-1);
|
||||
int M = x.size() / K;
|
||||
int M = x.shape(-2);
|
||||
int N = out.shape(-1);
|
||||
int B = x.size() / K / M;
|
||||
bool batched_w = w.ndim() > 2;
|
||||
_qmm_t_4_64(
|
||||
out.data<float>(),
|
||||
x.data<float>(),
|
||||
@@ -92,7 +106,9 @@ void QuantizedMatmul::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
biases.data<float>(),
|
||||
M,
|
||||
N,
|
||||
K);
|
||||
K,
|
||||
B,
|
||||
batched_w);
|
||||
} else {
|
||||
eval(inputs, out);
|
||||
}
|
||||
|
@@ -33,8 +33,8 @@ namespace {
|
||||
* Note: The implementation below is a general fast exp. There could be faster
|
||||
* implementations for numbers strictly < 0.
|
||||
*/
|
||||
inline simd_float16 simd_fast_exp(simd_float16 x) {
|
||||
x *= 1.442695; // multiply with log_2(e)
|
||||
inline simd_float16 simd_fast_exp(simd_float16 x_init) {
|
||||
auto x = x_init * 1.442695; // multiply with log_2(e)
|
||||
simd_float16 ipart, fpart;
|
||||
simd_int16 epart;
|
||||
x = simd_clamp(x, -80, 80);
|
||||
@@ -53,7 +53,9 @@ inline simd_float16 simd_fast_exp(simd_float16 x) {
|
||||
// bitshifting
|
||||
epart = (simd_int(ipart) + 127) << 23;
|
||||
|
||||
return (*(simd_float16*)&epart) * x;
|
||||
// Avoid supressing NaNs
|
||||
simd_int16 eq = (x_init == x_init);
|
||||
return simd_bitselect(x_init, (*(simd_float16*)&epart) * x, eq);
|
||||
}
|
||||
|
||||
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
|
||||
|
@@ -31,6 +31,7 @@ target_sources(
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/common.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/eigh.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/erf.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/hadamard.cpp
|
||||
|
@@ -2,46 +2,12 @@
|
||||
|
||||
#include "mlx/allocator.h"
|
||||
#include "mlx/backend/common/copy.h"
|
||||
#include "mlx/backend/common/lapack.h"
|
||||
#include "mlx/linalg.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
#ifdef ACCELERATE_NEW_LAPACK
|
||||
#include <Accelerate/Accelerate.h>
|
||||
#else
|
||||
#include <lapack.h>
|
||||
#endif
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace {
|
||||
|
||||
// Delegate to the Cholesky factorization taking into account differences in
|
||||
// LAPACK implementations (basically how to pass the 'uplo' string to fortran).
|
||||
int spotrf_wrapper(char uplo, float* matrix, int N) {
|
||||
int info;
|
||||
|
||||
#ifdef LAPACK_FORTRAN_STRLEN_END
|
||||
spotrf_(
|
||||
/* uplo = */ &uplo,
|
||||
/* n = */ &N,
|
||||
/* a = */ matrix,
|
||||
/* lda = */ &N,
|
||||
/* info = */ &info,
|
||||
/* uplo_len = */ static_cast<size_t>(1));
|
||||
#else
|
||||
spotrf_(
|
||||
/* uplo = */ &uplo,
|
||||
/* n = */ &N,
|
||||
/* a = */ matrix,
|
||||
/* lda = */ &N,
|
||||
/* info = */ &info);
|
||||
#endif
|
||||
|
||||
return info;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void cholesky_impl(const array& a, array& factor, bool upper) {
|
||||
// Lapack uses the column-major convention. We take advantage of the fact that
|
||||
// the matrix should be symmetric:
|
||||
@@ -66,7 +32,14 @@ void cholesky_impl(const array& a, array& factor, bool upper) {
|
||||
|
||||
for (int i = 0; i < num_matrices; i++) {
|
||||
// Compute Cholesky factorization.
|
||||
int info = spotrf_wrapper(uplo, matrix, N);
|
||||
int info;
|
||||
MLX_LAPACK_FUNC(spotrf)
|
||||
(
|
||||
/* uplo = */ &uplo,
|
||||
/* n = */ &N,
|
||||
/* a = */ matrix,
|
||||
/* lda = */ &N,
|
||||
/* info = */ &info);
|
||||
|
||||
// TODO: We do nothing when the matrix is not positive semi-definite
|
||||
// because throwing an error would result in a crash. If we figure out how
|
||||
|
@@ -39,7 +39,7 @@ void AsStrided::eval(const std::vector<array>& inputs, array& out) {
|
||||
// rely on data_size anyway.
|
||||
size_t data_size = out.size();
|
||||
|
||||
return out.copy_shared_buffer(in, strides_, flags, data_size, offset_);
|
||||
return move_or_copy(in, out, strides_, flags, data_size, offset_);
|
||||
}
|
||||
|
||||
void Broadcast::eval(const std::vector<array>& inputs, array& out) {
|
||||
@@ -58,12 +58,12 @@ void Broadcast::eval(const std::vector<array>& inputs, array& out) {
|
||||
if (out.size() > in.size()) {
|
||||
flags.row_contiguous = flags.col_contiguous = false;
|
||||
}
|
||||
out.copy_shared_buffer(in, strides, flags, in.data_size());
|
||||
move_or_copy(in, out, strides, flags, in.data_size());
|
||||
}
|
||||
|
||||
void Copy::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
out.copy_shared_buffer(inputs[0]);
|
||||
move_or_copy(inputs[0], out);
|
||||
}
|
||||
|
||||
void CustomTransforms::eval(
|
||||
@@ -72,7 +72,7 @@ void CustomTransforms::eval(
|
||||
assert(inputs.size() > outputs.size());
|
||||
for (int i = 0, j = inputs.size() - outputs.size(); i < outputs.size();
|
||||
i++, j++) {
|
||||
outputs[i].copy_shared_buffer(inputs[j]);
|
||||
move_or_copy(inputs[j], outputs[i]);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -81,7 +81,7 @@ void Depends::eval(
|
||||
std::vector<array>& outputs) {
|
||||
assert(inputs.size() > outputs.size());
|
||||
for (int i = 0; i < outputs.size(); i++) {
|
||||
outputs[i].copy_shared_buffer(inputs[i]);
|
||||
move_or_copy(inputs[i], outputs[i]);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -194,7 +194,7 @@ void Reshape::shared_buffer_reshape(
|
||||
auto max_dim = std::max_element(out.shape().begin(), out.shape().end());
|
||||
flags.col_contiguous = out.size() <= 1 || out.size() == *max_dim;
|
||||
}
|
||||
out.copy_shared_buffer(in, out_strides, flags, in.data_size());
|
||||
move_or_copy(in, out, out_strides, flags, in.data_size());
|
||||
}
|
||||
|
||||
void Split::eval(
|
||||
@@ -263,7 +263,7 @@ std::tuple<int64_t, std::vector<int64_t>> SliceUpdate::prepare_slice(
|
||||
|
||||
void StopGradient::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
out.copy_shared_buffer(inputs[0]);
|
||||
move_or_copy(inputs[0], out);
|
||||
}
|
||||
|
||||
void Transpose::eval(const std::vector<array>& inputs, array& out) {
|
||||
@@ -297,7 +297,7 @@ void Transpose::eval(const std::vector<array>& inputs, array& out) {
|
||||
b_stride *= out.shape(ri);
|
||||
}
|
||||
}
|
||||
out.copy_shared_buffer(in, out_strides, flags, in.data_size());
|
||||
move_or_copy(in, out, out_strides, flags, in.data_size());
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@@ -4,6 +4,8 @@
|
||||
#include <filesystem>
|
||||
#include <fstream>
|
||||
#include <list>
|
||||
#include <mutex>
|
||||
#include <shared_mutex>
|
||||
|
||||
#include "mlx/backend/common/compiled.h"
|
||||
#include "mlx/backend/common/compiled_preamble.h"
|
||||
@@ -12,22 +14,7 @@
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
// GPU compile is always available if the GPU is available and since we are in
|
||||
// this file CPU compile is also available.
|
||||
namespace detail {
|
||||
bool compile_available_for_device(const Device& device) {
|
||||
return true;
|
||||
}
|
||||
} // namespace detail
|
||||
|
||||
std::string get_temp_file(const std::string& name) {
|
||||
return std::filesystem::temp_directory_path().append(name);
|
||||
}
|
||||
|
||||
// Return a pointer to a compiled function
|
||||
void* compile(
|
||||
const std::string& kernel_name,
|
||||
const std::string& source_code = "") {
|
||||
struct CompilerCache {
|
||||
struct DLib {
|
||||
DLib(const std::string& libname) {
|
||||
lib = dlopen(libname.c_str(), RTLD_NOW);
|
||||
@@ -44,15 +31,41 @@ void* compile(
|
||||
void* lib;
|
||||
};
|
||||
// Statics to cache compiled libraries and functions
|
||||
static std::list<DLib> libs;
|
||||
static std::unordered_map<std::string, void*> kernels;
|
||||
if (auto it = kernels.find(kernel_name); it != kernels.end()) {
|
||||
return it->second;
|
||||
}
|
||||
if (source_code.empty()) {
|
||||
return nullptr;
|
||||
std::list<DLib> libs;
|
||||
std::unordered_map<std::string, void*> kernels;
|
||||
std::shared_mutex mtx;
|
||||
};
|
||||
|
||||
static CompilerCache cache{};
|
||||
|
||||
// GPU compile is always available if the GPU is available and since we are in
|
||||
// this file CPU compile is also available.
|
||||
namespace detail {
|
||||
bool compile_available_for_device(const Device& device) {
|
||||
return true;
|
||||
}
|
||||
} // namespace detail
|
||||
|
||||
std::string get_temp_file(const std::string& name) {
|
||||
return std::filesystem::temp_directory_path().append(name);
|
||||
}
|
||||
|
||||
// Return a pointer to a compiled function
|
||||
void* compile(
|
||||
const std::string& kernel_name,
|
||||
const std::function<std::string(void)>& source_builder) {
|
||||
{
|
||||
std::shared_lock lock(cache.mtx);
|
||||
if (auto it = cache.kernels.find(kernel_name); it != cache.kernels.end()) {
|
||||
return it->second;
|
||||
}
|
||||
}
|
||||
|
||||
std::unique_lock lock(cache.mtx);
|
||||
if (auto it = cache.kernels.find(kernel_name); it != cache.kernels.end()) {
|
||||
return it->second;
|
||||
}
|
||||
std::string source_code = source_builder();
|
||||
std::string kernel_file_name;
|
||||
|
||||
// Deal with long kernel names. Maximum length for files on macOS is 255
|
||||
@@ -90,8 +103,8 @@ void* compile(
|
||||
source_file.close();
|
||||
|
||||
std::ostringstream build_command;
|
||||
build_command << "g++ -std=c++17 -O2 -Wall -fPIC -shared "
|
||||
<< source_file_path << " -o " << shared_lib_path;
|
||||
build_command << "g++ -std=c++17 -O3 -Wall -fPIC -shared '"
|
||||
<< source_file_path << "' -o '" << shared_lib_path << "'";
|
||||
std::string build_command_str = build_command.str();
|
||||
auto return_code = system(build_command_str.c_str());
|
||||
if (return_code) {
|
||||
@@ -103,10 +116,10 @@ void* compile(
|
||||
}
|
||||
|
||||
// load library
|
||||
libs.emplace_back(shared_lib_path);
|
||||
cache.libs.emplace_back(shared_lib_path);
|
||||
|
||||
// Load function
|
||||
void* fun = dlsym(libs.back().lib, kernel_name.c_str());
|
||||
void* fun = dlsym(cache.libs.back().lib, kernel_name.c_str());
|
||||
if (!fun) {
|
||||
std::ostringstream msg;
|
||||
msg << "[Compile::eval_cpu] Failed to load compiled function "
|
||||
@@ -114,7 +127,7 @@ void* compile(
|
||||
<< dlerror();
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
kernels.insert({kernel_name, fun});
|
||||
cache.kernels.insert({kernel_name, fun});
|
||||
return fun;
|
||||
}
|
||||
|
||||
@@ -266,7 +279,7 @@ void Compiled::eval_cpu(
|
||||
|
||||
// Figure out which kernel we are using
|
||||
auto& shape = outputs[0].shape();
|
||||
bool contiguous = compiled_check_contiguity(inputs, shape);
|
||||
auto contiguous = compiled_check_contiguity(inputs, shape);
|
||||
|
||||
// Handle all broadcasting and collect function input arguments
|
||||
std::vector<void*> args;
|
||||
@@ -316,10 +329,7 @@ void Compiled::eval_cpu(
|
||||
}
|
||||
|
||||
// Get the function
|
||||
auto fn_ptr = compile(kernel_name);
|
||||
|
||||
// If it doesn't exist, compile it
|
||||
if (fn_ptr == nullptr) {
|
||||
auto fn_ptr = compile(kernel_name, [&]() {
|
||||
std::ostringstream kernel;
|
||||
kernel << get_kernel_preamble() << std::endl;
|
||||
kernel << "extern \"C\" {" << std::endl;
|
||||
@@ -334,10 +344,8 @@ void Compiled::eval_cpu(
|
||||
ndim);
|
||||
// Close extern "C"
|
||||
kernel << "}" << std::endl;
|
||||
|
||||
// Compile and get function pointer
|
||||
fn_ptr = compile(kernel_name, kernel.str());
|
||||
}
|
||||
return kernel.str();
|
||||
});
|
||||
|
||||
compiled_allocate_outputs(
|
||||
inputs, outputs, inputs_, constant_ids_, contiguous, false);
|
||||
|
@@ -3,13 +3,8 @@
|
||||
#include <cassert>
|
||||
#include <numeric>
|
||||
|
||||
#ifdef ACCELERATE_NEW_LAPACK
|
||||
#include <Accelerate/Accelerate.h>
|
||||
#else
|
||||
#include <cblas.h>
|
||||
#endif
|
||||
|
||||
#include "mlx/backend/common/copy.h"
|
||||
#include "mlx/backend/common/lapack.h"
|
||||
#include "mlx/primitives.h"
|
||||
#include "mlx/utils.h"
|
||||
|
||||
|
@@ -1,14 +1,10 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#ifdef ACCELERATE_NEW_LAPACK
|
||||
#include <Accelerate/Accelerate.h>
|
||||
#else
|
||||
#include <cblas.h>
|
||||
#endif
|
||||
#include <cstring>
|
||||
|
||||
#include "mlx/array.h"
|
||||
#include "mlx/backend/common/copy.h"
|
||||
#include "mlx/backend/common/lapack.h"
|
||||
#include "mlx/backend/common/utils.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
@@ -114,6 +110,7 @@ DEFAULT(Tanh)
|
||||
DEFAULT(Transpose)
|
||||
DEFAULT(Inverse)
|
||||
DEFAULT(Cholesky)
|
||||
DEFAULT_MULTI(Eigh)
|
||||
|
||||
namespace {
|
||||
|
||||
|
117
mlx/backend/common/eigh.cpp
Normal file
117
mlx/backend/common/eigh.cpp
Normal file
@@ -0,0 +1,117 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include "mlx/allocator.h"
|
||||
#include "mlx/array.h"
|
||||
#include "mlx/backend/common/copy.h"
|
||||
#include "mlx/backend/common/lapack.h"
|
||||
#include "mlx/linalg.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace {
|
||||
|
||||
void ssyevd(
|
||||
char jobz,
|
||||
char uplo,
|
||||
float* a,
|
||||
int N,
|
||||
float* w,
|
||||
float* work,
|
||||
int lwork,
|
||||
int* iwork,
|
||||
int liwork) {
|
||||
int info;
|
||||
MLX_LAPACK_FUNC(ssyevd)
|
||||
(
|
||||
/* jobz = */ &jobz,
|
||||
/* uplo = */ &uplo,
|
||||
/* n = */ &N,
|
||||
/* a = */ a,
|
||||
/* lda = */ &N,
|
||||
/* w = */ w,
|
||||
/* work = */ work,
|
||||
/* lwork = */ &lwork,
|
||||
/* iwork = */ iwork,
|
||||
/* liwork = */ &liwork,
|
||||
/* info = */ &info);
|
||||
if (info != 0) {
|
||||
std::stringstream msg;
|
||||
msg << "[Eigh::eval_cpu] Eigenvalue decomposition failed with error code "
|
||||
<< info;
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void Eigh::eval(const std::vector<array>& inputs, std::vector<array>& outputs) {
|
||||
const auto& a = inputs[0];
|
||||
auto& values = outputs[0];
|
||||
|
||||
auto vectors = compute_eigenvectors_
|
||||
? outputs[1]
|
||||
: array(a.shape(), a.dtype(), nullptr, {});
|
||||
|
||||
values.set_data(allocator::malloc_or_wait(values.nbytes()));
|
||||
|
||||
copy(
|
||||
a,
|
||||
vectors,
|
||||
a.flags().row_contiguous ? CopyType::Vector : CopyType::General);
|
||||
|
||||
if (compute_eigenvectors_) {
|
||||
// Set the strides and flags so the eigenvectors
|
||||
// are in the columns of the output
|
||||
auto flags = vectors.flags();
|
||||
auto strides = vectors.strides();
|
||||
auto ndim = a.ndim();
|
||||
std::swap(strides[ndim - 1], strides[ndim - 2]);
|
||||
|
||||
if (a.size() > 1) {
|
||||
flags.row_contiguous = false;
|
||||
if (ndim > 2) {
|
||||
flags.col_contiguous = false;
|
||||
} else {
|
||||
flags.col_contiguous = true;
|
||||
}
|
||||
}
|
||||
vectors.move_shared_buffer(vectors, strides, flags, vectors.data_size());
|
||||
}
|
||||
|
||||
auto vec_ptr = vectors.data<float>();
|
||||
auto eig_ptr = values.data<float>();
|
||||
|
||||
char jobz = compute_eigenvectors_ ? 'V' : 'N';
|
||||
auto N = a.shape(-1);
|
||||
|
||||
// Work query
|
||||
int lwork;
|
||||
int liwork;
|
||||
{
|
||||
float work;
|
||||
int iwork;
|
||||
ssyevd(jobz, uplo_[0], nullptr, N, nullptr, &work, -1, &iwork, -1);
|
||||
lwork = static_cast<int>(work);
|
||||
liwork = iwork;
|
||||
}
|
||||
|
||||
auto work_buf = array::Data{allocator::malloc_or_wait(sizeof(float) * lwork)};
|
||||
auto iwork_buf = array::Data{allocator::malloc_or_wait(sizeof(int) * liwork)};
|
||||
for (size_t i = 0; i < a.size() / (N * N); ++i) {
|
||||
ssyevd(
|
||||
jobz,
|
||||
uplo_[0],
|
||||
vec_ptr,
|
||||
N,
|
||||
eig_ptr,
|
||||
static_cast<float*>(work_buf.buffer.raw_ptr()),
|
||||
lwork,
|
||||
static_cast<int*>(iwork_buf.buffer.raw_ptr()),
|
||||
liwork);
|
||||
vec_ptr += N * N;
|
||||
eig_ptr += N;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
@@ -1,5 +1,4 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
|
||||
#include <algorithm>
|
||||
#include <cassert>
|
||||
#include <cmath>
|
||||
@@ -81,11 +80,18 @@ void gather(
|
||||
T* dst_ptr = out.data<T>();
|
||||
size_t out_idx = 0;
|
||||
|
||||
std::vector<ContiguousIterator<size_t>> its(inds.begin(), inds.end());
|
||||
ContiguousIterator<size_t> src_it;
|
||||
if (!can_copy && src.ndim() > 0) {
|
||||
src_it = std::move(
|
||||
ContiguousIterator<size_t>(slice_sizes, src.strides(), src.ndim()));
|
||||
}
|
||||
for (int idx = 0; idx < ind_size; idx++) {
|
||||
size_t src_idx = 0;
|
||||
for (int ii = 0; ii < inds.size(); ++ii) {
|
||||
auto ax = axes[ii];
|
||||
auto idx_loc = elem_to_loc(idx, inds[ii]);
|
||||
auto idx_loc = its[ii].loc;
|
||||
its[ii].step();
|
||||
auto idx_val =
|
||||
offset_neg_idx(inds[ii].data<IdxT>()[idx_loc], src.shape(ax));
|
||||
src_idx += (idx_val * src.strides()[ax]);
|
||||
@@ -99,9 +105,10 @@ void gather(
|
||||
out_idx += slice_size;
|
||||
} else {
|
||||
for (int jj = 0; jj < slice_size; jj++) {
|
||||
auto src_offset = elem_to_loc(jj, slice_sizes, src.strides());
|
||||
dst_ptr[out_idx++] = src_ptr[src_idx + src_offset];
|
||||
dst_ptr[out_idx++] = src_ptr[src_idx + src_it.loc];
|
||||
src_it.step();
|
||||
}
|
||||
src_it.reset();
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -223,21 +230,29 @@ void scatter(
|
||||
update_size *= us;
|
||||
}
|
||||
|
||||
std::vector<ContiguousIterator<size_t>> its(inds.begin(), inds.end());
|
||||
ContiguousIterator<size_t> update_it(updates);
|
||||
ContiguousIterator<size_t> out_it(update_shape, out.strides(), out.ndim());
|
||||
|
||||
for (int i = 0; i < n_updates; ++i) {
|
||||
size_t out_offset = 0;
|
||||
for (int j = 0; j < nind; ++j) {
|
||||
auto ax = axes[j];
|
||||
auto idx_loc = elem_to_loc(i, inds[j]);
|
||||
auto idx_loc = its[j].loc;
|
||||
its[j].step();
|
||||
auto idx_val =
|
||||
offset_neg_idx(inds[j].data<IdxT>()[idx_loc], out.shape(ax));
|
||||
out_offset += (idx_val * out.strides()[ax]);
|
||||
}
|
||||
update_it.seek(i * update_size);
|
||||
for (int j = 0; j < update_size; ++j) {
|
||||
auto update_loc = elem_to_loc(i * update_size + j, updates);
|
||||
auto out_loc = elem_to_loc(j, update_shape, out.strides());
|
||||
op(updates.data<InT>()[update_loc],
|
||||
out.data<InT>() + out_offset + out_loc);
|
||||
op(updates.data<InT>()[update_it.loc],
|
||||
out.data<InT>() + out_offset + out_it.loc);
|
||||
update_it.step();
|
||||
out_it.step();
|
||||
}
|
||||
out_it.reset();
|
||||
update_it.reset();
|
||||
}
|
||||
}
|
||||
|
||||
|
@@ -2,39 +2,19 @@
|
||||
|
||||
#include "mlx/allocator.h"
|
||||
#include "mlx/backend/common/copy.h"
|
||||
#include "mlx/backend/common/lapack.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
#ifdef ACCELERATE_NEW_LAPACK
|
||||
#include <Accelerate/Accelerate.h>
|
||||
#else
|
||||
#include <lapack.h>
|
||||
#endif
|
||||
|
||||
// Wrapper to account for differences in
|
||||
// LAPACK implementations (basically how to pass the 'uplo' string to fortran).
|
||||
int strtri_wrapper(char uplo, char diag, float* matrix, int N) {
|
||||
int info;
|
||||
|
||||
#ifdef LAPACK_FORTRAN_STRLEN_END
|
||||
strtri_(
|
||||
/* uplo = */ &uplo,
|
||||
/* diag = */ &diag,
|
||||
/* N = */ &N,
|
||||
/* a = */ matrix,
|
||||
/* lda = */ &N,
|
||||
/* info = */ &info,
|
||||
/* uplo_len = */ static_cast<size_t>(1),
|
||||
/* diag_len = */ static_cast<size_t>(1));
|
||||
#else
|
||||
strtri_(
|
||||
MLX_LAPACK_FUNC(strtri)
|
||||
(
|
||||
/* uplo = */ &uplo,
|
||||
/* diag = */ &diag,
|
||||
/* N = */ &N,
|
||||
/* a = */ matrix,
|
||||
/* lda = */ &N,
|
||||
/* info = */ &info);
|
||||
#endif
|
||||
|
||||
return info;
|
||||
}
|
||||
|
||||
|
@@ -1,10 +1,11 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#ifdef ACCELERATE_NEW_LAPACK
|
||||
#include <Accelerate/Accelerate.h>
|
||||
#else
|
||||
#include <cblas.h>
|
||||
#include <lapack.h>
|
||||
#endif
|
||||
|
@@ -18,10 +18,12 @@ if [ "$CLANG" = "TRUE" ]; then
|
||||
#include <cstdint>
|
||||
#include <vector>
|
||||
EOM
|
||||
|
||||
CC_FLAGS=""
|
||||
else
|
||||
CC_FLAGS="-std=c++17"
|
||||
fi
|
||||
|
||||
CONTENT=$($GCC -I "$SRCDIR" -E "$SRCDIR/mlx/backend/common/compiled_preamble.h" 2>/dev/null)
|
||||
CONTENT=$($GCC $CC_FLAGS -I "$SRCDIR" -E "$SRCDIR/mlx/backend/common/compiled_preamble.h" 2>/dev/null)
|
||||
|
||||
cat << EOF > "$OUTPUT_FILE"
|
||||
const char* get_kernel_preamble() {
|
||||
|
@@ -1,15 +1,10 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#ifdef ACCELERATE_NEW_LAPACK
|
||||
#include <Accelerate/Accelerate.h>
|
||||
#else
|
||||
#include <cblas.h>
|
||||
#endif
|
||||
|
||||
#include <cstring>
|
||||
|
||||
#include "mlx/array.h"
|
||||
#include "mlx/backend/common/copy.h"
|
||||
#include "mlx/backend/common/lapack.h"
|
||||
#include "mlx/backend/common/utils.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
|
@@ -295,6 +295,13 @@ struct Floor {
|
||||
}
|
||||
};
|
||||
|
||||
struct Imag {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return std::imag(x);
|
||||
}
|
||||
};
|
||||
|
||||
struct Log {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
@@ -337,6 +344,13 @@ struct Negative {
|
||||
}
|
||||
};
|
||||
|
||||
struct Real {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return std::real(x);
|
||||
}
|
||||
};
|
||||
|
||||
struct Round {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
|
@@ -159,6 +159,17 @@ void Conjugate::eval(const std::vector<array>& inputs, array& out) {
|
||||
}
|
||||
}
|
||||
|
||||
void Contiguous::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
if (in.flags().row_contiguous ||
|
||||
(allow_col_major_ && in.flags().col_contiguous)) {
|
||||
out.copy_shared_buffer(in);
|
||||
} else {
|
||||
copy(in, out, CopyType::General);
|
||||
}
|
||||
}
|
||||
|
||||
void Cos::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
@@ -273,6 +284,10 @@ void Full::eval(const std::vector<array>& inputs, array& out) {
|
||||
copy(in, out, ctype);
|
||||
}
|
||||
|
||||
void Imag::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
unary_op<complex64_t, float>(inputs[0], out, detail::Imag());
|
||||
}
|
||||
|
||||
void Log::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
@@ -398,6 +413,10 @@ void RandomBits::eval(const std::vector<array>& inputs, array& out) {
|
||||
}
|
||||
}
|
||||
|
||||
void Real::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
unary_op<complex64_t, float>(inputs[0], out, detail::Real());
|
||||
}
|
||||
|
||||
void Reshape::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
@@ -598,7 +617,7 @@ void View::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
if (ibytes == obytes || obytes < ibytes && in.strides().back() == 1 ||
|
||||
in.flags().row_contiguous) {
|
||||
auto strides = in.strides();
|
||||
for (int i = 0; i < strides.size() - 1; ++i) {
|
||||
for (int i = 0; i < static_cast<int>(strides.size()) - 1; ++i) {
|
||||
strides[i] *= ibytes;
|
||||
strides[i] /= obytes;
|
||||
}
|
||||
|
@@ -2,14 +2,9 @@
|
||||
|
||||
#include "mlx/allocator.h"
|
||||
#include "mlx/backend/common/copy.h"
|
||||
#include "mlx/backend/common/lapack.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
#ifdef ACCELERATE_NEW_LAPACK
|
||||
#include <Accelerate/Accelerate.h>
|
||||
#else
|
||||
#include <lapack.h>
|
||||
#endif
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
template <typename T>
|
||||
|
@@ -2,13 +2,38 @@
|
||||
|
||||
#include <cassert>
|
||||
|
||||
#include "mlx/backend/metal/copy.h"
|
||||
#include "mlx/backend/common/copy.h"
|
||||
#include "mlx/backend/common/ops.h"
|
||||
#include "mlx/fast_primitives.h"
|
||||
#include "mlx/primitives.h"
|
||||
#include "mlx/utils.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace {
|
||||
|
||||
template <typename T, int bits>
|
||||
void extract_bits(const uint8_t* w_in, T* w_out) {
|
||||
assert(bits == 3 || bits == 6);
|
||||
if (bits == 3) {
|
||||
w_out[0] = static_cast<T>(w_in[0] & 0x7);
|
||||
w_out[1] = static_cast<T>((w_in[0] & 0x38) >> 3);
|
||||
w_out[2] = static_cast<T>(((w_in[0] & 0xc0) >> 6) + ((w_in[1] & 0x1) << 2));
|
||||
w_out[3] = static_cast<T>((w_in[1] & 0xe) >> 1);
|
||||
w_out[4] = static_cast<T>((w_in[1] & 0x70) >> 4);
|
||||
w_out[5] = static_cast<T>(((w_in[1] & 0x80) >> 7) + ((w_in[2] & 0x3) << 1));
|
||||
w_out[6] = static_cast<T>((w_in[2] & 0x1c) >> 2);
|
||||
w_out[7] = static_cast<T>((w_in[2] & 0xe0) >> 5);
|
||||
} else if (bits == 6) {
|
||||
w_out[0] = static_cast<T>(w_in[0] & 0x3f);
|
||||
w_out[1] =
|
||||
static_cast<T>(((w_in[0] >> 6) & 0x03) + ((w_in[1] & 0x0f) << 2));
|
||||
w_out[2] =
|
||||
static_cast<T>(((w_in[1] >> 4) & 0x0f) + ((w_in[2] & 0x03) << 4));
|
||||
w_out[3] = static_cast<T>((w_in[2] >> 2) & 0x3f);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int bits, int group_size>
|
||||
void _qmm(
|
||||
T* result,
|
||||
@@ -20,13 +45,12 @@ void _qmm(
|
||||
int N,
|
||||
int K) {
|
||||
constexpr int bitmask = (1 << bits) - 1;
|
||||
constexpr int pack_factor = 32 / bits;
|
||||
constexpr int pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 8 / bits;
|
||||
constexpr int bytes_per_pack = (bits == 3 || bits == 6) ? 3 : 1;
|
||||
constexpr int packs_in_group = group_size / pack_factor;
|
||||
const int Ng = N / group_size;
|
||||
const int Nw = N / pack_factor;
|
||||
|
||||
for (int m = 0; m < M; m++) {
|
||||
const uint32_t* w_local = w;
|
||||
const uint8_t* w_local = (const uint8_t*)w;
|
||||
const T* scales_local = scales;
|
||||
const T* biases_local = biases;
|
||||
|
||||
@@ -40,13 +64,25 @@ void _qmm(
|
||||
T scale = *scales_local++;
|
||||
T bias = *biases_local++;
|
||||
for (int ng = 0; ng < packs_in_group; ng++) {
|
||||
uint32_t wi = *w_local++;
|
||||
|
||||
if (bits == 3 || bits == 6) {
|
||||
T wl[pack_factor];
|
||||
extract_bits<T, bits>(w_local, wl);
|
||||
#pragma clang loop unroll(full)
|
||||
for (int p = 0; p < pack_factor; p++) {
|
||||
(*result_local++) +=
|
||||
xi * (scale * static_cast<T>(wi & bitmask) + bias);
|
||||
wi >>= bits;
|
||||
for (int p = 0; p < pack_factor; p++) {
|
||||
(*result_local++) += xi * (scale * wl[p] + bias);
|
||||
}
|
||||
w_local += bytes_per_pack;
|
||||
|
||||
} else {
|
||||
uint8_t wi = *w_local++;
|
||||
#pragma clang loop unroll(full)
|
||||
for (int p = 0; p < pack_factor; p++) {
|
||||
(*result_local++) +=
|
||||
xi * (scale * static_cast<T>(wi & bitmask) + bias);
|
||||
if (bits != 8) {
|
||||
wi >>= bits;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -67,13 +103,12 @@ void _qmm_t(
|
||||
int N,
|
||||
int K) {
|
||||
constexpr int bitmask = (1 << bits) - 1;
|
||||
constexpr int pack_factor = 32 / bits;
|
||||
constexpr int pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 8 / bits;
|
||||
constexpr int bytes_per_pack = (bits == 3 || bits == 6) ? 3 : 1;
|
||||
constexpr int packs_in_group = group_size / pack_factor;
|
||||
const int Kg = K / group_size;
|
||||
const int Kw = K / pack_factor;
|
||||
|
||||
for (int m = 0; m < M; m++) {
|
||||
const uint32_t* w_local = w;
|
||||
const uint8_t* w_local = (const uint8_t*)w;
|
||||
const T* scales_local = scales;
|
||||
const T* biases_local = biases;
|
||||
|
||||
@@ -85,12 +120,26 @@ void _qmm_t(
|
||||
T bias = *biases_local++;
|
||||
|
||||
for (int kw = 0; kw < packs_in_group; kw++) {
|
||||
uint32_t wi = *w_local++;
|
||||
|
||||
if (bits == 3 || bits == 6) {
|
||||
T wl[pack_factor];
|
||||
extract_bits<T, bits>(w_local, wl);
|
||||
#pragma clang loop unroll(full)
|
||||
for (int p = 0; p < pack_factor; p++) {
|
||||
sum += (*x_local++) * (scale * static_cast<T>(wi & bitmask) + bias);
|
||||
wi >>= bits;
|
||||
for (int p = 0; p < pack_factor; p++) {
|
||||
sum += x_local[p] * (scale * wl[p] + bias);
|
||||
}
|
||||
w_local += bytes_per_pack;
|
||||
x_local += pack_factor;
|
||||
|
||||
} else {
|
||||
uint8_t wi = *w_local++;
|
||||
#pragma clang loop unroll(full)
|
||||
for (int p = 0; p < pack_factor; p++) {
|
||||
sum +=
|
||||
(*x_local++) * (scale * static_cast<T>(wi & bitmask) + bias);
|
||||
if (bits != 8) {
|
||||
wi >>= bits;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -102,6 +151,55 @@ void _qmm_t(
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int bits, int group_size>
|
||||
void _qmm_dispatch_transpose(
|
||||
T* result,
|
||||
const T* x,
|
||||
const uint32_t* w,
|
||||
const T* scales,
|
||||
const T* biases,
|
||||
int M,
|
||||
int N,
|
||||
int K,
|
||||
bool transposed_w) {
|
||||
if (transposed_w) {
|
||||
return _qmm_t<T, bits, group_size>(result, x, w, scales, biases, M, N, K);
|
||||
} else {
|
||||
return _qmm<T, bits, group_size>(result, x, w, scales, biases, M, N, K);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int bits>
|
||||
void _qmm_dispatch_group(
|
||||
T* result,
|
||||
const T* x,
|
||||
const uint32_t* w,
|
||||
const T* scales,
|
||||
const T* biases,
|
||||
int M,
|
||||
int N,
|
||||
int K,
|
||||
int group_size,
|
||||
bool transposed_w) {
|
||||
switch (group_size) {
|
||||
case 32:
|
||||
_qmm_dispatch_transpose<T, bits, 32>(
|
||||
result, x, w, scales, biases, M, N, K, transposed_w);
|
||||
break;
|
||||
case 64:
|
||||
_qmm_dispatch_transpose<T, bits, 64>(
|
||||
result, x, w, scales, biases, M, N, K, transposed_w);
|
||||
break;
|
||||
case 128:
|
||||
_qmm_dispatch_transpose<T, bits, 128>(
|
||||
result, x, w, scales, biases, M, N, K, transposed_w);
|
||||
break;
|
||||
default:
|
||||
throw std::invalid_argument(
|
||||
"Quantization group size must be 32, 64 or 128.");
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void _qmm_dispatch_typed(
|
||||
T* result,
|
||||
@@ -116,79 +214,29 @@ void _qmm_dispatch_typed(
|
||||
int bits,
|
||||
bool transposed_w) {
|
||||
switch (bits) {
|
||||
case 2: {
|
||||
switch (group_size) {
|
||||
case 32:
|
||||
if (transposed_w) {
|
||||
return _qmm_t<T, 2, 32>(result, x, w, scales, biases, M, N, K);
|
||||
} else {
|
||||
return _qmm<T, 2, 32>(result, x, w, scales, biases, M, N, K);
|
||||
}
|
||||
case 64:
|
||||
if (transposed_w) {
|
||||
return _qmm_t<T, 2, 64>(result, x, w, scales, biases, M, N, K);
|
||||
} else {
|
||||
return _qmm<T, 2, 64>(result, x, w, scales, biases, M, N, K);
|
||||
}
|
||||
case 128:
|
||||
if (transposed_w) {
|
||||
return _qmm_t<T, 2, 128>(result, x, w, scales, biases, M, N, K);
|
||||
} else {
|
||||
return _qmm<T, 2, 128>(result, x, w, scales, biases, M, N, K);
|
||||
}
|
||||
}
|
||||
}
|
||||
case 4: {
|
||||
switch (group_size) {
|
||||
case 32:
|
||||
if (transposed_w) {
|
||||
return _qmm_t<T, 4, 32>(result, x, w, scales, biases, M, N, K);
|
||||
} else {
|
||||
return _qmm<T, 4, 32>(result, x, w, scales, biases, M, N, K);
|
||||
}
|
||||
case 64:
|
||||
if (transposed_w) {
|
||||
return _qmm_t<T, 4, 64>(result, x, w, scales, biases, M, N, K);
|
||||
} else {
|
||||
return _qmm<T, 4, 64>(result, x, w, scales, biases, M, N, K);
|
||||
}
|
||||
case 128:
|
||||
if (transposed_w) {
|
||||
return _qmm_t<T, 4, 128>(result, x, w, scales, biases, M, N, K);
|
||||
} else {
|
||||
return _qmm<T, 4, 128>(result, x, w, scales, biases, M, N, K);
|
||||
}
|
||||
}
|
||||
}
|
||||
case 8: {
|
||||
switch (group_size) {
|
||||
case 32:
|
||||
if (transposed_w) {
|
||||
return _qmm_t<T, 8, 32>(result, x, w, scales, biases, M, N, K);
|
||||
} else {
|
||||
return _qmm<T, 8, 32>(result, x, w, scales, biases, M, N, K);
|
||||
}
|
||||
case 64:
|
||||
if (transposed_w) {
|
||||
return _qmm_t<T, 8, 64>(result, x, w, scales, biases, M, N, K);
|
||||
} else {
|
||||
return _qmm<T, 8, 64>(result, x, w, scales, biases, M, N, K);
|
||||
}
|
||||
case 128:
|
||||
if (transposed_w) {
|
||||
return _qmm_t<T, 8, 128>(result, x, w, scales, biases, M, N, K);
|
||||
} else {
|
||||
return _qmm<T, 8, 128>(result, x, w, scales, biases, M, N, K);
|
||||
}
|
||||
}
|
||||
}
|
||||
case 2:
|
||||
_qmm_dispatch_group<T, 2>(
|
||||
result, x, w, scales, biases, M, N, K, group_size, transposed_w);
|
||||
break;
|
||||
case 3:
|
||||
_qmm_dispatch_group<T, 3>(
|
||||
result, x, w, scales, biases, M, N, K, group_size, transposed_w);
|
||||
break;
|
||||
case 4:
|
||||
_qmm_dispatch_group<T, 4>(
|
||||
result, x, w, scales, biases, M, N, K, group_size, transposed_w);
|
||||
break;
|
||||
case 6:
|
||||
_qmm_dispatch_group<T, 6>(
|
||||
result, x, w, scales, biases, M, N, K, group_size, transposed_w);
|
||||
break;
|
||||
case 8:
|
||||
_qmm_dispatch_group<T, 8>(
|
||||
result, x, w, scales, biases, M, N, K, group_size, transposed_w);
|
||||
break;
|
||||
default:
|
||||
throw std::invalid_argument("Quantization bits must be 2, 3, 4, 6 or 8.");
|
||||
}
|
||||
std::ostringstream msg;
|
||||
msg << "Quantization type not supported. Provided bits=" << bits
|
||||
<< " and group_size=" << group_size
|
||||
<< ". The supported options are bits in "
|
||||
<< "{2, 4, 8} and group_size in {64, 128}.";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
void _qmm_dispatch(
|
||||
@@ -201,55 +249,61 @@ void _qmm_dispatch(
|
||||
int group_size,
|
||||
bool transposed_w) {
|
||||
int K = x.shape(-1);
|
||||
int M = x.size() / K;
|
||||
int M = x.shape(-2);
|
||||
int N = out.shape(-1);
|
||||
|
||||
switch (x.dtype()) {
|
||||
case float32:
|
||||
_qmm_dispatch_typed<float>(
|
||||
out.data<float>(),
|
||||
x.data<float>(),
|
||||
w.data<uint32_t>(),
|
||||
scales.data<float>(),
|
||||
biases.data<float>(),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
bits,
|
||||
group_size,
|
||||
transposed_w);
|
||||
break;
|
||||
case float16:
|
||||
_qmm_dispatch_typed<float16_t>(
|
||||
out.data<float16_t>(),
|
||||
x.data<float16_t>(),
|
||||
w.data<uint32_t>(),
|
||||
scales.data<float16_t>(),
|
||||
biases.data<float16_t>(),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
bits,
|
||||
group_size,
|
||||
transposed_w);
|
||||
break;
|
||||
case bfloat16:
|
||||
_qmm_dispatch_typed<bfloat16_t>(
|
||||
out.data<bfloat16_t>(),
|
||||
x.data<bfloat16_t>(),
|
||||
w.data<uint32_t>(),
|
||||
scales.data<bfloat16_t>(),
|
||||
biases.data<bfloat16_t>(),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
bits,
|
||||
group_size,
|
||||
transposed_w);
|
||||
break;
|
||||
default:
|
||||
throw std::invalid_argument(
|
||||
"[quantized_matmul] only floating types are supported");
|
||||
int w_els = w.ndim() > 2 ? w.shape(-1) * w.shape(-2) : 0;
|
||||
int g_els = w.ndim() > 2 ? scales.shape(-1) * scales.shape(-2) : 0;
|
||||
|
||||
int batch_size = x.size() / x.shape(-1) / x.shape(-2);
|
||||
for (int i = 0; i < batch_size; i++) {
|
||||
switch (x.dtype()) {
|
||||
case float32:
|
||||
_qmm_dispatch_typed<float>(
|
||||
out.data<float>() + i * M * N,
|
||||
x.data<float>() + elem_to_loc(i * M * K, x),
|
||||
w.data<uint32_t>() + elem_to_loc(i * w_els, w),
|
||||
scales.data<float>() + elem_to_loc(i * g_els, scales),
|
||||
biases.data<float>() + elem_to_loc(i * g_els, biases),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
bits,
|
||||
group_size,
|
||||
transposed_w);
|
||||
break;
|
||||
case float16:
|
||||
_qmm_dispatch_typed<float16_t>(
|
||||
out.data<float16_t>() + i * M * N,
|
||||
x.data<float16_t>() + elem_to_loc(i * M * K, x),
|
||||
w.data<uint32_t>() + elem_to_loc(i * w_els, w),
|
||||
scales.data<float16_t>() + elem_to_loc(i * g_els, scales),
|
||||
biases.data<float16_t>() + elem_to_loc(i * g_els, biases),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
bits,
|
||||
group_size,
|
||||
transposed_w);
|
||||
break;
|
||||
case bfloat16:
|
||||
_qmm_dispatch_typed<bfloat16_t>(
|
||||
out.data<bfloat16_t>() + i * M * N,
|
||||
x.data<bfloat16_t>() + elem_to_loc(i * M * K, x),
|
||||
w.data<uint32_t>() + elem_to_loc(i * w_els, w),
|
||||
scales.data<bfloat16_t>() + elem_to_loc(i * g_els, scales),
|
||||
biases.data<bfloat16_t>() + elem_to_loc(i * g_els, biases),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
bits,
|
||||
group_size,
|
||||
transposed_w);
|
||||
break;
|
||||
default:
|
||||
throw std::invalid_argument(
|
||||
"[quantized_matmul] only floating types are supported");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -398,4 +452,114 @@ void GatherQMM::eval(const std::vector<array>& inputs, array& out) {
|
||||
transpose_);
|
||||
}
|
||||
|
||||
template <typename T, typename U>
|
||||
void quantize(
|
||||
const array& w_,
|
||||
array& out_,
|
||||
array& scales_,
|
||||
array& biases_,
|
||||
int bits,
|
||||
int group_size) {
|
||||
const T* w = w_.data<T>();
|
||||
|
||||
auto out = out_.data<U>();
|
||||
T* scales = scales_.data<T>();
|
||||
T* biases = biases_.data<T>();
|
||||
|
||||
T n_bins = (1 << bits) - 1;
|
||||
T eps = 1e-7;
|
||||
bool power_of_2_bits = is_power_of_2(bits);
|
||||
int el_per_int = bits == 3 ? 8 : bits == 6 ? 4 : 32 / bits;
|
||||
// For 3/6 bits we read 3 uint8s at a time instead of 1 uint32
|
||||
int bytes_per_pack = power_of_2_bits ? 1 : 3;
|
||||
int int_per_group = group_size * bytes_per_pack / el_per_int;
|
||||
size_t n_groups = w_.size() / group_size;
|
||||
|
||||
for (size_t i = 0; i < n_groups; ++i) {
|
||||
size_t w_idx = i * group_size;
|
||||
T w_min = std::numeric_limits<float>::infinity();
|
||||
T w_max = -w_min;
|
||||
for (int j = 0; j < group_size; ++j) {
|
||||
w_max = std::max(w_max, w[w_idx + j]);
|
||||
w_min = std::min(w_min, w[w_idx + j]);
|
||||
}
|
||||
bool mask = std::abs(w_min) > std::abs(w_max);
|
||||
T scale = std::max(T((w_max - w_min) / n_bins), eps);
|
||||
scale = mask ? scale : -scale;
|
||||
|
||||
auto edge = mask ? w_min : w_max;
|
||||
auto q0 = std::rint(edge / scale);
|
||||
if (q0 == 0) {
|
||||
scales[i] = scale;
|
||||
biases[i] = 0;
|
||||
} else {
|
||||
scales[i] = edge / q0;
|
||||
biases[i] = edge;
|
||||
}
|
||||
size_t out_idx = i * int_per_group;
|
||||
for (int j = 0; j < int_per_group / bytes_per_pack; ++j) {
|
||||
uint32_t out_el = 0;
|
||||
for (int k = 0; k < el_per_int; ++k) {
|
||||
T w_el = w[w_idx + j * el_per_int + k];
|
||||
w_el = std::rint((w_el - biases[i]) / scales[i]);
|
||||
w_el = std::min(std::max(w_el, T(0)), n_bins);
|
||||
out_el |= static_cast<uint32_t>(w_el) << (k * bits);
|
||||
}
|
||||
if (power_of_2_bits) {
|
||||
out[out_idx + j] = out_el;
|
||||
} else {
|
||||
out[out_idx + bytes_per_pack * j] = out_el & 0xff;
|
||||
out[out_idx + bytes_per_pack * j + 1] = (out_el & 0xff00) >> 8;
|
||||
out[out_idx + bytes_per_pack * j + 2] = (out_el & 0xff0000) >> 16;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void fast::AffineQuantize::eval_cpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
auto ensure_row_contiguous = [](const array& arr) {
|
||||
if (arr.flags().row_contiguous) {
|
||||
return arr;
|
||||
} else {
|
||||
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
|
||||
copy(arr, arr_copy, CopyType::General);
|
||||
return arr_copy;
|
||||
}
|
||||
};
|
||||
auto w = ensure_row_contiguous(inputs[0]);
|
||||
|
||||
auto& out = outputs[0];
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
|
||||
auto& scales = outputs[1];
|
||||
auto& biases = outputs[2];
|
||||
scales.set_data(allocator::malloc_or_wait(scales.nbytes()));
|
||||
biases.set_data(allocator::malloc_or_wait(biases.nbytes()));
|
||||
if (w.dtype() == float16) {
|
||||
if (is_power_of_2(bits_)) {
|
||||
quantize<float16_t, uint32_t>(w, out, scales, biases, bits_, group_size_);
|
||||
} else {
|
||||
quantize<float16_t, uint8_t>(w, out, scales, biases, bits_, group_size_);
|
||||
}
|
||||
} else if (w.dtype() == bfloat16) {
|
||||
if (is_power_of_2(bits_)) {
|
||||
quantize<bfloat16_t, uint32_t>(
|
||||
w, out, scales, biases, bits_, group_size_);
|
||||
} else {
|
||||
quantize<bfloat16_t, uint8_t>(w, out, scales, biases, bits_, group_size_);
|
||||
}
|
||||
} else if (w.dtype() == float32) {
|
||||
if (is_power_of_2(bits_)) {
|
||||
quantize<float, uint32_t>(w, out, scales, biases, bits_, group_size_);
|
||||
} else {
|
||||
quantize<float, uint8_t>(w, out, scales, biases, bits_, group_size_);
|
||||
}
|
||||
} else {
|
||||
throw std::runtime_error(
|
||||
"[fast::AffineQuantize::eval_cpu] Only supports floating point inputs");
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@@ -120,48 +120,56 @@ struct MinReduce {
|
||||
};
|
||||
|
||||
template <typename InT>
|
||||
void reduce_dispatch_out(
|
||||
void reduce_dispatch_and_or(
|
||||
const array& in,
|
||||
array& out,
|
||||
Reduce::ReduceType rtype,
|
||||
const std::vector<int>& axes) {
|
||||
switch (rtype) {
|
||||
case Reduce::And: {
|
||||
reduction_op<InT, bool>(in, out, axes, true, AndReduce());
|
||||
break;
|
||||
if (rtype == Reduce::And) {
|
||||
reduction_op<InT, bool>(in, out, axes, true, AndReduce());
|
||||
} else {
|
||||
reduction_op<InT, bool>(in, out, axes, false, OrReduce());
|
||||
}
|
||||
}
|
||||
|
||||
template <typename InT>
|
||||
void reduce_dispatch_sum_prod(
|
||||
const array& in,
|
||||
array& out,
|
||||
Reduce::ReduceType rtype,
|
||||
const std::vector<int>& axes) {
|
||||
if (rtype == Reduce::Sum) {
|
||||
auto op = [](auto y, auto x) { (*y) = (*y) + x; };
|
||||
if constexpr (std::is_integral_v<InT> && sizeof(InT) <= 4) {
|
||||
reduction_op<InT, int32_t>(in, out, axes, 0, op);
|
||||
} else {
|
||||
reduction_op<InT, InT>(in, out, axes, 0, op);
|
||||
}
|
||||
case Reduce::Or: {
|
||||
reduction_op<InT, bool>(in, out, axes, false, OrReduce());
|
||||
break;
|
||||
}
|
||||
case Reduce::Sum: {
|
||||
auto op = [](auto y, auto x) { (*y) = (*y) + x; };
|
||||
if (out.dtype() == int32) {
|
||||
// special case since the input type can be bool
|
||||
reduction_op<InT, int32_t>(in, out, axes, 0, op);
|
||||
} else {
|
||||
reduction_op<InT, InT>(in, out, axes, 0, op);
|
||||
}
|
||||
break;
|
||||
}
|
||||
case Reduce::Prod: {
|
||||
auto op = [](auto y, auto x) { (*y) *= x; };
|
||||
} else {
|
||||
auto op = [](auto y, auto x) { (*y) *= x; };
|
||||
if constexpr (std::is_integral_v<InT> && sizeof(InT) <= 4) {
|
||||
reduction_op<InT, int32_t>(in, out, axes, 1, op);
|
||||
} else {
|
||||
reduction_op<InT, InT>(in, out, axes, 1, op);
|
||||
break;
|
||||
}
|
||||
case Reduce::Max: {
|
||||
auto init = Limits<InT>::min;
|
||||
reduction_op<InT, InT>(in, out, axes, init, MaxReduce());
|
||||
break;
|
||||
}
|
||||
case Reduce::Min: {
|
||||
auto init = Limits<InT>::max;
|
||||
reduction_op<InT, InT>(in, out, axes, init, MinReduce());
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename InT>
|
||||
void reduce_dispatch_min_max(
|
||||
const array& in,
|
||||
array& out,
|
||||
Reduce::ReduceType rtype,
|
||||
const std::vector<int>& axes) {
|
||||
if (rtype == Reduce::Max) {
|
||||
auto init = Limits<InT>::min;
|
||||
reduction_op<InT, InT>(in, out, axes, init, MaxReduce());
|
||||
} else {
|
||||
auto init = Limits<InT>::max;
|
||||
reduction_op<InT, InT>(in, out, axes, init, MinReduce());
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void nd_loop(
|
||||
@@ -190,46 +198,114 @@ void nd_loop(
|
||||
void Reduce::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
switch (in.dtype()) {
|
||||
case bool_:
|
||||
reduce_dispatch_out<bool>(in, out, reduce_type_, axes_);
|
||||
switch (reduce_type_) {
|
||||
case Reduce::And:
|
||||
case Reduce::Or: {
|
||||
switch (in.dtype()) {
|
||||
case bool_:
|
||||
case uint8:
|
||||
case int8:
|
||||
reduce_dispatch_and_or<int8_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case int16:
|
||||
case uint16:
|
||||
case float16:
|
||||
case bfloat16:
|
||||
reduce_dispatch_and_or<int16_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case uint32:
|
||||
case int32:
|
||||
case float32:
|
||||
reduce_dispatch_and_or<int32_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case uint64:
|
||||
case int64:
|
||||
case complex64:
|
||||
reduce_dispatch_and_or<int64_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
}
|
||||
break;
|
||||
case uint8:
|
||||
reduce_dispatch_out<uint8_t>(in, out, reduce_type_, axes_);
|
||||
}
|
||||
case Reduce::Sum:
|
||||
case Reduce::Prod: {
|
||||
switch (in.dtype()) {
|
||||
case bool_:
|
||||
case uint8:
|
||||
case int8:
|
||||
reduce_dispatch_sum_prod<int8_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case int16:
|
||||
case uint16:
|
||||
reduce_dispatch_sum_prod<int16_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case int32:
|
||||
case uint32:
|
||||
reduce_dispatch_sum_prod<int32_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case int64:
|
||||
case uint64:
|
||||
reduce_dispatch_sum_prod<int64_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case float16:
|
||||
reduce_dispatch_sum_prod<float16_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case bfloat16:
|
||||
reduce_dispatch_sum_prod<bfloat16_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case float32:
|
||||
reduce_dispatch_sum_prod<float>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case complex64:
|
||||
reduce_dispatch_sum_prod<complex64_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
}
|
||||
break;
|
||||
case uint16:
|
||||
reduce_dispatch_out<uint16_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case uint32:
|
||||
reduce_dispatch_out<uint32_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case uint64:
|
||||
reduce_dispatch_out<uint64_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case int8:
|
||||
reduce_dispatch_out<uint8_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case int16:
|
||||
reduce_dispatch_out<uint16_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case int32:
|
||||
reduce_dispatch_out<int32_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case int64:
|
||||
reduce_dispatch_out<int64_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case float16:
|
||||
reduce_dispatch_out<float16_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case float32:
|
||||
reduce_dispatch_out<float>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case bfloat16:
|
||||
reduce_dispatch_out<bfloat16_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case complex64:
|
||||
reduce_dispatch_out<complex64_t>(in, out, reduce_type_, axes_);
|
||||
}
|
||||
case Reduce::Max:
|
||||
case Reduce::Min: {
|
||||
switch (in.dtype()) {
|
||||
case bool_:
|
||||
reduce_dispatch_min_max<bool>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case uint8:
|
||||
reduce_dispatch_min_max<uint8_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case uint16:
|
||||
reduce_dispatch_min_max<uint16_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case uint32:
|
||||
reduce_dispatch_min_max<uint32_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case uint64:
|
||||
reduce_dispatch_min_max<uint64_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case int8:
|
||||
reduce_dispatch_min_max<uint8_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case int16:
|
||||
reduce_dispatch_min_max<uint16_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case int32:
|
||||
reduce_dispatch_min_max<int32_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case int64:
|
||||
reduce_dispatch_min_max<int64_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case float16:
|
||||
reduce_dispatch_min_max<float16_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case float32:
|
||||
reduce_dispatch_min_max<float>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case bfloat16:
|
||||
reduce_dispatch_min_max<bfloat16_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case complex64:
|
||||
reduce_dispatch_min_max<complex64_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@@ -34,7 +34,7 @@ void shared_buffer_slice(
|
||||
flags.col_contiguous = is_col_contiguous;
|
||||
flags.contiguous = (no_bsx_size == data_size);
|
||||
|
||||
out.copy_shared_buffer(in, out_strides, flags, data_size, data_offset);
|
||||
move_or_copy(in, out, out_strides, flags, data_size, data_offset);
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@@ -111,7 +111,8 @@ void sort(const array& in, array& out, int axis) {
|
||||
|
||||
// Get axis, shape and stride info
|
||||
axis = axis < 0 ? axis + in.ndim() : axis;
|
||||
size_t n_rows = in.size() / in.shape(axis);
|
||||
size_t in_size = in.flags().contiguous ? in.data_size() : in.size();
|
||||
size_t n_rows = in_size / in.shape(axis);
|
||||
|
||||
auto remaining_shape = out.shape();
|
||||
remaining_shape.erase(remaining_shape.begin() + axis);
|
||||
@@ -123,14 +124,16 @@ void sort(const array& in, array& out, int axis) {
|
||||
int axis_size = out.shape(axis);
|
||||
|
||||
// Perform sorting in place
|
||||
ContiguousIterator<size_t> src_it(
|
||||
remaining_shape, remaining_strides, remaining_shape.size());
|
||||
for (int i = 0; i < n_rows; i++) {
|
||||
size_t loc = elem_to_loc(i, remaining_shape, remaining_strides);
|
||||
T* data_ptr = out.data<T>() + loc;
|
||||
T* data_ptr = out.data<T>() + src_it.loc;
|
||||
|
||||
StridedIterator st(data_ptr, axis_stride, 0);
|
||||
StridedIterator ed(data_ptr, axis_stride, axis_size);
|
||||
|
||||
std::stable_sort(st, ed);
|
||||
src_it.step();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -160,11 +163,15 @@ void argsort(const array& in, array& out, int axis) {
|
||||
int axis_size = in.shape(axis);
|
||||
|
||||
// Perform sorting
|
||||
ContiguousIterator<size_t> in_it(
|
||||
in_remaining_shape, in_remaining_strides, in_remaining_shape.size());
|
||||
ContiguousIterator<size_t> out_it(
|
||||
out_remaining_shape, out_remaining_strides, out_remaining_shape.size());
|
||||
for (int i = 0; i < n_rows; i++) {
|
||||
size_t in_loc = elem_to_loc(i, in_remaining_shape, in_remaining_strides);
|
||||
size_t out_loc = elem_to_loc(i, out_remaining_shape, out_remaining_strides);
|
||||
const T* data_ptr = in.data<T>() + in_loc;
|
||||
IdxT* idx_ptr = out.data<IdxT>() + out_loc;
|
||||
const T* data_ptr = in.data<T>() + in_it.loc;
|
||||
IdxT* idx_ptr = out.data<IdxT>() + out_it.loc;
|
||||
in_it.step();
|
||||
out_it.step();
|
||||
|
||||
StridedIterator st_(idx_ptr, out_stride, 0);
|
||||
StridedIterator ed_(idx_ptr, out_stride, axis_size);
|
||||
@@ -192,7 +199,8 @@ void partition(const array& in, array& out, int axis, int kth) {
|
||||
|
||||
// Get axis, shape and stride info
|
||||
axis = axis < 0 ? axis + in.ndim() : axis;
|
||||
size_t n_rows = in.size() / in.shape(axis);
|
||||
size_t in_size = in.flags().contiguous ? in.data_size() : in.size();
|
||||
size_t n_rows = in_size / in.shape(axis);
|
||||
|
||||
auto remaining_shape = in.shape();
|
||||
remaining_shape.erase(remaining_shape.begin() + axis);
|
||||
@@ -206,9 +214,11 @@ void partition(const array& in, array& out, int axis, int kth) {
|
||||
kth = kth < 0 ? kth + axis_size : kth;
|
||||
|
||||
// Perform partition in place
|
||||
ContiguousIterator<size_t> src_it(
|
||||
remaining_shape, remaining_strides, remaining_shape.size());
|
||||
for (int i = 0; i < n_rows; i++) {
|
||||
size_t loc = elem_to_loc(i, remaining_shape, remaining_strides);
|
||||
T* data_ptr = out.data<T>() + loc;
|
||||
T* data_ptr = out.data<T>() + src_it.loc;
|
||||
src_it.step();
|
||||
|
||||
StridedIterator st(data_ptr, axis_stride, 0);
|
||||
StridedIterator md(data_ptr, axis_stride, kth);
|
||||
@@ -227,37 +237,49 @@ void argpartition(const array& in, array& out, int axis, int kth) {
|
||||
axis = axis < 0 ? axis + in.ndim() : axis;
|
||||
size_t n_rows = in.size() / in.shape(axis);
|
||||
|
||||
auto remaining_shape = in.shape();
|
||||
remaining_shape.erase(remaining_shape.begin() + axis);
|
||||
auto in_remaining_shape = in.shape();
|
||||
in_remaining_shape.erase(in_remaining_shape.begin() + axis);
|
||||
|
||||
auto remaining_strides = in.strides();
|
||||
remaining_strides.erase(remaining_strides.begin() + axis);
|
||||
auto in_remaining_strides = in.strides();
|
||||
in_remaining_strides.erase(in_remaining_strides.begin() + axis);
|
||||
|
||||
size_t axis_stride = in.strides()[axis];
|
||||
auto out_remaining_shape = out.shape();
|
||||
out_remaining_shape.erase(out_remaining_shape.begin() + axis);
|
||||
|
||||
auto out_remaining_strides = out.strides();
|
||||
out_remaining_strides.erase(out_remaining_strides.begin() + axis);
|
||||
|
||||
size_t in_stride = in.strides()[axis];
|
||||
size_t out_stride = out.strides()[axis];
|
||||
int axis_size = in.shape(axis);
|
||||
|
||||
kth = kth < 0 ? kth + axis_size : kth;
|
||||
|
||||
// Perform partition
|
||||
ContiguousIterator<size_t> in_it(
|
||||
in_remaining_shape, in_remaining_strides, in_remaining_shape.size());
|
||||
ContiguousIterator<size_t> out_it(
|
||||
out_remaining_shape, out_remaining_strides, out_remaining_shape.size());
|
||||
for (int i = 0; i < n_rows; i++) {
|
||||
size_t loc = elem_to_loc(i, remaining_shape, remaining_strides);
|
||||
const T* data_ptr = in.data<T>() + loc;
|
||||
IdxT* idx_ptr = out.data<IdxT>() + loc;
|
||||
const T* data_ptr = in.data<T>() + in_it.loc;
|
||||
IdxT* idx_ptr = out.data<IdxT>() + out_it.loc;
|
||||
in_it.step();
|
||||
out_it.step();
|
||||
|
||||
StridedIterator st_(idx_ptr, axis_stride, 0);
|
||||
StridedIterator ed_(idx_ptr, axis_stride, axis_size);
|
||||
StridedIterator st_(idx_ptr, out_stride, 0);
|
||||
StridedIterator ed_(idx_ptr, out_stride, axis_size);
|
||||
|
||||
// Initialize with iota
|
||||
std::iota(st_, ed_, IdxT(0));
|
||||
|
||||
// Sort according to vals
|
||||
StridedIterator st(idx_ptr, axis_stride, 0);
|
||||
StridedIterator md(idx_ptr, axis_stride, kth);
|
||||
StridedIterator ed(idx_ptr, axis_stride, axis_size);
|
||||
StridedIterator st(idx_ptr, out_stride, 0);
|
||||
StridedIterator md(idx_ptr, out_stride, kth);
|
||||
StridedIterator ed(idx_ptr, out_stride, axis_size);
|
||||
|
||||
std::nth_element(st, md, ed, [data_ptr, axis_stride](IdxT a, IdxT b) {
|
||||
auto v1 = data_ptr[a * axis_stride];
|
||||
auto v2 = data_ptr[b * axis_stride];
|
||||
std::nth_element(st, md, ed, [data_ptr, in_stride](IdxT a, IdxT b) {
|
||||
auto v1 = data_ptr[a * in_stride];
|
||||
auto v2 = data_ptr[b * in_stride];
|
||||
return v1 < v2 || (v1 == v2 && a < b);
|
||||
});
|
||||
}
|
||||
|
@@ -2,7 +2,7 @@
|
||||
|
||||
#include "mlx/allocator.h"
|
||||
#include "mlx/backend/common/copy.h"
|
||||
#include "mlx/backend/common/lapack_helper.h"
|
||||
#include "mlx/backend/common/lapack.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
@@ -24,26 +24,26 @@ void set_unary_output_data(const array& in, array& out) {
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename Op>
|
||||
void unary_op(const T* a, T* out, Op op, size_t shape, size_t stride) {
|
||||
template <typename T, typename U = T, typename Op>
|
||||
void unary_op(const T* a, U* out, Op op, size_t shape, size_t stride) {
|
||||
for (size_t i = 0; i < shape; i += 1) {
|
||||
out[i] = op(*a);
|
||||
a += stride;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename Op>
|
||||
template <typename T, typename U = T, typename Op>
|
||||
void unary_op(const array& a, array& out, Op op) {
|
||||
const T* a_ptr = a.data<T>();
|
||||
if (a.flags().contiguous) {
|
||||
set_unary_output_data(a, out);
|
||||
T* dst = out.data<T>();
|
||||
U* dst = out.data<U>();
|
||||
for (size_t i = 0; i < a.data_size(); ++i) {
|
||||
dst[i] = op(a_ptr[i]);
|
||||
}
|
||||
} else {
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
T* dst = out.data<T>();
|
||||
U* dst = out.data<U>();
|
||||
size_t shape = a.ndim() > 0 ? a.shape(-1) : 1;
|
||||
size_t stride = a.ndim() > 0 ? a.strides(-1) : 1;
|
||||
if (a.ndim() <= 1) {
|
||||
|
@@ -4,6 +4,28 @@
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
void move_or_copy(const array& in, array& out) {
|
||||
if (in.is_donatable()) {
|
||||
out.move_shared_buffer(in);
|
||||
} else {
|
||||
out.copy_shared_buffer(in);
|
||||
}
|
||||
}
|
||||
|
||||
void move_or_copy(
|
||||
const array& in,
|
||||
array& out,
|
||||
const std::vector<size_t>& strides,
|
||||
array::Flags flags,
|
||||
size_t data_size,
|
||||
size_t offset /* = 0 */) {
|
||||
if (in.is_donatable()) {
|
||||
out.move_shared_buffer(in, strides, flags, data_size, offset);
|
||||
} else {
|
||||
out.copy_shared_buffer(in, strides, flags, data_size, offset);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename StrideT>
|
||||
std::tuple<std::vector<int>, std::vector<std::vector<StrideT>>>
|
||||
collapse_contiguous_dims_impl(
|
||||
|
@@ -88,7 +88,11 @@ std::pair<std::vector<int>, std::vector<size_t>> collapse_contiguous_dims(
|
||||
template <typename StrideT>
|
||||
struct ContiguousIterator {
|
||||
inline void step() {
|
||||
int i = dims_;
|
||||
int dims = shape_.size();
|
||||
if (dims == 0) {
|
||||
return;
|
||||
}
|
||||
int i = dims - 1;
|
||||
while (pos_[i] == (shape_[i] - 1) && i > 0) {
|
||||
pos_[i] = 0;
|
||||
loc -= (shape_[i] - 1) * strides_[i];
|
||||
@@ -98,15 +102,41 @@ struct ContiguousIterator {
|
||||
loc += strides_[i];
|
||||
}
|
||||
|
||||
void seek(StrideT n) {
|
||||
loc = 0;
|
||||
for (int i = shape_.size() - 1; i >= 0; --i) {
|
||||
auto q_and_r = ldiv(n, shape_[i]);
|
||||
loc += q_and_r.rem * strides_[i];
|
||||
pos_[i] = q_and_r.rem;
|
||||
n = q_and_r.quot;
|
||||
}
|
||||
}
|
||||
|
||||
void reset() {
|
||||
loc = 0;
|
||||
std::fill(pos_.begin(), pos_.end(), 0);
|
||||
}
|
||||
|
||||
ContiguousIterator() {};
|
||||
|
||||
explicit ContiguousIterator(const array& a)
|
||||
: shape_(a.shape()), strides_(a.strides()) {
|
||||
if (!shape_.empty()) {
|
||||
std::tie(shape_, strides_) = collapse_contiguous_dims(shape_, strides_);
|
||||
pos_ = std::vector<int>(shape_.size(), 0);
|
||||
}
|
||||
}
|
||||
|
||||
explicit ContiguousIterator(
|
||||
const std::vector<int>& shape,
|
||||
const std::vector<StrideT>& strides,
|
||||
int dims)
|
||||
: shape_(shape.begin(), shape.begin() + dims),
|
||||
strides_(strides.begin(), strides.begin() + dims) {
|
||||
std::tie(shape_, strides_) = collapse_contiguous_dims(shape_, strides_);
|
||||
dims_ = shape_.size() - 1;
|
||||
pos_ = std::vector<int>(dims_ + 1, 0);
|
||||
if (!shape_.empty()) {
|
||||
std::tie(shape_, strides_) = collapse_contiguous_dims(shape_, strides_);
|
||||
pos_ = std::vector<int>(shape_.size(), 0);
|
||||
}
|
||||
}
|
||||
|
||||
StrideT loc{0};
|
||||
@@ -115,7 +145,6 @@ struct ContiguousIterator {
|
||||
std::vector<int> shape_;
|
||||
std::vector<StrideT> strides_;
|
||||
std::vector<int> pos_;
|
||||
int dims_;
|
||||
};
|
||||
|
||||
template <typename StrideT>
|
||||
@@ -149,4 +178,13 @@ inline bool is_donatable(const array& in, const array& out) {
|
||||
in.buffer_size() <= out.nbytes() + donation_extra;
|
||||
}
|
||||
|
||||
void move_or_copy(const array& in, array& out);
|
||||
void move_or_copy(
|
||||
const array& in,
|
||||
array& out,
|
||||
const std::vector<size_t>& strides,
|
||||
array::Flags flags,
|
||||
size_t data_size,
|
||||
size_t offset = 0);
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@@ -14,20 +14,27 @@ function(make_jit_source SRC_FILE)
|
||||
COMMAND
|
||||
/bin/bash ${CMAKE_CURRENT_SOURCE_DIR}/make_compiled_preamble.sh
|
||||
${CMAKE_CURRENT_BINARY_DIR}/jit ${CMAKE_C_COMPILER} ${PROJECT_SOURCE_DIR}
|
||||
${SRC_FILE} "-DMLX_METAL_VERSION=${MLX_METAL_VERSION}"
|
||||
${SRC_FILE}
|
||||
DEPENDS make_compiled_preamble.sh kernels/${SRC_FILE}.h ${ARGN})
|
||||
add_custom_target(${SRC_NAME} DEPENDS jit/${SRC_NAME}.cpp)
|
||||
add_dependencies(mlx ${SRC_NAME})
|
||||
target_sources(mlx PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/jit/${SRC_NAME}.cpp)
|
||||
endfunction(make_jit_source)
|
||||
|
||||
make_jit_source(utils kernels/bf16.h kernels/complex.h kernels/defines.h)
|
||||
make_jit_source(
|
||||
utils
|
||||
kernels/jit/bf16.h
|
||||
kernels/metal_3_0/bf16.h
|
||||
kernels/metal_3_1/bf16.h
|
||||
kernels/bf16_math.h
|
||||
kernels/complex.h
|
||||
kernels/defines.h)
|
||||
make_jit_source(unary_ops kernels/erf.h kernels/expm1f.h)
|
||||
make_jit_source(binary_ops)
|
||||
make_jit_source(ternary_ops)
|
||||
make_jit_source(reduce_utils kernels/atomic.h kernels/reduction/ops.h)
|
||||
make_jit_source(scatter)
|
||||
make_jit_source(gather)
|
||||
make_jit_source(scatter kernels/indexing.h)
|
||||
make_jit_source(gather kernels/indexing.h)
|
||||
make_jit_source(hadamard)
|
||||
|
||||
if(MLX_METAL_JIT)
|
||||
@@ -99,6 +106,7 @@ target_sources(
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/ternary.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/unary.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/resident.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp)
|
||||
|
||||
if(NOT MLX_METAL_PATH)
|
||||
|
@@ -2,6 +2,7 @@
|
||||
#include "mlx/backend/metal/allocator.h"
|
||||
#include "mlx/backend/metal/metal.h"
|
||||
#include "mlx/backend/metal/metal_impl.h"
|
||||
#include "mlx/backend/metal/resident.h"
|
||||
|
||||
#include <mach/vm_page_size.h>
|
||||
#include <unistd.h>
|
||||
@@ -140,6 +141,7 @@ void BufferCache::remove_from_list(BufferCache::BufferHolder* to_remove) {
|
||||
|
||||
MetalAllocator::MetalAllocator()
|
||||
: device_(device(mlx::core::Device::gpu).mtl_device()),
|
||||
residency_set_(device_),
|
||||
buffer_cache_(device_) {
|
||||
auto memsize = std::get<size_t>(device_info()["memory_size"]);
|
||||
block_limit_ =
|
||||
@@ -148,6 +150,8 @@ MetalAllocator::MetalAllocator()
|
||||
static_cast<size_t>(0.95 * device_->recommendedMaxWorkingSetSize()),
|
||||
block_limit_);
|
||||
max_pool_size_ = block_limit_;
|
||||
device(mlx::core::Device::gpu)
|
||||
.set_residency_set(residency_set_.mtl_residency_set());
|
||||
}
|
||||
|
||||
size_t MetalAllocator::set_cache_limit(size_t limit) {
|
||||
@@ -164,6 +168,12 @@ size_t MetalAllocator::set_memory_limit(size_t limit, bool relaxed) {
|
||||
return limit;
|
||||
};
|
||||
|
||||
size_t MetalAllocator::set_wired_limit(size_t limit) {
|
||||
std::swap(limit, wired_limit_);
|
||||
residency_set_.resize(wired_limit_);
|
||||
return limit;
|
||||
};
|
||||
|
||||
Buffer MetalAllocator::malloc(size_t size, bool allow_swap /* = false */) {
|
||||
// Metal doesn't like empty buffers
|
||||
if (size == 0) {
|
||||
@@ -205,7 +215,7 @@ Buffer MetalAllocator::malloc(size_t size, bool allow_swap /* = false */) {
|
||||
|
||||
// Allocate new buffer if needed
|
||||
size_t res_opt = MTL::ResourceStorageModeShared;
|
||||
res_opt |= MTL::ResourceHazardTrackingModeTracked;
|
||||
res_opt |= MTL::ResourceHazardTrackingModeUntracked;
|
||||
lk.unlock();
|
||||
buf = device_->newBuffer(size, res_opt);
|
||||
lk.lock();
|
||||
@@ -220,6 +230,8 @@ Buffer MetalAllocator::malloc(size_t size, bool allow_swap /* = false */) {
|
||||
buffer_cache_.release_cached_buffers(get_cache_memory() - max_pool_size_);
|
||||
}
|
||||
|
||||
residency_set_.insert(buf);
|
||||
|
||||
return Buffer{static_cast<void*>(buf)};
|
||||
}
|
||||
|
||||
@@ -230,7 +242,11 @@ void MetalAllocator::clear_cache() {
|
||||
|
||||
void MetalAllocator::free(Buffer buffer) {
|
||||
auto buf = static_cast<MTL::Buffer*>(buffer.ptr());
|
||||
if (buf == nullptr) {
|
||||
return;
|
||||
}
|
||||
std::unique_lock lk(mutex_);
|
||||
residency_set_.erase(buf);
|
||||
active_memory_ -= buf->length();
|
||||
if (get_cache_memory() < max_pool_size_) {
|
||||
buffer_cache_.recycle_to_cache(buf);
|
||||
@@ -246,15 +262,9 @@ size_t MetalAllocator::size(Buffer buffer) const {
|
||||
}
|
||||
|
||||
MetalAllocator& allocator() {
|
||||
// By creating the |allocator_| on heap, the destructor of MetalAllocator will
|
||||
// not be called on exit and all the buffers will be leaked. This is necessary
|
||||
// because releasing buffers can take more than 30sec when the program holds a
|
||||
// lot of RAM (for example inferencing a LLM), and it would feel frozen to
|
||||
// users when exiting.
|
||||
// TODO(zcbenz): Consider using the `base::NoDestructor` class from Chromium
|
||||
// when applying this pattern to more places, or when introducing sanitizers
|
||||
// to MLX.
|
||||
// https://source.chromium.org/chromium/chromium/src/+/main:base/no_destructor.h
|
||||
// By creating the |allocator_| on heap, the destructor of MetalAllocator
|
||||
// will not be called on exit and buffers in the cache will be leaked. This
|
||||
// can save some time at program exit.
|
||||
static MetalAllocator* allocator_ = new MetalAllocator;
|
||||
return *allocator_;
|
||||
}
|
||||
@@ -265,6 +275,15 @@ size_t set_cache_limit(size_t limit) {
|
||||
size_t set_memory_limit(size_t limit, bool relaxed /* = true */) {
|
||||
return allocator().set_memory_limit(limit, relaxed);
|
||||
}
|
||||
size_t set_wired_limit(size_t limit) {
|
||||
if (limit >
|
||||
std::get<size_t>(device_info()["max_recommended_working_set_size"])) {
|
||||
throw std::invalid_argument(
|
||||
"[metal::set_wired_limit] Setting a wired limit larger than "
|
||||
"the maximum working set size is not allowed.");
|
||||
}
|
||||
return allocator().set_wired_limit(limit);
|
||||
}
|
||||
size_t get_active_memory() {
|
||||
return allocator().get_active_memory();
|
||||
}
|
||||
|
@@ -8,6 +8,7 @@
|
||||
|
||||
#include "mlx/allocator.h"
|
||||
#include "mlx/backend/metal/device.h"
|
||||
#include "mlx/backend/metal/resident.h"
|
||||
|
||||
namespace mlx::core::metal {
|
||||
|
||||
@@ -72,6 +73,7 @@ class MetalAllocator : public allocator::Allocator {
|
||||
};
|
||||
size_t set_cache_limit(size_t limit);
|
||||
size_t set_memory_limit(size_t limit, bool relaxed);
|
||||
size_t set_wired_limit(size_t limit);
|
||||
void clear_cache();
|
||||
|
||||
private:
|
||||
@@ -82,12 +84,15 @@ class MetalAllocator : public allocator::Allocator {
|
||||
// Caching allocator
|
||||
BufferCache buffer_cache_;
|
||||
|
||||
ResidencySet residency_set_;
|
||||
|
||||
// Allocation stats
|
||||
size_t block_limit_;
|
||||
size_t gc_limit_;
|
||||
size_t active_memory_{0};
|
||||
size_t peak_memory_{0};
|
||||
size_t max_pool_size_;
|
||||
size_t wired_limit_{0};
|
||||
bool relaxed_{true};
|
||||
|
||||
std::mutex mutex_;
|
||||
|
@@ -1,5 +1,4 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#include "mlx/backend/common/binary.h"
|
||||
#include "mlx/backend/metal/device.h"
|
||||
#include "mlx/backend/metal/kernels.h"
|
||||
@@ -23,37 +22,37 @@ std::string get_kernel_name(
|
||||
BinaryOpType bopt,
|
||||
const std::string& op,
|
||||
const array& a,
|
||||
bool use_2d,
|
||||
bool large,
|
||||
int ndim,
|
||||
int work_per_thread) {
|
||||
std::ostringstream kname;
|
||||
std::string kname;
|
||||
switch (bopt) {
|
||||
case BinaryOpType::ScalarScalar:
|
||||
kname << "ss";
|
||||
kname = "ss";
|
||||
break;
|
||||
case BinaryOpType::ScalarVector:
|
||||
kname << (use_2d ? "sv2" : "sv");
|
||||
kname = (large ? "sv2" : "sv");
|
||||
break;
|
||||
case BinaryOpType::VectorScalar:
|
||||
kname << (use_2d ? "vs2" : "vs");
|
||||
kname = (large ? "vs2" : "vs");
|
||||
break;
|
||||
case BinaryOpType::VectorVector:
|
||||
kname << (use_2d ? "vv2" : "vv");
|
||||
kname = (large ? "vv2" : "vv");
|
||||
break;
|
||||
case BinaryOpType::General:
|
||||
kname << "g";
|
||||
kname = "g";
|
||||
if (ndim <= 3) {
|
||||
kname << ndim;
|
||||
kname += std::to_string(ndim);
|
||||
} else {
|
||||
kname << "n";
|
||||
if (work_per_thread > 1) {
|
||||
kname << work_per_thread;
|
||||
}
|
||||
concatenate(kname, "n", std::to_string(work_per_thread));
|
||||
}
|
||||
if (large) {
|
||||
kname += "large";
|
||||
}
|
||||
break;
|
||||
}
|
||||
kname << "_" << op << type_to_name(a);
|
||||
return kname.str();
|
||||
concatenate(kname, "_", op, type_to_name(a));
|
||||
return kname;
|
||||
}
|
||||
|
||||
void binary_op_gpu_inplace(
|
||||
@@ -82,19 +81,23 @@ void binary_op_gpu_inplace(
|
||||
};
|
||||
auto [shape, strides_a, strides_b, strides_out] = maybe_collapse();
|
||||
|
||||
bool use_2d = out.data_size() > UINT32_MAX;
|
||||
bool large = out.data_size() > UINT32_MAX;
|
||||
auto ndim = shape.size();
|
||||
int work_per_thread =
|
||||
(bopt == BinaryOpType::General && shape[ndim - 1] > 4) ? 4 : 1;
|
||||
int work_per_thread;
|
||||
if (bopt == BinaryOpType::General) {
|
||||
work_per_thread = large ? 4 : 2;
|
||||
} else {
|
||||
work_per_thread = 1;
|
||||
}
|
||||
std::string kernel_name =
|
||||
get_kernel_name(bopt, op, a, use_2d, shape.size(), work_per_thread);
|
||||
get_kernel_name(bopt, op, a, large, shape.size(), work_per_thread);
|
||||
auto& d = metal::device(s.device);
|
||||
|
||||
auto kernel = outputs.size() == 2
|
||||
? get_binary_two_kernel(d, kernel_name, a.dtype(), out.dtype(), op)
|
||||
: get_binary_kernel(d, kernel_name, a.dtype(), out.dtype(), op);
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
// - If a is donated it goes to the first output
|
||||
// - If b is donated it goes to the first output if a was not donated
|
||||
@@ -111,6 +114,7 @@ void binary_op_gpu_inplace(
|
||||
compute_encoder.set_output_array(outputs[1], arg_idx++);
|
||||
}
|
||||
|
||||
auto thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
||||
if (bopt == BinaryOpType::General) {
|
||||
// Launch up to 3D grid of threads
|
||||
size_t dim0 = ndim > 0 ? shape[ndim - 1] : 1;
|
||||
@@ -118,39 +122,33 @@ void binary_op_gpu_inplace(
|
||||
size_t rest = out.size() / (dim0 * dim1);
|
||||
|
||||
if (ndim > 3) {
|
||||
compute_encoder->setBytes(shape.data(), ndim * sizeof(int), arg_idx++);
|
||||
compute_encoder->setBytes(
|
||||
strides_a.data(), ndim * sizeof(size_t), arg_idx++);
|
||||
compute_encoder->setBytes(
|
||||
strides_b.data(), ndim * sizeof(size_t), arg_idx++);
|
||||
compute_encoder->setBytes(&ndim, sizeof(int), arg_idx++);
|
||||
compute_encoder.set_vector_bytes(shape, arg_idx++);
|
||||
compute_encoder.set_vector_bytes(strides_a, arg_idx++);
|
||||
compute_encoder.set_vector_bytes(strides_b, arg_idx++);
|
||||
compute_encoder.set_bytes<int>(ndim, arg_idx++);
|
||||
dim0 = (dim0 + work_per_thread - 1) / work_per_thread;
|
||||
} else {
|
||||
// The shape is implicit in the grid for <= 3D
|
||||
compute_encoder->setBytes(
|
||||
strides_a.data(), ndim * sizeof(size_t), arg_idx++);
|
||||
compute_encoder->setBytes(
|
||||
strides_b.data(), ndim * sizeof(size_t), arg_idx++);
|
||||
compute_encoder.set_vector_bytes(strides_a, arg_idx++);
|
||||
compute_encoder.set_vector_bytes(strides_b, arg_idx++);
|
||||
}
|
||||
|
||||
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
||||
if (thread_group_size != 1024) {
|
||||
throw std::runtime_error("[Metal::binary] Must use 1024 sized block");
|
||||
}
|
||||
auto group_dims = get_block_dims(dim0, dim1, rest);
|
||||
MTL::Size grid_dims = MTL::Size(dim0, dim1, rest);
|
||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
||||
} else {
|
||||
// Launch a 1D or 2D grid of threads
|
||||
size_t nthreads = out.data_size();
|
||||
MTL::Size grid_dims = use_2d ? get_2d_grid_dims(out.shape(), out.strides())
|
||||
: MTL::Size(nthreads, 1, 1);
|
||||
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
||||
if (thread_group_size > nthreads) {
|
||||
thread_group_size = nthreads;
|
||||
}
|
||||
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
|
||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||
MTL::Size grid_dims = large ? get_2d_grid_dims(out.shape(), out.strides())
|
||||
: MTL::Size(nthreads, 1, 1);
|
||||
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
||||
}
|
||||
}
|
||||
|
||||
|
@@ -1,5 +1,6 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include <fmt/format.h>
|
||||
#include <iostream> //TODO
|
||||
#include <sstream>
|
||||
|
||||
#include "mlx/backend/common/compiled.h"
|
||||
@@ -11,10 +12,12 @@
|
||||
#include "mlx/primitives.h"
|
||||
#include "mlx/utils.h"
|
||||
|
||||
using namespace fmt::literals;
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
inline void build_kernel(
|
||||
std::ostream& os,
|
||||
std::string& os,
|
||||
const std::string& kernel_name,
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<array>& outputs,
|
||||
@@ -23,7 +26,8 @@ inline void build_kernel(
|
||||
bool contiguous,
|
||||
int ndim,
|
||||
bool dynamic_dims,
|
||||
bool use_big_index = false) {
|
||||
bool use_big_index = false,
|
||||
int work_per_thread = 1) {
|
||||
// All outputs should have the exact same shape and will be row contiguous
|
||||
auto output_shape = outputs[0].shape();
|
||||
auto output_strides = outputs[0].strides();
|
||||
@@ -38,8 +42,8 @@ inline void build_kernel(
|
||||
int cnt = 0;
|
||||
|
||||
// Start the kernel
|
||||
os << "[[host_name(\"" << kernel_name << "\")]]" << std::endl
|
||||
<< "[[kernel]] void " << kernel_name << "(" << std::endl;
|
||||
os += fmt::format(
|
||||
"[[host_name(\"{0}\")]]\n[[kernel]] void {0}(\n", kernel_name);
|
||||
|
||||
// Add the input arguments
|
||||
for (auto& x : inputs) {
|
||||
@@ -51,135 +55,203 @@ inline void build_kernel(
|
||||
}
|
||||
|
||||
// Scalars and contiguous need no strides
|
||||
if (is_scalar(x) || contiguous) {
|
||||
os << " device const " << get_type_string(x.dtype()) << "* " << xname
|
||||
<< " [[buffer(" << cnt++ << ")]]," << std::endl;
|
||||
} else {
|
||||
if (!is_scalar(x) && !contiguous) {
|
||||
add_indices = true;
|
||||
os << " device const " << get_type_string(x.dtype()) << "* " << xname
|
||||
<< " [[buffer(" << cnt++ << ")]]," << std::endl;
|
||||
}
|
||||
os += fmt::format(
|
||||
" device const {0}* {1} [[buffer({2})]],\n",
|
||||
get_type_string(x.dtype()),
|
||||
xname,
|
||||
cnt++);
|
||||
}
|
||||
|
||||
if (add_indices) {
|
||||
os << " constant const size_t* in_strides [[buffer(" << cnt++
|
||||
<< ")]],\n";
|
||||
os += fmt::format(
|
||||
" constant const size_t* in_strides [[buffer({0})]],\n", cnt++);
|
||||
}
|
||||
|
||||
// Add the output arguments
|
||||
for (auto& x : outputs) {
|
||||
os << " device " << get_type_string(x.dtype()) << "* "
|
||||
<< namer.get_name(x) << " [[buffer(" << cnt++ << ")]]," << std::endl;
|
||||
os += fmt::format(
|
||||
" device {0}* {1} [[buffer({2})]],\n",
|
||||
get_type_string(x.dtype()),
|
||||
namer.get_name(x),
|
||||
cnt++);
|
||||
}
|
||||
// Add output strides and shape to extract the indices.
|
||||
if (!contiguous) {
|
||||
os << " constant const size_t* output_strides [[buffer(" << cnt++
|
||||
<< ")]]," << std::endl
|
||||
<< " constant const int* output_shape [[buffer(" << cnt++ << ")]],"
|
||||
<< std::endl;
|
||||
os += fmt::format(
|
||||
" constant const size_t* output_strides [[buffer({0})]],\n", cnt++);
|
||||
os += fmt::format(
|
||||
" constant const int* output_shape [[buffer({0})]],\n", cnt++);
|
||||
}
|
||||
if (dynamic_dims) {
|
||||
os << " constant const int& ndim [[buffer(" << cnt++ << ")]],"
|
||||
<< std::endl;
|
||||
os += fmt::format(" constant const int& ndim [[buffer({0})]],\n", cnt++);
|
||||
}
|
||||
|
||||
// The thread index in the whole grid
|
||||
os << " uint3 pos [[thread_position_in_grid]]," << std::endl
|
||||
<< " uint3 grid [[threads_per_grid]]) {" << std::endl;
|
||||
if (use_big_index) {
|
||||
os += " uint3 pos [[thread_position_in_grid]],\n";
|
||||
os += " uint3 grid [[threads_per_grid]]) {\n";
|
||||
|
||||
std::string idx_type = use_big_index ? "size_t" : "uint";
|
||||
if (contiguous && use_big_index) {
|
||||
// This is only used for contiguous kernels which don't have
|
||||
// a third grid dimension
|
||||
os << " size_t index = pos.x + grid.x * size_t(pos.y);";
|
||||
os += " size_t index = pos.x + grid.x * size_t(pos.y);\n";
|
||||
} else if (work_per_thread > 1) {
|
||||
os += fmt::format(" constexpr int N_ = {0};\n", work_per_thread);
|
||||
os += fmt::format(
|
||||
" int xshape = output_shape[{0}];\n",
|
||||
dynamic_dims ? "ndim - 1" : std::to_string(ndim - 1));
|
||||
os += fmt::format(
|
||||
" {0} index = N_ * pos.x + xshape * (pos.y + {0}(grid.y) * pos.z);\n",
|
||||
idx_type);
|
||||
} else {
|
||||
os << " uint index = pos.x + grid.x * (pos.y + grid.y * pos.z);";
|
||||
}
|
||||
os << std::endl;
|
||||
|
||||
// Extract the indices per axis to individual uints if we have arrays that
|
||||
// are broadcasted or transposed
|
||||
if (add_indices) {
|
||||
if (!dynamic_dims) {
|
||||
if (ndim == 1) {
|
||||
os << " uint index_0 = pos.x;" << std::endl;
|
||||
} else if (ndim == 2) {
|
||||
os << " uint index_0 = pos.y;" << std::endl
|
||||
<< " uint index_1 = pos.x;" << std::endl;
|
||||
} else if (ndim == 3) {
|
||||
os << " uint index_0 = pos.z;" << std::endl
|
||||
<< " uint index_1 = pos.y;" << std::endl
|
||||
<< " uint index_2 = pos.x;" << std::endl;
|
||||
} else {
|
||||
for (int i = 0; i < ndim - 2; i++) {
|
||||
os << " uint index_" << i << " = (index / uint(output_strides[" << i
|
||||
<< "])) % output_shape[" << i << "];" << std::endl;
|
||||
}
|
||||
os << " uint index_" << ndim - 2 << " = pos.y;" << std::endl
|
||||
<< " uint index_" << ndim - 1 << " = pos.x;" << std::endl;
|
||||
}
|
||||
}
|
||||
os += fmt::format(
|
||||
" {0} index = pos.x + grid.x * (pos.y + {0}(grid.y) * pos.z);\n",
|
||||
idx_type);
|
||||
}
|
||||
|
||||
// Read the inputs in tmps
|
||||
int nc_in_count = 0;
|
||||
// Read constant / contiguous inputs in tmps
|
||||
std::vector<array> nc_inputs;
|
||||
for (int i = 0; i < inputs.size(); ++i) {
|
||||
auto& x = inputs[i];
|
||||
auto& xname = namer.get_name(x);
|
||||
|
||||
if (is_constant(x)) {
|
||||
auto type_str = get_type_string(x.dtype());
|
||||
os << " auto tmp_" << xname << " = static_cast<"
|
||||
<< get_type_string(x.dtype()) << ">(";
|
||||
print_constant(os, x);
|
||||
os << ");" << std::endl;
|
||||
std::ostringstream ss;
|
||||
print_constant(ss, x);
|
||||
os += fmt::format(
|
||||
" auto tmp_{0} = static_cast<{1}>({2});\n",
|
||||
xname,
|
||||
get_type_string(x.dtype()),
|
||||
ss.str());
|
||||
} else if (is_scalar(x)) {
|
||||
os << " " << get_type_string(x.dtype()) << " tmp_" << xname << " = "
|
||||
<< xname << "[0];" << std::endl;
|
||||
os += fmt::format(
|
||||
" {0} tmp_{1} = {1}[0];\n", get_type_string(x.dtype()), xname);
|
||||
} else if (contiguous) {
|
||||
os << " " << get_type_string(x.dtype()) << " tmp_" << xname << " = "
|
||||
<< xname << "[index];" << std::endl;
|
||||
} else if (!dynamic_dims) {
|
||||
int offset = nc_in_count * ndim;
|
||||
os << " " << get_type_string(x.dtype()) << " tmp_" << xname << " = "
|
||||
<< xname << "[";
|
||||
os << "index_0 * " << "in_strides[" << offset << "]";
|
||||
for (int i = 1; i < ndim; i++) {
|
||||
os << " + index_" << i << " * " << "in_strides[" << offset + i << "]";
|
||||
}
|
||||
os << "];" << std::endl;
|
||||
nc_in_count++;
|
||||
os += fmt::format(
|
||||
" {0} tmp_{1} = {1}[index];\n", get_type_string(x.dtype()), xname);
|
||||
} else {
|
||||
os << " " << get_type_string(x.dtype()) << " tmp_" << xname << " = "
|
||||
<< xname << "[elem_to_loc(index, output_shape, in_strides + "
|
||||
<< nc_in_count * ndim << ", ndim)];" << std::endl;
|
||||
nc_in_count++;
|
||||
nc_inputs.push_back(x);
|
||||
}
|
||||
}
|
||||
|
||||
// Initialize the indices for non-contiguous inputs
|
||||
for (int i = 0; i < nc_inputs.size(); ++i) {
|
||||
auto& xname = namer.get_name(nc_inputs[i]);
|
||||
os += fmt::format(" {0} index_{1} = ", idx_type, xname);
|
||||
if (ndim == 1) {
|
||||
int offset = i * ndim;
|
||||
os += fmt::format(
|
||||
"elem_to_loc_1<size_t, uint>(pos.x, in_strides[{0}]);\n", offset);
|
||||
} else if (ndim == 2) {
|
||||
int offset = i * ndim;
|
||||
os += fmt::format(
|
||||
"elem_to_loc_2<size_t, {0}>({{pos.x, pos.y}}, in_strides + {1});\n",
|
||||
idx_type,
|
||||
offset);
|
||||
} else if (ndim == 3) {
|
||||
int offset = i * ndim;
|
||||
os += fmt::format(
|
||||
"elem_to_loc_3<size_t, {0}>(pos, in_strides + {1});\n",
|
||||
idx_type,
|
||||
offset);
|
||||
} else if (!dynamic_dims) {
|
||||
int offset = (i + 1) * ndim;
|
||||
os += fmt::format(
|
||||
"N_ * pos.x * {0}(in_strides[{1}]) + pos.y * {0}(in_strides[{2}]);\n",
|
||||
idx_type,
|
||||
offset - 1,
|
||||
offset - 2);
|
||||
} else {
|
||||
os += fmt::format(
|
||||
"N_ * pos.x * {0}(in_strides[ndim * {1} + ndim - 1]) + pos.y * {0}(in_strides[ndim * {1} + ndim - 2]);\n",
|
||||
idx_type,
|
||||
i);
|
||||
}
|
||||
}
|
||||
|
||||
if (!nc_inputs.empty() && (ndim > 3 || dynamic_dims)) {
|
||||
os += " uint zpos = pos.z;\n";
|
||||
if (dynamic_dims) {
|
||||
os += " for (int d = ndim - 3; d >= 0; --d) {\n";
|
||||
} else {
|
||||
os += fmt::format(" for (int d = {0}; d >= 0; --d) {{\n", ndim - 3);
|
||||
}
|
||||
os += " uint l = zpos % output_shape[d];\n";
|
||||
for (int i = 0; i < nc_inputs.size(); ++i) {
|
||||
auto& xname = namer.get_name(nc_inputs[i]);
|
||||
os += fmt::format(" index_{0} += ", xname);
|
||||
if (dynamic_dims) {
|
||||
os +=
|
||||
fmt::format("l * {0}(in_strides[{1} * ndim + d]);\n", idx_type, i);
|
||||
} else {
|
||||
os +=
|
||||
fmt::format("l * {0}(in_strides[{1} + d]);\n", idx_type, i * ndim);
|
||||
}
|
||||
}
|
||||
os += " zpos /= output_shape[d];\n }\n";
|
||||
}
|
||||
|
||||
// Open per-thread loop
|
||||
if (work_per_thread > 1) {
|
||||
os +=
|
||||
" for (int i = 0; i < N_ && (int(N_ * pos.x) + i) < xshape; ++i) {\n";
|
||||
}
|
||||
|
||||
// Read non-contiguous inputs into tmps
|
||||
for (int i = 0; i < nc_inputs.size(); ++i) {
|
||||
auto& x = nc_inputs[i];
|
||||
auto& xname = namer.get_name(x);
|
||||
os += fmt::format(
|
||||
" {0} tmp_{1} = {1}[index_{1}];\n", get_type_string(x.dtype()), xname);
|
||||
}
|
||||
|
||||
// Actually write the computation
|
||||
for (auto& x : tape) {
|
||||
os << " " << get_type_string(x.dtype()) << " tmp_" << namer.get_name(x)
|
||||
<< " = ";
|
||||
os += fmt::format(
|
||||
" {0} tmp_{1} = ", get_type_string(x.dtype()), namer.get_name(x));
|
||||
if (is_static_cast(x.primitive())) {
|
||||
os << "static_cast<" << get_type_string(x.dtype()) << ">(tmp_"
|
||||
<< namer.get_name(x.inputs()[0]) << ");" << std::endl;
|
||||
os += fmt::format(
|
||||
"static_cast<{0}>(tmp_{1});\n",
|
||||
get_type_string(x.dtype()),
|
||||
namer.get_name(x.inputs()[0]));
|
||||
} else {
|
||||
x.primitive().print(os);
|
||||
os << "()(";
|
||||
std::ostringstream ss;
|
||||
x.primitive().print(ss);
|
||||
os += ss.str();
|
||||
os += "()(";
|
||||
for (int i = 0; i < x.inputs().size() - 1; i++) {
|
||||
os << "tmp_" << namer.get_name(x.inputs()[i]) << ", ";
|
||||
os += fmt::format("tmp_{0}, ", namer.get_name(x.inputs()[i]));
|
||||
}
|
||||
os << "tmp_" << namer.get_name(x.inputs().back()) << ");" << std::endl;
|
||||
os += fmt::format("tmp_{0});\n", namer.get_name(x.inputs().back()));
|
||||
}
|
||||
}
|
||||
|
||||
// Write the outputs from tmps
|
||||
for (auto& x : outputs) {
|
||||
os << " " << namer.get_name(x) << "[index] = tmp_" << namer.get_name(x)
|
||||
<< ";" << std::endl;
|
||||
os += fmt::format(" {0}[index] = tmp_{0};\n", namer.get_name(x));
|
||||
}
|
||||
// Increment indices and close per thread loop
|
||||
if (work_per_thread > 1) {
|
||||
for (int i = 0; i < nc_inputs.size(); ++i) {
|
||||
auto& x = nc_inputs[i];
|
||||
auto& xname = namer.get_name(x);
|
||||
if (!dynamic_dims) {
|
||||
os += fmt::format(
|
||||
" index_{0} += in_strides[{1}];\n", xname, i * ndim + ndim - 1);
|
||||
} else {
|
||||
os += fmt::format(
|
||||
" index_{0} += in_strides[{1} * ndim + ndim - 1];\n", xname, i);
|
||||
}
|
||||
}
|
||||
os += " index++;\n }\n";
|
||||
}
|
||||
|
||||
// Finish the kernel
|
||||
os << "}" << std::endl;
|
||||
os += "}\n";
|
||||
|
||||
if (cnt > 31) {
|
||||
std::ostringstream msg;
|
||||
@@ -202,13 +274,10 @@ void Compiled::eval_gpu(
|
||||
// Get the kernel if someone else built it already
|
||||
auto& s = stream();
|
||||
auto& d = metal::device(s.device);
|
||||
auto lib = d.get_library(kernel_lib_);
|
||||
|
||||
// If not we have to build it ourselves
|
||||
if (lib == nullptr) {
|
||||
std::ostringstream kernel;
|
||||
kernel << metal::utils() << metal::unary_ops() << metal::binary_ops()
|
||||
<< metal::ternary_ops();
|
||||
auto lib = d.get_library(kernel_lib_, [&]() {
|
||||
std::string kernel = metal::utils();
|
||||
concatenate(
|
||||
kernel, metal::unary_ops(), metal::binary_ops(), metal::ternary_ops());
|
||||
build_kernel(
|
||||
kernel,
|
||||
kernel_lib_ + "_contiguous",
|
||||
@@ -221,7 +290,7 @@ void Compiled::eval_gpu(
|
||||
/* dynamic_dims = */ false);
|
||||
build_kernel(
|
||||
kernel,
|
||||
kernel_lib_ + "_contiguous_big",
|
||||
kernel_lib_ + "_contiguous_large",
|
||||
inputs_,
|
||||
outputs_,
|
||||
tape_,
|
||||
@@ -240,7 +309,23 @@ void Compiled::eval_gpu(
|
||||
constant_ids_,
|
||||
/* contiguous = */ false,
|
||||
/* ndim = */ i,
|
||||
/* dynamic_dims = */ false);
|
||||
/* dynamic_dims = */ false,
|
||||
/* use_big_index = */ false,
|
||||
/* work_per_thread = */ i > 3 ? 2 : 1);
|
||||
if (i > 1) {
|
||||
build_kernel(
|
||||
kernel,
|
||||
kernel_lib_ + "_strided_" + std::to_string(i) + "_large",
|
||||
inputs_,
|
||||
outputs_,
|
||||
tape_,
|
||||
constant_ids_,
|
||||
/* contiguous = */ false,
|
||||
/* ndim = */ i,
|
||||
/* dynamic_dims = */ false,
|
||||
/* use_big_index = */ true,
|
||||
/* work_per_thread = */ i > 3 ? 4 : 1);
|
||||
}
|
||||
}
|
||||
build_kernel(
|
||||
kernel,
|
||||
@@ -251,14 +336,27 @@ void Compiled::eval_gpu(
|
||||
constant_ids_,
|
||||
/* contiguous = */ false,
|
||||
/* ndim = */ 0,
|
||||
/* dynamic_dims = */ true);
|
||||
|
||||
lib = d.get_library(kernel_lib_, kernel.str());
|
||||
}
|
||||
/* dynamic_dims = */ true,
|
||||
/* use_big_index = */ false,
|
||||
/* work_per_thread = */ 2);
|
||||
build_kernel(
|
||||
kernel,
|
||||
kernel_lib_ + "_strided_dynamic_large",
|
||||
inputs_,
|
||||
outputs_,
|
||||
tape_,
|
||||
constant_ids_,
|
||||
/* contiguous = */ false,
|
||||
/* ndim = */ 0,
|
||||
/* dynamic_dims = */ true,
|
||||
/* use_big_index = */ true,
|
||||
/* work_per_thread = */ 4);
|
||||
return kernel;
|
||||
});
|
||||
|
||||
// Figure out which kernel we are using
|
||||
auto& output_shape = outputs[0].shape();
|
||||
bool contiguous = compiled_check_contiguity(inputs, output_shape);
|
||||
auto contiguous = compiled_check_contiguity(inputs, output_shape);
|
||||
|
||||
// Collapse contiguous dims to route to a faster kernel if possible. Also
|
||||
// handle all broadcasting.
|
||||
@@ -306,13 +404,19 @@ void Compiled::eval_gpu(
|
||||
collapse_contiguous_dims(output_shape, initial_strides, INT32_MAX);
|
||||
}
|
||||
|
||||
bool use_2d = false;
|
||||
bool large;
|
||||
if (contiguous) {
|
||||
size_t max_size = 0;
|
||||
for (auto& in : inputs) {
|
||||
max_size = std::max(max_size, in.data_size());
|
||||
}
|
||||
use_2d = (max_size > UINT32_MAX);
|
||||
large = (max_size > UINT32_MAX);
|
||||
} else {
|
||||
size_t max_size = 0;
|
||||
for (auto& o : outputs) {
|
||||
max_size = std::max(max_size, o.size());
|
||||
}
|
||||
large = (max_size > UINT32_MAX);
|
||||
}
|
||||
|
||||
// Get the kernel from the lib
|
||||
@@ -325,12 +429,13 @@ void Compiled::eval_gpu(
|
||||
} else {
|
||||
kernel_name += std::to_string(shape.size());
|
||||
}
|
||||
} else if (use_2d) {
|
||||
kernel_name += "_big";
|
||||
}
|
||||
if (large) {
|
||||
kernel_name += "_large";
|
||||
}
|
||||
auto kernel = d.get_kernel(kernel_name, lib);
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
// Put the inputs in
|
||||
int cnt = 0;
|
||||
@@ -351,8 +456,7 @@ void Compiled::eval_gpu(
|
||||
}
|
||||
}
|
||||
if (!in_strides.empty()) {
|
||||
compute_encoder->setBytes(
|
||||
in_strides.data(), in_strides.size() * sizeof(size_t), cnt++);
|
||||
compute_encoder.set_vector_bytes(in_strides, cnt++);
|
||||
}
|
||||
|
||||
compiled_allocate_outputs(
|
||||
@@ -365,36 +469,43 @@ void Compiled::eval_gpu(
|
||||
|
||||
// Put the output shape and strides in
|
||||
if (!contiguous) {
|
||||
compute_encoder->setBytes(
|
||||
strides[0].data(), strides[0].size() * sizeof(size_t), cnt++);
|
||||
compute_encoder->setBytes(shape.data(), shape.size() * sizeof(int), cnt++);
|
||||
compute_encoder.set_vector_bytes(strides[0], cnt++);
|
||||
compute_encoder.set_vector_bytes(shape, cnt++);
|
||||
}
|
||||
|
||||
// Put the number of dims in if it is dynamic
|
||||
if (dynamic) {
|
||||
compute_encoder->setBytes(&ndim, sizeof(int), cnt++);
|
||||
compute_encoder.set_bytes(ndim, cnt++);
|
||||
}
|
||||
|
||||
// Launch the kernel
|
||||
if (contiguous) {
|
||||
size_t nthreads = outputs[0].data_size();
|
||||
MTL::Size grid_dims = use_2d
|
||||
? get_2d_grid_dims(outputs[0].shape(), outputs[0].strides())
|
||||
: MTL::Size(nthreads, 1, 1);
|
||||
MTL::Size group_dims(
|
||||
std::min(nthreads, kernel->maxTotalThreadsPerThreadgroup()), 1, 1);
|
||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||
|
||||
MTL::Size grid_dims = large
|
||||
? get_2d_grid_dims(outputs[0].shape(), outputs[0].strides())
|
||||
: MTL::Size(nthreads, 1, 1);
|
||||
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
||||
} else {
|
||||
size_t dim0 = ndim > 0 ? shape[ndim - 1] : 1;
|
||||
size_t dim1 = ndim > 1 ? shape[ndim - 2] : 1;
|
||||
size_t rest = outputs[0].size() / (dim0 * dim1);
|
||||
int work_per_thread = ndim > 3 ? (large ? 4 : 2) : 1;
|
||||
dim0 = (dim0 + work_per_thread - 1) / work_per_thread;
|
||||
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
||||
if (thread_group_size != 1024) {
|
||||
throw std::runtime_error("[Metal::binary] Must use 1024 sized block");
|
||||
int pow2;
|
||||
if (thread_group_size == 1024) {
|
||||
pow2 = 10;
|
||||
} else if (thread_group_size > 512) {
|
||||
pow2 = 9;
|
||||
} else {
|
||||
throw std::runtime_error("[Metal::compiled] Must use > 512 sized block");
|
||||
}
|
||||
auto group_dims = get_block_dims(dim0, dim1, rest);
|
||||
auto group_dims = get_block_dims(dim0, dim1, rest, pow2);
|
||||
MTL::Size grid_dims = MTL::Size(dim0, dim1, rest);
|
||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
||||
}
|
||||
}
|
||||
|
||||
|
@@ -44,23 +44,24 @@ void explicit_gemm_conv_ND_gpu(
|
||||
kname << "naive_unfold_nd_" << type_to_name(in_unfolded) << "_" << N;
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
compute_encoder.set_input_array(in, 0);
|
||||
compute_encoder.set_output_array(in_unfolded, 1);
|
||||
|
||||
compute_encoder->setBytes(&conv_params, sizeof(conv_params), 2);
|
||||
compute_encoder.set_bytes(conv_params, 2);
|
||||
|
||||
// Launch unfolding kernel
|
||||
int tgp_x = std::min(conv_params.C, 64);
|
||||
size_t tgp_x = std::min(conv_params.C, 64);
|
||||
tgp_x = 32 * ((tgp_x + 32 - 1) / 32);
|
||||
int tgp_y = 256 / tgp_x;
|
||||
size_t tgp_y = 256 / tgp_x;
|
||||
|
||||
MTL::Size group_dims = MTL::Size(tgp_x, tgp_y, 1);
|
||||
MTL::Size grid_dims = MTL::Size(
|
||||
conv_params.C, unfolded_shape[1] / conv_params.C, unfolded_shape[0]);
|
||||
MTL::Size group_dims = MTL::Size(
|
||||
std::min(tgp_x, grid_dims.width), std::min(tgp_y, grid_dims.height), 1);
|
||||
|
||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
||||
|
||||
// Reshape weight
|
||||
std::vector<int> wt_reshape{implicit_K, implicit_N};
|
||||
@@ -72,7 +73,7 @@ void explicit_gemm_conv_ND_gpu(
|
||||
wt_reshaped.copy_shared_buffer(wt, wt_restride, wt_flags, wt.data_size());
|
||||
|
||||
// Perform gemm
|
||||
std::vector<array> copies = {in_unfolded, wt_reshaped};
|
||||
std::vector<array> copies = {in_unfolded};
|
||||
return steel_matmul(
|
||||
s,
|
||||
d,
|
||||
@@ -122,23 +123,24 @@ void explicit_gemm_conv_group_ND_gpu(
|
||||
<< N;
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
compute_encoder.set_input_array(in, 0);
|
||||
compute_encoder.set_output_array(in_unfolded, 1);
|
||||
|
||||
compute_encoder->setBytes(&conv_params, sizeof(conv_params), 2);
|
||||
compute_encoder.set_bytes(conv_params, 2);
|
||||
|
||||
// Launch unfolding kernel
|
||||
int tgp_x = std::min(conv_params.C, 64);
|
||||
size_t tgp_x = std::min(conv_params.C, 64);
|
||||
tgp_x = 32 * ((tgp_x + 32 - 1) / 32);
|
||||
int tgp_y = 256 / tgp_x;
|
||||
size_t tgp_y = 256 / tgp_x;
|
||||
|
||||
MTL::Size group_dims = MTL::Size(tgp_x, tgp_y, 1);
|
||||
MTL::Size grid_dims = MTL::Size(
|
||||
conv_params.C, unfolded_shape[1] / conv_params.C, unfolded_shape[0]);
|
||||
MTL::Size group_dims = MTL::Size(
|
||||
std::min(tgp_x, grid_dims.width), std::min(tgp_y, grid_dims.height), 1);
|
||||
|
||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
||||
|
||||
// Transpose kernel weights so that we can slice them by contiguous chunks
|
||||
// of channel groups.
|
||||
@@ -155,22 +157,27 @@ void explicit_gemm_conv_group_ND_gpu(
|
||||
copy_gpu(wt_view, wt_transpose, CopyType::General, s);
|
||||
|
||||
// Perform gemm
|
||||
std::vector<array> copies = {in_unfolded, wt_view, wt_transpose};
|
||||
return steel_matmul_conv_groups(
|
||||
std::vector<array> copies = {in_unfolded, wt_transpose};
|
||||
return steel_matmul_regular(
|
||||
s,
|
||||
d,
|
||||
/*a = */ in_unfolded,
|
||||
/*b = */ wt_transpose,
|
||||
/*c = */ out,
|
||||
/*M = */ implicit_M,
|
||||
/*N = */ implicit_N,
|
||||
/*K = */ implicit_K,
|
||||
/*a_cols = */ implicit_K * groups,
|
||||
/*b_cols = */ implicit_K,
|
||||
/*out_cols = */ implicit_N * groups,
|
||||
/*a_transposed = */ false,
|
||||
/*b_transposed = */ true,
|
||||
/* groups = */ groups,
|
||||
/* a = */ in_unfolded,
|
||||
/* b = */ wt_transpose,
|
||||
/* c = */ out,
|
||||
/* M = */ implicit_M,
|
||||
/* N = */ implicit_N,
|
||||
/* K = */ implicit_K,
|
||||
/* batch_size_out = */ groups,
|
||||
/* a_cols = */ implicit_K * groups,
|
||||
/* b_cols = */ implicit_K,
|
||||
/* out_cols = */ implicit_N * groups,
|
||||
/* a_transposed = */ false,
|
||||
/* b_transposed = */ true,
|
||||
/* batch_shape = */ {1},
|
||||
/* batch_strides = */ {0},
|
||||
/* A_batch_strides = */ size_t(implicit_K),
|
||||
/* B_batch_strides = */ size_t(implicit_N) * implicit_K,
|
||||
/* matrix_stride_out = */ size_t(implicit_N),
|
||||
/*copies = */ copies);
|
||||
}
|
||||
|
||||
@@ -232,7 +239,7 @@ void slow_conv_2D_gpu(
|
||||
// Encode and dispatch kernel
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
size_t n_pixels = conv_params.oS[0] * conv_params.oS[1];
|
||||
|
||||
@@ -247,8 +254,8 @@ void slow_conv_2D_gpu(
|
||||
compute_encoder.set_input_array(wt, 1);
|
||||
compute_encoder.set_output_array(out, 2);
|
||||
|
||||
compute_encoder->setBytes(&conv_params, sizeof(MLXConvParams<2>), 3);
|
||||
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
|
||||
compute_encoder.set_bytes(conv_params, 3);
|
||||
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
void implicit_gemm_conv_2D_gpu(
|
||||
@@ -347,7 +354,7 @@ void implicit_gemm_conv_2D_gpu(
|
||||
wn,
|
||||
n_channel_specialization,
|
||||
small_filter);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
// Deduce grid launch dimensions
|
||||
int tile = 1 << swizzle_log;
|
||||
@@ -363,11 +370,11 @@ void implicit_gemm_conv_2D_gpu(
|
||||
compute_encoder.set_output_array(out, 2);
|
||||
|
||||
// Encode params
|
||||
compute_encoder->setBytes(&conv_params, sizeof(MLXConvParams<2>), 3);
|
||||
compute_encoder->setBytes(&gemm_params, sizeof(ImplicitGemmConv2DParams), 4);
|
||||
compute_encoder.set_bytes(conv_params, 3);
|
||||
compute_encoder.set_bytes(gemm_params, 4);
|
||||
|
||||
// Launch kernel
|
||||
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
|
||||
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
void implicit_gemm_conv_2D_general_gpu(
|
||||
@@ -501,7 +508,7 @@ void implicit_gemm_conv_2D_general_gpu(
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel =
|
||||
get_steel_conv_general_kernel(d, kname.str(), out, bm, bn, bk, wm, wn);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
// Deduce grid launch dimensions
|
||||
int tile = 1 << swizzle_log;
|
||||
@@ -518,17 +525,15 @@ void implicit_gemm_conv_2D_general_gpu(
|
||||
compute_encoder.set_output_array(out, 2);
|
||||
|
||||
// Encode params
|
||||
compute_encoder->setBytes(&conv_params, sizeof(MLXConvParams<2>), 3);
|
||||
compute_encoder->setBytes(&gemm_params, sizeof(ImplicitGemmConv2DParams), 4);
|
||||
compute_encoder->setBytes(&jump_params, sizeof(Conv2DGeneralJumpParams), 5);
|
||||
compute_encoder.set_bytes(conv_params, 3);
|
||||
compute_encoder.set_bytes(gemm_params, 4);
|
||||
compute_encoder.set_bytes(jump_params, 5);
|
||||
|
||||
compute_encoder->setBytes(
|
||||
base_h.data(), sizeof(Conv2DGeneralBaseInfo) * base_h.size(), 6);
|
||||
compute_encoder->setBytes(
|
||||
base_w.data(), sizeof(Conv2DGeneralBaseInfo) * base_w.size(), 7);
|
||||
compute_encoder.set_vector_bytes(base_h, 6);
|
||||
compute_encoder.set_vector_bytes(base_w, 7);
|
||||
|
||||
// Launch kernel
|
||||
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
|
||||
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
void winograd_conv_2D_gpu(
|
||||
@@ -617,18 +622,18 @@ void winograd_conv_2D_gpu(
|
||||
<< bc;
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
compute_encoder.set_input_array(wt, 0);
|
||||
compute_encoder.set_output_array(filt_wg, 1);
|
||||
|
||||
compute_encoder->setBytes(&C_c, sizeof(int), 2);
|
||||
compute_encoder->setBytes(&O_c, sizeof(int), 3);
|
||||
compute_encoder.set_bytes(C_c, 2);
|
||||
compute_encoder.set_bytes(O_c, 3);
|
||||
|
||||
MTL::Size group_dims = MTL::Size(32, bo, 1);
|
||||
MTL::Size grid_dims = MTL::Size(O_c / bo, 1, 1);
|
||||
|
||||
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
|
||||
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
// Do input transform
|
||||
@@ -645,18 +650,17 @@ void winograd_conv_2D_gpu(
|
||||
<< bc;
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
compute_encoder.set_input_array(in_padded, 0);
|
||||
compute_encoder.set_output_array(inp_wg, 1);
|
||||
|
||||
compute_encoder->setBytes(
|
||||
&conv_params_updated, sizeof(MLXConvParams<2>), 2);
|
||||
compute_encoder.set_bytes(conv_params_updated, 2);
|
||||
|
||||
MTL::Size group_dims = MTL::Size(32, wn, wm);
|
||||
MTL::Size grid_dims = MTL::Size(N_tiles_w, N_tiles_h, N_tiles_n);
|
||||
|
||||
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
|
||||
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
// Do batched gemm
|
||||
@@ -693,18 +697,17 @@ void winograd_conv_2D_gpu(
|
||||
<< bc;
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
compute_encoder.set_input_array(out_wg, 0);
|
||||
compute_encoder.set_output_array(out, 1);
|
||||
|
||||
compute_encoder->setBytes(
|
||||
&conv_params_updated, sizeof(MLXConvParams<2>), 2);
|
||||
compute_encoder.set_bytes(conv_params_updated, 2);
|
||||
|
||||
MTL::Size group_dims = MTL::Size(32, wn, wm);
|
||||
MTL::Size grid_dims = MTL::Size(N_tiles_w, N_tiles_h, N_tiles_n);
|
||||
|
||||
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
|
||||
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -747,10 +750,6 @@ void conv_2D_gpu(
|
||||
bool is_kdil_one = conv_params.kdil[0] == 1 && conv_params.kdil[1] == 1;
|
||||
bool is_idil_one = conv_params.idil[0] == 1 && conv_params.idil[1] == 1;
|
||||
|
||||
bool inp_large = (conv_params.in_strides[0] >= 1ul << 18);
|
||||
bool channels_large = (conv_params.C + conv_params.O) >= 512;
|
||||
bool channels_med = (conv_params.C + conv_params.O) >= 256;
|
||||
|
||||
if (groups > 1) {
|
||||
const int C_per_group = conv_params.C / groups;
|
||||
const int O_per_group = conv_params.O / groups;
|
||||
@@ -764,10 +763,13 @@ void conv_2D_gpu(
|
||||
}
|
||||
|
||||
// Direct to winograd conv
|
||||
bool inp_large =
|
||||
(conv_params.N * conv_params.iS[0] * conv_params.iS[1]) >= 1ul << 12;
|
||||
bool channels_large = (conv_params.C + conv_params.O) >= 256;
|
||||
if (!flip && is_stride_one && is_kdil_one && is_idil_one &&
|
||||
conv_params.wS[0] == 3 && conv_params.wS[1] == 3 &&
|
||||
conv_params.C % 32 == 0 && conv_params.O % 32 == 0 &&
|
||||
(channels_large || (channels_med && inp_large))) {
|
||||
conv_params.C % 32 == 0 && conv_params.O % 32 == 0 && inp_large &&
|
||||
channels_large) {
|
||||
return winograd_conv_2D_gpu(s, d, in, wt, out, conv_params, copies);
|
||||
}
|
||||
|
||||
@@ -913,12 +915,8 @@ void Convolution::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
"[Convolution::eval_gpu] Only supports 1D, 2D or 3D convolutions.");
|
||||
}
|
||||
|
||||
// Clear copies
|
||||
if (copies.size() > 0) {
|
||||
auto command_buffer = d.get_command_buffer(s.index);
|
||||
command_buffer->addCompletedHandler(
|
||||
[copies](MTL::CommandBuffer*) mutable { copies.clear(); });
|
||||
}
|
||||
// Record copies
|
||||
d.add_temporaries(std::move(copies), s.index);
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@@ -74,44 +74,46 @@ void copy_gpu_inplace(
|
||||
};
|
||||
auto [shape, strides_in_, strides_out_] = maybe_collapse();
|
||||
int ndim = shape.size();
|
||||
|
||||
bool use_2d = out.data_size() > UINT32_MAX;
|
||||
bool large;
|
||||
if (ctype == CopyType::General || ctype == CopyType::GeneralGeneral) {
|
||||
// Allow for negative strides
|
||||
large = out.data_size() > INT32_MAX;
|
||||
} else {
|
||||
large = out.data_size() > UINT32_MAX;
|
||||
}
|
||||
auto& d = metal::device(s.device);
|
||||
int work_per_thread = 1;
|
||||
std::string kernel_name;
|
||||
{
|
||||
std::ostringstream kname;
|
||||
switch (ctype) {
|
||||
case CopyType::Scalar:
|
||||
kname << (use_2d ? "s2" : "s");
|
||||
break;
|
||||
case CopyType::Vector:
|
||||
kname << (use_2d ? "v2" : "v");
|
||||
break;
|
||||
case CopyType::General:
|
||||
kname << "g";
|
||||
break;
|
||||
case CopyType::GeneralGeneral:
|
||||
kname << "gg";
|
||||
break;
|
||||
}
|
||||
if (ctype == CopyType::General || ctype == CopyType::GeneralGeneral) {
|
||||
if (shape.size() <= MAX_COPY_SPECIALIZED_DIMS) {
|
||||
kname << shape.size();
|
||||
} else if (shape[ndim - 1] >= 4) {
|
||||
work_per_thread = 4;
|
||||
kname << "n4";
|
||||
}
|
||||
}
|
||||
kname << "_copy";
|
||||
kname << type_to_name(in) << type_to_name(out);
|
||||
kernel_name = kname.str();
|
||||
switch (ctype) {
|
||||
case CopyType::Scalar:
|
||||
kernel_name = (large ? "s2" : "s");
|
||||
break;
|
||||
case CopyType::Vector:
|
||||
kernel_name = (large ? "v2" : "v");
|
||||
break;
|
||||
case CopyType::General:
|
||||
kernel_name = "g";
|
||||
break;
|
||||
case CopyType::GeneralGeneral:
|
||||
kernel_name = "gg";
|
||||
break;
|
||||
}
|
||||
|
||||
if (ctype == CopyType::General || ctype == CopyType::GeneralGeneral) {
|
||||
if (shape.size() <= MAX_COPY_SPECIALIZED_DIMS) {
|
||||
kernel_name += std::to_string(shape.size());
|
||||
} else {
|
||||
work_per_thread = large ? 4 : 2;
|
||||
concatenate(kernel_name, "n", std::to_string(work_per_thread));
|
||||
}
|
||||
if (large) {
|
||||
kernel_name += "large";
|
||||
}
|
||||
}
|
||||
concatenate(kernel_name, "_copy", type_to_name(in), type_to_name(out));
|
||||
auto kernel = get_copy_kernel(d, kernel_name, in, out);
|
||||
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
bool donate_in = in.data_shared_ptr() == nullptr;
|
||||
|
||||
inp_offset *= size_of(in.dtype());
|
||||
@@ -120,15 +122,16 @@ void copy_gpu_inplace(
|
||||
compute_encoder.set_input_array(donate_in ? out : in, 0, inp_offset);
|
||||
compute_encoder.set_output_array(out, 1, out_offset);
|
||||
|
||||
auto thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
||||
if (ctype == CopyType::General || ctype == CopyType::GeneralGeneral) {
|
||||
std::vector<int64_t> strides_in{strides_in_.begin(), strides_in_.end()};
|
||||
std::vector<int64_t> strides_out{strides_out_.begin(), strides_out_.end()};
|
||||
if (ndim > 3) {
|
||||
set_vector_bytes(compute_encoder, shape, ndim, 2);
|
||||
compute_encoder.set_vector_bytes(shape, ndim, 2);
|
||||
}
|
||||
set_vector_bytes(compute_encoder, strides_in, ndim, 3);
|
||||
compute_encoder.set_vector_bytes(strides_in, ndim, 3);
|
||||
if (ctype == CopyType::GeneralGeneral) {
|
||||
set_vector_bytes(compute_encoder, strides_out, ndim, 4);
|
||||
compute_encoder.set_vector_bytes(strides_out, ndim, 4);
|
||||
}
|
||||
|
||||
int dim0 = ndim > 0 ? shape[ndim - 1] : 1;
|
||||
@@ -140,29 +143,27 @@ void copy_gpu_inplace(
|
||||
int rest = data_size / (dim0 * dim1);
|
||||
|
||||
if (ndim > MAX_COPY_SPECIALIZED_DIMS) {
|
||||
compute_encoder->setBytes(&ndim, sizeof(int), 5);
|
||||
compute_encoder.set_bytes(ndim, 5);
|
||||
dim0 = (dim0 + work_per_thread - 1) / work_per_thread;
|
||||
}
|
||||
|
||||
// NB assuming thread_group_size is a power of 2 larger than 32 x 32
|
||||
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
||||
if (thread_group_size != 1024) {
|
||||
throw std::runtime_error("[Metal::copy] Must use 1024 sized block");
|
||||
}
|
||||
|
||||
auto group_dims = get_block_dims(dim0, dim1, rest);
|
||||
MTL::Size grid_dims = MTL::Size(dim0, dim1, rest);
|
||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
||||
} else {
|
||||
size_t nthreads = out.data_size();
|
||||
MTL::Size grid_dims = use_2d ? get_2d_grid_dims(out.shape(), out.strides())
|
||||
: MTL::Size(nthreads, 1, 1);
|
||||
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
||||
if (thread_group_size > nthreads) {
|
||||
thread_group_size = nthreads;
|
||||
}
|
||||
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
|
||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||
MTL::Size grid_dims = large ? get_2d_grid_dims(out.shape(), out.strides())
|
||||
: MTL::Size(nthreads, 1, 1);
|
||||
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -194,26 +195,26 @@ void fill_gpu(const array& val, array& out, const Stream& s) {
|
||||
return;
|
||||
}
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
bool use_2d = out.data_size() > UINT32_MAX;
|
||||
bool large = out.data_size() > UINT32_MAX;
|
||||
auto& d = metal::device(s.device);
|
||||
std::string kernel_name = std::string(use_2d ? "s2" : "s") + "_copy" +
|
||||
std::string kernel_name = std::string(large ? "s2" : "s") + "_copy" +
|
||||
type_to_name(val) + type_to_name(out);
|
||||
auto kernel = get_copy_kernel(d, kernel_name, val, out);
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
compute_encoder.set_input_array(val, 0);
|
||||
compute_encoder.set_output_array(out, 1);
|
||||
|
||||
auto thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
||||
size_t nthreads = out.data_size();
|
||||
MTL::Size grid_dims = use_2d ? get_2d_grid_dims(out.shape(), out.strides())
|
||||
: MTL::Size(nthreads, 1, 1);
|
||||
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
||||
if (thread_group_size > nthreads) {
|
||||
thread_group_size = nthreads;
|
||||
}
|
||||
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
|
||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||
MTL::Size grid_dims = large ? get_2d_grid_dims(out.shape(), out.strides())
|
||||
: MTL::Size(nthreads, 1, 1);
|
||||
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@@ -32,20 +32,18 @@ void CustomKernel::eval_gpu(
|
||||
return copies.back();
|
||||
}
|
||||
};
|
||||
std::vector<const array> checked_inputs;
|
||||
std::vector<array> checked_inputs;
|
||||
for (const array& in : inputs) {
|
||||
checked_inputs.push_back(check_input(in));
|
||||
}
|
||||
|
||||
auto& d = metal::device(s.device);
|
||||
const auto& lib_name = name_;
|
||||
auto lib = d.get_library(lib_name);
|
||||
if (lib == nullptr) {
|
||||
lib = d.get_library(lib_name, metal::utils() + source_);
|
||||
}
|
||||
auto lib =
|
||||
d.get_library(lib_name, [this] { return metal::utils() + source_; });
|
||||
auto kernel = d.get_kernel(name_, lib);
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
int index = 0;
|
||||
for (int i = 0; i < checked_inputs.size(); i++) {
|
||||
const array& in = checked_inputs[i];
|
||||
@@ -55,15 +53,15 @@ void CustomKernel::eval_gpu(
|
||||
if (in.ndim() > 0) {
|
||||
int ndim = in.ndim();
|
||||
if (shape_info.shape) {
|
||||
set_vector_bytes(compute_encoder, in.shape(), ndim, index);
|
||||
compute_encoder.set_vector_bytes(in.shape(), ndim, index);
|
||||
index++;
|
||||
}
|
||||
if (shape_info.strides) {
|
||||
set_vector_bytes(compute_encoder, in.strides(), ndim, index);
|
||||
compute_encoder.set_vector_bytes(in.strides(), ndim, index);
|
||||
index++;
|
||||
}
|
||||
if (shape_info.ndim) {
|
||||
compute_encoder->setBytes(&ndim, sizeof(int), index);
|
||||
compute_encoder.set_bytes(ndim, index);
|
||||
index++;
|
||||
}
|
||||
}
|
||||
@@ -74,15 +72,13 @@ void CustomKernel::eval_gpu(
|
||||
}
|
||||
|
||||
const auto [tx, ty, tz] = threadgroup_;
|
||||
MTL::Size group_dims = MTL::Size(tx, ty, tz);
|
||||
const auto [gx, gy, gz] = grid_;
|
||||
MTL::Size group_dims =
|
||||
MTL::Size(std::min(tx, gx), std::min(ty, gy), std::min(tz, gz));
|
||||
MTL::Size grid_dims = MTL::Size(gx, gy, gz);
|
||||
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
||||
|
||||
if (!copies.empty()) {
|
||||
d.get_command_buffer(s.index)->addCompletedHandler(
|
||||
[copies](MTL::CommandBuffer*) mutable { copies.clear(); });
|
||||
}
|
||||
d.add_temporaries(std::move(copies), s.index);
|
||||
}
|
||||
|
||||
} // namespace mlx::core::fast
|
||||
|
@@ -20,18 +20,21 @@ namespace {
|
||||
|
||||
// TODO nicer way to set this or possibly expose as an environment variable
|
||||
constexpr int MAX_BUFFERS_PER_QUEUE = 12;
|
||||
constexpr int MAX_DISPATCHES_PER_ENCODER = 2;
|
||||
|
||||
constexpr const char* default_mtllib_path = METAL_PATH;
|
||||
|
||||
constexpr auto get_metal_version() {
|
||||
#if (MLX_METAL_VERSION >= 320)
|
||||
return MTL::LanguageVersion3_2;
|
||||
#elif (MLX_METAL_VERSION >= 310)
|
||||
return MTL::LanguageVersion3_1;
|
||||
#else
|
||||
return MTL::LanguageVersion3_0;
|
||||
#endif
|
||||
auto get_metal_version() {
|
||||
auto get_metal_version_ = []() {
|
||||
if (__builtin_available(macOS 15, iOS 18, tvOS 18, visionOS 2, *)) {
|
||||
return MTL::LanguageVersion3_2;
|
||||
} else if (__builtin_available(macOS 14, iOS 17, tvOS 17, visionOS 1, *)) {
|
||||
return MTL::LanguageVersion3_1;
|
||||
} else {
|
||||
return MTL::LanguageVersion3_0;
|
||||
}
|
||||
};
|
||||
static auto metal_version_ = get_metal_version_();
|
||||
return metal_version_;
|
||||
}
|
||||
|
||||
auto load_device() {
|
||||
@@ -121,33 +124,29 @@ MTL::Library* load_library(
|
||||
|
||||
} // namespace
|
||||
|
||||
CommandEncoder::CommandEncoder(MTL::CommandBuffer* cbuf) : cbuf(cbuf) {
|
||||
enc = cbuf->computeCommandEncoder(MTL::DispatchTypeConcurrent);
|
||||
enc->retain();
|
||||
CommandEncoder::CommandEncoder(MTL::CommandBuffer* cbuf) {
|
||||
enc_ = cbuf->computeCommandEncoder(MTL::DispatchTypeConcurrent);
|
||||
enc_->retain();
|
||||
}
|
||||
|
||||
CommandEncoder::~CommandEncoder() {
|
||||
enc->endEncoding();
|
||||
enc->release();
|
||||
enc_->endEncoding();
|
||||
enc_->release();
|
||||
}
|
||||
|
||||
void CommandEncoder::set_input_array(
|
||||
const array& a,
|
||||
int idx,
|
||||
int64_t offset /* = 0 */) {
|
||||
all_inputs_.insert(a.buffer().ptr());
|
||||
auto r_buf = static_cast<MTL::Resource*>(const_cast<void*>(a.buffer().ptr()));
|
||||
if (auto it = outputs.find(r_buf); it != outputs.end()) {
|
||||
// Insert a barrier
|
||||
enc->memoryBarrier(&r_buf, 1);
|
||||
|
||||
// Remove the output
|
||||
outputs.erase(it);
|
||||
}
|
||||
needs_barrier_ =
|
||||
needs_barrier_ | (prev_outputs_.find(r_buf) != prev_outputs_.end());
|
||||
auto a_buf = static_cast<const MTL::Buffer*>(a.buffer().ptr());
|
||||
auto base_offset = a.data<char>() -
|
||||
static_cast<char*>(const_cast<MTL::Buffer*>(a_buf)->contents());
|
||||
base_offset += offset;
|
||||
enc->setBuffer(a_buf, base_offset, idx);
|
||||
enc_->setBuffer(a_buf, base_offset, idx);
|
||||
}
|
||||
|
||||
void CommandEncoder::set_output_array(
|
||||
@@ -156,55 +155,49 @@ void CommandEncoder::set_output_array(
|
||||
int64_t offset /* = 0 */) {
|
||||
// Add barriers before adding the output to the output set
|
||||
set_input_array(a, idx, offset);
|
||||
all_outputs_.insert(a.buffer().ptr());
|
||||
auto buf = static_cast<MTL::Resource*>(a.buffer().ptr());
|
||||
if (concurrent) {
|
||||
concurrent_outputs.insert(buf);
|
||||
if (concurrent_) {
|
||||
concurrent_outputs_.insert(buf);
|
||||
} else {
|
||||
outputs.insert(buf);
|
||||
next_outputs_.insert(buf);
|
||||
}
|
||||
}
|
||||
|
||||
void CommandEncoder::dispatchThreadgroups(
|
||||
MTL::Size grid_dims,
|
||||
MTL::Size group_dims) {
|
||||
num_dispatches++;
|
||||
enc->dispatchThreadgroups(grid_dims, group_dims);
|
||||
maybe_split();
|
||||
}
|
||||
|
||||
void CommandEncoder::dispatchThreads(
|
||||
MTL::Size grid_dims,
|
||||
MTL::Size group_dims) {
|
||||
num_dispatches++;
|
||||
enc->dispatchThreads(grid_dims, group_dims);
|
||||
maybe_split();
|
||||
}
|
||||
|
||||
void CommandEncoder::maybe_split() {
|
||||
if (num_dispatches > MAX_DISPATCHES_PER_ENCODER && !concurrent) {
|
||||
enc->endEncoding();
|
||||
enc->release();
|
||||
num_dispatches = 0;
|
||||
outputs.clear();
|
||||
enc = cbuf->computeCommandEncoder(MTL::DispatchTypeConcurrent);
|
||||
enc->retain();
|
||||
void CommandEncoder::maybeInsertBarrier() {
|
||||
if (needs_barrier_) {
|
||||
enc_->memoryBarrier(MTL::BarrierScopeBuffers);
|
||||
needs_barrier_ = false;
|
||||
prev_outputs_ = std::move(next_outputs_);
|
||||
} else {
|
||||
prev_outputs_.insert(next_outputs_.begin(), next_outputs_.end());
|
||||
}
|
||||
next_outputs_.clear();
|
||||
}
|
||||
|
||||
void CommandEncoder::dispatch_threadgroups(
|
||||
MTL::Size grid_dims,
|
||||
MTL::Size group_dims) {
|
||||
maybeInsertBarrier();
|
||||
enc_->dispatchThreadgroups(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
void CommandEncoder::dispatch_threads(
|
||||
MTL::Size grid_dims,
|
||||
MTL::Size group_dims) {
|
||||
maybeInsertBarrier();
|
||||
enc_->dispatchThreads(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
Device::Device() {
|
||||
auto pool = new_scoped_memory_pool();
|
||||
device_ = load_device();
|
||||
library_map_ = {{"mlx", load_library(device_)}};
|
||||
arch_ = std::string(device_->architecture()->name()->utf8String());
|
||||
}
|
||||
|
||||
Device::~Device() {
|
||||
auto pool = new_scoped_memory_pool();
|
||||
for (auto& q : queue_map_) {
|
||||
q.second->release();
|
||||
}
|
||||
for (auto& b : buffer_map_) {
|
||||
b.second.second->release();
|
||||
}
|
||||
for (auto& k : kernel_map_) {
|
||||
k.second->release();
|
||||
}
|
||||
@@ -219,69 +212,134 @@ void Device::new_queue(int index) {
|
||||
|
||||
// Multiple threads can ask the device for queues
|
||||
// We lock this as a critical section for safety
|
||||
const std::lock_guard<std::mutex> lock(mtx_);
|
||||
auto q = device_->newCommandQueue(MAX_BUFFERS_PER_QUEUE);
|
||||
debug_set_stream_queue_label(q, index);
|
||||
if (!q) {
|
||||
throw std::runtime_error(
|
||||
"[metal::Device] Failed to make new command queue.");
|
||||
}
|
||||
queue_map_.insert({index, q});
|
||||
stream_map_.emplace(index, q);
|
||||
if (residency_set_ != nullptr) {
|
||||
q->addResidencySet(residency_set_);
|
||||
}
|
||||
}
|
||||
|
||||
int Device::get_command_buffer_ops(int index) {
|
||||
auto bit = buffer_map_.find(index);
|
||||
return bit->second.first;
|
||||
return get_stream_(index).buffer_ops;
|
||||
}
|
||||
|
||||
void Device::increment_command_buffer_ops(int index) {
|
||||
auto bit = buffer_map_.find(index);
|
||||
bit->second.first++;
|
||||
get_stream_(index).buffer_ops++;
|
||||
}
|
||||
|
||||
MTL::CommandBuffer* Device::get_command_buffer(int index) {
|
||||
auto bit = buffer_map_.find(index);
|
||||
if (bit == buffer_map_.end()) {
|
||||
auto qit = queue_map_.find(index);
|
||||
if (qit == queue_map_.end()) {
|
||||
throw std::runtime_error(
|
||||
"[metal::Device] Attempting to get command buffer for invalid queue.");
|
||||
}
|
||||
|
||||
auto cb = qit->second->commandBufferWithUnretainedReferences();
|
||||
|
||||
if (!cb) {
|
||||
auto& stream = get_stream_(index);
|
||||
if (stream.buffer == nullptr) {
|
||||
stream.buffer = stream.queue->commandBufferWithUnretainedReferences();
|
||||
if (!stream.buffer) {
|
||||
throw std::runtime_error(
|
||||
"[metal::Device] Unable to create new command buffer");
|
||||
}
|
||||
|
||||
// Increment ref count so the buffer is not garbage collected
|
||||
cb->retain();
|
||||
|
||||
bit = buffer_map_.insert({index, {0, cb}}).first;
|
||||
stream.buffer->retain();
|
||||
}
|
||||
return bit->second.second;
|
||||
return stream.buffer;
|
||||
}
|
||||
|
||||
void Device::commit_command_buffer(int index) {
|
||||
auto bit = buffer_map_.find(index);
|
||||
bit->second.second->commit();
|
||||
bit->second.second->release();
|
||||
buffer_map_.erase(bit);
|
||||
auto& stream = get_stream_(index);
|
||||
stream.buffer->commit();
|
||||
stream.buffer->release();
|
||||
stream.buffer = nullptr;
|
||||
stream.buffer_ops = 0;
|
||||
}
|
||||
|
||||
void Device::add_temporary(array arr, int index) {
|
||||
get_stream_(index).temporaries.push_back(std::move(arr));
|
||||
}
|
||||
|
||||
void Device::add_temporaries(std::vector<array> arrays, int index) {
|
||||
if (arrays.empty()) {
|
||||
return;
|
||||
}
|
||||
auto& stream = get_stream_(index);
|
||||
stream.temporaries.insert(
|
||||
stream.temporaries.end(),
|
||||
std::make_move_iterator(arrays.begin()),
|
||||
std::make_move_iterator(arrays.end()));
|
||||
}
|
||||
|
||||
void Device::end_encoding(int index) {
|
||||
encoder_map_.erase(index);
|
||||
auto& stream = get_stream_(index);
|
||||
if (stream.encoder != nullptr) {
|
||||
// Each command encoder has a unique fence. We also store a map of
|
||||
// all previous outputs of command encoders to their corresponding fence.
|
||||
// - The command encoder records its inputs and outputs.
|
||||
// - Wait on a fence if any inputs in the encoder are outputs of a previous
|
||||
// encoder.
|
||||
// - Update the map of outputs to include this command encoder's outputs.
|
||||
// - Always signal this command encoders fence.
|
||||
// - Add a completion handler for this command encoder that removes outputs
|
||||
// from the map to limit the growth of the map and avoid unecessary waits
|
||||
// - Temporaries are a special case as they do not cross command encoder
|
||||
// boundaries. These can be removed early from the encoders inputs and
|
||||
// outputs since they don't need synchronization.
|
||||
auto& enc = *stream.encoder;
|
||||
// Remove temporaries from inputs and outputs
|
||||
for (auto& t : stream.temporaries) {
|
||||
if (t.data<void>() != nullptr) {
|
||||
enc.outputs().erase(t.buffer().ptr());
|
||||
enc.inputs().erase(t.buffer().ptr());
|
||||
}
|
||||
}
|
||||
|
||||
// Keep references to the fences we waited on and put them
|
||||
// in the completion handler so they are not prematurely released
|
||||
std::unordered_set<std::shared_ptr<Fence>> waiting_on;
|
||||
{
|
||||
std::lock_guard<std::mutex> lk(stream.fence_mtx);
|
||||
for (auto in : enc.inputs()) {
|
||||
if (auto it = stream.outputs.find(in); it != stream.outputs.end()) {
|
||||
// If we've already waited on a fence, don't wait on it again.
|
||||
if (waiting_on.find(it->second) == waiting_on.end()) {
|
||||
enc.wait_for_fence(it->second->fence);
|
||||
waiting_on.insert(it->second);
|
||||
}
|
||||
}
|
||||
}
|
||||
for (auto out : enc.outputs()) {
|
||||
stream.outputs[out] = stream.fence;
|
||||
}
|
||||
}
|
||||
enc.update_fence(stream.fence->fence);
|
||||
stream.buffer->addCompletedHandler(
|
||||
[&stream,
|
||||
waiting_on = std::move(waiting_on),
|
||||
fence = std::move(stream.fence),
|
||||
outputs = std::move(enc.outputs()),
|
||||
temporaries =
|
||||
std::move(stream.temporaries)](MTL::CommandBuffer*) mutable {
|
||||
temporaries.clear();
|
||||
std::lock_guard<std::mutex> lk(stream.fence_mtx);
|
||||
for (auto o : outputs) {
|
||||
if (auto it = stream.outputs.find(o); it != stream.outputs.end()) {
|
||||
if (it->second == fence) {
|
||||
stream.outputs.erase(it);
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
stream.encoder = nullptr;
|
||||
}
|
||||
|
||||
CommandEncoder& Device::get_command_encoder(int index) {
|
||||
auto eit = encoder_map_.find(index);
|
||||
if (eit == encoder_map_.end()) {
|
||||
auto cb = get_command_buffer(index);
|
||||
eit =
|
||||
encoder_map_.emplace(index, std::make_unique<CommandEncoder>(cb)).first;
|
||||
auto& stream = get_stream_(index);
|
||||
if (stream.encoder == nullptr) {
|
||||
stream.encoder = std::make_unique<CommandEncoder>(stream.buffer);
|
||||
stream.fence = std::make_shared<Fence>(device_->newFence());
|
||||
}
|
||||
return *(eit->second);
|
||||
return *stream.encoder;
|
||||
}
|
||||
|
||||
void Device::register_library(
|
||||
@@ -293,20 +351,7 @@ void Device::register_library(
|
||||
}
|
||||
}
|
||||
|
||||
MTL::Library* Device::get_library_cache_(const std::string& lib_name) {
|
||||
// Search for cached metal lib
|
||||
MTL::Library* mtl_lib;
|
||||
if (auto it = library_map_.find(lib_name); it != library_map_.end()) {
|
||||
mtl_lib = it->second;
|
||||
} else { // Look for metallib alongside library
|
||||
register_library(lib_name, get_colocated_mtllib_path(lib_name));
|
||||
mtl_lib = library_map_[lib_name];
|
||||
}
|
||||
|
||||
return mtl_lib;
|
||||
}
|
||||
|
||||
MTL::Library* Device::get_library_(const std::string& source_string) {
|
||||
MTL::Library* Device::build_library_(const std::string& source_string) {
|
||||
auto pool = new_scoped_memory_pool();
|
||||
|
||||
auto ns_code =
|
||||
@@ -322,26 +367,7 @@ MTL::Library* Device::get_library_(const std::string& source_string) {
|
||||
// Throw error if unable to compile library
|
||||
if (!mtl_lib) {
|
||||
std::ostringstream msg;
|
||||
msg << "[metal::Device] Unable to build metal library from source" << "\n";
|
||||
if (error) {
|
||||
msg << error->localizedDescription()->utf8String() << "\n";
|
||||
}
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
|
||||
return mtl_lib;
|
||||
}
|
||||
|
||||
MTL::Library* Device::get_library_(const MTL::StitchedLibraryDescriptor* desc) {
|
||||
auto pool = new_scoped_memory_pool();
|
||||
|
||||
NS::Error* error = nullptr;
|
||||
auto mtl_lib = device_->newLibrary(desc, &error);
|
||||
|
||||
// Throw error if unable to compile library
|
||||
if (!mtl_lib) {
|
||||
std::ostringstream msg;
|
||||
msg << "[metal::Device] Unable to build stitched metal library" << "\n";
|
||||
msg << "[metal::Device] Unable to build metal library from source\n";
|
||||
if (error) {
|
||||
msg << error->localizedDescription()->utf8String() << "\n";
|
||||
}
|
||||
@@ -465,68 +491,32 @@ MTL::ComputePipelineState* Device::get_kernel_(
|
||||
return kernel;
|
||||
}
|
||||
|
||||
MTL::Library* Device::get_library(const std::string& name) {
|
||||
MTL::Library* Device::get_library_(const std::string& name) {
|
||||
std::shared_lock lock(library_mtx_);
|
||||
auto it = library_map_.find(name);
|
||||
return (it != library_map_.end()) ? it->second : nullptr;
|
||||
}
|
||||
|
||||
MTL::Library* Device::get_library(
|
||||
const std::string& name,
|
||||
const std::string& source,
|
||||
bool cache /* = true */) {
|
||||
if (cache) {
|
||||
const std::function<std::string(void)>& builder) {
|
||||
{
|
||||
std::shared_lock rlock(library_mtx_);
|
||||
if (auto it = library_map_.find(name); it != library_map_.end()) {
|
||||
return it->second;
|
||||
}
|
||||
}
|
||||
|
||||
auto mtl_lib = get_library_(source);
|
||||
|
||||
if (cache) {
|
||||
library_map_.insert({name, mtl_lib});
|
||||
std::unique_lock wlock(library_mtx_);
|
||||
if (auto it = library_map_.find(name); it != library_map_.end()) {
|
||||
return it->second;
|
||||
}
|
||||
|
||||
auto mtl_lib = build_library_(builder());
|
||||
library_map_.insert({name, mtl_lib});
|
||||
return mtl_lib;
|
||||
}
|
||||
|
||||
MTL::Library* Device::get_library(
|
||||
const std::string& name,
|
||||
const MTL::StitchedLibraryDescriptor* desc,
|
||||
bool cache /* = true */) {
|
||||
if (cache) {
|
||||
if (auto it = library_map_.find(name); it != library_map_.end()) {
|
||||
return it->second;
|
||||
}
|
||||
}
|
||||
|
||||
auto mtl_lib = get_library_(desc);
|
||||
|
||||
if (cache) {
|
||||
library_map_.insert({name, mtl_lib});
|
||||
}
|
||||
|
||||
return mtl_lib;
|
||||
}
|
||||
|
||||
MTL::Function* Device::get_function(
|
||||
const std::string& base_name,
|
||||
MTL::Library* mtl_lib,
|
||||
const std::string& specialized_name /* = "" */,
|
||||
const MTLFCList& func_consts /* = {} */) {
|
||||
return get_function_(base_name, specialized_name, func_consts, mtl_lib);
|
||||
}
|
||||
|
||||
MTL::Function* Device::get_function(
|
||||
const std::string& base_name,
|
||||
const std::string& lib_name /* = "mlx" */,
|
||||
const std::string& specialized_name /* = "" */,
|
||||
const MTLFCList& func_consts /* = {} */) {
|
||||
// Search for cached metal lib
|
||||
MTL::Library* mtl_lib = get_library_cache_(lib_name);
|
||||
|
||||
return get_function(base_name, mtl_lib, specialized_name, func_consts);
|
||||
}
|
||||
|
||||
MTL::LinkedFunctions* Device::get_linked_functions_(
|
||||
const std::vector<MTL::Function*>& funcs) {
|
||||
if (funcs.empty()) {
|
||||
@@ -547,34 +537,55 @@ MTL::LinkedFunctions* Device::get_linked_functions_(
|
||||
return lfuncs;
|
||||
}
|
||||
|
||||
MTL::ComputePipelineState* Device::get_kernel_(
|
||||
const std::string& base_name,
|
||||
MTL::Library* mtl_lib,
|
||||
const std::string& hash_name,
|
||||
const MTLFCList& func_consts /* = {} */,
|
||||
const std::vector<MTL::Function*>& linked_functions /* = {} */) {
|
||||
// Single writer allowed
|
||||
std::unique_lock wlock(kernel_mtx_);
|
||||
|
||||
// Try loading again to avoid loading twice
|
||||
if (auto it = kernel_map_.find(hash_name); it != kernel_map_.end()) {
|
||||
return it->second;
|
||||
}
|
||||
|
||||
auto pool = new_scoped_memory_pool();
|
||||
|
||||
// Pull kernel from library
|
||||
auto mtl_function = get_function_(base_name, hash_name, func_consts, mtl_lib);
|
||||
|
||||
// Compile kernel to compute pipeline
|
||||
auto mtl_linked_funcs = get_linked_functions_(linked_functions);
|
||||
auto kernel = get_kernel_(hash_name, mtl_function, mtl_linked_funcs);
|
||||
|
||||
mtl_function->release();
|
||||
mtl_linked_funcs->release();
|
||||
|
||||
// Add kernel to cache
|
||||
auto inserted = kernel_map_.insert({hash_name, kernel});
|
||||
|
||||
return kernel;
|
||||
}
|
||||
|
||||
MTL::ComputePipelineState* Device::get_kernel(
|
||||
const std::string& base_name,
|
||||
MTL::Library* mtl_lib,
|
||||
const std::string& hash_name /* = "" */,
|
||||
const MTLFCList& func_consts /* = {} */,
|
||||
const std::vector<MTL::Function*>& linked_functions /* = {} */) {
|
||||
auto pool = new_scoped_memory_pool();
|
||||
|
||||
// Look for cached kernel
|
||||
const auto& kname = hash_name.empty() ? base_name : hash_name;
|
||||
if (auto it = kernel_map_.find(kname); it != kernel_map_.end()) {
|
||||
return it->second;
|
||||
{
|
||||
// Multiple readers allowed
|
||||
std::shared_lock lock(kernel_mtx_);
|
||||
|
||||
// Look for cached kernel
|
||||
if (auto it = kernel_map_.find(kname); it != kernel_map_.end()) {
|
||||
return it->second;
|
||||
}
|
||||
}
|
||||
|
||||
// Pull kernel from library
|
||||
auto mtl_function = get_function_(base_name, kname, func_consts, mtl_lib);
|
||||
|
||||
// Compile kernel to compute pipeline
|
||||
auto mtl_linked_funcs = get_linked_functions_(linked_functions);
|
||||
auto kernel = get_kernel_(kname, mtl_function, mtl_linked_funcs);
|
||||
|
||||
mtl_function->release();
|
||||
mtl_linked_funcs->release();
|
||||
|
||||
// Add kernel to cache
|
||||
kernel_map_.insert({kname, kernel});
|
||||
|
||||
return kernel;
|
||||
return get_kernel_(base_name, mtl_lib, kname, func_consts, linked_functions);
|
||||
}
|
||||
|
||||
MTL::ComputePipelineState* Device::get_kernel(
|
||||
@@ -583,16 +594,34 @@ MTL::ComputePipelineState* Device::get_kernel(
|
||||
const std::string& hash_name /* = "" */,
|
||||
const MTLFCList& func_consts /* = {} */,
|
||||
const std::vector<MTL::Function*>& linked_functions /* = {} */) {
|
||||
// Look for cached kernel
|
||||
const auto& kname = hash_name.size() == 0 ? base_name : hash_name;
|
||||
if (auto it = kernel_map_.find(kname); it != kernel_map_.end()) {
|
||||
return it->second;
|
||||
{
|
||||
// Multiple readers allowed
|
||||
std::shared_lock lock(kernel_mtx_);
|
||||
|
||||
// Look for cached kernel
|
||||
if (auto it = kernel_map_.find(kname); it != kernel_map_.end()) {
|
||||
return it->second;
|
||||
}
|
||||
}
|
||||
|
||||
// Search for cached metal lib
|
||||
MTL::Library* mtl_lib = get_library_cache_(lib_name);
|
||||
MTL::Library* mtl_lib = get_library_(lib_name);
|
||||
return get_kernel_(base_name, mtl_lib, kname, func_consts, linked_functions);
|
||||
}
|
||||
|
||||
return get_kernel(base_name, mtl_lib, kname, func_consts, linked_functions);
|
||||
void Device::set_residency_set(const MTL::ResidencySet* residency_set) {
|
||||
if (residency_set_ != nullptr) {
|
||||
throw std::runtime_error(
|
||||
"[Device::set_residency_set] Can only be set once.");
|
||||
}
|
||||
if (residency_set == nullptr) {
|
||||
return;
|
||||
}
|
||||
residency_set_ = residency_set;
|
||||
// Attach residency set to existing command queues
|
||||
for (auto& [_, stream] : stream_map_) {
|
||||
stream.queue->addResidencySet(residency_set_);
|
||||
}
|
||||
}
|
||||
|
||||
Device& device(mlx::core::Device) {
|
||||
|
@@ -7,6 +7,7 @@
|
||||
#include <filesystem>
|
||||
#include <functional>
|
||||
#include <mutex>
|
||||
#include <shared_mutex>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <unordered_set>
|
||||
@@ -44,43 +45,114 @@ struct CommandEncoder {
|
||||
|
||||
struct ConcurrentContext {
|
||||
ConcurrentContext(CommandEncoder& enc) : enc(enc) {
|
||||
enc.concurrent = true;
|
||||
enc.concurrent_ = true;
|
||||
}
|
||||
~ConcurrentContext() {
|
||||
enc.concurrent = false;
|
||||
enc.outputs.insert(
|
||||
enc.concurrent_outputs.begin(), enc.concurrent_outputs.end());
|
||||
enc.concurrent_outputs.clear();
|
||||
enc.concurrent_ = false;
|
||||
enc.prev_outputs_.insert(
|
||||
enc.concurrent_outputs_.begin(), enc.concurrent_outputs_.end());
|
||||
enc.concurrent_outputs_.clear();
|
||||
}
|
||||
|
||||
private:
|
||||
CommandEncoder& enc;
|
||||
};
|
||||
|
||||
MTL::ComputeCommandEncoder* operator->() {
|
||||
return enc;
|
||||
}
|
||||
|
||||
void set_input_array(const array& a, int idx, int64_t offset = 0);
|
||||
void set_output_array(array& a, int idx, int64_t offset = 0);
|
||||
void dispatchThreadgroups(MTL::Size grid_dims, MTL::Size group_dims);
|
||||
void dispatchThreads(MTL::Size grid_dims, MTL::Size group_dims);
|
||||
void dispatch_threadgroups(MTL::Size grid_dims, MTL::Size group_dims);
|
||||
void dispatch_threads(MTL::Size grid_dims, MTL::Size group_dims);
|
||||
void maybeInsertBarrier();
|
||||
|
||||
void set_compute_pipeline_state(MTL::ComputePipelineState* kernel) {
|
||||
enc_->setComputePipelineState(kernel);
|
||||
}
|
||||
|
||||
void wait_for_fence(MTL::Fence* fence) {
|
||||
enc_->waitForFence(fence);
|
||||
}
|
||||
|
||||
void update_fence(MTL::Fence* fence) {
|
||||
enc_->updateFence(fence);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void set_vector_bytes(const std::vector<T>& vec, size_t nelems, int idx) {
|
||||
enc_->setBytes(vec.data(), nelems * sizeof(T), idx);
|
||||
}
|
||||
template <typename T>
|
||||
void set_vector_bytes(const std::vector<T>& vec, int idx) {
|
||||
return set_vector_bytes(vec, vec.size(), idx);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void set_bytes(const T* v, int n, int idx) {
|
||||
return enc_->setBytes(v, n * sizeof(T), idx);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void set_bytes(const T& v, int idx) {
|
||||
return enc_->setBytes(&v, sizeof(T), idx);
|
||||
}
|
||||
|
||||
ConcurrentContext start_concurrent() {
|
||||
return ConcurrentContext(*this);
|
||||
}
|
||||
|
||||
~CommandEncoder();
|
||||
|
||||
private:
|
||||
void maybe_split();
|
||||
// Inputs to all kernels in the encoder including temporaries
|
||||
std::unordered_set<const void*>& inputs() {
|
||||
return all_inputs_;
|
||||
};
|
||||
|
||||
int num_dispatches{0};
|
||||
MTL::CommandBuffer* cbuf;
|
||||
MTL::ComputeCommandEncoder* enc;
|
||||
bool concurrent{false};
|
||||
std::unordered_set<MTL::Resource*> outputs;
|
||||
std::unordered_set<MTL::Resource*> concurrent_outputs;
|
||||
// Outputs of all kernels in the encoder including temporaries
|
||||
std::unordered_set<const void*> outputs() {
|
||||
return all_outputs_;
|
||||
};
|
||||
|
||||
private:
|
||||
MTL::ComputeCommandEncoder* enc_;
|
||||
bool needs_barrier_{false};
|
||||
bool concurrent_{false};
|
||||
std::unordered_set<MTL::Resource*> prev_outputs_;
|
||||
std::unordered_set<MTL::Resource*> next_outputs_;
|
||||
std::unordered_set<MTL::Resource*> concurrent_outputs_;
|
||||
std::unordered_set<const void*> all_inputs_;
|
||||
std::unordered_set<const void*> all_outputs_;
|
||||
};
|
||||
|
||||
struct Fence {
|
||||
Fence(MTL::Fence* fence) : fence(fence) {}
|
||||
~Fence() {
|
||||
fence->release();
|
||||
}
|
||||
MTL::Fence* fence;
|
||||
};
|
||||
|
||||
struct DeviceStream {
|
||||
DeviceStream(MTL::CommandQueue* queue) : queue(queue) {};
|
||||
~DeviceStream() {
|
||||
queue->release();
|
||||
if (buffer != nullptr) {
|
||||
buffer->release();
|
||||
}
|
||||
};
|
||||
MTL::CommandQueue* queue;
|
||||
// A map of prior command encoder outputs to their corresponding fence
|
||||
std::unordered_map<const void*, std::shared_ptr<Fence>> outputs;
|
||||
// Used to allow thread-safe access to the outputs map
|
||||
std::mutex fence_mtx;
|
||||
|
||||
// The buffer and buffer op count are updated
|
||||
// between command buffers
|
||||
MTL::CommandBuffer* buffer{nullptr};
|
||||
int buffer_ops{0};
|
||||
|
||||
// The command encoder, fence, and temporaries are updated between command
|
||||
// encoders
|
||||
std::unique_ptr<CommandEncoder> encoder{nullptr};
|
||||
std::shared_ptr<Fence> fence;
|
||||
std::vector<array> temporaries;
|
||||
};
|
||||
|
||||
class Device {
|
||||
@@ -94,6 +166,10 @@ class Device {
|
||||
return device_;
|
||||
};
|
||||
|
||||
const std::string& get_architecture() {
|
||||
return arch_;
|
||||
}
|
||||
|
||||
void new_queue(int index);
|
||||
MTL::CommandBuffer* get_command_buffer(int index);
|
||||
int get_command_buffer_ops(int index);
|
||||
@@ -114,29 +190,9 @@ class Device {
|
||||
}
|
||||
}
|
||||
|
||||
MTL::Library* get_library(const std::string& name);
|
||||
|
||||
MTL::Library* get_library(
|
||||
const std::string& name,
|
||||
const std::string& source_string,
|
||||
bool cache = true);
|
||||
|
||||
MTL::Library* get_library(
|
||||
const std::string& name,
|
||||
const MTL::StitchedLibraryDescriptor* desc,
|
||||
bool cache = true);
|
||||
|
||||
MTL::Function* get_function(
|
||||
const std::string& base_name,
|
||||
MTL::Library* mtl_lib,
|
||||
const std::string& specialized_name = "",
|
||||
const MTLFCList& func_consts = {});
|
||||
|
||||
MTL::Function* get_function(
|
||||
const std::string& base_name,
|
||||
const std::string& lib_name = "mlx",
|
||||
const std::string& specialized_name = "",
|
||||
const MTLFCList& func_consts = {});
|
||||
const std::function<std::string(void)>& builder);
|
||||
|
||||
MTL::ComputePipelineState* get_kernel(
|
||||
const std::string& base_name,
|
||||
@@ -155,11 +211,20 @@ class Device {
|
||||
MTL::ArgumentEncoder* argument_encoder(
|
||||
const std::vector<MTL::ArgumentDescriptor*>& arg_descs) const;
|
||||
|
||||
// Record temporary arrays for the given stream index
|
||||
void add_temporary(array arr, int index);
|
||||
void add_temporaries(std::vector<array> arrays, int index);
|
||||
|
||||
void set_residency_set(const MTL::ResidencySet* residency_set);
|
||||
|
||||
private:
|
||||
DeviceStream& get_stream_(int index) {
|
||||
return stream_map_.find(index)->second;
|
||||
}
|
||||
MTL::Library* get_library_cache_(const std::string& name);
|
||||
|
||||
MTL::Library* get_library_(const std::string& source_string);
|
||||
MTL::Library* get_library_(const MTL::StitchedLibraryDescriptor* desc);
|
||||
MTL::Library* get_library_(const std::string& name);
|
||||
MTL::Library* build_library_(const std::string& source_string);
|
||||
|
||||
MTL::Function* get_function_(const std::string& name, MTL::Library* mtl_lib);
|
||||
|
||||
@@ -181,13 +246,23 @@ class Device {
|
||||
const MTL::Function* mtl_function,
|
||||
const MTL::LinkedFunctions* linked_functions);
|
||||
|
||||
MTL::ComputePipelineState* get_kernel_(
|
||||
const std::string& base_name,
|
||||
MTL::Library* mtl_lib,
|
||||
const std::string& hash_name,
|
||||
const MTLFCList& func_consts = {},
|
||||
const std::vector<MTL::Function*>& linked_functions = {});
|
||||
|
||||
MTL::Device* device_;
|
||||
std::unordered_map<int32_t, MTL::CommandQueue*> queue_map_;
|
||||
std::unordered_map<int32_t, std::pair<int, MTL::CommandBuffer*>> buffer_map_;
|
||||
std::unordered_map<int32_t, std::unique_ptr<CommandEncoder>> encoder_map_;
|
||||
std::unordered_map<int32_t, DeviceStream> stream_map_;
|
||||
|
||||
std::shared_mutex kernel_mtx_;
|
||||
std::unordered_map<std::string, MTL::ComputePipelineState*> kernel_map_;
|
||||
|
||||
std::shared_mutex library_mtx_;
|
||||
std::unordered_map<std::string, MTL::Library*> library_map_;
|
||||
std::mutex mtx_;
|
||||
const MTL::ResidencySet* residency_set_{nullptr};
|
||||
std::string arch_;
|
||||
};
|
||||
|
||||
Device& device(mlx::core::Device);
|
||||
|
@@ -27,4 +27,9 @@ void Event::signal() {
|
||||
static_cast<MTL::SharedEvent*>(raw_event().get())->setSignaledValue(value());
|
||||
}
|
||||
|
||||
bool Event::is_signaled() const {
|
||||
return static_cast<MTL::SharedEvent*>(raw_event().get())->signaledValue() >=
|
||||
value();
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@@ -575,8 +575,7 @@ void fft_op(
|
||||
auto plan = plan_fft(n);
|
||||
if (plan.four_step) {
|
||||
four_step_fft(in, out, axis, inverse, real, plan, copies, s);
|
||||
d.get_command_buffer(s.index)->addCompletedHandler(
|
||||
[copies](MTL::CommandBuffer*) mutable { copies.clear(); });
|
||||
d.add_temporaries(std::move(copies), s.index);
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -700,7 +699,7 @@ void fft_op(
|
||||
auto kernel =
|
||||
get_fft_kernel(d, base_name, hash_name, func_consts, template_def);
|
||||
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
compute_encoder.set_input_array(in_contiguous, 0);
|
||||
compute_encoder.set_output_array(out, 1);
|
||||
|
||||
@@ -712,9 +711,9 @@ void fft_op(
|
||||
|
||||
compute_encoder.set_input_array(w_q, 2); // w_q
|
||||
compute_encoder.set_input_array(w_k, 3); // w_k
|
||||
compute_encoder->setBytes(&n, sizeof(int), 4);
|
||||
compute_encoder->setBytes(&plan.bluestein_n, sizeof(int), 5);
|
||||
compute_encoder->setBytes(&total_batch_size, sizeof(int), 6);
|
||||
compute_encoder.set_bytes(n, 4);
|
||||
compute_encoder.set_bytes(plan.bluestein_n, 5);
|
||||
compute_encoder.set_bytes(total_batch_size, 6);
|
||||
} else if (plan.rader_n > 1) {
|
||||
auto [b_q, g_q, g_minus_q] = compute_raders_constants(plan.rader_n, s);
|
||||
copies.push_back(b_q);
|
||||
@@ -724,25 +723,25 @@ void fft_op(
|
||||
compute_encoder.set_input_array(b_q, 2);
|
||||
compute_encoder.set_input_array(g_q, 3);
|
||||
compute_encoder.set_input_array(g_minus_q, 4);
|
||||
compute_encoder->setBytes(&n, sizeof(int), 5);
|
||||
compute_encoder->setBytes(&total_batch_size, sizeof(int), 6);
|
||||
compute_encoder->setBytes(&plan.rader_n, sizeof(int), 7);
|
||||
compute_encoder.set_bytes(n, 5);
|
||||
compute_encoder.set_bytes(total_batch_size, 6);
|
||||
compute_encoder.set_bytes(plan.rader_n, 7);
|
||||
} else if (four_step_params.required) {
|
||||
compute_encoder->setBytes(&four_step_params.n1, sizeof(int), 2);
|
||||
compute_encoder->setBytes(&four_step_params.n2, sizeof(int), 3);
|
||||
compute_encoder->setBytes(&total_batch_size, sizeof(int), 4);
|
||||
compute_encoder.set_bytes(four_step_params.n1, 2);
|
||||
compute_encoder.set_bytes(four_step_params.n2, 3);
|
||||
compute_encoder.set_bytes(total_batch_size, 4);
|
||||
} else {
|
||||
compute_encoder->setBytes(&n, sizeof(int), 2);
|
||||
compute_encoder->setBytes(&total_batch_size, sizeof(int), 3);
|
||||
compute_encoder.set_bytes(n, 2);
|
||||
compute_encoder.set_bytes(total_batch_size, 3);
|
||||
}
|
||||
|
||||
auto group_dims = MTL::Size(1, threadgroup_batch_size, threads_per_fft);
|
||||
auto grid_dims =
|
||||
MTL::Size(batch_size, threadgroup_batch_size, threads_per_fft);
|
||||
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
||||
}
|
||||
d.get_command_buffer(s.index)->addCompletedHandler(
|
||||
[copies](MTL::CommandBuffer*) mutable { copies.clear(); });
|
||||
|
||||
d.add_temporaries(std::move(copies), s.index);
|
||||
}
|
||||
|
||||
void fft_op(
|
||||
@@ -785,8 +784,7 @@ void nd_fft_op(
|
||||
}
|
||||
|
||||
auto& d = metal::device(s.device);
|
||||
d.get_command_buffer(s.index)->addCompletedHandler(
|
||||
[temp_arrs](MTL::CommandBuffer*) mutable { temp_arrs.clear(); });
|
||||
d.add_temporaries(std::move(temp_arrs), s.index);
|
||||
}
|
||||
|
||||
void FFT::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
|
@@ -60,32 +60,6 @@ std::string gen_hadamard_codelet(int m) {
|
||||
return source.str();
|
||||
}
|
||||
|
||||
void launch_hadamard(
|
||||
const array& in,
|
||||
array& out,
|
||||
int batch_size,
|
||||
int threads_per,
|
||||
const std::string kernel_name,
|
||||
float scale,
|
||||
const Stream& s) {
|
||||
auto& d = metal::device(s.device);
|
||||
|
||||
const auto& lib_name = kernel_name.substr(1);
|
||||
auto lib = d.get_library(lib_name);
|
||||
auto kernel = d.get_kernel(kernel_name, lib);
|
||||
assert(threads_per <= kernel->maxTotalThreadsPerThreadgroup());
|
||||
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
compute_encoder.set_input_array(in, 0);
|
||||
compute_encoder.set_output_array(out, 1);
|
||||
compute_encoder->setBytes(&scale, sizeof(float), 2);
|
||||
|
||||
MTL::Size group_dims = MTL::Size(1, threads_per, 1);
|
||||
MTL::Size grid_dims = MTL::Size(batch_size, threads_per, 1);
|
||||
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
void Hadamard::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
auto& s = stream();
|
||||
|
||||
@@ -113,7 +87,8 @@ void Hadamard::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
}
|
||||
|
||||
auto [n, m] = decompose_hadamard(in.shape(axis));
|
||||
int n, m;
|
||||
std::tie(n, m) = decompose_hadamard(in.shape(axis));
|
||||
|
||||
if (n * (int)size_of(in.dtype()) > MAX_HADAMARD_BYTES) {
|
||||
throw std::invalid_argument(
|
||||
@@ -129,8 +104,7 @@ void Hadamard::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
auto kernel_name = kname.str();
|
||||
auto& d = metal::device(s.device);
|
||||
const auto& lib_name = kernel_name;
|
||||
auto lib = d.get_library(lib_name);
|
||||
if (lib == nullptr) {
|
||||
auto lib = d.get_library(lib_name, [&]() {
|
||||
std::ostringstream kernel_source;
|
||||
auto codelet = gen_hadamard_codelet(m);
|
||||
kernel_source << metal::utils() << codelet << metal::hadamard();
|
||||
@@ -148,12 +122,31 @@ void Hadamard::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
n,
|
||||
m,
|
||||
read_width);
|
||||
lib = d.get_library(lib_name, kernel_source.str());
|
||||
}
|
||||
return kernel_source.str();
|
||||
});
|
||||
|
||||
int batch_size = in.size() / n;
|
||||
int threads_per = n / max_radix;
|
||||
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
|
||||
auto launch_hadamard = [&](const array& in,
|
||||
array& out,
|
||||
const std::string& kernel_name,
|
||||
float scale) {
|
||||
auto kernel = d.get_kernel(kernel_name, lib);
|
||||
assert(threads_per <= kernel->maxTotalThreadsPerThreadgroup());
|
||||
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
compute_encoder.set_input_array(in, 0);
|
||||
compute_encoder.set_output_array(out, 1);
|
||||
compute_encoder.set_bytes(scale, 2);
|
||||
|
||||
MTL::Size group_dims = MTL::Size(1, threads_per, 1);
|
||||
MTL::Size grid_dims = MTL::Size(batch_size, threads_per, 1);
|
||||
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
||||
};
|
||||
|
||||
if (m > 1) {
|
||||
// When m is greater than 1, we decompose the
|
||||
// computation into two uploads to the GPU:
|
||||
@@ -171,33 +164,17 @@ void Hadamard::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
temp.set_data(allocator::malloc_or_wait(temp.nbytes()));
|
||||
copies.push_back(temp);
|
||||
|
||||
launch_hadamard(
|
||||
in_contiguous,
|
||||
temp,
|
||||
batch_size,
|
||||
threads_per,
|
||||
"n" + kernel_name,
|
||||
1.0,
|
||||
s);
|
||||
launch_hadamard(in_contiguous, temp, "n" + kernel_name, 1.0);
|
||||
|
||||
// Metal sometimes reports 256 max threads per group for hadamard_m kernel
|
||||
threads_per = std::min(n / read_width, MAX_HADAMARD_THREADS_PER_GROUP);
|
||||
batch_size = in.size() / m / read_width / threads_per;
|
||||
launch_hadamard(
|
||||
temp, out, batch_size, threads_per, "m" + kernel_name, scale_, s);
|
||||
launch_hadamard(temp, out, "m" + kernel_name, scale_);
|
||||
} else {
|
||||
launch_hadamard(
|
||||
in_contiguous,
|
||||
out,
|
||||
batch_size,
|
||||
threads_per,
|
||||
"n" + kernel_name,
|
||||
scale_,
|
||||
s);
|
||||
launch_hadamard(in_contiguous, out, "n" + kernel_name, scale_);
|
||||
}
|
||||
|
||||
d.get_command_buffer(s.index)->addCompletedHandler(
|
||||
[copies](MTL::CommandBuffer*) mutable { copies.clear(); });
|
||||
d.add_temporaries(std::move(copies), s.index);
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@@ -53,28 +53,31 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
int idx_ndim = nidx ? inputs[1].ndim() : 0;
|
||||
size_t ndim = src.ndim();
|
||||
|
||||
std::string lib_name;
|
||||
std::string kernel_name;
|
||||
std::string idx_type_name = nidx ? type_to_name(inputs[1]) : "";
|
||||
{
|
||||
std::ostringstream kname;
|
||||
kname << "gather" << type_to_name(out) << idx_type_name << "_" << nidx
|
||||
<< "_" << idx_ndim;
|
||||
lib_name = kname.str();
|
||||
kernel_name = lib_name;
|
||||
}
|
||||
bool large_index = nidx && inputs[1].size() > UINT32_MAX;
|
||||
bool large_src = src.size() > UINT32_MAX;
|
||||
bool large_out = out.size() > UINT32_MAX;
|
||||
bool large = large_index || large_src || large_out;
|
||||
|
||||
auto lib = d.get_library(lib_name);
|
||||
if (lib == nullptr) {
|
||||
std::ostringstream kernel_source;
|
||||
kernel_source << metal::utils() << metal::gather();
|
||||
std::string idx_type_name = nidx ? type_to_name(inputs[1]) : "";
|
||||
std::string kernel_name = fmt::format(
|
||||
"gather{0}{1}_{2}_{3}_{4}",
|
||||
type_to_name(out),
|
||||
idx_type_name,
|
||||
nidx,
|
||||
idx_ndim,
|
||||
large ? "size_t" : "uint");
|
||||
std::string lib_name = kernel_name;
|
||||
|
||||
auto lib = d.get_library(lib_name, [&]() {
|
||||
std::string kernel_source = metal::utils();
|
||||
kernel_source += metal::gather();
|
||||
std::string out_type_str = get_type_string(out.dtype());
|
||||
std::string idx_type_str =
|
||||
nidx ? get_type_string(inputs[1].dtype()) : "bool";
|
||||
auto [idx_args, idx_arr] = make_index_args(idx_type_str, nidx);
|
||||
|
||||
// Index dimension specializations
|
||||
kernel_source << fmt::format(
|
||||
kernel_source += fmt::format(
|
||||
gather_kernels,
|
||||
type_to_name(out) + idx_type_name,
|
||||
out_type_str,
|
||||
@@ -82,13 +85,14 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
nidx,
|
||||
idx_args,
|
||||
idx_arr,
|
||||
idx_ndim);
|
||||
lib = d.get_library(lib_name, kernel_source.str());
|
||||
}
|
||||
idx_ndim,
|
||||
large ? "size_t" : "uint");
|
||||
return kernel_source;
|
||||
});
|
||||
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kernel_name, lib);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
size_t slice_size = 1;
|
||||
for (auto s : slice_sizes_) {
|
||||
@@ -114,17 +118,17 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
// Collect all idx shapes and strides into one place
|
||||
std::vector<int> idx_shapes;
|
||||
std::vector<size_t> idx_strides;
|
||||
|
||||
std::vector<char> idx_contigs;
|
||||
for (int i = 0; i < nidx; ++i) {
|
||||
idx_shapes.insert(
|
||||
idx_shapes.end(),
|
||||
inputs[i + 1].shape().begin(),
|
||||
inputs[i + 1].shape().end());
|
||||
|
||||
idx_strides.insert(
|
||||
idx_strides.end(),
|
||||
inputs[i + 1].strides().begin(),
|
||||
inputs[i + 1].strides().end());
|
||||
idx_contigs.push_back(inputs[i + 1].flags().row_contiguous);
|
||||
}
|
||||
|
||||
// Set all the buffers
|
||||
@@ -132,21 +136,20 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
compute_encoder.set_output_array(out, 1);
|
||||
|
||||
// Set source info
|
||||
compute_encoder->setBytes(src.shape().data(), ndim * sizeof(int), 2);
|
||||
compute_encoder->setBytes(src.strides().data(), ndim * sizeof(size_t), 3);
|
||||
compute_encoder->setBytes(&ndim, sizeof(size_t), 4);
|
||||
compute_encoder->setBytes(slice_sizes_.data(), ndim * sizeof(int), 5);
|
||||
compute_encoder->setBytes(axes_.data(), nidx * sizeof(int), 6);
|
||||
compute_encoder.set_vector_bytes(src.shape(), 2);
|
||||
compute_encoder.set_vector_bytes(src.strides(), 3);
|
||||
compute_encoder.set_bytes(ndim, 4);
|
||||
compute_encoder.set_vector_bytes(slice_sizes_, 5);
|
||||
compute_encoder.set_vector_bytes(axes_, 6);
|
||||
|
||||
// Set index info
|
||||
//
|
||||
// We don't need to check for empty idx_shapes because gather has a
|
||||
// idx_ndim == 0 specialization
|
||||
compute_encoder->setBytes(
|
||||
idx_shapes.data(), idx_shapes.size() * sizeof(int), 7);
|
||||
compute_encoder->setBytes(
|
||||
idx_strides.data(), idx_strides.size() * sizeof(size_t), 8);
|
||||
compute_encoder->setBytes(&idx_ndim, sizeof(int), 9);
|
||||
compute_encoder.set_vector_bytes(idx_shapes, 7);
|
||||
compute_encoder.set_vector_bytes(idx_strides, 8);
|
||||
compute_encoder.set_vector_bytes(idx_contigs, 9);
|
||||
compute_encoder.set_bytes(idx_ndim, 10);
|
||||
|
||||
// Set index buffers
|
||||
for (int i = 0; i < nidx; ++i) {
|
||||
@@ -154,7 +157,7 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
}
|
||||
|
||||
// Launch grid
|
||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
@@ -173,12 +176,20 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
}
|
||||
|
||||
// Copy src into out
|
||||
auto copy_type =
|
||||
inputs[0].data_size() == 1 ? CopyType::Scalar : CopyType::General;
|
||||
CopyType copy_type;
|
||||
if (inputs[0].data_size() == 1) {
|
||||
copy_type = CopyType::Scalar;
|
||||
} else if (inputs[0].flags().row_contiguous) {
|
||||
copy_type = CopyType::Vector;
|
||||
} else {
|
||||
copy_type = CopyType::General;
|
||||
}
|
||||
copy_gpu(inputs[0], out, copy_type);
|
||||
|
||||
auto& upd = inputs.back();
|
||||
|
||||
// Empty update
|
||||
if (inputs.back().size() == 0) {
|
||||
if (upd.size() == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -187,23 +198,22 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
auto& d = metal::device(s.device);
|
||||
|
||||
int idx_ndim = nidx ? inputs[1].ndim() : 0;
|
||||
bool index_nd1_specialization = (idx_ndim == 1);
|
||||
size_t idx_size = nidx ? inputs[1].size() : 1;
|
||||
|
||||
// Bail from fast path (1d index specialization) if scatter dims aren't
|
||||
// the outermost dims and contiguous since update access won't be raster
|
||||
// order.
|
||||
for (auto i = 0; i < axes_.size() && index_nd1_specialization; i++) {
|
||||
index_nd1_specialization &= (axes_[i] == i);
|
||||
auto idx_to_out = idx_size / out.size();
|
||||
int nwork;
|
||||
if (idx_ndim <= 1 || idx_to_out < 1) {
|
||||
nwork = 1;
|
||||
} else if (idx_to_out <= 4) {
|
||||
nwork = 4;
|
||||
} else if (idx_to_out < 16) {
|
||||
nwork = 8;
|
||||
} else if (idx_to_out < 32) {
|
||||
nwork = 16;
|
||||
} else {
|
||||
nwork = 32;
|
||||
}
|
||||
|
||||
// Bail from fast path (1d index specialization) if any of the dims are
|
||||
// broadcasted, since we can't rely on linear indexing in that case.
|
||||
for (int i = 1; i < inputs.size() && index_nd1_specialization; i++) {
|
||||
index_nd1_specialization &= inputs[i].flags().row_contiguous;
|
||||
}
|
||||
|
||||
std::string lib_name;
|
||||
std::string kernel_name;
|
||||
std::string idx_type_name = nidx ? type_to_name(inputs[1]) : "";
|
||||
std::string op_name;
|
||||
switch (reduce_type_) {
|
||||
@@ -223,24 +233,25 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
op_name = "min";
|
||||
break;
|
||||
}
|
||||
auto upd_contig = upd.flags().row_contiguous;
|
||||
bool large_out = out.size() > UINT32_MAX;
|
||||
bool large_idx = nidx && (inputs[1].size() > UINT32_MAX);
|
||||
bool large_upd = upd.size() > UINT32_MAX;
|
||||
bool large = large_out || large_idx || large_upd;
|
||||
std::string kernel_name = fmt::format(
|
||||
"scatter{0}{1}_{2}_{3}_{4}_nwork{5}_{6}",
|
||||
type_to_name(out),
|
||||
idx_type_name,
|
||||
op_name,
|
||||
nidx,
|
||||
upd_contig ? "updc_true" : "updc_false",
|
||||
nwork,
|
||||
large ? "size_t" : "uint");
|
||||
std::string lib_name = kernel_name;
|
||||
|
||||
{
|
||||
std::ostringstream kname;
|
||||
if (index_nd1_specialization) {
|
||||
kname << "scatter_1d_index" << type_to_name(out) << idx_type_name;
|
||||
} else {
|
||||
kname << "scatter" << type_to_name(out) << idx_type_name;
|
||||
}
|
||||
kname << "_" << op_name << "_" << nidx;
|
||||
lib_name = kname.str();
|
||||
kernel_name = kname.str();
|
||||
}
|
||||
|
||||
auto lib = d.get_library(lib_name);
|
||||
if (lib == nullptr) {
|
||||
std::ostringstream kernel_source;
|
||||
kernel_source << metal::utils() << metal::reduce_utils()
|
||||
<< metal::scatter();
|
||||
auto lib = d.get_library(lib_name, [&]() {
|
||||
std::string kernel_source = metal::utils();
|
||||
concatenate(kernel_source, metal::reduce_utils(), metal::scatter());
|
||||
|
||||
std::string out_type_str = get_type_string(out.dtype());
|
||||
std::string idx_type_str =
|
||||
@@ -264,11 +275,11 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
break;
|
||||
}
|
||||
if (reduce_type_ != Scatter::None) {
|
||||
op_type = fmt::format(op_type, out_type_str);
|
||||
op_type = fmt::format(fmt::runtime(op_type), out_type_str);
|
||||
}
|
||||
auto [idx_args, idx_arr] = make_index_args(idx_type_str, nidx);
|
||||
|
||||
kernel_source << fmt::format(
|
||||
kernel_source += fmt::format(
|
||||
scatter_kernels,
|
||||
type_to_name(out) + idx_type_name + "_" + op_name,
|
||||
out_type_str,
|
||||
@@ -276,126 +287,105 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
op_type,
|
||||
nidx,
|
||||
idx_args,
|
||||
idx_arr);
|
||||
lib = d.get_library(lib_name, kernel_source.str());
|
||||
}
|
||||
idx_arr,
|
||||
upd_contig,
|
||||
nwork,
|
||||
large ? "size_t" : "uint");
|
||||
return kernel_source;
|
||||
});
|
||||
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kernel_name, lib);
|
||||
|
||||
auto& upd = inputs.back();
|
||||
size_t nthreads = upd.size();
|
||||
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
// Set all the buffers
|
||||
compute_encoder.set_input_array(upd, 1);
|
||||
compute_encoder.set_output_array(out, 2);
|
||||
|
||||
// Set update info
|
||||
uint upd_ndim = upd.ndim();
|
||||
size_t upd_ndim = upd.ndim();
|
||||
size_t upd_size = 1;
|
||||
for (int i = idx_ndim; i < upd.ndim(); ++i) {
|
||||
upd_size *= upd.shape(i);
|
||||
}
|
||||
if (index_nd1_specialization) {
|
||||
compute_encoder->setBytes(
|
||||
out.shape().data(), out.shape().size() * sizeof(int), 3);
|
||||
compute_encoder->setBytes(
|
||||
out.strides().data(), out.strides().size() * sizeof(size_t), 4);
|
||||
|
||||
size_t out_ndim = out.ndim();
|
||||
compute_encoder->setBytes(&out_ndim, sizeof(out_ndim), 5);
|
||||
if (upd_ndim <= 1) {
|
||||
// Placeholder so Metal doesn't compalain
|
||||
int shape_ = 0;
|
||||
compute_encoder->setBytes(&shape_, sizeof(int), 6);
|
||||
} else {
|
||||
compute_encoder->setBytes(upd.shape().data(), upd_ndim * sizeof(int), 6);
|
||||
}
|
||||
compute_encoder->setBytes(&upd_ndim, sizeof(size_t), 7);
|
||||
compute_encoder->setBytes(&upd_size, sizeof(size_t), 8);
|
||||
|
||||
// Set index buffers
|
||||
for (int i = 0; i < nidx; ++i) {
|
||||
compute_encoder.set_input_array(inputs[i + 1], 20 + i);
|
||||
}
|
||||
|
||||
// Launch grid
|
||||
MTL::Size grid_dims = MTL::Size(upd_size, nthreads / upd_size, 1);
|
||||
MTL::Size group_dims = get_block_dims(upd_size, nthreads / upd_size, 1);
|
||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||
|
||||
} else {
|
||||
// Collect all idx shapes and strides into one place
|
||||
std::vector<int> idx_shapes;
|
||||
std::vector<size_t> idx_strides;
|
||||
|
||||
for (int i = 0; i < nidx; ++i) {
|
||||
idx_shapes.insert(
|
||||
idx_shapes.end(),
|
||||
inputs[i + 1].shape().begin(),
|
||||
inputs[i + 1].shape().end());
|
||||
|
||||
idx_strides.insert(
|
||||
idx_strides.end(),
|
||||
inputs[i + 1].strides().begin(),
|
||||
inputs[i + 1].strides().end());
|
||||
}
|
||||
|
||||
if (upd_ndim == 0) {
|
||||
// Need placeholders so Metal doesn't compalain
|
||||
int shape_ = 0;
|
||||
size_t stride_ = 0;
|
||||
compute_encoder->setBytes(&shape_, sizeof(int), 3);
|
||||
compute_encoder->setBytes(&stride_, sizeof(size_t), 4);
|
||||
} else {
|
||||
compute_encoder->setBytes(upd.shape().data(), upd_ndim * sizeof(int), 3);
|
||||
compute_encoder->setBytes(
|
||||
upd.strides().data(), upd_ndim * sizeof(size_t), 4);
|
||||
}
|
||||
compute_encoder->setBytes(&upd_ndim, sizeof(size_t), 5);
|
||||
compute_encoder->setBytes(&upd_size, sizeof(size_t), 6);
|
||||
|
||||
// Set output info
|
||||
size_t out_ndim = out.ndim();
|
||||
if (out_ndim == 0) {
|
||||
// Need placeholders so Metal doesn't compalain
|
||||
int shape_ = 0;
|
||||
size_t stride_ = 0;
|
||||
compute_encoder->setBytes(&shape_, sizeof(int), 7);
|
||||
compute_encoder->setBytes(&stride_, sizeof(size_t), 8);
|
||||
} else {
|
||||
compute_encoder->setBytes(out.shape().data(), out_ndim * sizeof(int), 7);
|
||||
compute_encoder->setBytes(
|
||||
out.strides().data(), out_ndim * sizeof(size_t), 8);
|
||||
}
|
||||
compute_encoder->setBytes(&out_ndim, sizeof(size_t), 9);
|
||||
compute_encoder->setBytes(axes_.data(), axes_.size() * sizeof(int), 10);
|
||||
|
||||
// Set index info
|
||||
if (idx_ndim == 0) {
|
||||
// Add a 0 in idx_shapes and strides to avoid the missing buffer binding
|
||||
// error in the metal API.
|
||||
idx_shapes.push_back(0);
|
||||
idx_strides.push_back(0);
|
||||
}
|
||||
compute_encoder->setBytes(
|
||||
idx_shapes.data(), idx_shapes.size() * sizeof(int), 11);
|
||||
compute_encoder->setBytes(
|
||||
idx_strides.data(), idx_strides.size() * sizeof(size_t), 12);
|
||||
compute_encoder->setBytes(&idx_ndim, sizeof(int), 13);
|
||||
|
||||
// Set index buffers
|
||||
for (int i = 0; i < nidx; ++i) {
|
||||
compute_encoder.set_input_array(inputs[i + 1], 20 + i);
|
||||
}
|
||||
|
||||
// Launch grid
|
||||
MTL::Size grid_dims = MTL::Size(upd_size, nthreads / upd_size, 1);
|
||||
MTL::Size group_dims = get_block_dims(upd_size, nthreads / upd_size, 1);
|
||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||
// Collect all idx shapes and strides into one place
|
||||
std::vector<int> idx_shapes;
|
||||
std::vector<size_t> idx_strides;
|
||||
// To access .data() use char instead of bool
|
||||
// bool is 1 byte in Metal so this is safe
|
||||
std::vector<char> idx_contigs;
|
||||
for (int i = 0; i < nidx; ++i) {
|
||||
idx_shapes.insert(
|
||||
idx_shapes.end(),
|
||||
inputs[i + 1].shape().begin(),
|
||||
inputs[i + 1].shape().end());
|
||||
idx_strides.insert(
|
||||
idx_strides.end(),
|
||||
inputs[i + 1].strides().begin(),
|
||||
inputs[i + 1].strides().end());
|
||||
idx_contigs.push_back(inputs[i + 1].flags().row_contiguous);
|
||||
}
|
||||
|
||||
if (upd_ndim == 0) {
|
||||
// Need placeholders so Metal doesn't compalain
|
||||
int shape_ = 0;
|
||||
size_t stride_ = 0;
|
||||
compute_encoder.set_bytes(shape_, 3);
|
||||
compute_encoder.set_bytes(stride_, 4);
|
||||
} else {
|
||||
compute_encoder.set_vector_bytes(upd.shape(), 3);
|
||||
compute_encoder.set_vector_bytes(upd.strides(), 4);
|
||||
}
|
||||
compute_encoder.set_bytes(upd_ndim, 5);
|
||||
compute_encoder.set_bytes(upd_size, 6);
|
||||
|
||||
// Set output info
|
||||
size_t out_ndim = out.ndim();
|
||||
if (out_ndim == 0) {
|
||||
// Need placeholders so Metal doesn't compalain
|
||||
int shape_ = 0;
|
||||
size_t stride_ = 0;
|
||||
compute_encoder.set_bytes(shape_, 7);
|
||||
compute_encoder.set_bytes(stride_, 8);
|
||||
} else {
|
||||
compute_encoder.set_vector_bytes(out.shape(), 7);
|
||||
compute_encoder.set_vector_bytes(out.strides(), 8);
|
||||
}
|
||||
compute_encoder.set_bytes(out_ndim, 9);
|
||||
compute_encoder.set_vector_bytes(axes_, 10);
|
||||
|
||||
// Set index info
|
||||
if (idx_ndim == 0) {
|
||||
// Add a 0 in idx_shapes and strides to avoid the missing buffer binding
|
||||
// error in the metal API.
|
||||
idx_shapes.push_back(0);
|
||||
idx_strides.push_back(0);
|
||||
idx_contigs.push_back(false);
|
||||
}
|
||||
compute_encoder.set_vector_bytes(idx_shapes, 11);
|
||||
compute_encoder.set_vector_bytes(idx_strides, 12);
|
||||
compute_encoder.set_vector_bytes(idx_contigs, 13);
|
||||
compute_encoder.set_bytes(idx_ndim, 14);
|
||||
compute_encoder.set_bytes(idx_size, 15);
|
||||
|
||||
// Set index buffers
|
||||
for (int i = 0; i < nidx; ++i) {
|
||||
compute_encoder.set_input_array(inputs[i + 1], 20 + i);
|
||||
}
|
||||
|
||||
// Launch grid
|
||||
auto grid_y = (nthreads / upd_size);
|
||||
grid_y = (grid_y + nwork - 1) / nwork;
|
||||
MTL::Size grid_dims = MTL::Size(upd_size, grid_y, 1);
|
||||
auto thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
||||
if (thread_group_size != 1024) {
|
||||
throw std::runtime_error("[Scatter::eval_gpu] Invalid number of threads");
|
||||
}
|
||||
MTL::Size group_dims = get_block_dims(upd_size, grid_y, 1);
|
||||
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@@ -1,7 +1,7 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
constexpr std::string_view gather_kernels = R"(
|
||||
[[kernel]] void gather{0}_{3}_{6}(
|
||||
[[kernel]] void gather{0}_{3}_{6}_{7}(
|
||||
const device {1}* src [[buffer(0)]],
|
||||
device {1}* out [[buffer(1)]],
|
||||
const constant int* src_shape [[buffer(2)]],
|
||||
@@ -11,14 +11,15 @@ constexpr std::string_view gather_kernels = R"(
|
||||
const constant int* axes [[buffer(6)]],
|
||||
const constant int* idx_shapes [[buffer(7)]],
|
||||
const constant size_t* idx_strides [[buffer(8)]],
|
||||
const constant int& idx_ndim [[buffer(9)]],
|
||||
const constant bool* idx_contigs [[buffer(9)]],
|
||||
const constant int& idx_ndim [[buffer(10)]],
|
||||
{4}
|
||||
uint3 index [[thread_position_in_grid]],
|
||||
uint3 grid_dim [[threads_per_grid]]) {{
|
||||
Indices<{2}, {3}> idxs{{
|
||||
{{ {5} }}, idx_shapes, idx_strides, idx_ndim}};
|
||||
{{ {5} }}, idx_shapes, idx_strides, idx_contigs, idx_ndim}};
|
||||
|
||||
return gather_impl<{1}, {2}, {3}, {6}>(
|
||||
return gather_impl<{1}, {2}, {3}, {6}, {7}>(
|
||||
src,
|
||||
out,
|
||||
src_shape,
|
||||
@@ -33,32 +34,7 @@ constexpr std::string_view gather_kernels = R"(
|
||||
)";
|
||||
|
||||
constexpr std::string_view scatter_kernels = R"(
|
||||
[[kernel]] void scatter_1d_index{0}_{4}(
|
||||
const device {1}* updates [[buffer(1)]],
|
||||
device mlx_atomic<{1}>* out [[buffer(2)]],
|
||||
const constant int* out_shape [[buffer(3)]],
|
||||
const constant size_t* out_strides [[buffer(4)]],
|
||||
const constant size_t& out_ndim [[buffer(5)]],
|
||||
const constant int* upd_shape [[buffer(6)]],
|
||||
const constant size_t& upd_ndim [[buffer(7)]],
|
||||
const constant size_t& upd_size [[buffer(8)]],
|
||||
{5}
|
||||
uint2 gid [[thread_position_in_grid]]) {{
|
||||
const array<const device {2}*, {4}> idx_buffers = {{ {6} }};
|
||||
return scatter_1d_index_impl<{1}, {2}, {3}, {4}>(
|
||||
updates,
|
||||
out,
|
||||
out_shape,
|
||||
out_strides,
|
||||
out_ndim,
|
||||
upd_shape,
|
||||
upd_ndim,
|
||||
upd_size,
|
||||
idx_buffers,
|
||||
gid);
|
||||
}}
|
||||
|
||||
[[kernel]] void scatter{0}_{4}(
|
||||
[[kernel]] void scatter{0}_{4}_updc_{7}_nwork{8}_{9}(
|
||||
const device {1}* updates [[buffer(1)]],
|
||||
device mlx_atomic<{1}>* out [[buffer(2)]],
|
||||
const constant int* upd_shape [[buffer(3)]],
|
||||
@@ -71,12 +47,14 @@ constexpr std::string_view scatter_kernels = R"(
|
||||
const constant int* axes [[buffer(10)]],
|
||||
const constant int* idx_shapes [[buffer(11)]],
|
||||
const constant size_t* idx_strides [[buffer(12)]],
|
||||
const constant int& idx_ndim [[buffer(13)]],
|
||||
const constant bool* idx_contigs [[buffer(13)]],
|
||||
const constant int& idx_ndim [[buffer(14)]],
|
||||
const constant size_t& idx_size [[buffer(15)]],
|
||||
{5}
|
||||
uint2 gid [[thread_position_in_grid]]) {{
|
||||
Indices<{2}, {4}> idxs{{ {{ {6} }}, idx_shapes, idx_strides, idx_ndim}};
|
||||
Indices<{2}, {4}> idxs{{ {{ {6} }}, idx_shapes, idx_strides, idx_contigs, idx_ndim}};
|
||||
|
||||
return scatter_impl<{1}, {2}, {3}, {4}>(
|
||||
return scatter_impl<{1}, {2}, {3}, {4}, {7}, {8}, {9}>(
|
||||
updates,
|
||||
out,
|
||||
upd_shape,
|
||||
@@ -87,6 +65,7 @@ constexpr std::string_view scatter_kernels = R"(
|
||||
out_strides,
|
||||
out_ndim,
|
||||
axes,
|
||||
idx_size,
|
||||
idxs,
|
||||
gid);
|
||||
}}
|
||||
|
@@ -1,26 +0,0 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
constexpr std::string_view scan_kernels = R"(
|
||||
template [[host_name("contig_{0}")]] [[kernel]] void
|
||||
contiguous_scan<{1}, {2}, {3}<{2}>, 4, {4}, {5}>(
|
||||
const device {1}* in [[buffer(0)]],
|
||||
device {2}* out [[buffer(1)]],
|
||||
const constant size_t& axis_size [[buffer(2)]],
|
||||
uint gid [[thread_position_in_grid]],
|
||||
uint lid [[thread_position_in_threadgroup]],
|
||||
uint lsize [[threads_per_threadgroup]],
|
||||
uint simd_size [[threads_per_simdgroup]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
|
||||
|
||||
template [[host_name("strided_{0}")]] [[kernel]] void
|
||||
strided_scan<{1}, {2}, {3}<{2}>, 4, {4}, {5}>(
|
||||
const device {1}* in [[buffer(0)]],
|
||||
device {2}* out [[buffer(1)]],
|
||||
const constant size_t& axis_size [[buffer(2)]],
|
||||
const constant size_t& stride [[buffer(3)]],
|
||||
uint2 gid [[thread_position_in_grid]],
|
||||
uint2 lid [[thread_position_in_threadgroup]],
|
||||
uint2 lsize [[threads_per_threadgroup]],
|
||||
uint simd_size [[threads_per_simdgroup]]);
|
||||
)";
|
@@ -1,10 +1,8 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#include "mlx/backend/common/compiled.h"
|
||||
#include "mlx/backend/metal/jit/arange.h"
|
||||
#include "mlx/backend/metal/jit/gemv_masked.h"
|
||||
#include "mlx/backend/metal/jit/includes.h"
|
||||
#include "mlx/backend/metal/jit/scan.h"
|
||||
#include "mlx/backend/metal/jit/softmax.h"
|
||||
#include "mlx/backend/metal/jit/steel_conv.h"
|
||||
#include "mlx/backend/metal/jit/steel_gemm.h"
|
||||
@@ -25,48 +23,50 @@ MTL::ComputePipelineState* get_arange_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
const array& out) {
|
||||
const auto& lib_name = kernel_name;
|
||||
auto lib = d.get_library(lib_name);
|
||||
if (lib == nullptr) {
|
||||
auto lib = d.get_library(kernel_name, [&]() {
|
||||
std::ostringstream kernel_source;
|
||||
kernel_source
|
||||
<< metal::utils() << metal::arange()
|
||||
<< fmt::format(arange_kernels, lib_name, get_type_string(out.dtype()));
|
||||
lib = d.get_library(lib_name, kernel_source.str());
|
||||
}
|
||||
kernel_source << metal::utils() << metal::arange()
|
||||
<< fmt::format(
|
||||
arange_kernels,
|
||||
kernel_name,
|
||||
get_type_string(out.dtype()));
|
||||
return kernel_source.str();
|
||||
});
|
||||
return d.get_kernel(kernel_name, lib);
|
||||
}
|
||||
|
||||
MTL::ComputePipelineState* get_unary_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
Dtype in_type,
|
||||
Dtype out_type,
|
||||
const std::string op) {
|
||||
std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1);
|
||||
auto lib = d.get_library(lib_name);
|
||||
if (lib == nullptr) {
|
||||
std::ostringstream kernel_source;
|
||||
kernel_source << metal::utils() << metal::unary_ops() << metal::unary();
|
||||
kernel_source << get_template_definition(
|
||||
"v_" + lib_name, "unary_v", get_type_string(out_type), op);
|
||||
kernel_source << get_template_definition(
|
||||
"v2_" + lib_name, "unary_v2", get_type_string(out_type), op);
|
||||
kernel_source << get_template_definition(
|
||||
"g_" + lib_name, "unary_g", get_type_string(out_type), op);
|
||||
kernel_source << get_template_definition(
|
||||
"gn4_" + lib_name, "unary_g", get_type_string(out_type), op, 4);
|
||||
lib = d.get_library(lib_name, kernel_source.str());
|
||||
}
|
||||
auto lib = d.get_library(lib_name, [&]() {
|
||||
auto in_t = get_type_string(in_type);
|
||||
auto out_t = get_type_string(out_type);
|
||||
std::string kernel_source = metal::utils();
|
||||
concatenate(kernel_source, metal::unary_ops(), metal::unary());
|
||||
kernel_source +=
|
||||
get_template_definition("v_" + lib_name, "unary_v", in_t, out_t, op);
|
||||
kernel_source +=
|
||||
get_template_definition("v2_" + lib_name, "unary_v2", in_t, out_t, op);
|
||||
kernel_source += get_template_definition(
|
||||
"gn1_" + lib_name, "unary_g", in_t, out_t, op, 1, "uint");
|
||||
kernel_source += get_template_definition(
|
||||
"gn4large_" + lib_name, "unary_g", in_t, out_t, op, 4);
|
||||
return kernel_source;
|
||||
});
|
||||
return d.get_kernel(kernel_name, lib);
|
||||
}
|
||||
|
||||
void add_binary_kernels(
|
||||
void append_binary_kernels(
|
||||
const std::string lib_name,
|
||||
Dtype in_type,
|
||||
Dtype out_type,
|
||||
const std::string op,
|
||||
std::ostringstream& kernel_source) {
|
||||
const std::array<std::pair<std::string, std::string>, 11> kernel_types = {{
|
||||
std::string& kernel_source) {
|
||||
const std::array<std::pair<std::string, std::string>, 10> kernel_types = {{
|
||||
{"ss", "binary_ss"},
|
||||
{"vs", "binary_vs"},
|
||||
{"sv", "binary_sv"},
|
||||
@@ -75,27 +75,24 @@ void add_binary_kernels(
|
||||
{"sv2", "binary_sv2"},
|
||||
{"vv2", "binary_vv2"},
|
||||
{"g1", "binary_g_nd1"},
|
||||
{"g2", "binary_g_nd2"},
|
||||
{"g3", "binary_g_nd3"},
|
||||
{"gn", "binary_g"},
|
||||
{"g2large", "binary_g_nd2"},
|
||||
{"g3large", "binary_g_nd3"},
|
||||
}};
|
||||
auto in_t = get_type_string(in_type);
|
||||
auto out_t = get_type_string(out_type);
|
||||
|
||||
for (auto& [name, func] : kernel_types) {
|
||||
std::string template_def;
|
||||
template_def = get_template_definition(
|
||||
name + "_" + lib_name,
|
||||
func,
|
||||
get_type_string(in_type),
|
||||
get_type_string(out_type),
|
||||
op);
|
||||
kernel_source << template_def;
|
||||
kernel_source +=
|
||||
get_template_definition(name + "_" + lib_name, func, in_t, out_t, op);
|
||||
}
|
||||
kernel_source << get_template_definition(
|
||||
"gn4_" + lib_name,
|
||||
"binary_g",
|
||||
get_type_string(in_type),
|
||||
get_type_string(out_type),
|
||||
op,
|
||||
4);
|
||||
kernel_source += get_template_definition(
|
||||
"g2_" + lib_name, "binary_g_nd2", in_t, out_t, op, "uint");
|
||||
kernel_source += get_template_definition(
|
||||
"g3_" + lib_name, "binary_g_nd3", in_t, out_t, op, "uint");
|
||||
kernel_source += get_template_definition(
|
||||
"gn2_" + lib_name, "binary_g", in_t, out_t, op, 2, "uint");
|
||||
kernel_source += get_template_definition(
|
||||
"gn4large_" + lib_name, "binary_g", in_t, out_t, op, 4);
|
||||
}
|
||||
|
||||
MTL::ComputePipelineState* get_binary_kernel(
|
||||
@@ -105,13 +102,13 @@ MTL::ComputePipelineState* get_binary_kernel(
|
||||
Dtype out_type,
|
||||
const std::string op) {
|
||||
std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1);
|
||||
auto lib = d.get_library(lib_name);
|
||||
if (lib == nullptr) {
|
||||
std::ostringstream kernel_source;
|
||||
kernel_source << metal::utils() << metal::binary_ops() << metal::binary();
|
||||
add_binary_kernels(lib_name, in_type, out_type, op, kernel_source);
|
||||
lib = d.get_library(lib_name, kernel_source.str());
|
||||
}
|
||||
auto lib = d.get_library(lib_name, [&]() {
|
||||
std::string kernel_source;
|
||||
kernel_source = metal::utils();
|
||||
concatenate(kernel_source, metal::binary_ops(), metal::binary());
|
||||
append_binary_kernels(lib_name, in_type, out_type, op, kernel_source);
|
||||
return kernel_source;
|
||||
});
|
||||
return d.get_kernel(kernel_name, lib);
|
||||
}
|
||||
|
||||
@@ -122,14 +119,12 @@ MTL::ComputePipelineState* get_binary_two_kernel(
|
||||
Dtype out_type,
|
||||
const std::string op) {
|
||||
std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1);
|
||||
auto lib = d.get_library(lib_name);
|
||||
if (lib == nullptr) {
|
||||
std::ostringstream kernel_source;
|
||||
kernel_source << metal::utils() << metal::binary_ops()
|
||||
<< metal::binary_two();
|
||||
add_binary_kernels(lib_name, in_type, out_type, op, kernel_source);
|
||||
lib = d.get_library(lib_name, kernel_source.str());
|
||||
}
|
||||
auto lib = d.get_library(lib_name, [&]() {
|
||||
std::string kernel_source = metal::utils();
|
||||
concatenate(kernel_source, metal::binary_ops(), metal::binary_two());
|
||||
append_binary_kernels(lib_name, in_type, out_type, op, kernel_source);
|
||||
return kernel_source;
|
||||
});
|
||||
return d.get_kernel(kernel_name, lib);
|
||||
}
|
||||
|
||||
@@ -139,28 +134,31 @@ MTL::ComputePipelineState* get_ternary_kernel(
|
||||
Dtype type,
|
||||
const std::string op) {
|
||||
std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1);
|
||||
auto lib = d.get_library(lib_name);
|
||||
if (lib == nullptr) {
|
||||
std::ostringstream kernel_source;
|
||||
const std::array<std::pair<std::string, std::string>, 6> kernel_types = {{
|
||||
auto lib = d.get_library(lib_name, [&]() {
|
||||
auto t_str = get_type_string(type);
|
||||
std::string kernel_source = metal::utils();
|
||||
concatenate(kernel_source, metal::ternary_ops(), metal::ternary());
|
||||
const std::array<std::pair<std::string, std::string>, 5> kernel_types = {{
|
||||
{"v", "ternary_v"},
|
||||
{"v2", "ternary_v2"},
|
||||
{"g", "ternary_g"},
|
||||
{"g1", "ternary_g_nd1"},
|
||||
{"g2", "ternary_g_nd2"},
|
||||
{"g3", "ternary_g_nd3"},
|
||||
{"g2large", "ternary_g_nd2"},
|
||||
{"g3large", "ternary_g_nd3"},
|
||||
}};
|
||||
kernel_source << metal::utils() << metal::ternary_ops() << metal::ternary();
|
||||
for (auto& [name, func] : kernel_types) {
|
||||
std::string template_def;
|
||||
template_def = get_template_definition(
|
||||
name + "_" + lib_name, func, get_type_string(type), op);
|
||||
kernel_source << template_def;
|
||||
kernel_source +=
|
||||
get_template_definition(name + "_" + lib_name, func, t_str, op);
|
||||
}
|
||||
kernel_source << get_template_definition(
|
||||
"gn4_" + lib_name, "ternary_g", get_type_string(type), op, 4);
|
||||
lib = d.get_library(lib_name, kernel_source.str());
|
||||
}
|
||||
kernel_source += get_template_definition(
|
||||
"g2_" + lib_name, "ternary_g_nd2", t_str, op, "uint");
|
||||
kernel_source += get_template_definition(
|
||||
"g3_" + lib_name, "ternary_g_nd3", t_str, op, "uint");
|
||||
kernel_source += get_template_definition(
|
||||
"gn2_" + lib_name, "ternary_g", t_str, op, 2, "uint");
|
||||
kernel_source += get_template_definition(
|
||||
"gn4large_" + lib_name, "ternary_g", t_str, op, 4);
|
||||
return kernel_source;
|
||||
});
|
||||
return d.get_kernel(kernel_name, lib);
|
||||
}
|
||||
|
||||
@@ -170,36 +168,45 @@ MTL::ComputePipelineState* get_copy_kernel(
|
||||
const array& in,
|
||||
const array& out) {
|
||||
std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1);
|
||||
auto lib = d.get_library(lib_name);
|
||||
if (lib == nullptr) {
|
||||
std::ostringstream kernel_source;
|
||||
auto lib = d.get_library(lib_name, [&]() {
|
||||
std::string kernel_source = metal::utils();
|
||||
kernel_source += metal::copy();
|
||||
auto in_type = get_type_string(in.dtype());
|
||||
auto out_type = get_type_string(out.dtype());
|
||||
kernel_source
|
||||
<< metal::utils() << metal::copy()
|
||||
<< get_template_definition("s_" + lib_name, "copy_s", in_type, out_type)
|
||||
<< get_template_definition("v_" + lib_name, "copy_v", in_type, out_type)
|
||||
<< get_template_definition(
|
||||
"g1_" + lib_name, "copy_g_nd1", in_type, out_type)
|
||||
<< get_template_definition(
|
||||
"g2_" + lib_name, "copy_g_nd2", in_type, out_type)
|
||||
<< get_template_definition(
|
||||
"g3_" + lib_name, "copy_g_nd3", in_type, out_type)
|
||||
<< get_template_definition("g_" + lib_name, "copy_g", in_type, out_type)
|
||||
<< get_template_definition(
|
||||
"gn4_" + lib_name, "copy_g", in_type, out_type, 4)
|
||||
<< get_template_definition(
|
||||
"gg1_" + lib_name, "copy_gg_nd1", in_type, out_type)
|
||||
<< get_template_definition(
|
||||
"gg2_" + lib_name, "copy_gg_nd2", in_type, out_type)
|
||||
<< get_template_definition(
|
||||
"gg3_" + lib_name, "copy_gg_nd3", in_type, out_type)
|
||||
<< get_template_definition(
|
||||
"gg_" + lib_name, "copy_gg", in_type, out_type)
|
||||
<< get_template_definition(
|
||||
"ggn4_" + lib_name, "copy_gg", in_type, out_type, 4);
|
||||
lib = d.get_library(lib_name, kernel_source.str());
|
||||
}
|
||||
kernel_source +=
|
||||
get_template_definition("s_" + lib_name, "copy_s", in_type, out_type);
|
||||
kernel_source +=
|
||||
get_template_definition("v_" + lib_name, "copy_v", in_type, out_type);
|
||||
kernel_source += get_template_definition(
|
||||
"g1_" + lib_name, "copy_g_nd1", in_type, out_type);
|
||||
kernel_source += get_template_definition(
|
||||
"g2_" + lib_name, "copy_g_nd2", in_type, out_type, "int");
|
||||
kernel_source += get_template_definition(
|
||||
"g3_" + lib_name, "copy_g_nd3", in_type, out_type, "int");
|
||||
kernel_source += get_template_definition(
|
||||
"gn2_" + lib_name, "copy_g", in_type, out_type, 2, "int");
|
||||
kernel_source += get_template_definition(
|
||||
"gg1_" + lib_name, "copy_gg_nd1", in_type, out_type);
|
||||
kernel_source += get_template_definition(
|
||||
"gg2_" + lib_name, "copy_gg_nd2", in_type, out_type, "int");
|
||||
kernel_source += get_template_definition(
|
||||
"gg3_" + lib_name, "copy_gg_nd3", in_type, out_type, "int");
|
||||
kernel_source += get_template_definition(
|
||||
"ggn2_" + lib_name, "copy_gg", in_type, out_type, 2, "int");
|
||||
kernel_source += get_template_definition(
|
||||
"g2large_" + lib_name, "copy_g_nd2", in_type, out_type);
|
||||
kernel_source += get_template_definition(
|
||||
"g3large_" + lib_name, "copy_g_nd3", in_type, out_type);
|
||||
kernel_source += get_template_definition(
|
||||
"gn4large_" + lib_name, "copy_g", in_type, out_type, 4);
|
||||
kernel_source += get_template_definition(
|
||||
"gg2large_" + lib_name, "copy_gg_nd2", in_type, out_type);
|
||||
kernel_source += get_template_definition(
|
||||
"gg3large_" + lib_name, "copy_gg_nd3", in_type, out_type);
|
||||
kernel_source += get_template_definition(
|
||||
"ggn4large_" + lib_name, "copy_gg", in_type, out_type, 4);
|
||||
return kernel_source;
|
||||
});
|
||||
return d.get_kernel(kernel_name, lib);
|
||||
}
|
||||
|
||||
@@ -209,8 +216,7 @@ MTL::ComputePipelineState* get_softmax_kernel(
|
||||
bool precise,
|
||||
const array& out) {
|
||||
std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1);
|
||||
auto lib = d.get_library(lib_name);
|
||||
if (lib == nullptr) {
|
||||
auto lib = d.get_library(lib_name, [&] {
|
||||
std::ostringstream kernel_source;
|
||||
kernel_source << metal::utils() << metal::softmax()
|
||||
<< fmt::format(
|
||||
@@ -218,8 +224,8 @@ MTL::ComputePipelineState* get_softmax_kernel(
|
||||
lib_name,
|
||||
get_type_string(out.dtype()),
|
||||
get_type_string(precise ? float32 : out.dtype()));
|
||||
lib = d.get_library(lib_name, kernel_source.str());
|
||||
}
|
||||
return kernel_source.str();
|
||||
});
|
||||
return d.get_kernel(kernel_name, lib);
|
||||
}
|
||||
|
||||
@@ -232,22 +238,29 @@ MTL::ComputePipelineState* get_scan_kernel(
|
||||
const array& in,
|
||||
const array& out) {
|
||||
std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1);
|
||||
auto lib = d.get_library(lib_name);
|
||||
if (lib == nullptr) {
|
||||
std::string op_name = "Cum" + reduce_type;
|
||||
op_name[3] = toupper(op_name[3]);
|
||||
auto lib = d.get_library(lib_name, [&]() {
|
||||
auto out_type = get_type_string(out.dtype());
|
||||
std::string op = "Cum" + reduce_type + "<" + out_type + ">";
|
||||
op[3] = toupper(op[3]);
|
||||
std::ostringstream kernel_source;
|
||||
kernel_source << metal::utils() << metal::scan()
|
||||
<< fmt::format(
|
||||
scan_kernels,
|
||||
lib_name,
|
||||
get_type_string(in.dtype()),
|
||||
get_type_string(out.dtype()),
|
||||
op_name,
|
||||
inclusive,
|
||||
reverse);
|
||||
lib = d.get_library(lib_name, kernel_source.str());
|
||||
}
|
||||
kernel_source << metal::utils() << metal::scan();
|
||||
const std::array<std::pair<std::string, std::string>, 2> scan_kernels = {{
|
||||
{"contig_", "contiguous_scan"},
|
||||
{"strided_", "strided_scan"},
|
||||
}};
|
||||
for (auto& [prefix, kernel] : scan_kernels) {
|
||||
kernel_source << get_template_definition(
|
||||
prefix + lib_name,
|
||||
kernel,
|
||||
get_type_string(in.dtype()),
|
||||
get_type_string(out.dtype()),
|
||||
op,
|
||||
in.itemsize() <= 4 ? 4 : 2,
|
||||
inclusive,
|
||||
reverse);
|
||||
}
|
||||
return kernel_source.str();
|
||||
});
|
||||
return d.get_kernel(kernel_name, lib);
|
||||
}
|
||||
|
||||
@@ -259,8 +272,7 @@ MTL::ComputePipelineState* get_sort_kernel(
|
||||
int bn,
|
||||
int tn) {
|
||||
std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1);
|
||||
auto lib = d.get_library(lib_name);
|
||||
if (lib == nullptr) {
|
||||
auto lib = d.get_library(lib_name, [&]() {
|
||||
std::ostringstream kernel_source;
|
||||
auto in_type = get_type_string(in.dtype());
|
||||
auto out_type = get_type_string(out.dtype());
|
||||
@@ -285,8 +297,8 @@ MTL::ComputePipelineState* get_sort_kernel(
|
||||
bn,
|
||||
tn);
|
||||
}
|
||||
lib = d.get_library(lib_name, kernel_source.str());
|
||||
}
|
||||
return kernel_source.str();
|
||||
});
|
||||
return d.get_kernel(kernel_name, lib);
|
||||
}
|
||||
|
||||
@@ -298,8 +310,7 @@ MTL::ComputePipelineState* get_mb_sort_kernel(
|
||||
int bn,
|
||||
int tn) {
|
||||
std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1);
|
||||
auto lib = d.get_library(lib_name);
|
||||
if (lib == nullptr) {
|
||||
auto lib = d.get_library(lib_name, [&]() {
|
||||
std::ostringstream kernel_source;
|
||||
kernel_source << metal::utils() << metal::sort();
|
||||
std::array<std::pair<std::string, std::string>, 3> kernel_types = {
|
||||
@@ -316,27 +327,28 @@ MTL::ComputePipelineState* get_mb_sort_kernel(
|
||||
bn,
|
||||
tn);
|
||||
}
|
||||
lib = d.get_library(lib_name, kernel_source.str());
|
||||
}
|
||||
return kernel_source.str();
|
||||
});
|
||||
return d.get_kernel(kernel_name, lib);
|
||||
}
|
||||
|
||||
MTL::ComputePipelineState* get_reduce_init_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
const array& out) {
|
||||
auto lib = d.get_library(kernel_name);
|
||||
if (lib == nullptr) {
|
||||
std::ostringstream kernel_source;
|
||||
std::string op_type = op_name(out);
|
||||
op_type[0] = std::toupper(op_name(out)[0]);
|
||||
auto out_type = get_type_string(out.dtype());
|
||||
std::string op = op_type + "<" + out_type + ">";
|
||||
kernel_source << metal::utils() << metal::reduce_utils() << metal::reduce();
|
||||
kernel_source << get_template_definition(
|
||||
kernel_name, "init_reduce", out_type, op);
|
||||
lib = d.get_library(kernel_name, kernel_source.str());
|
||||
}
|
||||
const std::string& func_name,
|
||||
const std::string& op_name,
|
||||
const Dtype& out_type) {
|
||||
auto lib = d.get_library(kernel_name, [&]() {
|
||||
std::string op_type = op_name;
|
||||
op_type[0] = std::toupper(op_name[0]);
|
||||
auto out_t = get_type_string(out_type);
|
||||
std::string op = op_type + "<" + out_t + ">";
|
||||
std::string kernel_source = metal::utils();
|
||||
kernel_source += metal::reduce_utils();
|
||||
kernel_source += metal::reduce();
|
||||
kernel_source += get_template_definition(kernel_name, func_name, out_t, op);
|
||||
return kernel_source;
|
||||
});
|
||||
return d.get_kernel(kernel_name, lib);
|
||||
}
|
||||
|
||||
@@ -345,32 +357,32 @@ MTL::ComputePipelineState* get_reduce_kernel(
|
||||
const std::string& kernel_name,
|
||||
const std::string& func_name,
|
||||
const std::string& op_name,
|
||||
const array& in,
|
||||
const array& out,
|
||||
const Dtype& in_type,
|
||||
const Dtype& out_type,
|
||||
const std::string& idx_t,
|
||||
int ndim /* = -1 */,
|
||||
int bm /* = -1 */,
|
||||
int bn /* = -1 */) {
|
||||
auto lib = d.get_library(kernel_name);
|
||||
if (lib == nullptr) {
|
||||
auto lib = d.get_library(kernel_name, [&]() {
|
||||
std::string op_type = op_name;
|
||||
op_type[0] = std::toupper(op_name[0]);
|
||||
std::ostringstream kernel_source;
|
||||
auto in_type = get_type_string(in.dtype());
|
||||
auto out_type = get_type_string(out.dtype());
|
||||
std::string op = op_type + "<" + out_type + ">";
|
||||
kernel_source << metal::utils() << metal::reduce_utils() << metal::reduce();
|
||||
auto in_t = get_type_string(in_type);
|
||||
auto out_t = get_type_string(out_type);
|
||||
std::string op = op_type + "<" + out_t + ">";
|
||||
std::string kernel_source = metal::utils();
|
||||
concatenate(kernel_source, metal::reduce_utils(), metal::reduce());
|
||||
if (bm >= 0) {
|
||||
kernel_source << get_template_definition(
|
||||
kernel_name, func_name, in_type, out_type, op, ndim, bm, bn);
|
||||
kernel_source += get_template_definition(
|
||||
kernel_name, func_name, in_t, out_t, op, idx_t, ndim, bm, bn);
|
||||
} else if (ndim >= 0) {
|
||||
kernel_source << get_template_definition(
|
||||
kernel_name, func_name, in_type, out_type, op, ndim);
|
||||
kernel_source += get_template_definition(
|
||||
kernel_name, func_name, in_t, out_t, op, idx_t, ndim);
|
||||
} else {
|
||||
kernel_source << get_template_definition(
|
||||
kernel_name, func_name, in_type, out_type, op);
|
||||
kernel_source += get_template_definition(
|
||||
kernel_name, func_name, in_t, out_t, op, idx_t);
|
||||
}
|
||||
lib = d.get_library(kernel_name, kernel_source.str());
|
||||
}
|
||||
return kernel_source;
|
||||
});
|
||||
auto st = d.get_kernel(kernel_name, lib);
|
||||
return st;
|
||||
}
|
||||
@@ -389,8 +401,7 @@ MTL::ComputePipelineState* get_steel_gemm_fused_kernel(
|
||||
int wm,
|
||||
int wn) {
|
||||
const auto& lib_name = kernel_name;
|
||||
auto lib = d.get_library(lib_name);
|
||||
if (lib == nullptr) {
|
||||
auto lib = d.get_library(lib_name, [&]() {
|
||||
std::ostringstream kernel_source;
|
||||
kernel_source << metal::utils() << metal::gemm()
|
||||
<< metal::steel_gemm_fused()
|
||||
@@ -405,8 +416,8 @@ MTL::ComputePipelineState* get_steel_gemm_fused_kernel(
|
||||
"wn"_a = wn,
|
||||
"trans_a"_a = transpose_a,
|
||||
"trans_b"_a = transpose_b);
|
||||
lib = d.get_library(lib_name, kernel_source.str());
|
||||
}
|
||||
return kernel_source.str();
|
||||
});
|
||||
return d.get_kernel(kernel_name, lib, hash_name, func_consts);
|
||||
}
|
||||
|
||||
@@ -425,8 +436,7 @@ MTL::ComputePipelineState* get_steel_gemm_splitk_kernel(
|
||||
bool mn_aligned,
|
||||
bool k_aligned) {
|
||||
const auto& lib_name = kernel_name;
|
||||
auto lib = d.get_library(lib_name);
|
||||
if (lib == nullptr) {
|
||||
auto lib = d.get_library(lib_name, [&]() {
|
||||
std::ostringstream kernel_source;
|
||||
kernel_source << metal::utils() << metal::gemm()
|
||||
<< metal::steel_gemm_splitk()
|
||||
@@ -444,8 +454,8 @@ MTL::ComputePipelineState* get_steel_gemm_splitk_kernel(
|
||||
"trans_b"_a = transpose_b,
|
||||
"mn_aligned"_a = mn_aligned,
|
||||
"k_aligned"_a = k_aligned);
|
||||
lib = d.get_library(lib_name, kernel_source.str());
|
||||
}
|
||||
return kernel_source.str();
|
||||
});
|
||||
return d.get_kernel(kernel_name, lib);
|
||||
}
|
||||
|
||||
@@ -456,19 +466,19 @@ MTL::ComputePipelineState* get_steel_gemm_splitk_accum_kernel(
|
||||
const array& out,
|
||||
bool axbpy) {
|
||||
const auto& lib_name = kernel_name;
|
||||
auto lib = d.get_library(lib_name);
|
||||
if (lib == nullptr) {
|
||||
auto lib = d.get_library(lib_name, [&]() {
|
||||
std::ostringstream kernel_source;
|
||||
kernel_source << metal::utils() << metal::gemm()
|
||||
<< metal::steel_gemm_splitk()
|
||||
<< fmt::format(
|
||||
axbpy ? steel_gemm_splitk_accum_axbpy_kernels
|
||||
: steel_gemm_splitk_accum_kernels,
|
||||
fmt::runtime(
|
||||
axbpy ? steel_gemm_splitk_accum_axbpy_kernels
|
||||
: steel_gemm_splitk_accum_kernels),
|
||||
"name"_a = lib_name,
|
||||
"atype"_a = get_type_string(in.dtype()),
|
||||
"otype"_a = get_type_string(out.dtype()));
|
||||
lib = d.get_library(lib_name, kernel_source.str());
|
||||
}
|
||||
return kernel_source.str();
|
||||
});
|
||||
return d.get_kernel(kernel_name, lib);
|
||||
}
|
||||
|
||||
@@ -488,8 +498,7 @@ MTL::ComputePipelineState* get_steel_gemm_masked_kernel(
|
||||
bool mn_aligned,
|
||||
bool k_aligned) {
|
||||
const auto& lib_name = kernel_name;
|
||||
auto lib = d.get_library(lib_name);
|
||||
if (lib == nullptr) {
|
||||
auto lib = d.get_library(lib_name, [&]() {
|
||||
std::ostringstream kernel_source;
|
||||
auto out_mask_type = mask_out.has_value()
|
||||
? get_type_string((*mask_out).dtype())
|
||||
@@ -513,8 +522,8 @@ MTL::ComputePipelineState* get_steel_gemm_masked_kernel(
|
||||
"trans_b"_a = transpose_b,
|
||||
"mn_aligned"_a = mn_aligned,
|
||||
"k_aligned"_a = k_aligned);
|
||||
lib = d.get_library(lib_name, kernel_source.str());
|
||||
}
|
||||
return kernel_source.str();
|
||||
});
|
||||
return d.get_kernel(kernel_name, lib);
|
||||
}
|
||||
|
||||
@@ -533,8 +542,7 @@ MTL::ComputePipelineState* get_gemv_masked_kernel(
|
||||
int tn,
|
||||
bool contiguous) {
|
||||
const auto& lib_name = kernel_name;
|
||||
auto lib = d.get_library(lib_name);
|
||||
if (lib == nullptr) {
|
||||
auto lib = d.get_library(lib_name, [&]() {
|
||||
std::ostringstream kernel_source;
|
||||
auto out_mask_type = mask_out.has_value()
|
||||
? get_type_string((*mask_out).dtype())
|
||||
@@ -556,8 +564,8 @@ MTL::ComputePipelineState* get_gemv_masked_kernel(
|
||||
"tn"_a = tn,
|
||||
"trans"_a = transpose_mat ? "t_" : "",
|
||||
"nc"_a = contiguous ? "0" : "1");
|
||||
lib = d.get_library(lib_name, kernel_source.str());
|
||||
}
|
||||
return kernel_source.str();
|
||||
});
|
||||
return d.get_kernel(kernel_name, lib);
|
||||
}
|
||||
|
||||
@@ -573,8 +581,7 @@ MTL::ComputePipelineState* get_steel_conv_kernel(
|
||||
int n_channel_specialization,
|
||||
bool small_filter) {
|
||||
const auto& lib_name = kernel_name;
|
||||
auto lib = d.get_library(lib_name);
|
||||
if (lib == nullptr) {
|
||||
auto lib = d.get_library(lib_name, [&]() {
|
||||
std::ostringstream kernel_source;
|
||||
kernel_source << metal::utils() << metal::conv() << metal::steel_conv()
|
||||
<< fmt::format(
|
||||
@@ -588,8 +595,8 @@ MTL::ComputePipelineState* get_steel_conv_kernel(
|
||||
"wn"_a = wn,
|
||||
"n_channels"_a = n_channel_specialization,
|
||||
"small_filter"_a = small_filter);
|
||||
lib = d.get_library(lib_name, kernel_source.str());
|
||||
}
|
||||
return kernel_source.str();
|
||||
});
|
||||
return d.get_kernel(kernel_name, lib);
|
||||
}
|
||||
|
||||
@@ -603,8 +610,7 @@ MTL::ComputePipelineState* get_steel_conv_general_kernel(
|
||||
int wm,
|
||||
int wn) {
|
||||
const auto& lib_name = kernel_name;
|
||||
auto lib = d.get_library(lib_name);
|
||||
if (lib == nullptr) {
|
||||
auto lib = d.get_library(lib_name, [&]() {
|
||||
std::ostringstream kernel_source;
|
||||
kernel_source << metal::utils() << metal::conv()
|
||||
<< metal::steel_conv_general()
|
||||
@@ -617,8 +623,8 @@ MTL::ComputePipelineState* get_steel_conv_general_kernel(
|
||||
"bk"_a = bk,
|
||||
"wm"_a = wm,
|
||||
"wn"_a = wn);
|
||||
lib = d.get_library(lib_name, kernel_source.str());
|
||||
}
|
||||
return kernel_source.str();
|
||||
});
|
||||
return d.get_kernel(kernel_name, lib);
|
||||
}
|
||||
|
||||
@@ -629,13 +635,12 @@ MTL::ComputePipelineState* get_fft_kernel(
|
||||
const metal::MTLFCList& func_consts,
|
||||
const std::string& template_def) {
|
||||
const auto& lib_name = kernel_name;
|
||||
auto lib = d.get_library(lib_name);
|
||||
if (lib == nullptr) {
|
||||
auto lib = d.get_library(lib_name, [&]() {
|
||||
std::ostringstream kernel_source;
|
||||
std::string kernel_string;
|
||||
kernel_source << metal::fft() << template_def;
|
||||
lib = d.get_library(lib_name, kernel_source.str());
|
||||
}
|
||||
return kernel_source.str();
|
||||
});
|
||||
return d.get_kernel(kernel_name, lib, hash_name, func_consts);
|
||||
}
|
||||
|
||||
@@ -644,13 +649,12 @@ MTL::ComputePipelineState* get_quantized_kernel(
|
||||
const std::string& kernel_name,
|
||||
const std::string& template_def) {
|
||||
const auto& lib_name = kernel_name;
|
||||
auto lib = d.get_library(lib_name);
|
||||
if (lib == nullptr) {
|
||||
auto lib = d.get_library(lib_name, [&]() {
|
||||
std::ostringstream kernel_source;
|
||||
kernel_source << metal::utils() << metal::gemm() << metal::quantized()
|
||||
<< template_def;
|
||||
lib = d.get_library(lib_name, kernel_source.str());
|
||||
}
|
||||
return kernel_source.str();
|
||||
});
|
||||
return d.get_kernel(kernel_name, lib);
|
||||
}
|
||||
|
||||
|
@@ -15,6 +15,7 @@ MTL::ComputePipelineState* get_arange_kernel(
|
||||
MTL::ComputePipelineState* get_unary_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
Dtype in_type,
|
||||
Dtype out_type,
|
||||
const std::string op);
|
||||
|
||||
@@ -78,15 +79,18 @@ MTL::ComputePipelineState* get_mb_sort_kernel(
|
||||
MTL::ComputePipelineState* get_reduce_init_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
const array& out);
|
||||
const std::string& func_name,
|
||||
const std::string& op_name,
|
||||
const Dtype& out_type);
|
||||
|
||||
MTL::ComputePipelineState* get_reduce_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
const std::string& func_name,
|
||||
const std::string& op_name,
|
||||
const array& in,
|
||||
const array& out,
|
||||
const Dtype& in_type,
|
||||
const Dtype& out_type,
|
||||
const std::string& idx_t,
|
||||
int ndim = -1,
|
||||
int bm = -1,
|
||||
int bn = -1);
|
||||
@@ -208,10 +212,10 @@ get_template_definition(std::string name, std::string func, Args... args) {
|
||||
};
|
||||
(add_arg(args), ...);
|
||||
s << ">";
|
||||
std::string base_string = R"(
|
||||
template [[host_name("{0}")]] [[kernel]] decltype({1}) {1};
|
||||
)";
|
||||
return fmt::format(base_string, name, s.str());
|
||||
return fmt::format(
|
||||
"\ntemplate [[host_name(\"{0}\")]] [[kernel]] decltype({1}) {1};\n",
|
||||
name,
|
||||
s.str());
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@@ -1,13 +1,27 @@
|
||||
set(BASE_HEADERS bf16.h bf16_math.h complex.h defines.h expm1f.h utils.h)
|
||||
set(BASE_HEADERS
|
||||
metal_3_1/bf16.h
|
||||
metal_3_0/bf16.h
|
||||
bf16_math.h
|
||||
complex.h
|
||||
defines.h
|
||||
expm1f.h
|
||||
utils.h)
|
||||
|
||||
function(build_kernel_base TARGET SRCFILE DEPS)
|
||||
set(METAL_FLAGS -Wall -Wextra -fno-fast-math)
|
||||
if(MLX_METAL_DEBUG)
|
||||
set(METAL_FLAGS ${METAL_FLAGS} -gline-tables-only -frecord-sources)
|
||||
endif()
|
||||
if(MLX_METAL_VERSION GREATER_EQUAL 310)
|
||||
set(VERSION_INCLUDES
|
||||
${PROJECT_SOURCE_DIR}/mlx/backend/metal/kernels/metal_3_1)
|
||||
else()
|
||||
set(VERSION_INCLUDES
|
||||
${PROJECT_SOURCE_DIR}/mlx/backend/metal/kernels/metal_3_0)
|
||||
endif()
|
||||
add_custom_command(
|
||||
COMMAND xcrun -sdk macosx metal ${METAL_FLAGS} -c ${SRCFILE}
|
||||
-I${PROJECT_SOURCE_DIR} -o ${TARGET}.air
|
||||
-I${PROJECT_SOURCE_DIR} -I${VERSION_INCLUDES} -o ${TARGET}.air
|
||||
DEPENDS ${SRCFILE} ${DEPS} ${BASE_HEADERS}
|
||||
OUTPUT ${TARGET}.air
|
||||
COMMENT "Building ${TARGET}.air"
|
||||
@@ -30,8 +44,7 @@ build_kernel(layer_norm)
|
||||
build_kernel(random)
|
||||
build_kernel(rms_norm)
|
||||
build_kernel(rope)
|
||||
build_kernel(scaled_dot_product_attention scaled_dot_product_attention_params.h
|
||||
steel/defines.h steel/gemm/transforms.h steel/utils.h)
|
||||
build_kernel(scaled_dot_product_attention sdpa_vector.h)
|
||||
|
||||
set(STEEL_HEADERS
|
||||
steel/defines.h
|
||||
@@ -49,7 +62,27 @@ set(STEEL_HEADERS
|
||||
steel/gemm/transforms.h
|
||||
steel/gemm/kernels/steel_gemm_fused.h
|
||||
steel/gemm/kernels/steel_gemm_masked.h
|
||||
steel/gemm/kernels/steel_gemm_splitk.h)
|
||||
steel/gemm/kernels/steel_gemm_splitk.h
|
||||
steel/utils/type_traits.h
|
||||
steel/utils/integral_constant.h)
|
||||
|
||||
set(STEEL_ATTN_HEADERS
|
||||
steel/defines.h
|
||||
steel/utils.h
|
||||
steel/gemm/gemm.h
|
||||
steel/gemm/mma.h
|
||||
steel/gemm/loader.h
|
||||
steel/gemm/transforms.h
|
||||
steel/utils/type_traits.h
|
||||
steel/utils/integral_constant.h
|
||||
steel/attn/attn.h
|
||||
steel/attn/loader.h
|
||||
steel/attn/mma.h
|
||||
steel/attn/params.h
|
||||
steel/attn/transforms.h
|
||||
steel/attn/kernels/steel_attention.h)
|
||||
|
||||
build_kernel(steel/attn/kernels/steel_attention ${STEEL_ATTN_HEADERS})
|
||||
|
||||
if(NOT MLX_METAL_JIT)
|
||||
build_kernel(arange arange.h)
|
||||
|
@@ -1,7 +1,7 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
// clang-format off
|
||||
#include "mlx/backend/metal/kernels/bf16.h"
|
||||
#include "mlx/backend/metal/kernels/utils.h"
|
||||
#include "mlx/backend/metal/kernels/arange.h"
|
||||
|
||||
#define instantiate_arange(tname, type) \
|
||||
|
@@ -2,8 +2,6 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "mlx/backend/metal/kernels/bf16.h"
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Metal math for bfloat16
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
@@ -369,18 +367,6 @@ instantiate_metal_math_funcs(
|
||||
return static_cast<otype>(__metal_simd_xor(static_cast<ctype>(data))); \
|
||||
}
|
||||
|
||||
#if (MLX_METAL_VERSION >= 310) || (__METAL_VERSION__ >= 310)
|
||||
|
||||
#define bfloat16_to_uint16(x) as_type<uint16_t>(x)
|
||||
#define uint16_to_bfloat16(x) as_type<bfloat16_t>(x)
|
||||
|
||||
#else
|
||||
|
||||
#define bfloat16_to_uint16(x) x.bits_
|
||||
#define uint16_to_bfloat16(x) _MLX_BFloat16(x, _MLX_BFloat16::bits_to_bfloat())
|
||||
|
||||
#endif
|
||||
|
||||
namespace metal {
|
||||
|
||||
instantiate_metal_simd_comm_funcs(
|
||||
|
@@ -77,12 +77,12 @@ template <typename T, typename U, typename Op>
|
||||
constant const size_t& a_stride,
|
||||
constant const size_t& b_stride,
|
||||
uint index [[thread_position_in_grid]]) {
|
||||
auto a_idx = elem_to_loc_1(index, a_stride);
|
||||
auto b_idx = elem_to_loc_1(index, b_stride);
|
||||
auto a_idx = elem_to_loc_1<size_t, uint>(index, a_stride);
|
||||
auto b_idx = elem_to_loc_1<size_t, uint>(index, b_stride);
|
||||
c[index] = Op()(a[a_idx], b[b_idx]);
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op>
|
||||
template <typename T, typename U, typename Op, typename IdxT = size_t>
|
||||
[[kernel]] void binary_g_nd2(
|
||||
device const T* a,
|
||||
device const T* b,
|
||||
@@ -91,13 +91,13 @@ template <typename T, typename U, typename Op>
|
||||
constant const size_t b_strides[2],
|
||||
uint2 index [[thread_position_in_grid]],
|
||||
uint2 grid_dim [[threads_per_grid]]) {
|
||||
auto a_idx = elem_to_loc_2(index, a_strides);
|
||||
auto b_idx = elem_to_loc_2(index, b_strides);
|
||||
size_t out_idx = index.x + size_t(grid_dim.x) * index.y;
|
||||
auto a_idx = elem_to_loc_2<size_t, IdxT>(index, a_strides);
|
||||
auto b_idx = elem_to_loc_2<size_t, IdxT>(index, b_strides);
|
||||
IdxT out_idx = index.x + IdxT(grid_dim.x) * index.y;
|
||||
c[out_idx] = Op()(a[a_idx], b[b_idx]);
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op>
|
||||
template <typename T, typename U, typename Op, typename IdxT = size_t>
|
||||
[[kernel]] void binary_g_nd3(
|
||||
device const T* a,
|
||||
device const T* b,
|
||||
@@ -106,14 +106,18 @@ template <typename T, typename U, typename Op>
|
||||
constant const size_t b_strides[3],
|
||||
uint3 index [[thread_position_in_grid]],
|
||||
uint3 grid_dim [[threads_per_grid]]) {
|
||||
auto a_idx = elem_to_loc_3(index, a_strides);
|
||||
auto b_idx = elem_to_loc_3(index, b_strides);
|
||||
size_t out_idx =
|
||||
index.x + grid_dim.x * (index.y + size_t(grid_dim.y) * index.z);
|
||||
auto a_idx = elem_to_loc_3<size_t, IdxT>(index, a_strides);
|
||||
auto b_idx = elem_to_loc_3<size_t, IdxT>(index, b_strides);
|
||||
IdxT out_idx = index.x + grid_dim.x * (index.y + IdxT(grid_dim.y) * index.z);
|
||||
c[out_idx] = Op()(a[a_idx], b[b_idx]);
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op, int N = 1>
|
||||
template <
|
||||
typename T,
|
||||
typename U,
|
||||
typename Op,
|
||||
int N = 1,
|
||||
typename IdxT = size_t>
|
||||
[[kernel]] void binary_g(
|
||||
device const T* a,
|
||||
device const T* b,
|
||||
@@ -124,13 +128,12 @@ template <typename T, typename U, typename Op, int N = 1>
|
||||
constant const int& ndim,
|
||||
uint3 index [[thread_position_in_grid]],
|
||||
uint3 grid_dim [[threads_per_grid]]) {
|
||||
auto idx = elem_to_loc_2_nd(
|
||||
auto idx = elem_to_loc_2_nd<size_t, IdxT>(
|
||||
{N * index.x, index.y, index.z}, shape, a_strides, b_strides, ndim);
|
||||
auto xshape = shape[ndim - 1];
|
||||
size_t out_idx =
|
||||
N * index.x + xshape * (index.y + size_t(grid_dim.y) * index.z);
|
||||
auto a_xstride = a_strides[ndim - 1];
|
||||
auto b_xstride = b_strides[ndim - 1];
|
||||
IdxT out_idx = N * index.x + xshape * (index.y + IdxT(grid_dim.y) * index.z);
|
||||
IdxT a_xstride = a_strides[ndim - 1];
|
||||
IdxT b_xstride = b_strides[ndim - 1];
|
||||
for (int i = 0; i < N && (int(N * index.x) + i) < xshape; ++i) {
|
||||
c[out_idx++] = Op()(a[idx.x], b[idx.y]);
|
||||
idx.x += a_xstride;
|
||||
|
@@ -9,19 +9,21 @@
|
||||
#include "mlx/backend/metal/kernels/binary_ops.h"
|
||||
#include "mlx/backend/metal/kernels/binary.h"
|
||||
|
||||
#define instantiate_binary_all(op, tname, itype, otype) \
|
||||
instantiate_kernel("ss_" #op #tname, binary_ss, itype, otype, op) \
|
||||
instantiate_kernel("sv_" #op #tname, binary_sv, itype, otype, op) \
|
||||
instantiate_kernel("vs_" #op #tname, binary_vs, itype, otype, op) \
|
||||
instantiate_kernel("vv_" #op #tname, binary_vv, itype, otype, op) \
|
||||
instantiate_kernel("sv2_" #op #tname, binary_sv2, itype, otype, op) \
|
||||
instantiate_kernel("vs2_" #op #tname, binary_vs2, itype, otype, op) \
|
||||
instantiate_kernel("vv2_" #op #tname, binary_vv2, itype, otype, op) \
|
||||
instantiate_kernel("gn_" #op #tname, binary_g, itype, otype, op) \
|
||||
instantiate_kernel("gn4_" #op #tname, binary_g, itype, otype, op, 4) \
|
||||
instantiate_kernel("g1_" #op #tname, binary_g_nd1, itype, otype, op) \
|
||||
instantiate_kernel("g2_" #op #tname, binary_g_nd2, itype, otype, op) \
|
||||
instantiate_kernel("g3_" #op #tname, binary_g_nd3, itype, otype, op) \
|
||||
#define instantiate_binary_all(op, tname, itype, otype) \
|
||||
instantiate_kernel("ss_" #op #tname, binary_ss, itype, otype, op) \
|
||||
instantiate_kernel("sv_" #op #tname, binary_sv, itype, otype, op) \
|
||||
instantiate_kernel("vs_" #op #tname, binary_vs, itype, otype, op) \
|
||||
instantiate_kernel("vv_" #op #tname, binary_vv, itype, otype, op) \
|
||||
instantiate_kernel("sv2_" #op #tname, binary_sv2, itype, otype, op) \
|
||||
instantiate_kernel("vs2_" #op #tname, binary_vs2, itype, otype, op) \
|
||||
instantiate_kernel("vv2_" #op #tname, binary_vv2, itype, otype, op) \
|
||||
instantiate_kernel("gn2_" #op #tname, binary_g, itype, otype, op, 2, uint) \
|
||||
instantiate_kernel("gn4large_" #op #tname, binary_g, itype, otype, op, 4) \
|
||||
instantiate_kernel("g1_" #op #tname, binary_g_nd1, itype, otype, op) \
|
||||
instantiate_kernel("g2_" #op #tname, binary_g_nd2, itype, otype, op, uint) \
|
||||
instantiate_kernel("g2large_" #op #tname, binary_g_nd2, itype, otype, op) \
|
||||
instantiate_kernel("g3_" #op #tname, binary_g_nd3, itype, otype, op, uint) \
|
||||
instantiate_kernel("g3large_" #op #tname, binary_g_nd3, itype, otype, op)
|
||||
|
||||
#define instantiate_binary_integer(op) \
|
||||
instantiate_binary_all(op, uint8, uint8_t, uint8_t) \
|
||||
|
@@ -217,7 +217,7 @@ struct Power {
|
||||
|
||||
template <>
|
||||
complex64_t operator()(complex64_t x, complex64_t y) {
|
||||
auto x_theta = metal::atan(x.imag / x.real);
|
||||
auto x_theta = metal::atan2(x.imag, x.real);
|
||||
auto x_ln_r = 0.5 * metal::log(x.real * x.real + x.imag * x.imag);
|
||||
auto mag = metal::exp(y.real * x_ln_r - y.imag * x_theta);
|
||||
auto phase = y.imag * x_ln_r + y.real * x_theta;
|
||||
|
@@ -99,14 +99,14 @@ template <typename T, typename U, typename Op>
|
||||
constant const size_t& a_stride,
|
||||
constant const size_t& b_stride,
|
||||
uint index [[thread_position_in_grid]]) {
|
||||
auto a_idx = elem_to_loc_1(index, a_stride);
|
||||
auto b_idx = elem_to_loc_1(index, b_stride);
|
||||
auto a_idx = elem_to_loc_1<size_t, uint>(index, a_stride);
|
||||
auto b_idx = elem_to_loc_1<size_t, uint>(index, b_stride);
|
||||
auto out = Op()(a[a_idx], b[b_idx]);
|
||||
c[index] = out[0];
|
||||
d[index] = out[1];
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op>
|
||||
template <typename T, typename U, typename Op, typename IdxT = size_t>
|
||||
[[kernel]] void binary_g_nd2(
|
||||
device const T* a,
|
||||
device const T* b,
|
||||
@@ -116,15 +116,15 @@ template <typename T, typename U, typename Op>
|
||||
constant const size_t b_strides[2],
|
||||
uint2 index [[thread_position_in_grid]],
|
||||
uint2 grid_dim [[threads_per_grid]]) {
|
||||
auto a_idx = elem_to_loc_2(index, a_strides);
|
||||
auto b_idx = elem_to_loc_2(index, b_strides);
|
||||
size_t out_idx = index.x + size_t(grid_dim.x) * index.y;
|
||||
auto a_idx = elem_to_loc_2<size_t, IdxT>(index, a_strides);
|
||||
auto b_idx = elem_to_loc_2<size_t, IdxT>(index, b_strides);
|
||||
IdxT out_idx = index.x + IdxT(grid_dim.x) * index.y;
|
||||
auto out = Op()(a[a_idx], b[b_idx]);
|
||||
c[out_idx] = out[0];
|
||||
d[out_idx] = out[1];
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op>
|
||||
template <typename T, typename U, typename Op, typename IdxT = size_t>
|
||||
[[kernel]] void binary_g_nd3(
|
||||
device const T* a,
|
||||
device const T* b,
|
||||
@@ -134,16 +134,20 @@ template <typename T, typename U, typename Op>
|
||||
constant const size_t b_strides[3],
|
||||
uint3 index [[thread_position_in_grid]],
|
||||
uint3 grid_dim [[threads_per_grid]]) {
|
||||
auto a_idx = elem_to_loc_3(index, a_strides);
|
||||
auto b_idx = elem_to_loc_3(index, b_strides);
|
||||
size_t out_idx =
|
||||
index.x + grid_dim.x * (index.y + size_t(grid_dim.y) * index.z);
|
||||
auto a_idx = elem_to_loc_3<size_t, IdxT>(index, a_strides);
|
||||
auto b_idx = elem_to_loc_3<size_t, IdxT>(index, b_strides);
|
||||
IdxT out_idx = index.x + grid_dim.x * (index.y + IdxT(grid_dim.y) * index.z);
|
||||
auto out = Op()(a[a_idx], b[b_idx]);
|
||||
c[out_idx] = out[0];
|
||||
d[out_idx] = out[1];
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op, int N = 1>
|
||||
template <
|
||||
typename T,
|
||||
typename U,
|
||||
typename Op,
|
||||
int N = 1,
|
||||
typename IdxT = size_t>
|
||||
[[kernel]] void binary_g(
|
||||
device const T* a,
|
||||
device const T* b,
|
||||
@@ -155,13 +159,12 @@ template <typename T, typename U, typename Op, int N = 1>
|
||||
constant const int& ndim,
|
||||
uint3 index [[thread_position_in_grid]],
|
||||
uint3 grid_dim [[threads_per_grid]]) {
|
||||
auto idx = elem_to_loc_2_nd(
|
||||
auto idx = elem_to_loc_2_nd<size_t, IdxT>(
|
||||
{N * index.x, index.y, index.z}, shape, a_strides, b_strides, ndim);
|
||||
auto xshape = shape[ndim - 1];
|
||||
size_t out_idx =
|
||||
N * index.x + xshape * (index.y + size_t(grid_dim.y) * index.z);
|
||||
auto a_xstride = a_strides[ndim - 1];
|
||||
auto b_xstride = b_strides[ndim - 1];
|
||||
IdxT out_idx = N * index.x + xshape * (index.y + IdxT(grid_dim.y) * index.z);
|
||||
IdxT a_xstride = a_strides[ndim - 1];
|
||||
IdxT b_xstride = b_strides[ndim - 1];
|
||||
for (int i = 0; i < N && (int(N * index.x) + i) < xshape; ++i) {
|
||||
auto out = Op()(a[idx.x], b[idx.y]);
|
||||
c[out_idx] = out[0];
|
||||
|
@@ -7,19 +7,21 @@
|
||||
#include "mlx/backend/metal/kernels/binary_ops.h"
|
||||
#include "mlx/backend/metal/kernels/binary_two.h"
|
||||
|
||||
#define instantiate_binary_all(op, tname, itype, otype) \
|
||||
instantiate_kernel("ss_" #op #tname, binary_ss, itype, otype, op) \
|
||||
instantiate_kernel("sv_" #op #tname, binary_sv, itype, otype, op) \
|
||||
instantiate_kernel("vs_" #op #tname, binary_vs, itype, otype, op) \
|
||||
instantiate_kernel("vv_" #op #tname, binary_vv, itype, otype, op) \
|
||||
instantiate_kernel("sv2_" #op #tname, binary_sv2, itype, otype, op) \
|
||||
instantiate_kernel("vs2_" #op #tname, binary_vs2, itype, otype, op) \
|
||||
instantiate_kernel("vv2_" #op #tname, binary_vv2, itype, otype, op) \
|
||||
instantiate_kernel("gn_" #op #tname, binary_g, itype, otype, op) \
|
||||
instantiate_kernel("gn4_" #op #tname, binary_g, itype, otype, op, 4) \
|
||||
instantiate_kernel("g1_" #op #tname, binary_g_nd1, itype, otype, op) \
|
||||
instantiate_kernel("g2_" #op #tname, binary_g_nd2, itype, otype, op) \
|
||||
instantiate_kernel("g3_" #op #tname, binary_g_nd3, itype, otype, op) \
|
||||
#define instantiate_binary_all(op, tname, itype, otype) \
|
||||
instantiate_kernel("ss_" #op #tname, binary_ss, itype, otype, op) \
|
||||
instantiate_kernel("sv_" #op #tname, binary_sv, itype, otype, op) \
|
||||
instantiate_kernel("vs_" #op #tname, binary_vs, itype, otype, op) \
|
||||
instantiate_kernel("vv_" #op #tname, binary_vv, itype, otype, op) \
|
||||
instantiate_kernel("sv2_" #op #tname, binary_sv2, itype, otype, op) \
|
||||
instantiate_kernel("vs2_" #op #tname, binary_vs2, itype, otype, op) \
|
||||
instantiate_kernel("vv2_" #op #tname, binary_vv2, itype, otype, op) \
|
||||
instantiate_kernel("gn2_" #op #tname, binary_g, itype, otype, op, 2, uint) \
|
||||
instantiate_kernel("gn4large_" #op #tname, binary_g, itype, otype, op, 4) \
|
||||
instantiate_kernel("g1_" #op #tname, binary_g_nd1, itype, otype, op) \
|
||||
instantiate_kernel("g2_" #op #tname, binary_g_nd2, itype, otype, op, uint) \
|
||||
instantiate_kernel("g3_" #op #tname, binary_g_nd3, itype, otype, op, uint) \
|
||||
instantiate_kernel("g2large_" #op #tname, binary_g_nd2, itype, otype, op) \
|
||||
instantiate_kernel("g3large_" #op #tname, binary_g_nd3, itype, otype, op)
|
||||
|
||||
#define instantiate_binary_float(op) \
|
||||
instantiate_binary_all(op, float16, half, half) \
|
||||
|
@@ -4,8 +4,8 @@
|
||||
#include <metal_simdgroup_matrix>
|
||||
#include <metal_stdlib>
|
||||
|
||||
#include "mlx/backend/metal/kernels/bf16.h"
|
||||
#include "mlx/backend/metal/kernels/steel/conv/params.h"
|
||||
#include "mlx/backend/metal/kernels/utils.h"
|
||||
|
||||
#define MLX_MTL_CONST static constant constexpr const
|
||||
|
||||
@@ -113,6 +113,7 @@ template <typename T, int N>
|
||||
for (int i = N - 1; i >= 0; --i) {
|
||||
int os_ = (oS % params->oS[i]);
|
||||
int ws_ = (wS % params->wS[i]);
|
||||
out += ws_ * kernel_stride;
|
||||
|
||||
ws_ = params->flip ? params->wS[i] - ws_ - 1 : ws_;
|
||||
|
||||
@@ -126,7 +127,6 @@ template <typename T, int N>
|
||||
oS /= params->oS[i];
|
||||
wS /= params->wS[i];
|
||||
|
||||
out += ws_ * kernel_stride;
|
||||
kernel_stride *= params->wS[i];
|
||||
}
|
||||
|
||||
|
@@ -42,36 +42,36 @@ template <typename T, typename U>
|
||||
device U* dst [[buffer(1)]],
|
||||
constant const int64_t& src_stride [[buffer(3)]],
|
||||
uint index [[thread_position_in_grid]]) {
|
||||
auto src_idx = elem_to_loc_1(index, src_stride);
|
||||
auto src_idx = elem_to_loc_1<int64_t, int>(index, src_stride);
|
||||
dst[index] = static_cast<U>(src[src_idx]);
|
||||
}
|
||||
|
||||
template <typename T, typename U>
|
||||
template <typename T, typename U, typename IdxT = int64_t>
|
||||
[[kernel]] void copy_g_nd2(
|
||||
device const T* src [[buffer(0)]],
|
||||
device U* dst [[buffer(1)]],
|
||||
constant const int64_t* src_strides [[buffer(3)]],
|
||||
uint2 index [[thread_position_in_grid]],
|
||||
uint2 grid_dim [[threads_per_grid]]) {
|
||||
auto src_idx = elem_to_loc_2(index, src_strides);
|
||||
int64_t dst_idx = index.x + (int64_t)grid_dim.x * index.y;
|
||||
auto src_idx = elem_to_loc_2<int64_t, IdxT>(index, src_strides);
|
||||
IdxT dst_idx = index.x + IdxT(grid_dim.x) * index.y;
|
||||
dst[dst_idx] = static_cast<U>(src[src_idx]);
|
||||
}
|
||||
|
||||
template <typename T, typename U>
|
||||
template <typename T, typename U, typename IdxT = int64_t>
|
||||
[[kernel]] void copy_g_nd3(
|
||||
device const T* src [[buffer(0)]],
|
||||
device U* dst [[buffer(1)]],
|
||||
constant const int64_t* src_strides [[buffer(3)]],
|
||||
uint3 index [[thread_position_in_grid]],
|
||||
uint3 grid_dim [[threads_per_grid]]) {
|
||||
auto src_idx = elem_to_loc_3(index, src_strides);
|
||||
int64_t dst_idx =
|
||||
index.x + (int64_t)grid_dim.x * (index.y + (int64_t)grid_dim.y * index.z);
|
||||
auto src_idx = elem_to_loc_3<int64_t, IdxT>(index, src_strides);
|
||||
IdxT dst_idx =
|
||||
index.x + IdxT(grid_dim.x) * (index.y + IdxT(grid_dim.y) * index.z);
|
||||
dst[dst_idx] = static_cast<U>(src[src_idx]);
|
||||
}
|
||||
|
||||
template <typename T, typename U, int N = 1>
|
||||
template <typename T, typename U, int N = 1, typename IdxT = int64_t>
|
||||
[[kernel]] void copy_g(
|
||||
device const T* src [[buffer(0)]],
|
||||
device U* dst [[buffer(1)]],
|
||||
@@ -80,17 +80,16 @@ template <typename T, typename U, int N = 1>
|
||||
constant const int& ndim [[buffer(5)]],
|
||||
uint3 index [[thread_position_in_grid]],
|
||||
uint3 grid_dim [[threads_per_grid]]) {
|
||||
auto src_idx = elem_to_loc(
|
||||
auto src_idx = elem_to_loc<int64_t, IdxT>(
|
||||
{N * index.x, index.y, index.z}, src_shape, src_strides, ndim);
|
||||
if (N == 1) {
|
||||
int64_t dst_idx =
|
||||
index.x + grid_dim.x * (index.y + int64_t(grid_dim.y) * index.z);
|
||||
IdxT dst_idx =
|
||||
index.x + grid_dim.x * (index.y + IdxT(grid_dim.y) * index.z);
|
||||
dst[dst_idx] = static_cast<U>(src[src_idx]);
|
||||
return;
|
||||
}
|
||||
auto xshape = src_shape[ndim - 1];
|
||||
int64_t dst_idx =
|
||||
N * index.x + xshape * (index.y + int64_t(grid_dim.y) * index.z);
|
||||
IdxT dst_idx = N * index.x + xshape * (index.y + IdxT(grid_dim.y) * index.z);
|
||||
auto src_xstride = src_strides[ndim - 1];
|
||||
for (int i = 0; i < N && (int(N * index.x) + i) < xshape; ++i) {
|
||||
dst[dst_idx + i] = static_cast<U>(src[src_idx]);
|
||||
@@ -105,36 +104,36 @@ template <typename T, typename U>
|
||||
constant const int64_t& src_stride [[buffer(3)]],
|
||||
constant const int64_t& dst_stride [[buffer(4)]],
|
||||
uint index [[thread_position_in_grid]]) {
|
||||
auto src_idx = elem_to_loc_1(index, src_stride);
|
||||
auto dst_idx = elem_to_loc_1(index, dst_stride);
|
||||
auto src_idx = elem_to_loc_1<int64_t, int>(index, src_stride);
|
||||
auto dst_idx = elem_to_loc_1<int64_t, int>(index, dst_stride);
|
||||
dst[dst_idx] = static_cast<U>(src[src_idx]);
|
||||
}
|
||||
|
||||
template <typename T, typename U>
|
||||
template <typename T, typename U, typename IdxT = int64_t>
|
||||
[[kernel]] void copy_gg_nd2(
|
||||
device const T* src [[buffer(0)]],
|
||||
device U* dst [[buffer(1)]],
|
||||
constant const int64_t* src_strides [[buffer(3)]],
|
||||
constant const int64_t* dst_strides [[buffer(4)]],
|
||||
uint2 index [[thread_position_in_grid]]) {
|
||||
auto src_idx = elem_to_loc_2(index, src_strides);
|
||||
auto dst_idx = elem_to_loc_2(index, dst_strides);
|
||||
auto src_idx = elem_to_loc_2<int64_t, IdxT>(index, src_strides);
|
||||
auto dst_idx = elem_to_loc_2<int64_t, IdxT>(index, dst_strides);
|
||||
dst[dst_idx] = static_cast<U>(src[src_idx]);
|
||||
}
|
||||
|
||||
template <typename T, typename U>
|
||||
template <typename T, typename U, typename IdxT = int64_t>
|
||||
[[kernel]] void copy_gg_nd3(
|
||||
device const T* src [[buffer(0)]],
|
||||
device U* dst [[buffer(1)]],
|
||||
constant const int64_t* src_strides [[buffer(3)]],
|
||||
constant const int64_t* dst_strides [[buffer(4)]],
|
||||
uint3 index [[thread_position_in_grid]]) {
|
||||
auto src_idx = elem_to_loc_3(index, src_strides);
|
||||
auto dst_idx = elem_to_loc_3(index, dst_strides);
|
||||
auto src_idx = elem_to_loc_3<int64_t, IdxT>(index, src_strides);
|
||||
auto dst_idx = elem_to_loc_3<int64_t, IdxT>(index, dst_strides);
|
||||
dst[dst_idx] = static_cast<U>(src[src_idx]);
|
||||
}
|
||||
|
||||
template <typename T, typename U, int N = 1>
|
||||
template <typename T, typename U, int N = 1, typename IdxT = int64_t>
|
||||
[[kernel]] void copy_gg(
|
||||
device const T* src [[buffer(0)]],
|
||||
device U* dst [[buffer(1)]],
|
||||
@@ -143,7 +142,7 @@ template <typename T, typename U, int N = 1>
|
||||
constant const int64_t* dst_strides [[buffer(4)]],
|
||||
constant const int& ndim [[buffer(5)]],
|
||||
uint3 index [[thread_position_in_grid]]) {
|
||||
auto idx = elem_to_loc_2_nd(
|
||||
auto idx = elem_to_loc_2_nd<int64_t, IdxT>(
|
||||
{N * index.x, index.y, index.z},
|
||||
src_shape,
|
||||
src_strides,
|
||||
@@ -153,8 +152,8 @@ template <typename T, typename U, int N = 1>
|
||||
dst[idx.y] = static_cast<U>(src[idx.x]);
|
||||
return;
|
||||
}
|
||||
auto src_xstride = src_strides[ndim - 1];
|
||||
auto dst_xstride = dst_strides[ndim - 1];
|
||||
IdxT src_xstride = src_strides[ndim - 1];
|
||||
IdxT dst_xstride = dst_strides[ndim - 1];
|
||||
auto xshape = src_shape[ndim - 1];
|
||||
for (int i = 0; i < N && (int(N * index.x) + i) < xshape; ++i) {
|
||||
dst[idx.y] = static_cast<U>(src[idx.x]);
|
||||
|
@@ -2,24 +2,27 @@
|
||||
|
||||
// clang-format off
|
||||
#include "mlx/backend/metal/kernels/utils.h"
|
||||
#include "mlx/backend/metal/kernels/bf16.h"
|
||||
#include "mlx/backend/metal/kernels/copy.h"
|
||||
|
||||
#define instantiate_copy_all(tname, itype, otype) \
|
||||
instantiate_kernel("s_copy" #tname, copy_s, itype, otype) \
|
||||
instantiate_kernel("v_copy" #tname, copy_v, itype, otype) \
|
||||
instantiate_kernel("s2_copy" #tname, copy_s2, itype, otype) \
|
||||
instantiate_kernel("v2_copy" #tname, copy_v2, itype, otype) \
|
||||
instantiate_kernel("g1_copy" #tname, copy_g_nd1, itype, otype) \
|
||||
instantiate_kernel("g2_copy" #tname, copy_g_nd2, itype, otype) \
|
||||
instantiate_kernel("g3_copy" #tname, copy_g_nd3, itype, otype) \
|
||||
instantiate_kernel("gg1_copy" #tname, copy_gg_nd1, itype, otype) \
|
||||
instantiate_kernel("gg2_copy" #tname, copy_gg_nd2, itype, otype) \
|
||||
instantiate_kernel("gg3_copy" #tname, copy_gg_nd3, itype, otype) \
|
||||
instantiate_kernel("g_copy" #tname, copy_g, itype, otype) \
|
||||
instantiate_kernel("gn4_copy" #tname, copy_g, itype, otype, 4) \
|
||||
instantiate_kernel("gg_copy" #tname, copy_gg, itype, otype) \
|
||||
instantiate_kernel("ggn4_copy" #tname, copy_gg, itype, otype, 4)
|
||||
#define instantiate_copy_all(tname, itype, otype) \
|
||||
instantiate_kernel("s_copy" #tname, copy_s, itype, otype) \
|
||||
instantiate_kernel("v_copy" #tname, copy_v, itype, otype) \
|
||||
instantiate_kernel("s2_copy" #tname, copy_s2, itype, otype) \
|
||||
instantiate_kernel("v2_copy" #tname, copy_v2, itype, otype) \
|
||||
instantiate_kernel("g1_copy" #tname, copy_g_nd1, itype, otype) \
|
||||
instantiate_kernel("g2_copy" #tname, copy_g_nd2, itype, otype, int) \
|
||||
instantiate_kernel("g3_copy" #tname, copy_g_nd3, itype, otype, int) \
|
||||
instantiate_kernel("gg1_copy" #tname, copy_gg_nd1, itype, otype) \
|
||||
instantiate_kernel("gg2_copy" #tname, copy_gg_nd2, itype, otype, int) \
|
||||
instantiate_kernel("gg3_copy" #tname, copy_gg_nd3, itype, otype, int) \
|
||||
instantiate_kernel("gn2_copy" #tname, copy_g, itype, otype, 2, int) \
|
||||
instantiate_kernel("ggn2_copy" #tname, copy_gg, itype, otype, 2, int) \
|
||||
instantiate_kernel("g2large_copy" #tname, copy_g_nd2, itype, otype) \
|
||||
instantiate_kernel("g3large_copy" #tname, copy_g_nd3, itype, otype) \
|
||||
instantiate_kernel("gg2large_copy" #tname, copy_gg_nd2, itype, otype) \
|
||||
instantiate_kernel("gg3large_copy" #tname, copy_gg_nd3, itype, otype) \
|
||||
instantiate_kernel("gn4large_copy" #tname, copy_g, itype, otype, 4) \
|
||||
instantiate_kernel("ggn4large_copy" #tname, copy_gg, itype, otype, 4)
|
||||
|
||||
#define instantiate_copy_itype(itname, itype) \
|
||||
instantiate_copy_all(itname ##bool_, itype, bool) \
|
||||
|
@@ -4,7 +4,7 @@
|
||||
|
||||
#include "mlx/backend/metal/kernels/indexing.h"
|
||||
|
||||
template <typename T, typename IdxT, int NIDX, int IDX_NDIM>
|
||||
template <typename T, typename IdxT, int NIDX, int IDX_NDIM, typename LocT>
|
||||
METAL_FUNC void gather_impl(
|
||||
const device T* src [[buffer(0)]],
|
||||
device T* out [[buffer(1)]],
|
||||
@@ -16,34 +16,36 @@ METAL_FUNC void gather_impl(
|
||||
const thread Indices<IdxT, NIDX>& indices,
|
||||
uint3 index [[thread_position_in_grid]],
|
||||
uint3 grid_dim [[threads_per_grid]]) {
|
||||
size_t src_idx = 0;
|
||||
LocT src_idx = 0;
|
||||
for (int i = 0; i < NIDX; ++i) {
|
||||
size_t idx_loc;
|
||||
LocT idx_loc;
|
||||
if (IDX_NDIM == 0) {
|
||||
idx_loc = 0;
|
||||
} else if (IDX_NDIM == 1) {
|
||||
idx_loc = index.x * indices.strides[indices.ndim * i];
|
||||
idx_loc = index.x * static_cast<LocT>(indices.strides[indices.ndim * i]);
|
||||
} else {
|
||||
idx_loc = index.x * indices.strides[indices.ndim * i];
|
||||
idx_loc += elem_to_loc(
|
||||
index.y,
|
||||
&indices.shapes[indices.ndim * i + 1],
|
||||
&indices.strides[indices.ndim * i + 1],
|
||||
indices.ndim - 1);
|
||||
idx_loc = index.x * static_cast<LocT>(indices.strides[indices.ndim * i]);
|
||||
idx_loc += indices.row_contiguous[i]
|
||||
? index.y
|
||||
: elem_to_loc<size_t, LocT>(
|
||||
index.y,
|
||||
&indices.shapes[indices.ndim * i + 1],
|
||||
&indices.strides[indices.ndim * i + 1],
|
||||
indices.ndim - 1);
|
||||
}
|
||||
auto ax = axes[i];
|
||||
auto idx_val = offset_neg_idx(indices.buffers[i][idx_loc], src_shape[ax]);
|
||||
src_idx += idx_val * src_strides[ax];
|
||||
src_idx += static_cast<LocT>(idx_val) * static_cast<LocT>(src_strides[ax]);
|
||||
}
|
||||
|
||||
auto src_offset = elem_to_loc(index.z, slice_sizes, src_strides, src_ndim);
|
||||
auto src_offset =
|
||||
elem_to_loc<size_t, LocT>(index.z, slice_sizes, src_strides, src_ndim);
|
||||
|
||||
size_t out_idx = index.z;
|
||||
LocT out_idx = index.z;
|
||||
if (IDX_NDIM == 1) {
|
||||
out_idx += static_cast<size_t>(grid_dim.z) * index.x;
|
||||
out_idx += static_cast<LocT>(grid_dim.z) * index.x;
|
||||
} else if (IDX_NDIM >= 2) {
|
||||
out_idx +=
|
||||
grid_dim.z * (index.x * static_cast<size_t>(grid_dim.y) + index.y);
|
||||
out_idx += grid_dim.z * (index.x * static_cast<LocT>(grid_dim.y) + index.y);
|
||||
}
|
||||
out[out_idx] = src[src_offset + src_idx];
|
||||
}
|
||||
|
@@ -3,8 +3,6 @@
|
||||
#include <metal_simdgroup>
|
||||
#include <metal_stdlib>
|
||||
|
||||
#include "mlx/backend/metal/kernels/bf16.h"
|
||||
#include "mlx/backend/metal/kernels/defines.h"
|
||||
#include "mlx/backend/metal/kernels/utils.h"
|
||||
|
||||
#include "mlx/backend/metal/kernels/steel/utils.h"
|
||||
@@ -912,4 +910,4 @@ template <
|
||||
// clang-format off
|
||||
instantiate_gemv_t_bs_blocks(float32, float);
|
||||
instantiate_gemv_t_bs_blocks(float16, half);
|
||||
instantiate_gemv_t_bs_blocks(bfloat16, bfloat16_t); // clang-format on
|
||||
instantiate_gemv_t_bs_blocks(bfloat16, bfloat16_t); // clang-format on
|
||||
|
@@ -4,8 +4,6 @@
|
||||
#include <metal_simdgroup>
|
||||
#include <metal_stdlib>
|
||||
|
||||
#include "mlx/backend/metal/kernels/bf16.h"
|
||||
#include "mlx/backend/metal/kernels/defines.h"
|
||||
#include "mlx/backend/metal/kernels/utils.h"
|
||||
|
||||
#include "mlx/backend/metal/kernels/gemv_masked.h"
|
||||
|
@@ -9,11 +9,12 @@ struct Indices {
|
||||
const array<const device IdxT*, NIDX> buffers;
|
||||
const constant int* shapes;
|
||||
const constant size_t* strides;
|
||||
const constant bool* row_contiguous;
|
||||
const int ndim;
|
||||
};
|
||||
|
||||
template <typename IdxT>
|
||||
METAL_FUNC size_t offset_neg_idx(IdxT idx, size_t size) {
|
||||
METAL_FUNC size_t offset_neg_idx(IdxT idx, int size) {
|
||||
if (is_unsigned_v<IdxT>) {
|
||||
return idx;
|
||||
} else {
|
||||
|
16
mlx/backend/metal/kernels/jit/bf16.h
Normal file
16
mlx/backend/metal/kernels/jit/bf16.h
Normal file
@@ -0,0 +1,16 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
// clang-format off
|
||||
#define jit_if #if
|
||||
#define jit_else #else
|
||||
#define jit_endif #endif
|
||||
|
||||
jit_if (__METAL_VERSION__ >= 310)
|
||||
|
||||
#include "mlx/backend/metal/kernels/metal_3_1/bf16.h"
|
||||
|
||||
jit_else
|
||||
|
||||
#include "mlx/backend/metal/kernels/metal_3_0/bf16.h"
|
||||
|
||||
jit_endif // clang-format on
|
@@ -3,8 +3,6 @@
|
||||
#include <metal_common>
|
||||
#include <metal_simdgroup>
|
||||
|
||||
#include "mlx/backend/metal/kernels/bf16.h"
|
||||
#include "mlx/backend/metal/kernels/defines.h"
|
||||
#include "mlx/backend/metal/kernels/utils.h"
|
||||
|
||||
using namespace metal;
|
||||
|
@@ -6,12 +6,6 @@
|
||||
|
||||
using namespace metal;
|
||||
|
||||
#if (MLX_METAL_VERSION >= 310) || (__METAL_VERSION__ >= 310)
|
||||
|
||||
typedef bfloat bfloat16_t;
|
||||
|
||||
#else
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
// Helpers
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
@@ -311,7 +305,10 @@ METAL_FUNC bool isnan(_MLX_BFloat16 x) {
|
||||
} // namespace metal
|
||||
|
||||
#pragma METAL internals : disable
|
||||
inline uint16_t bfloat16_to_uint16(const bfloat16_t x) {
|
||||
return x.bits_;
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
#include "mlx/backend/metal/kernels/bf16_math.h"
|
||||
inline bfloat16_t uint16_to_bfloat16(const uint16_t x) {
|
||||
return _MLX_BFloat16(x, _MLX_BFloat16::bits_to_bfloat());
|
||||
}
|
16
mlx/backend/metal/kernels/metal_3_1/bf16.h
Normal file
16
mlx/backend/metal/kernels/metal_3_1/bf16.h
Normal file
@@ -0,0 +1,16 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <metal_stdlib>
|
||||
|
||||
using namespace metal;
|
||||
|
||||
typedef bfloat bfloat16_t;
|
||||
inline uint16_t bfloat16_to_uint16(const bfloat16_t x) {
|
||||
return as_type<uint16_t>(x);
|
||||
}
|
||||
|
||||
inline bfloat16_t uint16_to_bfloat16(const uint16_t x) {
|
||||
return as_type<bfloat16_t>(x);
|
||||
}
|
File diff suppressed because it is too large
Load Diff
@@ -5,67 +5,119 @@
|
||||
#include "mlx/backend/metal/kernels/steel/gemm/gemm.h"
|
||||
#include "mlx/backend/metal/kernels/quantized.h"
|
||||
|
||||
#define instantiate_quantized(name, type, group_size, bits) \
|
||||
instantiate_kernel( \
|
||||
#name "_" #type "_gs_" #group_size "_b_" #bits, \
|
||||
name, \
|
||||
type, \
|
||||
group_size, \
|
||||
#define instantiate_quantized(name, type, group_size, bits) \
|
||||
instantiate_kernel( \
|
||||
#name "_" #type "_gs_" #group_size "_b_" #bits, \
|
||||
name, \
|
||||
type, \
|
||||
group_size, \
|
||||
bits)
|
||||
|
||||
#define instantiate_quantized_types(name, group_size, bits) \
|
||||
instantiate_quantized(name, float, group_size, bits) \
|
||||
instantiate_quantized(name, float16_t, group_size, bits) \
|
||||
instantiate_quantized(name, bfloat16_t, group_size, bits)
|
||||
#define instantiate_quantized_batched(name, type, group_size, bits, batched) \
|
||||
instantiate_kernel( \
|
||||
#name "_" #type "_gs_" #group_size "_b_" #bits "_batch_" #batched, \
|
||||
name, \
|
||||
type, \
|
||||
group_size, \
|
||||
bits, \
|
||||
batched)
|
||||
|
||||
#define instantiate_quantized_groups(name, bits) \
|
||||
instantiate_quantized_types(name, 128, bits) \
|
||||
instantiate_quantized_types(name, 64, bits) \
|
||||
instantiate_quantized_types(name, 32, bits)
|
||||
|
||||
#define instantiate_quantized_all(name) \
|
||||
instantiate_quantized_groups(name, 2) \
|
||||
instantiate_quantized_groups(name, 4) \
|
||||
instantiate_quantized_groups(name, 8)
|
||||
|
||||
instantiate_quantized_all(qmv_fast)
|
||||
instantiate_quantized_all(qmv)
|
||||
instantiate_quantized_all(qvm)
|
||||
instantiate_quantized_all(qmm_n)
|
||||
instantiate_quantized_all(bs_qmv_fast)
|
||||
instantiate_quantized_all(bs_qmv)
|
||||
instantiate_quantized_all(bs_qvm)
|
||||
instantiate_quantized_all(bs_qmm_n)
|
||||
instantiate_quantized_all(affine_quantize)
|
||||
instantiate_quantized_all(affine_quantize_scales_biases)
|
||||
instantiate_quantized_all(affine_dequantize)
|
||||
|
||||
#define instantiate_quantized_aligned(name, type, group_size, bits, aligned) \
|
||||
instantiate_kernel( \
|
||||
#name "_" #type "_gs_" #group_size "_b_" #bits "_alN_" #aligned, \
|
||||
#define instantiate_quantized_aligned(name, type, group_size, bits, aligned) \
|
||||
instantiate_kernel( \
|
||||
#name "_" #type "_gs_" #group_size "_b_" #bits "_alN_" #aligned, \
|
||||
name, \
|
||||
type, \
|
||||
group_size, \
|
||||
bits, \
|
||||
aligned)
|
||||
|
||||
#define instantiate_quantized_types_aligned(name, group_size, bits) \
|
||||
instantiate_quantized_aligned(name, float, group_size, bits, true) \
|
||||
instantiate_quantized_aligned(name, float16_t, group_size, bits, true) \
|
||||
instantiate_quantized_aligned(name, bfloat16_t, group_size, bits, true) \
|
||||
instantiate_quantized_aligned(name, float, group_size, bits, false) \
|
||||
instantiate_quantized_aligned(name, float16_t, group_size, bits, false) \
|
||||
instantiate_quantized_aligned(name, bfloat16_t, group_size, bits, false)
|
||||
#define instantiate_quantized_aligned_batched(name, type, group_size, bits, aligned, batched) \
|
||||
instantiate_kernel( \
|
||||
#name "_" #type "_gs_" #group_size "_b_" #bits "_alN_" #aligned "_batch_" #batched, \
|
||||
name, \
|
||||
type, \
|
||||
group_size, \
|
||||
bits, \
|
||||
aligned, \
|
||||
batched)
|
||||
|
||||
#define instantiate_quantized_groups_aligned(name, bits) \
|
||||
instantiate_quantized_types_aligned(name, 128, bits) \
|
||||
instantiate_quantized_types_aligned(name, 64, bits) \
|
||||
instantiate_quantized_types_aligned(name, 32, bits)
|
||||
#define instantiate_quantized_quad(name, type, group_size, bits, D, batched) \
|
||||
instantiate_kernel( \
|
||||
#name "_" #type "_gs_" #group_size "_b_" #bits "_d_" #D "_batch_" #batched, \
|
||||
name, \
|
||||
type, \
|
||||
group_size, \
|
||||
bits, \
|
||||
D, \
|
||||
batched)
|
||||
|
||||
#define instantiate_quantized_all_aligned(name) \
|
||||
instantiate_quantized_groups_aligned(name, 2) \
|
||||
instantiate_quantized_groups_aligned(name, 4) \
|
||||
instantiate_quantized_groups_aligned(name, 8) \
|
||||
#define instantiate_quantized_split_k(name, type, group_size, bits, split_k) \
|
||||
instantiate_kernel( \
|
||||
#name "_" #type "_gs_" #group_size "_b_" #bits "_spk_" #split_k, \
|
||||
name, \
|
||||
type, \
|
||||
group_size, \
|
||||
bits, \
|
||||
split_k)
|
||||
|
||||
instantiate_quantized_all_aligned(qmm_t)
|
||||
instantiate_quantized_all_aligned(bs_qmm_t) // clang-format on
|
||||
#define instantiate_quantized_batched_wrap(name, type, group_size, bits) \
|
||||
instantiate_quantized_batched(name, type, group_size, bits, 1) \
|
||||
instantiate_quantized_batched(name, type, group_size, bits, 0)
|
||||
|
||||
#define instantiate_quantized_all_batched(type, group_size, bits) \
|
||||
instantiate_quantized_batched_wrap(qmv_fast, type, group_size, bits) \
|
||||
instantiate_quantized_batched_wrap(qmv, type, group_size, bits) \
|
||||
instantiate_quantized_batched_wrap(qvm, type, group_size, bits) \
|
||||
instantiate_quantized_batched_wrap(qmm_n, type, group_size, bits)
|
||||
|
||||
#define instantiate_quantized_all_single(type, group_size, bits) \
|
||||
instantiate_quantized(affine_quantize, type, group_size, bits) \
|
||||
instantiate_quantized(affine_dequantize, type, group_size, bits) \
|
||||
instantiate_quantized(bs_qmv_fast, type, group_size, bits) \
|
||||
instantiate_quantized(bs_qmv, type, group_size, bits) \
|
||||
instantiate_quantized(bs_qvm, type, group_size, bits) \
|
||||
instantiate_quantized(bs_qmm_n, type, group_size, bits)
|
||||
|
||||
#define instantiate_quantized_all_aligned(type, group_size, bits) \
|
||||
instantiate_quantized_aligned(bs_qmm_t, type, group_size, bits, true) \
|
||||
instantiate_quantized_aligned(bs_qmm_t, type, group_size, bits, false) \
|
||||
instantiate_quantized_aligned_batched(qmm_t, type, group_size, bits, true, 1) \
|
||||
instantiate_quantized_aligned_batched(qmm_t, type, group_size, bits, true, 0) \
|
||||
instantiate_quantized_aligned_batched(qmm_t, type, group_size, bits, false, 1) \
|
||||
instantiate_quantized_aligned_batched(qmm_t, type, group_size, bits, false, 0)
|
||||
|
||||
#define instantiate_quantized_all_quad(type, group_size, bits) \
|
||||
instantiate_quantized_quad(qmv_quad, type, group_size, bits, 64, 1) \
|
||||
instantiate_quantized_quad(qmv_quad, type, group_size, bits, 64, 0) \
|
||||
instantiate_quantized_quad(qmv_quad, type, group_size, bits, 128, 1) \
|
||||
instantiate_quantized_quad(qmv_quad, type, group_size, bits, 128, 0)
|
||||
|
||||
#define instantiate_quantized_all_splitk(type, group_size, bits) \
|
||||
instantiate_quantized_split_k(qvm_split_k, type, group_size, bits, 8) \
|
||||
instantiate_quantized_split_k(qvm_split_k, type, group_size, bits, 32)
|
||||
|
||||
#define instantiate_quantized_funcs(type, group_size, bits) \
|
||||
instantiate_quantized_all_single(type, group_size, bits) \
|
||||
instantiate_quantized_all_batched(type, group_size, bits) \
|
||||
instantiate_quantized_all_aligned(type, group_size, bits) \
|
||||
instantiate_quantized_all_quad(type, group_size, bits) \
|
||||
instantiate_quantized_all_splitk(type, group_size, bits)
|
||||
|
||||
#define instantiate_quantized_types(group_size, bits) \
|
||||
instantiate_quantized_funcs(float, group_size, bits) \
|
||||
instantiate_quantized_funcs(float16_t, group_size, bits) \
|
||||
instantiate_quantized_funcs(bfloat16_t, group_size, bits)
|
||||
|
||||
#define instantiate_quantized_groups(bits) \
|
||||
instantiate_quantized_types(128, bits) \
|
||||
instantiate_quantized_types(64, bits) \
|
||||
instantiate_quantized_types(32, bits)
|
||||
|
||||
#define instantiate_quantized_all() \
|
||||
instantiate_quantized_groups(2) \
|
||||
instantiate_quantized_groups(3) \
|
||||
instantiate_quantized_groups(4) \
|
||||
instantiate_quantized_groups(6) \
|
||||
instantiate_quantized_groups(8)
|
||||
|
||||
instantiate_quantized_all() // clang-format on
|
||||
|
@@ -34,8 +34,8 @@ rbits threefry2x32_hash(const thread uint2& key, uint2 count) {
|
||||
[[kernel]] void rbitsc(
|
||||
device const uint32_t* keys,
|
||||
device char* out,
|
||||
device const bool& odd,
|
||||
device const uint& bytes_per_key,
|
||||
constant const bool& odd,
|
||||
constant const uint& bytes_per_key,
|
||||
uint2 grid_dim [[threads_per_grid]],
|
||||
uint2 index [[thread_position_in_grid]]) {
|
||||
auto kidx = 2 * index.x;
|
||||
@@ -67,8 +67,8 @@ rbits threefry2x32_hash(const thread uint2& key, uint2 count) {
|
||||
[[kernel]] void rbits(
|
||||
device const uint32_t* keys,
|
||||
device char* out,
|
||||
device const bool& odd,
|
||||
device const uint& bytes_per_key,
|
||||
constant const bool& odd,
|
||||
constant const uint& bytes_per_key,
|
||||
constant const int& ndim,
|
||||
constant const int* key_shape,
|
||||
constant const size_t* key_strides,
|
||||
|
@@ -10,178 +10,156 @@
|
||||
#include "mlx/backend/metal/kernels/reduction/ops.h"
|
||||
#include "mlx/backend/metal/kernels/reduce.h"
|
||||
|
||||
#define instantiate_reduce_helper_floats(inst_f, name, op) \
|
||||
inst_f(name, float16, half, op) \
|
||||
inst_f(name, float32, float, op) \
|
||||
inst_f(name, bfloat16, bfloat16_t, op)
|
||||
#define instantiate_init_reduce(name, tname, type, op) \
|
||||
instantiate_kernel("init_reduce_" #name #tname, init_reduce, type, op<type>)
|
||||
|
||||
#define instantiate_reduce_helper_uints(inst_f, name, op) \
|
||||
inst_f(name, uint8, uint8_t, op) \
|
||||
inst_f(name, uint16, uint16_t, op) \
|
||||
inst_f(name, uint32, uint32_t, op)
|
||||
instantiate_init_reduce(and, bool_, bool, And)
|
||||
instantiate_init_reduce(or, bool_, bool, Or)
|
||||
|
||||
#define instantiate_reduce_helper_ints(inst_f, name, op) \
|
||||
inst_f(name, int8, int8_t, op) \
|
||||
inst_f(name, int16, int16_t, op) \
|
||||
inst_f(name, int32, int32_t, op)
|
||||
#define instantiate_init_sum_prod(name, op) \
|
||||
instantiate_init_reduce(name, int32, int32_t, op) \
|
||||
instantiate_init_reduce(name, int64, int64_t, op) \
|
||||
instantiate_init_reduce(name, float16, float16_t, op) \
|
||||
instantiate_init_reduce(name, bfloat16, bfloat16_t, op) \
|
||||
instantiate_init_reduce(name, float32, float, op) \
|
||||
instantiate_init_reduce(name, complex64, complex64_t, op)
|
||||
|
||||
#define instantiate_reduce_helper_64b(inst_f, name, op) \
|
||||
inst_f(name, int64, int64_t, op) \
|
||||
inst_f(name, uint64, uint64_t, op) \
|
||||
inst_f(name, complex64, complex64_t, op)
|
||||
instantiate_init_sum_prod(sum, Sum)
|
||||
instantiate_init_sum_prod(prod, Prod)
|
||||
|
||||
#define instantiate_reduce_helper_types(inst_f, name, op) \
|
||||
instantiate_reduce_helper_floats(inst_f, name, op) \
|
||||
instantiate_reduce_helper_uints(inst_f, name, op) \
|
||||
instantiate_reduce_helper_ints(inst_f, name, op)
|
||||
#define instantiate_init_min_max(name, op) \
|
||||
instantiate_init_reduce(name, bool_, bool, op) \
|
||||
instantiate_init_reduce(name, int8, int8_t, op) \
|
||||
instantiate_init_reduce(name, int16, int16_t, op) \
|
||||
instantiate_init_reduce(name, int32, int32_t, op) \
|
||||
instantiate_init_reduce(name, int64, int64_t, op) \
|
||||
instantiate_init_reduce(name, uint8, uint8_t, op) \
|
||||
instantiate_init_reduce(name, uint16, uint16_t, op) \
|
||||
instantiate_init_reduce(name, uint32, uint32_t, op) \
|
||||
instantiate_init_reduce(name, uint64, uint64_t, op) \
|
||||
instantiate_init_reduce(name, float16, float16_t, op) \
|
||||
instantiate_init_reduce(name, bfloat16, bfloat16_t, op) \
|
||||
instantiate_init_reduce(name, float32, float, op) \
|
||||
instantiate_init_reduce(name, complex64, complex64_t, op)
|
||||
|
||||
#define instantiate_reduce_ops(inst_f, type_f) \
|
||||
type_f(inst_f, sum, Sum) \
|
||||
type_f(inst_f, prod, Prod) \
|
||||
type_f(inst_f, min, Min) \
|
||||
type_f(inst_f, max, Max)
|
||||
|
||||
// Special case for bool reductions
|
||||
#define instantiate_reduce_from_types_helper( \
|
||||
inst_f, name, tname, itype, otype, op) \
|
||||
inst_f(name##tname, itype, otype, op)
|
||||
|
||||
#define instantiate_reduce_from_types(inst_f, name, otype, op) \
|
||||
instantiate_reduce_from_types_helper( \
|
||||
inst_f, name, bool_, bool, otype, op) \
|
||||
instantiate_reduce_from_types_helper( \
|
||||
inst_f, name, uint8, uint8_t, otype, op) \
|
||||
instantiate_reduce_from_types_helper( \
|
||||
inst_f, name, uint16, uint16_t, otype, op) \
|
||||
instantiate_reduce_from_types_helper( \
|
||||
inst_f, name, uint32, uint32_t, otype, op) \
|
||||
instantiate_reduce_from_types_helper( \
|
||||
inst_f, name, uint64, uint64_t, otype, op) \
|
||||
instantiate_reduce_from_types_helper( \
|
||||
inst_f, name, int8, int8_t, otype, op) \
|
||||
instantiate_reduce_from_types_helper( \
|
||||
inst_f, name, int16, int16_t, otype, op) \
|
||||
instantiate_reduce_from_types_helper( \
|
||||
inst_f, name, int32, int32_t, otype, op) \
|
||||
instantiate_reduce_from_types_helper( \
|
||||
inst_f, name, int64, int64_t, otype, op) \
|
||||
instantiate_reduce_from_types_helper( \
|
||||
inst_f, name, float16, half, otype, op) \
|
||||
instantiate_reduce_from_types_helper( \
|
||||
inst_f, \
|
||||
name, \
|
||||
float32, \
|
||||
float, \
|
||||
otype, \
|
||||
op) \
|
||||
instantiate_reduce_from_types_helper( \
|
||||
inst_f, \
|
||||
name, \
|
||||
bfloat16, \
|
||||
bfloat16_t, \
|
||||
otype, \
|
||||
op)
|
||||
|
||||
#define instantiate_init_reduce(name, otype, op) \
|
||||
instantiate_kernel("init_reduce_" #name, \
|
||||
init_reduce, \
|
||||
otype, op)
|
||||
|
||||
#define instantiate_init_reduce_helper(name, tname, type, op) \
|
||||
instantiate_init_reduce(name##tname, type, op<type>)
|
||||
|
||||
instantiate_reduce_ops(instantiate_init_reduce_helper, instantiate_reduce_helper_types)
|
||||
instantiate_reduce_ops(instantiate_init_reduce_helper, instantiate_reduce_helper_64b)
|
||||
|
||||
instantiate_init_reduce(andbool_, bool, And<bool>)
|
||||
instantiate_init_reduce(orbool_, bool, Or<bool>)
|
||||
instantiate_init_min_max(min, Min)
|
||||
instantiate_init_min_max(max, Max)
|
||||
|
||||
#define instantiate_all_reduce(name, itype, otype, op) \
|
||||
instantiate_kernel("all_reduce_" #name, \
|
||||
all_reduce, \
|
||||
itype, otype, op)
|
||||
|
||||
#define instantiate_same_all_reduce_helper(name, tname, type, op) \
|
||||
instantiate_all_reduce(name##tname, type, type, op<type>)
|
||||
#define instantiate_col_reduce_small(name, itype, otype, op, dim) \
|
||||
instantiate_kernel("col_reduce_small_" #dim "_reduce_" #name, \
|
||||
col_reduce_small, \
|
||||
itype, otype, op, uint, dim) \
|
||||
instantiate_kernel("col_reduce_longcolumn_" #dim "_reduce_" #name, \
|
||||
col_reduce_longcolumn, \
|
||||
itype, otype, op, uint, dim) \
|
||||
instantiate_kernel("col_reduce_small_large_" #dim "_reduce_" #name, \
|
||||
col_reduce_small, \
|
||||
itype, otype, op, size_t, dim) \
|
||||
instantiate_kernel("col_reduce_longcolumn_large_" #dim "_reduce_" #name, \
|
||||
col_reduce_longcolumn, \
|
||||
itype, otype, op, size_t, dim)
|
||||
|
||||
instantiate_reduce_ops(instantiate_same_all_reduce_helper, instantiate_reduce_helper_types)
|
||||
instantiate_reduce_ops(instantiate_same_all_reduce_helper, instantiate_reduce_helper_64b)
|
||||
#define instantiate_col_reduce_looped_tile(name, itype, otype, op, dim, bm, bn) \
|
||||
instantiate_kernel("col_reduce_looped_" #dim "_" #bm "_" #bn "_reduce_" #name, \
|
||||
col_reduce_looped, \
|
||||
itype, otype, op, uint, dim, bm, bn) \
|
||||
instantiate_kernel("col_reduce_looped_large_" #dim "_" #bm "_" #bn "_reduce_" #name, \
|
||||
col_reduce_looped, \
|
||||
itype, otype, op, size_t, dim, bm, bn)
|
||||
|
||||
instantiate_reduce_from_types(instantiate_all_reduce, and, bool, And<bool>)
|
||||
instantiate_reduce_from_types(instantiate_all_reduce, or, bool, Or<bool>)
|
||||
|
||||
// special case bool with larger output type
|
||||
instantiate_all_reduce(sumbool_, bool, uint32_t, Sum<uint32_t>)
|
||||
|
||||
#define instantiate_col_reduce_small(name, itype, otype, op, dim) \
|
||||
instantiate_kernel("col_reduce_small_" #dim "_reduce_" #name, \
|
||||
col_reduce_small, \
|
||||
itype, otype, op, dim)
|
||||
|
||||
#define instantiate_col_reduce_looped_tile(name, itype, otype, op, dim, bm, bn) \
|
||||
instantiate_kernel("col_reduce_looped_" #dim "_" #bm "_" #bn "_reduce_" #name, \
|
||||
col_reduce_looped, \
|
||||
itype, otype, op, dim, bm, bn)
|
||||
#define instantiate_col_reduce_2pass_tile(name, itype, otype, op, dim, bm, bn) \
|
||||
instantiate_kernel("col_reduce_2pass_" #dim "_" #bm "_" #bn "_reduce_" #name, \
|
||||
col_reduce_2pass, \
|
||||
itype, otype, op, uint, dim, bm, bn) \
|
||||
instantiate_kernel("col_reduce_2pass_large_" #dim "_" #bm "_" #bn "_reduce_" #name, \
|
||||
col_reduce_2pass, \
|
||||
itype, otype, op, size_t, dim, bm, bn)
|
||||
|
||||
#define instantiate_col_reduce_looped(name, itype, otype, op, dim) \
|
||||
instantiate_col_reduce_looped_tile(name, itype, otype, op, dim, 8, 128) \
|
||||
instantiate_col_reduce_looped_tile(name, itype, otype, op, dim, 32, 32)
|
||||
instantiate_col_reduce_looped_tile(name, itype, otype, op, dim, 32, 32) \
|
||||
instantiate_col_reduce_2pass_tile(name, itype, otype, op, dim, 32, 32)
|
||||
|
||||
#define instantiate_col_reduce_general(name, itype, otype, op) \
|
||||
instantiate_col_reduce_small(name, itype, otype, op, 0) \
|
||||
instantiate_col_reduce_small(name, itype, otype, op, 1) \
|
||||
instantiate_col_reduce_small(name, itype, otype, op, 2) \
|
||||
instantiate_col_reduce_small(name, itype, otype, op, 3) \
|
||||
instantiate_col_reduce_small(name, itype, otype, op, 4) \
|
||||
instantiate_col_reduce_looped(name, itype, otype, op, 0) \
|
||||
instantiate_col_reduce_small(name, itype, otype, op, 5) \
|
||||
instantiate_col_reduce_looped(name, itype, otype, op, 1) \
|
||||
instantiate_col_reduce_looped(name, itype, otype, op, 2) \
|
||||
instantiate_col_reduce_looped(name, itype, otype, op, 3) \
|
||||
instantiate_col_reduce_looped(name, itype, otype, op, 4)
|
||||
instantiate_col_reduce_looped(name, itype, otype, op, 5)
|
||||
|
||||
#define instantiate_same_col_reduce_helper(name, tname, type, op) \
|
||||
instantiate_col_reduce_general(name##tname, type, type, op<type>)
|
||||
#define instantiate_row_reduce_small(name, itype, otype, op, dim) \
|
||||
instantiate_kernel("row_reduce_small_" #dim "_reduce_" #name, \
|
||||
row_reduce_small, \
|
||||
itype, otype, op, uint, dim) \
|
||||
instantiate_kernel("row_reduce_small_large_" #dim "_reduce_" #name, \
|
||||
row_reduce_small, \
|
||||
itype, otype, op, size_t, dim)
|
||||
|
||||
instantiate_reduce_ops(instantiate_same_col_reduce_helper, instantiate_reduce_helper_types)
|
||||
instantiate_reduce_ops(instantiate_same_col_reduce_helper, instantiate_reduce_helper_64b)
|
||||
|
||||
instantiate_col_reduce_general(sumbool_, bool, uint32_t, Sum<uint32_t>)
|
||||
instantiate_reduce_from_types(instantiate_col_reduce_general, and, bool, And<bool>)
|
||||
instantiate_reduce_from_types(instantiate_col_reduce_general, or, bool, Or<bool>)
|
||||
|
||||
#define instantiate_row_reduce_small(name, itype, otype, op, dim) \
|
||||
instantiate_kernel("row_reduce_small_" #dim "_reduce_" #name, \
|
||||
row_reduce_small, \
|
||||
itype, otype, op, dim)
|
||||
|
||||
#define instantiate_row_reduce_looped(name, itype, otype, op, dim) \
|
||||
instantiate_kernel("row_reduce_looped_" #dim "_reduce_" #name, \
|
||||
row_reduce_looped, \
|
||||
itype, otype, op, dim)
|
||||
#define instantiate_row_reduce_looped(name, itype, otype, op, dim) \
|
||||
instantiate_kernel("row_reduce_looped_" #dim "_reduce_" #name, \
|
||||
row_reduce_looped, \
|
||||
itype, otype, op, uint, dim) \
|
||||
instantiate_kernel("row_reduce_looped_large_" #dim "_reduce_" #name, \
|
||||
row_reduce_looped, \
|
||||
itype, otype, op, size_t, dim)
|
||||
|
||||
#define instantiate_row_reduce_general(name, itype, otype, op) \
|
||||
instantiate_row_reduce_small(name, itype, otype, op, 0) \
|
||||
instantiate_row_reduce_small(name, itype, otype, op, 1) \
|
||||
instantiate_row_reduce_small(name, itype, otype, op, 2) \
|
||||
instantiate_row_reduce_small(name, itype, otype, op, 3) \
|
||||
instantiate_row_reduce_small(name, itype, otype, op, 4) \
|
||||
instantiate_row_reduce_looped(name, itype, otype, op, 0) \
|
||||
instantiate_row_reduce_small(name, itype, otype, op, 5) \
|
||||
instantiate_row_reduce_looped(name, itype, otype, op, 1) \
|
||||
instantiate_row_reduce_looped(name, itype, otype, op, 2) \
|
||||
instantiate_row_reduce_looped(name, itype, otype, op, 3) \
|
||||
instantiate_row_reduce_looped(name, itype, otype, op, 4) \
|
||||
instantiate_row_reduce_looped(name, itype, otype, op, 5) \
|
||||
instantiate_kernel("row_reduce_simple_" #name, \
|
||||
row_reduce_simple, \
|
||||
itype, otype, op)
|
||||
|
||||
#define instantiate_same_row_reduce_helper(name, tname, type, op) \
|
||||
instantiate_row_reduce_general(name##tname, type, type, op<type>)
|
||||
#define instantiate_reduce_functions(name, tname, itype, otype, op) \
|
||||
instantiate_all_reduce(name##tname, itype, otype, op<otype>) \
|
||||
instantiate_row_reduce_general(name##tname, itype, otype, op<otype>) \
|
||||
instantiate_col_reduce_general(name##tname, itype, otype, op<otype>)
|
||||
|
||||
instantiate_reduce_ops(instantiate_same_row_reduce_helper, instantiate_reduce_helper_types)
|
||||
instantiate_reduce_ops(instantiate_same_row_reduce_helper, instantiate_reduce_helper_64b)
|
||||
#define instantiate_and_or(name, op) \
|
||||
instantiate_reduce_functions(name, bool_, bool, bool, op) \
|
||||
instantiate_reduce_functions(name, int16, int16_t, bool, op) \
|
||||
instantiate_reduce_functions(name, int32, int32_t, bool, op) \
|
||||
instantiate_reduce_functions(name, int64, int64_t, bool, op)
|
||||
|
||||
instantiate_reduce_from_types(instantiate_row_reduce_general, and, bool, And<bool>)
|
||||
instantiate_reduce_from_types(instantiate_row_reduce_general, or, bool, Or<bool>)
|
||||
instantiate_and_or(and, And)
|
||||
instantiate_and_or(or, Or)
|
||||
|
||||
instantiate_row_reduce_general(sumbool_, bool, uint32_t, Sum<uint32_t>)
|
||||
#define instantiate_sum_prod(name, op) \
|
||||
instantiate_reduce_functions(name, int8, int8_t, int32_t, op) \
|
||||
instantiate_reduce_functions(name, int16, int16_t, int32_t, op) \
|
||||
instantiate_reduce_functions(name, int32, int32_t, int32_t, op) \
|
||||
instantiate_reduce_functions(name, int64, int64_t, int64_t, op) \
|
||||
instantiate_reduce_functions(name, float16, float16_t, float16_t, op) \
|
||||
instantiate_reduce_functions(name, bfloat16, bfloat16_t, bfloat16_t, op) \
|
||||
instantiate_reduce_functions(name, float32, float, float, op) \
|
||||
instantiate_reduce_functions(name, complex64, complex64_t, complex64_t, op)
|
||||
|
||||
instantiate_sum_prod(sum, Sum)
|
||||
instantiate_sum_prod(prod, Prod)
|
||||
|
||||
#define instantiate_min_max(name, op) \
|
||||
instantiate_reduce_functions(name, int8, int8_t, int8_t, op) \
|
||||
instantiate_reduce_functions(name, int16, int16_t, int16_t, op) \
|
||||
instantiate_reduce_functions(name, int32, int32_t, int32_t, op) \
|
||||
instantiate_reduce_functions(name, int64, int64_t, int64_t, op) \
|
||||
instantiate_reduce_functions(name, uint8, uint8_t, uint8_t, op) \
|
||||
instantiate_reduce_functions(name, uint16, uint16_t, uint16_t, op) \
|
||||
instantiate_reduce_functions(name, uint32, uint32_t, uint32_t, op) \
|
||||
instantiate_reduce_functions(name, uint64, uint64_t, uint64_t, op) \
|
||||
instantiate_reduce_functions(name, float16, float16_t, float16_t, op) \
|
||||
instantiate_reduce_functions(name, bfloat16, bfloat16_t, bfloat16_t, op) \
|
||||
instantiate_reduce_functions(name, float32, float, float, op) \
|
||||
instantiate_reduce_functions(name, complex64, complex64_t, complex64_t, op)
|
||||
|
||||
instantiate_min_max(min, Min)
|
||||
instantiate_min_max(max, Max)
|
||||
// clang-format on
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user