mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-07 09:14:34 +08:00
Compare commits
47 Commits
io-dev
...
async_all_
Author | SHA1 | Date | |
---|---|---|---|
![]() |
3a1df968cf | ||
![]() |
9f9cb7a2ef | ||
![]() |
7e26fd8032 | ||
![]() |
eab2685c67 | ||
![]() |
50dfb664db | ||
![]() |
0189ab6ab6 | ||
![]() |
9401507336 | ||
![]() |
eb8321d863 | ||
![]() |
79ef49b2c2 | ||
![]() |
e110ca11e2 | ||
![]() |
226748b3e7 | ||
![]() |
d568c7ee36 | ||
![]() |
e6fecbb3e1 | ||
![]() |
da83f899bb | ||
![]() |
7e5674d8be | ||
![]() |
0a558577bf | ||
![]() |
fb71a82ada | ||
![]() |
23406c9e9e | ||
![]() |
b3ec792380 | ||
![]() |
6a9b584f3d | ||
![]() |
81dd33af66 | ||
![]() |
8b76571896 | ||
![]() |
e78a6518fa | ||
![]() |
1873ffda01 | ||
![]() |
c417e42116 | ||
![]() |
358e1fd6ab | ||
![]() |
631dfbe673 | ||
![]() |
56a4eaed72 | ||
![]() |
bf925d9dc7 | ||
![]() |
1a7ed5dcb6 | ||
![]() |
5be5daa6ef | ||
![]() |
60cb11764e | ||
![]() |
cbd5445ea7 | ||
![]() |
2c7e9b5158 | ||
![]() |
2263e4b279 | ||
![]() |
863039da4c | ||
![]() |
7178ac0111 | ||
![]() |
e7f9710499 | ||
![]() |
ff4223904d | ||
![]() |
a9f80d60f6 | ||
![]() |
2e158cf6d0 | ||
![]() |
8bd6bfa4b5 | ||
![]() |
8b1906abd0 | ||
![]() |
06375e6605 | ||
![]() |
b21242faf1 | ||
![]() |
cc05a281c4 | ||
![]() |
fe96ceee66 |
@@ -49,11 +49,6 @@ jobs:
|
||||
name: Run Python tests
|
||||
command: |
|
||||
python3 -m unittest discover python/tests -v
|
||||
# TODO: Reenable when extension api becomes stable
|
||||
# - run:
|
||||
# name: Build example extension
|
||||
# command: |
|
||||
# cd examples/extensions && python3 -m pip install .
|
||||
- run:
|
||||
name: Build CPP only
|
||||
command: |
|
||||
@@ -69,13 +64,14 @@ jobs:
|
||||
default: "15.2.0"
|
||||
macos:
|
||||
xcode: << parameters.xcode_version >>
|
||||
resource_class: macos.m1.large.gen1
|
||||
resource_class: macos.m1.medium.gen1
|
||||
steps:
|
||||
- checkout
|
||||
- run:
|
||||
name: Install dependencies
|
||||
command: |
|
||||
brew install python@3.8
|
||||
brew install openmpi
|
||||
python3.8 -m venv env
|
||||
source env/bin/activate
|
||||
pip install --upgrade pip
|
||||
@@ -101,11 +97,14 @@ jobs:
|
||||
source env/bin/activate
|
||||
LOW_MEMORY=1 DEVICE=cpu python -m xmlrunner discover -v python/tests -o test-results/cpu
|
||||
LOW_MEMORY=1 DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 METAL_DEBUG_ERROR_MODE=0 python -m xmlrunner discover -v python/tests -o test-results/gpu
|
||||
# TODO: Reenable when extension api becomes stable
|
||||
# - run:
|
||||
# name: Build example extension
|
||||
# command: |
|
||||
# cd examples/extensions && python3.11 -m pip install .
|
||||
mpirun -host localhost:8 -np 8 -x DYLD_LIBRARY_PATH=/opt/homebrew/lib/ python python/tests/mpi_test_distributed.py
|
||||
- run:
|
||||
name: Build example extension
|
||||
command: |
|
||||
source env/bin/activate
|
||||
cd examples/extensions
|
||||
pip install -r requirements.txt
|
||||
python setup.py build_ext -j8
|
||||
- store_test_results:
|
||||
path: test-results
|
||||
- run:
|
||||
@@ -117,7 +116,13 @@ jobs:
|
||||
name: Run CPP tests
|
||||
command: |
|
||||
DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 METAL_DEBUG_ERROR_MODE=0 ./build/tests/tests
|
||||
DEVICE=cpu ./build/tests/tests
|
||||
- run:
|
||||
name: Build small binary
|
||||
command: |
|
||||
source env/bin/activate
|
||||
cd build/
|
||||
cmake .. -DCMAKE_BUILD_TYPE=MinSizeRel -DBUILD_SHARED_LIBS=ON -DMLX_BUILD_CPU=OFF -DMLX_BUILD_SAFETENSORS=OFF -DMLX_BUILD_GGUF=OFF -DMLX_METAL_JIT=ON
|
||||
make -j
|
||||
|
||||
build_release:
|
||||
parameters:
|
||||
@@ -132,7 +137,7 @@ jobs:
|
||||
default: ""
|
||||
macos:
|
||||
xcode: << parameters.xcode_version >>
|
||||
resource_class: macos.m1.large.gen1
|
||||
resource_class: macos.m1.medium.gen1
|
||||
steps:
|
||||
- checkout
|
||||
- run:
|
||||
|
@@ -16,6 +16,7 @@ MLX was developed with contributions from the following individuals:
|
||||
- Luca Arnaboldi: Added `Ceil` and `Floor` ops; implemented pickling, copy and deepcopy for mlx arrays.
|
||||
- Brian Keene & Atila Orhon, with Argmax Inc.: Added `fast.scaled_dot_product_attention`
|
||||
- AmirHossein Razlighi: Added chaining support for some of the ops in `nn.Module`. Comparison works for non array objects in `mlx.core.array`. Exception handling for invalid operations in `mlx.core.array`.
|
||||
- Gleb Pobudzey: Added the `where` primitive, and groups in 1D and 2D convolutions.
|
||||
|
||||
<a href="https://github.com/ml-explore/mlx/graphs/contributors">
|
||||
<img class="dark-light" src="https://contrib.rocks/image?repo=ml-explore/mlx&anon=0&columns=20&max=100&r=true" />
|
||||
|
111
CMakeLists.txt
111
CMakeLists.txt
@@ -15,12 +15,16 @@ option(MLX_BUILD_EXAMPLES "Build examples for mlx" ON)
|
||||
option(MLX_BUILD_BENCHMARKS "Build benchmarks for mlx" OFF)
|
||||
option(MLX_BUILD_PYTHON_BINDINGS "Build python bindings for mlx" OFF)
|
||||
option(MLX_BUILD_METAL "Build metal backend" ON)
|
||||
option(MLX_BUILD_CPU "Build cpu backend" ON)
|
||||
option(MLX_METAL_DEBUG "Enhance metal debug workflow" OFF)
|
||||
option(MLX_ENABLE_X64_MAC "Enable building for x64 macOS" OFF)
|
||||
option(MLX_BUILD_GGUF "Include support for GGUF format" ON)
|
||||
option(MLX_BUILD_SAFETENSORS "Include support for safetensors format" ON)
|
||||
option(MLX_METAL_JIT "Use JIT compilation for Metal kernels" OFF)
|
||||
option(BUILD_SHARED_LIBS "Build mlx as a shared library" OFF)
|
||||
|
||||
if(NOT MLX_VERSION)
|
||||
set(MLX_VERSION 0.12.2)
|
||||
set(MLX_VERSION 0.14.0)
|
||||
endif()
|
||||
|
||||
# --------------------- Processor tests -------------------------
|
||||
@@ -84,9 +88,11 @@ elseif (MLX_BUILD_METAL)
|
||||
if (${MACOS_VERSION} GREATER_EQUAL 14.2)
|
||||
set(METAL_CPP_PATCH ${CMAKE_CURRENT_SOURCE_DIR}/cmake/metal.14.2.diff)
|
||||
set(METAL_CPP_URL https://developer.apple.com/metal/cpp/files/metal-cpp_macOS14.2_iOS17.2.zip)
|
||||
set(MLX_METAL_VERSION METAL_3_1)
|
||||
elseif (${MACOS_VERSION} GREATER_EQUAL 14.0)
|
||||
set(METAL_CPP_PATCH ${CMAKE_CURRENT_SOURCE_DIR}/cmake/metal.14.0.diff)
|
||||
set(METAL_CPP_URL https://developer.apple.com/metal/cpp/files/metal-cpp_macOS14_iOS17-beta.zip)
|
||||
set(MLX_METAL_VERSION METAL_3_0)
|
||||
else()
|
||||
message(FATAL_ERROR "MLX requires macOS SDK >= 14.0 to be built with MLX_BUILD_METAL=ON" )
|
||||
endif()
|
||||
@@ -104,55 +110,66 @@ elseif (MLX_BUILD_METAL)
|
||||
$<INSTALL_INTERFACE:include/metal_cpp>
|
||||
)
|
||||
target_link_libraries(
|
||||
mlx
|
||||
mlx PUBLIC
|
||||
${METAL_LIB}
|
||||
${FOUNDATION_LIB}
|
||||
${QUARTZ_LIB})
|
||||
|
||||
add_compile_definitions(${MLX_METAL_VERSION})
|
||||
endif()
|
||||
|
||||
find_library(ACCELERATE_LIBRARY Accelerate)
|
||||
if (MLX_BUILD_ARM AND ACCELERATE_LIBRARY)
|
||||
message(STATUS "Accelerate found ${ACCELERATE_LIBRARY}")
|
||||
set(MLX_BUILD_ACCELERATE ON)
|
||||
target_link_libraries(mlx ${ACCELERATE_LIBRARY})
|
||||
add_compile_definitions(ACCELERATE_NEW_LAPACK)
|
||||
if (MLX_BUILD_CPU)
|
||||
find_library(ACCELERATE_LIBRARY Accelerate)
|
||||
if (MLX_BUILD_ARM AND ACCELERATE_LIBRARY)
|
||||
message(STATUS "Accelerate found ${ACCELERATE_LIBRARY}")
|
||||
set(MLX_BUILD_ACCELERATE ON)
|
||||
target_link_libraries(mlx PUBLIC ${ACCELERATE_LIBRARY})
|
||||
add_compile_definitions(ACCELERATE_NEW_LAPACK)
|
||||
else()
|
||||
message(STATUS "Accelerate or arm neon not found, using default backend.")
|
||||
set(MLX_BUILD_ACCELERATE OFF)
|
||||
if(${CMAKE_HOST_APPLE})
|
||||
# The blas shipped in macOS SDK is not supported, search homebrew for
|
||||
# openblas instead.
|
||||
set(BLA_VENDOR OpenBLAS)
|
||||
set(LAPACK_ROOT "${LAPACK_ROOT};$ENV{LAPACK_ROOT};/usr/local/opt/openblas")
|
||||
endif()
|
||||
# Search and link with lapack.
|
||||
find_package(LAPACK REQUIRED)
|
||||
if (NOT LAPACK_FOUND)
|
||||
message(FATAL_ERROR "Must have LAPACK installed")
|
||||
endif()
|
||||
find_path(LAPACK_INCLUDE_DIRS lapacke.h
|
||||
/usr/include
|
||||
/usr/local/include
|
||||
/usr/local/opt/openblas/include)
|
||||
message(STATUS "Lapack lib " ${LAPACK_LIBRARIES})
|
||||
message(STATUS "Lapack include " ${LAPACK_INCLUDE_DIRS})
|
||||
target_include_directories(mlx PRIVATE ${LAPACK_INCLUDE_DIRS})
|
||||
target_link_libraries(mlx PUBLIC ${LAPACK_LIBRARIES})
|
||||
# List blas after lapack otherwise we may accidentally incldue an old version
|
||||
# of lapack.h from the include dirs of blas.
|
||||
find_package(BLAS REQUIRED)
|
||||
if (NOT BLAS_FOUND)
|
||||
message(FATAL_ERROR "Must have BLAS installed")
|
||||
endif()
|
||||
# TODO find a cleaner way to do this
|
||||
find_path(BLAS_INCLUDE_DIRS cblas.h
|
||||
/usr/include
|
||||
/usr/local/include
|
||||
$ENV{BLAS_HOME}/include)
|
||||
message(STATUS "Blas lib " ${BLAS_LIBRARIES})
|
||||
message(STATUS "Blas include " ${BLAS_INCLUDE_DIRS})
|
||||
target_include_directories(mlx PRIVATE ${BLAS_INCLUDE_DIRS})
|
||||
target_link_libraries(mlx PUBLIC ${BLAS_LIBRARIES})
|
||||
endif()
|
||||
else()
|
||||
message(STATUS "Accelerate or arm neon not found, using default backend.")
|
||||
set(MLX_BUILD_ACCELERATE OFF)
|
||||
if(${CMAKE_HOST_APPLE})
|
||||
# The blas shipped in macOS SDK is not supported, search homebrew for
|
||||
# openblas instead.
|
||||
set(BLA_VENDOR OpenBLAS)
|
||||
set(LAPACK_ROOT "${LAPACK_ROOT};$ENV{LAPACK_ROOT};/usr/local/opt/openblas")
|
||||
endif()
|
||||
# Search and link with lapack.
|
||||
find_package(LAPACK REQUIRED)
|
||||
if (NOT LAPACK_FOUND)
|
||||
message(FATAL_ERROR "Must have LAPACK installed")
|
||||
endif()
|
||||
find_path(LAPACK_INCLUDE_DIRS lapacke.h
|
||||
/usr/include
|
||||
/usr/local/include
|
||||
/usr/local/opt/openblas/include)
|
||||
message(STATUS "Lapack lib " ${LAPACK_LIBRARIES})
|
||||
message(STATUS "Lapack include " ${LAPACK_INCLUDE_DIRS})
|
||||
target_include_directories(mlx PRIVATE ${LAPACK_INCLUDE_DIRS})
|
||||
target_link_libraries(mlx ${LAPACK_LIBRARIES})
|
||||
# List blas after lapack otherwise we may accidentally incldue an old version
|
||||
# of lapack.h from the include dirs of blas.
|
||||
find_package(BLAS REQUIRED)
|
||||
if (NOT BLAS_FOUND)
|
||||
message(FATAL_ERROR "Must have BLAS installed")
|
||||
endif()
|
||||
# TODO find a cleaner way to do this
|
||||
find_path(BLAS_INCLUDE_DIRS cblas.h
|
||||
/usr/include
|
||||
/usr/local/include
|
||||
$ENV{BLAS_HOME}/include)
|
||||
message(STATUS "Blas lib " ${BLAS_LIBRARIES})
|
||||
message(STATUS "Blas include " ${BLAS_INCLUDE_DIRS})
|
||||
target_include_directories(mlx PRIVATE ${BLAS_INCLUDE_DIRS})
|
||||
target_link_libraries(mlx ${BLAS_LIBRARIES})
|
||||
endif()
|
||||
|
||||
find_package(MPI)
|
||||
if (MPI_FOUND)
|
||||
target_include_directories(mlx PRIVATE ${MPI_INCLUDE_PATH})
|
||||
endif()
|
||||
|
||||
add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/mlx)
|
||||
@@ -164,6 +181,14 @@ target_include_directories(
|
||||
$<INSTALL_INTERFACE:include>
|
||||
)
|
||||
|
||||
FetchContent_Declare(fmt
|
||||
GIT_REPOSITORY https://github.com/fmtlib/fmt.git
|
||||
GIT_TAG 10.2.1
|
||||
EXCLUDE_FROM_ALL
|
||||
)
|
||||
FetchContent_MakeAvailable(fmt)
|
||||
target_link_libraries(mlx PRIVATE fmt::fmt-header-only)
|
||||
|
||||
if (MLX_BUILD_PYTHON_BINDINGS)
|
||||
message(STATUS "Building Python bindings.")
|
||||
find_package(Python 3.8 COMPONENTS Interpreter Development.Module REQUIRED)
|
||||
|
@@ -88,13 +88,13 @@ for more information on building the C++ and Python APIs from source.
|
||||
|
||||
## Contributing
|
||||
|
||||
Check out the [contribution guidelines](CONTRIBUTING.md) for more information
|
||||
Check out the [contribution guidelines](https://github.com/ml-explore/mlx/tree/main/CONTRIBUTING.md) for more information
|
||||
on contributing to MLX. See the
|
||||
[docs](https://ml-explore.github.io/mlx/build/html/install.html) for more
|
||||
information on building from source, and running tests.
|
||||
|
||||
We are grateful for all of [our
|
||||
contributors](ACKNOWLEDGMENTS.md#Individual-Contributors). If you contribute
|
||||
contributors](https://github.com/ml-explore/mlx/tree/main/ACKNOWLEDGMENTS.md#Individual-Contributors). If you contribute
|
||||
to MLX and wish to be acknowledged, please add your name to the list in your
|
||||
pull request.
|
||||
|
||||
|
@@ -28,11 +28,11 @@ def bench(f, a, b):
|
||||
return (e - s) * 1e-9
|
||||
|
||||
|
||||
def make_mx_conv_2D(strides=(1, 1), padding=(0, 0)):
|
||||
def make_mx_conv_2D(strides=(1, 1), padding=(0, 0), groups=1):
|
||||
def mx_conv_2D(a, b):
|
||||
ys = []
|
||||
for i in range(N_iter_func):
|
||||
y = mx.conv2d(a, b, stride=strides, padding=padding)
|
||||
y = mx.conv2d(a, b, stride=strides, padding=padding, groups=groups)
|
||||
ys.append(y)
|
||||
mx.eval(ys)
|
||||
return ys
|
||||
@@ -40,12 +40,12 @@ def make_mx_conv_2D(strides=(1, 1), padding=(0, 0)):
|
||||
return mx_conv_2D
|
||||
|
||||
|
||||
def make_pt_conv_2D(strides=(1, 1), padding=(0, 0)):
|
||||
def make_pt_conv_2D(strides=(1, 1), padding=(0, 0), groups=1):
|
||||
@torch.no_grad()
|
||||
def pt_conv_2D(a, b):
|
||||
ys = []
|
||||
for i in range(N_iter_func):
|
||||
y = torch.conv2d(a, b, stride=strides, padding=padding)
|
||||
y = torch.conv2d(a, b, stride=strides, padding=padding, groups=groups)
|
||||
ys.append(y)
|
||||
torch.mps.synchronize()
|
||||
return ys
|
||||
@@ -53,11 +53,13 @@ def make_pt_conv_2D(strides=(1, 1), padding=(0, 0)):
|
||||
return pt_conv_2D
|
||||
|
||||
|
||||
def bench_shape(N, H, W, C, kH, kW, O, strides, padding, np_dtype):
|
||||
def bench_shape(N, H, W, C, kH, kW, O, strides, padding, groups, np_dtype):
|
||||
|
||||
scale = 1.0 / math.sqrt(kH * kH * C)
|
||||
a_np = np.random.uniform(0, 0.5, (N, H, W, C)).astype(np_dtype)
|
||||
b_np = np.random.uniform(-scale, scale, (O, kH, kW, C)).astype(np_dtype)
|
||||
b_np = np.random.uniform(-scale, scale, (O, kH, kW, int(C / groups))).astype(
|
||||
np_dtype
|
||||
)
|
||||
|
||||
a_mx = mx.array(a_np)
|
||||
b_mx = mx.array(b_np)
|
||||
@@ -67,15 +69,15 @@ def bench_shape(N, H, W, C, kH, kW, O, strides, padding, np_dtype):
|
||||
|
||||
torch.mps.synchronize()
|
||||
|
||||
f_mx = make_mx_conv_2D(strides, padding)
|
||||
f_pt = make_pt_conv_2D(strides, padding)
|
||||
f_mx = make_mx_conv_2D(strides, padding, groups)
|
||||
f_pt = make_pt_conv_2D(strides, padding, groups)
|
||||
|
||||
time_torch = bench(f_pt, a_pt, b_pt)
|
||||
time_mlx = bench(f_mx, a_mx, b_mx)
|
||||
|
||||
out_mx = mx.conv2d(a_mx, b_mx, stride=strides, padding=padding)
|
||||
out_mx = mx.conv2d(a_mx, b_mx, stride=strides, padding=padding, groups=groups)
|
||||
out_pt = torch.conv2d(
|
||||
a_pt.to("cpu"), b_pt.to("cpu"), stride=strides, padding=padding
|
||||
a_pt.to("cpu"), b_pt.to("cpu"), stride=strides, padding=padding, groups=groups
|
||||
)
|
||||
out_pt = torch.permute(out_pt, (0, 2, 3, 1))
|
||||
out_pt = out_pt.numpy(force=True)
|
||||
@@ -84,7 +86,7 @@ def bench_shape(N, H, W, C, kH, kW, O, strides, padding, np_dtype):
|
||||
|
||||
if not np.allclose(out_pt, out_mx, atol=atol):
|
||||
print(
|
||||
f"Failed at {(N, H, W, C)}, {(O, kH, kW, C)} [strides = {strides}, padding = {padding}] with max(|a - b|) = {np.max(np.abs(out_pt - out_mx))}"
|
||||
f"Failed at {(N, H, W, C)}, {(O, kH, kW, C)} [strides = {strides}, padding = {padding}, groups = {groups}] with max(|a - b|) = {np.max(np.abs(out_pt - out_mx))}"
|
||||
)
|
||||
|
||||
return time_mlx, time_torch
|
||||
@@ -95,35 +97,40 @@ if __name__ == "__main__":
|
||||
|
||||
dtypes = ("float32",)
|
||||
shapes = (
|
||||
(4, 32, 32, 32, 5, 5, 32, (1, 1), (2, 2)),
|
||||
(4, 32, 32, 64, 5, 5, 64, (1, 1), (2, 2)),
|
||||
(4, 32, 32, 128, 5, 5, 128, (1, 1), (2, 2)),
|
||||
(4, 32, 32, 256, 5, 5, 256, (1, 1), (2, 2)),
|
||||
(4, 32, 32, 512, 5, 5, 512, (1, 1), (2, 2)),
|
||||
(4, 64, 64, 32, 5, 5, 32, (1, 1), (2, 2)),
|
||||
(4, 64, 64, 64, 5, 5, 64, (1, 1), (2, 2)),
|
||||
(4, 64, 64, 128, 5, 5, 128, (1, 1), (2, 2)),
|
||||
(4, 64, 64, 256, 5, 5, 256, (1, 1), (2, 2)),
|
||||
(4, 128, 128, 32, 5, 5, 32, (1, 1), (2, 2)),
|
||||
(4, 128, 128, 64, 5, 5, 64, (1, 1), (2, 2)),
|
||||
(4, 128, 128, 128, 5, 5, 128, (1, 1), (2, 2)),
|
||||
(4, 256, 256, 32, 5, 5, 3, (1, 1), (2, 2)),
|
||||
(4, 256, 256, 3, 5, 5, 32, (1, 1), (2, 2)),
|
||||
(4, 128, 128, 64, 5, 5, 3, (1, 1), (2, 2)),
|
||||
(4, 128, 128, 3, 5, 5, 64, (1, 1), (2, 2)),
|
||||
(4, 32, 32, 32, 5, 5, 32, (1, 1), (2, 2), 1),
|
||||
(4, 32, 32, 64, 5, 5, 64, (1, 1), (2, 2), 1),
|
||||
(4, 32, 32, 128, 5, 5, 128, (1, 1), (2, 2), 1),
|
||||
(4, 32, 32, 256, 5, 5, 256, (1, 1), (2, 2), 1),
|
||||
(4, 32, 32, 512, 5, 5, 512, (1, 1), (2, 2), 1),
|
||||
(4, 64, 64, 32, 5, 5, 32, (1, 1), (2, 2), 1),
|
||||
(4, 64, 64, 64, 5, 5, 64, (1, 1), (2, 2), 1),
|
||||
(4, 64, 64, 128, 5, 5, 128, (1, 1), (2, 2), 1),
|
||||
(4, 64, 64, 256, 5, 5, 256, (1, 1), (2, 2), 1),
|
||||
(4, 64, 64, 256, 5, 5, 256, (1, 1), (2, 2), 2),
|
||||
(4, 64, 64, 256, 5, 5, 256, (1, 1), (2, 2), 16),
|
||||
(4, 64, 64, 256, 5, 5, 256, (1, 1), (2, 2), 64),
|
||||
(4, 128, 128, 32, 5, 5, 32, (1, 1), (2, 2), 1),
|
||||
(4, 128, 128, 64, 5, 5, 64, (1, 1), (2, 2), 1),
|
||||
(4, 128, 128, 128, 5, 5, 128, (1, 1), (2, 2), 1),
|
||||
(4, 256, 256, 32, 5, 5, 3, (1, 1), (2, 2), 1),
|
||||
(4, 256, 256, 3, 5, 5, 32, (1, 1), (2, 2), 1),
|
||||
(4, 128, 128, 64, 5, 5, 3, (1, 1), (2, 2), 1),
|
||||
(4, 128, 128, 3, 5, 5, 64, (1, 1), (2, 2), 1),
|
||||
)
|
||||
|
||||
for dtype in dtypes:
|
||||
print("(N, H, W, C), ( O, kH, kW, C), dtype, stride, pads, diff%")
|
||||
for N, H, W, C, kH, kW, O, strides, padding in shapes:
|
||||
print(
|
||||
"(N, H, W, C), ( O, kH, kW, C), dtype, stride, pads, groups, diff%"
|
||||
)
|
||||
for N, H, W, C, kH, kW, O, strides, padding, groups in shapes:
|
||||
np_dtype = getattr(np, dtype)
|
||||
time_mlx, time_torch = bench_shape(
|
||||
N, H, W, C, kH, kW, O, strides, padding, np_dtype
|
||||
N, H, W, C, kH, kW, O, strides, padding, groups, np_dtype
|
||||
)
|
||||
diff = time_torch / time_mlx - 1.0
|
||||
|
||||
print(
|
||||
f"({N}, {H:3d}, {W:3d}, {C:3d}), ({O:3d}, {kH:2d}, {kW:2d}, {C:3d}), {dtype}, {strides}, {padding}, {100. * diff:+5.2f}%"
|
||||
f"({N}, {H:3d}, {W:3d}, {C:3d}), ({O:3d}, {kH:2d}, {kW:2d}, {C:3d}), {dtype}, {strides}, {padding}, {groups:7d}, {100. * diff:+5.2f}%"
|
||||
)
|
||||
if time_mlx >= 2.0 * time_torch:
|
||||
print("ATTENTION ^^^^^^^")
|
||||
|
@@ -1,5 +1,5 @@
|
||||
Developer Documentation
|
||||
=======================
|
||||
Custom Extensions in MLX
|
||||
========================
|
||||
|
||||
You can extend MLX with custom operations on the CPU or GPU. This guide
|
||||
explains how to do that with a simple example.
|
||||
@@ -494,7 +494,7 @@ below.
|
||||
auto kernel = d.get_kernel(kname.str(), "mlx_ext");
|
||||
|
||||
// Prepare to encode kernel
|
||||
auto compute_encoder = d.get_command_encoder(s.index);
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
// Kernel parameters are registered with buffer indices corresponding to
|
||||
@@ -503,11 +503,11 @@ below.
|
||||
size_t nelem = out.size();
|
||||
|
||||
// Encode input arrays to kernel
|
||||
set_array_buffer(compute_encoder, x, 0);
|
||||
set_array_buffer(compute_encoder, y, 1);
|
||||
compute_encoder.set_input_array(x, 0);
|
||||
compute_encoder.set_input_array(y, 1);
|
||||
|
||||
// Encode output arrays to kernel
|
||||
set_array_buffer(compute_encoder, out, 2);
|
||||
compute_encoder.set_output_array(out, 2);
|
||||
|
||||
// Encode alpha and beta
|
||||
compute_encoder->setBytes(&alpha_, sizeof(float), 3);
|
||||
@@ -531,7 +531,7 @@ below.
|
||||
|
||||
// Launch the grid with the given number of threads divided among
|
||||
// the given threadgroups
|
||||
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
We can now call the :meth:`axpby` operation on both the CPU and the GPU!
|
||||
@@ -825,7 +825,7 @@ Let's look at a simple script and its results:
|
||||
|
||||
print(f"c shape: {c.shape}")
|
||||
print(f"c dtype: {c.dtype}")
|
||||
print(f"c correctness: {mx.all(c == 6.0).item()}")
|
||||
print(f"c correct: {mx.all(c == 6.0).item()}")
|
||||
|
||||
Output:
|
||||
|
||||
|
@@ -153,11 +153,18 @@ should point to the path to the built metal library.
|
||||
- OFF
|
||||
* - MLX_BUILD_METAL
|
||||
- ON
|
||||
* - MLX_BUILD_CPU
|
||||
- ON
|
||||
* - MLX_BUILD_PYTHON_BINDINGS
|
||||
- OFF
|
||||
* - MLX_METAL_DEBUG
|
||||
- OFF
|
||||
|
||||
* - MLX_BUILD_SAFETENSORS
|
||||
- ON
|
||||
* - MLX_BUILD_GGUF
|
||||
- ON
|
||||
* - MLX_METAL_JIT
|
||||
- OFF
|
||||
|
||||
.. note::
|
||||
|
||||
@@ -176,10 +183,37 @@ should point to the path to the built metal library.
|
||||
|
||||
xcrun -sdk macosx --show-sdk-version
|
||||
|
||||
Binary Size Minimization
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
To produce a smaller binary use the CMake flags `CMAKE_BUILD_TYPE=MinSizeRel`
|
||||
and `BUILD_SHARED_LIBS=ON`.
|
||||
|
||||
The MLX CMake build has several additional options to make smaller binaries.
|
||||
For example, if you don't need the CPU backend or support for safetensors and
|
||||
GGUF, you can do:
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
cmake ..
|
||||
-DCMAKE_BUILD_TYPE=MinSizeRel \
|
||||
-DBUILD_SHARED_LIBS=ON \
|
||||
-DMLX_BUILD_CPU=OFF \
|
||||
-DMLX_BUILD_SAFETENSORS=OFF \
|
||||
-DMLX_BUILD_GGUF=OFF \
|
||||
-DMLX_METAL_JIT=ON
|
||||
|
||||
THE `MLX_METAL_JIT` flag minimizes the size of the MLX Metal library which
|
||||
contains pre-built GPU kernels. This substantially reduces the size of the
|
||||
Metal library by run-time compiling kernels the first time they are used in MLX
|
||||
on a given machine. Note run-time compilation incurs a cold-start cost which can
|
||||
be anwywhere from a few hundred millisecond to a few seconds depending on the
|
||||
application. Once a kernel is compiled, it will be cached by the system. The
|
||||
Metal kernel cache persists accross reboots.
|
||||
|
||||
Troubleshooting
|
||||
^^^^^^^^^^^^^^^
|
||||
|
||||
|
||||
Metal not found
|
||||
~~~~~~~~~~~~~~~
|
||||
|
||||
|
@@ -8,5 +8,8 @@ Linear Algebra
|
||||
.. autosummary::
|
||||
:toctree: _autosummary
|
||||
|
||||
inv
|
||||
norm
|
||||
cholesky
|
||||
qr
|
||||
svd
|
||||
|
@@ -15,6 +15,7 @@ Layers
|
||||
BatchNorm
|
||||
Conv1d
|
||||
Conv2d
|
||||
Conv3d
|
||||
Dropout
|
||||
Dropout2d
|
||||
Dropout3d
|
||||
|
@@ -10,6 +10,7 @@ Operations
|
||||
|
||||
abs
|
||||
add
|
||||
addmm
|
||||
all
|
||||
allclose
|
||||
any
|
||||
@@ -19,12 +20,14 @@ Operations
|
||||
arcsin
|
||||
arcsinh
|
||||
arctan
|
||||
arctan2
|
||||
arctanh
|
||||
argmax
|
||||
argmin
|
||||
argpartition
|
||||
argsort
|
||||
array_equal
|
||||
as_strided
|
||||
atleast_1d
|
||||
atleast_2d
|
||||
atleast_3d
|
||||
@@ -32,11 +35,12 @@ Operations
|
||||
bitwise_or
|
||||
bitwise_xor
|
||||
block_masked_mm
|
||||
block_sparse_mm
|
||||
broadcast_to
|
||||
ceil
|
||||
clip
|
||||
concatenate
|
||||
conj
|
||||
conjugate
|
||||
convolve
|
||||
conv1d
|
||||
conv2d
|
||||
@@ -64,6 +68,8 @@ Operations
|
||||
floor
|
||||
floor_divide
|
||||
full
|
||||
gather_mm
|
||||
gather_qmm
|
||||
greater
|
||||
greater_equal
|
||||
identity
|
||||
@@ -73,6 +79,7 @@ Operations
|
||||
isnan
|
||||
isneginf
|
||||
isposinf
|
||||
issubdtype
|
||||
left_shift
|
||||
less
|
||||
less_equal
|
||||
@@ -103,11 +110,13 @@ Operations
|
||||
outer
|
||||
partition
|
||||
pad
|
||||
power
|
||||
prod
|
||||
quantize
|
||||
quantized_matmul
|
||||
radians
|
||||
reciprocal
|
||||
remainder
|
||||
repeat
|
||||
reshape
|
||||
right_shift
|
||||
@@ -141,6 +150,7 @@ Operations
|
||||
tensordot
|
||||
tile
|
||||
topk
|
||||
trace
|
||||
transpose
|
||||
tri
|
||||
tril
|
||||
|
@@ -9,3 +9,4 @@ build_example(tutorial.cpp)
|
||||
build_example(linear_regression.cpp)
|
||||
build_example(logistic_regression.cpp)
|
||||
build_example(metal_capture.cpp)
|
||||
build_example(distributed.cpp)
|
||||
|
22
examples/cpp/distributed.cpp
Normal file
22
examples/cpp/distributed.cpp
Normal file
@@ -0,0 +1,22 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include "mlx/mlx.h"
|
||||
|
||||
using namespace mlx::core;
|
||||
|
||||
int main() {
|
||||
if (!distributed::is_available()) {
|
||||
std::cout << "No communication backend found" << std::endl;
|
||||
return 1;
|
||||
}
|
||||
|
||||
auto global_group = distributed::init();
|
||||
std::cout << global_group.rank() << " / " << global_group.size() << std::endl;
|
||||
|
||||
array x = ones({10});
|
||||
array out = distributed::all_reduce_sum(x, global_group);
|
||||
|
||||
std::cout << out << std::endl;
|
||||
}
|
@@ -89,8 +89,8 @@ void automatic_differentiation() {
|
||||
// dfdx is 2 * x
|
||||
|
||||
// Get the second derivative by composing grad with grad
|
||||
auto df2dx2 = grad(grad(fn))(x);
|
||||
// df2dx2 is 2
|
||||
auto d2fdx2 = grad(grad(fn))(x);
|
||||
// d2fdx2 is 2
|
||||
}
|
||||
|
||||
int main() {
|
||||
|
@@ -1,5 +1,5 @@
|
||||
|
||||
## Build the extensions
|
||||
## Build
|
||||
|
||||
```
|
||||
pip install -e .
|
||||
@@ -16,3 +16,9 @@ And then run:
|
||||
```
|
||||
python setup.py build_ext -j8 --inplace
|
||||
```
|
||||
|
||||
## Test
|
||||
|
||||
```
|
||||
python test.py
|
||||
```
|
||||
|
@@ -257,7 +257,7 @@ void Axpby::eval_gpu(
|
||||
auto kernel = d.get_kernel(kname.str(), "mlx_ext");
|
||||
|
||||
// Prepare to encode kernel
|
||||
auto compute_encoder = d.get_command_encoder(s.index);
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
// Kernel parameters are registered with buffer indices corresponding to
|
||||
@@ -266,11 +266,11 @@ void Axpby::eval_gpu(
|
||||
size_t nelem = out.size();
|
||||
|
||||
// Encode input arrays to kernel
|
||||
set_array_buffer(compute_encoder, x, 0);
|
||||
set_array_buffer(compute_encoder, y, 1);
|
||||
compute_encoder.set_input_array(x, 0);
|
||||
compute_encoder.set_input_array(y, 1);
|
||||
|
||||
// Encode output arrays to kernel
|
||||
set_array_buffer(compute_encoder, out, 2);
|
||||
compute_encoder.set_output_array(out, 2);
|
||||
|
||||
// Encode alpha and beta
|
||||
compute_encoder->setBytes(&alpha_, sizeof(float), 3);
|
||||
@@ -296,7 +296,7 @@ void Axpby::eval_gpu(
|
||||
|
||||
// Launch the grid with the given number of threads divided among
|
||||
// the given threadgroups
|
||||
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
#else // Metal is not available
|
||||
|
@@ -2,4 +2,4 @@
|
||||
|
||||
import mlx.core as mx
|
||||
|
||||
from .mlx_sample_extensions import *
|
||||
from ._ext import axpby
|
||||
|
@@ -1,4 +1,4 @@
|
||||
setuptools>=42
|
||||
cmake>=3.24
|
||||
mlx>=0.9.0
|
||||
nanobind@git+https://github.com/wjakob/nanobind.git#egg=4148debcf91f5ccab0c3b8d67b5c3cabd61f407f
|
||||
nanobind@git+https://github.com/wjakob/nanobind.git@2f04eac452a6d9142dedb957701bdb20125561e4
|
||||
|
10
examples/extensions/test.py
Normal file
10
examples/extensions/test.py
Normal file
@@ -0,0 +1,10 @@
|
||||
import mlx.core as mx
|
||||
from mlx_sample_extensions import axpby
|
||||
|
||||
a = mx.ones((3, 4))
|
||||
b = mx.ones((3, 4))
|
||||
c = axpby(a, b, 4.0, 2.0, stream=mx.cpu)
|
||||
|
||||
print(f"c shape: {c.shape}")
|
||||
print(f"c dtype: {c.dtype}")
|
||||
print(f"c correct: {mx.all(c == 6.0).item()}")
|
@@ -19,12 +19,17 @@ target_sources(
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/backend/metal/metal.h
|
||||
)
|
||||
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/common)
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/io)
|
||||
if (MLX_BUILD_CPU)
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/common)
|
||||
else()
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/no_cpu)
|
||||
endif()
|
||||
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/distributed)
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/io)
|
||||
if (MLX_BUILD_ACCELERATE)
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/accelerate)
|
||||
else()
|
||||
elseif(MLX_BUILD_CPU)
|
||||
target_sources(
|
||||
mlx
|
||||
PRIVATE
|
||||
|
@@ -252,8 +252,9 @@ class array {
|
||||
}
|
||||
|
||||
/** True indicates the arrays buffer is safe to reuse */
|
||||
bool is_donatable() const {
|
||||
return array_desc_.use_count() == 1 && (array_desc_->data.use_count() == 1);
|
||||
bool is_donatable(int known_instances = 1) const {
|
||||
return array_desc_.use_count() == known_instances &&
|
||||
(array_desc_->data.use_count() == 1);
|
||||
}
|
||||
|
||||
/** The array's siblings. */
|
||||
|
@@ -32,10 +32,10 @@ DEFAULT(ArgReduce)
|
||||
DEFAULT(ArgSort)
|
||||
DEFAULT(AsStrided)
|
||||
DEFAULT(BlockMaskedMM)
|
||||
DEFAULT(BlockSparseMM)
|
||||
DEFAULT(Broadcast)
|
||||
DEFAULT(Ceil)
|
||||
DEFAULT(Concatenate)
|
||||
DEFAULT(Conjugate)
|
||||
DEFAULT(Copy)
|
||||
DEFAULT_MULTI(CustomVJP)
|
||||
DEFAULT_MULTI(Depends)
|
||||
@@ -47,10 +47,13 @@ DEFAULT(ErfInv)
|
||||
DEFAULT(FFT)
|
||||
DEFAULT(Floor)
|
||||
DEFAULT(Gather)
|
||||
DEFAULT(GatherMM)
|
||||
DEFAULT(GatherQMM)
|
||||
DEFAULT(Greater)
|
||||
DEFAULT(GreaterEqual)
|
||||
DEFAULT(Less)
|
||||
DEFAULT(LessEqual)
|
||||
DEFAULT(Load)
|
||||
DEFAULT(LogicalNot)
|
||||
DEFAULT(LogicalAnd)
|
||||
DEFAULT(LogicalOr)
|
||||
@@ -77,6 +80,7 @@ DEFAULT(StopGradient)
|
||||
DEFAULT_MULTI(SVD)
|
||||
DEFAULT(Transpose)
|
||||
DEFAULT(Inverse)
|
||||
DEFAULT(Cholesky)
|
||||
|
||||
void Abs::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
@@ -192,6 +196,26 @@ void ArcTan::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
}
|
||||
}
|
||||
|
||||
void ArcTan2::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
auto& a = inputs[0];
|
||||
auto& b = inputs[1];
|
||||
if (out.dtype() == float32 && a.flags().row_contiguous &&
|
||||
b.flags().row_contiguous) {
|
||||
if (a.is_donatable()) {
|
||||
out.copy_shared_buffer(a);
|
||||
} else if (b.is_donatable()) {
|
||||
out.copy_shared_buffer(b);
|
||||
} else {
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
}
|
||||
int size = a.data_size();
|
||||
vvatan2f(out.data<float>(), a.data<float>(), b.data<float>(), &size);
|
||||
} else {
|
||||
eval(inputs, out);
|
||||
}
|
||||
}
|
||||
|
||||
void ArcTanh::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
|
@@ -37,6 +37,7 @@ target_sources(
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/arg_reduce.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/binary.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/compiled.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/common.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/erf.cpp
|
||||
@@ -55,7 +56,7 @@ target_sources(
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/qrf.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/svd.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/inverse.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/cpu_impl.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/cholesky.cpp
|
||||
${CMAKE_CURRENT_BINARY_DIR}/compiled_preamble.cpp
|
||||
)
|
||||
|
||||
|
@@ -293,4 +293,25 @@ void BitwiseBinary::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
}
|
||||
}
|
||||
|
||||
void ArcTan2::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
const auto& a = inputs[0];
|
||||
const auto& b = inputs[1];
|
||||
if (out.dtype() == float32) {
|
||||
binary_op<float>(a, b, out, detail::ArcTan2());
|
||||
} else if (out.dtype() == float16) {
|
||||
binary_op<float16_t>(a, b, out, detail::ArcTan2());
|
||||
} else if (out.dtype() == bfloat16) {
|
||||
binary_op<bfloat16_t>(a, b, out, detail::ArcTan2());
|
||||
} else if (issubdtype(out.dtype(), inexact)) {
|
||||
std::ostringstream err;
|
||||
err << "[arctan2] Does not support " << out.dtype();
|
||||
throw std::invalid_argument(err.str());
|
||||
} else {
|
||||
throw std::invalid_argument(
|
||||
"[arctan2] Cannot compute inverse tangent for arrays"
|
||||
" with non floating point type.");
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@@ -1,6 +1,8 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
#include <cassert>
|
||||
|
||||
#include "mlx/allocator.h"
|
||||
#include "mlx/array.h"
|
||||
#include "mlx/backend/common/utils.h"
|
||||
|
101
mlx/backend/common/cholesky.cpp
Normal file
101
mlx/backend/common/cholesky.cpp
Normal file
@@ -0,0 +1,101 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include "mlx/allocator.h"
|
||||
#include "mlx/backend/common/copy.h"
|
||||
#include "mlx/linalg.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
#ifdef ACCELERATE_NEW_LAPACK
|
||||
#include <Accelerate/Accelerate.h>
|
||||
#else
|
||||
#include <lapack.h>
|
||||
#endif
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace {
|
||||
|
||||
// Delegate to the Cholesky factorization taking into account differences in
|
||||
// LAPACK implementations (basically how to pass the 'uplo' string to fortran).
|
||||
int spotrf_wrapper(char uplo, float* matrix, int N) {
|
||||
int info;
|
||||
|
||||
#ifdef LAPACK_FORTRAN_STRLEN_END
|
||||
spotrf_(
|
||||
/* uplo = */ &uplo,
|
||||
/* n = */ &N,
|
||||
/* a = */ matrix,
|
||||
/* lda = */ &N,
|
||||
/* info = */ &info,
|
||||
/* uplo_len = */ static_cast<size_t>(1));
|
||||
#else
|
||||
spotrf_(
|
||||
/* uplo = */ &uplo,
|
||||
/* n = */ &N,
|
||||
/* a = */ matrix,
|
||||
/* lda = */ &N,
|
||||
/* info = */ &info);
|
||||
#endif
|
||||
|
||||
return info;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void cholesky_impl(const array& a, array& factor, bool upper) {
|
||||
// Lapack uses the column-major convention. We take advantage of the fact that
|
||||
// the matrix should be symmetric:
|
||||
// (A)ᵀ = A
|
||||
// and that a column-major lower triangular matrix is a row-major upper
|
||||
// triangular matrix, so uplo is the opposite of what we would expect from
|
||||
// upper
|
||||
|
||||
char uplo = (upper) ? 'L' : 'U';
|
||||
|
||||
// The decomposition is computed in place, so just copy the input to the
|
||||
// output.
|
||||
copy(
|
||||
a,
|
||||
factor,
|
||||
a.flags().row_contiguous ? CopyType::Vector : CopyType::General);
|
||||
|
||||
const int N = a.shape(-1);
|
||||
const size_t num_matrices = a.size() / (N * N);
|
||||
|
||||
float* matrix = factor.data<float>();
|
||||
|
||||
for (int i = 0; i < num_matrices; i++) {
|
||||
// Compute Cholesky factorization.
|
||||
int info = spotrf_wrapper(uplo, matrix, N);
|
||||
|
||||
// TODO: We do nothing when the matrix is not positive semi-definite
|
||||
// because throwing an error would result in a crash. If we figure out how
|
||||
// to catch errors from the implementation we should throw.
|
||||
if (info < 0) {
|
||||
std::stringstream msg;
|
||||
msg << "[cholesky] Cholesky decomposition failed with error code "
|
||||
<< info;
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
|
||||
// Zero out the upper/lower triangle while advancing the pointer to the
|
||||
// next matrix at the same time.
|
||||
for (int row = 0; row < N; row++) {
|
||||
if (upper) {
|
||||
std::fill(matrix, matrix + row, 0);
|
||||
} else {
|
||||
std::fill(matrix + row + 1, matrix + N, 0);
|
||||
}
|
||||
matrix += N;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void Cholesky::eval(const std::vector<array>& inputs, array& output) {
|
||||
if (inputs[0].dtype() != float32) {
|
||||
throw std::runtime_error("[Cholesky::eval] only supports float32.");
|
||||
}
|
||||
cholesky_impl(inputs[0], output, upper_);
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
347
mlx/backend/common/common.cpp
Normal file
347
mlx/backend/common/common.cpp
Normal file
@@ -0,0 +1,347 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
#include <cassert>
|
||||
|
||||
#include "mlx/backend/common/utils.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
void AsStrided::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
|
||||
auto& in = inputs[0];
|
||||
|
||||
if (!in.flags().row_contiguous) {
|
||||
// Just ensuring that inputs[0] came from the ops which would ensure the
|
||||
// input is row contiguous.
|
||||
throw std::runtime_error(
|
||||
"AsStrided must be used with row contiguous arrays only.");
|
||||
}
|
||||
|
||||
// Compute the flags given the shape and strides
|
||||
bool row_contiguous = true, col_contiguous = true;
|
||||
size_t r = 1, c = 1;
|
||||
for (int i = strides_.size() - 1, j = 0; i >= 0; i--, j++) {
|
||||
row_contiguous &= (r == strides_[i]) || (shape_[i] == 1);
|
||||
col_contiguous &= (c == strides_[j]) || (shape_[j] == 1);
|
||||
r *= shape_[i];
|
||||
c *= shape_[j];
|
||||
}
|
||||
auto flags = in.flags();
|
||||
// TODO: Compute the contiguous flag in a better way cause now we are
|
||||
// unnecessarily strict.
|
||||
flags.contiguous = row_contiguous || col_contiguous;
|
||||
flags.row_contiguous = row_contiguous;
|
||||
flags.col_contiguous = col_contiguous;
|
||||
|
||||
// There is no easy way to compute the actual data size so we use out.size().
|
||||
// The contiguous flag will almost certainly not be set so no code should
|
||||
// rely on data_size anyway.
|
||||
size_t data_size = out.size();
|
||||
|
||||
return out.copy_shared_buffer(in, strides_, flags, data_size, offset_);
|
||||
}
|
||||
|
||||
void Broadcast::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
if (out.size() == 0) {
|
||||
out.set_data(nullptr);
|
||||
return;
|
||||
}
|
||||
std::vector<size_t> strides(out.ndim(), 0);
|
||||
int diff = out.ndim() - in.ndim();
|
||||
for (int i = in.ndim() - 1; i >= 0; --i) {
|
||||
strides[i + diff] = (in.shape()[i] == 1) ? 0 : in.strides()[i];
|
||||
}
|
||||
auto flags = in.flags();
|
||||
if (out.size() > in.size()) {
|
||||
flags.row_contiguous = flags.col_contiguous = false;
|
||||
}
|
||||
out.copy_shared_buffer(in, strides, flags, in.data_size());
|
||||
}
|
||||
|
||||
void Copy::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
out.copy_shared_buffer(inputs[0]);
|
||||
}
|
||||
|
||||
void CustomVJP::eval(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
assert(inputs.size() > outputs.size());
|
||||
for (int i = 0, j = inputs.size() - outputs.size(); i < outputs.size();
|
||||
i++, j++) {
|
||||
outputs[i].copy_shared_buffer(inputs[j]);
|
||||
}
|
||||
}
|
||||
|
||||
void Depends::eval(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
assert(inputs.size() > outputs.size());
|
||||
for (int i = 0; i < outputs.size(); i++) {
|
||||
outputs[i].copy_shared_buffer(inputs[i]);
|
||||
}
|
||||
}
|
||||
|
||||
void NumberOfElements::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
|
||||
double numel = 1;
|
||||
for (auto ax : axes_) {
|
||||
numel *= inputs[0].shape(ax);
|
||||
}
|
||||
|
||||
if (inverted_) {
|
||||
numel = 1.0 / numel;
|
||||
}
|
||||
|
||||
switch (out.dtype()) {
|
||||
case bool_:
|
||||
*out.data<bool>() = static_cast<bool>(numel);
|
||||
break;
|
||||
case uint8:
|
||||
*out.data<uint8_t>() = static_cast<uint8_t>(numel);
|
||||
break;
|
||||
case uint16:
|
||||
*out.data<uint16_t>() = static_cast<uint16_t>(numel);
|
||||
break;
|
||||
case uint32:
|
||||
*out.data<uint32_t>() = static_cast<uint32_t>(numel);
|
||||
break;
|
||||
case uint64:
|
||||
*out.data<uint64_t>() = static_cast<uint64_t>(numel);
|
||||
break;
|
||||
case int8:
|
||||
*out.data<int8_t>() = static_cast<int8_t>(numel);
|
||||
break;
|
||||
case int16:
|
||||
*out.data<int16_t>() = static_cast<int16_t>(numel);
|
||||
break;
|
||||
case int32:
|
||||
*out.data<int32_t>() = static_cast<int32_t>(numel);
|
||||
break;
|
||||
case int64:
|
||||
*out.data<int64_t>() = static_cast<int64_t>(numel);
|
||||
break;
|
||||
case float16:
|
||||
*out.data<float16_t>() = static_cast<float16_t>(numel);
|
||||
break;
|
||||
case float32:
|
||||
*out.data<float>() = static_cast<float>(numel);
|
||||
break;
|
||||
case bfloat16:
|
||||
*out.data<bfloat16_t>() = static_cast<bfloat16_t>(numel);
|
||||
break;
|
||||
case complex64:
|
||||
*out.data<complex64_t>() = static_cast<complex64_t>(numel);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
std::pair<bool, std::vector<size_t>> Reshape::prepare_reshape(
|
||||
const array& in,
|
||||
const array& out) {
|
||||
// Special case for empty arrays or row contiguous arrays
|
||||
if (in.size() == 0 || in.flags().row_contiguous) {
|
||||
return {false, out.strides()};
|
||||
}
|
||||
|
||||
// Special case for scalars
|
||||
if (in.ndim() == 0) {
|
||||
std::vector<size_t> out_strides(out.ndim(), 0);
|
||||
return {false, out_strides};
|
||||
}
|
||||
|
||||
// Firstly let's collapse all the contiguous dimensions of the input
|
||||
auto [shape, _strides] = collapse_contiguous_dims(in);
|
||||
auto& strides = _strides[0];
|
||||
|
||||
// If shapes fit exactly in the contiguous dims then no copy is necessary so
|
||||
// let's check.
|
||||
std::vector<size_t> out_strides;
|
||||
bool copy_necessary = false;
|
||||
int j = 0;
|
||||
for (int i = 0; i < out.ndim(); i++) {
|
||||
int N = out.shape(i);
|
||||
if (j < shape.size() && shape[j] % N == 0) {
|
||||
shape[j] /= N;
|
||||
out_strides.push_back(shape[j] * strides[j]);
|
||||
j += (shape[j] == 1);
|
||||
} else if (N == 1) {
|
||||
// i > 0 because otherwise j < shape.size() && shape[j] % 1 == 0
|
||||
out_strides.push_back(out_strides.back());
|
||||
} else {
|
||||
copy_necessary = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
return {copy_necessary, out_strides};
|
||||
}
|
||||
|
||||
void Reshape::shared_buffer_reshape(
|
||||
const array& in,
|
||||
const std::vector<size_t>& out_strides,
|
||||
array& out) {
|
||||
auto flags = in.flags();
|
||||
if (flags.row_contiguous) {
|
||||
// For row contiguous reshapes:
|
||||
// - Shallow copy the buffer
|
||||
// - If reshaping into a vector (all singleton dimensions except one) it
|
||||
// becomes col contiguous again.
|
||||
auto max_dim = std::max_element(out.shape().begin(), out.shape().end());
|
||||
flags.col_contiguous = out.size() <= 1 || out.size() == *max_dim;
|
||||
}
|
||||
out.copy_shared_buffer(in, out_strides, flags, in.data_size());
|
||||
}
|
||||
|
||||
void Split::eval(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
assert(inputs.size() == 1);
|
||||
|
||||
auto& in = inputs[0];
|
||||
|
||||
auto compute_new_flags = [](const auto& shape,
|
||||
const auto& strides,
|
||||
size_t in_data_size,
|
||||
auto flags) {
|
||||
size_t data_size = 1;
|
||||
size_t f_stride = 1;
|
||||
size_t b_stride = 1;
|
||||
flags.row_contiguous = true;
|
||||
flags.col_contiguous = true;
|
||||
for (int i = 0, ri = shape.size() - 1; ri >= 0; i++, ri--) {
|
||||
flags.col_contiguous &= strides[i] == f_stride || shape[i] == 1;
|
||||
flags.row_contiguous &= strides[ri] == b_stride || shape[ri] == 1;
|
||||
f_stride *= shape[i];
|
||||
b_stride *= shape[ri];
|
||||
if (strides[i] > 0) {
|
||||
data_size *= shape[i];
|
||||
}
|
||||
}
|
||||
|
||||
if (data_size == 1) {
|
||||
// Broadcasted scalar array is contiguous.
|
||||
flags.contiguous = true;
|
||||
} else if (data_size == in_data_size) {
|
||||
// Means we sliced a broadcasted dimension so leave the "no holes" flag
|
||||
// alone.
|
||||
} else {
|
||||
// We sliced something. So either we are row or col contiguous or we
|
||||
// punched a hole.
|
||||
flags.contiguous &= flags.row_contiguous || flags.col_contiguous;
|
||||
}
|
||||
|
||||
return std::pair<decltype(flags), size_t>{flags, data_size};
|
||||
};
|
||||
|
||||
std::vector<int> indices(1, 0);
|
||||
indices.insert(indices.end(), indices_.begin(), indices_.end());
|
||||
for (int i = 0; i < indices.size(); i++) {
|
||||
size_t offset = indices[i] * in.strides()[axis_];
|
||||
auto [new_flags, data_size] = compute_new_flags(
|
||||
outputs[i].shape(), in.strides(), in.data_size(), in.flags());
|
||||
outputs[i].copy_shared_buffer(
|
||||
in, in.strides(), new_flags, data_size, offset);
|
||||
}
|
||||
}
|
||||
|
||||
std::tuple<bool, int64_t, std::vector<int64_t>> Slice::prepare_slice(
|
||||
const array& in) {
|
||||
int64_t data_offset = 0;
|
||||
bool copy_needed = false;
|
||||
std::vector<int64_t> inp_strides(in.ndim(), 0);
|
||||
for (int i = 0; i < in.ndim(); ++i) {
|
||||
data_offset += start_indices_[i] * in.strides()[i];
|
||||
inp_strides[i] = in.strides()[i] * strides_[i];
|
||||
|
||||
copy_needed |= strides_[i] < 0;
|
||||
}
|
||||
|
||||
return std::make_tuple(copy_needed, data_offset, inp_strides);
|
||||
}
|
||||
|
||||
void Slice::shared_buffer_slice(
|
||||
const array& in,
|
||||
const std::vector<size_t>& out_strides,
|
||||
size_t data_offset,
|
||||
array& out) {
|
||||
// Compute row/col contiguity
|
||||
auto [data_size, is_row_contiguous, is_col_contiguous] =
|
||||
check_contiguity(out.shape(), out_strides);
|
||||
|
||||
auto flags = in.flags();
|
||||
flags.row_contiguous = is_row_contiguous;
|
||||
flags.col_contiguous = is_col_contiguous;
|
||||
|
||||
if (data_size == 1) {
|
||||
// Broadcasted scalar array is contiguous.
|
||||
flags.contiguous = true;
|
||||
} else if (data_size == in.data_size()) {
|
||||
// Means we sliced a broadcasted dimension so leave the "no holes" flag
|
||||
// alone.
|
||||
} else {
|
||||
// We sliced something. So either we are row or col contiguous or we
|
||||
// punched a hole.
|
||||
flags.contiguous &= flags.row_contiguous || flags.col_contiguous;
|
||||
}
|
||||
|
||||
out.copy_shared_buffer(in, out_strides, flags, data_size, data_offset);
|
||||
}
|
||||
|
||||
std::tuple<int64_t, std::vector<int64_t>> SliceUpdate::prepare_slice(
|
||||
const array& in) {
|
||||
int64_t data_offset = 0;
|
||||
std::vector<int64_t> inp_strides(in.ndim(), 0);
|
||||
for (int i = 0; i < in.ndim(); ++i) {
|
||||
data_offset += start_indices_[i] * in.strides()[i];
|
||||
inp_strides[i] = in.strides()[i] * strides_[i];
|
||||
}
|
||||
|
||||
return std::make_tuple(data_offset, inp_strides);
|
||||
}
|
||||
|
||||
void StopGradient::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
out.copy_shared_buffer(inputs[0]);
|
||||
}
|
||||
|
||||
void Transpose::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
std::vector<size_t> out_strides(out.ndim());
|
||||
auto& in = inputs[0];
|
||||
for (int ax = 0; ax < axes_.size(); ++ax) {
|
||||
out_strides[ax] = in.strides()[axes_[ax]];
|
||||
}
|
||||
|
||||
// Conditions for {row/col}_contiguous
|
||||
// - array must be contiguous (no gaps)
|
||||
// - underlying buffer size should have the same size as the array
|
||||
// - cumulative product of shapes is equal to the strides (we can ignore axes
|
||||
// with size == 1)
|
||||
// - in the forward direction (column contiguous)
|
||||
// - in the reverse direction (row contiguous)
|
||||
// - vectors are both row and col contiguous (hence if both row/col are
|
||||
// true, they stay true)
|
||||
auto flags = in.flags();
|
||||
if (flags.contiguous && in.data_size() == in.size()) {
|
||||
size_t f_stride = 1;
|
||||
size_t b_stride = 1;
|
||||
flags.col_contiguous = true;
|
||||
flags.row_contiguous = true;
|
||||
for (int i = 0, ri = out.ndim() - 1; i < out.ndim(); ++i, --ri) {
|
||||
flags.col_contiguous &= (out_strides[i] == f_stride || out.shape(i) == 1);
|
||||
f_stride *= out.shape(i);
|
||||
flags.row_contiguous &=
|
||||
(out_strides[ri] == b_stride || out.shape(ri) == 1);
|
||||
b_stride *= out.shape(ri);
|
||||
}
|
||||
}
|
||||
out.copy_shared_buffer(in, out_strides, flags, in.data_size());
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
@@ -111,13 +111,17 @@ void slow_conv_2D(
|
||||
const int N = in.shape(0); // Batch size, should be the same as out.shape(0)
|
||||
const int iH = 1 + in_dilation[0] * (in.shape(1) - 1); // Input spatial dim
|
||||
const int iW = 1 + in_dilation[1] * (in.shape(2) - 1); // Input spatial dim
|
||||
const int C = in.shape(3); // In channels
|
||||
const int oH = out.shape(1); // Output spatial dim
|
||||
const int oW = out.shape(2); // Output spatial dim
|
||||
const int O = wt.shape(0); // Out channels
|
||||
const int C = wt.shape(3); // In channels
|
||||
const int wH = wt.shape(1); // Weight spatial dim
|
||||
const int wW = wt.shape(2); // Weight spatial dim
|
||||
|
||||
const int groups = C / wt.shape(3);
|
||||
const int C_per_group = wt.shape(3);
|
||||
const int O_per_group = O / groups;
|
||||
|
||||
const size_t in_stride_N = in.strides()[0];
|
||||
const size_t in_stride_H = in.strides()[1];
|
||||
const size_t in_stride_W = in.strides()[2];
|
||||
@@ -141,33 +145,35 @@ void slow_conv_2D(
|
||||
int ih_base = oh * wt_strides[0] - padding[0];
|
||||
int iw_base = ow * wt_strides[1] - padding[1];
|
||||
|
||||
for (int o = 0; o < O; ++o) {
|
||||
float r = 0.;
|
||||
for (int g = 0; g < groups; ++g) {
|
||||
for (int o = g * O_per_group; o < (g + 1) * O_per_group; ++o) {
|
||||
float r = 0.;
|
||||
|
||||
for (int wh = 0; wh < wH; ++wh) {
|
||||
for (int ww = 0; ww < wW; ++ww) {
|
||||
int wh_flip = flip ? wH - wh - 1 : wh;
|
||||
int ww_flip = flip ? wW - ww - 1 : ww;
|
||||
int ih = ih_base + wh_flip * wt_dilation[0];
|
||||
int iw = iw_base + ww_flip * wt_dilation[1];
|
||||
for (int wh = 0; wh < wH; ++wh) {
|
||||
for (int ww = 0; ww < wW; ++ww) {
|
||||
int wh_flip = flip ? wH - wh - 1 : wh;
|
||||
int ww_flip = flip ? wW - ww - 1 : ww;
|
||||
int ih = ih_base + wh_flip * wt_dilation[0];
|
||||
int iw = iw_base + ww_flip * wt_dilation[1];
|
||||
|
||||
const T* wt_ptr_pt = wt_ptr + wh * wt_stride_H + ww * wt_stride_W;
|
||||
const T* in_ptr_pt = in_ptr + ih * in_stride_H + iw * in_stride_W;
|
||||
const T* wt_ptr_pt =
|
||||
wt_ptr + wh * wt_stride_H + ww * wt_stride_W;
|
||||
const T* in_ptr_pt =
|
||||
in_ptr + ih * in_stride_H + iw * in_stride_W;
|
||||
|
||||
for (int c = 0; c < C; ++c) {
|
||||
r += static_cast<float>(in_ptr_pt[0]) *
|
||||
static_cast<float>(wt_ptr_pt[0]);
|
||||
in_ptr_pt += in_stride_C;
|
||||
wt_ptr_pt += wt_stride_C;
|
||||
} // c
|
||||
for (int c = g * C_per_group; c < (g + 1) * C_per_group; ++c) {
|
||||
r += static_cast<float>(in_ptr_pt[c * in_stride_C]) *
|
||||
static_cast<float>(
|
||||
wt_ptr_pt[(c % C_per_group) * wt_stride_C]);
|
||||
} // c
|
||||
} // ww
|
||||
} // wh
|
||||
|
||||
} // ww
|
||||
} // wh
|
||||
|
||||
out_ptr[0] = static_cast<T>(r);
|
||||
out_ptr += out_stride_O;
|
||||
wt_ptr += wt_stride_O;
|
||||
} // o
|
||||
out_ptr[0] = static_cast<T>(r);
|
||||
out_ptr += out_stride_O;
|
||||
wt_ptr += wt_stride_O;
|
||||
} // o
|
||||
} // g
|
||||
};
|
||||
|
||||
int jump_h = flip ? -wt_dilation[0] : wt_dilation[0];
|
||||
@@ -219,41 +225,43 @@ void slow_conv_2D(
|
||||
int wh_base = base_h[oh % f_out_jump_h];
|
||||
int ww_base = base_w[ow % f_out_jump_w];
|
||||
|
||||
for (int o = 0; o < O; ++o) {
|
||||
float r = 0.;
|
||||
for (int g = 0; g < groups; ++g) {
|
||||
for (int o = g * O_per_group; o < (g + 1) * O_per_group; ++o) {
|
||||
float r = 0.;
|
||||
|
||||
for (int wh = wh_base; wh < wH; wh += f_wgt_jump_h) {
|
||||
for (int ww = ww_base; ww < wW; ww += f_wgt_jump_w) {
|
||||
int wh_flip = flip ? wH - wh - 1 : wh;
|
||||
int ww_flip = flip ? wW - ww - 1 : ww;
|
||||
int ih = ih_base + wh_flip * wt_dilation[0];
|
||||
int iw = iw_base + ww_flip * wt_dilation[1];
|
||||
for (int wh = wh_base; wh < wH; wh += f_wgt_jump_h) {
|
||||
for (int ww = ww_base; ww < wW; ww += f_wgt_jump_w) {
|
||||
int wh_flip = flip ? wH - wh - 1 : wh;
|
||||
int ww_flip = flip ? wW - ww - 1 : ww;
|
||||
int ih = ih_base + wh_flip * wt_dilation[0];
|
||||
int iw = iw_base + ww_flip * wt_dilation[1];
|
||||
|
||||
if (ih >= 0 && ih < iH && iw >= 0 && iw < iW) {
|
||||
const T* wt_ptr_pt =
|
||||
wt_ptr + wh * wt_stride_H + ww * wt_stride_W;
|
||||
if (ih >= 0 && ih < iH && iw >= 0 && iw < iW) {
|
||||
const T* wt_ptr_pt =
|
||||
wt_ptr + wh * wt_stride_H + ww * wt_stride_W;
|
||||
|
||||
int ih_dil = !is_idil_one ? (ih / in_dilation[0]) : ih;
|
||||
int iw_dil = !is_idil_one ? (iw / in_dilation[1]) : iw;
|
||||
int ih_dil = !is_idil_one ? (ih / in_dilation[0]) : ih;
|
||||
int iw_dil = !is_idil_one ? (iw / in_dilation[1]) : iw;
|
||||
|
||||
const T* in_ptr_pt =
|
||||
in_ptr + ih_dil * in_stride_H + iw_dil * in_stride_W;
|
||||
const T* in_ptr_pt =
|
||||
in_ptr + ih_dil * in_stride_H + iw_dil * in_stride_W;
|
||||
|
||||
for (int c = 0; c < C; ++c) {
|
||||
r += static_cast<float>(in_ptr_pt[0]) *
|
||||
static_cast<float>(wt_ptr_pt[0]);
|
||||
in_ptr_pt += in_stride_C;
|
||||
wt_ptr_pt += wt_stride_C;
|
||||
} // c
|
||||
for (int c = g * C_per_group; c < (g + 1) * C_per_group;
|
||||
++c) {
|
||||
r += static_cast<float>(in_ptr_pt[c * in_stride_C]) *
|
||||
static_cast<float>(
|
||||
wt_ptr_pt[(c % C_per_group) * wt_stride_C]);
|
||||
} // c
|
||||
|
||||
} // ih, iw check
|
||||
} // ww
|
||||
} // wh
|
||||
} // ih, iw check
|
||||
} // ww
|
||||
} // wh
|
||||
|
||||
out_ptr[0] = static_cast<T>(r);
|
||||
out_ptr += out_stride_O;
|
||||
wt_ptr += wt_stride_O;
|
||||
} // o
|
||||
out_ptr[0] = static_cast<T>(r);
|
||||
out_ptr += out_stride_O;
|
||||
wt_ptr += wt_stride_O;
|
||||
} // o
|
||||
} // g
|
||||
};
|
||||
|
||||
int oH_border_0 = 0;
|
||||
@@ -310,6 +318,296 @@ void slow_conv_2D(
|
||||
} // n
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void slow_conv_3D(
|
||||
const array& in,
|
||||
const array& wt,
|
||||
array out,
|
||||
const std::vector<int>& padding,
|
||||
const std::vector<int>& wt_strides,
|
||||
const std::vector<int>& wt_dilation,
|
||||
const std::vector<int>& in_dilation,
|
||||
bool flip) {
|
||||
const T* st_wt_ptr = wt.data<T>();
|
||||
const T* st_in_ptr = in.data<T>();
|
||||
T* st_out_ptr = out.data<T>();
|
||||
|
||||
const int N = in.shape(0); // Batch size, should be the same as out.shape(0)
|
||||
const int iD = 1 + in_dilation[0] * (in.shape(1) - 1); // Input spatial dim
|
||||
const int iH = 1 + in_dilation[1] * (in.shape(2) - 1); // Input spatial dim
|
||||
const int iW = 1 + in_dilation[2] * (in.shape(3) - 1); // Input spatial dim
|
||||
const int oD = out.shape(1); // Output spatial dim
|
||||
const int oH = out.shape(2); // Output spatial dim
|
||||
const int oW = out.shape(3); // Output spatial dim
|
||||
const int O = wt.shape(0); // Out channels
|
||||
const int C = wt.shape(4); // In channels
|
||||
const int wD = wt.shape(1); // Weight spatial dim
|
||||
const int wH = wt.shape(2); // Weight spatial dim
|
||||
const int wW = wt.shape(3); // Weight spatial dim
|
||||
|
||||
const size_t in_stride_N = in.strides()[0];
|
||||
const size_t in_stride_D = in.strides()[1];
|
||||
const size_t in_stride_H = in.strides()[2];
|
||||
const size_t in_stride_W = in.strides()[3];
|
||||
const size_t in_stride_C = in.strides()[4];
|
||||
|
||||
const size_t wt_stride_O = wt.strides()[0];
|
||||
const size_t wt_stride_D = wt.strides()[1];
|
||||
const size_t wt_stride_H = wt.strides()[2];
|
||||
const size_t wt_stride_W = wt.strides()[3];
|
||||
const size_t wt_stride_C = wt.strides()[4];
|
||||
|
||||
const size_t out_stride_N = out.strides()[0];
|
||||
const size_t out_stride_D = out.strides()[1];
|
||||
const size_t out_stride_H = out.strides()[2];
|
||||
const size_t out_stride_W = out.strides()[3];
|
||||
const size_t out_stride_O = out.strides()[4];
|
||||
|
||||
bool is_idil_one =
|
||||
in_dilation[0] == 1 && in_dilation[1] == 1 && in_dilation[2] == 1;
|
||||
|
||||
auto pt_conv_no_checks = [&](const T* in_ptr,
|
||||
const T* wt_ptr,
|
||||
T* out_ptr,
|
||||
int od,
|
||||
int oh,
|
||||
int ow) {
|
||||
out_ptr += od * out_stride_D + oh * out_stride_H + ow * out_stride_W;
|
||||
int id_base = od * wt_strides[0] - padding[0];
|
||||
int ih_base = oh * wt_strides[1] - padding[1];
|
||||
int iw_base = ow * wt_strides[2] - padding[2];
|
||||
|
||||
for (int o = 0; o < O; ++o) {
|
||||
float r = 0.;
|
||||
|
||||
for (int wd = 0; wd < wD; ++wd) {
|
||||
for (int wh = 0; wh < wH; ++wh) {
|
||||
for (int ww = 0; ww < wW; ++ww) {
|
||||
int wd_flip = flip ? wD - wd - 1 : wd;
|
||||
int wh_flip = flip ? wH - wh - 1 : wh;
|
||||
int ww_flip = flip ? wW - ww - 1 : ww;
|
||||
int id = id_base + wd_flip * wt_dilation[0];
|
||||
int ih = ih_base + wh_flip * wt_dilation[1];
|
||||
int iw = iw_base + ww_flip * wt_dilation[2];
|
||||
|
||||
const T* wt_ptr_pt =
|
||||
wt_ptr + wd * wt_stride_D + wh * wt_stride_H + ww * wt_stride_W;
|
||||
const T* in_ptr_pt =
|
||||
in_ptr + id * in_stride_D + ih * in_stride_H + iw * in_stride_W;
|
||||
|
||||
for (int c = 0; c < C; ++c) {
|
||||
r += static_cast<float>(in_ptr_pt[0]) *
|
||||
static_cast<float>(wt_ptr_pt[0]);
|
||||
in_ptr_pt += in_stride_C;
|
||||
wt_ptr_pt += wt_stride_C;
|
||||
} // c
|
||||
|
||||
} // ww
|
||||
} // wh
|
||||
} // wd
|
||||
|
||||
out_ptr[0] = static_cast<T>(r);
|
||||
out_ptr += out_stride_O;
|
||||
wt_ptr += wt_stride_O;
|
||||
} // o
|
||||
};
|
||||
|
||||
int jump_d = flip ? -wt_dilation[0] : wt_dilation[0];
|
||||
int jump_h = flip ? -wt_dilation[1] : wt_dilation[1];
|
||||
int jump_w = flip ? -wt_dilation[2] : wt_dilation[2];
|
||||
|
||||
int init_d = (flip ? (wD - 1) * wt_dilation[0] : 0);
|
||||
int init_h = (flip ? (wH - 1) * wt_dilation[1] : 0);
|
||||
int init_w = (flip ? (wW - 1) * wt_dilation[2] : 0);
|
||||
|
||||
int f_wgt_jump_d = std::lcm(in_dilation[0], wt_dilation[0]) / wt_dilation[0];
|
||||
int f_wgt_jump_h = std::lcm(in_dilation[1], wt_dilation[1]) / wt_dilation[1];
|
||||
int f_wgt_jump_w = std::lcm(in_dilation[2], wt_dilation[2]) / wt_dilation[2];
|
||||
|
||||
int f_out_jump_d = std::lcm(in_dilation[0], wt_strides[0]) / wt_strides[0];
|
||||
int f_out_jump_h = std::lcm(in_dilation[1], wt_strides[1]) / wt_strides[1];
|
||||
int f_out_jump_w = std::lcm(in_dilation[2], wt_strides[2]) / wt_strides[2];
|
||||
|
||||
std::vector<int> base_d(f_out_jump_d);
|
||||
std::vector<int> base_h(f_out_jump_h);
|
||||
std::vector<int> base_w(f_out_jump_w);
|
||||
|
||||
for (int i = 0; i < f_out_jump_d; ++i) {
|
||||
int id_loop = i * wt_strides[0] - padding[0] + init_d;
|
||||
|
||||
int wd_base = 0;
|
||||
while (wd_base < wD && id_loop % in_dilation[0] != 0) {
|
||||
wd_base++;
|
||||
id_loop += jump_d;
|
||||
}
|
||||
|
||||
base_d[i] = wd_base;
|
||||
}
|
||||
|
||||
for (int i = 0; i < f_out_jump_h; ++i) {
|
||||
int ih_loop = i * wt_strides[1] - padding[1] + init_h;
|
||||
|
||||
int wh_base = 0;
|
||||
while (wh_base < wH && ih_loop % in_dilation[1] != 0) {
|
||||
wh_base++;
|
||||
ih_loop += jump_h;
|
||||
}
|
||||
|
||||
base_h[i] = wh_base;
|
||||
}
|
||||
|
||||
for (int j = 0; j < f_out_jump_w; ++j) {
|
||||
int iw_loop = j * wt_strides[2] - padding[2] + init_w;
|
||||
|
||||
int ww_base = 0;
|
||||
while (ww_base < wW && iw_loop % in_dilation[2] != 0) {
|
||||
ww_base++;
|
||||
iw_loop += jump_w;
|
||||
}
|
||||
|
||||
base_w[j] = ww_base;
|
||||
}
|
||||
|
||||
auto pt_conv_all_checks = [&](const T* in_ptr,
|
||||
const T* wt_ptr,
|
||||
T* out_ptr,
|
||||
int od,
|
||||
int oh,
|
||||
int ow) {
|
||||
out_ptr += od * out_stride_D + oh * out_stride_H + ow * out_stride_W;
|
||||
|
||||
int id_base = od * wt_strides[0] - padding[0];
|
||||
int ih_base = oh * wt_strides[1] - padding[1];
|
||||
int iw_base = ow * wt_strides[2] - padding[2];
|
||||
|
||||
int wd_base = base_d[od % f_out_jump_d];
|
||||
int wh_base = base_h[oh % f_out_jump_h];
|
||||
int ww_base = base_w[ow % f_out_jump_w];
|
||||
|
||||
for (int o = 0; o < O; ++o) {
|
||||
float r = 0.;
|
||||
|
||||
for (int wd = wd_base; wd < wD; wd += f_wgt_jump_d) {
|
||||
for (int wh = wh_base; wh < wH; wh += f_wgt_jump_h) {
|
||||
for (int ww = ww_base; ww < wW; ww += f_wgt_jump_w) {
|
||||
int wd_flip = flip ? wD - wd - 1 : wd;
|
||||
int wh_flip = flip ? wH - wh - 1 : wh;
|
||||
int ww_flip = flip ? wW - ww - 1 : ww;
|
||||
int id = id_base + wd_flip * wt_dilation[0];
|
||||
int ih = ih_base + wh_flip * wt_dilation[1];
|
||||
int iw = iw_base + ww_flip * wt_dilation[2];
|
||||
|
||||
if (id >= 0 && id < iD && ih >= 0 && ih < iH && iw >= 0 &&
|
||||
iw < iW) {
|
||||
const T* wt_ptr_pt = wt_ptr + wd * wt_stride_D +
|
||||
wh * wt_stride_H + ww * wt_stride_W;
|
||||
|
||||
int id_dil = !is_idil_one ? (id / in_dilation[0]) : id;
|
||||
int ih_dil = !is_idil_one ? (ih / in_dilation[1]) : ih;
|
||||
int iw_dil = !is_idil_one ? (iw / in_dilation[2]) : iw;
|
||||
|
||||
const T* in_ptr_pt = in_ptr + id_dil * in_stride_D +
|
||||
ih_dil * in_stride_H + iw_dil * in_stride_W;
|
||||
|
||||
for (int c = 0; c < C; ++c) {
|
||||
r += static_cast<float>(in_ptr_pt[0]) *
|
||||
static_cast<float>(wt_ptr_pt[0]);
|
||||
in_ptr_pt += in_stride_C;
|
||||
wt_ptr_pt += wt_stride_C;
|
||||
} // c
|
||||
|
||||
} // iD, ih, iw check
|
||||
} // ww
|
||||
} // wh
|
||||
} // wd
|
||||
|
||||
out_ptr[0] = static_cast<T>(r);
|
||||
out_ptr += out_stride_O;
|
||||
wt_ptr += wt_stride_O;
|
||||
} // o
|
||||
};
|
||||
|
||||
int oD_border_0 = 0;
|
||||
int oD_border_1 =
|
||||
is_idil_one ? ((padding[0] + wt_strides[0] - 1) / wt_strides[0]) : oD;
|
||||
int oD_border_2 = std::max(
|
||||
oD_border_1, (iD + padding[0] - wD * wt_dilation[0]) / wt_strides[0]);
|
||||
int oD_border_3 = oD;
|
||||
|
||||
int oH_border_0 = 0;
|
||||
int oH_border_1 =
|
||||
is_idil_one ? ((padding[1] + wt_strides[1] - 1) / wt_strides[1]) : oH;
|
||||
int oH_border_2 = std::max(
|
||||
oH_border_1, (iH + padding[1] - wH * wt_dilation[1]) / wt_strides[1]);
|
||||
int oH_border_3 = oH;
|
||||
|
||||
int oW_border_0 = 0;
|
||||
int oW_border_1 =
|
||||
is_idil_one ? ((padding[2] + wt_strides[2] - 1) / wt_strides[2]) : oW;
|
||||
int oW_border_2 = std::max(
|
||||
oW_border_1, (iW + padding[2] - wW * wt_dilation[2]) / wt_strides[2]);
|
||||
int oW_border_3 = oW;
|
||||
|
||||
for (int n = 0; n < N; ++n) {
|
||||
// Case 1: od might put us out of bounds
|
||||
for (int od = oD_border_0; od < oD_border_1; ++od) {
|
||||
for (int oh = 0; oh < oH; ++oh) {
|
||||
for (int ow = 0; ow < oW; ++ow) {
|
||||
pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, od, oh, ow);
|
||||
} // ow
|
||||
} // oh
|
||||
} // od
|
||||
|
||||
// Case 2: od in bounds
|
||||
for (int od = oD_border_1; od < oD_border_2; ++od) {
|
||||
// Case 2.1: oh might put us out of bounds
|
||||
for (int oh = oH_border_0; oh < oH_border_1; ++oh) {
|
||||
for (int ow = 0; ow < oW; ++ow) {
|
||||
pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, od, oh, ow);
|
||||
} // ow
|
||||
} // oh
|
||||
|
||||
// Case 2.2: oh in bounds
|
||||
for (int oh = oH_border_1; oh < oH_border_2; ++oh) {
|
||||
// Case 2.2.1: ow might put us out of bounds
|
||||
for (int ow = oW_border_0; ow < oW_border_1; ++ow) {
|
||||
pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, od, oh, ow);
|
||||
} // ow
|
||||
|
||||
// Case 2.2.2: ow in bounds
|
||||
for (int ow = oW_border_1; ow < oW_border_2; ++ow) {
|
||||
pt_conv_no_checks(st_in_ptr, st_wt_ptr, st_out_ptr, od, oh, ow);
|
||||
} // ow
|
||||
|
||||
// Case 2.2.3: ow might put us out of bounds
|
||||
for (int ow = oW_border_2; ow < oW_border_3; ++ow) {
|
||||
pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, od, oh, ow);
|
||||
} // ow
|
||||
} // oh
|
||||
|
||||
// Case 2.3: oh might put us out of bounds
|
||||
for (int oh = oH_border_2; oh < oH_border_3; ++oh) {
|
||||
for (int ow = 0; ow < oW; ++ow) {
|
||||
pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, od, oh, ow);
|
||||
} // ow
|
||||
} // oh
|
||||
} // od
|
||||
|
||||
// Case 3: od might put us out of bounds
|
||||
for (int od = oD_border_2; od < oD_border_3; ++od) {
|
||||
for (int oh = 0; oh < oH; ++oh) {
|
||||
for (int ow = 0; ow < oW; ++ow) {
|
||||
pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, od, oh, ow);
|
||||
} // ow
|
||||
} // oh
|
||||
} // od
|
||||
|
||||
st_in_ptr += in_stride_N;
|
||||
st_out_ptr += out_stride_N;
|
||||
|
||||
} // n
|
||||
}
|
||||
|
||||
void dispatch_slow_conv_1D(
|
||||
const array& in,
|
||||
const array& wt,
|
||||
@@ -358,6 +656,30 @@ void dispatch_slow_conv_2D(
|
||||
}
|
||||
}
|
||||
|
||||
void dispatch_slow_conv_3D(
|
||||
const array& in,
|
||||
const array& wt,
|
||||
array out,
|
||||
const std::vector<int>& padding,
|
||||
const std::vector<int>& wt_strides,
|
||||
const std::vector<int>& wt_dilation,
|
||||
const std::vector<int>& in_dilation,
|
||||
bool flip) {
|
||||
if (in.dtype() == float32) {
|
||||
return slow_conv_3D<float>(
|
||||
in, wt, out, padding, wt_strides, wt_dilation, in_dilation, flip);
|
||||
} else if (in.dtype() == float16) {
|
||||
return slow_conv_3D<float16_t>(
|
||||
in, wt, out, padding, wt_strides, wt_dilation, in_dilation, flip);
|
||||
} else if (in.dtype() == bfloat16) {
|
||||
return slow_conv_3D<bfloat16_t>(
|
||||
in, wt, out, padding, wt_strides, wt_dilation, in_dilation, flip);
|
||||
} else {
|
||||
throw std::invalid_argument(
|
||||
"[Convolution::eval] got unsupported data type.");
|
||||
}
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Explicit gemm conv
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
@@ -582,6 +904,131 @@ void explicit_gemm_conv_2D_cpu(
|
||||
}
|
||||
}
|
||||
|
||||
void explicit_gemm_conv_ND_cpu(
|
||||
const array& in,
|
||||
const array& wt,
|
||||
array out,
|
||||
const std::vector<int>& padding,
|
||||
const std::vector<int>& wt_strides,
|
||||
const std::vector<int>& wt_dilation) {
|
||||
const int N = in.shape(0); // Batch size, should be the same as out.shape(0)
|
||||
const auto iDim = std::vector<int>(
|
||||
in.shape().begin() + 1, in.shape().end() - 1); // Input spatial dim
|
||||
const auto oDim = std::vector<int>(
|
||||
out.shape().begin() + 1, out.shape().end() - 1); // Output spatial dim
|
||||
const int O = wt.shape(0); // Out channels
|
||||
const int C = wt.shape(-1); // In channels
|
||||
const auto wDim = std::vector<int>(
|
||||
wt.shape().begin() + 1, wt.shape().end() - 1); // Weight spatial dim
|
||||
|
||||
auto conv_dtype = float32;
|
||||
|
||||
// Pad input
|
||||
std::vector<int> padded_shape(in.shape().size());
|
||||
padded_shape.front() = N;
|
||||
for (size_t i = 0; i < iDim.size(); i++) {
|
||||
padded_shape[i + 1] = iDim[i] + 2 * padding[i];
|
||||
}
|
||||
padded_shape.back() = C;
|
||||
array in_padded(padded_shape, conv_dtype, nullptr, {});
|
||||
|
||||
// Fill with zeros
|
||||
copy(array(0, conv_dtype), in_padded, CopyType::Scalar);
|
||||
|
||||
// Pick input slice from padded
|
||||
size_t data_offset = 0;
|
||||
for (size_t i = 0; i < padding.size(); i++) {
|
||||
data_offset += padding[i] * in_padded.strides()[i + 1];
|
||||
}
|
||||
array in_padded_slice(in.shape(), in_padded.dtype(), nullptr, {});
|
||||
in_padded_slice.copy_shared_buffer(
|
||||
in_padded,
|
||||
in_padded.strides(),
|
||||
in_padded.flags(),
|
||||
in_padded_slice.size(),
|
||||
data_offset);
|
||||
|
||||
// Copy input values into the slice
|
||||
copy_inplace(in, in_padded_slice, CopyType::GeneralGeneral);
|
||||
|
||||
// Make strided view
|
||||
std::vector<int> strided_shape(oDim.size() + wDim.size() + 2);
|
||||
strided_shape.front() = N;
|
||||
for (size_t i = 0; i < oDim.size(); i++) {
|
||||
strided_shape[i + 1] = oDim[i];
|
||||
}
|
||||
for (size_t i = 0; i < wDim.size(); i++) {
|
||||
strided_shape[i + 1 + oDim.size()] = wDim[i];
|
||||
}
|
||||
strided_shape.back() = C;
|
||||
|
||||
std::vector<size_t> strided_strides(in.shape().size() * 2 - 2);
|
||||
strided_strides[0] = in_padded.strides()[0];
|
||||
for (size_t i = 0; i < wt_strides.size(); i++) {
|
||||
strided_strides[i + 1] = in_padded.strides()[i + 1] * wt_strides[i];
|
||||
}
|
||||
for (size_t i = 1; i < in_padded.strides().size(); i++) {
|
||||
strided_strides[i + wt_strides.size()] = in_padded.strides()[i];
|
||||
}
|
||||
|
||||
auto flags = in_padded.flags();
|
||||
|
||||
array in_strided_view(strided_shape, in_padded.dtype(), nullptr, {});
|
||||
in_strided_view.copy_shared_buffer(
|
||||
in_padded, strided_strides, flags, in_strided_view.size(), 0);
|
||||
|
||||
// Materialize strided view
|
||||
std::vector<int> strided_reshape = {N, C};
|
||||
for (const auto& o : oDim) {
|
||||
strided_reshape[0] *= o;
|
||||
}
|
||||
for (const auto& w : wDim) {
|
||||
strided_reshape[1] *= w;
|
||||
}
|
||||
|
||||
array in_strided(strided_reshape, in_strided_view.dtype(), nullptr, {});
|
||||
copy(in_strided_view, in_strided, CopyType::General);
|
||||
|
||||
// Check wt dtype and prepare
|
||||
auto gemm_wt = wt;
|
||||
auto gemm_out = out;
|
||||
|
||||
if (wt.dtype() != float32 || !wt.flags().row_contiguous) {
|
||||
auto ctype =
|
||||
wt.flags().row_contiguous ? CopyType::Vector : CopyType::General;
|
||||
gemm_wt = array(wt.shape(), float32, nullptr, {});
|
||||
copy(wt, gemm_wt, ctype);
|
||||
}
|
||||
|
||||
if (out.dtype() != float32) {
|
||||
gemm_out = array(out.shape(), float32, nullptr, {});
|
||||
gemm_out.set_data(allocator::malloc_or_wait(gemm_out.nbytes()));
|
||||
}
|
||||
|
||||
// Perform gemm
|
||||
cblas_sgemm(
|
||||
CblasRowMajor,
|
||||
CblasNoTrans, // no trans A
|
||||
CblasTrans, // transB
|
||||
strided_reshape[0], // M
|
||||
O, // N
|
||||
strided_reshape[1], // K
|
||||
1.0f, // alpha
|
||||
in_strided.data<float>(),
|
||||
strided_reshape[1], // lda
|
||||
gemm_wt.data<float>(),
|
||||
strided_reshape[1], // ldb
|
||||
0.0f, // beta
|
||||
gemm_out.data<float>(),
|
||||
O // ldc
|
||||
);
|
||||
|
||||
// Copy results if needed
|
||||
if (out.dtype() != float32) {
|
||||
copy(gemm_out, out, CopyType::Vector);
|
||||
}
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Conv routing
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
@@ -617,6 +1064,19 @@ void conv_2D_cpu(
|
||||
in, wt, out, padding, wt_strides, wt_dilation, in_dilation, flip);
|
||||
}
|
||||
|
||||
void conv_3D_cpu(
|
||||
const array& in,
|
||||
const array& wt,
|
||||
array out,
|
||||
const std::vector<int>& padding,
|
||||
const std::vector<int>& wt_strides,
|
||||
const std::vector<int>& wt_dilation,
|
||||
const std::vector<int>& in_dilation,
|
||||
bool flip) {
|
||||
return dispatch_slow_conv_3D(
|
||||
in, wt, out, padding, wt_strides, wt_dilation, in_dilation, flip);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void Convolution::eval(const std::vector<array>& inputs, array& out) {
|
||||
@@ -625,8 +1085,20 @@ void Convolution::eval(const std::vector<array>& inputs, array& out) {
|
||||
auto& in = inputs[0];
|
||||
auto& wt = inputs[1];
|
||||
|
||||
// 3D convolution
|
||||
if (in.ndim() == (3 + 2)) {
|
||||
return conv_3D_cpu(
|
||||
in,
|
||||
wt,
|
||||
out,
|
||||
padding_,
|
||||
kernel_strides_,
|
||||
kernel_dilation_,
|
||||
input_dilation_,
|
||||
flip_);
|
||||
}
|
||||
// 2D convolution
|
||||
if (in.ndim() == (2 + 2)) {
|
||||
else if (in.ndim() == (2 + 2)) {
|
||||
return conv_2D_cpu(
|
||||
in,
|
||||
wt,
|
||||
|
@@ -256,7 +256,7 @@ void copy_general_general(
|
||||
}
|
||||
|
||||
int size = std::accumulate(
|
||||
data_shape.begin() - 5, data_shape.end(), 1, std::multiplies<int>());
|
||||
data_shape.end() - 5, data_shape.end(), 1, std::multiplies<int>());
|
||||
for (int i = 0; i < src.size(); i += size) {
|
||||
stride_t src_offset = i_offset + elem_to_loc(i, data_shape, i_strides);
|
||||
stride_t dst_offset = o_offset + elem_to_loc(i, dst.shape(), o_strides);
|
||||
|
@@ -1,47 +0,0 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include "mlx/backend/common/cpu_impl.h"
|
||||
#include "mlx/primitives.h"
|
||||
#include "mlx/scheduler.h"
|
||||
|
||||
namespace mlx::core::cpu {
|
||||
|
||||
std::function<void()> make_task(array arr, bool signal) {
|
||||
return [arr = std::move(arr), signal]() mutable {
|
||||
auto stream = arr.primitive().stream();
|
||||
|
||||
// Wait on inputs coming from different streams/devices.
|
||||
for (auto& input : arr.inputs()) {
|
||||
if (input.event().valid() && input.event().stream() != stream) {
|
||||
input.event().wait();
|
||||
}
|
||||
}
|
||||
|
||||
// Task computation actually starting.
|
||||
scheduler::notify_new_task(stream);
|
||||
|
||||
// Perform the computation
|
||||
auto outputs = arr.outputs();
|
||||
arr.primitive().eval_cpu(arr.inputs(), outputs);
|
||||
|
||||
// Check if we need to detach and signal other arrays waiting for the
|
||||
// result to be ready.
|
||||
if (!arr.is_tracer()) {
|
||||
arr.detach();
|
||||
}
|
||||
if (signal) {
|
||||
arr.event().signal();
|
||||
}
|
||||
|
||||
// Task computation done.
|
||||
scheduler::notify_task_completion(stream);
|
||||
};
|
||||
}
|
||||
|
||||
std::function<void()> make_synchronize_task(
|
||||
Stream s,
|
||||
std::shared_ptr<std::promise<void>> p) {
|
||||
return [p = std::move(p)]() { p->set_value(); };
|
||||
}
|
||||
|
||||
} // namespace mlx::core::cpu
|
@@ -1,18 +0,0 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <functional>
|
||||
#include <future>
|
||||
#include <memory>
|
||||
|
||||
#include "mlx/array.h"
|
||||
|
||||
namespace mlx::core::cpu {
|
||||
|
||||
std::function<void()> make_task(array arr, bool signal);
|
||||
std::function<void()> make_synchronize_task(
|
||||
Stream s,
|
||||
std::shared_ptr<std::promise<void>> p);
|
||||
|
||||
} // namespace mlx::core::cpu
|
@@ -34,6 +34,7 @@ DEFAULT(ArcCosh)
|
||||
DEFAULT(ArcSin)
|
||||
DEFAULT(ArcSinh)
|
||||
DEFAULT(ArcTan)
|
||||
DEFAULT(ArcTan2)
|
||||
DEFAULT(ArcTanh)
|
||||
DEFAULT(ArgPartition)
|
||||
DEFAULT(ArgReduce)
|
||||
@@ -42,10 +43,12 @@ DEFAULT(AsType)
|
||||
DEFAULT(AsStrided)
|
||||
DEFAULT(Broadcast)
|
||||
DEFAULT(BlockMaskedMM)
|
||||
DEFAULT(BlockSparseMM)
|
||||
DEFAULT(GatherMM)
|
||||
DEFAULT(GatherQMM)
|
||||
DEFAULT_MULTI(DivMod)
|
||||
DEFAULT(Ceil)
|
||||
DEFAULT(Concatenate)
|
||||
DEFAULT(Conjugate)
|
||||
DEFAULT(Convolution)
|
||||
DEFAULT(Copy)
|
||||
DEFAULT(Cos)
|
||||
@@ -68,6 +71,7 @@ DEFAULT(Greater)
|
||||
DEFAULT(GreaterEqual)
|
||||
DEFAULT(Less)
|
||||
DEFAULT(LessEqual)
|
||||
DEFAULT(Load)
|
||||
DEFAULT(Log)
|
||||
DEFAULT(Log1p)
|
||||
DEFAULT(LogicalNot)
|
||||
@@ -109,6 +113,7 @@ DEFAULT(Tan)
|
||||
DEFAULT(Tanh)
|
||||
DEFAULT(Transpose)
|
||||
DEFAULT(Inverse)
|
||||
DEFAULT(Cholesky)
|
||||
|
||||
namespace {
|
||||
|
||||
|
@@ -2,7 +2,6 @@
|
||||
|
||||
#include "mlx/allocator.h"
|
||||
#include "mlx/backend/common/copy.h"
|
||||
#include "mlx/linalg.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
#ifdef ACCELERATE_NEW_LAPACK
|
||||
@@ -93,12 +92,4 @@ void Inverse::eval(const std::vector<array>& inputs, array& output) {
|
||||
inverse_impl(inputs[0], output);
|
||||
}
|
||||
|
||||
std::pair<std::vector<array>, std::vector<int>> Inverse::vmap(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<int>& axes) {
|
||||
auto ax = axes[0] >= 0 ? 0 : -1;
|
||||
auto a = axes[0] > 0 ? moveaxis(inputs[0], axes[0], 0, stream()) : inputs[0];
|
||||
return {{linalg::inv(a, stream())}, {ax}};
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@@ -17,24 +17,25 @@ namespace mlx::core {
|
||||
|
||||
namespace {
|
||||
|
||||
template <typename T>
|
||||
template <typename T, typename mask_t>
|
||||
inline void mask_matrix(
|
||||
T* data,
|
||||
const bool* mask,
|
||||
const mask_t* mask,
|
||||
int block_size,
|
||||
const int X,
|
||||
const int Y,
|
||||
const size_t X_data_str,
|
||||
const size_t Y_data_str,
|
||||
const size_t X_mask_str,
|
||||
const size_t Y_mask_str) {
|
||||
const size_t Y_mask_str,
|
||||
const size_t mask_offset) {
|
||||
int tX = (X + block_size - 1) / block_size;
|
||||
int tY = (Y + block_size - 1) / block_size;
|
||||
|
||||
for (int i = 0; i < tX; i++) {
|
||||
for (int j = 0; j < tY; j++) {
|
||||
bool do_mask = mask[i * X_mask_str + j * Y_mask_str];
|
||||
if (!do_mask) {
|
||||
mask_t do_mask = mask[mask_offset + i * X_mask_str + j * Y_mask_str];
|
||||
if (do_mask != 1) {
|
||||
int loc_x = i * block_size;
|
||||
int loc_y = j * block_size;
|
||||
T* data_block = data + loc_x * X_data_str + loc_y * Y_data_str;
|
||||
@@ -43,7 +44,11 @@ inline void mask_matrix(
|
||||
int size_y = std::min(block_size, Y - loc_y);
|
||||
for (int ii = 0; ii < size_x; ii++) {
|
||||
for (int jj = 0; jj < size_y; jj++) {
|
||||
data_block[ii * X_data_str + jj * Y_data_str] = T(0.);
|
||||
if constexpr (std::is_same_v<mask_t, bool>) {
|
||||
data_block[ii * X_data_str + jj * Y_data_str] = T(0.);
|
||||
} else {
|
||||
data_block[ii * X_data_str + jj * Y_data_str] *= do_mask;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -62,36 +67,39 @@ void BlockMaskedMM::eval(const std::vector<array>& inputs, array& out) {
|
||||
|
||||
auto& a_pre = inputs[0];
|
||||
auto& b_pre = inputs[1];
|
||||
auto& out_mask = inputs[2];
|
||||
|
||||
auto check_transpose = [](const array& arr, bool do_copy) {
|
||||
auto stx = arr.strides()[arr.ndim() - 2];
|
||||
auto sty = arr.strides()[arr.ndim() - 1];
|
||||
if (stx == arr.shape(-1) && sty == 1) {
|
||||
if (do_copy) {
|
||||
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
|
||||
copy(arr, arr_copy, CopyType::Vector);
|
||||
return std::make_tuple(false, stx, arr_copy);
|
||||
}
|
||||
return std::make_tuple(false, stx, arr);
|
||||
} else if (stx == 1 && sty == arr.shape(-2)) {
|
||||
if (do_copy) {
|
||||
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
|
||||
copy(arr, arr_copy, CopyType::Vector);
|
||||
return std::make_tuple(true, sty, arr_copy);
|
||||
}
|
||||
return std::make_tuple(true, sty, arr);
|
||||
} else {
|
||||
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
|
||||
copy(arr, arr_copy, CopyType::General);
|
||||
size_t stx = arr.shape(-1);
|
||||
return std::make_tuple(false, stx, arr_copy);
|
||||
}
|
||||
};
|
||||
auto check_transpose =
|
||||
[](const array& arr, bool do_copy, bool expand_all = false) {
|
||||
auto stx = arr.strides()[arr.ndim() - 2];
|
||||
auto sty = arr.strides()[arr.ndim() - 1];
|
||||
if (!expand_all && stx == arr.shape(-1) && sty == 1) {
|
||||
if (do_copy) {
|
||||
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
|
||||
copy(arr, arr_copy, CopyType::Vector);
|
||||
return std::make_tuple(false, stx, arr_copy);
|
||||
}
|
||||
return std::make_tuple(false, stx, arr);
|
||||
} else if (!expand_all && stx == 1 && sty == arr.shape(-2)) {
|
||||
if (do_copy) {
|
||||
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
|
||||
copy(arr, arr_copy, CopyType::Vector);
|
||||
return std::make_tuple(true, sty, arr_copy);
|
||||
}
|
||||
return std::make_tuple(true, sty, arr);
|
||||
} else {
|
||||
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
|
||||
copy(arr, arr_copy, CopyType::General);
|
||||
size_t stx = arr.shape(-1);
|
||||
return std::make_tuple(false, stx, arr_copy);
|
||||
}
|
||||
};
|
||||
|
||||
bool has_op_mask = inputs.size() > 3;
|
||||
auto [a_transposed, lda, a] = check_transpose(a_pre, has_op_mask);
|
||||
auto [b_transposed, ldb, b] = check_transpose(b_pre, has_op_mask);
|
||||
bool has_out_mask = inputs.size() == 3 || inputs.size() == 5;
|
||||
auto [a_transposed, lda, a] =
|
||||
check_transpose(a_pre, has_op_mask, inputs.back().dtype() != bool_);
|
||||
auto [b_transposed, ldb, b] =
|
||||
check_transpose(b_pre, has_op_mask, inputs.back().dtype() != bool_);
|
||||
|
||||
size_t M = a.shape(-2);
|
||||
size_t N = b.shape(-1);
|
||||
@@ -114,27 +122,42 @@ void BlockMaskedMM::eval(const std::vector<array>& inputs, array& out) {
|
||||
int Y,
|
||||
size_t X_data_str,
|
||||
size_t Y_data_str) {
|
||||
const bool* mask_ptr = mask.data<bool>() +
|
||||
elem_to_loc(mask.shape(-1) * mask.shape(-2) * batch_idx,
|
||||
mask.shape(),
|
||||
mask.strides());
|
||||
size_t mask_offset = elem_to_loc(
|
||||
mask.shape(-1) * mask.shape(-2) * batch_idx,
|
||||
mask.shape(),
|
||||
mask.strides());
|
||||
|
||||
size_t X_mask_str = mask.strides()[mask.ndim() - 2];
|
||||
size_t Y_mask_str = mask.strides()[mask.ndim() - 1];
|
||||
|
||||
return mask_matrix(
|
||||
data,
|
||||
mask_ptr,
|
||||
block_size,
|
||||
X,
|
||||
Y,
|
||||
X_data_str,
|
||||
Y_data_str,
|
||||
X_mask_str,
|
||||
Y_mask_str);
|
||||
if (mask.dtype() == bool_) {
|
||||
return mask_matrix(
|
||||
data,
|
||||
mask.data<bool>(),
|
||||
block_size,
|
||||
X,
|
||||
Y,
|
||||
X_data_str,
|
||||
Y_data_str,
|
||||
X_mask_str,
|
||||
Y_mask_str,
|
||||
mask_offset);
|
||||
} else {
|
||||
return mask_matrix(
|
||||
data,
|
||||
mask.data<float>(),
|
||||
block_size,
|
||||
X,
|
||||
Y,
|
||||
X_data_str,
|
||||
Y_data_str,
|
||||
X_mask_str,
|
||||
Y_mask_str,
|
||||
mask_offset);
|
||||
}
|
||||
};
|
||||
|
||||
for (int i = 0; i < (a.size() / (M * K)); ++i) {
|
||||
for (int i = 0; i < (out.size() / (M * size_t(N))); ++i) {
|
||||
// Adjust pointer
|
||||
float* ai =
|
||||
a.data<float>() + elem_to_loc(M * K * i, a.shape(), a.strides());
|
||||
@@ -144,7 +167,7 @@ void BlockMaskedMM::eval(const std::vector<array>& inputs, array& out) {
|
||||
|
||||
// Zero out blocks in a and b if needed
|
||||
if (has_op_mask) {
|
||||
auto& a_mask = inputs[3];
|
||||
auto& a_mask = inputs[inputs.size() - 2];
|
||||
mask_array(
|
||||
a_mask,
|
||||
ai,
|
||||
@@ -155,7 +178,7 @@ void BlockMaskedMM::eval(const std::vector<array>& inputs, array& out) {
|
||||
a_transposed ? 1 : lda,
|
||||
a_transposed ? lda : 1);
|
||||
|
||||
auto& b_mask = inputs[4];
|
||||
auto& b_mask = inputs[inputs.size() - 1];
|
||||
mask_array(
|
||||
b_mask,
|
||||
bi,
|
||||
@@ -186,14 +209,16 @@ void BlockMaskedMM::eval(const std::vector<array>& inputs, array& out) {
|
||||
);
|
||||
|
||||
// Zero out blocks in out
|
||||
mask_array(out_mask, ci, block_size_, i, M, N, N, 1);
|
||||
if (has_out_mask) {
|
||||
mask_array(inputs[2], ci, block_size_, i, M, N, N, 1);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void BlockSparseMM::eval(const std::vector<array>& inputs, array& out) {
|
||||
void GatherMM::eval(const std::vector<array>& inputs, array& out) {
|
||||
if (out.dtype() != float32) {
|
||||
throw std::runtime_error(
|
||||
"[BlockSparseMM::eval] Currently only supports float32.");
|
||||
"[GatherMM::eval] Currently only supports float32.");
|
||||
}
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
|
||||
@@ -277,4 +302,4 @@ void BlockSparseMM::eval(const std::vector<array>& inputs, array& out) {
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
} // namespace mlx::core
|
||||
|
@@ -161,6 +161,13 @@ struct ArcTan {
|
||||
};
|
||||
};
|
||||
|
||||
struct ArcTan2 {
|
||||
template <typename T>
|
||||
T operator()(T y, T x) {
|
||||
return std::atan2(y, x);
|
||||
};
|
||||
};
|
||||
|
||||
struct ArcTanh {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
@@ -202,6 +209,12 @@ struct Ceil {
|
||||
};
|
||||
};
|
||||
|
||||
struct Conjugate {
|
||||
complex64_t operator()(complex64_t x) {
|
||||
return std::conj(x);
|
||||
}
|
||||
};
|
||||
|
||||
struct Cos {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
|
@@ -1,4 +1,4 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include <algorithm>
|
||||
#include <cassert>
|
||||
@@ -113,61 +113,6 @@ void AsType::eval(const std::vector<array>& inputs, array& out) {
|
||||
copy(in, out, ctype);
|
||||
}
|
||||
|
||||
void AsStrided::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
|
||||
auto& in = inputs[0];
|
||||
|
||||
if (!in.flags().row_contiguous) {
|
||||
// Just ensuring that inputs[0] came from the ops which would ensure the
|
||||
// input is row contiguous.
|
||||
throw std::runtime_error(
|
||||
"AsStrided must be used with row contiguous arrays only.");
|
||||
}
|
||||
|
||||
// Compute the flags given the shape and strides
|
||||
bool row_contiguous = true, col_contiguous = true;
|
||||
size_t r = 1, c = 1;
|
||||
for (int i = strides_.size() - 1, j = 0; i >= 0; i--, j++) {
|
||||
row_contiguous &= (r == strides_[i]) || (shape_[i] == 1);
|
||||
col_contiguous &= (c == strides_[j]) || (shape_[j] == 1);
|
||||
r *= shape_[i];
|
||||
c *= shape_[j];
|
||||
}
|
||||
auto flags = in.flags();
|
||||
// TODO: Compute the contiguous flag in a better way cause now we are
|
||||
// unnecessarily strict.
|
||||
flags.contiguous = row_contiguous || col_contiguous;
|
||||
flags.row_contiguous = row_contiguous;
|
||||
flags.col_contiguous = col_contiguous;
|
||||
|
||||
// There is no easy way to compute the actual data size so we use out.size().
|
||||
// The contiguous flag will almost certainly not be set so no code should
|
||||
// rely on data_size anyway.
|
||||
size_t data_size = out.size();
|
||||
|
||||
return out.copy_shared_buffer(in, strides_, flags, data_size, offset_);
|
||||
}
|
||||
|
||||
void Broadcast::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
if (out.size() == 0) {
|
||||
out.set_data(nullptr);
|
||||
return;
|
||||
}
|
||||
std::vector<size_t> strides(out.ndim(), 0);
|
||||
int diff = out.ndim() - in.ndim();
|
||||
for (int i = in.ndim() - 1; i >= 0; --i) {
|
||||
strides[i + diff] = (in.shape()[i] == 1) ? 0 : in.strides()[i];
|
||||
}
|
||||
auto flags = in.flags();
|
||||
if (out.size() > in.size()) {
|
||||
flags.row_contiguous = flags.col_contiguous = false;
|
||||
}
|
||||
out.copy_shared_buffer(in, strides, flags, in.data_size());
|
||||
}
|
||||
|
||||
void Ceil::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
@@ -203,9 +148,15 @@ void Concatenate::eval(const std::vector<array>& inputs, array& out) {
|
||||
}
|
||||
}
|
||||
|
||||
void Copy::eval(const std::vector<array>& inputs, array& out) {
|
||||
void Conjugate::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
out.copy_shared_buffer(inputs[0]);
|
||||
const auto& in = inputs[0];
|
||||
if (out.dtype() == complex64) {
|
||||
unary_fp(in, out, detail::Conjugate());
|
||||
} else {
|
||||
throw std::invalid_argument(
|
||||
"[conjugate] conjugate must be called on complex input.");
|
||||
}
|
||||
}
|
||||
|
||||
void Cos::eval(const std::vector<array>& inputs, array& out) {
|
||||
@@ -232,81 +183,6 @@ void Cosh::eval(const std::vector<array>& inputs, array& out) {
|
||||
}
|
||||
}
|
||||
|
||||
void CustomVJP::eval(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
assert(inputs.size() > outputs.size());
|
||||
for (int i = 0, j = inputs.size() - outputs.size(); i < outputs.size();
|
||||
i++, j++) {
|
||||
outputs[i].copy_shared_buffer(inputs[j]);
|
||||
}
|
||||
}
|
||||
|
||||
void Depends::eval(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
assert(inputs.size() > outputs.size());
|
||||
for (int i = 0; i < outputs.size(); i++) {
|
||||
outputs[i].copy_shared_buffer(inputs[i]);
|
||||
}
|
||||
}
|
||||
|
||||
void NumberOfElements::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
|
||||
double numel = 1;
|
||||
for (auto ax : axes_) {
|
||||
numel *= inputs[0].shape(ax);
|
||||
}
|
||||
|
||||
if (inverted_) {
|
||||
numel = 1.0 / numel;
|
||||
}
|
||||
|
||||
switch (out.dtype()) {
|
||||
case bool_:
|
||||
*out.data<bool>() = static_cast<bool>(numel);
|
||||
break;
|
||||
case uint8:
|
||||
*out.data<uint8_t>() = static_cast<uint8_t>(numel);
|
||||
break;
|
||||
case uint16:
|
||||
*out.data<uint16_t>() = static_cast<uint16_t>(numel);
|
||||
break;
|
||||
case uint32:
|
||||
*out.data<uint32_t>() = static_cast<uint32_t>(numel);
|
||||
break;
|
||||
case uint64:
|
||||
*out.data<uint64_t>() = static_cast<uint64_t>(numel);
|
||||
break;
|
||||
case int8:
|
||||
*out.data<int8_t>() = static_cast<int8_t>(numel);
|
||||
break;
|
||||
case int16:
|
||||
*out.data<int16_t>() = static_cast<int16_t>(numel);
|
||||
break;
|
||||
case int32:
|
||||
*out.data<int32_t>() = static_cast<int32_t>(numel);
|
||||
break;
|
||||
case int64:
|
||||
*out.data<int64_t>() = static_cast<int64_t>(numel);
|
||||
break;
|
||||
case float16:
|
||||
*out.data<float16_t>() = static_cast<float16_t>(numel);
|
||||
break;
|
||||
case float32:
|
||||
*out.data<float>() = static_cast<float>(numel);
|
||||
break;
|
||||
case bfloat16:
|
||||
*out.data<bfloat16_t>() = static_cast<bfloat16_t>(numel);
|
||||
break;
|
||||
case complex64:
|
||||
*out.data<complex64_t>() = static_cast<complex64_t>(numel);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
void Erf::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
@@ -536,63 +412,6 @@ void RandomBits::eval(const std::vector<array>& inputs, array& out) {
|
||||
}
|
||||
}
|
||||
|
||||
std::pair<bool, std::vector<size_t>> Reshape::prepare_reshape(
|
||||
const array& in,
|
||||
const array& out) {
|
||||
// Special case for empty arrays or row contiguous arrays
|
||||
if (in.size() == 0 || in.flags().row_contiguous) {
|
||||
return {false, out.strides()};
|
||||
}
|
||||
|
||||
// Special case for scalars
|
||||
if (in.ndim() == 0) {
|
||||
std::vector<size_t> out_strides(out.ndim(), 0);
|
||||
return {false, out_strides};
|
||||
}
|
||||
|
||||
// Firstly let's collapse all the contiguous dimensions of the input
|
||||
auto [shape, _strides] = collapse_contiguous_dims(in);
|
||||
auto& strides = _strides[0];
|
||||
|
||||
// If shapes fit exactly in the contiguous dims then no copy is necessary so
|
||||
// let's check.
|
||||
std::vector<size_t> out_strides;
|
||||
bool copy_necessary = false;
|
||||
int j = 0;
|
||||
for (int i = 0; i < out.ndim(); i++) {
|
||||
int N = out.shape(i);
|
||||
if (j < shape.size() && shape[j] % N == 0) {
|
||||
shape[j] /= N;
|
||||
out_strides.push_back(shape[j] * strides[j]);
|
||||
j += (shape[j] == 1);
|
||||
} else if (N == 1) {
|
||||
// i > 0 because otherwise j < shape.size() && shape[j] % 1 == 0
|
||||
out_strides.push_back(out_strides.back());
|
||||
} else {
|
||||
copy_necessary = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
return {copy_necessary, out_strides};
|
||||
}
|
||||
|
||||
void Reshape::shared_buffer_reshape(
|
||||
const array& in,
|
||||
const std::vector<size_t>& out_strides,
|
||||
array& out) {
|
||||
auto flags = in.flags();
|
||||
if (flags.row_contiguous) {
|
||||
// For row contiguous reshapes:
|
||||
// - Shallow copy the buffer
|
||||
// - If reshaping into a vector (all singleton dimensions except one) it
|
||||
// becomes col contiguous again.
|
||||
auto max_dim = std::max_element(out.shape().begin(), out.shape().end());
|
||||
flags.col_contiguous = out.size() <= 1 || out.size() == *max_dim;
|
||||
}
|
||||
out.copy_shared_buffer(in, out_strides, flags, in.data_size());
|
||||
}
|
||||
|
||||
void Reshape::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
@@ -663,49 +482,6 @@ void Sinh::eval(const std::vector<array>& inputs, array& out) {
|
||||
}
|
||||
}
|
||||
|
||||
std::tuple<bool, int64_t, std::vector<int64_t>> Slice::prepare_slice(
|
||||
const array& in) {
|
||||
int64_t data_offset = 0;
|
||||
bool copy_needed = false;
|
||||
std::vector<int64_t> inp_strides(in.ndim(), 0);
|
||||
for (int i = 0; i < in.ndim(); ++i) {
|
||||
data_offset += start_indices_[i] * in.strides()[i];
|
||||
inp_strides[i] = in.strides()[i] * strides_[i];
|
||||
|
||||
copy_needed |= strides_[i] < 0;
|
||||
}
|
||||
|
||||
return std::make_tuple(copy_needed, data_offset, inp_strides);
|
||||
}
|
||||
|
||||
void Slice::shared_buffer_slice(
|
||||
const array& in,
|
||||
const std::vector<size_t>& out_strides,
|
||||
size_t data_offset,
|
||||
array& out) {
|
||||
// Compute row/col contiguity
|
||||
auto [data_size, is_row_contiguous, is_col_contiguous] =
|
||||
check_contiguity(out.shape(), out_strides);
|
||||
|
||||
auto flags = in.flags();
|
||||
flags.row_contiguous = is_row_contiguous;
|
||||
flags.col_contiguous = is_col_contiguous;
|
||||
|
||||
if (data_size == 1) {
|
||||
// Broadcasted scalar array is contiguous.
|
||||
flags.contiguous = true;
|
||||
} else if (data_size == in.data_size()) {
|
||||
// Means we sliced a broadcasted dimension so leave the "no holes" flag
|
||||
// alone.
|
||||
} else {
|
||||
// We sliced something. So either we are row or col contiguous or we
|
||||
// punched a hole.
|
||||
flags.contiguous &= flags.row_contiguous || flags.col_contiguous;
|
||||
}
|
||||
|
||||
out.copy_shared_buffer(in, out_strides, flags, data_size, data_offset);
|
||||
}
|
||||
|
||||
void Slice::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
if (out.size() == 0) {
|
||||
@@ -737,18 +513,6 @@ void Slice::eval(const std::vector<array>& inputs, array& out) {
|
||||
}
|
||||
}
|
||||
|
||||
std::tuple<int64_t, std::vector<int64_t>> SliceUpdate::prepare_slice(
|
||||
const array& in) {
|
||||
int64_t data_offset = 0;
|
||||
std::vector<int64_t> inp_strides(in.ndim(), 0);
|
||||
for (int i = 0; i < in.ndim(); ++i) {
|
||||
data_offset += start_indices_[i] * in.strides()[i];
|
||||
inp_strides[i] = in.strides()[i] * strides_[i];
|
||||
}
|
||||
|
||||
return std::make_tuple(data_offset, inp_strides);
|
||||
}
|
||||
|
||||
void SliceUpdate::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
if (out.size() == 0) {
|
||||
@@ -786,58 +550,6 @@ void SliceUpdate::eval(const std::vector<array>& inputs, array& out) {
|
||||
/* CopyType ctype = */ CopyType::GeneralGeneral);
|
||||
}
|
||||
|
||||
void Split::eval(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
assert(inputs.size() == 1);
|
||||
|
||||
auto& in = inputs[0];
|
||||
|
||||
auto compute_new_flags = [](const auto& shape,
|
||||
const auto& strides,
|
||||
size_t in_data_size,
|
||||
auto flags) {
|
||||
size_t data_size = 1;
|
||||
size_t f_stride = 1;
|
||||
size_t b_stride = 1;
|
||||
flags.row_contiguous = true;
|
||||
flags.col_contiguous = true;
|
||||
for (int i = 0, ri = shape.size() - 1; ri >= 0; i++, ri--) {
|
||||
flags.col_contiguous &= strides[i] == f_stride || shape[i] == 1;
|
||||
flags.row_contiguous &= strides[ri] == b_stride || shape[ri] == 1;
|
||||
f_stride *= shape[i];
|
||||
b_stride *= shape[ri];
|
||||
if (strides[i] > 0) {
|
||||
data_size *= shape[i];
|
||||
}
|
||||
}
|
||||
|
||||
if (data_size == 1) {
|
||||
// Broadcasted scalar array is contiguous.
|
||||
flags.contiguous = true;
|
||||
} else if (data_size == in_data_size) {
|
||||
// Means we sliced a broadcasted dimension so leave the "no holes" flag
|
||||
// alone.
|
||||
} else {
|
||||
// We sliced something. So either we are row or col contiguous or we
|
||||
// punched a hole.
|
||||
flags.contiguous &= flags.row_contiguous || flags.col_contiguous;
|
||||
}
|
||||
|
||||
return std::pair<decltype(flags), size_t>{flags, data_size};
|
||||
};
|
||||
|
||||
std::vector<int> indices(1, 0);
|
||||
indices.insert(indices.end(), indices_.begin(), indices_.end());
|
||||
for (int i = 0; i < indices.size(); i++) {
|
||||
size_t offset = indices[i] * in.strides()[axis_];
|
||||
auto [new_flags, data_size] = compute_new_flags(
|
||||
outputs[i].shape(), in.strides(), in.data_size(), in.flags());
|
||||
outputs[i].copy_shared_buffer(
|
||||
in, in.strides(), new_flags, data_size, offset);
|
||||
}
|
||||
}
|
||||
|
||||
void Square::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
@@ -854,11 +566,6 @@ void Sqrt::eval(const std::vector<array>& inputs, array& out) {
|
||||
}
|
||||
}
|
||||
|
||||
void StopGradient::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
out.copy_shared_buffer(inputs[0]);
|
||||
}
|
||||
|
||||
void Tan::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
@@ -883,38 +590,4 @@ void Tanh::eval(const std::vector<array>& inputs, array& out) {
|
||||
}
|
||||
}
|
||||
|
||||
void Transpose::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
std::vector<size_t> out_strides(out.ndim());
|
||||
auto& in = inputs[0];
|
||||
for (int ax = 0; ax < axes_.size(); ++ax) {
|
||||
out_strides[ax] = in.strides()[axes_[ax]];
|
||||
}
|
||||
|
||||
// Conditions for {row/col}_contiguous
|
||||
// - array must be contiguous (no gaps)
|
||||
// - underlying buffer size should have the same size as the array
|
||||
// - cumulative product of shapes is equal to the strides (we can ignore axes
|
||||
// with size == 1)
|
||||
// - in the forward direction (column contiguous)
|
||||
// - in the reverse direction (row contiguous)
|
||||
// - vectors are both row and col contiguous (hence if both row/col are
|
||||
// true, they stay true)
|
||||
auto flags = in.flags();
|
||||
if (flags.contiguous && in.data_size() == in.size()) {
|
||||
size_t f_stride = 1;
|
||||
size_t b_stride = 1;
|
||||
flags.col_contiguous = true;
|
||||
flags.row_contiguous = true;
|
||||
for (int i = 0, ri = out.ndim() - 1; i < out.ndim(); ++i, --ri) {
|
||||
flags.col_contiguous &= (out_strides[i] == f_stride || out.shape(i) == 1);
|
||||
f_stride *= out.shape(i);
|
||||
flags.row_contiguous &=
|
||||
(out_strides[ri] == b_stride || out.shape(ri) == 1);
|
||||
b_stride *= out.shape(ri);
|
||||
}
|
||||
}
|
||||
out.copy_shared_buffer(in, out_strides, flags, in.data_size());
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@@ -192,7 +192,7 @@ void _qmm_dispatch_typed(
|
||||
}
|
||||
|
||||
void _qmm_dispatch(
|
||||
array out,
|
||||
array& out,
|
||||
const array& x,
|
||||
const array& w,
|
||||
const array& scales,
|
||||
@@ -253,6 +253,81 @@ void _qmm_dispatch(
|
||||
}
|
||||
}
|
||||
|
||||
void _bs_qmm_dispatch(
|
||||
array& out,
|
||||
const array& x,
|
||||
const array& w,
|
||||
const array& scales,
|
||||
const array& biases,
|
||||
const array& lhs_indices,
|
||||
const array& rhs_indices,
|
||||
int bits,
|
||||
int group_size,
|
||||
bool transposed_w) {
|
||||
int K = x.shape(-1);
|
||||
int M = x.shape(-2);
|
||||
int N = out.shape(-1);
|
||||
|
||||
int w_els = w.shape(-1) * w.shape(-2);
|
||||
int g_els = scales.shape(-1) * scales.shape(-2);
|
||||
|
||||
const uint32_t* lhs_indices_data = lhs_indices.data<uint32_t>();
|
||||
const uint32_t* rhs_indices_data = rhs_indices.data<uint32_t>();
|
||||
|
||||
for (int i = 0; i < lhs_indices.size(); i++) {
|
||||
int x_idx = lhs_indices_data[elem_to_loc(i, lhs_indices)];
|
||||
int w_idx = rhs_indices_data[elem_to_loc(i, rhs_indices)];
|
||||
|
||||
switch (x.dtype()) {
|
||||
case float32:
|
||||
_qmm_dispatch_typed<float>(
|
||||
out.data<float>() + i * M * N,
|
||||
x.data<float>() + elem_to_loc(x_idx * M * K, x),
|
||||
w.data<uint32_t>() + elem_to_loc(w_idx * w_els, w),
|
||||
scales.data<float>() + elem_to_loc(w_idx * g_els, scales),
|
||||
biases.data<float>() + elem_to_loc(w_idx * g_els, biases),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
bits,
|
||||
group_size,
|
||||
transposed_w);
|
||||
break;
|
||||
case float16:
|
||||
_qmm_dispatch_typed<float16_t>(
|
||||
out.data<float16_t>() + i * M * N,
|
||||
x.data<float16_t>() + elem_to_loc(x_idx * M * K, x),
|
||||
w.data<uint32_t>() + elem_to_loc(w_idx * w_els, w),
|
||||
scales.data<float16_t>() + elem_to_loc(w_idx * g_els, scales),
|
||||
biases.data<float16_t>() + elem_to_loc(w_idx * g_els, biases),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
bits,
|
||||
group_size,
|
||||
transposed_w);
|
||||
break;
|
||||
case bfloat16:
|
||||
_qmm_dispatch_typed<bfloat16_t>(
|
||||
out.data<bfloat16_t>() + i * M * N,
|
||||
x.data<bfloat16_t>() + elem_to_loc(x_idx * M * K, x),
|
||||
w.data<uint32_t>() + elem_to_loc(w_idx * w_els, w),
|
||||
scales.data<bfloat16_t>() + elem_to_loc(w_idx * g_els, scales),
|
||||
biases.data<bfloat16_t>() + elem_to_loc(w_idx * g_els, biases),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
bits,
|
||||
group_size,
|
||||
transposed_w);
|
||||
break;
|
||||
default:
|
||||
throw std::invalid_argument(
|
||||
"[quantized_matmul] only floating types are supported");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void QuantizedMatmul::eval(const std::vector<array>& inputs, array& out) {
|
||||
@@ -282,4 +357,45 @@ void QuantizedMatmul::eval(const std::vector<array>& inputs, array& out) {
|
||||
_qmm_dispatch(out, x, w, scales, biases, group_size_, bits_, transpose_);
|
||||
}
|
||||
|
||||
void GatherQMM::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 6);
|
||||
|
||||
auto& x_pre = inputs[0];
|
||||
auto& w_pre = inputs[1];
|
||||
auto& scales_pre = inputs[2];
|
||||
auto& biases_pre = inputs[3];
|
||||
auto& lhs_indices = inputs[4];
|
||||
auto& rhs_indices = inputs[5];
|
||||
|
||||
auto ensure_row_contiguous_last_dims = [](const array& arr) {
|
||||
auto stride_0 = arr.strides()[arr.ndim() - 2];
|
||||
auto stride_1 = arr.strides()[arr.ndim() - 1];
|
||||
if (stride_0 == arr.shape(-1) && stride_1 == 1) {
|
||||
return arr;
|
||||
} else {
|
||||
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
|
||||
copy(arr, arr_copy, CopyType::General);
|
||||
return arr_copy;
|
||||
}
|
||||
};
|
||||
|
||||
auto x = ensure_row_contiguous_last_dims(x_pre);
|
||||
auto w = ensure_row_contiguous_last_dims(w_pre);
|
||||
auto scales = ensure_row_contiguous_last_dims(scales_pre);
|
||||
auto biases = ensure_row_contiguous_last_dims(biases_pre);
|
||||
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
_bs_qmm_dispatch(
|
||||
out,
|
||||
x,
|
||||
w,
|
||||
scales,
|
||||
biases,
|
||||
lhs_indices,
|
||||
rhs_indices,
|
||||
group_size_,
|
||||
bits_,
|
||||
transpose_);
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@@ -3,7 +3,6 @@
|
||||
#include "mlx/allocator.h"
|
||||
#include "mlx/backend/common/copy.h"
|
||||
#include "mlx/backend/common/lapack_helper.h"
|
||||
#include "mlx/linalg.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
namespace mlx::core {
|
||||
@@ -145,12 +144,4 @@ void SVD::eval(const std::vector<array>& inputs, std::vector<array>& outputs) {
|
||||
svd_impl(inputs[0], outputs[0], outputs[1], outputs[2]);
|
||||
}
|
||||
|
||||
std::pair<std::vector<array>, std::vector<int>> SVD::vmap(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<int>& axes) {
|
||||
auto ax = axes[0] >= 0 ? 0 : -1;
|
||||
auto a = axes[0] > 0 ? moveaxis(inputs[0], axes[0], 0, stream()) : inputs[0];
|
||||
return {{linalg::svd(a, stream())}, {ax, ax, ax}};
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@@ -1,7 +0,0 @@
|
||||
target_sources(
|
||||
mlx
|
||||
PRIVATE
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/io_impl.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/thread_pool.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
|
||||
)
|
@@ -1,72 +0,0 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#include "mlx/backend/io/io_impl.h"
|
||||
#include "mlx/backend/io/thread_pool.h"
|
||||
#include "mlx/primitives.h"
|
||||
#include "mlx/scheduler.h"
|
||||
|
||||
namespace mlx::core::io {
|
||||
|
||||
namespace {
|
||||
|
||||
detail::ThreadPool& thread_pool() {
|
||||
static std::unique_ptr<detail::ThreadPool> pool_ptr;
|
||||
|
||||
if (pool_ptr == nullptr) {
|
||||
pool_ptr = std::make_unique<detail::ThreadPool>(4);
|
||||
}
|
||||
|
||||
return *pool_ptr;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
std::function<void()> make_task(array arr, bool signal) {
|
||||
return [arr = std::move(arr), signal]() mutable {
|
||||
auto stream = arr.primitive().stream();
|
||||
|
||||
// Wait on inputs coming from different streams/devices.
|
||||
for (auto& input : arr.inputs()) {
|
||||
if (input.event().valid() && input.event().stream() != stream) {
|
||||
input.event().wait();
|
||||
}
|
||||
}
|
||||
|
||||
// Task computation actually starting.
|
||||
scheduler::notify_new_task(stream);
|
||||
|
||||
// Schedule the computation
|
||||
auto inputs = arr.inputs();
|
||||
auto outputs = arr.outputs();
|
||||
thread_pool().enqueue(
|
||||
[arr = std::move(arr), inputs, outputs, signal, stream]() mutable {
|
||||
// Perform the computation
|
||||
arr.primitive().eval_io(inputs, outputs);
|
||||
|
||||
if (!arr.is_tracer()) {
|
||||
arr.detach();
|
||||
}
|
||||
|
||||
if (signal) {
|
||||
thread_pool().barrier(
|
||||
[arr = std::move(arr)]() { arr.event().signal(); });
|
||||
}
|
||||
|
||||
// Task computation done.
|
||||
scheduler::notify_task_completion(stream);
|
||||
},
|
||||
inputs,
|
||||
outputs);
|
||||
};
|
||||
}
|
||||
|
||||
std::function<void()> make_synchronize_task(
|
||||
Stream s,
|
||||
std::shared_ptr<std::promise<void>> p) {
|
||||
return [p = std::move(p)]() {
|
||||
thread_pool().barrier().wait();
|
||||
p->set_value();
|
||||
};
|
||||
}
|
||||
|
||||
} // namespace mlx::core::io
|
@@ -1,18 +0,0 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <functional>
|
||||
#include <future>
|
||||
#include <memory>
|
||||
|
||||
#include "mlx/array.h"
|
||||
|
||||
namespace mlx::core::io {
|
||||
|
||||
std::function<void()> make_task(array arr, bool signal);
|
||||
std::function<void()> make_synchronize_task(
|
||||
Stream s,
|
||||
std::shared_ptr<std::promise<void>> p);
|
||||
|
||||
} // namespace mlx::core::io
|
@@ -1,60 +0,0 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#include <algorithm>
|
||||
#include <cassert>
|
||||
#include <utility>
|
||||
|
||||
#include "mlx/allocator.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace {
|
||||
|
||||
template <const uint8_t scalar_size>
|
||||
void swap_endianness(uint8_t* data_bytes, size_t N) {
|
||||
struct Elem {
|
||||
uint8_t bytes[scalar_size];
|
||||
};
|
||||
|
||||
Elem* data = reinterpret_cast<Elem*>(data_bytes);
|
||||
|
||||
for (size_t i = 0; i < N; i++) {
|
||||
for (size_t j = 0; j < (scalar_size / 2); j++) {
|
||||
std::swap(data[i].bytes[j], data[i].bytes[scalar_size - j - 1]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void Load::eval_io(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
assert(inputs.size() == 0);
|
||||
array& out = outputs[0];
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
|
||||
{
|
||||
std::lock_guard lock(*reader_);
|
||||
|
||||
reader_->seek(offset_, std::ios_base::beg);
|
||||
reader_->read(out.data<char>(), out.nbytes());
|
||||
}
|
||||
|
||||
if (swap_endianness_) {
|
||||
switch (out.itemsize()) {
|
||||
case 2:
|
||||
swap_endianness<2>(out.data<uint8_t>(), out.data_size());
|
||||
break;
|
||||
case 4:
|
||||
swap_endianness<4>(out.data<uint8_t>(), out.data_size());
|
||||
break;
|
||||
case 8:
|
||||
swap_endianness<8>(out.data<uint8_t>(), out.data_size());
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
@@ -1,216 +0,0 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#include <numeric>
|
||||
|
||||
#include "mlx/backend/io/thread_pool.h"
|
||||
|
||||
namespace mlx::core::io::detail {
|
||||
|
||||
ThreadPool::ThreadPool(int workers)
|
||||
: task_queues_(workers),
|
||||
queue_mutexes_(workers),
|
||||
queue_cvs_(workers),
|
||||
set_mutexes_(workers),
|
||||
output_sets_(workers),
|
||||
stop_(false) {
|
||||
for (int i = 0; i < workers; i++) {
|
||||
workers_.emplace_back(&ThreadPool::worker, this, i);
|
||||
}
|
||||
}
|
||||
|
||||
ThreadPool::~ThreadPool() {
|
||||
stop_ = true;
|
||||
for (auto& cv : queue_cvs_) {
|
||||
cv.notify_one();
|
||||
}
|
||||
|
||||
for (auto& t : workers_) {
|
||||
if (t.joinable()) {
|
||||
t.join();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::future<void> ThreadPool::enqueue(
|
||||
std::function<void()> task,
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<array>& outputs) {
|
||||
std::vector<int> barriers;
|
||||
if (!inputs.empty()) {
|
||||
for (int i = 0; i < output_sets_.size(); i++) {
|
||||
std::lock_guard<std::mutex> lock(set_mutexes_[i]);
|
||||
|
||||
for (auto& a : inputs) {
|
||||
if (output_sets_[i].find(a.id()) != output_sets_[i].end()) {
|
||||
barriers.push_back(i);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Case 1: Barriers is empty so try to add it to the smallest queue
|
||||
if (barriers.empty()) {
|
||||
auto min_queue = std::min_element(
|
||||
task_queues_.begin(),
|
||||
task_queues_.end(),
|
||||
[](const auto& left, const auto& right) {
|
||||
return left.size() < right.size();
|
||||
});
|
||||
int worker_idx = std::distance(task_queues_.begin(), min_queue);
|
||||
|
||||
add_outputs_to_worker(outputs, worker_idx);
|
||||
return enqueue(
|
||||
remove_outputs_when_done(std::move(task), outputs, worker_idx),
|
||||
worker_idx);
|
||||
}
|
||||
|
||||
// Case 2: Barriers has only one queue so put that into that queue
|
||||
if (barriers.size() == 1) {
|
||||
int worker_idx = barriers[0];
|
||||
add_outputs_to_worker(outputs, worker_idx);
|
||||
return enqueue(
|
||||
remove_outputs_when_done(std::move(task), outputs, worker_idx),
|
||||
worker_idx);
|
||||
}
|
||||
|
||||
// Case 3: We need to add a barrier before our task and add it to the
|
||||
// smallest queue of the barriers.
|
||||
auto min_queue = std::min_element(
|
||||
barriers.begin(), barriers.end(), [this](int left, int right) {
|
||||
return task_queues_[left].size() < task_queues_[right].size();
|
||||
});
|
||||
int worker_idx = *min_queue;
|
||||
barriers.erase(min_queue);
|
||||
std::shared_future<void> queue_barrier =
|
||||
barrier(barriers); // We shouldn't need shared future here
|
||||
add_outputs_to_worker(outputs, worker_idx);
|
||||
return enqueue(
|
||||
remove_outputs_when_done(
|
||||
[queue_barrier = std::move(queue_barrier),
|
||||
og_task = std::move(task)]() {
|
||||
queue_barrier.wait();
|
||||
og_task();
|
||||
},
|
||||
outputs,
|
||||
worker_idx),
|
||||
worker_idx);
|
||||
}
|
||||
|
||||
std::future<void> ThreadPool::enqueue(
|
||||
std::function<void()> task,
|
||||
int worker_idx) {
|
||||
std::packaged_task<void()> pt(std::move(task));
|
||||
std::future<void> result = pt.get_future();
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(queue_mutexes_[worker_idx]);
|
||||
task_queues_[worker_idx].emplace(std::move(pt));
|
||||
}
|
||||
queue_cvs_[worker_idx].notify_one();
|
||||
return result;
|
||||
}
|
||||
|
||||
void ThreadPool::add_outputs_to_worker(
|
||||
const std::vector<array>& outputs,
|
||||
int worker_idx) {
|
||||
if (outputs.size() == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
std::lock_guard<std::mutex> lock(set_mutexes_[worker_idx]);
|
||||
for (auto& a : outputs) {
|
||||
output_sets_[worker_idx].insert(a.id());
|
||||
}
|
||||
}
|
||||
|
||||
std::function<void()> ThreadPool::remove_outputs_when_done(
|
||||
std::function<void()> task,
|
||||
const std::vector<array>& outputs,
|
||||
int worker_idx) {
|
||||
if (outputs.size() == 0) {
|
||||
return task;
|
||||
}
|
||||
|
||||
std::vector<std::uintptr_t> output_ids;
|
||||
for (auto& a : outputs) {
|
||||
output_ids.push_back(a.id());
|
||||
}
|
||||
|
||||
return [og_task = std::move(task),
|
||||
ids = std::move(output_ids),
|
||||
worker_idx,
|
||||
this]() {
|
||||
og_task();
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(set_mutexes_[worker_idx]);
|
||||
for (auto id : ids) {
|
||||
output_sets_[worker_idx].erase(id);
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
std::future<void> ThreadPool::barrier(
|
||||
const std::vector<int>& worker_ids,
|
||||
std::function<void()> on_barrier) {
|
||||
auto workers = std::make_shared<std::atomic<int>>(worker_ids.size());
|
||||
auto promise = std::make_shared<std::promise<void>>();
|
||||
auto future = promise->get_future();
|
||||
|
||||
for (auto idx : worker_ids) {
|
||||
enqueue(
|
||||
[workers, promise, on_barrier = std::move(on_barrier)]() {
|
||||
(*workers)--;
|
||||
if (*workers <= 0) {
|
||||
on_barrier();
|
||||
promise->set_value();
|
||||
}
|
||||
},
|
||||
idx);
|
||||
}
|
||||
|
||||
return future;
|
||||
}
|
||||
|
||||
std::future<void> ThreadPool::barrier(const std::vector<int>& worker_ids) {
|
||||
auto noop = []() {};
|
||||
return barrier(worker_ids, std::move(noop));
|
||||
}
|
||||
|
||||
std::future<void> ThreadPool::barrier(std::function<void()> on_barrier) {
|
||||
std::vector<int> worker_ids(workers_.size());
|
||||
std::iota(worker_ids.begin(), worker_ids.end(), 0);
|
||||
return barrier(worker_ids, std::move(on_barrier));
|
||||
}
|
||||
|
||||
std::future<void> ThreadPool::barrier() {
|
||||
auto noop = []() {};
|
||||
return barrier(std::move(noop));
|
||||
}
|
||||
|
||||
void ThreadPool::worker(int idx) {
|
||||
while (true) {
|
||||
std::packaged_task<void()> task;
|
||||
{
|
||||
std::unique_lock<std::mutex> lock(queue_mutexes_[idx]);
|
||||
queue_cvs_[idx].wait(
|
||||
lock, [this, idx]() { return stop_ || !task_queues_[idx].empty(); });
|
||||
if (task_queues_[idx].empty()) {
|
||||
if (stop_) {
|
||||
break;
|
||||
} else {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
task = std::move(task_queues_[idx].front());
|
||||
task_queues_[idx].pop();
|
||||
}
|
||||
try {
|
||||
task();
|
||||
} catch (...) {
|
||||
// do nothing?
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core::io::detail
|
@@ -1,52 +0,0 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <future>
|
||||
#include <queue>
|
||||
#include <unordered_set>
|
||||
|
||||
#include "mlx/array.h"
|
||||
|
||||
namespace mlx::core::io::detail {
|
||||
|
||||
class ThreadPool {
|
||||
public:
|
||||
explicit ThreadPool(int workers);
|
||||
~ThreadPool();
|
||||
|
||||
ThreadPool(ThreadPool&&) = delete;
|
||||
ThreadPool(const ThreadPool&) = delete;
|
||||
ThreadPool& operator=(ThreadPool&&) = delete;
|
||||
ThreadPool& operator=(const ThreadPool&) = delete;
|
||||
|
||||
std::future<void> enqueue(
|
||||
std::function<void()> task,
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<array>& outputs);
|
||||
std::future<void> barrier(
|
||||
const std::vector<int>& worker_ids,
|
||||
std::function<void()> on_barrier);
|
||||
std::future<void> barrier(const std::vector<int>& worker_ids);
|
||||
std::future<void> barrier(std::function<void()> on_barrier);
|
||||
std::future<void> barrier();
|
||||
|
||||
private:
|
||||
std::future<void> enqueue(std::function<void()> task, int worker_idx);
|
||||
void add_outputs_to_worker(const std::vector<array>& outputs, int worker_idx);
|
||||
std::function<void()> remove_outputs_when_done(
|
||||
std::function<void()> task,
|
||||
const std::vector<array>& outputs,
|
||||
int worker_idx);
|
||||
void worker(int idx);
|
||||
|
||||
std::vector<std::queue<std::packaged_task<void()>>> task_queues_;
|
||||
std::vector<std::mutex> queue_mutexes_;
|
||||
std::vector<std::condition_variable> queue_cvs_;
|
||||
std::vector<std::mutex> set_mutexes_;
|
||||
std::vector<std::unordered_set<std::uintptr_t>> output_sets_;
|
||||
bool stop_;
|
||||
std::vector<std::thread> workers_;
|
||||
};
|
||||
|
||||
} // namespace mlx::core::io::detail
|
@@ -1,27 +1,125 @@
|
||||
add_custom_command(
|
||||
OUTPUT compiled_preamble.cpp
|
||||
function(make_jit_source SRC_FILE)
|
||||
# This function takes a metal header file,
|
||||
# runs the C preprocessesor on it, and makes
|
||||
# the processed contents available as a string in a C++ function
|
||||
# mlx::core::metal::${SRC_NAME}()
|
||||
#
|
||||
# To use the function, declare it in jit/includes.h and
|
||||
# include jit/includes.h.
|
||||
#
|
||||
# Additional arguments to this function are treated as dependencies
|
||||
# in the Cmake build system.
|
||||
get_filename_component(SRC_NAME ${SRC_FILE} NAME)
|
||||
add_custom_command(
|
||||
OUTPUT jit/${SRC_NAME}.cpp
|
||||
COMMAND /bin/bash
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/make_compiled_preamble.sh
|
||||
${CMAKE_CURRENT_BINARY_DIR}/compiled_preamble.cpp
|
||||
${CMAKE_CURRENT_BINARY_DIR}/jit
|
||||
${CMAKE_C_COMPILER}
|
||||
${PROJECT_SOURCE_DIR}
|
||||
${SRC_FILE}
|
||||
"-D${MLX_METAL_VERSION}"
|
||||
DEPENDS make_compiled_preamble.sh
|
||||
kernels/compiled_preamble.h
|
||||
kernels/unary.h
|
||||
kernels/binary.h
|
||||
)
|
||||
kernels/${SRC_FILE}.h
|
||||
${ARGN}
|
||||
)
|
||||
add_custom_target(${SRC_NAME} DEPENDS jit/${SRC_NAME}.cpp)
|
||||
add_dependencies(mlx ${SRC_NAME})
|
||||
target_sources(
|
||||
mlx
|
||||
PRIVATE
|
||||
${CMAKE_CURRENT_BINARY_DIR}/jit/${SRC_NAME}.cpp
|
||||
)
|
||||
endfunction(make_jit_source)
|
||||
|
||||
add_custom_target(
|
||||
compiled_preamble
|
||||
DEPENDS compiled_preamble.cpp
|
||||
make_jit_source(
|
||||
utils
|
||||
kernels/bf16.h
|
||||
kernels/complex.h
|
||||
kernels/defines.h
|
||||
)
|
||||
make_jit_source(
|
||||
unary_ops
|
||||
kernels/erf.h
|
||||
kernels/expm1f.h
|
||||
)
|
||||
make_jit_source(binary_ops)
|
||||
make_jit_source(ternary_ops)
|
||||
make_jit_source(
|
||||
reduce_utils
|
||||
kernels/atomic.h
|
||||
kernels/reduction/ops.h
|
||||
)
|
||||
make_jit_source(scatter)
|
||||
make_jit_source(gather)
|
||||
|
||||
add_dependencies(mlx compiled_preamble)
|
||||
if (MLX_METAL_JIT)
|
||||
target_sources(
|
||||
mlx
|
||||
PRIVATE
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/jit_kernels.cpp
|
||||
)
|
||||
make_jit_source(arange)
|
||||
make_jit_source(copy)
|
||||
make_jit_source(unary)
|
||||
make_jit_source(binary)
|
||||
make_jit_source(binary_two)
|
||||
make_jit_source(ternary)
|
||||
make_jit_source(softmax)
|
||||
make_jit_source(scan)
|
||||
make_jit_source(sort)
|
||||
make_jit_source(
|
||||
reduce
|
||||
kernels/reduction/reduce_all.h
|
||||
kernels/reduction/reduce_col.h
|
||||
kernels/reduction/reduce_row.h
|
||||
)
|
||||
make_jit_source(
|
||||
steel/gemm/gemm
|
||||
kernels/steel/utils.h
|
||||
kernels/steel/gemm/loader.h
|
||||
kernels/steel/gemm/mma.h
|
||||
kernels/steel/gemm/params.h
|
||||
kernels/steel/gemm/transforms.h
|
||||
)
|
||||
make_jit_source(steel/gemm/kernels/steel_gemm_fused)
|
||||
make_jit_source(
|
||||
steel/gemm/kernels/steel_gemm_masked
|
||||
kernels/steel/defines.h
|
||||
)
|
||||
make_jit_source(steel/gemm/kernels/steel_gemm_splitk)
|
||||
make_jit_source(
|
||||
steel/conv/conv
|
||||
kernels/steel/utils.h
|
||||
kernels/steel/defines.h
|
||||
kernels/steel/gemm/mma.h
|
||||
kernels/steel/gemm/transforms.h
|
||||
kernels/steel/conv/params.h
|
||||
kernels/steel/conv/loader.h
|
||||
kernels/steel/conv/loaders/loader_channel_l.h
|
||||
kernels/steel/conv/loaders/loader_channel_n.h
|
||||
)
|
||||
make_jit_source(
|
||||
steel/conv/kernels/steel_conv
|
||||
)
|
||||
make_jit_source(
|
||||
steel/conv/kernels/steel_conv_general
|
||||
kernels/steel/defines.h
|
||||
kernels/steel/conv/loaders/loader_general.h
|
||||
)
|
||||
else()
|
||||
target_sources(
|
||||
mlx
|
||||
PRIVATE
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/nojit_kernels.cpp
|
||||
)
|
||||
endif()
|
||||
|
||||
target_sources(
|
||||
mlx
|
||||
PRIVATE
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/binary.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/compiled.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp
|
||||
@@ -40,7 +138,8 @@ target_sources(
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/softmax.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/sort.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp
|
||||
${CMAKE_CURRENT_BINARY_DIR}/compiled_preamble.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/ternary.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/unary.cpp
|
||||
)
|
||||
|
||||
if (NOT MLX_METAL_PATH)
|
||||
|
322
mlx/backend/metal/binary.cpp
Normal file
322
mlx/backend/metal/binary.cpp
Normal file
@@ -0,0 +1,322 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#include "mlx/backend/common/binary.h"
|
||||
#include "mlx/backend/metal/device.h"
|
||||
#include "mlx/backend/metal/kernels.h"
|
||||
#include "mlx/backend/metal/utils.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
constexpr int MAX_BINARY_SPECIALIZED_DIMS = 5;
|
||||
|
||||
void binary_op(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs,
|
||||
const std::string op) {
|
||||
assert(inputs.size() == 2);
|
||||
auto& a = inputs[0];
|
||||
auto& b = inputs[1];
|
||||
auto bopt = get_binary_op_type(a, b);
|
||||
set_binary_op_output_data(a, b, outputs[0], bopt, true);
|
||||
set_binary_op_output_data(a, b, outputs[1], bopt, true);
|
||||
|
||||
auto& out = outputs[0];
|
||||
if (out.size() == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Try to collapse contiguous dims
|
||||
auto [shape, strides] = collapse_contiguous_dims(a, b, out);
|
||||
auto& strides_a = strides[0];
|
||||
auto& strides_b = strides[1];
|
||||
auto& strides_out = strides[2];
|
||||
|
||||
std::string kernel_name;
|
||||
{
|
||||
std::ostringstream kname;
|
||||
switch (bopt) {
|
||||
case BinaryOpType::ScalarScalar:
|
||||
kname << "ss";
|
||||
break;
|
||||
case BinaryOpType::ScalarVector:
|
||||
kname << "sv";
|
||||
break;
|
||||
case BinaryOpType::VectorScalar:
|
||||
kname << "vs";
|
||||
break;
|
||||
case BinaryOpType::VectorVector:
|
||||
kname << "vv";
|
||||
break;
|
||||
case BinaryOpType::General:
|
||||
kname << "g";
|
||||
if (shape.size() <= MAX_BINARY_SPECIALIZED_DIMS) {
|
||||
kname << shape.size();
|
||||
} else {
|
||||
kname << "n";
|
||||
}
|
||||
break;
|
||||
}
|
||||
kname << op << type_to_name(a);
|
||||
kernel_name = kname.str();
|
||||
}
|
||||
|
||||
auto& s = out.primitive().stream();
|
||||
auto& d = metal::device(s.device);
|
||||
|
||||
auto kernel = get_binary_two_kernel(d, kernel_name, a, outputs[0]);
|
||||
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
// - If a is donated it goes to the first output
|
||||
// - If b is donated it goes to the first output if a was not donated
|
||||
// otherwise it goes to the second output
|
||||
bool donate_a = a.data_shared_ptr() == nullptr;
|
||||
bool donate_b = b.data_shared_ptr() == nullptr;
|
||||
compute_encoder.set_input_array(donate_a ? outputs[0] : a, 0);
|
||||
compute_encoder.set_input_array(
|
||||
donate_b ? (donate_a ? outputs[1] : outputs[0]) : b, 1);
|
||||
compute_encoder.set_output_array(outputs[0], 2);
|
||||
compute_encoder.set_output_array(outputs[1], 3);
|
||||
|
||||
if (bopt == BinaryOpType::General) {
|
||||
auto ndim = shape.size();
|
||||
if (ndim > 3) {
|
||||
compute_encoder->setBytes(shape.data(), ndim * sizeof(int), 4);
|
||||
compute_encoder->setBytes(strides_a.data(), ndim * sizeof(size_t), 5);
|
||||
compute_encoder->setBytes(strides_b.data(), ndim * sizeof(size_t), 6);
|
||||
} else {
|
||||
// The shape is implicit in the grid for <= 3D
|
||||
compute_encoder->setBytes(strides_a.data(), ndim * sizeof(size_t), 4);
|
||||
compute_encoder->setBytes(strides_b.data(), ndim * sizeof(size_t), 5);
|
||||
}
|
||||
|
||||
if (ndim > MAX_BINARY_SPECIALIZED_DIMS) {
|
||||
compute_encoder->setBytes(&ndim, sizeof(int), 7);
|
||||
}
|
||||
|
||||
// Launch up to 3D grid of threads
|
||||
size_t dim0 = ndim > 0 ? shape[ndim - 1] : 1;
|
||||
size_t dim1 = ndim > 1 ? shape[ndim - 2] : 1;
|
||||
size_t rest = out.size() / (dim0 * dim1);
|
||||
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
||||
if (thread_group_size != 1024) {
|
||||
throw std::runtime_error("[Metal::binary] Must use 1024 sized block");
|
||||
}
|
||||
auto group_dims = get_block_dims(dim0, dim1, rest);
|
||||
MTL::Size grid_dims = MTL::Size(dim0, dim1, rest);
|
||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||
} else {
|
||||
// Launch a 1D grid of threads
|
||||
size_t nthreads = out.data_size();
|
||||
MTL::Size grid_dims = MTL::Size(nthreads, 1, 1);
|
||||
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
||||
if (thread_group_size > nthreads) {
|
||||
thread_group_size = nthreads;
|
||||
}
|
||||
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
|
||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||
}
|
||||
}
|
||||
|
||||
void binary_op(
|
||||
const std::vector<array>& inputs,
|
||||
array& out,
|
||||
const std::string op) {
|
||||
assert(inputs.size() == 2);
|
||||
auto& a = inputs[0];
|
||||
auto& b = inputs[1];
|
||||
auto bopt = get_binary_op_type(a, b);
|
||||
set_binary_op_output_data(a, b, out, bopt, true);
|
||||
if (out.size() == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Try to collapse contiguous dims
|
||||
auto [shape, strides] = collapse_contiguous_dims(a, b, out);
|
||||
auto& strides_a = strides[0];
|
||||
auto& strides_b = strides[1];
|
||||
auto& strides_out = strides[2];
|
||||
|
||||
std::string kernel_name;
|
||||
{
|
||||
std::ostringstream kname;
|
||||
switch (bopt) {
|
||||
case BinaryOpType::ScalarScalar:
|
||||
kname << "ss";
|
||||
break;
|
||||
case BinaryOpType::ScalarVector:
|
||||
kname << "sv";
|
||||
break;
|
||||
case BinaryOpType::VectorScalar:
|
||||
kname << "vs";
|
||||
break;
|
||||
case BinaryOpType::VectorVector:
|
||||
kname << "vv";
|
||||
break;
|
||||
case BinaryOpType::General:
|
||||
kname << "g";
|
||||
if (shape.size() <= MAX_BINARY_SPECIALIZED_DIMS) {
|
||||
kname << shape.size();
|
||||
} else {
|
||||
kname << "n";
|
||||
}
|
||||
break;
|
||||
}
|
||||
kname << op << type_to_name(a);
|
||||
kernel_name = kname.str();
|
||||
}
|
||||
|
||||
auto& s = out.primitive().stream();
|
||||
auto& d = metal::device(s.device);
|
||||
|
||||
auto kernel = get_binary_kernel(d, kernel_name, a, out);
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
bool donate_a = a.data_shared_ptr() == nullptr;
|
||||
bool donate_b = b.data_shared_ptr() == nullptr;
|
||||
compute_encoder.set_input_array(donate_a ? out : a, 0);
|
||||
compute_encoder.set_input_array(donate_b ? out : b, 1);
|
||||
compute_encoder.set_output_array(out, 2);
|
||||
|
||||
if (bopt == BinaryOpType::General) {
|
||||
auto ndim = shape.size();
|
||||
if (ndim > 3) {
|
||||
compute_encoder->setBytes(shape.data(), ndim * sizeof(int), 3);
|
||||
compute_encoder->setBytes(strides_a.data(), ndim * sizeof(size_t), 4);
|
||||
compute_encoder->setBytes(strides_b.data(), ndim * sizeof(size_t), 5);
|
||||
} else {
|
||||
// The shape is implicit in the grid for <= 3D
|
||||
compute_encoder->setBytes(strides_a.data(), ndim * sizeof(size_t), 3);
|
||||
compute_encoder->setBytes(strides_b.data(), ndim * sizeof(size_t), 4);
|
||||
}
|
||||
|
||||
if (ndim > MAX_BINARY_SPECIALIZED_DIMS) {
|
||||
compute_encoder->setBytes(&ndim, sizeof(int), 6);
|
||||
}
|
||||
|
||||
// Launch up to 3D grid of threads
|
||||
size_t dim0 = ndim > 0 ? shape[ndim - 1] : 1;
|
||||
size_t dim1 = ndim > 1 ? shape[ndim - 2] : 1;
|
||||
size_t rest = out.size() / (dim0 * dim1);
|
||||
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
||||
if (thread_group_size != 1024) {
|
||||
throw std::runtime_error("[Metal::binary] Must use 1024 sized block");
|
||||
}
|
||||
auto group_dims = get_block_dims(dim0, dim1, rest);
|
||||
MTL::Size grid_dims = MTL::Size(dim0, dim1, rest);
|
||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||
} else {
|
||||
// Launch a 1D grid of threads
|
||||
size_t nthreads =
|
||||
bopt == BinaryOpType::General ? out.size() : out.data_size();
|
||||
MTL::Size grid_dims = MTL::Size(nthreads, 1, 1);
|
||||
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
||||
if (thread_group_size > nthreads) {
|
||||
thread_group_size = nthreads;
|
||||
}
|
||||
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
|
||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||
}
|
||||
}
|
||||
|
||||
void Add::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
binary_op(inputs, out, "add");
|
||||
}
|
||||
|
||||
void ArcTan2::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
binary_op(inputs, out, "arctan2");
|
||||
}
|
||||
|
||||
void BitwiseBinary::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
switch (op_) {
|
||||
case BitwiseBinary::And:
|
||||
binary_op(inputs, out, "bitwise_and");
|
||||
break;
|
||||
case BitwiseBinary::Or:
|
||||
binary_op(inputs, out, "bitwise_or");
|
||||
break;
|
||||
case BitwiseBinary::Xor:
|
||||
binary_op(inputs, out, "bitwise_xor");
|
||||
break;
|
||||
case BitwiseBinary::LeftShift:
|
||||
binary_op(inputs, out, "left_shift");
|
||||
break;
|
||||
case BitwiseBinary::RightShift:
|
||||
binary_op(inputs, out, "right_shift");
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
void Divide::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
binary_op(inputs, out, "div");
|
||||
}
|
||||
|
||||
void DivMod::eval_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
binary_op(inputs, outputs, "divmod");
|
||||
}
|
||||
|
||||
void Remainder::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
binary_op(inputs, out, "rem");
|
||||
}
|
||||
|
||||
void Equal::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
binary_op(inputs, out, equal_nan_ ? "naneq" : "eq");
|
||||
}
|
||||
|
||||
void Greater::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
binary_op(inputs, out, "ge");
|
||||
}
|
||||
|
||||
void GreaterEqual::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
binary_op(inputs, out, "geq");
|
||||
}
|
||||
|
||||
void Less::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
binary_op(inputs, out, "le");
|
||||
}
|
||||
|
||||
void LessEqual::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
binary_op(inputs, out, "leq");
|
||||
}
|
||||
|
||||
void LogicalAnd::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
binary_op(inputs, out, "land");
|
||||
}
|
||||
|
||||
void LogicalOr::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
binary_op(inputs, out, "lor");
|
||||
}
|
||||
|
||||
void LogAddExp::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
binary_op(inputs, out, "lae");
|
||||
}
|
||||
|
||||
void Maximum::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
binary_op(inputs, out, "max");
|
||||
}
|
||||
|
||||
void Minimum::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
binary_op(inputs, out, "min");
|
||||
}
|
||||
|
||||
void Multiply::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
binary_op(inputs, out, "mul");
|
||||
}
|
||||
|
||||
void NotEqual::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
binary_op(inputs, out, "neq");
|
||||
}
|
||||
|
||||
void Power::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
binary_op(inputs, out, "pow");
|
||||
}
|
||||
|
||||
void Subtract::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
binary_op(inputs, out, "sub");
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
@@ -4,8 +4,8 @@
|
||||
|
||||
#include "mlx/backend/common/compiled.h"
|
||||
#include "mlx/backend/common/utils.h"
|
||||
#include "mlx/backend/metal/compiled_preamble.h"
|
||||
#include "mlx/backend/metal/device.h"
|
||||
#include "mlx/backend/metal/jit/includes.h"
|
||||
#include "mlx/backend/metal/utils.h"
|
||||
#include "mlx/graph_utils.h"
|
||||
#include "mlx/primitives.h"
|
||||
@@ -190,7 +190,8 @@ void Compiled::eval_gpu(
|
||||
// If not we have to build it ourselves
|
||||
if (lib == nullptr) {
|
||||
std::ostringstream kernel;
|
||||
kernel << metal::get_kernel_preamble() << std::endl;
|
||||
kernel << metal::utils() << metal::unary_ops() << metal::binary_ops()
|
||||
<< metal::ternary_ops();
|
||||
build_kernel(
|
||||
kernel,
|
||||
kernel_lib_ + "_contiguous",
|
||||
@@ -336,7 +337,7 @@ void Compiled::eval_gpu(
|
||||
MTL::Size grid_dims(nthreads, 1, 1);
|
||||
MTL::Size group_dims(
|
||||
std::min(nthreads, kernel->maxTotalThreadsPerThreadgroup()), 1, 1);
|
||||
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||
} else {
|
||||
size_t dim0 = ndim > 0 ? shape[ndim - 1] : 1;
|
||||
size_t dim1 = ndim > 1 ? shape[ndim - 2] : 1;
|
||||
@@ -347,7 +348,7 @@ void Compiled::eval_gpu(
|
||||
}
|
||||
auto group_dims = get_block_dims(dim0, dim1, rest);
|
||||
MTL::Size grid_dims = MTL::Size(dim0, dim1, rest);
|
||||
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||
}
|
||||
}
|
||||
|
||||
|
@@ -1,9 +0,0 @@
|
||||
// Copyright © 2023-24 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
namespace mlx::core::metal {
|
||||
|
||||
const char* get_kernel_preamble();
|
||||
|
||||
}
|
@@ -7,6 +7,7 @@
|
||||
|
||||
#include "mlx/backend/metal/copy.h"
|
||||
#include "mlx/backend/metal/device.h"
|
||||
#include "mlx/backend/metal/kernels.h"
|
||||
#include "mlx/backend/metal/kernels/defines.h"
|
||||
#include "mlx/backend/metal/kernels/steel/conv/params.h"
|
||||
#include "mlx/backend/metal/matmul.h"
|
||||
@@ -59,7 +60,7 @@ void explicit_gemm_conv_ND_gpu(
|
||||
MTL::Size grid_dims = MTL::Size(
|
||||
conv_params.C, unfolded_shape[1] / conv_params.C, unfolded_shape[0]);
|
||||
|
||||
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||
|
||||
// Reshape weight
|
||||
std::vector<int> wt_reshape{implicit_K, implicit_N};
|
||||
@@ -137,7 +138,7 @@ void explicit_gemm_conv_group_ND_gpu(
|
||||
MTL::Size grid_dims = MTL::Size(
|
||||
conv_params.C, unfolded_shape[1] / conv_params.C, unfolded_shape[0]);
|
||||
|
||||
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||
|
||||
// Transpose kernel weights so that we can slice them by contiguous chunks
|
||||
// of channel groups.
|
||||
@@ -247,7 +248,7 @@ void slow_conv_2D_gpu(
|
||||
compute_encoder.set_output_array(out, 2);
|
||||
|
||||
compute_encoder->setBytes(&conv_params, sizeof(MLXConvParams<2>), 3);
|
||||
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
|
||||
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
void implicit_gemm_conv_2D_gpu(
|
||||
@@ -257,15 +258,19 @@ void implicit_gemm_conv_2D_gpu(
|
||||
const array& wt,
|
||||
array out,
|
||||
const MLXConvParams<2>& conv_params) {
|
||||
const int groups = conv_params.groups;
|
||||
const int C_per_group = conv_params.C / conv_params.groups;
|
||||
const int O_per_group = conv_params.O / conv_params.groups;
|
||||
|
||||
// Deduce implicit gemm size
|
||||
int implicit_M = conv_params.N * conv_params.oS[0] * conv_params.oS[1];
|
||||
int implicit_N = conv_params.O;
|
||||
int implicit_K = conv_params.wS[0] * conv_params.wS[1] * conv_params.C;
|
||||
const int implicit_M = conv_params.N * conv_params.oS[0] * conv_params.oS[1];
|
||||
const int implicit_N = O_per_group;
|
||||
const int implicit_K = conv_params.wS[0] * conv_params.wS[1] * C_per_group;
|
||||
|
||||
// Determine block and warp tiles
|
||||
int wm = 2, wn = 2;
|
||||
|
||||
int bm = implicit_M >= 8192 && conv_params.C >= 64 ? 64 : 32;
|
||||
int bm = implicit_M >= 8192 && C_per_group >= 64 ? 64 : 32;
|
||||
int bn = (bm == 64 || implicit_N >= 64) ? 64 : 32;
|
||||
int bk = 16;
|
||||
|
||||
@@ -281,15 +286,15 @@ void implicit_gemm_conv_2D_gpu(
|
||||
|
||||
// Fix small channel specialization
|
||||
int n_channel_specialization = 0;
|
||||
int channel_k_iters = ((conv_params.C + bk - 1) / bk);
|
||||
int channel_k_iters = ((C_per_group + bk - 1) / bk);
|
||||
int gemm_k_iters = conv_params.wS[0] * conv_params.wS[1] * channel_k_iters;
|
||||
|
||||
if (conv_params.C <= 2) {
|
||||
if (C_per_group <= 2) {
|
||||
gemm_k_iters = (implicit_K + bk - 1) / bk;
|
||||
n_channel_specialization = conv_params.C;
|
||||
} else if (conv_params.C <= 4) {
|
||||
n_channel_specialization = C_per_group;
|
||||
} else if (C_per_group <= 4) {
|
||||
gemm_k_iters = ((conv_params.wS[0] * conv_params.wS[1] * 4) + bk - 1) / bk;
|
||||
n_channel_specialization = conv_params.C;
|
||||
n_channel_specialization = C_per_group;
|
||||
}
|
||||
|
||||
bool small_filter = (!n_channel_specialization) &&
|
||||
@@ -331,7 +336,17 @@ void implicit_gemm_conv_2D_gpu(
|
||||
|
||||
// Encode and dispatch kernel
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
auto kernel = get_steel_conv_kernel(
|
||||
d,
|
||||
kname.str(),
|
||||
out,
|
||||
bm,
|
||||
bn,
|
||||
bk,
|
||||
wm,
|
||||
wn,
|
||||
n_channel_specialization,
|
||||
small_filter);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
// Deduce grid launch dimensions
|
||||
@@ -340,7 +355,7 @@ void implicit_gemm_conv_2D_gpu(
|
||||
size_t grid_dim_x = tn * tile;
|
||||
|
||||
MTL::Size group_dims = MTL::Size(32, wn, wm);
|
||||
MTL::Size grid_dims = MTL::Size(grid_dim_x, grid_dim_y, 1);
|
||||
MTL::Size grid_dims = MTL::Size(grid_dim_x, grid_dim_y, groups);
|
||||
|
||||
// Encode arrays
|
||||
compute_encoder.set_input_array(in, 0);
|
||||
@@ -352,7 +367,7 @@ void implicit_gemm_conv_2D_gpu(
|
||||
compute_encoder->setBytes(&gemm_params, sizeof(ImplicitGemmConv2DParams), 4);
|
||||
|
||||
// Launch kernel
|
||||
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
|
||||
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
void implicit_gemm_conv_2D_general_gpu(
|
||||
@@ -484,7 +499,8 @@ void implicit_gemm_conv_2D_general_gpu(
|
||||
|
||||
// Encode and dispatch kernel
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
auto kernel =
|
||||
get_steel_conv_general_kernel(d, kname.str(), out, bm, bn, bk, wm, wn);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
// Deduce grid launch dimensions
|
||||
@@ -512,7 +528,7 @@ void implicit_gemm_conv_2D_general_gpu(
|
||||
base_w.data(), sizeof(Conv2DGeneralBaseInfo) * base_w.size(), 7);
|
||||
|
||||
// Launch kernel
|
||||
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
|
||||
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
void winograd_conv_2D_gpu(
|
||||
@@ -613,7 +629,7 @@ void winograd_conv_2D_gpu(
|
||||
MTL::Size group_dims = MTL::Size(32, bo, 1);
|
||||
MTL::Size grid_dims = MTL::Size(O_c / bo, 1, 1);
|
||||
|
||||
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
|
||||
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
// Do input transform
|
||||
@@ -641,7 +657,7 @@ void winograd_conv_2D_gpu(
|
||||
MTL::Size group_dims = MTL::Size(32, wn, wm);
|
||||
MTL::Size grid_dims = MTL::Size(N_tiles_w, N_tiles_h, N_tiles_n);
|
||||
|
||||
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
|
||||
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
// Do batched gemm
|
||||
@@ -689,7 +705,7 @@ void winograd_conv_2D_gpu(
|
||||
MTL::Size group_dims = MTL::Size(32, wn, wm);
|
||||
MTL::Size grid_dims = MTL::Size(N_tiles_w, N_tiles_h, N_tiles_n);
|
||||
|
||||
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
|
||||
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -703,6 +719,7 @@ void conv_2D_gpu(
|
||||
const std::vector<int>& wt_strides,
|
||||
const std::vector<int>& wt_dilation,
|
||||
const std::vector<int>& in_dilation,
|
||||
const int groups,
|
||||
bool flip,
|
||||
std::vector<array>& copies) {
|
||||
// Make conv params
|
||||
@@ -718,12 +735,12 @@ void conv_2D_gpu(
|
||||
/* const int kdil[NDIM] = */ {wt_dilation[0], wt_dilation[1]},
|
||||
/* const int idil[NDIM] = */ {in_dilation[0], in_dilation[1]},
|
||||
/* const size_t in_strides[NDIM + 2] = */
|
||||
{in.strides()[0], in.strides()[1], in.strides()[2], in.strides()[3]},
|
||||
{in.strides(0), in.strides(1), in.strides(2), in.strides(3)},
|
||||
/* const size_t wt_strides[NDIM + 2] = */
|
||||
{wt.strides()[0], wt.strides()[1], wt.strides()[2], wt.strides()[3]},
|
||||
{wt.strides(0), wt.strides(1), wt.strides(2), wt.strides(3)},
|
||||
/* const size_t out_strides[NDIM + 2] = */
|
||||
{out.strides()[0], out.strides()[1], out.strides()[2], out.strides()[3]},
|
||||
/* const int groups = */ 1,
|
||||
{out.strides(0), out.strides(1), out.strides(2), out.strides(3)},
|
||||
/* const int groups = */ groups,
|
||||
/* const bool flip = */ flip,
|
||||
};
|
||||
|
||||
@@ -735,6 +752,18 @@ void conv_2D_gpu(
|
||||
bool channels_large = (conv_params.C + conv_params.O) >= 512;
|
||||
bool channels_med = (conv_params.C + conv_params.O) >= 256;
|
||||
|
||||
if (groups > 1) {
|
||||
const int C_per_group = conv_params.C / groups;
|
||||
const int O_per_group = conv_params.O / groups;
|
||||
|
||||
if (is_idil_one && (C_per_group <= 4 || C_per_group % 16 == 0) &&
|
||||
(O_per_group <= 16 || O_per_group % 16 == 0)) {
|
||||
return implicit_gemm_conv_2D_gpu(s, d, in, wt, out, conv_params);
|
||||
} else {
|
||||
return explicit_gemm_conv_group_ND_gpu(s, d, in, wt, out, conv_params);
|
||||
}
|
||||
}
|
||||
|
||||
// Direct to winograd conv
|
||||
if (!flip && is_stride_one && is_kdil_one && is_idil_one &&
|
||||
conv_params.wS[0] == 3 && conv_params.wS[1] == 3 &&
|
||||
@@ -759,6 +788,56 @@ void conv_2D_gpu(
|
||||
}
|
||||
}
|
||||
|
||||
void conv_3D_gpu(
|
||||
const Stream& s,
|
||||
metal::Device& d,
|
||||
const array& in,
|
||||
const array& wt,
|
||||
array out,
|
||||
const std::vector<int>& padding,
|
||||
const std::vector<int>& wt_strides,
|
||||
const std::vector<int>& wt_dilation,
|
||||
const std::vector<int>& in_dilation,
|
||||
bool flip,
|
||||
std::vector<array>& copies) {
|
||||
// Make conv params
|
||||
MLXConvParams<3> conv_params{
|
||||
/* const int N = */ in.shape(0),
|
||||
/* const int C = */ in.shape(4),
|
||||
/* const int O = */ wt.shape(0),
|
||||
/* const int iS[NDIM] = */ {in.shape(1), in.shape(2), in.shape(3)},
|
||||
/* const int wS[NDIM] = */ {wt.shape(1), wt.shape(2), wt.shape(3)},
|
||||
/* const int oS[NDIM] = */ {out.shape(1), out.shape(2), out.shape(3)},
|
||||
/* const int str[NDIM] = */ {wt_strides[0], wt_strides[1], wt_strides[2]},
|
||||
/* const int pad[NDIM] = */ {padding[0], padding[1], padding[2]},
|
||||
/* const int kdil[NDIM] = */
|
||||
{wt_dilation[0], wt_dilation[1], wt_dilation[2]},
|
||||
/* const int idil[NDIM] = */
|
||||
{in_dilation[0], in_dilation[1], in_dilation[2]},
|
||||
/* const size_t in_strides[NDIM + 2] = */
|
||||
{in.strides()[0],
|
||||
in.strides()[1],
|
||||
in.strides()[2],
|
||||
in.strides()[3],
|
||||
in.strides()[4]},
|
||||
/* const size_t wt_strides[NDIM + 2] = */
|
||||
{wt.strides()[0],
|
||||
wt.strides()[1],
|
||||
wt.strides()[2],
|
||||
wt.strides()[3],
|
||||
wt.strides()[4]},
|
||||
/* const size_t out_strides[NDIM + 2] = */
|
||||
{out.strides()[0],
|
||||
out.strides()[1],
|
||||
out.strides()[2],
|
||||
out.strides()[3],
|
||||
out.strides()[4]},
|
||||
/* const int groups = */ 1,
|
||||
/* const bool flip = */ flip,
|
||||
};
|
||||
return explicit_gemm_conv_ND_gpu(s, d, in, wt, out, conv_params);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void Convolution::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
@@ -783,8 +862,23 @@ void Convolution::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
wt = arr_copy;
|
||||
}
|
||||
|
||||
// 3D conv
|
||||
if (out.ndim() == 5) {
|
||||
conv_3D_gpu(
|
||||
s,
|
||||
d,
|
||||
in,
|
||||
wt,
|
||||
out,
|
||||
padding_,
|
||||
kernel_strides_,
|
||||
kernel_dilation_,
|
||||
input_dilation_,
|
||||
flip_,
|
||||
copies);
|
||||
}
|
||||
// 2D conv
|
||||
if (out.ndim() == 4) {
|
||||
else if (out.ndim() == 4) {
|
||||
conv_2D_gpu(
|
||||
s,
|
||||
d,
|
||||
@@ -795,6 +889,7 @@ void Convolution::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
kernel_strides_,
|
||||
kernel_dilation_,
|
||||
input_dilation_,
|
||||
groups_,
|
||||
flip_,
|
||||
copies);
|
||||
}
|
||||
|
@@ -4,12 +4,14 @@
|
||||
|
||||
#include "mlx/backend/metal/copy.h"
|
||||
#include "mlx/backend/metal/device.h"
|
||||
#include "mlx/backend/metal/kernels/defines.h"
|
||||
#include "mlx/backend/metal/kernels.h"
|
||||
#include "mlx/backend/metal/utils.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
constexpr int MAX_COPY_SPECIALIZED_DIMS = 5;
|
||||
|
||||
void copy_gpu(const array& in, array& out, CopyType ctype, const Stream& s) {
|
||||
if (ctype == CopyType::Vector) {
|
||||
// If the input is donateable, we are doing a vector copy and the types
|
||||
@@ -62,27 +64,34 @@ void copy_gpu_inplace(
|
||||
auto& strides_out_ = strides[1];
|
||||
|
||||
auto& d = metal::device(s.device);
|
||||
std::ostringstream kname;
|
||||
switch (ctype) {
|
||||
case CopyType::Scalar:
|
||||
kname << "scopy";
|
||||
break;
|
||||
case CopyType::Vector:
|
||||
kname << "vcopy";
|
||||
break;
|
||||
case CopyType::General:
|
||||
kname << "gcopy";
|
||||
break;
|
||||
case CopyType::GeneralGeneral:
|
||||
kname << "ggcopy";
|
||||
break;
|
||||
std::string kernel_name;
|
||||
{
|
||||
std::ostringstream kname;
|
||||
switch (ctype) {
|
||||
case CopyType::Scalar:
|
||||
kname << "s";
|
||||
break;
|
||||
case CopyType::Vector:
|
||||
kname << "v";
|
||||
break;
|
||||
case CopyType::General:
|
||||
kname << "g";
|
||||
break;
|
||||
case CopyType::GeneralGeneral:
|
||||
kname << "gg";
|
||||
break;
|
||||
}
|
||||
if ((ctype == CopyType::General || ctype == CopyType::GeneralGeneral) &&
|
||||
shape.size() <= MAX_COPY_SPECIALIZED_DIMS) {
|
||||
kname << shape.size();
|
||||
}
|
||||
kname << "_copy";
|
||||
kname << type_to_name(in) << type_to_name(out);
|
||||
kernel_name = kname.str();
|
||||
}
|
||||
kname << type_to_name(in) << type_to_name(out);
|
||||
if ((ctype == CopyType::General || ctype == CopyType::GeneralGeneral) &&
|
||||
shape.size() <= MAX_COPY_SPECIALIZED_DIMS) {
|
||||
kname << "_" << shape.size();
|
||||
}
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
|
||||
auto kernel = get_copy_kernel(d, kernel_name, in, out);
|
||||
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
bool donate_in = in.data_shared_ptr() == nullptr;
|
||||
@@ -106,7 +115,7 @@ void copy_gpu_inplace(
|
||||
set_vector_bytes(compute_encoder, strides_out, ndim, 4);
|
||||
}
|
||||
|
||||
if (ndim > MAX_BINARY_SPECIALIZED_DIMS) {
|
||||
if (ndim > MAX_COPY_SPECIALIZED_DIMS) {
|
||||
compute_encoder->setBytes(&ndim, sizeof(int), 5);
|
||||
}
|
||||
|
||||
@@ -126,7 +135,7 @@ void copy_gpu_inplace(
|
||||
|
||||
auto group_dims = get_block_dims(dim0, dim1, rest);
|
||||
MTL::Size grid_dims = MTL::Size(dim0, dim1, rest);
|
||||
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||
} else {
|
||||
size_t nthreads = out.data_size();
|
||||
MTL::Size grid_dims = MTL::Size(nthreads, 1, 1);
|
||||
@@ -135,7 +144,7 @@ void copy_gpu_inplace(
|
||||
thread_group_size = nthreads;
|
||||
}
|
||||
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
|
||||
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||
}
|
||||
}
|
||||
|
||||
|
@@ -25,9 +25,18 @@ namespace {
|
||||
|
||||
// TODO nicer way to set this or possibly expose as an environment variable
|
||||
constexpr int MAX_BUFFERS_PER_QUEUE = 12;
|
||||
constexpr int MAX_DISPATCHES_PER_ENCODER = 2;
|
||||
|
||||
constexpr const char* default_mtllib_path = METAL_PATH;
|
||||
|
||||
constexpr auto get_metal_version() {
|
||||
#if defined METAL_3_1
|
||||
return MTL::LanguageVersion3_1;
|
||||
#else
|
||||
return MTL::LanguageVersion3_0;
|
||||
#endif
|
||||
}
|
||||
|
||||
auto load_device() {
|
||||
auto devices = MTL::CopyAllDevices();
|
||||
auto device = static_cast<MTL::Device*>(devices->object(0))
|
||||
@@ -37,7 +46,6 @@ auto load_device() {
|
||||
}
|
||||
return device;
|
||||
}
|
||||
|
||||
std::pair<MTL::Library*, NS::Error*> load_library_from_path(
|
||||
MTL::Device* device,
|
||||
const char* path) {
|
||||
@@ -116,6 +124,33 @@ MTL::Library* load_library(
|
||||
|
||||
} // namespace
|
||||
|
||||
void CommandEncoder::dispatchThreadgroups(
|
||||
MTL::Size grid_dims,
|
||||
MTL::Size group_dims) {
|
||||
num_dispatches++;
|
||||
enc->dispatchThreadgroups(grid_dims, group_dims);
|
||||
maybe_split();
|
||||
}
|
||||
|
||||
void CommandEncoder::dispatchThreads(
|
||||
MTL::Size grid_dims,
|
||||
MTL::Size group_dims) {
|
||||
num_dispatches++;
|
||||
enc->dispatchThreads(grid_dims, group_dims);
|
||||
maybe_split();
|
||||
}
|
||||
|
||||
void CommandEncoder::maybe_split() {
|
||||
if (num_dispatches > MAX_DISPATCHES_PER_ENCODER && !concurrent) {
|
||||
enc->endEncoding();
|
||||
enc->release();
|
||||
num_dispatches = 0;
|
||||
outputs.clear();
|
||||
enc = cbuf->computeCommandEncoder(MTL::DispatchTypeConcurrent);
|
||||
enc->retain();
|
||||
}
|
||||
}
|
||||
|
||||
Device::Device() {
|
||||
auto pool = new_scoped_memory_pool();
|
||||
device_ = load_device();
|
||||
@@ -130,9 +165,6 @@ Device::~Device() {
|
||||
for (auto& b : buffer_map_) {
|
||||
b.second.second->release();
|
||||
}
|
||||
for (auto& e : encoder_map_) {
|
||||
(*e.second)->release();
|
||||
}
|
||||
for (auto& k : kernel_map_) {
|
||||
k.second->release();
|
||||
}
|
||||
@@ -169,27 +201,26 @@ void Device::increment_command_buffer_ops(int index) {
|
||||
|
||||
MTL::CommandBuffer* Device::get_command_buffer(int index) {
|
||||
auto bit = buffer_map_.find(index);
|
||||
return (bit == buffer_map_.end()) ? nullptr : bit->second.second;
|
||||
}
|
||||
if (bit == buffer_map_.end()) {
|
||||
auto qit = queue_map_.find(index);
|
||||
if (qit == queue_map_.end()) {
|
||||
throw std::runtime_error(
|
||||
"[metal::Device] Attempting to get command buffer for invalid queue.");
|
||||
}
|
||||
|
||||
MTL::CommandBuffer* Device::new_command_buffer(int index) {
|
||||
auto qit = queue_map_.find(index);
|
||||
if (qit == queue_map_.end()) {
|
||||
throw std::runtime_error(
|
||||
"[metal::Device] Attempting to get command buffer for invalid queue.");
|
||||
auto cb = qit->second->commandBufferWithUnretainedReferences();
|
||||
|
||||
if (!cb) {
|
||||
throw std::runtime_error(
|
||||
"[metal::Device] Unable to create new command buffer");
|
||||
}
|
||||
|
||||
// Increment ref count so the buffer is not garbage collected
|
||||
cb->retain();
|
||||
|
||||
bit = buffer_map_.insert({index, {0, cb}}).first;
|
||||
}
|
||||
|
||||
auto cb = qit->second->commandBufferWithUnretainedReferences();
|
||||
|
||||
if (!cb) {
|
||||
throw std::runtime_error(
|
||||
"[metal::Device] Unable to create new command buffer");
|
||||
}
|
||||
|
||||
// Increment ref count so the buffer is not garbage collected
|
||||
cb->retain();
|
||||
|
||||
return buffer_map_.insert({index, {0, cb}}).first->second.second;
|
||||
return bit->second.second;
|
||||
}
|
||||
|
||||
void Device::commit_command_buffer(int index) {
|
||||
@@ -200,25 +231,15 @@ void Device::commit_command_buffer(int index) {
|
||||
}
|
||||
|
||||
void Device::end_encoding(int index) {
|
||||
auto eit = encoder_map_.find(index);
|
||||
if (eit != encoder_map_.end()) {
|
||||
(*eit->second)->endEncoding();
|
||||
(*eit->second)->release();
|
||||
encoder_map_.erase(eit);
|
||||
}
|
||||
encoder_map_.erase(index);
|
||||
}
|
||||
|
||||
CommandEncoder& Device::get_command_encoder(int index) {
|
||||
auto eit = encoder_map_.find(index);
|
||||
if (eit == encoder_map_.end()) {
|
||||
auto cb = get_command_buffer(index);
|
||||
auto compute_encoder =
|
||||
cb->computeCommandEncoder(MTL::DispatchTypeConcurrent);
|
||||
// Increment ref count so the buffer is not garbage collected
|
||||
compute_encoder->retain();
|
||||
eit = encoder_map_
|
||||
.emplace(index, std::make_unique<CommandEncoder>(compute_encoder))
|
||||
.first;
|
||||
eit =
|
||||
encoder_map_.emplace(index, std::make_unique<CommandEncoder>(cb)).first;
|
||||
}
|
||||
return *(eit->second);
|
||||
}
|
||||
@@ -262,7 +283,11 @@ MTL::Library* Device::get_library_(const std::string& source_string) {
|
||||
NS::String::string(source_string.c_str(), NS::ASCIIStringEncoding);
|
||||
|
||||
NS::Error* error = nullptr;
|
||||
auto mtl_lib = device_->newLibrary(ns_code, nullptr, &error);
|
||||
auto options = MTL::CompileOptions::alloc()->init();
|
||||
options->setFastMathEnabled(false);
|
||||
options->setLanguageVersion(get_metal_version());
|
||||
auto mtl_lib = device_->newLibrary(ns_code, options, &error);
|
||||
options->release();
|
||||
|
||||
// Throw error if unable to compile library
|
||||
if (!mtl_lib) {
|
||||
@@ -344,7 +369,6 @@ MTL::Function* Device::get_function_(
|
||||
}
|
||||
|
||||
mtl_func_consts->release();
|
||||
desc->release();
|
||||
|
||||
return mtl_function;
|
||||
}
|
||||
@@ -513,11 +537,13 @@ MTL::ComputePipelineState* Device::get_kernel(
|
||||
// Compile kernel to compute pipeline
|
||||
auto mtl_linked_funcs = get_linked_functions_(linked_functions);
|
||||
auto kernel = get_kernel_(kname, mtl_function, mtl_linked_funcs);
|
||||
|
||||
mtl_function->release();
|
||||
mtl_linked_funcs->release();
|
||||
|
||||
// Add kernel to cache
|
||||
kernel_map_.insert({kname, kernel});
|
||||
|
||||
return kernel;
|
||||
}
|
||||
|
||||
|
@@ -37,8 +37,10 @@ using MTLFCList =
|
||||
std::vector<std::tuple<const void*, MTL::DataType, NS::UInteger>>;
|
||||
|
||||
struct CommandEncoder {
|
||||
CommandEncoder(MTL::ComputeCommandEncoder* enc)
|
||||
: enc(enc), concurrent(false) {};
|
||||
CommandEncoder(MTL::CommandBuffer* cbuf) : cbuf(cbuf) {
|
||||
enc = cbuf->computeCommandEncoder(MTL::DispatchTypeConcurrent);
|
||||
enc->retain();
|
||||
};
|
||||
CommandEncoder(const CommandEncoder&) = delete;
|
||||
CommandEncoder& operator=(const CommandEncoder&) = delete;
|
||||
|
||||
@@ -61,7 +63,7 @@ struct CommandEncoder {
|
||||
return enc;
|
||||
}
|
||||
|
||||
void set_input_array(const array& a, int idx, int offset = 0) {
|
||||
void set_input_array(const array& a, int idx, int64_t offset = 0) {
|
||||
auto r_buf =
|
||||
static_cast<MTL::Resource*>(const_cast<void*>(a.buffer().ptr()));
|
||||
if (auto it = outputs.find(r_buf); it != outputs.end()) {
|
||||
@@ -78,7 +80,7 @@ struct CommandEncoder {
|
||||
enc->setBuffer(a_buf, base_offset, idx);
|
||||
}
|
||||
|
||||
void set_output_array(array& a, int idx, int offset = 0) {
|
||||
void set_output_array(array& a, int idx, int64_t offset = 0) {
|
||||
// Add barriers before adding the output to the output set
|
||||
set_input_array(a, idx, offset);
|
||||
auto buf = static_cast<MTL::Resource*>(a.buffer().ptr());
|
||||
@@ -89,13 +91,25 @@ struct CommandEncoder {
|
||||
}
|
||||
}
|
||||
|
||||
void dispatchThreadgroups(MTL::Size grid_dims, MTL::Size group_dims);
|
||||
void dispatchThreads(MTL::Size grid_dims, MTL::Size group_dims);
|
||||
|
||||
ConcurrentContext start_concurrent() {
|
||||
return ConcurrentContext(*this);
|
||||
}
|
||||
|
||||
~CommandEncoder() {
|
||||
enc->endEncoding();
|
||||
enc->release();
|
||||
}
|
||||
|
||||
private:
|
||||
void maybe_split();
|
||||
|
||||
int num_dispatches{0};
|
||||
MTL::CommandBuffer* cbuf;
|
||||
MTL::ComputeCommandEncoder* enc;
|
||||
bool concurrent;
|
||||
bool concurrent{false};
|
||||
std::unordered_set<MTL::Resource*> outputs;
|
||||
std::unordered_set<MTL::Resource*> concurrent_outputs;
|
||||
};
|
||||
@@ -112,7 +126,6 @@ class Device {
|
||||
};
|
||||
|
||||
void new_queue(int index);
|
||||
MTL::CommandBuffer* new_command_buffer(int index);
|
||||
MTL::CommandBuffer* get_command_buffer(int index);
|
||||
int get_command_buffer_ops(int index);
|
||||
void increment_command_buffer_ops(int index);
|
||||
|
@@ -97,7 +97,7 @@ void FFT::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
|
||||
auto group_dims = MTL::Size(1, m, 1);
|
||||
auto grid_dims = MTL::Size(batch, m, 1);
|
||||
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||
}
|
||||
d.get_command_buffer(s.index)->addCompletedHandler(
|
||||
[copies](MTL::CommandBuffer*) mutable { copies.clear(); });
|
||||
|
@@ -1,24 +1,35 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
#include <algorithm>
|
||||
#include <cassert>
|
||||
#include <numeric>
|
||||
#include <sstream>
|
||||
#include <fmt/format.h>
|
||||
|
||||
#include "mlx/backend/common/binary.h"
|
||||
#include "mlx/backend/common/compiled.h"
|
||||
#include "mlx/backend/metal/copy.h"
|
||||
#include "mlx/backend/metal/device.h"
|
||||
#include "mlx/backend/metal/kernels/defines.h"
|
||||
#include "mlx/backend/metal/jit/includes.h"
|
||||
#include "mlx/backend/metal/jit/indexing.h"
|
||||
#include "mlx/backend/metal/utils.h"
|
||||
#include "mlx/primitives.h"
|
||||
#include "mlx/utils.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace {
|
||||
constexpr int METAL_MAX_INDEX_ARRAYS = 20;
|
||||
|
||||
constexpr int METAL_MAX_INDEX_ARRAYS = 10;
|
||||
|
||||
} // namespace
|
||||
std::pair<std::string, std::string> make_index_args(
|
||||
const std::string& idx_type,
|
||||
int nidx) {
|
||||
std::ostringstream idx_args;
|
||||
std::ostringstream idx_arr;
|
||||
for (int i = 0; i < nidx; ++i) {
|
||||
idx_args << fmt::format(
|
||||
"const device {0} *idx{1} [[buffer({2})]],", idx_type, i, 20 + i);
|
||||
idx_arr << fmt::format("idx{0}", i);
|
||||
if (i < nidx - 1) {
|
||||
idx_args << "\n";
|
||||
idx_arr << ",";
|
||||
}
|
||||
}
|
||||
return {idx_args.str(), idx_arr.str()};
|
||||
}
|
||||
|
||||
void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
auto& src = inputs[0];
|
||||
@@ -42,15 +53,41 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
int idx_ndim = nidx ? inputs[1].ndim() : 0;
|
||||
size_t ndim = src.ndim();
|
||||
|
||||
std::ostringstream kname;
|
||||
std::string lib_name;
|
||||
std::string kernel_name;
|
||||
std::string idx_type_name = nidx ? type_to_name(inputs[1]) : "";
|
||||
kname << "gather" << type_to_name(src) << idx_type_name << "_" << nidx;
|
||||
if (idx_ndim <= 1) {
|
||||
kname << "_" << idx_ndim;
|
||||
{
|
||||
std::ostringstream kname;
|
||||
kname << "gather" << type_to_name(out) << idx_type_name << "_" << nidx
|
||||
<< "_" << idx_ndim;
|
||||
lib_name = kname.str();
|
||||
kernel_name = lib_name;
|
||||
}
|
||||
|
||||
auto lib = d.get_library(lib_name);
|
||||
if (lib == nullptr) {
|
||||
std::ostringstream kernel_source;
|
||||
kernel_source << metal::utils() << metal::gather();
|
||||
std::string out_type_str = get_type_string(out.dtype());
|
||||
std::string idx_type_str =
|
||||
nidx ? get_type_string(inputs[1].dtype()) : "bool";
|
||||
auto [idx_args, idx_arr] = make_index_args(idx_type_str, nidx);
|
||||
|
||||
// Index dimension specializations
|
||||
kernel_source << fmt::format(
|
||||
gather_kernels,
|
||||
type_to_name(out) + idx_type_name,
|
||||
out_type_str,
|
||||
idx_type_str,
|
||||
nidx,
|
||||
idx_args,
|
||||
idx_arr,
|
||||
idx_ndim);
|
||||
lib = d.get_library(lib_name, kernel_source.str());
|
||||
}
|
||||
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
auto kernel = d.get_kernel(kernel_name, lib);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
size_t slice_size = 1;
|
||||
@@ -102,12 +139,12 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
compute_encoder->setBytes(&idx_ndim, sizeof(int), 9);
|
||||
|
||||
// Set index buffers
|
||||
for (int i = 1; i < nidx + 1; ++i) {
|
||||
compute_encoder.set_input_array(inputs[i], 20 + i);
|
||||
for (int i = 0; i < nidx; ++i) {
|
||||
compute_encoder.set_input_array(inputs[i + 1], 20 + i);
|
||||
}
|
||||
|
||||
// Launch grid
|
||||
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
@@ -139,10 +176,6 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
auto& s = stream();
|
||||
auto& d = metal::device(s.device);
|
||||
|
||||
// Get kernel name
|
||||
std::ostringstream kname;
|
||||
std::string idx_type_name = nidx ? type_to_name(inputs[1]) : "";
|
||||
|
||||
int idx_ndim = nidx ? inputs[1].ndim() : 0;
|
||||
bool index_nd1_specialization = (idx_ndim == 1);
|
||||
|
||||
@@ -159,32 +192,86 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
index_nd1_specialization &= inputs[i].flags().row_contiguous;
|
||||
}
|
||||
|
||||
if (index_nd1_specialization) {
|
||||
kname << "scatter_1d_index" << type_to_name(out) << idx_type_name;
|
||||
} else {
|
||||
kname << "scatter" << type_to_name(out) << idx_type_name;
|
||||
}
|
||||
std::string lib_name;
|
||||
std::string kernel_name;
|
||||
std::string idx_type_name = nidx ? type_to_name(inputs[1]) : "";
|
||||
std::string op_name;
|
||||
switch (reduce_type_) {
|
||||
case Scatter::None:
|
||||
kname << "_none";
|
||||
op_name = "none";
|
||||
break;
|
||||
case Scatter::Sum:
|
||||
kname << "_sum";
|
||||
op_name = "sum";
|
||||
break;
|
||||
case Scatter::Prod:
|
||||
kname << "_prod";
|
||||
op_name = "prod";
|
||||
break;
|
||||
case Scatter::Max:
|
||||
kname << "_max";
|
||||
op_name = "max";
|
||||
break;
|
||||
case Scatter::Min:
|
||||
kname << "_min";
|
||||
op_name = "min";
|
||||
break;
|
||||
}
|
||||
kname << "_" << nidx;
|
||||
|
||||
{
|
||||
std::ostringstream kname;
|
||||
if (index_nd1_specialization) {
|
||||
kname << "scatter_1d_index" << type_to_name(out) << idx_type_name;
|
||||
} else {
|
||||
kname << "scatter" << type_to_name(out) << idx_type_name;
|
||||
}
|
||||
kname << "_" << op_name << "_" << nidx;
|
||||
lib_name = kname.str();
|
||||
kernel_name = kname.str();
|
||||
}
|
||||
|
||||
auto lib = d.get_library(lib_name);
|
||||
if (lib == nullptr) {
|
||||
std::ostringstream kernel_source;
|
||||
kernel_source << metal::utils() << metal::reduce_utils()
|
||||
<< metal::scatter();
|
||||
|
||||
std::string out_type_str = get_type_string(out.dtype());
|
||||
std::string idx_type_str =
|
||||
nidx ? get_type_string(inputs[1].dtype()) : "bool";
|
||||
std::string op_type;
|
||||
switch (reduce_type_) {
|
||||
case Scatter::None:
|
||||
op_type = "None";
|
||||
break;
|
||||
case Scatter::Sum:
|
||||
op_type = "Sum<{0}>";
|
||||
break;
|
||||
case Scatter::Prod:
|
||||
op_type = "Prod<{0}>";
|
||||
break;
|
||||
case Scatter::Max:
|
||||
op_type = "Max<{0}>";
|
||||
break;
|
||||
case Scatter::Min:
|
||||
op_type = "Min<{0}>";
|
||||
break;
|
||||
}
|
||||
if (reduce_type_ != Scatter::None) {
|
||||
op_type = fmt::format(op_type, out_type_str);
|
||||
}
|
||||
auto [idx_args, idx_arr] = make_index_args(idx_type_str, nidx);
|
||||
|
||||
kernel_source << fmt::format(
|
||||
scatter_kernels,
|
||||
type_to_name(out) + idx_type_name + "_" + op_name,
|
||||
out_type_str,
|
||||
idx_type_str,
|
||||
op_type,
|
||||
nidx,
|
||||
idx_args,
|
||||
idx_arr);
|
||||
lib = d.get_library(lib_name, kernel_source.str());
|
||||
}
|
||||
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
auto kernel = d.get_kernel(kernel_name, lib);
|
||||
|
||||
auto& upd = inputs.back();
|
||||
size_t nthreads = upd.size();
|
||||
@@ -209,14 +296,14 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
compute_encoder->setBytes(&upd_size, sizeof(size_t), 5);
|
||||
|
||||
// Set index buffers
|
||||
for (int i = 1; i < nidx + 1; ++i) {
|
||||
compute_encoder.set_input_array(inputs[i], 20 + i);
|
||||
for (int i = 0; i < nidx; ++i) {
|
||||
compute_encoder.set_input_array(inputs[i + 1], 20 + i);
|
||||
}
|
||||
|
||||
// Launch grid
|
||||
MTL::Size grid_dims = MTL::Size(upd_size, nthreads / upd_size, 1);
|
||||
MTL::Size group_dims = get_block_dims(upd_size, nthreads / upd_size, 1);
|
||||
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||
|
||||
} else {
|
||||
// Collect all idx shapes and strides into one place
|
||||
@@ -279,14 +366,14 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
compute_encoder->setBytes(&idx_ndim, sizeof(int), 13);
|
||||
|
||||
// Set index buffers
|
||||
for (int i = 1; i < nidx + 1; ++i) {
|
||||
compute_encoder.set_input_array(inputs[i], 20 + i);
|
||||
for (int i = 0; i < nidx; ++i) {
|
||||
compute_encoder.set_input_array(inputs[i + 1], 20 + i);
|
||||
}
|
||||
|
||||
// Launch grid
|
||||
MTL::Size grid_dims = MTL::Size(upd_size, nthreads / upd_size, 1);
|
||||
MTL::Size group_dims = get_block_dims(upd_size, nthreads / upd_size, 1);
|
||||
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||
}
|
||||
}
|
||||
|
||||
|
9
mlx/backend/metal/jit/arange.h
Normal file
9
mlx/backend/metal/jit/arange.h
Normal file
@@ -0,0 +1,9 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
constexpr std::string_view arange_kernels = R"(
|
||||
template [[host_name("{0}")]] [[kernel]] void arange<{1}>(
|
||||
constant const {1}& start,
|
||||
constant const {1}& step,
|
||||
device {1}* out,
|
||||
uint index [[thread_position_in_grid]]);
|
||||
)";
|
87
mlx/backend/metal/jit/binary.h
Normal file
87
mlx/backend/metal/jit/binary.h
Normal file
@@ -0,0 +1,87 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
constexpr std::string_view binary_kernels = R"(
|
||||
template [[host_name("ss{0}")]] [[kernel]]
|
||||
void binary_ss<{1}, {2}, {3}>(
|
||||
device const {1}* a,
|
||||
device const {1}* b,
|
||||
device {2}* c,
|
||||
uint index [[thread_position_in_grid]]);
|
||||
template [[host_name("vs{0}")]] [[kernel]]
|
||||
void binary_vs<{1}, {2}, {3}>(
|
||||
device const {1}* a,
|
||||
device const {1}* b,
|
||||
device {2}* c,
|
||||
uint index [[thread_position_in_grid]]);
|
||||
template [[host_name("sv{0}")]] [[kernel]]
|
||||
void binary_sv<{1}, {2}, {3}>(
|
||||
device const {1}* a,
|
||||
device const {1}* b,
|
||||
device {2}* c,
|
||||
uint index [[thread_position_in_grid]]);
|
||||
template [[host_name("vv{0}")]] [[kernel]]
|
||||
void binary_vv<{1}, {2}, {3}>(
|
||||
device const {1}* a,
|
||||
device const {1}* b,
|
||||
device {2}* c,
|
||||
uint index [[thread_position_in_grid]]);
|
||||
template [[host_name("g4{0}")]] [[kernel]] void
|
||||
binary_g_nd<{1}, {2}, {3}, 4>(
|
||||
device const {1}* a,
|
||||
device const {1}* b,
|
||||
device {2}* c,
|
||||
constant const int shape[4],
|
||||
constant const size_t a_strides[4],
|
||||
constant const size_t b_strides[4],
|
||||
uint3 index [[thread_position_in_grid]],
|
||||
uint3 grid_dim [[threads_per_grid]]);
|
||||
template [[host_name("g5{0}")]] [[kernel]] void
|
||||
binary_g_nd<{1}, {2}, {3}, 5>(
|
||||
device const {1}* a,
|
||||
device const {1}* b,
|
||||
device {2}* c,
|
||||
constant const int shape[5],
|
||||
constant const size_t a_strides[5],
|
||||
constant const size_t b_strides[5],
|
||||
uint3 index [[thread_position_in_grid]],
|
||||
uint3 grid_dim [[threads_per_grid]]);
|
||||
|
||||
template [[host_name("g1{0}")]] [[kernel]] void
|
||||
binary_g_nd1<{1}, {2}, {3}>(
|
||||
device const {1}* a,
|
||||
device const {1}* b,
|
||||
device {2}* c,
|
||||
constant const size_t& a_stride,
|
||||
constant const size_t& b_stride,
|
||||
uint index [[thread_position_in_grid]]);
|
||||
template [[host_name("g2{0}")]] [[kernel]] void
|
||||
binary_g_nd2<{1}, {2}, {3}>(
|
||||
device const {1}* a,
|
||||
device const {1}* b,
|
||||
device {2}* c,
|
||||
constant const size_t a_strides[2],
|
||||
constant const size_t b_strides[2],
|
||||
uint2 index [[thread_position_in_grid]],
|
||||
uint2 grid_dim [[threads_per_grid]]);
|
||||
template [[host_name("g3{0}")]] [[kernel]] void
|
||||
binary_g_nd3<{1}, {2}, {3}>(
|
||||
device const {1}* a,
|
||||
device const {1}* b,
|
||||
device {2}* c,
|
||||
constant const size_t a_strides[3],
|
||||
constant const size_t b_strides[3],
|
||||
uint3 index [[thread_position_in_grid]],
|
||||
uint3 grid_dim [[threads_per_grid]]);
|
||||
|
||||
template [[host_name("gn{0}")]] [[kernel]]
|
||||
void binary_g<{1}, {2}, {3}>(
|
||||
device const {1}* a,
|
||||
device const {1}* b,
|
||||
device {2}* c,
|
||||
constant const int* shape,
|
||||
constant const size_t* a_strides,
|
||||
constant const size_t* b_strides,
|
||||
constant const int& ndim,
|
||||
uint3 index [[thread_position_in_grid]],
|
||||
uint3 grid_dim [[threads_per_grid]]);
|
||||
)";
|
98
mlx/backend/metal/jit/binary_two.h
Normal file
98
mlx/backend/metal/jit/binary_two.h
Normal file
@@ -0,0 +1,98 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
constexpr std::string_view binary_two_kernels = R"(
|
||||
template [[host_name("ss{0}")]] [[kernel]]
|
||||
void binary_ss<{1}, {2}, {3}>(
|
||||
device const {1}* a,
|
||||
device const {1}* b,
|
||||
device {2}* c,
|
||||
device {2}* d,
|
||||
uint index [[thread_position_in_grid]]);
|
||||
template [[host_name("vs{0}")]] [[kernel]]
|
||||
void binary_vs<{1}, {2}, {3}>(
|
||||
device const {1}* a,
|
||||
device const {1}* b,
|
||||
device {2}* c,
|
||||
device {2}* d,
|
||||
uint index [[thread_position_in_grid]]);
|
||||
template [[host_name("sv{0}")]] [[kernel]]
|
||||
void binary_sv<{1}, {2}, {3}>(
|
||||
device const {1}* a,
|
||||
device const {1}* b,
|
||||
device {2}* c,
|
||||
device {2}* d,
|
||||
uint index [[thread_position_in_grid]]);
|
||||
template [[host_name("vv{0}")]] [[kernel]]
|
||||
void binary_vv<{1}, {2}, {3}>(
|
||||
device const {1}* a,
|
||||
device const {1}* b,
|
||||
device {2}* c,
|
||||
device {2}* d,
|
||||
uint index [[thread_position_in_grid]]);
|
||||
|
||||
template [[host_name("g4{0}")]] [[kernel]] void
|
||||
binary_g_nd<{1}, {2}, {3}, 4>(
|
||||
device const {1}* a,
|
||||
device const {1}* b,
|
||||
device {2}* c,
|
||||
device {2}* d,
|
||||
constant const int shape[4],
|
||||
constant const size_t a_strides[4],
|
||||
constant const size_t b_strides[4],
|
||||
uint3 index [[thread_position_in_grid]],
|
||||
uint3 grid_dim [[threads_per_grid]]);
|
||||
template [[host_name("g5{0}")]] [[kernel]] void
|
||||
binary_g_nd<{1}, {2}, {3}, 5>(
|
||||
device const {1}* a,
|
||||
device const {1}* b,
|
||||
device {2}* c,
|
||||
device {2}* d,
|
||||
constant const int shape[5],
|
||||
constant const size_t a_strides[5],
|
||||
constant const size_t b_strides[5],
|
||||
uint3 index [[thread_position_in_grid]],
|
||||
uint3 grid_dim [[threads_per_grid]]);
|
||||
|
||||
template [[host_name("g1{0}")]] [[kernel]] void
|
||||
binary_g_nd1<{1}, {2}, {3}>(
|
||||
device const {1}* a,
|
||||
device const {1}* b,
|
||||
device {2}* c,
|
||||
device {2}* d,
|
||||
constant const size_t& a_stride,
|
||||
constant const size_t& b_stride,
|
||||
uint index [[thread_position_in_grid]]);
|
||||
template [[host_name("g2{0}")]] [[kernel]] void
|
||||
binary_g_nd2<{1}, {2}, {3}>(
|
||||
device const {1}* a,
|
||||
device const {1}* b,
|
||||
device {2}* c,
|
||||
device {2}* d,
|
||||
constant const size_t a_strides[2],
|
||||
constant const size_t b_strides[2],
|
||||
uint2 index [[thread_position_in_grid]],
|
||||
uint2 grid_dim [[threads_per_grid]]);
|
||||
template [[host_name("g3{0}")]] [[kernel]] void
|
||||
binary_g_nd3<{1}, {2}, {3}>(
|
||||
device const {1}* a,
|
||||
device const {1}* b,
|
||||
device {2}* c,
|
||||
device {2}* d,
|
||||
constant const size_t a_strides[3],
|
||||
constant const size_t b_strides[3],
|
||||
uint3 index [[thread_position_in_grid]],
|
||||
uint3 grid_dim [[threads_per_grid]]);
|
||||
|
||||
template [[host_name("gn{0}")]] [[kernel]]
|
||||
void binary_g<{1}, {2}, {3}>(
|
||||
device const {1}* a,
|
||||
device const {1}* b,
|
||||
device {2}* c,
|
||||
device {2}* d,
|
||||
constant const int* shape,
|
||||
constant const size_t* a_strides,
|
||||
constant const size_t* b_strides,
|
||||
constant const int& ndim,
|
||||
uint3 index [[thread_position_in_grid]],
|
||||
uint3 grid_dim [[threads_per_grid]]);
|
||||
)";
|
100
mlx/backend/metal/jit/copy.h
Normal file
100
mlx/backend/metal/jit/copy.h
Normal file
@@ -0,0 +1,100 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
constexpr std::string_view copy_kernels = R"(
|
||||
template [[host_name("s_{0}")]] [[kernel]] void copy_s<{1}, {2}>(
|
||||
device const {1}* src [[buffer(0)]],
|
||||
device {2}* dst [[buffer(1)]],
|
||||
uint index [[thread_position_in_grid]]);
|
||||
template [[host_name("v_{0}")]] [[kernel]] void copy_v<{1}, {2}>(
|
||||
device const {1}* src [[buffer(0)]],
|
||||
device {2}* dst [[buffer(1)]],
|
||||
uint index [[thread_position_in_grid]]);
|
||||
|
||||
template [[host_name("g4_{0}")]] [[kernel]] void
|
||||
copy_g_nd<{1}, {2}, 4>(
|
||||
device const {1}* src [[buffer(0)]],
|
||||
device {2}* dst [[buffer(1)]],
|
||||
constant const int* src_shape [[buffer(2)]],
|
||||
constant const int64_t* src_strides [[buffer(3)]],
|
||||
uint3 index [[thread_position_in_grid]],
|
||||
uint3 grid_dim [[threads_per_grid]]);
|
||||
template [[host_name("gg4_{0}")]] [[kernel]] void
|
||||
copy_gg_nd<{1}, {2}, 4>(
|
||||
device const {1}* src [[buffer(0)]],
|
||||
device {2}* dst [[buffer(1)]],
|
||||
constant const int* src_shape [[buffer(2)]],
|
||||
constant const int64_t* src_strides [[buffer(3)]],
|
||||
constant const int64_t* dst_strides [[buffer(4)]],
|
||||
uint3 index [[thread_position_in_grid]]);
|
||||
template [[host_name("g5_{0}")]] [[kernel]] void
|
||||
copy_g_nd<{1}, {2}, 5>(
|
||||
device const {1}* src [[buffer(0)]],
|
||||
device {2}* dst [[buffer(1)]],
|
||||
constant const int* src_shape [[buffer(2)]],
|
||||
constant const int64_t* src_strides [[buffer(3)]],
|
||||
uint3 index [[thread_position_in_grid]],
|
||||
uint3 grid_dim [[threads_per_grid]]);
|
||||
template [[host_name("gg5_{0}")]] [[kernel]] void
|
||||
copy_gg_nd<{1}, {2}, 5>(
|
||||
device const {1}* src [[buffer(0)]],
|
||||
device {2}* dst [[buffer(1)]],
|
||||
constant const int* src_shape [[buffer(2)]],
|
||||
constant const int64_t* src_strides [[buffer(3)]],
|
||||
constant const int64_t* dst_strides [[buffer(4)]],
|
||||
uint3 index [[thread_position_in_grid]]);
|
||||
template [[host_name("g1_{0}")]] [[kernel]] void copy_g_nd1<{1}, {2}>(
|
||||
device const {1}* src [[buffer(0)]],
|
||||
device {2}* dst [[buffer(1)]],
|
||||
constant const int64_t& src_stride [[buffer(3)]],
|
||||
uint index [[thread_position_in_grid]]);
|
||||
template [[host_name("g2_{0}")]] [[kernel]] void copy_g_nd2<{1}, {2}>(
|
||||
device const {1}* src [[buffer(0)]],
|
||||
device {2}* dst [[buffer(1)]],
|
||||
constant const int64_t* src_strides [[buffer(3)]],
|
||||
uint2 index [[thread_position_in_grid]],
|
||||
uint2 grid_dim [[threads_per_grid]]);
|
||||
template [[host_name("g3_{0}")]] [[kernel]] void copy_g_nd3<{1}, {2}>(
|
||||
device const {1}* src [[buffer(0)]],
|
||||
device {2}* dst [[buffer(1)]],
|
||||
constant const int64_t* src_strides [[buffer(3)]],
|
||||
uint3 index [[thread_position_in_grid]],
|
||||
uint3 grid_dim [[threads_per_grid]]);
|
||||
template [[host_name("gg1_{0}")]] [[kernel]] void
|
||||
copy_gg_nd1<{1}, {2}>(
|
||||
device const {1}* src [[buffer(0)]],
|
||||
device {2}* dst [[buffer(1)]],
|
||||
constant const int64_t& src_stride [[buffer(3)]],
|
||||
constant const int64_t& dst_stride [[buffer(4)]],
|
||||
uint index [[thread_position_in_grid]]);
|
||||
template [[host_name("gg2_{0}")]] [[kernel]] void
|
||||
copy_gg_nd2<{1}, {2}>(
|
||||
device const {1}* src [[buffer(0)]],
|
||||
device {2}* dst [[buffer(1)]],
|
||||
constant const int64_t* src_strides [[buffer(3)]],
|
||||
constant const int64_t* dst_strides [[buffer(4)]],
|
||||
uint2 index [[thread_position_in_grid]]);
|
||||
template [[host_name("gg3_{0}")]] [[kernel]] void
|
||||
copy_gg_nd3<{1}, {2}>(
|
||||
device const {1}* src [[buffer(0)]],
|
||||
device {2}* dst [[buffer(1)]],
|
||||
constant const int64_t* src_strides [[buffer(3)]],
|
||||
constant const int64_t* dst_strides [[buffer(4)]],
|
||||
uint3 index [[thread_position_in_grid]]);
|
||||
|
||||
template [[host_name("g_{0}")]] [[kernel]] void copy_g<{1}, {2}>(
|
||||
device const {1}* src [[buffer(0)]],
|
||||
device {2}* dst [[buffer(1)]],
|
||||
constant const int* src_shape [[buffer(2)]],
|
||||
constant const int64_t* src_strides [[buffer(3)]],
|
||||
constant const int& ndim [[buffer(5)]],
|
||||
uint3 index [[thread_position_in_grid]],
|
||||
uint3 grid_dim [[threads_per_grid]]);
|
||||
template [[host_name("gg_{0}")]] [[kernel]] void copy_gg<{1}, {2}>(
|
||||
device const {1}* src [[buffer(0)]],
|
||||
device {2}* dst [[buffer(1)]],
|
||||
constant const int* src_shape [[buffer(2)]],
|
||||
constant const int64_t* src_strides [[buffer(3)]],
|
||||
constant const int64_t* dst_strides [[buffer(4)]],
|
||||
constant const int& ndim [[buffer(5)]],
|
||||
uint3 index [[thread_position_in_grid]]);
|
||||
)";
|
34
mlx/backend/metal/jit/includes.h
Normal file
34
mlx/backend/metal/jit/includes.h
Normal file
@@ -0,0 +1,34 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
namespace mlx::core::metal {
|
||||
|
||||
const char* utils();
|
||||
const char* binary_ops();
|
||||
const char* unary_ops();
|
||||
const char* ternary_ops();
|
||||
const char* reduce_utils();
|
||||
const char* gather();
|
||||
const char* scatter();
|
||||
|
||||
const char* arange();
|
||||
const char* unary();
|
||||
const char* binary();
|
||||
const char* binary_two();
|
||||
const char* copy();
|
||||
const char* ternary();
|
||||
const char* scan();
|
||||
const char* softmax();
|
||||
const char* sort();
|
||||
const char* reduce();
|
||||
|
||||
const char* gemm();
|
||||
const char* steel_gemm_fused();
|
||||
const char* steel_gemm_masked();
|
||||
const char* steel_gemm_splitk();
|
||||
const char* conv();
|
||||
const char* steel_conv();
|
||||
const char* steel_conv_general();
|
||||
|
||||
} // namespace mlx::core::metal
|
81
mlx/backend/metal/jit/indexing.h
Normal file
81
mlx/backend/metal/jit/indexing.h
Normal file
@@ -0,0 +1,81 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
constexpr std::string_view gather_kernels = R"(
|
||||
[[kernel]] void gather{0}_{3}_{6}(
|
||||
const device {1}* src [[buffer(0)]],
|
||||
device {1}* out [[buffer(1)]],
|
||||
const constant int* src_shape [[buffer(2)]],
|
||||
const constant size_t* src_strides [[buffer(3)]],
|
||||
const constant size_t& src_ndim [[buffer(4)]],
|
||||
const constant int* slice_sizes [[buffer(5)]],
|
||||
const constant int* axes [[buffer(6)]],
|
||||
const constant int* idx_shapes [[buffer(7)]],
|
||||
const constant size_t* idx_strides [[buffer(8)]],
|
||||
const constant int& idx_ndim [[buffer(9)]],
|
||||
{4}
|
||||
uint2 index [[thread_position_in_grid]],
|
||||
uint2 grid_dim [[threads_per_grid]]) {{
|
||||
Indices<{2}, {3}> idxs{{
|
||||
{{ {5} }}, idx_shapes, idx_strides, idx_ndim}};
|
||||
|
||||
return gather_impl<{1}, {2}, {3}, {6}>(
|
||||
src,
|
||||
out,
|
||||
src_shape,
|
||||
src_strides,
|
||||
src_ndim,
|
||||
slice_sizes,
|
||||
axes,
|
||||
idxs,
|
||||
index,
|
||||
grid_dim);
|
||||
}}
|
||||
)";
|
||||
|
||||
constexpr std::string_view scatter_kernels = R"(
|
||||
[[kernel]] void scatter_1d_index{0}_{4}(
|
||||
const device {1}* updates [[buffer(1)]],
|
||||
device mlx_atomic<{1}>* out [[buffer(2)]],
|
||||
const constant int* out_shape [[buffer(3)]],
|
||||
const constant size_t* out_strides [[buffer(4)]],
|
||||
const constant size_t& upd_size [[buffer(5)]],
|
||||
{5}
|
||||
uint2 gid [[thread_position_in_grid]]) {{
|
||||
const array<const device {2}*, {4}> idx_buffers = {{ {6} }};
|
||||
return scatter_1d_index_impl<{1}, {2}, {3}, {4}>(
|
||||
updates, out, out_shape, out_strides, upd_size, idx_buffers, gid);
|
||||
}}
|
||||
|
||||
[[kernel]] void scatter{0}_{4}(
|
||||
const device {1}* updates [[buffer(1)]],
|
||||
device mlx_atomic<{1}>* out [[buffer(2)]],
|
||||
const constant int* upd_shape [[buffer(3)]],
|
||||
const constant size_t* upd_strides [[buffer(4)]],
|
||||
const constant size_t& upd_ndim [[buffer(5)]],
|
||||
const constant size_t& upd_size [[buffer(6)]],
|
||||
const constant int* out_shape [[buffer(7)]],
|
||||
const constant size_t* out_strides [[buffer(8)]],
|
||||
const constant size_t& out_ndim [[buffer(9)]],
|
||||
const constant int* axes [[buffer(10)]],
|
||||
const constant int* idx_shapes [[buffer(11)]],
|
||||
const constant size_t* idx_strides [[buffer(12)]],
|
||||
const constant int& idx_ndim [[buffer(13)]],
|
||||
{5}
|
||||
uint2 gid [[thread_position_in_grid]]) {{
|
||||
Indices<{2}, {4}> idxs{{ {{ {6} }}, idx_shapes, idx_strides, idx_ndim}};
|
||||
|
||||
return scatter_impl<{1}, {2}, {3}, {4}>(
|
||||
updates,
|
||||
out,
|
||||
upd_shape,
|
||||
upd_strides,
|
||||
upd_ndim,
|
||||
upd_size,
|
||||
out_shape,
|
||||
out_strides,
|
||||
out_ndim,
|
||||
axes,
|
||||
idxs,
|
||||
gid);
|
||||
}}
|
||||
)";
|
168
mlx/backend/metal/jit/reduce.h
Normal file
168
mlx/backend/metal/jit/reduce.h
Normal file
@@ -0,0 +1,168 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
constexpr std::string_view reduce_init_kernels = R"(
|
||||
[[kernel]] void {0}(
|
||||
device {1}* out [[buffer(0)]],
|
||||
uint tid [[thread_position_in_grid]]) {{
|
||||
out[tid] = {2}<{1}>::init;
|
||||
}}
|
||||
)";
|
||||
|
||||
constexpr std::string_view reduce_kernels = R"(
|
||||
template [[host_name("all_{0}")]] [[kernel]] void
|
||||
all_reduce<{1}, {2}, {3}<{2}>>(
|
||||
const device {1}* in [[buffer(0)]],
|
||||
device mlx_atomic<{2}>* out [[buffer(1)]],
|
||||
const device size_t& in_size [[buffer(2)]],
|
||||
uint gid [[thread_position_in_grid]],
|
||||
uint lid [[thread_position_in_threadgroup]],
|
||||
uint grid_size [[threads_per_grid]],
|
||||
uint simd_per_group [[simdgroups_per_threadgroup]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
|
||||
template [[host_name("colGeneral_{0}")]] [[kernel]] void
|
||||
col_reduce_general<{1}, {2}, {3}<{2}>>(
|
||||
const device {1}* in [[buffer(0)]],
|
||||
device mlx_atomic<{2}>* out [[buffer(1)]],
|
||||
const constant size_t& reduction_size [[buffer(2)]],
|
||||
const constant size_t& reduction_stride [[buffer(3)]],
|
||||
const constant size_t& out_size [[buffer(4)]],
|
||||
const constant int* shape [[buffer(5)]],
|
||||
const constant size_t* strides [[buffer(6)]],
|
||||
const constant int& ndim [[buffer(7)]],
|
||||
threadgroup {2}* local_data [[threadgroup(0)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]],
|
||||
uint3 lsize [[threads_per_threadgroup]]);
|
||||
template [[host_name("colSmall_{0}")]] [[kernel]] void
|
||||
col_reduce_small<{1}, {2}, {3}<{2}>>(
|
||||
const device {1}* in [[buffer(0)]],
|
||||
device {2}* out [[buffer(1)]],
|
||||
const constant size_t& reduction_size [[buffer(2)]],
|
||||
const constant size_t& reduction_stride [[buffer(3)]],
|
||||
const constant size_t& out_size [[buffer(4)]],
|
||||
const constant int* shape [[buffer(5)]],
|
||||
const constant size_t* strides [[buffer(6)]],
|
||||
const constant int& ndim [[buffer(7)]],
|
||||
const constant size_t& non_col_reductions [[buffer(8)]],
|
||||
const constant int* non_col_shapes [[buffer(9)]],
|
||||
const constant size_t* non_col_strides [[buffer(10)]],
|
||||
const constant int& non_col_ndim [[buffer(11)]],
|
||||
uint tid [[thread_position_in_grid]]);
|
||||
template [[host_name("rowGeneralSmall_{0}")]] [[kernel]] void
|
||||
row_reduce_general_small<{1}, {2}, {3}<{2}>>(
|
||||
const device {1}* in [[buffer(0)]],
|
||||
device {2}* out [[buffer(1)]],
|
||||
const constant size_t& reduction_size [[buffer(2)]],
|
||||
const constant size_t& out_size [[buffer(3)]],
|
||||
const constant size_t& non_row_reductions [[buffer(4)]],
|
||||
const constant int* shape [[buffer(5)]],
|
||||
const constant size_t* strides [[buffer(6)]],
|
||||
const constant int& ndim [[buffer(7)]],
|
||||
uint lid [[thread_position_in_grid]]);
|
||||
template [[host_name("rowGeneralMed_{0}")]] [[kernel]] void
|
||||
row_reduce_general_med<{1}, {2}, {3}<{2}>>(
|
||||
const device {1}* in [[buffer(0)]],
|
||||
device {2}* out [[buffer(1)]],
|
||||
const constant size_t& reduction_size [[buffer(2)]],
|
||||
const constant size_t& out_size [[buffer(3)]],
|
||||
const constant size_t& non_row_reductions [[buffer(4)]],
|
||||
const constant int* shape [[buffer(5)]],
|
||||
const constant size_t* strides [[buffer(6)]],
|
||||
const constant int& ndim [[buffer(7)]],
|
||||
uint tid [[threadgroup_position_in_grid]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_per_group [[dispatch_simdgroups_per_threadgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
|
||||
template [[host_name("rowGeneral_{0}")]] [[kernel]] void
|
||||
row_reduce_general<{1}, {2}, {3}<{2}>>(
|
||||
const device {1}* in [[buffer(0)]],
|
||||
device mlx_atomic<{2}>* out [[buffer(1)]],
|
||||
const constant size_t& reduction_size [[buffer(2)]],
|
||||
const constant size_t& out_size [[buffer(3)]],
|
||||
const constant size_t& non_row_reductions [[buffer(4)]],
|
||||
const constant int* shape [[buffer(5)]],
|
||||
const constant size_t* strides [[buffer(6)]],
|
||||
const constant int& ndim [[buffer(7)]],
|
||||
uint3 lid [[thread_position_in_threadgroup]],
|
||||
uint3 lsize [[threads_per_threadgroup]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_per_group [[simdgroups_per_threadgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
|
||||
)";
|
||||
|
||||
constexpr std::string_view reduce_non_atomic_kernels = R"(
|
||||
template [[host_name("allNoAtomics_{0}")]] [[kernel]] void
|
||||
all_reduce_no_atomics<{1}, {2}, {3}<{2}>>(
|
||||
const device {1}* in [[buffer(0)]],
|
||||
device {2}* out [[buffer(1)]],
|
||||
const device size_t& in_size [[buffer(2)]],
|
||||
uint gid [[thread_position_in_grid]],
|
||||
uint lid [[thread_position_in_threadgroup]],
|
||||
uint grid_size [[threads_per_grid]],
|
||||
uint simd_per_group [[simdgroups_per_threadgroup]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||
uint thread_group_id [[threadgroup_position_in_grid]]);
|
||||
|
||||
template [[host_name("colGeneralNoAtomics_{0}")]] [[kernel]] void
|
||||
col_reduce_general_no_atomics<{1}, {2}, {3}<{2}>>(
|
||||
const device {1}* in [[buffer(0)]],
|
||||
device {2}* out [[buffer(1)]],
|
||||
const constant size_t& reduction_size [[buffer(2)]],
|
||||
const constant size_t& reduction_stride [[buffer(3)]],
|
||||
const constant size_t& out_size [[buffer(4)]],
|
||||
const constant int* shape [[buffer(5)]],
|
||||
const constant size_t* strides [[buffer(6)]],
|
||||
const constant int& ndim [[buffer(7)]],
|
||||
threadgroup {2}* local_data [[threadgroup(0)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]],
|
||||
uint3 gid [[thread_position_in_grid]],
|
||||
uint3 lsize [[threads_per_threadgroup]],
|
||||
uint3 gsize [[threads_per_grid]]);
|
||||
template [[host_name("colSmall_{0}")]] [[kernel]] void
|
||||
col_reduce_small<{1}, {2}, {3}<{2}>>(
|
||||
const device {1}* in [[buffer(0)]],
|
||||
device {2}* out [[buffer(1)]],
|
||||
const constant size_t& reduction_size [[buffer(2)]],
|
||||
const constant size_t& reduction_stride [[buffer(3)]],
|
||||
const constant size_t& out_size [[buffer(4)]],
|
||||
const constant int* shape [[buffer(5)]],
|
||||
const constant size_t* strides [[buffer(6)]],
|
||||
const constant int& ndim [[buffer(7)]],
|
||||
const constant size_t& non_col_reductions [[buffer(8)]],
|
||||
const constant int* non_col_shapes [[buffer(9)]],
|
||||
const constant size_t* non_col_strides [[buffer(10)]],
|
||||
const constant int& non_col_ndim [[buffer(11)]],
|
||||
uint tid [[thread_position_in_grid]]);
|
||||
template [[host_name("rowGeneralSmall_{0}")]] [[kernel]] void
|
||||
row_reduce_general_small<{1}, {2}, {3}<{2}>>(
|
||||
const device {1}* in [[buffer(0)]],
|
||||
device {2}* out [[buffer(1)]],
|
||||
const constant size_t& reduction_size [[buffer(2)]],
|
||||
const constant size_t& out_size [[buffer(3)]],
|
||||
const constant size_t& non_row_reductions [[buffer(4)]],
|
||||
const constant int* shape [[buffer(5)]],
|
||||
const constant size_t* strides [[buffer(6)]],
|
||||
const constant int& ndim [[buffer(7)]],
|
||||
uint lid [[thread_position_in_grid]]);
|
||||
template [[host_name("rowGeneralNoAtomics_{0}")]] [[kernel]] void
|
||||
row_reduce_general_no_atomics<{1}, {2}, {3}<{2}>>(
|
||||
const device {1}* in [[buffer(0)]],
|
||||
device {2}* out [[buffer(1)]],
|
||||
const constant size_t& reduction_size [[buffer(2)]],
|
||||
const constant size_t& out_size [[buffer(3)]],
|
||||
const constant size_t& non_row_reductions [[buffer(4)]],
|
||||
const constant int* shape [[buffer(5)]],
|
||||
const constant size_t* strides [[buffer(6)]],
|
||||
const constant int& ndim [[buffer(7)]],
|
||||
uint3 lid [[thread_position_in_threadgroup]],
|
||||
uint3 lsize [[threads_per_threadgroup]],
|
||||
uint3 gsize [[threads_per_grid]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_per_group [[simdgroups_per_threadgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
|
||||
)";
|
26
mlx/backend/metal/jit/scan.h
Normal file
26
mlx/backend/metal/jit/scan.h
Normal file
@@ -0,0 +1,26 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
constexpr std::string_view scan_kernels = R"(
|
||||
template [[host_name("contig_{0}")]] [[kernel]] void
|
||||
contiguous_scan<{1}, {2}, {3}<{2}>, 4, {4}, {5}>(
|
||||
const device {1}* in [[buffer(0)]],
|
||||
device {2}* out [[buffer(1)]],
|
||||
const constant size_t& axis_size [[buffer(2)]],
|
||||
uint gid [[thread_position_in_grid]],
|
||||
uint lid [[thread_position_in_threadgroup]],
|
||||
uint lsize [[threads_per_threadgroup]],
|
||||
uint simd_size [[threads_per_simdgroup]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
|
||||
|
||||
template [[host_name("strided_{0}")]] [[kernel]] void
|
||||
strided_scan<{1}, {2}, {3}<{2}>, 4, {4}, {5}>(
|
||||
const device {1}* in [[buffer(0)]],
|
||||
device {2}* out [[buffer(1)]],
|
||||
const constant size_t& axis_size [[buffer(2)]],
|
||||
const constant size_t& stride [[buffer(3)]],
|
||||
uint2 gid [[thread_position_in_grid]],
|
||||
uint2 lid [[thread_position_in_threadgroup]],
|
||||
uint2 lsize [[threads_per_threadgroup]],
|
||||
uint simd_size [[threads_per_simdgroup]]);
|
||||
)";
|
23
mlx/backend/metal/jit/softmax.h
Normal file
23
mlx/backend/metal/jit/softmax.h
Normal file
@@ -0,0 +1,23 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
constexpr std::string_view softmax_kernels = R"(
|
||||
template [[host_name("block_{0}")]] [[kernel]] void
|
||||
softmax_single_row<{1}, {2}>(
|
||||
const device {1}* in,
|
||||
device {1}* out,
|
||||
constant int& axis_size,
|
||||
uint gid [[thread_position_in_grid]],
|
||||
uint _lid [[thread_position_in_threadgroup]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
|
||||
template [[host_name("looped_{0}")]] [[kernel]] void
|
||||
softmax_looped<{1}, {2}>(
|
||||
const device {1}* in,
|
||||
device {1}* out,
|
||||
constant int& axis_size,
|
||||
uint gid [[threadgroup_position_in_grid]],
|
||||
uint lid [[thread_position_in_threadgroup]],
|
||||
uint lsize [[threads_per_threadgroup]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
|
||||
)";
|
81
mlx/backend/metal/jit/sort.h
Normal file
81
mlx/backend/metal/jit/sort.h
Normal file
@@ -0,0 +1,81 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
constexpr std::string_view block_sort_kernels = R"(
|
||||
template [[host_name("carg_{0}")]] [[kernel]] void
|
||||
block_sort<{1}, {2}, true, {3}, {4}>(
|
||||
const device {1}* inp [[buffer(0)]],
|
||||
device {2}* out [[buffer(1)]],
|
||||
const constant int& size_sorted_axis [[buffer(2)]],
|
||||
const constant int& stride_sorted_axis [[buffer(3)]],
|
||||
const constant int& stride_segment_axis [[buffer(4)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]]);
|
||||
template [[host_name("ncarg_{0}")]] [[kernel]] void
|
||||
block_sort_nc<{1}, {2}, true, {3}, {4}>(
|
||||
const device {1}* inp [[buffer(0)]],
|
||||
device {2}* out [[buffer(1)]],
|
||||
const constant int& size_sorted_axis [[buffer(2)]],
|
||||
const constant int& stride_sorted_axis [[buffer(3)]],
|
||||
const constant int& nc_dim [[buffer(4)]],
|
||||
const device int* nc_shape [[buffer(5)]],
|
||||
const device size_t* nc_strides [[buffer(6)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]]);
|
||||
template [[host_name("c_{0}")]] [[kernel]] void
|
||||
block_sort<{1}, {2}, false, {3}, {4}>(
|
||||
const device {1}* inp [[buffer(0)]],
|
||||
device {2}* out [[buffer(1)]],
|
||||
const constant int& size_sorted_axis [[buffer(2)]],
|
||||
const constant int& stride_sorted_axis [[buffer(3)]],
|
||||
const constant int& stride_segment_axis [[buffer(4)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]]);
|
||||
template [[host_name("nc_{0}")]] [[kernel]] void
|
||||
block_sort_nc<{1}, {2}, false, {3}, {4}>(
|
||||
const device {1}* inp [[buffer(0)]],
|
||||
device {2}* out [[buffer(1)]],
|
||||
const constant int& size_sorted_axis [[buffer(2)]],
|
||||
const constant int& stride_sorted_axis [[buffer(3)]],
|
||||
const constant int& nc_dim [[buffer(4)]],
|
||||
const device int* nc_shape [[buffer(5)]],
|
||||
const device size_t* nc_strides [[buffer(6)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]]);
|
||||
)";
|
||||
|
||||
constexpr std::string_view multiblock_sort_kernels = R"(
|
||||
template [[host_name("sort_{0}")]] [[kernel]] void
|
||||
mb_block_sort<{1}, {2}, true, {3}, {4}>(
|
||||
const device {1}* inp [[buffer(0)]],
|
||||
device {1}* out_vals [[buffer(1)]],
|
||||
device {2}* out_idxs [[buffer(2)]],
|
||||
const constant int& size_sorted_axis [[buffer(3)]],
|
||||
const constant int& stride_sorted_axis [[buffer(4)]],
|
||||
const constant int& nc_dim [[buffer(5)]],
|
||||
const device int* nc_shape [[buffer(6)]],
|
||||
const device size_t* nc_strides [[buffer(7)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]]);
|
||||
template [[host_name("partition_{0}")]] [[kernel]] void
|
||||
mb_block_partition<{1}, {2}, true, {3}, {4}>(
|
||||
device {2}* block_partitions [[buffer(0)]],
|
||||
const device {1}* dev_vals [[buffer(1)]],
|
||||
const device {2}* dev_idxs [[buffer(2)]],
|
||||
const constant int& size_sorted_axis [[buffer(3)]],
|
||||
const constant int& merge_tiles [[buffer(4)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]],
|
||||
uint3 tgp_dims [[threads_per_threadgroup]]);
|
||||
template [[host_name("merge_{0}")]] [[kernel]] void
|
||||
mb_block_merge<{1}, {2}, true, {3}, {4}>(
|
||||
const device {2}* block_partitions [[buffer(0)]],
|
||||
const device {1}* dev_vals_in [[buffer(1)]],
|
||||
const device {2}* dev_idxs_in [[buffer(2)]],
|
||||
device {1}* dev_vals_out [[buffer(3)]],
|
||||
device {2}* dev_idxs_out [[buffer(4)]],
|
||||
const constant int& size_sorted_axis [[buffer(5)]],
|
||||
const constant int& merge_tiles [[buffer(6)]],
|
||||
const constant int& num_tiles [[buffer(7)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]]);
|
||||
)";
|
32
mlx/backend/metal/jit/steel_conv.h
Normal file
32
mlx/backend/metal/jit/steel_conv.h
Normal file
@@ -0,0 +1,32 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
constexpr std::string_view steel_conv_kernels = R"(
|
||||
template [[host_name("{name}")]] [[kernel]] void
|
||||
implicit_gemm_conv_2d<{itype}, {bm}, {bn}, {bk}, {wm}, {wn}, {n_channels}, {small_filter}>(
|
||||
const device {itype}* A [[buffer(0)]],
|
||||
const device {itype}* B [[buffer(1)]],
|
||||
device {itype}* C [[buffer(2)]],
|
||||
const constant MLXConvParams<2>* params [[buffer(3)]],
|
||||
const constant ImplicitGemmConv2DParams* gemm_params [[buffer(4)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]],
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lid [[thread_index_in_simdgroup]]);
|
||||
)";
|
||||
|
||||
constexpr std::string_view steel_conv_general_kernels = R"(
|
||||
template [[host_name("{name}")]] [[kernel]] void
|
||||
implicit_gemm_conv_2d_general<{itype}, {bm}, {bn}, {bk}, {wm}, {wn}>(
|
||||
const device {itype}* A [[buffer(0)]],
|
||||
const device {itype}* B [[buffer(1)]],
|
||||
device {itype}* C [[buffer(2)]],
|
||||
const constant MLXConvParams<2>* params [[buffer(3)]],
|
||||
const constant ImplicitGemmConv2DParams* gemm_params [[buffer(4)]],
|
||||
const constant Conv2DGeneralJumpParams* jump_params [[buffer(5)]],
|
||||
const constant Conv2DGeneralBaseInfo* base_h [[buffer(6)]],
|
||||
const constant Conv2DGeneralBaseInfo* base_w [[buffer(7)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]],
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lid [[thread_index_in_simdgroup]]);
|
||||
)";
|
106
mlx/backend/metal/jit/steel_gemm.h
Normal file
106
mlx/backend/metal/jit/steel_gemm.h
Normal file
@@ -0,0 +1,106 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
constexpr std::string_view steel_gemm_fused_kernels = R"(
|
||||
template [[host_name("{name}")]]
|
||||
[[kernel]] void gemm<{itype}, {bm}, {bn}, {bk}, {wm}, {wn}, {trans_a}, {trans_b}, float>(
|
||||
const device {itype} *A [[buffer(0)]],
|
||||
const device {itype} *B [[buffer(1)]],
|
||||
const device {itype} *C [[buffer(2), function_constant(use_out_source)]],
|
||||
device {itype} *D [[buffer(3)]],
|
||||
const constant GEMMParams* params [[buffer(4)]],
|
||||
const constant GEMMAddMMParams* addmm_params [[buffer(5), function_constant(use_out_source)]],
|
||||
const constant int* batch_shape [[buffer(6)]],
|
||||
const constant size_t* batch_strides [[buffer(7)]],
|
||||
const constant uint32_t* lhs_indices [[buffer(10), function_constant(do_gather)]],
|
||||
const constant uint32_t* rhs_indices [[buffer(11), function_constant(do_gather)]],
|
||||
const constant uint32_t* C_indices [[buffer(12), function_constant(gather_bias)]],
|
||||
const constant int* operand_shape [[buffer(13), function_constant(do_gather)]],
|
||||
const constant size_t* operand_strides [[buffer(14), function_constant(do_gather)]],
|
||||
const constant packed_int3& operand_batch_ndim [[buffer(15), function_constant(do_gather)]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]]);
|
||||
)";
|
||||
|
||||
constexpr std::string_view steel_gemm_masked_kernels = R"(
|
||||
template [[host_name("{name}")]] [[kernel]] void
|
||||
block_masked_gemm<
|
||||
{itype},
|
||||
{outmasktype},
|
||||
{opmasktype},
|
||||
{bm},
|
||||
{bn},
|
||||
{bk},
|
||||
{wm},
|
||||
{wn},
|
||||
{trans_a},
|
||||
{trans_b},
|
||||
{mn_aligned},
|
||||
{k_aligned}>(
|
||||
const device {itype}* A [[buffer(0)]],
|
||||
const device {itype}* B [[buffer(1)]],
|
||||
device {itype}* D [[buffer(3)]],
|
||||
const constant GEMMParams* params [[buffer(4)]],
|
||||
const constant int* batch_shape [[buffer(6)]],
|
||||
const constant size_t* batch_strides [[buffer(7)]],
|
||||
const device {outmasktype}* out_mask [[buffer(10)]],
|
||||
const device {opmasktype}* lhs_mask [[buffer(11)]],
|
||||
const device {opmasktype}* rhs_mask [[buffer(12)]],
|
||||
const constant int* mask_strides [[buffer(13)]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]]);
|
||||
)";
|
||||
|
||||
constexpr std::string_view steel_gemm_splitk_kernels = R"(
|
||||
template [[host_name("{name}")]] [[kernel]] void
|
||||
gemm_splitk<
|
||||
{itype},
|
||||
{otype},
|
||||
{bm},
|
||||
{bn},
|
||||
{bk},
|
||||
{wm},
|
||||
{wn},
|
||||
{trans_a},
|
||||
{trans_b},
|
||||
{mn_aligned},
|
||||
{k_aligned}>(
|
||||
const device {itype}* A [[buffer(0)]],
|
||||
const device {itype}* B [[buffer(1)]],
|
||||
device {otype}* C [[buffer(2)]],
|
||||
const constant GEMMSpiltKParams* params [[buffer(3)]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]]);
|
||||
)";
|
||||
|
||||
constexpr std::string_view steel_gemm_splitk_accum_kernels = R"(
|
||||
template [[host_name("{name}")]] [[kernel]] void
|
||||
gemm_splitk_accum<{atype}, {otype}>(
|
||||
const device {atype}* C_split [[buffer(0)]],
|
||||
device {otype}* D [[buffer(1)]],
|
||||
const constant int& k_partitions [[buffer(2)]],
|
||||
const constant int& partition_stride [[buffer(3)]],
|
||||
const constant int& ldd [[buffer(4)]],
|
||||
uint2 gid [[thread_position_in_grid]]);
|
||||
)";
|
||||
|
||||
constexpr std::string_view steel_gemm_splitk_accum_axbpy_kernels = R"(
|
||||
template [[host_name("{name}")]] [[kernel]] void
|
||||
gemm_splitk_accum_axpby<{atype}, {otype}>(
|
||||
const device {atype}* C_split [[buffer(0)]],
|
||||
device {otype}* D [[buffer(1)]],
|
||||
const constant int& k_partitions [[buffer(2)]],
|
||||
const constant int& partition_stride [[buffer(3)]],
|
||||
const constant int& ldd [[buffer(4)]],
|
||||
const device {otype}* C [[buffer(5)]],
|
||||
const constant int& ldc [[buffer(6)]],
|
||||
const constant int& fdc [[buffer(7)]],
|
||||
const constant float& alpha [[buffer(8)]],
|
||||
const constant float& beta [[buffer(9)]],
|
||||
uint2 gid [[thread_position_in_grid]]);
|
||||
)";
|
80
mlx/backend/metal/jit/ternary.h
Normal file
80
mlx/backend/metal/jit/ternary.h
Normal file
@@ -0,0 +1,80 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
constexpr std::string_view ternary_kernels = R"(
|
||||
template [[host_name("v_{0}")]] [[kernel]] void ternary_v<{1}, {2}>(
|
||||
device const bool* a,
|
||||
device const {1}* b,
|
||||
device const {1}* c,
|
||||
device {1}* d,
|
||||
uint index [[thread_position_in_grid]]);
|
||||
|
||||
template [[host_name("g_{0}")]] [[kernel]] void ternary_g<{1}, {2}>(
|
||||
device const bool* a,
|
||||
device const {1}* b,
|
||||
device const {1}* c,
|
||||
device {1}* d,
|
||||
constant const int* shape,
|
||||
constant const size_t* a_strides,
|
||||
constant const size_t* b_strides,
|
||||
constant const size_t* c_strides,
|
||||
constant const int& ndim,
|
||||
uint3 index [[thread_position_in_grid]],
|
||||
uint3 grid_dim [[threads_per_grid]]);
|
||||
|
||||
template [[host_name("g1_{0}")]] [[kernel]] void
|
||||
ternary_g_nd1<{1}, {2}>(
|
||||
device const bool* a,
|
||||
device const {1}* b,
|
||||
device const {1}* c,
|
||||
device {1}* d,
|
||||
constant const size_t& a_strides,
|
||||
constant const size_t& b_strides,
|
||||
constant const size_t& c_strides,
|
||||
uint index [[thread_position_in_grid]]);
|
||||
template [[host_name("g2_{0}")]] [[kernel]] void
|
||||
ternary_g_nd2<{1}, {2}>(
|
||||
device const bool* a,
|
||||
device const {1}* b,
|
||||
device const {1}* c,
|
||||
device {1}* d,
|
||||
constant const size_t a_strides[2],
|
||||
constant const size_t b_strides[2],
|
||||
constant const size_t c_strides[2],
|
||||
uint2 index [[thread_position_in_grid]],
|
||||
uint2 grid_dim [[threads_per_grid]]);
|
||||
template [[host_name("g3_{0}")]] [[kernel]] void
|
||||
ternary_g_nd3<{1}, {2}>(
|
||||
device const bool* a,
|
||||
device const {1}* b,
|
||||
device const {1}* c,
|
||||
device {1}* d,
|
||||
constant const size_t a_strides[3],
|
||||
constant const size_t b_strides[3],
|
||||
constant const size_t c_strides[3],
|
||||
uint3 index [[thread_position_in_grid]],
|
||||
uint3 grid_dim [[threads_per_grid]]);
|
||||
template [[host_name("g4_{0}")]] [[kernel]] void
|
||||
ternary_g_nd<{1}, {2}, 4>(
|
||||
device const bool* a,
|
||||
device const {1}* b,
|
||||
device const {1}* c,
|
||||
device {1}* d,
|
||||
constant const int shape[4],
|
||||
constant const size_t a_strides[4],
|
||||
constant const size_t b_strides[4],
|
||||
constant const size_t c_strides[4],
|
||||
uint3 index [[thread_position_in_grid]],
|
||||
uint3 grid_dim [[threads_per_grid]]);
|
||||
template [[host_name("g5_{0}")]] [[kernel]] void
|
||||
ternary_g_nd<{1}, {2}, 5>(
|
||||
device const bool* a,
|
||||
device const {1}* b,
|
||||
device const {1}* c,
|
||||
device {1}* d,
|
||||
constant const int shape[5],
|
||||
constant const size_t a_strides[5],
|
||||
constant const size_t b_strides[5],
|
||||
constant const size_t c_strides[5],
|
||||
uint3 index [[thread_position_in_grid]],
|
||||
uint3 grid_dim [[threads_per_grid]]);
|
||||
)";
|
16
mlx/backend/metal/jit/unary.h
Normal file
16
mlx/backend/metal/jit/unary.h
Normal file
@@ -0,0 +1,16 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
constexpr std::string_view unary_kernels = R"(
|
||||
template [[host_name("v{0}")]] [[kernel]] void unary_v<{1}, {2}>(
|
||||
device const {1}* in,
|
||||
device {1}* out,
|
||||
uint index [[thread_position_in_grid]]);
|
||||
|
||||
template [[host_name("g{0}")]] [[kernel]] void unary_g<{1}, {2}>(
|
||||
device const {1}* in,
|
||||
device {1}* out,
|
||||
device const int* in_shape,
|
||||
device const size_t* in_strides,
|
||||
device const int& ndim,
|
||||
uint index [[thread_position_in_grid]]);
|
||||
)";
|
486
mlx/backend/metal/jit_kernels.cpp
Normal file
486
mlx/backend/metal/jit_kernels.cpp
Normal file
@@ -0,0 +1,486 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
#include <fmt/format.h>
|
||||
|
||||
#include "mlx/backend/common/compiled.h"
|
||||
#include "mlx/backend/metal/jit/arange.h"
|
||||
#include "mlx/backend/metal/jit/binary.h"
|
||||
#include "mlx/backend/metal/jit/binary_two.h"
|
||||
#include "mlx/backend/metal/jit/copy.h"
|
||||
#include "mlx/backend/metal/jit/includes.h"
|
||||
#include "mlx/backend/metal/jit/reduce.h"
|
||||
#include "mlx/backend/metal/jit/scan.h"
|
||||
#include "mlx/backend/metal/jit/softmax.h"
|
||||
#include "mlx/backend/metal/jit/sort.h"
|
||||
#include "mlx/backend/metal/jit/steel_conv.h"
|
||||
#include "mlx/backend/metal/jit/steel_gemm.h"
|
||||
#include "mlx/backend/metal/jit/ternary.h"
|
||||
#include "mlx/backend/metal/jit/unary.h"
|
||||
#include "mlx/backend/metal/kernels.h"
|
||||
#include "mlx/backend/metal/utils.h"
|
||||
|
||||
using namespace fmt::literals;
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
std::string op_name(const array& arr) {
|
||||
std::ostringstream op_t;
|
||||
arr.primitive().print(op_t);
|
||||
return op_t.str();
|
||||
}
|
||||
|
||||
MTL::ComputePipelineState* get_arange_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
const array& out) {
|
||||
const auto& lib_name = kernel_name;
|
||||
auto lib = d.get_library(lib_name);
|
||||
if (lib == nullptr) {
|
||||
std::ostringstream kernel_source;
|
||||
kernel_source
|
||||
<< metal::utils() << metal::arange()
|
||||
<< fmt::format(arange_kernels, lib_name, get_type_string(out.dtype()));
|
||||
lib = d.get_library(lib_name, kernel_source.str());
|
||||
}
|
||||
return d.get_kernel(kernel_name, lib);
|
||||
}
|
||||
|
||||
MTL::ComputePipelineState* get_unary_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
const array& out) {
|
||||
std::string lib_name = kernel_name.substr(1);
|
||||
auto lib = d.get_library(lib_name);
|
||||
if (lib == nullptr) {
|
||||
std::ostringstream kernel_source;
|
||||
kernel_source << metal::utils() << metal::unary_ops() << metal::unary()
|
||||
<< fmt::format(
|
||||
unary_kernels,
|
||||
lib_name,
|
||||
get_type_string(out.dtype()),
|
||||
op_name(out));
|
||||
lib = d.get_library(lib_name, kernel_source.str());
|
||||
}
|
||||
return d.get_kernel(kernel_name, lib);
|
||||
}
|
||||
|
||||
MTL::ComputePipelineState* get_binary_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
const array& in,
|
||||
const array& out) {
|
||||
std::string lib_name = kernel_name.substr(2);
|
||||
auto lib = d.get_library(lib_name);
|
||||
if (lib == nullptr) {
|
||||
std::ostringstream kernel_source;
|
||||
kernel_source << metal::utils() << metal::binary_ops() << metal::binary()
|
||||
<< fmt::format(
|
||||
binary_kernels,
|
||||
lib_name,
|
||||
get_type_string(in.dtype()),
|
||||
get_type_string(out.dtype()),
|
||||
op_name(out));
|
||||
lib = d.get_library(lib_name, kernel_source.str());
|
||||
}
|
||||
return d.get_kernel(kernel_name, lib);
|
||||
}
|
||||
|
||||
MTL::ComputePipelineState* get_binary_two_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
const array& in,
|
||||
const array& out) {
|
||||
std::string lib_name = kernel_name.substr(2);
|
||||
auto lib = d.get_library(lib_name);
|
||||
if (lib == nullptr) {
|
||||
std::ostringstream kernel_source;
|
||||
kernel_source << metal::utils() << metal::binary_ops()
|
||||
<< metal::binary_two()
|
||||
<< fmt::format(
|
||||
binary_two_kernels,
|
||||
lib_name,
|
||||
get_type_string(in.dtype()),
|
||||
get_type_string(out.dtype()),
|
||||
op_name(out));
|
||||
lib = d.get_library(lib_name, kernel_source.str());
|
||||
}
|
||||
return d.get_kernel(kernel_name, lib);
|
||||
}
|
||||
|
||||
MTL::ComputePipelineState* get_ternary_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
const array& out) {
|
||||
std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1);
|
||||
auto lib = d.get_library(lib_name);
|
||||
if (lib == nullptr) {
|
||||
std::ostringstream kernel_source;
|
||||
kernel_source << metal::utils() << metal::ternary_ops() << metal::ternary()
|
||||
<< fmt::format(
|
||||
ternary_kernels,
|
||||
lib_name,
|
||||
get_type_string(out.dtype()),
|
||||
op_name(out));
|
||||
lib = d.get_library(lib_name, kernel_source.str());
|
||||
}
|
||||
return d.get_kernel(kernel_name, lib);
|
||||
}
|
||||
|
||||
MTL::ComputePipelineState* get_copy_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
const array& in,
|
||||
const array& out) {
|
||||
std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1);
|
||||
auto lib = d.get_library(lib_name);
|
||||
if (lib == nullptr) {
|
||||
std::ostringstream kernel_source;
|
||||
kernel_source << metal::utils() << metal::copy()
|
||||
<< fmt::format(
|
||||
copy_kernels,
|
||||
lib_name,
|
||||
get_type_string(in.dtype()),
|
||||
get_type_string(out.dtype()));
|
||||
lib = d.get_library(lib_name, kernel_source.str());
|
||||
}
|
||||
return d.get_kernel(kernel_name, lib);
|
||||
}
|
||||
|
||||
MTL::ComputePipelineState* get_softmax_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
bool precise,
|
||||
const array& out) {
|
||||
std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1);
|
||||
auto lib = d.get_library(lib_name);
|
||||
if (lib == nullptr) {
|
||||
std::ostringstream kernel_source;
|
||||
kernel_source << metal::utils() << metal::softmax()
|
||||
<< fmt::format(
|
||||
softmax_kernels,
|
||||
lib_name,
|
||||
get_type_string(out.dtype()),
|
||||
get_type_string(precise ? float32 : out.dtype()));
|
||||
lib = d.get_library(lib_name, kernel_source.str());
|
||||
}
|
||||
return d.get_kernel(kernel_name, lib);
|
||||
}
|
||||
|
||||
MTL::ComputePipelineState* get_scan_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
bool reverse,
|
||||
bool inclusive,
|
||||
const array& in,
|
||||
const array& out) {
|
||||
std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1);
|
||||
auto lib = d.get_library(lib_name);
|
||||
if (lib == nullptr) {
|
||||
std::ostringstream kernel_source;
|
||||
kernel_source << metal::utils() << metal::scan()
|
||||
<< fmt::format(
|
||||
scan_kernels,
|
||||
lib_name,
|
||||
get_type_string(in.dtype()),
|
||||
get_type_string(out.dtype()),
|
||||
op_name(out),
|
||||
inclusive,
|
||||
reverse);
|
||||
lib = d.get_library(lib_name, kernel_source.str());
|
||||
}
|
||||
return d.get_kernel(kernel_name, lib);
|
||||
}
|
||||
|
||||
MTL::ComputePipelineState* get_sort_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
const array& in,
|
||||
const array& out,
|
||||
int bn,
|
||||
int tn) {
|
||||
std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1);
|
||||
auto lib = d.get_library(lib_name);
|
||||
if (lib == nullptr) {
|
||||
std::ostringstream kernel_source;
|
||||
kernel_source << metal::utils() << metal::sort()
|
||||
<< fmt::format(
|
||||
block_sort_kernels,
|
||||
lib_name,
|
||||
get_type_string(in.dtype()),
|
||||
get_type_string(out.dtype()),
|
||||
bn,
|
||||
tn);
|
||||
lib = d.get_library(lib_name, kernel_source.str());
|
||||
}
|
||||
return d.get_kernel(kernel_name, lib);
|
||||
}
|
||||
|
||||
MTL::ComputePipelineState* get_mb_sort_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
const array& in,
|
||||
const array& idx,
|
||||
int bn,
|
||||
int tn) {
|
||||
std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1);
|
||||
auto lib = d.get_library(lib_name);
|
||||
if (lib == nullptr) {
|
||||
std::ostringstream kernel_source;
|
||||
kernel_source << metal::utils() << metal::sort()
|
||||
<< fmt::format(
|
||||
multiblock_sort_kernels,
|
||||
lib_name,
|
||||
get_type_string(in.dtype()),
|
||||
get_type_string(idx.dtype()),
|
||||
bn,
|
||||
tn);
|
||||
lib = d.get_library(lib_name, kernel_source.str());
|
||||
}
|
||||
return d.get_kernel(kernel_name, lib);
|
||||
}
|
||||
|
||||
MTL::ComputePipelineState* get_reduce_init_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
const array& out) {
|
||||
auto lib = d.get_library(kernel_name);
|
||||
if (lib == nullptr) {
|
||||
std::ostringstream kernel_source;
|
||||
kernel_source << metal::utils() << metal::reduce_utils()
|
||||
<< fmt::format(
|
||||
reduce_init_kernels,
|
||||
kernel_name,
|
||||
get_type_string(out.dtype()),
|
||||
op_name(out));
|
||||
lib = d.get_library(kernel_name, kernel_source.str());
|
||||
}
|
||||
return d.get_kernel(kernel_name, lib);
|
||||
}
|
||||
|
||||
MTL::ComputePipelineState* get_reduce_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
const array& in,
|
||||
const array& out) {
|
||||
std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1);
|
||||
auto lib = d.get_library(lib_name);
|
||||
if (lib == nullptr) {
|
||||
bool non_atomic = out.dtype() == int64 || out.dtype() == uint64;
|
||||
std::ostringstream kernel_source;
|
||||
kernel_source << metal::utils() << metal::reduce_utils() << metal::reduce()
|
||||
<< fmt::format(
|
||||
non_atomic ? reduce_non_atomic_kernels
|
||||
: reduce_kernels,
|
||||
lib_name,
|
||||
get_type_string(in.dtype()),
|
||||
get_type_string(out.dtype()),
|
||||
op_name(out));
|
||||
lib = d.get_library(lib_name, kernel_source.str());
|
||||
}
|
||||
return d.get_kernel(kernel_name, lib);
|
||||
}
|
||||
|
||||
MTL::ComputePipelineState* get_steel_gemm_fused_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
const std::string& hash_name,
|
||||
const metal::MTLFCList& func_consts,
|
||||
const array& out,
|
||||
bool transpose_a,
|
||||
bool transpose_b,
|
||||
int bm,
|
||||
int bn,
|
||||
int bk,
|
||||
int wm,
|
||||
int wn) {
|
||||
const auto& lib_name = kernel_name;
|
||||
auto lib = d.get_library(lib_name);
|
||||
if (lib == nullptr) {
|
||||
std::ostringstream kernel_source;
|
||||
kernel_source << metal::utils() << metal::gemm()
|
||||
<< metal::steel_gemm_fused()
|
||||
<< fmt::format(
|
||||
steel_gemm_fused_kernels,
|
||||
"name"_a = lib_name,
|
||||
"itype"_a = get_type_string(out.dtype()),
|
||||
"bm"_a = bm,
|
||||
"bn"_a = bn,
|
||||
"bk"_a = bk,
|
||||
"wm"_a = wm,
|
||||
"wn"_a = wn,
|
||||
"trans_a"_a = transpose_a,
|
||||
"trans_b"_a = transpose_b);
|
||||
lib = d.get_library(lib_name, kernel_source.str());
|
||||
}
|
||||
return d.get_kernel(kernel_name, lib, hash_name, func_consts);
|
||||
}
|
||||
|
||||
MTL::ComputePipelineState* get_steel_gemm_splitk_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
const array& in,
|
||||
const array& out,
|
||||
bool transpose_a,
|
||||
bool transpose_b,
|
||||
int bm,
|
||||
int bn,
|
||||
int bk,
|
||||
int wm,
|
||||
int wn,
|
||||
bool mn_aligned,
|
||||
bool k_aligned) {
|
||||
const auto& lib_name = kernel_name;
|
||||
auto lib = d.get_library(lib_name);
|
||||
if (lib == nullptr) {
|
||||
std::ostringstream kernel_source;
|
||||
kernel_source << metal::utils() << metal::gemm()
|
||||
<< metal::steel_gemm_splitk()
|
||||
<< fmt::format(
|
||||
steel_gemm_splitk_kernels,
|
||||
"name"_a = lib_name,
|
||||
"itype"_a = get_type_string(in.dtype()),
|
||||
"otype"_a = get_type_string(out.dtype()),
|
||||
"bm"_a = bm,
|
||||
"bn"_a = bn,
|
||||
"bk"_a = bk,
|
||||
"wm"_a = wm,
|
||||
"wn"_a = wn,
|
||||
"trans_a"_a = transpose_a,
|
||||
"trans_b"_a = transpose_b,
|
||||
"mn_aligned"_a = mn_aligned,
|
||||
"k_aligned"_a = k_aligned);
|
||||
lib = d.get_library(lib_name, kernel_source.str());
|
||||
}
|
||||
return d.get_kernel(kernel_name, lib);
|
||||
}
|
||||
|
||||
MTL::ComputePipelineState* get_steel_gemm_splitk_accum_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
const array& in,
|
||||
const array& out,
|
||||
bool axbpy) {
|
||||
const auto& lib_name = kernel_name;
|
||||
auto lib = d.get_library(lib_name);
|
||||
if (lib == nullptr) {
|
||||
std::ostringstream kernel_source;
|
||||
kernel_source << metal::utils() << metal::gemm()
|
||||
<< metal::steel_gemm_splitk()
|
||||
<< fmt::format(
|
||||
axbpy ? steel_gemm_splitk_accum_axbpy_kernels
|
||||
: steel_gemm_splitk_accum_kernels,
|
||||
"name"_a = lib_name,
|
||||
"atype"_a = get_type_string(in.dtype()),
|
||||
"otype"_a = get_type_string(out.dtype()));
|
||||
lib = d.get_library(lib_name, kernel_source.str());
|
||||
}
|
||||
return d.get_kernel(kernel_name, lib);
|
||||
}
|
||||
|
||||
MTL::ComputePipelineState* get_steel_gemm_masked_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
const array& out,
|
||||
const std::optional<array>& mask_out,
|
||||
const std::optional<array>& mask_op,
|
||||
bool transpose_a,
|
||||
bool transpose_b,
|
||||
int bm,
|
||||
int bn,
|
||||
int bk,
|
||||
int wm,
|
||||
int wn,
|
||||
bool mn_aligned,
|
||||
bool k_aligned) {
|
||||
const auto& lib_name = kernel_name;
|
||||
auto lib = d.get_library(lib_name);
|
||||
if (lib == nullptr) {
|
||||
std::ostringstream kernel_source;
|
||||
auto out_mask_type = mask_out.has_value()
|
||||
? get_type_string((*mask_out).dtype())
|
||||
: "nomask_t";
|
||||
auto op_mask_type =
|
||||
mask_op.has_value() ? get_type_string((*mask_op).dtype()) : "nomask_t";
|
||||
kernel_source << metal::utils() << metal::gemm()
|
||||
<< metal::steel_gemm_masked()
|
||||
<< fmt::format(
|
||||
steel_gemm_masked_kernels,
|
||||
"name"_a = lib_name,
|
||||
"itype"_a = get_type_string(out.dtype()),
|
||||
"outmasktype"_a = out_mask_type,
|
||||
"opmasktype"_a = op_mask_type,
|
||||
"bm"_a = bm,
|
||||
"bn"_a = bn,
|
||||
"bk"_a = bk,
|
||||
"wm"_a = wm,
|
||||
"wn"_a = wn,
|
||||
"trans_a"_a = transpose_a,
|
||||
"trans_b"_a = transpose_b,
|
||||
"mn_aligned"_a = mn_aligned,
|
||||
"k_aligned"_a = k_aligned);
|
||||
lib = d.get_library(lib_name, kernel_source.str());
|
||||
}
|
||||
return d.get_kernel(kernel_name, lib);
|
||||
}
|
||||
|
||||
MTL::ComputePipelineState* get_steel_conv_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
const array& out,
|
||||
int bm,
|
||||
int bn,
|
||||
int bk,
|
||||
int wm,
|
||||
int wn,
|
||||
int n_channel_specialization,
|
||||
bool small_filter) {
|
||||
const auto& lib_name = kernel_name;
|
||||
auto lib = d.get_library(lib_name);
|
||||
if (lib == nullptr) {
|
||||
std::ostringstream kernel_source;
|
||||
kernel_source << metal::utils() << metal::conv() << metal::steel_conv()
|
||||
<< fmt::format(
|
||||
steel_conv_kernels,
|
||||
"name"_a = lib_name,
|
||||
"itype"_a = get_type_string(out.dtype()),
|
||||
"bm"_a = bm,
|
||||
"bn"_a = bn,
|
||||
"bk"_a = bk,
|
||||
"wm"_a = wm,
|
||||
"wn"_a = wn,
|
||||
"n_channels"_a = n_channel_specialization,
|
||||
"small_filter"_a = small_filter);
|
||||
lib = d.get_library(lib_name, kernel_source.str());
|
||||
}
|
||||
return d.get_kernel(kernel_name, lib);
|
||||
}
|
||||
|
||||
MTL::ComputePipelineState* get_steel_conv_general_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
const array& out,
|
||||
int bm,
|
||||
int bn,
|
||||
int bk,
|
||||
int wm,
|
||||
int wn) {
|
||||
const auto& lib_name = kernel_name;
|
||||
auto lib = d.get_library(lib_name);
|
||||
if (lib == nullptr) {
|
||||
std::ostringstream kernel_source;
|
||||
kernel_source << metal::utils() << metal::conv()
|
||||
<< metal::steel_conv_general()
|
||||
<< fmt::format(
|
||||
steel_conv_general_kernels,
|
||||
"name"_a = lib_name,
|
||||
"itype"_a = get_type_string(out.dtype()),
|
||||
"bm"_a = bm,
|
||||
"bn"_a = bn,
|
||||
"bk"_a = bk,
|
||||
"wm"_a = wm,
|
||||
"wn"_a = wn);
|
||||
lib = d.get_library(lib_name, kernel_source.str());
|
||||
}
|
||||
return d.get_kernel(kernel_name, lib);
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
156
mlx/backend/metal/kernels.h
Normal file
156
mlx/backend/metal/kernels.h
Normal file
@@ -0,0 +1,156 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#include "mlx/array.h"
|
||||
#include "mlx/backend/metal/device.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
MTL::ComputePipelineState* get_arange_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
const array& out);
|
||||
|
||||
MTL::ComputePipelineState* get_unary_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
const array& out);
|
||||
|
||||
MTL::ComputePipelineState* get_binary_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
const array& in,
|
||||
const array& out);
|
||||
|
||||
MTL::ComputePipelineState* get_binary_two_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
const array& in,
|
||||
const array& out);
|
||||
|
||||
MTL::ComputePipelineState* get_ternary_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
const array& out);
|
||||
|
||||
MTL::ComputePipelineState* get_copy_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
const array& in,
|
||||
const array& out);
|
||||
|
||||
MTL::ComputePipelineState* get_softmax_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
bool precise,
|
||||
const array& out);
|
||||
|
||||
MTL::ComputePipelineState* get_scan_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
bool reverse,
|
||||
bool inclusive,
|
||||
const array& in,
|
||||
const array& out);
|
||||
|
||||
MTL::ComputePipelineState* get_sort_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
const array& in,
|
||||
const array& out,
|
||||
int bn,
|
||||
int tn);
|
||||
|
||||
MTL::ComputePipelineState* get_mb_sort_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
const array& in,
|
||||
const array& idx,
|
||||
int bn,
|
||||
int tn);
|
||||
|
||||
MTL::ComputePipelineState* get_reduce_init_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
const array& out);
|
||||
|
||||
MTL::ComputePipelineState* get_reduce_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
const array& in,
|
||||
const array& out);
|
||||
|
||||
MTL::ComputePipelineState* get_steel_gemm_fused_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
const std::string& hash_name,
|
||||
const metal::MTLFCList& func_consts,
|
||||
const array& out,
|
||||
bool transpose_a,
|
||||
bool transpose_b,
|
||||
int bm,
|
||||
int bn,
|
||||
int bk,
|
||||
int wm,
|
||||
int wn);
|
||||
|
||||
MTL::ComputePipelineState* get_steel_gemm_splitk_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
const array& in,
|
||||
const array& out,
|
||||
bool transpose_a,
|
||||
bool transpose_b,
|
||||
int bm,
|
||||
int bn,
|
||||
int bk,
|
||||
int wm,
|
||||
int wn,
|
||||
bool mn_aligned,
|
||||
bool k_aligned);
|
||||
|
||||
MTL::ComputePipelineState* get_steel_gemm_splitk_accum_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
const array& in,
|
||||
const array& out,
|
||||
bool axbpy);
|
||||
|
||||
MTL::ComputePipelineState* get_steel_gemm_masked_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
const array& out,
|
||||
const std::optional<array>& mask_out,
|
||||
const std::optional<array>& mask_op,
|
||||
bool transpose_a,
|
||||
bool transpose_b,
|
||||
int bm,
|
||||
int bn,
|
||||
int bk,
|
||||
int wm,
|
||||
int wn,
|
||||
bool mn_aligned,
|
||||
bool k_aligned);
|
||||
|
||||
MTL::ComputePipelineState* get_steel_conv_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
const array& out,
|
||||
int bm,
|
||||
int bn,
|
||||
int bk,
|
||||
int wm,
|
||||
int wn,
|
||||
int n_channel_specialization,
|
||||
bool small_filter);
|
||||
|
||||
MTL::ComputePipelineState* get_steel_conv_general_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
const array& out,
|
||||
int bm,
|
||||
int bn,
|
||||
int bk,
|
||||
int wm,
|
||||
int wn);
|
||||
|
||||
} // namespace mlx::core
|
@@ -1,26 +1,17 @@
|
||||
set(
|
||||
HEADERS
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/atomic.h
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/bf16.h
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/bf16_math.h
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/binary.h
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/complex.h
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/defines.h
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/erf.h
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/expm1f.h
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/indexing.h
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/unary.h
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/utils.h
|
||||
bf16.h
|
||||
bf16_math.h
|
||||
complex.h
|
||||
defines.h
|
||||
utils.h
|
||||
steel/conv/params.h
|
||||
)
|
||||
|
||||
set(
|
||||
KERNELS
|
||||
"arange"
|
||||
"arg_reduce"
|
||||
"binary"
|
||||
"binary_two"
|
||||
"conv"
|
||||
"copy"
|
||||
"fft"
|
||||
"gemv"
|
||||
"quantized"
|
||||
@@ -28,18 +19,48 @@ set(
|
||||
"rms_norm"
|
||||
"layer_norm"
|
||||
"rope"
|
||||
"scan"
|
||||
"scaled_dot_product_attention"
|
||||
"softmax"
|
||||
"sort"
|
||||
"ternary"
|
||||
"unary"
|
||||
"gather"
|
||||
"scatter"
|
||||
)
|
||||
|
||||
if (NOT MLX_METAL_JIT)
|
||||
set(
|
||||
KERNELS
|
||||
${KERNELS}
|
||||
"arange"
|
||||
"binary"
|
||||
"binary_two"
|
||||
"unary"
|
||||
"ternary"
|
||||
"copy"
|
||||
"softmax"
|
||||
"sort"
|
||||
"scan"
|
||||
"reduce"
|
||||
)
|
||||
set(
|
||||
HEADERS
|
||||
${HEADERS}
|
||||
atomic.h
|
||||
arange.h
|
||||
unary_ops.h
|
||||
unary.h
|
||||
binary_ops.h
|
||||
binary.h
|
||||
ternary.h
|
||||
copy.h
|
||||
softmax.h
|
||||
sort.h
|
||||
scan.h
|
||||
reduction/ops.h
|
||||
reduction/reduce_init.h
|
||||
reduction/reduce_all.h
|
||||
reduction/reduce_col.h
|
||||
reduction/reduce_row.h
|
||||
)
|
||||
endif()
|
||||
|
||||
function(build_kernel_base TARGET SRCFILE DEPS)
|
||||
set(METAL_FLAGS -Wall -Wextra -fno-fast-math)
|
||||
set(METAL_FLAGS -Wall -Wextra -fno-fast-math -D${MLX_METAL_VERSION})
|
||||
if(MLX_METAL_DEBUG)
|
||||
set(METAL_FLAGS ${METAL_FLAGS}
|
||||
-gline-tables-only
|
||||
@@ -68,23 +89,40 @@ foreach(KERNEL ${KERNELS})
|
||||
set(KERNEL_AIR ${KERNEL}.air ${KERNEL_AIR})
|
||||
endforeach()
|
||||
|
||||
file(GLOB_RECURSE STEEL_KERNELS ${CMAKE_CURRENT_SOURCE_DIR}/steel/*.metal)
|
||||
file(GLOB_RECURSE STEEL_HEADERS ${CMAKE_CURRENT_SOURCE_DIR}/steel/*.h)
|
||||
|
||||
foreach(KERNEL ${STEEL_KERNELS})
|
||||
cmake_path(GET KERNEL STEM TARGET)
|
||||
build_kernel_base(${TARGET} ${KERNEL} "${STEEL_HEADERS}")
|
||||
set(KERNEL_AIR ${TARGET}.air ${KERNEL_AIR})
|
||||
endforeach()
|
||||
|
||||
file(GLOB_RECURSE REDUCE_KERNELS ${CMAKE_CURRENT_SOURCE_DIR}/reduction/*.metal)
|
||||
file(GLOB_RECURSE REDUCE_HEADERS ${CMAKE_CURRENT_SOURCE_DIR}/reduction/*.h)
|
||||
|
||||
foreach(KERNEL ${REDUCE_KERNELS})
|
||||
cmake_path(GET KERNEL STEM TARGET)
|
||||
build_kernel_base(${TARGET} ${KERNEL} "${REDUCE_HEADERS}")
|
||||
set(KERNEL_AIR ${TARGET}.air ${KERNEL_AIR})
|
||||
endforeach()
|
||||
if (NOT MLX_METAL_JIT)
|
||||
set(
|
||||
STEEL_KERNELS
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/steel/conv/kernels/steel_conv.metal
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/steel/conv/kernels/steel_conv_general.metal
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/steel/gemm/kernels/steel_gemm_fused.metal
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/steel/gemm/kernels/steel_gemm_masked.metal
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/steel/gemm/kernels/steel_gemm_splitk.metal
|
||||
)
|
||||
set(
|
||||
STEEL_HEADERS
|
||||
steel/defines.h
|
||||
steel/utils.h
|
||||
steel/conv/conv.h
|
||||
steel/conv/loader.h
|
||||
steel/conv/loaders/loader_channel_l.h
|
||||
steel/conv/loaders/loader_channel_n.h
|
||||
steel/conv/loaders/loader_general.h
|
||||
steel/conv/kernels/steel_conv.h
|
||||
steel/conv/kernels/steel_conv_general.h
|
||||
steel/gemm/gemm.h
|
||||
steel/gemm/mma.h
|
||||
steel/gemm/loader.h
|
||||
steel/gemm/transforms.h
|
||||
steel/gemm/kernels/steel_gemm_fused.h
|
||||
steel/gemm/kernels/steel_gemm_masked.h
|
||||
steel/gemm/kernels/steel_gemm_splitk.h
|
||||
)
|
||||
foreach(KERNEL ${STEEL_KERNELS})
|
||||
cmake_path(GET KERNEL STEM TARGET)
|
||||
build_kernel_base(${TARGET} ${KERNEL} "${STEEL_HEADERS}")
|
||||
set(KERNEL_AIR ${TARGET}.air ${KERNEL_AIR})
|
||||
endforeach()
|
||||
endif()
|
||||
|
||||
add_custom_command(
|
||||
OUTPUT ${MLX_METAL_PATH}/mlx.metallib
|
||||
|
9
mlx/backend/metal/kernels/arange.h
Normal file
9
mlx/backend/metal/kernels/arange.h
Normal file
@@ -0,0 +1,9 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
template <typename T>
|
||||
[[kernel]] void arange(
|
||||
constant const T& start,
|
||||
constant const T& step,
|
||||
device T* out,
|
||||
uint index [[thread_position_in_grid]]) {
|
||||
out[index] = start + index * step;
|
||||
}
|
@@ -1,15 +1,8 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
// clang-format off
|
||||
#include "mlx/backend/metal/kernels/bf16.h"
|
||||
|
||||
template <typename T>
|
||||
[[kernel]] void arange(
|
||||
constant const T& start,
|
||||
constant const T& step,
|
||||
device T* out,
|
||||
uint index [[thread_position_in_grid]]) {
|
||||
out[index] = start + index * step;
|
||||
}
|
||||
#include "mlx/backend/metal/kernels/arange.h"
|
||||
|
||||
#define instantiate_arange(tname, type) \
|
||||
template [[host_name("arange" #tname)]] [[kernel]] void arange<type>( \
|
||||
@@ -18,15 +11,14 @@ template <typename T>
|
||||
device type* out, \
|
||||
uint index [[thread_position_in_grid]]);
|
||||
|
||||
// clang-format off
|
||||
instantiate_arange(uint8, uint8_t)
|
||||
instantiate_arange(uint8, uint8_t)
|
||||
instantiate_arange(uint16, uint16_t)
|
||||
instantiate_arange(uint32, uint32_t)
|
||||
instantiate_arange(uint32, uint32_t)
|
||||
instantiate_arange(uint64, uint64_t)
|
||||
instantiate_arange(int8, int8_t)
|
||||
instantiate_arange(int8, int8_t)
|
||||
instantiate_arange(int16, int16_t)
|
||||
instantiate_arange(int32, int32_t)
|
||||
instantiate_arange(int64, int64_t)
|
||||
instantiate_arange(float16, half)
|
||||
instantiate_arange(float32, float)
|
||||
instantiate_arange(bfloat16, bfloat16_t) // clang-format on
|
||||
instantiate_arange(bfloat16, bfloat16_t) // clang-format on
|
||||
|
@@ -1,6 +1,5 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
|
||||
#include <metal_atomic>
|
||||
#include <metal_simdgroup>
|
||||
|
||||
#include "mlx/backend/metal/kernels/utils.h"
|
||||
@@ -194,4 +193,4 @@ instantiate_arg_reduce(int32, int32_t)
|
||||
instantiate_arg_reduce(int64, int64_t)
|
||||
instantiate_arg_reduce(float16, half)
|
||||
instantiate_arg_reduce(float32, float)
|
||||
instantiate_arg_reduce(bfloat16, bfloat16_t) // clang-format on
|
||||
instantiate_arg_reduce(bfloat16, bfloat16_t) // clang-format on
|
||||
|
@@ -4,7 +4,6 @@
|
||||
|
||||
#include <metal_atomic>
|
||||
#include <metal_stdlib>
|
||||
#include "mlx/backend/metal/kernels/bf16.h"
|
||||
|
||||
using namespace metal;
|
||||
|
||||
|
@@ -6,7 +6,7 @@
|
||||
|
||||
using namespace metal;
|
||||
|
||||
#if defined(__HAVE_BFLOAT__)
|
||||
#if defined METAL_3_1 || (__METAL_VERSION__ >= 310)
|
||||
|
||||
typedef bfloat bfloat16_t;
|
||||
|
||||
@@ -312,6 +312,6 @@ METAL_FUNC bool isnan(_MLX_BFloat16 x) {
|
||||
|
||||
#pragma METAL internals : disable
|
||||
|
||||
#endif // defined(__HAVE_BFLOAT__)
|
||||
#endif
|
||||
|
||||
#include "mlx/backend/metal/kernels/bf16_math.h"
|
||||
|
@@ -369,7 +369,7 @@ instantiate_metal_math_funcs(
|
||||
return static_cast<otype>(__metal_simd_xor(static_cast<ctype>(data))); \
|
||||
}
|
||||
|
||||
#if defined(__HAVE_BFLOAT__)
|
||||
#if defined METAL_3_1 || (__METAL_VERSION__ >= 310)
|
||||
|
||||
#define bfloat16_to_uint16(x) as_type<uint16_t>(x)
|
||||
#define uint16_to_bfloat16(x) as_type<bfloat16_t>(x)
|
||||
@@ -391,4 +391,4 @@ instantiate_metal_simd_comm_funcs(
|
||||
uint16_to_bfloat16);
|
||||
instantiate_metal_simd_reduction_funcs(bfloat16_t, bfloat16_t, float);
|
||||
|
||||
} // namespace metal
|
||||
} // namespace metal
|
||||
|
@@ -1,266 +1,113 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
template <typename T, typename U, typename Op>
|
||||
[[kernel]] void binary_ss(
|
||||
device const T* a,
|
||||
device const T* b,
|
||||
device U* c,
|
||||
uint index [[thread_position_in_grid]]) {
|
||||
c[index] = Op()(a[0], b[0]);
|
||||
}
|
||||
|
||||
#include <metal_integer>
|
||||
#include <metal_math>
|
||||
template <typename T, typename U, typename Op>
|
||||
[[kernel]] void binary_sv(
|
||||
device const T* a,
|
||||
device const T* b,
|
||||
device U* c,
|
||||
uint index [[thread_position_in_grid]]) {
|
||||
c[index] = Op()(a[0], b[index]);
|
||||
}
|
||||
|
||||
#include "mlx/backend/metal/kernels/bf16.h"
|
||||
#include "mlx/backend/metal/kernels/utils.h"
|
||||
template <typename T, typename U, typename Op>
|
||||
[[kernel]] void binary_vs(
|
||||
device const T* a,
|
||||
device const T* b,
|
||||
device U* c,
|
||||
uint index [[thread_position_in_grid]]) {
|
||||
c[index] = Op()(a[index], b[0]);
|
||||
}
|
||||
|
||||
struct Add {
|
||||
template <typename T>
|
||||
T operator()(T x, T y) {
|
||||
return x + y;
|
||||
}
|
||||
};
|
||||
template <typename T, typename U, typename Op>
|
||||
[[kernel]] void binary_vv(
|
||||
device const T* a,
|
||||
device const T* b,
|
||||
device U* c,
|
||||
uint index [[thread_position_in_grid]]) {
|
||||
c[index] = Op()(a[index], b[index]);
|
||||
}
|
||||
|
||||
struct Divide {
|
||||
template <typename T>
|
||||
T operator()(T x, T y) {
|
||||
return x / y;
|
||||
}
|
||||
};
|
||||
template <typename T, typename U, typename Op>
|
||||
[[kernel]] void binary_g_nd1(
|
||||
device const T* a,
|
||||
device const T* b,
|
||||
device U* c,
|
||||
constant const size_t& a_stride,
|
||||
constant const size_t& b_stride,
|
||||
uint index [[thread_position_in_grid]]) {
|
||||
auto a_idx = elem_to_loc_1(index, a_stride);
|
||||
auto b_idx = elem_to_loc_1(index, b_stride);
|
||||
c[index] = Op()(a[a_idx], b[b_idx]);
|
||||
}
|
||||
|
||||
struct Remainder {
|
||||
template <typename T>
|
||||
metal::enable_if_t<metal::is_integral_v<T> & !metal::is_signed_v<T>, T>
|
||||
operator()(T x, T y) {
|
||||
return x % y;
|
||||
}
|
||||
template <typename T>
|
||||
metal::enable_if_t<metal::is_integral_v<T> & metal::is_signed_v<T>, T>
|
||||
operator()(T x, T y) {
|
||||
auto r = x % y;
|
||||
if (r != 0 && (r < 0 != y < 0)) {
|
||||
r += y;
|
||||
}
|
||||
return r;
|
||||
}
|
||||
template <typename T>
|
||||
metal::enable_if_t<!metal::is_integral_v<T>, T> operator()(T x, T y) {
|
||||
T r = fmod(x, y);
|
||||
if (r != 0 && (r < 0 != y < 0)) {
|
||||
r += y;
|
||||
}
|
||||
return r;
|
||||
}
|
||||
template <>
|
||||
complex64_t operator()(complex64_t x, complex64_t y) {
|
||||
return x % y;
|
||||
}
|
||||
};
|
||||
template <typename T, typename U, typename Op>
|
||||
[[kernel]] void binary_g_nd2(
|
||||
device const T* a,
|
||||
device const T* b,
|
||||
device U* c,
|
||||
constant const size_t a_strides[2],
|
||||
constant const size_t b_strides[2],
|
||||
uint2 index [[thread_position_in_grid]],
|
||||
uint2 grid_dim [[threads_per_grid]]) {
|
||||
auto a_idx = elem_to_loc_2(index, a_strides);
|
||||
auto b_idx = elem_to_loc_2(index, b_strides);
|
||||
size_t out_idx = index.x + (size_t)grid_dim.x * index.y;
|
||||
c[out_idx] = Op()(a[a_idx], b[b_idx]);
|
||||
}
|
||||
|
||||
struct Equal {
|
||||
template <typename T>
|
||||
bool operator()(T x, T y) {
|
||||
return x == y;
|
||||
}
|
||||
};
|
||||
template <typename T, typename U, typename Op>
|
||||
[[kernel]] void binary_g_nd3(
|
||||
device const T* a,
|
||||
device const T* b,
|
||||
device U* c,
|
||||
constant const size_t a_strides[3],
|
||||
constant const size_t b_strides[3],
|
||||
uint3 index [[thread_position_in_grid]],
|
||||
uint3 grid_dim [[threads_per_grid]]) {
|
||||
auto a_idx = elem_to_loc_3(index, a_strides);
|
||||
auto b_idx = elem_to_loc_3(index, b_strides);
|
||||
size_t out_idx =
|
||||
index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z);
|
||||
c[out_idx] = Op()(a[a_idx], b[b_idx]);
|
||||
}
|
||||
|
||||
struct NaNEqual {
|
||||
template <typename T>
|
||||
bool operator()(T x, T y) {
|
||||
return x == y || (metal::isnan(x) && metal::isnan(y));
|
||||
}
|
||||
template <>
|
||||
bool operator()(complex64_t x, complex64_t y) {
|
||||
return x == y ||
|
||||
(metal::isnan(x.real) && metal::isnan(y.real) && metal::isnan(x.imag) &&
|
||||
metal::isnan(y.imag)) ||
|
||||
(x.real == y.real && metal::isnan(x.imag) && metal::isnan(y.imag)) ||
|
||||
(metal::isnan(x.real) && metal::isnan(y.real) && x.imag == y.imag);
|
||||
}
|
||||
};
|
||||
template <typename T, typename U, typename Op, int DIM>
|
||||
[[kernel]] void binary_g_nd(
|
||||
device const T* a,
|
||||
device const T* b,
|
||||
device U* c,
|
||||
constant const int shape[DIM],
|
||||
constant const size_t a_strides[DIM],
|
||||
constant const size_t b_strides[DIM],
|
||||
uint3 index [[thread_position_in_grid]],
|
||||
uint3 grid_dim [[threads_per_grid]]) {
|
||||
auto idx = elem_to_loc_2_nd<DIM>(index, shape, a_strides, b_strides);
|
||||
size_t out_idx =
|
||||
index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z);
|
||||
c[out_idx] = Op()(a[idx.x], b[idx.y]);
|
||||
}
|
||||
|
||||
struct Greater {
|
||||
template <typename T>
|
||||
bool operator()(T x, T y) {
|
||||
return x > y;
|
||||
}
|
||||
};
|
||||
|
||||
struct GreaterEqual {
|
||||
template <typename T>
|
||||
bool operator()(T x, T y) {
|
||||
return x >= y;
|
||||
}
|
||||
};
|
||||
|
||||
struct Less {
|
||||
template <typename T>
|
||||
bool operator()(T x, T y) {
|
||||
return x < y;
|
||||
}
|
||||
};
|
||||
|
||||
struct LessEqual {
|
||||
template <typename T>
|
||||
bool operator()(T x, T y) {
|
||||
return x <= y;
|
||||
}
|
||||
};
|
||||
|
||||
struct LogAddExp {
|
||||
template <typename T>
|
||||
T operator()(T x, T y) {
|
||||
if (metal::isnan(x) || metal::isnan(y)) {
|
||||
return metal::numeric_limits<T>::quiet_NaN();
|
||||
}
|
||||
constexpr T inf = metal::numeric_limits<T>::infinity();
|
||||
T maxval = metal::max(x, y);
|
||||
T minval = metal::min(x, y);
|
||||
return (minval == -inf || maxval == inf)
|
||||
? maxval
|
||||
: (maxval + log1p(metal::exp(minval - maxval)));
|
||||
};
|
||||
};
|
||||
|
||||
struct Maximum {
|
||||
template <typename T>
|
||||
metal::enable_if_t<metal::is_integral_v<T>, T> operator()(T x, T y) {
|
||||
return metal::max(x, y);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
metal::enable_if_t<!metal::is_integral_v<T>, T> operator()(T x, T y) {
|
||||
if (metal::isnan(x)) {
|
||||
return x;
|
||||
}
|
||||
return x > y ? x : y;
|
||||
}
|
||||
|
||||
template <>
|
||||
complex64_t operator()(complex64_t x, complex64_t y) {
|
||||
if (metal::isnan(x.real) || metal::isnan(x.imag)) {
|
||||
return x;
|
||||
}
|
||||
return x > y ? x : y;
|
||||
}
|
||||
};
|
||||
|
||||
struct Minimum {
|
||||
template <typename T>
|
||||
metal::enable_if_t<metal::is_integral_v<T>, T> operator()(T x, T y) {
|
||||
return metal::min(x, y);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
metal::enable_if_t<!metal::is_integral_v<T>, T> operator()(T x, T y) {
|
||||
if (metal::isnan(x)) {
|
||||
return x;
|
||||
}
|
||||
return x < y ? x : y;
|
||||
}
|
||||
|
||||
template <>
|
||||
complex64_t operator()(complex64_t x, complex64_t y) {
|
||||
if (metal::isnan(x.real) || metal::isnan(x.imag)) {
|
||||
return x;
|
||||
}
|
||||
return x < y ? x : y;
|
||||
}
|
||||
};
|
||||
|
||||
struct Multiply {
|
||||
template <typename T>
|
||||
T operator()(T x, T y) {
|
||||
return x * y;
|
||||
}
|
||||
};
|
||||
|
||||
struct NotEqual {
|
||||
template <typename T>
|
||||
bool operator()(T x, T y) {
|
||||
return x != y;
|
||||
}
|
||||
template <>
|
||||
bool operator()(complex64_t x, complex64_t y) {
|
||||
return x.real != y.real || x.imag != y.imag;
|
||||
}
|
||||
};
|
||||
|
||||
struct Power {
|
||||
template <typename T>
|
||||
metal::enable_if_t<!metal::is_integral_v<T>, T> operator()(T base, T exp) {
|
||||
return metal::pow(base, exp);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
metal::enable_if_t<metal::is_integral_v<T>, T> operator()(T base, T exp) {
|
||||
T res = 1;
|
||||
while (exp) {
|
||||
if (exp & 1) {
|
||||
res *= base;
|
||||
}
|
||||
exp >>= 1;
|
||||
base *= base;
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
template <>
|
||||
complex64_t operator()(complex64_t x, complex64_t y) {
|
||||
auto x_theta = metal::atan(x.imag / x.real);
|
||||
auto x_ln_r = 0.5 * metal::log(x.real * x.real + x.imag * x.imag);
|
||||
auto mag = metal::exp(y.real * x_ln_r - y.imag * x_theta);
|
||||
auto phase = y.imag * x_ln_r + y.real * x_theta;
|
||||
return {mag * metal::cos(phase), mag * metal::sin(phase)};
|
||||
}
|
||||
};
|
||||
|
||||
struct Subtract {
|
||||
template <typename T>
|
||||
T operator()(T x, T y) {
|
||||
return x - y;
|
||||
}
|
||||
};
|
||||
|
||||
struct LogicalAnd {
|
||||
template <typename T>
|
||||
T operator()(T x, T y) {
|
||||
return x && y;
|
||||
};
|
||||
};
|
||||
|
||||
struct LogicalOr {
|
||||
template <typename T>
|
||||
T operator()(T x, T y) {
|
||||
return x || y;
|
||||
};
|
||||
};
|
||||
|
||||
struct BitwiseAnd {
|
||||
template <typename T>
|
||||
T operator()(T x, T y) {
|
||||
return x & y;
|
||||
};
|
||||
};
|
||||
|
||||
struct BitwiseOr {
|
||||
template <typename T>
|
||||
T operator()(T x, T y) {
|
||||
return x | y;
|
||||
};
|
||||
};
|
||||
|
||||
struct BitwiseXor {
|
||||
template <typename T>
|
||||
T operator()(T x, T y) {
|
||||
return x ^ y;
|
||||
};
|
||||
};
|
||||
|
||||
struct LeftShift {
|
||||
template <typename T>
|
||||
T operator()(T x, T y) {
|
||||
return x << y;
|
||||
};
|
||||
};
|
||||
|
||||
struct RightShift {
|
||||
template <typename T>
|
||||
T operator()(T x, T y) {
|
||||
return x >> y;
|
||||
};
|
||||
};
|
||||
template <typename T, typename U, typename Op>
|
||||
[[kernel]] void binary_g(
|
||||
device const T* a,
|
||||
device const T* b,
|
||||
device U* c,
|
||||
constant const int* shape,
|
||||
constant const size_t* a_strides,
|
||||
constant const size_t* b_strides,
|
||||
constant const int& ndim,
|
||||
uint3 index [[thread_position_in_grid]],
|
||||
uint3 grid_dim [[threads_per_grid]]) {
|
||||
auto idx = elem_to_loc_2_nd(index, shape, a_strides, b_strides, ndim);
|
||||
size_t out_idx = index.x + grid_dim.x * (index.y + grid_dim.y * index.z);
|
||||
c[out_idx] = Op()(a[idx.x], b[idx.y]);
|
||||
}
|
||||
|
@@ -1,130 +1,24 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#include <metal_integer>
|
||||
#include <metal_math>
|
||||
|
||||
// clang-format off
|
||||
#include "mlx/backend/metal/kernels/utils.h"
|
||||
#include "mlx/backend/metal/kernels/binary_ops.h"
|
||||
#include "mlx/backend/metal/kernels/binary.h"
|
||||
|
||||
template <typename T, typename U, typename Op>
|
||||
[[kernel]] void binary_op_ss(
|
||||
device const T* a,
|
||||
device const T* b,
|
||||
device U* c,
|
||||
uint index [[thread_position_in_grid]]) {
|
||||
c[index] = Op()(a[0], b[0]);
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op>
|
||||
[[kernel]] void binary_op_sv(
|
||||
device const T* a,
|
||||
device const T* b,
|
||||
device U* c,
|
||||
uint index [[thread_position_in_grid]]) {
|
||||
c[index] = Op()(a[0], b[index]);
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op>
|
||||
[[kernel]] void binary_op_vs(
|
||||
device const T* a,
|
||||
device const T* b,
|
||||
device U* c,
|
||||
uint index [[thread_position_in_grid]]) {
|
||||
c[index] = Op()(a[index], b[0]);
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op>
|
||||
[[kernel]] void binary_op_vv(
|
||||
device const T* a,
|
||||
device const T* b,
|
||||
device U* c,
|
||||
uint index [[thread_position_in_grid]]) {
|
||||
c[index] = Op()(a[index], b[index]);
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op>
|
||||
[[kernel]] void binary_op_g_nd1(
|
||||
device const T* a,
|
||||
device const T* b,
|
||||
device U* c,
|
||||
constant const size_t& a_stride,
|
||||
constant const size_t& b_stride,
|
||||
uint index [[thread_position_in_grid]]) {
|
||||
auto a_idx = elem_to_loc_1(index, a_stride);
|
||||
auto b_idx = elem_to_loc_1(index, b_stride);
|
||||
c[index] = Op()(a[a_idx], b[b_idx]);
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op>
|
||||
[[kernel]] void binary_op_g_nd2(
|
||||
device const T* a,
|
||||
device const T* b,
|
||||
device U* c,
|
||||
constant const size_t a_strides[2],
|
||||
constant const size_t b_strides[2],
|
||||
uint2 index [[thread_position_in_grid]],
|
||||
uint2 grid_dim [[threads_per_grid]]) {
|
||||
auto a_idx = elem_to_loc_2(index, a_strides);
|
||||
auto b_idx = elem_to_loc_2(index, b_strides);
|
||||
size_t out_idx = index.x + (size_t)grid_dim.x * index.y;
|
||||
c[out_idx] = Op()(a[a_idx], b[b_idx]);
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op>
|
||||
[[kernel]] void binary_op_g_nd3(
|
||||
device const T* a,
|
||||
device const T* b,
|
||||
device U* c,
|
||||
constant const size_t a_strides[3],
|
||||
constant const size_t b_strides[3],
|
||||
uint3 index [[thread_position_in_grid]],
|
||||
uint3 grid_dim [[threads_per_grid]]) {
|
||||
auto a_idx = elem_to_loc_3(index, a_strides);
|
||||
auto b_idx = elem_to_loc_3(index, b_strides);
|
||||
size_t out_idx =
|
||||
index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z);
|
||||
c[out_idx] = Op()(a[a_idx], b[b_idx]);
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op, int DIM>
|
||||
[[kernel]] void binary_op_g_nd(
|
||||
device const T* a,
|
||||
device const T* b,
|
||||
device U* c,
|
||||
constant const int shape[DIM],
|
||||
constant const size_t a_strides[DIM],
|
||||
constant const size_t b_strides[DIM],
|
||||
uint3 index [[thread_position_in_grid]],
|
||||
uint3 grid_dim [[threads_per_grid]]) {
|
||||
auto idx = elem_to_loc_2_nd<DIM>(index, shape, a_strides, b_strides);
|
||||
size_t out_idx =
|
||||
index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z);
|
||||
c[out_idx] = Op()(a[idx.x], b[idx.y]);
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op>
|
||||
[[kernel]] void binary_op_g(
|
||||
device const T* a,
|
||||
device const T* b,
|
||||
device U* c,
|
||||
constant const int* shape,
|
||||
constant const size_t* a_strides,
|
||||
constant const size_t* b_strides,
|
||||
constant const int& ndim,
|
||||
uint3 index [[thread_position_in_grid]],
|
||||
uint3 grid_dim [[threads_per_grid]]) {
|
||||
auto idx = elem_to_loc_2_nd(index, shape, a_strides, b_strides, ndim);
|
||||
size_t out_idx = index.x + grid_dim.x * (index.y + grid_dim.y * index.z);
|
||||
c[out_idx] = Op()(a[idx.x], b[idx.y]);
|
||||
}
|
||||
|
||||
#define instantiate_binary(name, itype, otype, op, bopt) \
|
||||
template \
|
||||
[[host_name(name)]] [[kernel]] void binary_op_##bopt<itype, otype, op>( \
|
||||
[[host_name(name)]] [[kernel]] void binary_##bopt<itype, otype, op>( \
|
||||
device const itype* a, \
|
||||
device const itype* b, \
|
||||
device otype* c, \
|
||||
uint index [[thread_position_in_grid]]);
|
||||
|
||||
#define instantiate_binary_g_dim(name, itype, otype, op, dims) \
|
||||
template [[host_name(name "_" #dims)]] [[kernel]] void \
|
||||
binary_op_g_nd<itype, otype, op, dims>( \
|
||||
template [[host_name("g" #dims name)]] [[kernel]] void \
|
||||
binary_g_nd<itype, otype, op, dims>( \
|
||||
device const itype* a, \
|
||||
device const itype* b, \
|
||||
device otype* c, \
|
||||
@@ -135,16 +29,16 @@ template <typename T, typename U, typename Op>
|
||||
uint3 grid_dim [[threads_per_grid]]);
|
||||
|
||||
#define instantiate_binary_g_nd(name, itype, otype, op) \
|
||||
template [[host_name(name "_1")]] [[kernel]] void \
|
||||
binary_op_g_nd1<itype, otype, op>( \
|
||||
template [[host_name("g1" name)]] [[kernel]] void \
|
||||
binary_g_nd1<itype, otype, op>( \
|
||||
device const itype* a, \
|
||||
device const itype* b, \
|
||||
device otype* c, \
|
||||
constant const size_t& a_stride, \
|
||||
constant const size_t& b_stride, \
|
||||
uint index [[thread_position_in_grid]]); \
|
||||
template [[host_name(name "_2")]] [[kernel]] void \
|
||||
binary_op_g_nd2<itype, otype, op>( \
|
||||
template [[host_name("g2" name)]] [[kernel]] void \
|
||||
binary_g_nd2<itype, otype, op>( \
|
||||
device const itype* a, \
|
||||
device const itype* b, \
|
||||
device otype* c, \
|
||||
@@ -152,8 +46,8 @@ template <typename T, typename U, typename Op>
|
||||
constant const size_t b_strides[2], \
|
||||
uint2 index [[thread_position_in_grid]], \
|
||||
uint2 grid_dim [[threads_per_grid]]); \
|
||||
template [[host_name(name "_3")]] [[kernel]] void \
|
||||
binary_op_g_nd3<itype, otype, op>( \
|
||||
template [[host_name("g3" name)]] [[kernel]] void \
|
||||
binary_g_nd3<itype, otype, op>( \
|
||||
device const itype* a, \
|
||||
device const itype* b, \
|
||||
device otype* c, \
|
||||
@@ -162,30 +56,28 @@ template <typename T, typename U, typename Op>
|
||||
uint3 index [[thread_position_in_grid]], \
|
||||
uint3 grid_dim [[threads_per_grid]]); \
|
||||
instantiate_binary_g_dim(name, itype, otype, op, 4) \
|
||||
instantiate_binary_g_dim(name, itype, otype, op, 5)
|
||||
instantiate_binary_g_dim(name, itype, otype, op, 5)
|
||||
|
||||
#define instantiate_binary_g(name, itype, otype, op) \
|
||||
template [[host_name(name)]] [[kernel]] void binary_op_g<itype, otype, op>( \
|
||||
device const itype* a, \
|
||||
device const itype* b, \
|
||||
device otype* c, \
|
||||
constant const int* shape, \
|
||||
constant const size_t* a_strides, \
|
||||
constant const size_t* b_strides, \
|
||||
constant const int& ndim, \
|
||||
uint3 index [[thread_position_in_grid]], \
|
||||
#define instantiate_binary_g(name, itype, otype, op) \
|
||||
template [[host_name("gn" name)]] [[kernel]] void binary_g<itype, otype, op>( \
|
||||
device const itype* a, \
|
||||
device const itype* b, \
|
||||
device otype* c, \
|
||||
constant const int* shape, \
|
||||
constant const size_t* a_strides, \
|
||||
constant const size_t* b_strides, \
|
||||
constant const int& ndim, \
|
||||
uint3 index [[thread_position_in_grid]], \
|
||||
uint3 grid_dim [[threads_per_grid]]);
|
||||
|
||||
// clang-format off
|
||||
#define instantiate_binary_all(name, tname, itype, otype, op) \
|
||||
instantiate_binary("ss" #name #tname, itype, otype, op, ss) \
|
||||
instantiate_binary("sv" #name #tname, itype, otype, op, sv) \
|
||||
instantiate_binary("vs" #name #tname, itype, otype, op, vs) \
|
||||
instantiate_binary("vv" #name #tname, itype, otype, op, vv) \
|
||||
instantiate_binary_g("g" #name #tname, itype, otype, op) \
|
||||
instantiate_binary_g_nd("g" #name #tname, itype, otype, op) // clang-format on
|
||||
instantiate_binary_g(#name #tname, itype, otype, op) \
|
||||
instantiate_binary_g_nd(#name #tname, itype, otype, op)
|
||||
|
||||
// clang-format off
|
||||
#define instantiate_binary_integer(name, op) \
|
||||
instantiate_binary_all(name, uint8, uint8_t, uint8_t, op) \
|
||||
instantiate_binary_all(name, uint16, uint16_t, uint16_t, op) \
|
||||
@@ -194,22 +86,19 @@ template <typename T, typename U, typename Op>
|
||||
instantiate_binary_all(name, int8, int8_t, int8_t, op) \
|
||||
instantiate_binary_all(name, int16, int16_t, int16_t, op) \
|
||||
instantiate_binary_all(name, int32, int32_t, int32_t, op) \
|
||||
instantiate_binary_all(name, int64, int64_t, int64_t, op) // clang-format on
|
||||
instantiate_binary_all(name, int64, int64_t, int64_t, op)
|
||||
|
||||
// clang-format off
|
||||
#define instantiate_binary_float(name, op) \
|
||||
instantiate_binary_all(name, float16, half, half, op) \
|
||||
instantiate_binary_all(name, float32, float, float, op) \
|
||||
instantiate_binary_all(name, bfloat16, bfloat16_t, bfloat16_t, op) // clang-format on
|
||||
instantiate_binary_all(name, bfloat16, bfloat16_t, bfloat16_t, op)
|
||||
|
||||
// clang-format off
|
||||
#define instantiate_binary_types(name, op) \
|
||||
instantiate_binary_all(name, bool_, bool, bool, op) \
|
||||
instantiate_binary_integer(name, op) \
|
||||
instantiate_binary_all(name, complex64, complex64_t, complex64_t, op) \
|
||||
instantiate_binary_float(name, op) // clang-format on
|
||||
instantiate_binary_float(name, op)
|
||||
|
||||
// clang-format off
|
||||
#define instantiate_binary_types_bool(name, op) \
|
||||
instantiate_binary_all(name, bool_, bool, bool, op) \
|
||||
instantiate_binary_all(name, uint8, uint8_t, bool, op) \
|
||||
@@ -223,9 +112,8 @@ template <typename T, typename U, typename Op>
|
||||
instantiate_binary_all(name, float16, half, bool, op) \
|
||||
instantiate_binary_all(name, float32, float, bool, op) \
|
||||
instantiate_binary_all(name, bfloat16, bfloat16_t, bool, op) \
|
||||
instantiate_binary_all(name, complex64, complex64_t, bool, op) // clang-format on
|
||||
instantiate_binary_all(name, complex64, complex64_t, bool, op)
|
||||
|
||||
// clang-format off
|
||||
instantiate_binary_types(add, Add)
|
||||
instantiate_binary_types(div, Divide)
|
||||
instantiate_binary_types_bool(eq, Equal)
|
||||
@@ -241,6 +129,7 @@ instantiate_binary_types(mul, Multiply)
|
||||
instantiate_binary_types(sub, Subtract)
|
||||
instantiate_binary_types(pow, Power)
|
||||
instantiate_binary_types(rem, Remainder)
|
||||
instantiate_binary_float(arctan2, ArcTan2)
|
||||
|
||||
// NaNEqual only needed for floating point types with boolean output
|
||||
instantiate_binary_all(naneq, float16, half, bool, NaNEqual)
|
||||
|
296
mlx/backend/metal/kernels/binary_ops.h
Normal file
296
mlx/backend/metal/kernels/binary_ops.h
Normal file
@@ -0,0 +1,296 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <metal_integer>
|
||||
#include <metal_math>
|
||||
|
||||
struct Add {
|
||||
template <typename T>
|
||||
T operator()(T x, T y) {
|
||||
return x + y;
|
||||
}
|
||||
};
|
||||
|
||||
struct FloorDivide {
|
||||
template <typename T>
|
||||
T operator()(T x, T y) {
|
||||
return x / y;
|
||||
}
|
||||
template <>
|
||||
float operator()(float x, float y) {
|
||||
return trunc(x / y);
|
||||
}
|
||||
template <>
|
||||
half operator()(half x, half y) {
|
||||
return trunc(x / y);
|
||||
}
|
||||
template <>
|
||||
bfloat16_t operator()(bfloat16_t x, bfloat16_t y) {
|
||||
return trunc(x / y);
|
||||
}
|
||||
};
|
||||
|
||||
struct Divide {
|
||||
template <typename T>
|
||||
T operator()(T x, T y) {
|
||||
return x / y;
|
||||
}
|
||||
};
|
||||
|
||||
struct Remainder {
|
||||
template <typename T>
|
||||
metal::enable_if_t<metal::is_integral_v<T> & !metal::is_signed_v<T>, T>
|
||||
operator()(T x, T y) {
|
||||
return x % y;
|
||||
}
|
||||
template <typename T>
|
||||
metal::enable_if_t<metal::is_integral_v<T> & metal::is_signed_v<T>, T>
|
||||
operator()(T x, T y) {
|
||||
auto r = x % y;
|
||||
if (r != 0 && (r < 0 != y < 0)) {
|
||||
r += y;
|
||||
}
|
||||
return r;
|
||||
}
|
||||
template <typename T>
|
||||
metal::enable_if_t<!metal::is_integral_v<T>, T> operator()(T x, T y) {
|
||||
T r = fmod(x, y);
|
||||
if (r != 0 && (r < 0 != y < 0)) {
|
||||
r += y;
|
||||
}
|
||||
return r;
|
||||
}
|
||||
template <>
|
||||
complex64_t operator()(complex64_t x, complex64_t y) {
|
||||
return x % y;
|
||||
}
|
||||
};
|
||||
|
||||
struct Equal {
|
||||
template <typename T>
|
||||
bool operator()(T x, T y) {
|
||||
return x == y;
|
||||
}
|
||||
};
|
||||
|
||||
struct NaNEqual {
|
||||
template <typename T>
|
||||
bool operator()(T x, T y) {
|
||||
return x == y || (metal::isnan(x) && metal::isnan(y));
|
||||
}
|
||||
template <>
|
||||
bool operator()(complex64_t x, complex64_t y) {
|
||||
return x == y ||
|
||||
(metal::isnan(x.real) && metal::isnan(y.real) && metal::isnan(x.imag) &&
|
||||
metal::isnan(y.imag)) ||
|
||||
(x.real == y.real && metal::isnan(x.imag) && metal::isnan(y.imag)) ||
|
||||
(metal::isnan(x.real) && metal::isnan(y.real) && x.imag == y.imag);
|
||||
}
|
||||
};
|
||||
|
||||
struct Greater {
|
||||
template <typename T>
|
||||
bool operator()(T x, T y) {
|
||||
return x > y;
|
||||
}
|
||||
};
|
||||
|
||||
struct GreaterEqual {
|
||||
template <typename T>
|
||||
bool operator()(T x, T y) {
|
||||
return x >= y;
|
||||
}
|
||||
};
|
||||
|
||||
struct Less {
|
||||
template <typename T>
|
||||
bool operator()(T x, T y) {
|
||||
return x < y;
|
||||
}
|
||||
};
|
||||
|
||||
struct LessEqual {
|
||||
template <typename T>
|
||||
bool operator()(T x, T y) {
|
||||
return x <= y;
|
||||
}
|
||||
};
|
||||
|
||||
struct LogAddExp {
|
||||
template <typename T>
|
||||
T operator()(T x, T y) {
|
||||
if (metal::isnan(x) || metal::isnan(y)) {
|
||||
return metal::numeric_limits<T>::quiet_NaN();
|
||||
}
|
||||
constexpr T inf = metal::numeric_limits<T>::infinity();
|
||||
T maxval = metal::max(x, y);
|
||||
T minval = metal::min(x, y);
|
||||
return (minval == -inf || maxval == inf)
|
||||
? maxval
|
||||
: (maxval + log1p(metal::exp(minval - maxval)));
|
||||
};
|
||||
};
|
||||
|
||||
struct Maximum {
|
||||
template <typename T>
|
||||
metal::enable_if_t<metal::is_integral_v<T>, T> operator()(T x, T y) {
|
||||
return metal::max(x, y);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
metal::enable_if_t<!metal::is_integral_v<T>, T> operator()(T x, T y) {
|
||||
if (metal::isnan(x)) {
|
||||
return x;
|
||||
}
|
||||
return x > y ? x : y;
|
||||
}
|
||||
|
||||
template <>
|
||||
complex64_t operator()(complex64_t x, complex64_t y) {
|
||||
if (metal::isnan(x.real) || metal::isnan(x.imag)) {
|
||||
return x;
|
||||
}
|
||||
return x > y ? x : y;
|
||||
}
|
||||
};
|
||||
|
||||
struct Minimum {
|
||||
template <typename T>
|
||||
metal::enable_if_t<metal::is_integral_v<T>, T> operator()(T x, T y) {
|
||||
return metal::min(x, y);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
metal::enable_if_t<!metal::is_integral_v<T>, T> operator()(T x, T y) {
|
||||
if (metal::isnan(x)) {
|
||||
return x;
|
||||
}
|
||||
return x < y ? x : y;
|
||||
}
|
||||
|
||||
template <>
|
||||
complex64_t operator()(complex64_t x, complex64_t y) {
|
||||
if (metal::isnan(x.real) || metal::isnan(x.imag)) {
|
||||
return x;
|
||||
}
|
||||
return x < y ? x : y;
|
||||
}
|
||||
};
|
||||
|
||||
struct Multiply {
|
||||
template <typename T>
|
||||
T operator()(T x, T y) {
|
||||
return x * y;
|
||||
}
|
||||
};
|
||||
|
||||
struct NotEqual {
|
||||
template <typename T>
|
||||
bool operator()(T x, T y) {
|
||||
return x != y;
|
||||
}
|
||||
template <>
|
||||
bool operator()(complex64_t x, complex64_t y) {
|
||||
return x.real != y.real || x.imag != y.imag;
|
||||
}
|
||||
};
|
||||
|
||||
struct Power {
|
||||
template <typename T>
|
||||
metal::enable_if_t<!metal::is_integral_v<T>, T> operator()(T base, T exp) {
|
||||
return metal::pow(base, exp);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
metal::enable_if_t<metal::is_integral_v<T>, T> operator()(T base, T exp) {
|
||||
T res = 1;
|
||||
while (exp) {
|
||||
if (exp & 1) {
|
||||
res *= base;
|
||||
}
|
||||
exp >>= 1;
|
||||
base *= base;
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
template <>
|
||||
complex64_t operator()(complex64_t x, complex64_t y) {
|
||||
auto x_theta = metal::atan(x.imag / x.real);
|
||||
auto x_ln_r = 0.5 * metal::log(x.real * x.real + x.imag * x.imag);
|
||||
auto mag = metal::exp(y.real * x_ln_r - y.imag * x_theta);
|
||||
auto phase = y.imag * x_ln_r + y.real * x_theta;
|
||||
return {mag * metal::cos(phase), mag * metal::sin(phase)};
|
||||
}
|
||||
};
|
||||
|
||||
struct Subtract {
|
||||
template <typename T>
|
||||
T operator()(T x, T y) {
|
||||
return x - y;
|
||||
}
|
||||
};
|
||||
|
||||
struct LogicalAnd {
|
||||
template <typename T>
|
||||
T operator()(T x, T y) {
|
||||
return x && y;
|
||||
};
|
||||
};
|
||||
|
||||
struct LogicalOr {
|
||||
template <typename T>
|
||||
T operator()(T x, T y) {
|
||||
return x || y;
|
||||
};
|
||||
};
|
||||
|
||||
struct BitwiseAnd {
|
||||
template <typename T>
|
||||
T operator()(T x, T y) {
|
||||
return x & y;
|
||||
};
|
||||
};
|
||||
|
||||
struct BitwiseOr {
|
||||
template <typename T>
|
||||
T operator()(T x, T y) {
|
||||
return x | y;
|
||||
};
|
||||
};
|
||||
|
||||
struct BitwiseXor {
|
||||
template <typename T>
|
||||
T operator()(T x, T y) {
|
||||
return x ^ y;
|
||||
};
|
||||
};
|
||||
|
||||
struct LeftShift {
|
||||
template <typename T>
|
||||
T operator()(T x, T y) {
|
||||
return x << y;
|
||||
};
|
||||
};
|
||||
|
||||
struct RightShift {
|
||||
template <typename T>
|
||||
T operator()(T x, T y) {
|
||||
return x >> y;
|
||||
};
|
||||
};
|
||||
|
||||
struct ArcTan2 {
|
||||
template <typename T>
|
||||
T operator()(T y, T x) {
|
||||
return metal::precise::atan2(y, x);
|
||||
}
|
||||
};
|
||||
|
||||
struct DivMod {
|
||||
template <typename T>
|
||||
metal::array<T, 2> operator()(T x, T y) {
|
||||
return {FloorDivide{}(x, y), Remainder{}(x, y)};
|
||||
};
|
||||
};
|
140
mlx/backend/metal/kernels/binary_two.h
Normal file
140
mlx/backend/metal/kernels/binary_two.h
Normal file
@@ -0,0 +1,140 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
template <typename T, typename U, typename Op>
|
||||
[[kernel]] void binary_ss(
|
||||
device const T* a,
|
||||
device const T* b,
|
||||
device U* c,
|
||||
device U* d,
|
||||
uint index [[thread_position_in_grid]]) {
|
||||
auto out = Op()(a[0], b[0]);
|
||||
c[index] = out[0];
|
||||
d[index] = out[1];
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op>
|
||||
[[kernel]] void binary_sv(
|
||||
device const T* a,
|
||||
device const T* b,
|
||||
device U* c,
|
||||
device U* d,
|
||||
uint index [[thread_position_in_grid]]) {
|
||||
auto out = Op()(a[0], b[index]);
|
||||
c[index] = out[0];
|
||||
d[index] = out[1];
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op>
|
||||
[[kernel]] void binary_vs(
|
||||
device const T* a,
|
||||
device const T* b,
|
||||
device U* c,
|
||||
device U* d,
|
||||
uint index [[thread_position_in_grid]]) {
|
||||
auto out = Op()(a[index], b[0]);
|
||||
c[index] = out[0];
|
||||
d[index] = out[1];
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op>
|
||||
[[kernel]] void binary_vv(
|
||||
device const T* a,
|
||||
device const T* b,
|
||||
device U* c,
|
||||
device U* d,
|
||||
uint index [[thread_position_in_grid]]) {
|
||||
auto out = Op()(a[index], b[index]);
|
||||
c[index] = out[0];
|
||||
d[index] = out[1];
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op>
|
||||
[[kernel]] void binary_g_nd1(
|
||||
device const T* a,
|
||||
device const T* b,
|
||||
device U* c,
|
||||
device U* d,
|
||||
constant const size_t& a_stride,
|
||||
constant const size_t& b_stride,
|
||||
uint index [[thread_position_in_grid]]) {
|
||||
auto a_idx = elem_to_loc_1(index, a_stride);
|
||||
auto b_idx = elem_to_loc_1(index, b_stride);
|
||||
auto out = Op()(a[a_idx], b[b_idx]);
|
||||
c[index] = out[0];
|
||||
d[index] = out[1];
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op>
|
||||
[[kernel]] void binary_g_nd2(
|
||||
device const T* a,
|
||||
device const T* b,
|
||||
device U* c,
|
||||
device U* d,
|
||||
constant const size_t a_strides[2],
|
||||
constant const size_t b_strides[2],
|
||||
uint2 index [[thread_position_in_grid]],
|
||||
uint2 grid_dim [[threads_per_grid]]) {
|
||||
auto a_idx = elem_to_loc_2(index, a_strides);
|
||||
auto b_idx = elem_to_loc_2(index, b_strides);
|
||||
size_t out_idx = index.x + (size_t)grid_dim.x * index.y;
|
||||
auto out = Op()(a[a_idx], b[b_idx]);
|
||||
c[out_idx] = out[0];
|
||||
d[out_idx] = out[1];
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op>
|
||||
[[kernel]] void binary_g_nd3(
|
||||
device const T* a,
|
||||
device const T* b,
|
||||
device U* c,
|
||||
device U* d,
|
||||
constant const size_t a_strides[3],
|
||||
constant const size_t b_strides[3],
|
||||
uint3 index [[thread_position_in_grid]],
|
||||
uint3 grid_dim [[threads_per_grid]]) {
|
||||
auto a_idx = elem_to_loc_3(index, a_strides);
|
||||
auto b_idx = elem_to_loc_3(index, b_strides);
|
||||
size_t out_idx =
|
||||
index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z);
|
||||
auto out = Op()(a[a_idx], b[b_idx]);
|
||||
c[out_idx] = out[0];
|
||||
d[out_idx] = out[1];
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op, int DIM>
|
||||
[[kernel]] void binary_g_nd(
|
||||
device const T* a,
|
||||
device const T* b,
|
||||
device U* c,
|
||||
device U* d,
|
||||
constant const int shape[DIM],
|
||||
constant const size_t a_strides[DIM],
|
||||
constant const size_t b_strides[DIM],
|
||||
uint3 index [[thread_position_in_grid]],
|
||||
uint3 grid_dim [[threads_per_grid]]) {
|
||||
auto idx = elem_to_loc_2_nd<DIM>(index, shape, a_strides, b_strides);
|
||||
size_t out_idx =
|
||||
index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z);
|
||||
auto out = Op()(a[idx.x], b[idx.y]);
|
||||
c[out_idx] = out[0];
|
||||
d[out_idx] = out[1];
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op>
|
||||
[[kernel]] void binary_g(
|
||||
device const T* a,
|
||||
device const T* b,
|
||||
device U* c,
|
||||
device U* d,
|
||||
constant const int* shape,
|
||||
constant const size_t* a_strides,
|
||||
constant const size_t* b_strides,
|
||||
constant const int& ndim,
|
||||
uint3 index [[thread_position_in_grid]],
|
||||
uint3 grid_dim [[threads_per_grid]]) {
|
||||
auto idx = elem_to_loc_2_nd(index, shape, a_strides, b_strides, ndim);
|
||||
size_t out_idx = index.x + grid_dim.x * (index.y + grid_dim.y * index.z);
|
||||
auto out = Op()(a[idx.x], b[idx.y]);
|
||||
c[out_idx] = out[0];
|
||||
d[out_idx] = out[1];
|
||||
}
|
@@ -1,212 +1,24 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
|
||||
// Copyright © 2024 Apple Inc.
|
||||
#include <metal_integer>
|
||||
#include <metal_math>
|
||||
|
||||
#include "mlx/backend/metal/kernels/bf16.h"
|
||||
// clang-format off
|
||||
#include "mlx/backend/metal/kernels/utils.h"
|
||||
#include "mlx/backend/metal/kernels/binary_ops.h"
|
||||
#include "mlx/backend/metal/kernels/binary_two.h"
|
||||
|
||||
struct FloorDivide {
|
||||
template <typename T>
|
||||
T operator()(T x, T y) {
|
||||
return x / y;
|
||||
}
|
||||
template <>
|
||||
float operator()(float x, float y) {
|
||||
return trunc(x / y);
|
||||
}
|
||||
template <>
|
||||
half operator()(half x, half y) {
|
||||
return trunc(x / y);
|
||||
}
|
||||
template <>
|
||||
bfloat16_t operator()(bfloat16_t x, bfloat16_t y) {
|
||||
return trunc(x / y);
|
||||
}
|
||||
};
|
||||
|
||||
struct Remainder {
|
||||
template <typename T>
|
||||
metal::enable_if_t<metal::is_integral_v<T> & !metal::is_signed_v<T>, T>
|
||||
operator()(T x, T y) {
|
||||
return x % y;
|
||||
}
|
||||
template <typename T>
|
||||
metal::enable_if_t<metal::is_integral_v<T> & metal::is_signed_v<T>, T>
|
||||
operator()(T x, T y) {
|
||||
auto r = x % y;
|
||||
if (r != 0 && (r < 0 != y < 0)) {
|
||||
r += y;
|
||||
}
|
||||
return r;
|
||||
}
|
||||
template <typename T>
|
||||
metal::enable_if_t<!metal::is_integral_v<T>, T> operator()(T x, T y) {
|
||||
T r = fmod(x, y);
|
||||
if (r != 0 && (r < 0 != y < 0)) {
|
||||
r += y;
|
||||
}
|
||||
return r;
|
||||
}
|
||||
template <>
|
||||
complex64_t operator()(complex64_t x, complex64_t y) {
|
||||
return x % y;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, typename U, typename Op1, typename Op2>
|
||||
[[kernel]] void binary_op_s2s(
|
||||
device const T* a,
|
||||
device const T* b,
|
||||
device U* c,
|
||||
device U* d,
|
||||
uint index [[thread_position_in_grid]]) {
|
||||
c[index] = Op1()(a[0], b[0]);
|
||||
d[index] = Op2()(a[0], b[0]);
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op1, typename Op2>
|
||||
[[kernel]] void binary_op_ss(
|
||||
device const T* a,
|
||||
device const T* b,
|
||||
device U* c,
|
||||
device U* d,
|
||||
uint index [[thread_position_in_grid]]) {
|
||||
c[index] = Op1()(a[0], b[0]);
|
||||
d[index] = Op2()(a[0], b[0]);
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op1, typename Op2>
|
||||
[[kernel]] void binary_op_sv(
|
||||
device const T* a,
|
||||
device const T* b,
|
||||
device U* c,
|
||||
device U* d,
|
||||
uint index [[thread_position_in_grid]]) {
|
||||
c[index] = Op1()(a[0], b[index]);
|
||||
d[index] = Op2()(a[0], b[index]);
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op1, typename Op2>
|
||||
[[kernel]] void binary_op_vs(
|
||||
device const T* a,
|
||||
device const T* b,
|
||||
device U* c,
|
||||
device U* d,
|
||||
uint index [[thread_position_in_grid]]) {
|
||||
c[index] = Op1()(a[index], b[0]);
|
||||
d[index] = Op2()(a[index], b[0]);
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op1, typename Op2>
|
||||
[[kernel]] void binary_op_vv(
|
||||
device const T* a,
|
||||
device const T* b,
|
||||
device U* c,
|
||||
device U* d,
|
||||
uint index [[thread_position_in_grid]]) {
|
||||
c[index] = Op1()(a[index], b[index]);
|
||||
d[index] = Op2()(a[index], b[index]);
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op1, typename Op2>
|
||||
[[kernel]] void binary_op_g_nd1(
|
||||
device const T* a,
|
||||
device const T* b,
|
||||
device U* c,
|
||||
device U* d,
|
||||
constant const size_t& a_stride,
|
||||
constant const size_t& b_stride,
|
||||
uint index [[thread_position_in_grid]]) {
|
||||
auto a_idx = elem_to_loc_1(index, a_stride);
|
||||
auto b_idx = elem_to_loc_1(index, b_stride);
|
||||
c[index] = Op1()(a[a_idx], b[b_idx]);
|
||||
d[index] = Op2()(a[a_idx], b[b_idx]);
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op1, typename Op2>
|
||||
[[kernel]] void binary_op_g_nd2(
|
||||
device const T* a,
|
||||
device const T* b,
|
||||
device U* c,
|
||||
device U* d,
|
||||
constant const size_t a_strides[2],
|
||||
constant const size_t b_strides[2],
|
||||
uint2 index [[thread_position_in_grid]],
|
||||
uint2 grid_dim [[threads_per_grid]]) {
|
||||
auto a_idx = elem_to_loc_2(index, a_strides);
|
||||
auto b_idx = elem_to_loc_2(index, b_strides);
|
||||
size_t out_idx = index.x + (size_t)grid_dim.x * index.y;
|
||||
c[out_idx] = Op1()(a[a_idx], b[b_idx]);
|
||||
d[out_idx] = Op2()(a[a_idx], b[b_idx]);
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op1, typename Op2>
|
||||
[[kernel]] void binary_op_g_nd3(
|
||||
device const T* a,
|
||||
device const T* b,
|
||||
device U* c,
|
||||
device U* d,
|
||||
constant const size_t a_strides[3],
|
||||
constant const size_t b_strides[3],
|
||||
uint3 index [[thread_position_in_grid]],
|
||||
uint3 grid_dim [[threads_per_grid]]) {
|
||||
auto a_idx = elem_to_loc_3(index, a_strides);
|
||||
auto b_idx = elem_to_loc_3(index, b_strides);
|
||||
size_t out_idx =
|
||||
index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z);
|
||||
c[out_idx] = Op1()(a[a_idx], b[b_idx]);
|
||||
d[out_idx] = Op2()(a[a_idx], b[b_idx]);
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op1, typename Op2, int DIM>
|
||||
[[kernel]] void binary_op_g_nd(
|
||||
device const T* a,
|
||||
device const T* b,
|
||||
device U* c,
|
||||
device U* d,
|
||||
constant const int shape[DIM],
|
||||
constant const size_t a_strides[DIM],
|
||||
constant const size_t b_strides[DIM],
|
||||
uint3 index [[thread_position_in_grid]],
|
||||
uint3 grid_dim [[threads_per_grid]]) {
|
||||
auto idx = elem_to_loc_2_nd<DIM>(index, shape, a_strides, b_strides);
|
||||
size_t out_idx =
|
||||
index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z);
|
||||
c[out_idx] = Op1()(a[idx.x], b[idx.y]);
|
||||
d[out_idx] = Op2()(a[idx.x], b[idx.y]);
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op1, typename Op2>
|
||||
[[kernel]] void binary_op_g(
|
||||
device const T* a,
|
||||
device const T* b,
|
||||
device U* c,
|
||||
device U* d,
|
||||
constant const int* shape,
|
||||
constant const size_t* a_strides,
|
||||
constant const size_t* b_strides,
|
||||
constant const int& ndim,
|
||||
uint3 index [[thread_position_in_grid]],
|
||||
uint3 grid_dim [[threads_per_grid]]) {
|
||||
auto idx = elem_to_loc_2_nd(index, shape, a_strides, b_strides, ndim);
|
||||
size_t out_idx = index.x + grid_dim.x * (index.y + grid_dim.y * index.z);
|
||||
c[out_idx] = Op1()(a[idx.x], b[idx.y]);
|
||||
d[out_idx] = Op2()(a[idx.x], b[idx.y]);
|
||||
}
|
||||
|
||||
#define instantiate_binary(name, itype, otype, op1, op2, bopt) \
|
||||
#define instantiate_binary(name, itype, otype, op, bopt) \
|
||||
template [[host_name(name)]] [[kernel]] void \
|
||||
binary_op_##bopt<itype, otype, op1, op2>( \
|
||||
binary_##bopt<itype, otype, op>( \
|
||||
device const itype* a, \
|
||||
device const itype* b, \
|
||||
device otype* c, \
|
||||
device otype* d, \
|
||||
uint index [[thread_position_in_grid]]);
|
||||
|
||||
#define instantiate_binary_g_dim(name, itype, otype, op1, op2, dims) \
|
||||
template [[host_name(name "_" #dims)]] [[kernel]] void \
|
||||
binary_op_g_nd<itype, otype, op1, op2, dims>( \
|
||||
#define instantiate_binary_g_dim(name, itype, otype, op, dims) \
|
||||
template [[host_name("g" #dims name)]] [[kernel]] void \
|
||||
binary_g_nd<itype, otype, op, dims>( \
|
||||
device const itype* a, \
|
||||
device const itype* b, \
|
||||
device otype* c, \
|
||||
@@ -217,10 +29,9 @@ template <typename T, typename U, typename Op1, typename Op2>
|
||||
uint3 index [[thread_position_in_grid]], \
|
||||
uint3 grid_dim [[threads_per_grid]]);
|
||||
|
||||
// clang-format off
|
||||
#define instantiate_binary_g_nd(name, itype, otype, op1, op2) \
|
||||
template [[host_name(name "_1")]] [[kernel]] void \
|
||||
binary_op_g_nd1<itype, otype, op1, op2>( \
|
||||
#define instantiate_binary_g_nd(name, itype, otype, op) \
|
||||
template [[host_name("g1" name)]] [[kernel]] void \
|
||||
binary_g_nd1<itype, otype, op>( \
|
||||
device const itype* a, \
|
||||
device const itype* b, \
|
||||
device otype* c, \
|
||||
@@ -228,8 +39,8 @@ template <typename T, typename U, typename Op1, typename Op2>
|
||||
constant const size_t& a_stride, \
|
||||
constant const size_t& b_stride, \
|
||||
uint index [[thread_position_in_grid]]); \
|
||||
template [[host_name(name "_2")]] [[kernel]] void \
|
||||
binary_op_g_nd2<itype, otype, op1, op2>( \
|
||||
template [[host_name("g2" name)]] [[kernel]] void \
|
||||
binary_g_nd2<itype, otype, op>( \
|
||||
device const itype* a, \
|
||||
device const itype* b, \
|
||||
device otype* c, \
|
||||
@@ -238,8 +49,8 @@ template <typename T, typename U, typename Op1, typename Op2>
|
||||
constant const size_t b_strides[2], \
|
||||
uint2 index [[thread_position_in_grid]], \
|
||||
uint2 grid_dim [[threads_per_grid]]); \
|
||||
template [[host_name(name "_3")]] [[kernel]] void \
|
||||
binary_op_g_nd3<itype, otype, op1, op2>( \
|
||||
template [[host_name("g3" name)]] [[kernel]] void \
|
||||
binary_g_nd3<itype, otype, op>( \
|
||||
device const itype* a, \
|
||||
device const itype* b, \
|
||||
device otype* c, \
|
||||
@@ -248,12 +59,12 @@ template <typename T, typename U, typename Op1, typename Op2>
|
||||
constant const size_t b_strides[3], \
|
||||
uint3 index [[thread_position_in_grid]], \
|
||||
uint3 grid_dim [[threads_per_grid]]); \
|
||||
instantiate_binary_g_dim(name, itype, otype, op1, op2, 4) \
|
||||
instantiate_binary_g_dim(name, itype, otype, op1, op2, 5) // clang-format on
|
||||
instantiate_binary_g_dim(name, itype, otype, op, 4) \
|
||||
instantiate_binary_g_dim(name, itype, otype, op, 5)
|
||||
|
||||
#define instantiate_binary_g(name, itype, otype, op1, op2) \
|
||||
template [[host_name(name)]] [[kernel]] void \
|
||||
binary_op_g<itype, otype, op2, op2>( \
|
||||
#define instantiate_binary_g(name, itype, otype, op) \
|
||||
template [[host_name("gn" name)]] [[kernel]] void \
|
||||
binary_g<itype, otype, op>( \
|
||||
device const itype* a, \
|
||||
device const itype* b, \
|
||||
device otype* c, \
|
||||
@@ -265,33 +76,30 @@ template <typename T, typename U, typename Op1, typename Op2>
|
||||
uint3 index [[thread_position_in_grid]], \
|
||||
uint3 grid_dim [[threads_per_grid]]);
|
||||
|
||||
// clang-format off
|
||||
#define instantiate_binary_all(name, tname, itype, otype, op1, op2) \
|
||||
instantiate_binary("ss" #name #tname, itype, otype, op1, op2, ss) \
|
||||
instantiate_binary("sv" #name #tname, itype, otype, op1, op2, sv) \
|
||||
instantiate_binary("vs" #name #tname, itype, otype, op1, op2, vs) \
|
||||
instantiate_binary("vv" #name #tname, itype, otype, op1, op2, vv) \
|
||||
instantiate_binary_g("g" #name #tname, itype, otype, op1, op2) \
|
||||
instantiate_binary_g_nd("g" #name #tname, itype, otype, op1, op2) // clang-format on
|
||||
#define instantiate_binary_all(name, tname, itype, otype, op) \
|
||||
instantiate_binary("ss" #name #tname, itype, otype, op, ss) \
|
||||
instantiate_binary("sv" #name #tname, itype, otype, op, sv) \
|
||||
instantiate_binary("vs" #name #tname, itype, otype, op, vs) \
|
||||
instantiate_binary("vv" #name #tname, itype, otype, op, vv) \
|
||||
instantiate_binary_g(#name #tname, itype, otype, op) \
|
||||
instantiate_binary_g_nd(#name #tname, itype, otype, op)
|
||||
|
||||
// clang-format off
|
||||
#define instantiate_binary_float(name, op1, op2) \
|
||||
instantiate_binary_all(name, float16, half, half, op1, op2) \
|
||||
instantiate_binary_all(name, float32, float, float, op1, op2) \
|
||||
instantiate_binary_all(name, bfloat16, bfloat16_t, bfloat16_t, op1, op2) // clang-format on
|
||||
#define instantiate_binary_float(name, op) \
|
||||
instantiate_binary_all(name, float16, half, half, op) \
|
||||
instantiate_binary_all(name, float32, float, float, op) \
|
||||
instantiate_binary_all(name, bfloat16, bfloat16_t, bfloat16_t, op)
|
||||
|
||||
// clang-format off
|
||||
#define instantiate_binary_types(name, op1, op2) \
|
||||
instantiate_binary_all(name, bool_, bool, bool, op1, op2) \
|
||||
instantiate_binary_all(name, uint8, uint8_t, uint8_t, op1, op2) \
|
||||
instantiate_binary_all(name, uint16, uint16_t, uint16_t, op1, op2) \
|
||||
instantiate_binary_all(name, uint32, uint32_t, uint32_t, op1, op2) \
|
||||
instantiate_binary_all(name, uint64, uint64_t, uint64_t, op1, op2) \
|
||||
instantiate_binary_all(name, int8, int8_t, int8_t, op1, op2) \
|
||||
instantiate_binary_all(name, int16, int16_t, int16_t, op1, op2) \
|
||||
instantiate_binary_all(name, int32, int32_t, int32_t, op1, op2) \
|
||||
instantiate_binary_all(name, int64, int64_t, int64_t, op1, op2) \
|
||||
instantiate_binary_all(name, complex64, complex64_t, complex64_t, op1, op2) \
|
||||
instantiate_binary_float(name, op1, op2)
|
||||
#define instantiate_binary_types(name, op) \
|
||||
instantiate_binary_all(name, bool_, bool, bool, op) \
|
||||
instantiate_binary_all(name, uint8, uint8_t, uint8_t, op) \
|
||||
instantiate_binary_all(name, uint16, uint16_t, uint16_t, op) \
|
||||
instantiate_binary_all(name, uint32, uint32_t, uint32_t, op) \
|
||||
instantiate_binary_all(name, uint64, uint64_t, uint64_t, op) \
|
||||
instantiate_binary_all(name, int8, int8_t, int8_t, op) \
|
||||
instantiate_binary_all(name, int16, int16_t, int16_t, op) \
|
||||
instantiate_binary_all(name, int32, int32_t, int32_t, op) \
|
||||
instantiate_binary_all(name, int64, int64_t, int64_t, op) \
|
||||
instantiate_binary_all(name, complex64, complex64_t, complex64_t, op) \
|
||||
instantiate_binary_float(name, op)
|
||||
|
||||
instantiate_binary_types(divmod, FloorDivide, Remainder) // clang-format on
|
||||
instantiate_binary_types(divmod, DivMod) // clang-format on
|
||||
|
@@ -1,7 +0,0 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include "mlx/backend/metal/kernels/binary.h"
|
||||
#include "mlx/backend/metal/kernels/ternary.h"
|
||||
#include "mlx/backend/metal/kernels/unary.h"
|
||||
|
||||
typedef half float16_t;
|
@@ -109,6 +109,7 @@ template <typename T, int N>
|
||||
bool valid = n < params->N;
|
||||
|
||||
// Unroll dimensions
|
||||
int kernel_stride = 1;
|
||||
for (int i = N - 1; i >= 0; --i) {
|
||||
int os_ = (oS % params->oS[i]);
|
||||
int ws_ = (wS % params->wS[i]);
|
||||
@@ -125,7 +126,8 @@ template <typename T, int N>
|
||||
oS /= params->oS[i];
|
||||
wS /= params->wS[i];
|
||||
|
||||
out += ws_ * params->str[i];
|
||||
out += ws_ * kernel_stride;
|
||||
kernel_stride *= params->wS[i];
|
||||
}
|
||||
|
||||
if (valid) {
|
||||
@@ -648,4 +650,4 @@ winograd_conv_2d_output_transform(
|
||||
|
||||
// clang-format off
|
||||
instantiate_winograd_conv_2d(float32, float);
|
||||
instantiate_winograd_conv_2d(float16, half); // clang-format on
|
||||
instantiate_winograd_conv_2d(float16, half); // clang-format on
|
||||
|
144
mlx/backend/metal/kernels/copy.h
Normal file
144
mlx/backend/metal/kernels/copy.h
Normal file
@@ -0,0 +1,144 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
template <typename T, typename U>
|
||||
[[kernel]] void copy_s(
|
||||
device const T* src [[buffer(0)]],
|
||||
device U* dst [[buffer(1)]],
|
||||
uint index [[thread_position_in_grid]]) {
|
||||
dst[index] = static_cast<U>(src[0]);
|
||||
}
|
||||
|
||||
template <typename T, typename U>
|
||||
[[kernel]] void copy_v(
|
||||
device const T* src [[buffer(0)]],
|
||||
device U* dst [[buffer(1)]],
|
||||
uint index [[thread_position_in_grid]]) {
|
||||
dst[index] = static_cast<U>(src[index]);
|
||||
}
|
||||
|
||||
template <typename T, typename U>
|
||||
[[kernel]] void copy_g_nd1(
|
||||
device const T* src [[buffer(0)]],
|
||||
device U* dst [[buffer(1)]],
|
||||
constant const int64_t& src_stride [[buffer(3)]],
|
||||
uint index [[thread_position_in_grid]]) {
|
||||
auto src_idx = elem_to_loc_1(index, src_stride);
|
||||
dst[index] = static_cast<U>(src[src_idx]);
|
||||
}
|
||||
|
||||
template <typename T, typename U>
|
||||
[[kernel]] void copy_g_nd2(
|
||||
device const T* src [[buffer(0)]],
|
||||
device U* dst [[buffer(1)]],
|
||||
constant const int64_t* src_strides [[buffer(3)]],
|
||||
uint2 index [[thread_position_in_grid]],
|
||||
uint2 grid_dim [[threads_per_grid]]) {
|
||||
auto src_idx = elem_to_loc_2(index, src_strides);
|
||||
int64_t dst_idx = index.x + (int64_t)grid_dim.x * index.y;
|
||||
dst[dst_idx] = static_cast<U>(src[src_idx]);
|
||||
}
|
||||
|
||||
template <typename T, typename U>
|
||||
[[kernel]] void copy_g_nd3(
|
||||
device const T* src [[buffer(0)]],
|
||||
device U* dst [[buffer(1)]],
|
||||
constant const int64_t* src_strides [[buffer(3)]],
|
||||
uint3 index [[thread_position_in_grid]],
|
||||
uint3 grid_dim [[threads_per_grid]]) {
|
||||
auto src_idx = elem_to_loc_3(index, src_strides);
|
||||
int64_t dst_idx =
|
||||
index.x + (int64_t)grid_dim.x * (index.y + (int64_t)grid_dim.y * index.z);
|
||||
dst[dst_idx] = static_cast<U>(src[src_idx]);
|
||||
}
|
||||
|
||||
template <typename T, typename U, int DIM>
|
||||
[[kernel]] void copy_g_nd(
|
||||
device const T* src [[buffer(0)]],
|
||||
device U* dst [[buffer(1)]],
|
||||
constant const int* src_shape [[buffer(2)]],
|
||||
constant const int64_t* src_strides [[buffer(3)]],
|
||||
uint3 index [[thread_position_in_grid]],
|
||||
uint3 grid_dim [[threads_per_grid]]) {
|
||||
auto src_idx = elem_to_loc_nd<DIM>(index, src_shape, src_strides);
|
||||
int64_t dst_idx =
|
||||
index.x + (int64_t)grid_dim.x * (index.y + (int64_t)grid_dim.y * index.z);
|
||||
dst[dst_idx] = static_cast<U>(src[src_idx]);
|
||||
}
|
||||
|
||||
template <typename T, typename U>
|
||||
[[kernel]] void copy_g(
|
||||
device const T* src [[buffer(0)]],
|
||||
device U* dst [[buffer(1)]],
|
||||
constant const int* src_shape [[buffer(2)]],
|
||||
constant const int64_t* src_strides [[buffer(3)]],
|
||||
constant const int& ndim [[buffer(5)]],
|
||||
uint3 index [[thread_position_in_grid]],
|
||||
uint3 grid_dim [[threads_per_grid]]) {
|
||||
auto src_idx = elem_to_loc(index, src_shape, src_strides, ndim);
|
||||
int64_t dst_idx =
|
||||
index.x + (int64_t)grid_dim.x * (index.y + (int64_t)grid_dim.y * index.z);
|
||||
dst[dst_idx] = static_cast<U>(src[src_idx]);
|
||||
}
|
||||
|
||||
template <typename T, typename U>
|
||||
[[kernel]] void copy_gg_nd1(
|
||||
device const T* src [[buffer(0)]],
|
||||
device U* dst [[buffer(1)]],
|
||||
constant const int64_t& src_stride [[buffer(3)]],
|
||||
constant const int64_t& dst_stride [[buffer(4)]],
|
||||
uint index [[thread_position_in_grid]]) {
|
||||
auto src_idx = elem_to_loc_1(index, src_stride);
|
||||
auto dst_idx = elem_to_loc_1(index, dst_stride);
|
||||
dst[dst_idx] = static_cast<U>(src[src_idx]);
|
||||
}
|
||||
|
||||
template <typename T, typename U>
|
||||
[[kernel]] void copy_gg_nd2(
|
||||
device const T* src [[buffer(0)]],
|
||||
device U* dst [[buffer(1)]],
|
||||
constant const int64_t* src_strides [[buffer(3)]],
|
||||
constant const int64_t* dst_strides [[buffer(4)]],
|
||||
uint2 index [[thread_position_in_grid]]) {
|
||||
auto src_idx = elem_to_loc_2(index, src_strides);
|
||||
auto dst_idx = elem_to_loc_2(index, dst_strides);
|
||||
dst[dst_idx] = static_cast<U>(src[src_idx]);
|
||||
}
|
||||
|
||||
template <typename T, typename U>
|
||||
[[kernel]] void copy_gg_nd3(
|
||||
device const T* src [[buffer(0)]],
|
||||
device U* dst [[buffer(1)]],
|
||||
constant const int64_t* src_strides [[buffer(3)]],
|
||||
constant const int64_t* dst_strides [[buffer(4)]],
|
||||
uint3 index [[thread_position_in_grid]]) {
|
||||
auto src_idx = elem_to_loc_3(index, src_strides);
|
||||
auto dst_idx = elem_to_loc_3(index, dst_strides);
|
||||
dst[dst_idx] = static_cast<U>(src[src_idx]);
|
||||
}
|
||||
|
||||
template <typename T, typename U, int DIM>
|
||||
[[kernel]] void copy_gg_nd(
|
||||
device const T* src [[buffer(0)]],
|
||||
device U* dst [[buffer(1)]],
|
||||
constant const int* src_shape [[buffer(2)]],
|
||||
constant const int64_t* src_strides [[buffer(3)]],
|
||||
constant const int64_t* dst_strides [[buffer(4)]],
|
||||
uint3 index [[thread_position_in_grid]]) {
|
||||
auto src_idx = elem_to_loc_nd<DIM>(index, src_shape, src_strides);
|
||||
auto dst_idx = elem_to_loc_nd<DIM>(index, src_shape, dst_strides);
|
||||
dst[dst_idx] = static_cast<U>(src[src_idx]);
|
||||
}
|
||||
|
||||
template <typename T, typename U>
|
||||
[[kernel]] void copy_gg(
|
||||
device const T* src [[buffer(0)]],
|
||||
device U* dst [[buffer(1)]],
|
||||
constant const int* src_shape [[buffer(2)]],
|
||||
constant const int64_t* src_strides [[buffer(3)]],
|
||||
constant const int64_t* dst_strides [[buffer(4)]],
|
||||
constant const int& ndim [[buffer(5)]],
|
||||
uint3 index [[thread_position_in_grid]]) {
|
||||
auto src_idx = elem_to_loc(index, src_shape, src_strides, ndim);
|
||||
auto dst_idx = elem_to_loc(index, src_shape, dst_strides, ndim);
|
||||
dst[dst_idx] = static_cast<U>(src[src_idx]);
|
||||
}
|
@@ -1,150 +1,9 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#include "mlx/backend/metal/kernels/bf16.h"
|
||||
// clang-format off
|
||||
#include "mlx/backend/metal/kernels/utils.h"
|
||||
|
||||
template <typename T, typename U>
|
||||
[[kernel]] void copy_s(
|
||||
device const T* src [[buffer(0)]],
|
||||
device U* dst [[buffer(1)]],
|
||||
uint index [[thread_position_in_grid]]) {
|
||||
dst[index] = static_cast<U>(src[0]);
|
||||
}
|
||||
|
||||
template <typename T, typename U>
|
||||
[[kernel]] void copy_v(
|
||||
device const T* src [[buffer(0)]],
|
||||
device U* dst [[buffer(1)]],
|
||||
uint index [[thread_position_in_grid]]) {
|
||||
dst[index] = static_cast<U>(src[index]);
|
||||
}
|
||||
|
||||
template <typename T, typename U>
|
||||
[[kernel]] void copy_g_nd1(
|
||||
device const T* src [[buffer(0)]],
|
||||
device U* dst [[buffer(1)]],
|
||||
constant const int64_t& src_stride [[buffer(3)]],
|
||||
uint index [[thread_position_in_grid]]) {
|
||||
auto src_idx = elem_to_loc_1(index, src_stride);
|
||||
dst[index] = static_cast<U>(src[src_idx]);
|
||||
}
|
||||
|
||||
template <typename T, typename U>
|
||||
[[kernel]] void copy_g_nd2(
|
||||
device const T* src [[buffer(0)]],
|
||||
device U* dst [[buffer(1)]],
|
||||
constant const int64_t* src_strides [[buffer(3)]],
|
||||
uint2 index [[thread_position_in_grid]],
|
||||
uint2 grid_dim [[threads_per_grid]]) {
|
||||
auto src_idx = elem_to_loc_2(index, src_strides);
|
||||
int64_t dst_idx = index.x + (int64_t)grid_dim.x * index.y;
|
||||
dst[dst_idx] = static_cast<U>(src[src_idx]);
|
||||
}
|
||||
|
||||
template <typename T, typename U>
|
||||
[[kernel]] void copy_g_nd3(
|
||||
device const T* src [[buffer(0)]],
|
||||
device U* dst [[buffer(1)]],
|
||||
constant const int64_t* src_strides [[buffer(3)]],
|
||||
uint3 index [[thread_position_in_grid]],
|
||||
uint3 grid_dim [[threads_per_grid]]) {
|
||||
auto src_idx = elem_to_loc_3(index, src_strides);
|
||||
int64_t dst_idx =
|
||||
index.x + (int64_t)grid_dim.x * (index.y + (int64_t)grid_dim.y * index.z);
|
||||
dst[dst_idx] = static_cast<U>(src[src_idx]);
|
||||
}
|
||||
|
||||
template <typename T, typename U, int DIM>
|
||||
[[kernel]] void copy_g_nd(
|
||||
device const T* src [[buffer(0)]],
|
||||
device U* dst [[buffer(1)]],
|
||||
constant const int* src_shape [[buffer(2)]],
|
||||
constant const int64_t* src_strides [[buffer(3)]],
|
||||
uint3 index [[thread_position_in_grid]],
|
||||
uint3 grid_dim [[threads_per_grid]]) {
|
||||
auto src_idx = elem_to_loc_nd<DIM>(index, src_shape, src_strides);
|
||||
int64_t dst_idx =
|
||||
index.x + (int64_t)grid_dim.x * (index.y + (int64_t)grid_dim.y * index.z);
|
||||
dst[dst_idx] = static_cast<U>(src[src_idx]);
|
||||
}
|
||||
|
||||
template <typename T, typename U>
|
||||
[[kernel]] void copy_g(
|
||||
device const T* src [[buffer(0)]],
|
||||
device U* dst [[buffer(1)]],
|
||||
constant const int* src_shape [[buffer(2)]],
|
||||
constant const int64_t* src_strides [[buffer(3)]],
|
||||
constant const int& ndim [[buffer(5)]],
|
||||
uint3 index [[thread_position_in_grid]],
|
||||
uint3 grid_dim [[threads_per_grid]]) {
|
||||
auto src_idx = elem_to_loc(index, src_shape, src_strides, ndim);
|
||||
int64_t dst_idx =
|
||||
index.x + (int64_t)grid_dim.x * (index.y + (int64_t)grid_dim.y * index.z);
|
||||
dst[dst_idx] = static_cast<U>(src[src_idx]);
|
||||
}
|
||||
|
||||
template <typename T, typename U>
|
||||
[[kernel]] void copy_gg_nd1(
|
||||
device const T* src [[buffer(0)]],
|
||||
device U* dst [[buffer(1)]],
|
||||
constant const int64_t& src_stride [[buffer(3)]],
|
||||
constant const int64_t& dst_stride [[buffer(4)]],
|
||||
uint index [[thread_position_in_grid]]) {
|
||||
auto src_idx = elem_to_loc_1(index, src_stride);
|
||||
auto dst_idx = elem_to_loc_1(index, dst_stride);
|
||||
dst[dst_idx] = static_cast<U>(src[src_idx]);
|
||||
}
|
||||
|
||||
template <typename T, typename U>
|
||||
[[kernel]] void copy_gg_nd2(
|
||||
device const T* src [[buffer(0)]],
|
||||
device U* dst [[buffer(1)]],
|
||||
constant const int64_t* src_strides [[buffer(3)]],
|
||||
constant const int64_t* dst_strides [[buffer(4)]],
|
||||
uint2 index [[thread_position_in_grid]]) {
|
||||
auto src_idx = elem_to_loc_2(index, src_strides);
|
||||
auto dst_idx = elem_to_loc_2(index, dst_strides);
|
||||
dst[dst_idx] = static_cast<U>(src[src_idx]);
|
||||
}
|
||||
|
||||
template <typename T, typename U>
|
||||
[[kernel]] void copy_gg_nd3(
|
||||
device const T* src [[buffer(0)]],
|
||||
device U* dst [[buffer(1)]],
|
||||
constant const int64_t* src_strides [[buffer(3)]],
|
||||
constant const int64_t* dst_strides [[buffer(4)]],
|
||||
uint3 index [[thread_position_in_grid]]) {
|
||||
auto src_idx = elem_to_loc_3(index, src_strides);
|
||||
auto dst_idx = elem_to_loc_3(index, dst_strides);
|
||||
dst[dst_idx] = static_cast<U>(src[src_idx]);
|
||||
}
|
||||
|
||||
template <typename T, typename U, int DIM>
|
||||
[[kernel]] void copy_gg_nd(
|
||||
device const T* src [[buffer(0)]],
|
||||
device U* dst [[buffer(1)]],
|
||||
constant const int* src_shape [[buffer(2)]],
|
||||
constant const int64_t* src_strides [[buffer(3)]],
|
||||
constant const int64_t* dst_strides [[buffer(4)]],
|
||||
uint3 index [[thread_position_in_grid]]) {
|
||||
auto src_idx = elem_to_loc_nd<DIM>(index, src_shape, src_strides);
|
||||
auto dst_idx = elem_to_loc_nd<DIM>(index, src_shape, dst_strides);
|
||||
dst[dst_idx] = static_cast<U>(src[src_idx]);
|
||||
}
|
||||
|
||||
template <typename T, typename U>
|
||||
[[kernel]] void copy_gg(
|
||||
device const T* src [[buffer(0)]],
|
||||
device U* dst [[buffer(1)]],
|
||||
constant const int* src_shape [[buffer(2)]],
|
||||
constant const int64_t* src_strides [[buffer(3)]],
|
||||
constant const int64_t* dst_strides [[buffer(4)]],
|
||||
constant const int& ndim [[buffer(5)]],
|
||||
uint3 index [[thread_position_in_grid]]) {
|
||||
auto src_idx = elem_to_loc(index, src_shape, src_strides, ndim);
|
||||
auto dst_idx = elem_to_loc(index, src_shape, dst_strides, ndim);
|
||||
dst[dst_idx] = static_cast<U>(src[src_idx]);
|
||||
}
|
||||
#include "mlx/backend/metal/kernels/bf16.h"
|
||||
#include "mlx/backend/metal/kernels/copy.h"
|
||||
|
||||
#define instantiate_copy(name, itype, otype, ctype) \
|
||||
template [[host_name(name)]] [[kernel]] void copy_##ctype<itype, otype>( \
|
||||
@@ -152,92 +11,90 @@ template <typename T, typename U>
|
||||
device otype* dst [[buffer(1)]], \
|
||||
uint index [[thread_position_in_grid]]);
|
||||
|
||||
#define instantiate_copy_g_dim(name, itype, otype, dims) \
|
||||
template [[host_name(name "_" #dims)]] [[kernel]] void \
|
||||
copy_g_nd<itype, otype, dims>( \
|
||||
device const itype* src [[buffer(0)]], \
|
||||
device otype* dst [[buffer(1)]], \
|
||||
constant const int* src_shape [[buffer(2)]], \
|
||||
constant const int64_t* src_strides [[buffer(3)]], \
|
||||
uint3 index [[thread_position_in_grid]], \
|
||||
uint3 grid_dim [[threads_per_grid]]); \
|
||||
template [[host_name("g" name "_" #dims)]] [[kernel]] void \
|
||||
copy_gg_nd<itype, otype, dims>( \
|
||||
device const itype* src [[buffer(0)]], \
|
||||
device otype* dst [[buffer(1)]], \
|
||||
constant const int* src_shape [[buffer(2)]], \
|
||||
constant const int64_t* src_strides [[buffer(3)]], \
|
||||
constant const int64_t* dst_strides [[buffer(4)]], \
|
||||
#define instantiate_copy_g_dim(name, itype, otype, dims) \
|
||||
template [[host_name("g" #dims "_" name)]] [[kernel]] void \
|
||||
copy_g_nd<itype, otype, dims>( \
|
||||
device const itype* src [[buffer(0)]], \
|
||||
device otype* dst [[buffer(1)]], \
|
||||
constant const int* src_shape [[buffer(2)]], \
|
||||
constant const int64_t* src_strides [[buffer(3)]], \
|
||||
uint3 index [[thread_position_in_grid]], \
|
||||
uint3 grid_dim [[threads_per_grid]]); \
|
||||
template [[host_name("gg" #dims "_" name)]] [[kernel]] void \
|
||||
copy_gg_nd<itype, otype, dims>( \
|
||||
device const itype* src [[buffer(0)]], \
|
||||
device otype* dst [[buffer(1)]], \
|
||||
constant const int* src_shape [[buffer(2)]], \
|
||||
constant const int64_t* src_strides [[buffer(3)]], \
|
||||
constant const int64_t* dst_strides [[buffer(4)]], \
|
||||
uint3 index [[thread_position_in_grid]]);
|
||||
|
||||
#define instantiate_copy_g_nd(name, itype, otype) \
|
||||
template [[host_name(name "_1")]] [[kernel]] void copy_g_nd1<itype, otype>( \
|
||||
device const itype* src [[buffer(0)]], \
|
||||
device otype* dst [[buffer(1)]], \
|
||||
constant const int64_t& src_stride [[buffer(3)]], \
|
||||
uint index [[thread_position_in_grid]]); \
|
||||
template [[host_name(name "_2")]] [[kernel]] void copy_g_nd2<itype, otype>( \
|
||||
device const itype* src [[buffer(0)]], \
|
||||
device otype* dst [[buffer(1)]], \
|
||||
constant const int64_t* src_strides [[buffer(3)]], \
|
||||
uint2 index [[thread_position_in_grid]], \
|
||||
uint2 grid_dim [[threads_per_grid]]); \
|
||||
template [[host_name(name "_3")]] [[kernel]] void copy_g_nd3<itype, otype>( \
|
||||
device const itype* src [[buffer(0)]], \
|
||||
device otype* dst [[buffer(1)]], \
|
||||
constant const int64_t* src_strides [[buffer(3)]], \
|
||||
uint3 index [[thread_position_in_grid]], \
|
||||
uint3 grid_dim [[threads_per_grid]]); \
|
||||
template [[host_name("g" name "_1")]] [[kernel]] void \
|
||||
copy_gg_nd1<itype, otype>( \
|
||||
device const itype* src [[buffer(0)]], \
|
||||
device otype* dst [[buffer(1)]], \
|
||||
constant const int64_t& src_stride [[buffer(3)]], \
|
||||
constant const int64_t& dst_stride [[buffer(4)]], \
|
||||
uint index [[thread_position_in_grid]]); \
|
||||
template [[host_name("g" name "_2")]] [[kernel]] void \
|
||||
copy_gg_nd2<itype, otype>( \
|
||||
device const itype* src [[buffer(0)]], \
|
||||
device otype* dst [[buffer(1)]], \
|
||||
constant const int64_t* src_strides [[buffer(3)]], \
|
||||
constant const int64_t* dst_strides [[buffer(4)]], \
|
||||
uint2 index [[thread_position_in_grid]]); \
|
||||
template [[host_name("g" name "_3")]] [[kernel]] void \
|
||||
copy_gg_nd3<itype, otype>( \
|
||||
device const itype* src [[buffer(0)]], \
|
||||
device otype* dst [[buffer(1)]], \
|
||||
constant const int64_t* src_strides [[buffer(3)]], \
|
||||
constant const int64_t* dst_strides [[buffer(4)]], \
|
||||
uint3 index [[thread_position_in_grid]]); \
|
||||
instantiate_copy_g_dim(name, itype, otype, 4) \
|
||||
instantiate_copy_g_dim(name, itype, otype, 5)
|
||||
#define instantiate_copy_g_nd(name, itype, otype) \
|
||||
template [[host_name("g1_" name)]] [[kernel]] void copy_g_nd1<itype, otype>( \
|
||||
device const itype* src [[buffer(0)]], \
|
||||
device otype* dst [[buffer(1)]], \
|
||||
constant const int64_t& src_stride [[buffer(3)]], \
|
||||
uint index [[thread_position_in_grid]]); \
|
||||
template [[host_name("g2_" name)]] [[kernel]] void copy_g_nd2<itype, otype>( \
|
||||
device const itype* src [[buffer(0)]], \
|
||||
device otype* dst [[buffer(1)]], \
|
||||
constant const int64_t* src_strides [[buffer(3)]], \
|
||||
uint2 index [[thread_position_in_grid]], \
|
||||
uint2 grid_dim [[threads_per_grid]]); \
|
||||
template [[host_name("g3_" name)]] [[kernel]] void copy_g_nd3<itype, otype>( \
|
||||
device const itype* src [[buffer(0)]], \
|
||||
device otype* dst [[buffer(1)]], \
|
||||
constant const int64_t* src_strides [[buffer(3)]], \
|
||||
uint3 index [[thread_position_in_grid]], \
|
||||
uint3 grid_dim [[threads_per_grid]]); \
|
||||
template [[host_name("gg1_" name )]] [[kernel]] void \
|
||||
copy_gg_nd1<itype, otype>( \
|
||||
device const itype* src [[buffer(0)]], \
|
||||
device otype* dst [[buffer(1)]], \
|
||||
constant const int64_t& src_stride [[buffer(3)]], \
|
||||
constant const int64_t& dst_stride [[buffer(4)]], \
|
||||
uint index [[thread_position_in_grid]]); \
|
||||
template [[host_name("gg2_" name)]] [[kernel]] void \
|
||||
copy_gg_nd2<itype, otype>( \
|
||||
device const itype* src [[buffer(0)]], \
|
||||
device otype* dst [[buffer(1)]], \
|
||||
constant const int64_t* src_strides [[buffer(3)]], \
|
||||
constant const int64_t* dst_strides [[buffer(4)]], \
|
||||
uint2 index [[thread_position_in_grid]]); \
|
||||
template [[host_name("gg3_" name)]] [[kernel]] void \
|
||||
copy_gg_nd3<itype, otype>( \
|
||||
device const itype* src [[buffer(0)]], \
|
||||
device otype* dst [[buffer(1)]], \
|
||||
constant const int64_t* src_strides [[buffer(3)]], \
|
||||
constant const int64_t* dst_strides [[buffer(4)]], \
|
||||
uint3 index [[thread_position_in_grid]]); \
|
||||
instantiate_copy_g_dim(name, itype, otype, 4) \
|
||||
instantiate_copy_g_dim(name, itype, otype, 5)
|
||||
|
||||
#define instantiate_copy_g(name, itype, otype) \
|
||||
template [[host_name(name)]] [[kernel]] void copy_g<itype, otype>( \
|
||||
device const itype* src [[buffer(0)]], \
|
||||
device otype* dst [[buffer(1)]], \
|
||||
constant const int* src_shape [[buffer(2)]], \
|
||||
constant const int64_t* src_strides [[buffer(3)]], \
|
||||
constant const int& ndim [[buffer(5)]], \
|
||||
uint3 index [[thread_position_in_grid]], \
|
||||
uint3 grid_dim [[threads_per_grid]]); \
|
||||
template [[host_name("g" name)]] [[kernel]] void copy_gg<itype, otype>( \
|
||||
device const itype* src [[buffer(0)]], \
|
||||
device otype* dst [[buffer(1)]], \
|
||||
constant const int* src_shape [[buffer(2)]], \
|
||||
constant const int64_t* src_strides [[buffer(3)]], \
|
||||
constant const int64_t* dst_strides [[buffer(4)]], \
|
||||
constant const int& ndim [[buffer(5)]], \
|
||||
#define instantiate_copy_g(name, itype, otype) \
|
||||
template [[host_name("g_" name)]] [[kernel]] void copy_g<itype, otype>( \
|
||||
device const itype* src [[buffer(0)]], \
|
||||
device otype* dst [[buffer(1)]], \
|
||||
constant const int* src_shape [[buffer(2)]], \
|
||||
constant const int64_t* src_strides [[buffer(3)]], \
|
||||
constant const int& ndim [[buffer(5)]], \
|
||||
uint3 index [[thread_position_in_grid]], \
|
||||
uint3 grid_dim [[threads_per_grid]]); \
|
||||
template [[host_name("gg_" name)]] [[kernel]] void copy_gg<itype, otype>( \
|
||||
device const itype* src [[buffer(0)]], \
|
||||
device otype* dst [[buffer(1)]], \
|
||||
constant const int* src_shape [[buffer(2)]], \
|
||||
constant const int64_t* src_strides [[buffer(3)]], \
|
||||
constant const int64_t* dst_strides [[buffer(4)]], \
|
||||
constant const int& ndim [[buffer(5)]], \
|
||||
uint3 index [[thread_position_in_grid]]);
|
||||
|
||||
// clang-format off
|
||||
#define instantiate_copy_all(tname, itype, otype) \
|
||||
instantiate_copy("scopy" #tname, itype, otype, s) \
|
||||
instantiate_copy("vcopy" #tname, itype, otype, v) \
|
||||
instantiate_copy_g("gcopy" #tname, itype, otype) \
|
||||
instantiate_copy_g_nd("gcopy" #tname, itype, otype) // clang-format on
|
||||
#define instantiate_copy_all(tname, itype, otype) \
|
||||
instantiate_copy("s_copy" #tname, itype, otype, s) \
|
||||
instantiate_copy("v_copy" #tname, itype, otype, v) \
|
||||
instantiate_copy_g("copy" #tname, itype, otype) \
|
||||
instantiate_copy_g_nd("copy" #tname, itype, otype)
|
||||
|
||||
// clang-format off
|
||||
#define instantiate_copy_itype(itname, itype) \
|
||||
instantiate_copy_all(itname ##bool_, itype, bool) \
|
||||
instantiate_copy_all(itname ##uint8, itype, uint8_t) \
|
||||
|
@@ -2,17 +2,14 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#ifdef __METAL__
|
||||
#if defined __METAL__ || defined MLX_METAL_JIT
|
||||
#define MTL_CONST constant
|
||||
#else
|
||||
#define MTL_CONST
|
||||
#endif
|
||||
|
||||
static MTL_CONST constexpr int MAX_BINARY_SPECIALIZED_DIMS = 5;
|
||||
static MTL_CONST constexpr int MAX_COPY_SPECIALIZED_DIMS = 5;
|
||||
static MTL_CONST constexpr int MAX_REDUCE_SPECIALIZED_DIMS = 4;
|
||||
static MTL_CONST constexpr int REDUCE_N_READS = 16;
|
||||
static MTL_CONST constexpr int SOFTMAX_N_READS = 4;
|
||||
static MTL_CONST constexpr int SOFTMAX_LOOPED_LIMIT = 4096;
|
||||
static MTL_CONST constexpr int RMS_N_READS = 4;
|
||||
static MTL_CONST constexpr int RMS_LOOPED_LIMIT = 4096;
|
||||
|
@@ -1,7 +1,6 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <metal_math>
|
||||
|
||||
/*
|
||||
@@ -67,4 +66,4 @@ float erfinv(float a) {
|
||||
p = metal::fma(p, t, 8.86226892e-1f); // 0x1.c5bf88p-1
|
||||
}
|
||||
return a * p;
|
||||
}
|
||||
}
|
||||
|
45
mlx/backend/metal/kernels/gather.h
Normal file
45
mlx/backend/metal/kernels/gather.h
Normal file
@@ -0,0 +1,45 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "mlx/backend/metal/kernels/indexing.h"
|
||||
|
||||
template <typename T, typename IdxT, int NIDX, int IDX_NDIM>
|
||||
METAL_FUNC void gather_impl(
|
||||
const device T* src [[buffer(0)]],
|
||||
device T* out [[buffer(1)]],
|
||||
const constant int* src_shape [[buffer(2)]],
|
||||
const constant size_t* src_strides [[buffer(3)]],
|
||||
const constant size_t& src_ndim [[buffer(4)]],
|
||||
const constant int* slice_sizes [[buffer(5)]],
|
||||
const constant int* axes [[buffer(6)]],
|
||||
const thread Indices<IdxT, NIDX>& indices,
|
||||
uint2 index [[thread_position_in_grid]],
|
||||
uint2 grid_dim [[threads_per_grid]]) {
|
||||
auto ind_idx = index.x;
|
||||
auto ind_offset = index.y;
|
||||
|
||||
size_t src_idx = 0;
|
||||
for (int i = 0; i < NIDX; ++i) {
|
||||
size_t idx_loc;
|
||||
if (IDX_NDIM == 0) {
|
||||
idx_loc = 0;
|
||||
} else if (IDX_NDIM == 1) {
|
||||
idx_loc = ind_idx * indices.strides[indices.ndim * i];
|
||||
} else {
|
||||
idx_loc = elem_to_loc(
|
||||
ind_idx,
|
||||
&indices.shapes[indices.ndim * i],
|
||||
&indices.strides[indices.ndim * i],
|
||||
indices.ndim);
|
||||
}
|
||||
auto ax = axes[i];
|
||||
auto idx_val = offset_neg_idx(indices.buffers[i][idx_loc], src_shape[ax]);
|
||||
src_idx += idx_val * src_strides[ax];
|
||||
}
|
||||
|
||||
auto src_offset = elem_to_loc(ind_offset, slice_sizes, src_strides, src_ndim);
|
||||
|
||||
size_t out_idx = index.y + static_cast<size_t>(grid_dim.y) * index.x;
|
||||
out[out_idx] = src[src_offset + src_idx];
|
||||
}
|
@@ -1,173 +0,0 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include <metal_atomic>
|
||||
|
||||
#include "mlx/backend/metal/kernels/bf16.h"
|
||||
#include "mlx/backend/metal/kernels/indexing.h"
|
||||
#include "mlx/backend/metal/kernels/utils.h"
|
||||
|
||||
using namespace metal;
|
||||
|
||||
/////////////////////////////////////////////////////////////////////
|
||||
// Gather kernel
|
||||
/////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename T, typename IdxT, int NIDX, int IDX_NDIM>
|
||||
METAL_FUNC void gather_impl(
|
||||
const device T* src [[buffer(0)]],
|
||||
device T* out [[buffer(1)]],
|
||||
const constant int* src_shape [[buffer(2)]],
|
||||
const constant size_t* src_strides [[buffer(3)]],
|
||||
const constant size_t& src_ndim [[buffer(4)]],
|
||||
const constant int* slice_sizes [[buffer(5)]],
|
||||
const constant int* axes [[buffer(6)]],
|
||||
const thread Indices<IdxT, NIDX>& indices,
|
||||
uint2 index [[thread_position_in_grid]],
|
||||
uint2 grid_dim [[threads_per_grid]]) {
|
||||
auto ind_idx = index.x;
|
||||
auto ind_offset = index.y;
|
||||
|
||||
size_t src_idx = 0;
|
||||
for (int i = 0; i < NIDX; ++i) {
|
||||
size_t idx_loc;
|
||||
if (IDX_NDIM == 0) {
|
||||
idx_loc = 0;
|
||||
} else if (IDX_NDIM == 1) {
|
||||
idx_loc = ind_idx * indices.strides[indices.ndim * i];
|
||||
} else {
|
||||
idx_loc = elem_to_loc(
|
||||
ind_idx,
|
||||
&indices.shapes[indices.ndim * i],
|
||||
&indices.strides[indices.ndim * i],
|
||||
indices.ndim);
|
||||
}
|
||||
auto ax = axes[i];
|
||||
auto idx_val = offset_neg_idx(indices.buffers[i][idx_loc], src_shape[ax]);
|
||||
src_idx += idx_val * src_strides[ax];
|
||||
}
|
||||
|
||||
auto src_offset = elem_to_loc(ind_offset, slice_sizes, src_strides, src_ndim);
|
||||
|
||||
size_t out_idx = index.y + static_cast<size_t>(grid_dim.y) * index.x;
|
||||
out[out_idx] = src[src_offset + src_idx];
|
||||
}
|
||||
|
||||
#define make_gather_impl(IDX_ARG, IDX_ARR) \
|
||||
template <typename T, typename IdxT, int NIDX, int IDX_NDIM> \
|
||||
[[kernel]] void gather( \
|
||||
const device T* src [[buffer(0)]], \
|
||||
device T* out [[buffer(1)]], \
|
||||
const constant int* src_shape [[buffer(2)]], \
|
||||
const constant size_t* src_strides [[buffer(3)]], \
|
||||
const constant size_t& src_ndim [[buffer(4)]], \
|
||||
const constant int* slice_sizes [[buffer(5)]], \
|
||||
const constant int* axes [[buffer(6)]], \
|
||||
const constant int* idx_shapes [[buffer(7)]], \
|
||||
const constant size_t* idx_strides [[buffer(8)]], \
|
||||
const constant int& idx_ndim [[buffer(9)]], \
|
||||
IDX_ARG(IdxT) uint2 index [[thread_position_in_grid]], \
|
||||
uint2 grid_dim [[threads_per_grid]]) { \
|
||||
Indices<IdxT, NIDX> idxs{ \
|
||||
{{IDX_ARR()}}, idx_shapes, idx_strides, idx_ndim}; \
|
||||
\
|
||||
return gather_impl<T, IdxT, NIDX, IDX_NDIM>( \
|
||||
src, \
|
||||
out, \
|
||||
src_shape, \
|
||||
src_strides, \
|
||||
src_ndim, \
|
||||
slice_sizes, \
|
||||
axes, \
|
||||
idxs, \
|
||||
index, \
|
||||
grid_dim); \
|
||||
}
|
||||
|
||||
#define make_gather(n) make_gather_impl(IDX_ARG_##n, IDX_ARR_##n)
|
||||
|
||||
make_gather(0) make_gather(1) make_gather(2) make_gather(3) make_gather(4)
|
||||
make_gather(5) make_gather(6) make_gather(7) make_gather(8) make_gather(9)
|
||||
make_gather(10)
|
||||
|
||||
/////////////////////////////////////////////////////////////////////
|
||||
// Gather instantiations
|
||||
/////////////////////////////////////////////////////////////////////
|
||||
|
||||
#define instantiate_gather6(name, src_t, idx_t, nidx, IDX_ARG, nd, nd_name) \
|
||||
template [[host_name("gather" name "_" #nidx "" #nd_name)]] [[kernel]] void \
|
||||
gather<src_t, idx_t, nidx, nd>( \
|
||||
const device src_t* src [[buffer(0)]], \
|
||||
device src_t* out [[buffer(1)]], \
|
||||
const constant int* src_shape [[buffer(2)]], \
|
||||
const constant size_t* src_strides [[buffer(3)]], \
|
||||
const constant size_t& src_ndim [[buffer(4)]], \
|
||||
const constant int* slice_sizes [[buffer(5)]], \
|
||||
const constant int* axes [[buffer(6)]], \
|
||||
const constant int* idx_shapes [[buffer(7)]], \
|
||||
const constant size_t* idx_strides [[buffer(8)]], \
|
||||
const constant int& idx_ndim [[buffer(9)]], \
|
||||
IDX_ARG(idx_t) uint2 index [[thread_position_in_grid]], \
|
||||
uint2 grid_dim [[threads_per_grid]]);
|
||||
|
||||
// clang-format off
|
||||
#define instantiate_gather5(name, src_t, idx_t, nidx, nd, nd_name) \
|
||||
instantiate_gather6(name, src_t, idx_t, nidx, IDX_ARG_ ##nidx, nd, nd_name) // clang-format on
|
||||
|
||||
// clang-format off
|
||||
#define instantiate_gather4(name, src_t, idx_t, nidx) \
|
||||
instantiate_gather5(name, src_t, idx_t, nidx, 0, _0) \
|
||||
instantiate_gather5(name, src_t, idx_t, nidx, 1, _1) \
|
||||
instantiate_gather5(name, src_t, idx_t, nidx, 2, )
|
||||
|
||||
|
||||
// Special for case NIDX=0
|
||||
instantiate_gather4("bool_", bool, bool, 0)
|
||||
instantiate_gather4("uint8", uint8_t, bool, 0)
|
||||
instantiate_gather4("uint16", uint16_t, bool, 0)
|
||||
instantiate_gather4("uint32", uint32_t, bool, 0)
|
||||
instantiate_gather4("uint64", uint64_t, bool, 0)
|
||||
instantiate_gather4("int8", int8_t, bool, 0)
|
||||
instantiate_gather4("int16", int16_t, bool, 0)
|
||||
instantiate_gather4("int32", int32_t, bool, 0)
|
||||
instantiate_gather4("int64", int64_t, bool, 0)
|
||||
instantiate_gather4("float16", half, bool, 0)
|
||||
instantiate_gather4("float32", float, bool, 0)
|
||||
instantiate_gather4("bfloat16", bfloat16_t, bool, 0) // clang-format on
|
||||
|
||||
// clang-format off
|
||||
#define instantiate_gather3(name, src_type, ind_type) \
|
||||
instantiate_gather4(name, src_type, ind_type, 1) \
|
||||
instantiate_gather4(name, src_type, ind_type, 2) \
|
||||
instantiate_gather4(name, src_type, ind_type, 3) \
|
||||
instantiate_gather4(name, src_type, ind_type, 4) \
|
||||
instantiate_gather4(name, src_type, ind_type, 5) \
|
||||
instantiate_gather4(name, src_type, ind_type, 6) \
|
||||
instantiate_gather4(name, src_type, ind_type, 7) \
|
||||
instantiate_gather4(name, src_type, ind_type, 8) \
|
||||
instantiate_gather4(name, src_type, ind_type, 9) \
|
||||
instantiate_gather4(name, src_type, ind_type, 10) // clang-format on
|
||||
|
||||
// clang-format off
|
||||
#define instantiate_gather(name, src_type) \
|
||||
instantiate_gather3(#name "bool_", src_type, bool) \
|
||||
instantiate_gather3(#name "uint8", src_type, uint8_t) \
|
||||
instantiate_gather3(#name "uint16", src_type, uint16_t) \
|
||||
instantiate_gather3(#name "uint32", src_type, uint32_t) \
|
||||
instantiate_gather3(#name "uint64", src_type, uint64_t) \
|
||||
instantiate_gather3(#name "int8", src_type, int8_t) \
|
||||
instantiate_gather3(#name "int16", src_type, int16_t) \
|
||||
instantiate_gather3(#name "int32", src_type, int32_t) \
|
||||
instantiate_gather3(#name "int64", src_type, int64_t)
|
||||
|
||||
instantiate_gather(bool_, bool)
|
||||
instantiate_gather(uint8, uint8_t)
|
||||
instantiate_gather(uint16, uint16_t)
|
||||
instantiate_gather(uint32, uint32_t)
|
||||
instantiate_gather(uint64, uint64_t)
|
||||
instantiate_gather(int8, int8_t)
|
||||
instantiate_gather(int16, int16_t)
|
||||
instantiate_gather(int32, int32_t)
|
||||
instantiate_gather(int64, int64_t)
|
||||
instantiate_gather(float16, half)
|
||||
instantiate_gather(float32, float)
|
||||
instantiate_gather(bfloat16, bfloat16_t) // clang-format on
|
@@ -1,13 +1,9 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <metal_stdlib>
|
||||
|
||||
using namespace metal;
|
||||
|
||||
/////////////////////////////////////////////////////////////////////
|
||||
// Indexing utils
|
||||
/////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename IdxT, int NIDX>
|
||||
struct Indices {
|
||||
const array<const device IdxT*, NIDX> buffers;
|
||||
@@ -24,31 +20,3 @@ METAL_FUNC size_t offset_neg_idx(IdxT idx, size_t size) {
|
||||
return (idx < 0) ? idx + size : idx;
|
||||
}
|
||||
}
|
||||
|
||||
#define IDX_ARG_N(idx_t, n) const device idx_t *idx##n [[buffer(n)]],
|
||||
|
||||
#define IDX_ARG_0(idx_t)
|
||||
#define IDX_ARG_1(idx_t) IDX_ARG_0(idx_t) IDX_ARG_N(idx_t, 21)
|
||||
#define IDX_ARG_2(idx_t) IDX_ARG_1(idx_t) IDX_ARG_N(idx_t, 22)
|
||||
#define IDX_ARG_3(idx_t) IDX_ARG_2(idx_t) IDX_ARG_N(idx_t, 23)
|
||||
#define IDX_ARG_4(idx_t) IDX_ARG_3(idx_t) IDX_ARG_N(idx_t, 24)
|
||||
#define IDX_ARG_5(idx_t) IDX_ARG_4(idx_t) IDX_ARG_N(idx_t, 25)
|
||||
#define IDX_ARG_6(idx_t) IDX_ARG_5(idx_t) IDX_ARG_N(idx_t, 26)
|
||||
#define IDX_ARG_7(idx_t) IDX_ARG_6(idx_t) IDX_ARG_N(idx_t, 27)
|
||||
#define IDX_ARG_8(idx_t) IDX_ARG_7(idx_t) IDX_ARG_N(idx_t, 28)
|
||||
#define IDX_ARG_9(idx_t) IDX_ARG_8(idx_t) IDX_ARG_N(idx_t, 29)
|
||||
#define IDX_ARG_10(idx_t) IDX_ARG_9(idx_t) IDX_ARG_N(idx_t, 30)
|
||||
|
||||
#define IDX_ARR_N(n) idx##n,
|
||||
|
||||
#define IDX_ARR_0()
|
||||
#define IDX_ARR_1() IDX_ARR_0() IDX_ARR_N(21)
|
||||
#define IDX_ARR_2() IDX_ARR_1() IDX_ARR_N(22)
|
||||
#define IDX_ARR_3() IDX_ARR_2() IDX_ARR_N(23)
|
||||
#define IDX_ARR_4() IDX_ARR_3() IDX_ARR_N(24)
|
||||
#define IDX_ARR_5() IDX_ARR_4() IDX_ARR_N(25)
|
||||
#define IDX_ARR_6() IDX_ARR_5() IDX_ARR_N(26)
|
||||
#define IDX_ARR_7() IDX_ARR_6() IDX_ARR_N(27)
|
||||
#define IDX_ARR_8() IDX_ARR_7() IDX_ARR_N(28)
|
||||
#define IDX_ARR_9() IDX_ARR_8() IDX_ARR_N(29)
|
||||
#define IDX_ARR_10() IDX_ARR_9() IDX_ARR_N(30)
|
File diff suppressed because it is too large
Load Diff
4
mlx/backend/metal/kernels/reduce.h
Normal file
4
mlx/backend/metal/kernels/reduce.h
Normal file
@@ -0,0 +1,4 @@
|
||||
#pragma once
|
||||
#include "mlx/backend/metal/kernels/reduction/reduce_all.h"
|
||||
#include "mlx/backend/metal/kernels/reduction/reduce_col.h"
|
||||
#include "mlx/backend/metal/kernels/reduction/reduce_row.h"
|
293
mlx/backend/metal/kernels/reduce.metal
Normal file
293
mlx/backend/metal/kernels/reduce.metal
Normal file
@@ -0,0 +1,293 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#include <metal_atomic>
|
||||
#include <metal_simdgroup>
|
||||
|
||||
// clang-format off
|
||||
#include "mlx/backend/metal/kernels/defines.h"
|
||||
#include "mlx/backend/metal/kernels/utils.h"
|
||||
#include "mlx/backend/metal/kernels/atomic.h"
|
||||
#include "mlx/backend/metal/kernels/reduction/ops.h"
|
||||
#include "mlx/backend/metal/kernels/reduction/reduce_init.h"
|
||||
#include "mlx/backend/metal/kernels/reduce.h"
|
||||
|
||||
#define instantiate_reduce_helper_floats(inst_f, name, op) \
|
||||
inst_f(name, float16, half, op) \
|
||||
inst_f(name, float32, float, op) \
|
||||
inst_f(name, bfloat16, bfloat16_t, op)
|
||||
|
||||
#define instantiate_reduce_helper_uints(inst_f, name, op) \
|
||||
inst_f(name, uint8, uint8_t, op) \
|
||||
inst_f(name, uint16, uint16_t, op) \
|
||||
inst_f(name, uint32, uint32_t, op)
|
||||
|
||||
#define instantiate_reduce_helper_ints(inst_f, name, op) \
|
||||
inst_f(name, int8, int8_t, op) \
|
||||
inst_f(name, int16, int16_t, op) \
|
||||
inst_f(name, int32, int32_t, op)
|
||||
|
||||
#define instantiate_reduce_helper_64b(inst_f, name, op) \
|
||||
inst_f(name, int64, int64_t, op) \
|
||||
inst_f(name, uint64, uint64_t, op)
|
||||
|
||||
#define instantiate_reduce_helper_types(inst_f, name, op) \
|
||||
instantiate_reduce_helper_floats(inst_f, name, op) \
|
||||
instantiate_reduce_helper_uints(inst_f, name, op) \
|
||||
instantiate_reduce_helper_ints(inst_f, name, op)
|
||||
|
||||
#define instantiate_reduce_ops(inst_f, type_f) \
|
||||
type_f(inst_f, sum, Sum) \
|
||||
type_f(inst_f, prod, Prod) \
|
||||
type_f(inst_f, min_, Min) \
|
||||
type_f(inst_f, max_, Max)
|
||||
|
||||
// Special case for bool reductions
|
||||
#define instantiate_reduce_from_types_helper( \
|
||||
inst_f, name, tname, itype, otype, op) \
|
||||
inst_f(name##tname, itype, otype, op)
|
||||
|
||||
#define instantiate_reduce_from_types(inst_f, name, otype, op) \
|
||||
instantiate_reduce_from_types_helper( \
|
||||
inst_f, name, bool_, bool, otype, op) \
|
||||
instantiate_reduce_from_types_helper( \
|
||||
inst_f, name, uint8, uint8_t, otype, op) \
|
||||
instantiate_reduce_from_types_helper( \
|
||||
inst_f, name, uint16, uint16_t, otype, op) \
|
||||
instantiate_reduce_from_types_helper( \
|
||||
inst_f, name, uint32, uint32_t, otype, op) \
|
||||
instantiate_reduce_from_types_helper( \
|
||||
inst_f, name, uint64, uint64_t, otype, op) \
|
||||
instantiate_reduce_from_types_helper( \
|
||||
inst_f, name, int8, int8_t, otype, op) \
|
||||
instantiate_reduce_from_types_helper( \
|
||||
inst_f, name, int16, int16_t, otype, op) \
|
||||
instantiate_reduce_from_types_helper( \
|
||||
inst_f, name, int32, int32_t, otype, op) \
|
||||
instantiate_reduce_from_types_helper( \
|
||||
inst_f, name, int64, int64_t, otype, op) \
|
||||
instantiate_reduce_from_types_helper( \
|
||||
inst_f, name, float16, half, otype, op) \
|
||||
instantiate_reduce_from_types_helper( \
|
||||
inst_f, \
|
||||
name, \
|
||||
float32, \
|
||||
float, \
|
||||
otype, \
|
||||
op) \
|
||||
instantiate_reduce_from_types_helper( \
|
||||
inst_f, \
|
||||
name, \
|
||||
bfloat16, \
|
||||
bfloat16_t, \
|
||||
otype, \
|
||||
op)
|
||||
|
||||
#define instantiate_init_reduce(name, otype, op) \
|
||||
template [[host_name("i_reduce_" #name)]] [[kernel]] void \
|
||||
init_reduce<otype, op>( \
|
||||
device otype * out [[buffer(1)]], uint tid [[thread_position_in_grid]]);
|
||||
|
||||
#define instantiate_init_reduce_helper(name, tname, type, op) \
|
||||
instantiate_init_reduce(name##tname, type, op<type>)
|
||||
|
||||
instantiate_reduce_ops(instantiate_init_reduce_helper, instantiate_reduce_helper_types)
|
||||
instantiate_reduce_ops(instantiate_init_reduce_helper, instantiate_reduce_helper_64b)
|
||||
|
||||
instantiate_init_reduce(andbool_, bool, And<bool>)
|
||||
instantiate_init_reduce(orbool_, bool, Or<bool>)
|
||||
|
||||
#define instantiate_all_reduce(name, itype, otype, op) \
|
||||
template [[host_name("all_reduce_" #name)]] [[kernel]] void \
|
||||
all_reduce<itype, otype, op>( \
|
||||
const device itype* in [[buffer(0)]], \
|
||||
device mlx_atomic<otype>* out [[buffer(1)]], \
|
||||
const device size_t& in_size [[buffer(2)]], \
|
||||
uint gid [[thread_position_in_grid]], \
|
||||
uint lid [[thread_position_in_threadgroup]], \
|
||||
uint grid_size [[threads_per_grid]], \
|
||||
uint simd_per_group [[simdgroups_per_threadgroup]], \
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
|
||||
|
||||
#define instantiate_all_reduce_no_atomics(name, itype, otype, op) \
|
||||
template [[host_name("allNoAtomics_reduce_" #name)]] [[kernel]] void \
|
||||
all_reduce_no_atomics<itype, otype, op>( \
|
||||
const device itype* in [[buffer(0)]], \
|
||||
device otype* out [[buffer(1)]], \
|
||||
const device size_t& in_size [[buffer(2)]], \
|
||||
uint gid [[thread_position_in_grid]], \
|
||||
uint lid [[thread_position_in_threadgroup]], \
|
||||
uint grid_size [[threads_per_grid]], \
|
||||
uint simd_per_group [[simdgroups_per_threadgroup]], \
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]], \
|
||||
uint thread_group_id [[threadgroup_position_in_grid]]);
|
||||
|
||||
#define instantiate_same_all_reduce_helper(name, tname, type, op) \
|
||||
instantiate_all_reduce(name##tname, type, type, op<type>)
|
||||
|
||||
#define instantiate_same_all_reduce_na_helper(name, tname, type, op) \
|
||||
instantiate_all_reduce_no_atomics(name##tname, type, type, op<type>)
|
||||
|
||||
instantiate_reduce_ops(instantiate_same_all_reduce_helper, instantiate_reduce_helper_types)
|
||||
instantiate_reduce_ops(instantiate_same_all_reduce_na_helper, instantiate_reduce_helper_64b)
|
||||
|
||||
instantiate_reduce_from_types(instantiate_all_reduce, and, bool, And<bool>)
|
||||
instantiate_reduce_from_types(instantiate_all_reduce, or, bool, Or<bool>)
|
||||
|
||||
// special case bool with larger output type
|
||||
instantiate_all_reduce(sumbool_, bool, uint32_t, Sum<uint32_t>)
|
||||
|
||||
#define instantiate_col_reduce_general(name, itype, otype, op) \
|
||||
template [[host_name("colGeneral_reduce_" #name)]] [[kernel]] void \
|
||||
col_reduce_general<itype, otype, op>( \
|
||||
const device itype* in [[buffer(0)]], \
|
||||
device mlx_atomic<otype>* out [[buffer(1)]], \
|
||||
const constant size_t& reduction_size [[buffer(2)]], \
|
||||
const constant size_t& reduction_stride [[buffer(3)]], \
|
||||
const constant size_t& out_size [[buffer(4)]], \
|
||||
const constant int* shape [[buffer(5)]], \
|
||||
const constant size_t* strides [[buffer(6)]], \
|
||||
const constant int& ndim [[buffer(7)]], \
|
||||
threadgroup otype* local_data [[threadgroup(0)]], \
|
||||
uint3 tid [[threadgroup_position_in_grid]], \
|
||||
uint3 lid [[thread_position_in_threadgroup]], \
|
||||
uint3 lsize [[threads_per_threadgroup]]);
|
||||
|
||||
#define instantiate_col_reduce_general_no_atomics(name, itype, otype, op) \
|
||||
template \
|
||||
[[host_name("colGeneralNoAtomics_reduce_" #name)]] [[kernel]] void \
|
||||
col_reduce_general_no_atomics<itype, otype, op>( \
|
||||
const device itype* in [[buffer(0)]], \
|
||||
device otype* out [[buffer(1)]], \
|
||||
const constant size_t& reduction_size [[buffer(2)]], \
|
||||
const constant size_t& reduction_stride [[buffer(3)]], \
|
||||
const constant size_t& out_size [[buffer(4)]], \
|
||||
const constant int* shape [[buffer(5)]], \
|
||||
const constant size_t* strides [[buffer(6)]], \
|
||||
const constant int& ndim [[buffer(7)]], \
|
||||
threadgroup otype* local_data [[threadgroup(0)]], \
|
||||
uint3 tid [[threadgroup_position_in_grid]], \
|
||||
uint3 lid [[thread_position_in_threadgroup]], \
|
||||
uint3 gid [[thread_position_in_grid]], \
|
||||
uint3 lsize [[threads_per_threadgroup]], \
|
||||
uint3 gsize [[threads_per_grid]]);
|
||||
|
||||
#define instantiate_col_reduce_small(name, itype, otype, op) \
|
||||
template [[host_name("colSmall_reduce_" #name)]] [[kernel]] void \
|
||||
col_reduce_small<itype, otype, op>( \
|
||||
const device itype* in [[buffer(0)]], \
|
||||
device otype* out [[buffer(1)]], \
|
||||
const constant size_t& reduction_size [[buffer(2)]], \
|
||||
const constant size_t& reduction_stride [[buffer(3)]], \
|
||||
const constant size_t& out_size [[buffer(4)]], \
|
||||
const constant int* shape [[buffer(5)]], \
|
||||
const constant size_t* strides [[buffer(6)]], \
|
||||
const constant int& ndim [[buffer(7)]], \
|
||||
const constant size_t& non_col_reductions [[buffer(8)]], \
|
||||
const constant int* non_col_shapes [[buffer(9)]], \
|
||||
const constant size_t* non_col_strides [[buffer(10)]], \
|
||||
const constant int& non_col_ndim [[buffer(11)]], \
|
||||
uint tid [[thread_position_in_grid]]);
|
||||
|
||||
#define instantiate_same_col_reduce_helper(name, tname, type, op) \
|
||||
instantiate_col_reduce_small(name ##tname, type, type, op<type>) \
|
||||
instantiate_col_reduce_general(name ##tname, type, type, op<type>)
|
||||
|
||||
#define instantiate_same_col_reduce_na_helper(name, tname, type, op) \
|
||||
instantiate_col_reduce_small(name ##tname, type, type, op<type>) \
|
||||
instantiate_col_reduce_general_no_atomics(name ##tname, type, type, op<type>)
|
||||
|
||||
instantiate_reduce_ops(instantiate_same_col_reduce_helper, instantiate_reduce_helper_types)
|
||||
instantiate_reduce_ops(instantiate_same_col_reduce_na_helper, instantiate_reduce_helper_64b)
|
||||
|
||||
instantiate_col_reduce_general(sumbool_, bool, uint32_t, Sum<uint32_t>)
|
||||
instantiate_reduce_from_types(instantiate_col_reduce_general, and, bool, And<bool>)
|
||||
instantiate_reduce_from_types(instantiate_col_reduce_general, or, bool, Or<bool>)
|
||||
|
||||
instantiate_col_reduce_small(sumbool_, bool, uint32_t, Sum<uint32_t>)
|
||||
instantiate_reduce_from_types(instantiate_col_reduce_small, and, bool, And<bool>)
|
||||
instantiate_reduce_from_types(instantiate_col_reduce_small, or, bool, Or<bool>)
|
||||
|
||||
#define instantiate_row_reduce_small(name, itype, otype, op) \
|
||||
template [[host_name("rowGeneralSmall_reduce_" #name)]] [[kernel]] void \
|
||||
row_reduce_general_small<itype, otype, op>( \
|
||||
const device itype* in [[buffer(0)]], \
|
||||
device otype* out [[buffer(1)]], \
|
||||
const constant size_t& reduction_size [[buffer(2)]], \
|
||||
const constant size_t& out_size [[buffer(3)]], \
|
||||
const constant size_t& non_row_reductions [[buffer(4)]], \
|
||||
const constant int* shape [[buffer(5)]], \
|
||||
const constant size_t* strides [[buffer(6)]], \
|
||||
const constant int& ndim [[buffer(7)]], \
|
||||
uint lid [[thread_position_in_grid]]); \
|
||||
template [[host_name("rowGeneralMed_reduce_" #name)]] [[kernel]] void \
|
||||
row_reduce_general_med<itype, otype, op>( \
|
||||
const device itype* in [[buffer(0)]], \
|
||||
device otype* out [[buffer(1)]], \
|
||||
const constant size_t& reduction_size [[buffer(2)]], \
|
||||
const constant size_t& out_size [[buffer(3)]], \
|
||||
const constant size_t& non_row_reductions [[buffer(4)]], \
|
||||
const constant int* shape [[buffer(5)]], \
|
||||
const constant size_t* strides [[buffer(6)]], \
|
||||
const constant int& ndim [[buffer(7)]], \
|
||||
uint tid [[threadgroup_position_in_grid]], \
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||
uint simd_per_group [[dispatch_simdgroups_per_threadgroup]], \
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
|
||||
|
||||
#define instantiate_row_reduce_general(name, itype, otype, op) \
|
||||
instantiate_row_reduce_small(name, itype, otype, op) \
|
||||
template \
|
||||
[[host_name("rowGeneral_reduce_" #name)]] [[kernel]] void \
|
||||
row_reduce_general<itype, otype, op>( \
|
||||
const device itype* in [[buffer(0)]], \
|
||||
device mlx_atomic<otype>* out [[buffer(1)]], \
|
||||
const constant size_t& reduction_size [[buffer(2)]], \
|
||||
const constant size_t& out_size [[buffer(3)]], \
|
||||
const constant size_t& non_row_reductions [[buffer(4)]], \
|
||||
const constant int* shape [[buffer(5)]], \
|
||||
const constant size_t* strides [[buffer(6)]], \
|
||||
const constant int& ndim [[buffer(7)]], \
|
||||
uint3 lid [[thread_position_in_threadgroup]], \
|
||||
uint3 lsize [[threads_per_threadgroup]], \
|
||||
uint3 tid [[threadgroup_position_in_grid]], \
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||
uint simd_per_group [[simdgroups_per_threadgroup]], \
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
|
||||
|
||||
#define instantiate_row_reduce_general_no_atomics(name, itype, otype, op) \
|
||||
instantiate_row_reduce_small(name, itype, otype, op) \
|
||||
template \
|
||||
[[host_name("rowGeneralNoAtomics_reduce_" #name)]] [[kernel]] void \
|
||||
row_reduce_general_no_atomics<itype, otype, op>( \
|
||||
const device itype* in [[buffer(0)]], \
|
||||
device otype* out [[buffer(1)]], \
|
||||
const constant size_t& reduction_size [[buffer(2)]], \
|
||||
const constant size_t& out_size [[buffer(3)]], \
|
||||
const constant size_t& non_row_reductions [[buffer(4)]], \
|
||||
const constant int* shape [[buffer(5)]], \
|
||||
const constant size_t* strides [[buffer(6)]], \
|
||||
const constant int& ndim [[buffer(7)]], \
|
||||
uint3 lid [[thread_position_in_threadgroup]], \
|
||||
uint3 lsize [[threads_per_threadgroup]], \
|
||||
uint3 gsize [[threads_per_grid]], \
|
||||
uint3 tid [[threadgroup_position_in_grid]], \
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||
uint simd_per_group [[simdgroups_per_threadgroup]], \
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
|
||||
|
||||
#define instantiate_same_row_reduce_helper(name, tname, type, op) \
|
||||
instantiate_row_reduce_general(name##tname, type, type, op<type>)
|
||||
|
||||
#define instantiate_same_row_reduce_na_helper(name, tname, type, op) \
|
||||
instantiate_row_reduce_general_no_atomics(name##tname, type, type, op<type>)
|
||||
|
||||
instantiate_reduce_ops(instantiate_same_row_reduce_helper, instantiate_reduce_helper_types)
|
||||
instantiate_reduce_ops(instantiate_same_row_reduce_na_helper, instantiate_reduce_helper_64b)
|
||||
|
||||
instantiate_reduce_from_types(instantiate_row_reduce_general, and, bool, And<bool>)
|
||||
instantiate_reduce_from_types(instantiate_row_reduce_general, or, bool, Or<bool>)
|
||||
|
||||
instantiate_row_reduce_general(sumbool_, bool, uint32_t, Sum<uint32_t>)
|
||||
// clang-format on
|
6
mlx/backend/metal/kernels/reduce_utils.h
Normal file
6
mlx/backend/metal/kernels/reduce_utils.h
Normal file
@@ -0,0 +1,6 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "mlx/backend/metal/kernels/atomic.h"
|
||||
#include "mlx/backend/metal/kernels/reduction/ops.h"
|
@@ -1,32 +0,0 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include "mlx/backend/metal/kernels/reduction/ops.h"
|
||||
#include "mlx/backend/metal/kernels/reduction/reduce_inst.h"
|
||||
#include "mlx/backend/metal/kernels/reduction/utils.h"
|
||||
|
||||
using namespace metal;
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Reduce init
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename T, typename Op>
|
||||
[[kernel]] void init_reduce(
|
||||
device T* out [[buffer(0)]],
|
||||
uint tid [[thread_position_in_grid]]) {
|
||||
out[tid] = Op::init;
|
||||
}
|
||||
|
||||
#define instantiate_init_reduce(name, otype, op) \
|
||||
template [[host_name("i" #name)]] [[kernel]] void init_reduce<otype, op>( \
|
||||
device otype * out [[buffer(1)]], uint tid [[thread_position_in_grid]]);
|
||||
|
||||
#define instantiate_init_reduce_helper(name, tname, type, op) \
|
||||
instantiate_init_reduce(name##tname, type, op<type>)
|
||||
|
||||
// clang-format off
|
||||
instantiate_reduce_ops(instantiate_init_reduce_helper, instantiate_reduce_helper_types)
|
||||
instantiate_reduce_ops(instantiate_init_reduce_helper, instantiate_reduce_helper_64b)
|
||||
|
||||
instantiate_init_reduce(andbool_, bool, And)
|
||||
instantiate_init_reduce(orbool_, bool, Or) // clang-format on
|
@@ -5,9 +5,7 @@
|
||||
#include <metal_atomic>
|
||||
#include <metal_simdgroup>
|
||||
|
||||
#include "mlx/backend/metal/kernels/atomic.h"
|
||||
#include "mlx/backend/metal/kernels/bf16.h"
|
||||
#include "mlx/backend/metal/kernels/utils.h"
|
||||
static constant constexpr const uint8_t simd_size = 32;
|
||||
|
||||
union bool4_or_uint {
|
||||
bool4 b;
|
||||
@@ -21,6 +19,7 @@ struct None {
|
||||
}
|
||||
};
|
||||
|
||||
template <typename U = bool>
|
||||
struct And {
|
||||
bool simd_reduce(bool val) {
|
||||
return simd_all(val);
|
||||
@@ -58,6 +57,7 @@ struct And {
|
||||
}
|
||||
};
|
||||
|
||||
template <typename U = bool>
|
||||
struct Or {
|
||||
bool simd_reduce(bool val) {
|
||||
return simd_any(val);
|
||||
|
@@ -1,11 +1,5 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include "mlx/backend/metal/kernels/reduction/ops.h"
|
||||
#include "mlx/backend/metal/kernels/reduction/reduce_inst.h"
|
||||
#include "mlx/backend/metal/kernels/reduction/utils.h"
|
||||
|
||||
using namespace metal;
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// All reduce helper
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
@@ -139,50 +133,3 @@ template <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS>
|
||||
out[thread_group_id] = total_val;
|
||||
}
|
||||
}
|
||||
|
||||
#define instantiate_all_reduce(name, itype, otype, op) \
|
||||
template [[host_name("all_reduce_" #name)]] [[kernel]] void \
|
||||
all_reduce<itype, otype, op>( \
|
||||
const device itype* in [[buffer(0)]], \
|
||||
device mlx_atomic<otype>* out [[buffer(1)]], \
|
||||
const device size_t& in_size [[buffer(2)]], \
|
||||
uint gid [[thread_position_in_grid]], \
|
||||
uint lid [[thread_position_in_threadgroup]], \
|
||||
uint grid_size [[threads_per_grid]], \
|
||||
uint simd_per_group [[simdgroups_per_threadgroup]], \
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
|
||||
|
||||
#define instantiate_all_reduce_no_atomics(name, itype, otype, op) \
|
||||
template [[host_name("all_reduce_no_atomics_" #name)]] [[kernel]] void \
|
||||
all_reduce_no_atomics<itype, otype, op>( \
|
||||
const device itype* in [[buffer(0)]], \
|
||||
device otype* out [[buffer(1)]], \
|
||||
const device size_t& in_size [[buffer(2)]], \
|
||||
uint gid [[thread_position_in_grid]], \
|
||||
uint lid [[thread_position_in_threadgroup]], \
|
||||
uint grid_size [[threads_per_grid]], \
|
||||
uint simd_per_group [[simdgroups_per_threadgroup]], \
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]], \
|
||||
uint thread_group_id [[threadgroup_position_in_grid]]);
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Instantiations
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#define instantiate_same_all_reduce_helper(name, tname, type, op) \
|
||||
instantiate_all_reduce(name##tname, type, type, op<type>)
|
||||
|
||||
#define instantiate_same_all_reduce_na_helper(name, tname, type, op) \
|
||||
instantiate_all_reduce_no_atomics(name##tname, type, type, op<type>)
|
||||
|
||||
// clang-format off
|
||||
instantiate_reduce_ops(instantiate_same_all_reduce_helper, instantiate_reduce_helper_types)
|
||||
instantiate_reduce_ops(instantiate_same_all_reduce_na_helper, instantiate_reduce_helper_64b)
|
||||
|
||||
instantiate_reduce_from_types(instantiate_all_reduce, and, bool, And)
|
||||
instantiate_reduce_from_types(instantiate_all_reduce, or, bool, Or)
|
||||
|
||||
// special case bool with larger output type
|
||||
instantiate_all_reduce(sumbool_, bool, uint32_t, Sum<uint32_t>) // clang-format on
|
@@ -1,11 +1,5 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include "mlx/backend/metal/kernels/reduction/ops.h"
|
||||
#include "mlx/backend/metal/kernels/reduction/reduce_inst.h"
|
||||
#include "mlx/backend/metal/kernels/reduction/utils.h"
|
||||
|
||||
using namespace metal;
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Small column reduce kernel
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
@@ -52,23 +46,6 @@ template <typename T, typename U, typename Op>
|
||||
out[out_idx] = total_val;
|
||||
}
|
||||
|
||||
#define instantiate_col_reduce_small(name, itype, otype, op) \
|
||||
template [[host_name("col_reduce_small_" #name)]] [[kernel]] void \
|
||||
col_reduce_small<itype, otype, op>( \
|
||||
const device itype* in [[buffer(0)]], \
|
||||
device otype* out [[buffer(1)]], \
|
||||
const constant size_t& reduction_size [[buffer(2)]], \
|
||||
const constant size_t& reduction_stride [[buffer(3)]], \
|
||||
const constant size_t& out_size [[buffer(4)]], \
|
||||
const constant int* shape [[buffer(5)]], \
|
||||
const constant size_t* strides [[buffer(6)]], \
|
||||
const constant int& ndim [[buffer(7)]], \
|
||||
const constant size_t& non_col_reductions [[buffer(8)]], \
|
||||
const constant int* non_col_shapes [[buffer(9)]], \
|
||||
const constant size_t* non_col_strides [[buffer(10)]], \
|
||||
const constant int& non_col_ndim [[buffer(11)]], \
|
||||
uint tid [[thread_position_in_grid]]);
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Column reduce helper
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
@@ -186,64 +163,3 @@ template <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS>
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#define instantiate_col_reduce_general(name, itype, otype, op) \
|
||||
template [[host_name("col_reduce_general_" #name)]] [[kernel]] void \
|
||||
col_reduce_general<itype, otype, op>( \
|
||||
const device itype* in [[buffer(0)]], \
|
||||
device mlx_atomic<otype>* out [[buffer(1)]], \
|
||||
const constant size_t& reduction_size [[buffer(2)]], \
|
||||
const constant size_t& reduction_stride [[buffer(3)]], \
|
||||
const constant size_t& out_size [[buffer(4)]], \
|
||||
const constant int* shape [[buffer(5)]], \
|
||||
const constant size_t* strides [[buffer(6)]], \
|
||||
const constant int& ndim [[buffer(7)]], \
|
||||
threadgroup otype* local_data [[threadgroup(0)]], \
|
||||
uint3 tid [[threadgroup_position_in_grid]], \
|
||||
uint3 lid [[thread_position_in_threadgroup]], \
|
||||
uint3 lsize [[threads_per_threadgroup]]);
|
||||
|
||||
#define instantiate_col_reduce_general_no_atomics(name, itype, otype, op) \
|
||||
template \
|
||||
[[host_name("col_reduce_general_no_atomics_" #name)]] [[kernel]] void \
|
||||
col_reduce_general_no_atomics<itype, otype, op>( \
|
||||
const device itype* in [[buffer(0)]], \
|
||||
device otype* out [[buffer(1)]], \
|
||||
const constant size_t& reduction_size [[buffer(2)]], \
|
||||
const constant size_t& reduction_stride [[buffer(3)]], \
|
||||
const constant size_t& out_size [[buffer(4)]], \
|
||||
const constant int* shape [[buffer(5)]], \
|
||||
const constant size_t* strides [[buffer(6)]], \
|
||||
const constant int& ndim [[buffer(7)]], \
|
||||
threadgroup otype* local_data [[threadgroup(0)]], \
|
||||
uint3 tid [[threadgroup_position_in_grid]], \
|
||||
uint3 lid [[thread_position_in_threadgroup]], \
|
||||
uint3 gid [[thread_position_in_grid]], \
|
||||
uint3 lsize [[threads_per_threadgroup]], \
|
||||
uint3 gsize [[threads_per_grid]]);
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Instantiations
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// clang-format off
|
||||
#define instantiate_same_col_reduce_helper(name, tname, type, op) \
|
||||
instantiate_col_reduce_small(name ##tname, type, type, op<type>) \
|
||||
instantiate_col_reduce_general(name ##tname, type, type, op<type>) // clang-format on
|
||||
|
||||
// clang-format off
|
||||
#define instantiate_same_col_reduce_na_helper(name, tname, type, op) \
|
||||
instantiate_col_reduce_small(name ##tname, type, type, op<type>) \
|
||||
instantiate_col_reduce_general_no_atomics(name ##tname, type, type, op<type>) // clang-format on
|
||||
|
||||
// clang-format off
|
||||
instantiate_reduce_ops(instantiate_same_col_reduce_helper, instantiate_reduce_helper_types)
|
||||
instantiate_reduce_ops(instantiate_same_col_reduce_na_helper, instantiate_reduce_helper_64b)
|
||||
|
||||
instantiate_col_reduce_general(sumbool_, bool, uint32_t, Sum<uint32_t>)
|
||||
instantiate_reduce_from_types(instantiate_col_reduce_general, and, bool, And)
|
||||
instantiate_reduce_from_types(instantiate_col_reduce_general, or, bool, Or)
|
||||
|
||||
instantiate_col_reduce_small(sumbool_, bool, uint32_t, Sum<uint32_t>)
|
||||
instantiate_reduce_from_types(instantiate_col_reduce_small, and, bool, And)
|
||||
instantiate_reduce_from_types(instantiate_col_reduce_small, or, bool, Or) // clang-format on
|
8
mlx/backend/metal/kernels/reduction/reduce_init.h
Normal file
8
mlx/backend/metal/kernels/reduction/reduce_init.h
Normal file
@@ -0,0 +1,8 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
template <typename T, typename Op>
|
||||
[[kernel]] void init_reduce(
|
||||
device T* out [[buffer(0)]],
|
||||
uint tid [[thread_position_in_grid]]) {
|
||||
out[tid] = Op::init;
|
||||
}
|
@@ -1,71 +0,0 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <metal_atomic>
|
||||
#include <metal_simdgroup>
|
||||
|
||||
#include "mlx/backend/metal/kernels/defines.h"
|
||||
#include "mlx/backend/metal/kernels/reduction/ops.h"
|
||||
|
||||
#define instantiate_reduce_helper_floats(inst_f, name, op) \
|
||||
inst_f(name, float16, half, op) inst_f(name, float32, float, op) \
|
||||
inst_f(name, bfloat16, bfloat16_t, op)
|
||||
|
||||
#define instantiate_reduce_helper_uints(inst_f, name, op) \
|
||||
inst_f(name, uint8, uint8_t, op) inst_f(name, uint16, uint16_t, op) \
|
||||
inst_f(name, uint32, uint32_t, op)
|
||||
|
||||
#define instantiate_reduce_helper_ints(inst_f, name, op) \
|
||||
inst_f(name, int8, int8_t, op) inst_f(name, int16, int16_t, op) \
|
||||
inst_f(name, int32, int32_t, op)
|
||||
|
||||
#define instantiate_reduce_helper_64b(inst_f, name, op) \
|
||||
inst_f(name, int64, int64_t, op) inst_f(name, uint64, uint64_t, op)
|
||||
|
||||
#define instantiate_reduce_helper_types(inst_f, name, op) \
|
||||
instantiate_reduce_helper_floats(inst_f, name, op) \
|
||||
instantiate_reduce_helper_uints(inst_f, name, op) \
|
||||
instantiate_reduce_helper_ints(inst_f, name, op)
|
||||
|
||||
#define instantiate_reduce_ops(inst_f, type_f) \
|
||||
type_f(inst_f, sum, Sum) type_f(inst_f, prod, Prod) \
|
||||
type_f(inst_f, min_, Min) type_f(inst_f, max_, Max)
|
||||
|
||||
// Special case for bool reductions
|
||||
#define instantiate_reduce_from_types_helper( \
|
||||
inst_f, name, tname, itype, otype, op) \
|
||||
inst_f(name##tname, itype, otype, op)
|
||||
|
||||
#define instantiate_reduce_from_types(inst_f, name, otype, op) \
|
||||
instantiate_reduce_from_types_helper(inst_f, name, bool_, bool, otype, op) \
|
||||
instantiate_reduce_from_types_helper( \
|
||||
inst_f, name, uint8, uint8_t, otype, op) \
|
||||
instantiate_reduce_from_types_helper( \
|
||||
inst_f, name, uint16, uint16_t, otype, op) \
|
||||
instantiate_reduce_from_types_helper( \
|
||||
inst_f, name, uint32, uint32_t, otype, op) \
|
||||
instantiate_reduce_from_types_helper( \
|
||||
inst_f, name, int8, int8_t, otype, op) \
|
||||
instantiate_reduce_from_types_helper( \
|
||||
inst_f, name, int16, int16_t, otype, op) \
|
||||
instantiate_reduce_from_types_helper( \
|
||||
inst_f, name, int32, int32_t, otype, op) \
|
||||
instantiate_reduce_from_types_helper( \
|
||||
inst_f, name, int64, int64_t, otype, op) \
|
||||
instantiate_reduce_from_types_helper( \
|
||||
inst_f, name, float16, half, otype, op) \
|
||||
instantiate_reduce_from_types_helper( \
|
||||
inst_f, \
|
||||
name, \
|
||||
float32, \
|
||||
float, \
|
||||
otype, \
|
||||
op) \
|
||||
instantiate_reduce_from_types_helper( \
|
||||
inst_f, \
|
||||
name, \
|
||||
bfloat16, \
|
||||
bfloat16_t, \
|
||||
otype, \
|
||||
op)
|
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user