mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Compare commits
92 Commits
v0.23.2
...
steel-refa
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
4c46e17a5d | ||
|
|
99eefd2ec0 | ||
|
|
e9e268336b | ||
|
|
7275ac7523 | ||
|
|
c4189a38e4 | ||
|
|
68d1b3256b | ||
|
|
9c6953bda7 | ||
|
|
ef7ece9851 | ||
|
|
ddaa4b7dcb | ||
|
|
dfae2c6989 | ||
|
|
515f104926 | ||
|
|
9ecefd56db | ||
|
|
e5d35aa187 | ||
|
|
00794c42bc | ||
|
|
08a1bf3f10 | ||
|
|
60c4154346 | ||
|
|
f2c85308c1 | ||
|
|
1a28b69ee2 | ||
|
|
ba09f01ce8 | ||
|
|
6cf48872b7 | ||
|
|
7b3b8fa000 | ||
|
|
ec5e2aae61 | ||
|
|
86389bf970 | ||
|
|
3290bfa690 | ||
|
|
8777fd104f | ||
|
|
c41f7565ed | ||
|
|
9ba81e3da4 | ||
|
|
c23888acd7 | ||
|
|
f98ce25ab9 | ||
|
|
de5f38fd48 | ||
|
|
ec2854b13a | ||
|
|
90823d2938 | ||
|
|
5f5770e3a2 | ||
|
|
28f39e9038 | ||
|
|
b2d2b37888 | ||
|
|
fe597e141c | ||
|
|
72ca1539e0 | ||
|
|
13b26775f1 | ||
|
|
05d7118561 | ||
|
|
98b901ad66 | ||
|
|
5580b47291 | ||
|
|
bc62932984 | ||
|
|
a6b5d6e759 | ||
|
|
a8931306e1 | ||
|
|
fecdb8717e | ||
|
|
916fd273ea | ||
|
|
0da8506552 | ||
|
|
eda7a7b43e | ||
|
|
022eabb734 | ||
|
|
aba899cef8 | ||
|
|
6a40e1c176 | ||
|
|
9307b2ab8b | ||
|
|
522d8d3917 | ||
|
|
a84cc0123f | ||
|
|
f018e248cd | ||
|
|
cfd7237a80 | ||
|
|
4eef8102c9 | ||
|
|
69e4dd506b | ||
|
|
25814a9458 | ||
|
|
2a980a76ce | ||
|
|
d343782c8b | ||
|
|
4e1994e9d7 | ||
|
|
65a38c452b | ||
|
|
7b7e2352cd | ||
|
|
1177d28395 | ||
|
|
005e7efa64 | ||
|
|
b42d13ec84 | ||
|
|
9adcd1a650 | ||
|
|
3c164fca8c | ||
|
|
95e335db7b | ||
|
|
f90206ad74 | ||
|
|
3779150750 | ||
|
|
0a9777aa5c | ||
|
|
45ad06aac8 | ||
|
|
c6ea2ba329 | ||
|
|
2770a10240 | ||
|
|
d2a94f9e6a | ||
|
|
32da94507a | ||
|
|
736a340478 | ||
|
|
117e1355a2 | ||
|
|
3c3e558c60 | ||
|
|
cffceda6ee | ||
|
|
048805ad2c | ||
|
|
d14c9fe7ea | ||
|
|
5db90ce822 | ||
|
|
d699cc1330 | ||
|
|
c4230747a1 | ||
|
|
5245f12a46 | ||
|
|
a198b2787e | ||
|
|
04edad8c59 | ||
|
|
392b3060b0 | ||
|
|
85b34d59bc |
@@ -24,8 +24,8 @@ jobs:
|
||||
type: boolean
|
||||
default: false
|
||||
macos:
|
||||
xcode: "15.2.0"
|
||||
resource_class: macos.m1.medium.gen1
|
||||
xcode: "16.2.0"
|
||||
resource_class: m2pro.medium
|
||||
steps:
|
||||
- checkout
|
||||
- run:
|
||||
@@ -89,6 +89,7 @@ jobs:
|
||||
pip install numpy
|
||||
sudo apt-get update
|
||||
sudo apt-get install libblas-dev liblapack-dev liblapacke-dev
|
||||
sudo apt-get install openmpi-bin openmpi-common libopenmpi-dev
|
||||
- run:
|
||||
name: Install Python package
|
||||
command: |
|
||||
@@ -108,6 +109,8 @@ jobs:
|
||||
name: Run Python tests
|
||||
command: |
|
||||
python3 -m unittest discover python/tests -v
|
||||
mpirun --bind-to none -host localhost:8 -np 8 python python/tests/mpi_test_distributed.py
|
||||
mlx.launch --verbose -n 8 python/tests/ring_test_distributed.py
|
||||
- run:
|
||||
name: Build CPP only
|
||||
command: |
|
||||
@@ -122,10 +125,15 @@ jobs:
|
||||
parameters:
|
||||
xcode_version:
|
||||
type: string
|
||||
default: "15.2.0"
|
||||
default: "16.2.0"
|
||||
macosx_deployment_target:
|
||||
type: string
|
||||
default: ""
|
||||
macos:
|
||||
xcode: << parameters.xcode_version >>
|
||||
resource_class: macos.m1.medium.gen1
|
||||
environment:
|
||||
MACOSX_DEPLOYMENT_TARGET: << parameters.macosx_deployment_target >>
|
||||
resource_class: m2pro.medium
|
||||
steps:
|
||||
- checkout
|
||||
- run:
|
||||
@@ -146,7 +154,9 @@ jobs:
|
||||
name: Install Python package
|
||||
command: |
|
||||
source env/bin/activate
|
||||
DEBUG=1 CMAKE_BUILD_PARALLEL_LEVEL=`sysctl -n hw.ncpu` pip install -e . -v
|
||||
DEBUG=1 CMAKE_BUILD_PARALLEL_LEVEL=`sysctl -n hw.ncpu` \
|
||||
CMAKE_ARGS="-DCMAKE_COMPILE_WARNING_AS_ERROR=ON" \
|
||||
pip install -e . -v
|
||||
- run:
|
||||
name: Generate package stubs
|
||||
command: |
|
||||
@@ -209,13 +219,18 @@ jobs:
|
||||
default: "3.9"
|
||||
xcode_version:
|
||||
type: string
|
||||
default: "15.2.0"
|
||||
default: "16.2.0"
|
||||
build_env:
|
||||
type: string
|
||||
default: ""
|
||||
macosx_deployment_target:
|
||||
type: string
|
||||
default: ""
|
||||
macos:
|
||||
xcode: << parameters.xcode_version >>
|
||||
resource_class: macos.m1.medium.gen1
|
||||
resource_class: m2pro.medium
|
||||
environment:
|
||||
MACOSX_DEPLOYMENT_TARGET: << parameters.macosx_deployment_target >>
|
||||
steps:
|
||||
- checkout
|
||||
- run:
|
||||
@@ -236,7 +251,7 @@ jobs:
|
||||
name: Install Python package
|
||||
command: |
|
||||
source env/bin/activate
|
||||
DEV_RELEASE=1 \
|
||||
env -u MACOSX_DEPLOYMENT_TARGET DEV_RELEASE=1 \
|
||||
CMAKE_BUILD_PARALLEL_LEVEL=`sysctl -n hw.ncpu` \
|
||||
pip install . -v
|
||||
- run:
|
||||
@@ -331,7 +346,7 @@ workflows:
|
||||
- mac_build_and_test:
|
||||
matrix:
|
||||
parameters:
|
||||
xcode_version: ["15.0.0", "15.2.0", "16.0.0"]
|
||||
macosx_deployment_target: ["13.5", "14.0"]
|
||||
- linux_build_and_test
|
||||
- build_documentation
|
||||
|
||||
@@ -351,8 +366,70 @@ workflows:
|
||||
matrix:
|
||||
parameters:
|
||||
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
|
||||
xcode_version: ["15.0.0", "15.2.0"]
|
||||
macosx_deployment_target: ["13.5", "14.0", "15.0"]
|
||||
build_env: ["PYPI_RELEASE=1"]
|
||||
xcode_version: ["16.2.0", "15.0.0"]
|
||||
exclude:
|
||||
- macosx_deployment_target: "13.5"
|
||||
xcode_version: "16.2.0"
|
||||
python_version: "3.9"
|
||||
build_env: "PYPI_RELEASE=1"
|
||||
- macosx_deployment_target: "13.5"
|
||||
xcode_version: "16.2.0"
|
||||
python_version: "3.10"
|
||||
build_env: "PYPI_RELEASE=1"
|
||||
- macosx_deployment_target: "13.5"
|
||||
xcode_version: "16.2.0"
|
||||
python_version: "3.11"
|
||||
build_env: "PYPI_RELEASE=1"
|
||||
- macosx_deployment_target: "13.5"
|
||||
xcode_version: "16.2.0"
|
||||
python_version: "3.12"
|
||||
build_env: "PYPI_RELEASE=1"
|
||||
- macosx_deployment_target: "13.5"
|
||||
xcode_version: "16.2.0"
|
||||
python_version: "3.13"
|
||||
build_env: "PYPI_RELEASE=1"
|
||||
- macosx_deployment_target: "14.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.9"
|
||||
build_env: "PYPI_RELEASE=1"
|
||||
- macosx_deployment_target: "14.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.10"
|
||||
build_env: "PYPI_RELEASE=1"
|
||||
- macosx_deployment_target: "14.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.11"
|
||||
build_env: "PYPI_RELEASE=1"
|
||||
- macosx_deployment_target: "14.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.12"
|
||||
build_env: "PYPI_RELEASE=1"
|
||||
- macosx_deployment_target: "14.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.13"
|
||||
build_env: "PYPI_RELEASE=1"
|
||||
- macosx_deployment_target: "15.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.9"
|
||||
build_env: "PYPI_RELEASE=1"
|
||||
- macosx_deployment_target: "15.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.10"
|
||||
build_env: "PYPI_RELEASE=1"
|
||||
- macosx_deployment_target: "15.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.11"
|
||||
build_env: "PYPI_RELEASE=1"
|
||||
- macosx_deployment_target: "15.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.12"
|
||||
build_env: "PYPI_RELEASE=1"
|
||||
- macosx_deployment_target: "15.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.13"
|
||||
build_env: "PYPI_RELEASE=1"
|
||||
- build_documentation:
|
||||
filters:
|
||||
tags:
|
||||
@@ -375,7 +452,7 @@ workflows:
|
||||
requires: [ hold ]
|
||||
matrix:
|
||||
parameters:
|
||||
xcode_version: ["15.0.0", "15.2.0", "16.0.0"]
|
||||
macosx_deployment_target: ["13.5", "14.0"]
|
||||
- linux_build_and_test:
|
||||
requires: [ hold ]
|
||||
nightly_build:
|
||||
@@ -388,7 +465,54 @@ workflows:
|
||||
matrix:
|
||||
parameters:
|
||||
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
|
||||
xcode_version: ["15.0.0", "15.2.0"]
|
||||
macosx_deployment_target: ["13.5", "14.0", "15.0"]
|
||||
xcode_version: ["16.2.0", "15.0.0"]
|
||||
exclude:
|
||||
- macosx_deployment_target: "13.5"
|
||||
xcode_version: "16.2.0"
|
||||
python_version: "3.9"
|
||||
- macosx_deployment_target: "13.5"
|
||||
xcode_version: "16.2.0"
|
||||
python_version: "3.10"
|
||||
- macosx_deployment_target: "13.5"
|
||||
xcode_version: "16.2.0"
|
||||
python_version: "3.11"
|
||||
- macosx_deployment_target: "13.5"
|
||||
xcode_version: "16.2.0"
|
||||
python_version: "3.12"
|
||||
- macosx_deployment_target: "13.5"
|
||||
xcode_version: "16.2.0"
|
||||
python_version: "3.13"
|
||||
- macosx_deployment_target: "14.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.9"
|
||||
- macosx_deployment_target: "14.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.10"
|
||||
- macosx_deployment_target: "14.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.11"
|
||||
- macosx_deployment_target: "14.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.12"
|
||||
- macosx_deployment_target: "14.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.13"
|
||||
- macosx_deployment_target: "15.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.9"
|
||||
- macosx_deployment_target: "15.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.10"
|
||||
- macosx_deployment_target: "15.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.11"
|
||||
- macosx_deployment_target: "15.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.12"
|
||||
- macosx_deployment_target: "15.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.13"
|
||||
weekly_build:
|
||||
when:
|
||||
and:
|
||||
@@ -399,8 +523,70 @@ workflows:
|
||||
matrix:
|
||||
parameters:
|
||||
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
|
||||
xcode_version: ["15.0.0", "15.2.0", "16.0.0"]
|
||||
macosx_deployment_target: ["13.5", "14.0", "15.0"]
|
||||
build_env: ["DEV_RELEASE=1"]
|
||||
xcode_version: ["16.2.0", "15.0.0"]
|
||||
exclude:
|
||||
- macosx_deployment_target: "13.5"
|
||||
xcode_version: "16.2.0"
|
||||
python_version: "3.9"
|
||||
build_env: "DEV_RELEASE=1"
|
||||
- macosx_deployment_target: "13.5"
|
||||
xcode_version: "16.2.0"
|
||||
python_version: "3.10"
|
||||
build_env: "DEV_RELEASE=1"
|
||||
- macosx_deployment_target: "13.5"
|
||||
xcode_version: "16.2.0"
|
||||
python_version: "3.11"
|
||||
build_env: "DEV_RELEASE=1"
|
||||
- macosx_deployment_target: "13.5"
|
||||
xcode_version: "16.2.0"
|
||||
python_version: "3.12"
|
||||
build_env: "DEV_RELEASE=1"
|
||||
- macosx_deployment_target: "13.5"
|
||||
xcode_version: "16.2.0"
|
||||
python_version: "3.13"
|
||||
build_env: "DEV_RELEASE=1"
|
||||
- macosx_deployment_target: "14.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.9"
|
||||
build_env: "DEV_RELEASE=1"
|
||||
- macosx_deployment_target: "14.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.10"
|
||||
build_env: "DEV_RELEASE=1"
|
||||
- macosx_deployment_target: "14.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.11"
|
||||
build_env: "DEV_RELEASE=1"
|
||||
- macosx_deployment_target: "14.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.12"
|
||||
build_env: "DEV_RELEASE=1"
|
||||
- macosx_deployment_target: "14.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.13"
|
||||
build_env: "DEV_RELEASE=1"
|
||||
- macosx_deployment_target: "15.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.9"
|
||||
build_env: "DEV_RELEASE=1"
|
||||
- macosx_deployment_target: "15.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.10"
|
||||
build_env: "DEV_RELEASE=1"
|
||||
- macosx_deployment_target: "15.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.11"
|
||||
build_env: "DEV_RELEASE=1"
|
||||
- macosx_deployment_target: "15.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.12"
|
||||
build_env: "DEV_RELEASE=1"
|
||||
- macosx_deployment_target: "15.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.13"
|
||||
build_env: "DEV_RELEASE=1"
|
||||
linux_test_release:
|
||||
when:
|
||||
and:
|
||||
|
||||
@@ -9,6 +9,7 @@ if(NOT MLX_VERSION)
|
||||
string(REGEX MATCH "#define MLX_VERSION_PATCH ([0-9]+)" _ "${_mlx_h_version}")
|
||||
set(_patch ${CMAKE_MATCH_1})
|
||||
set(MLX_PROJECT_VERSION "${_major}.${_minor}.${_patch}")
|
||||
set(MLX_VERSION ${MLX_PROJECT_VERSION})
|
||||
else()
|
||||
string(REGEX REPLACE "^([0-9]+\.[0-9]+\.[0-9]+).*" "\\1" MLX_PROJECT_VERSION
|
||||
${MLX_VERSION})
|
||||
@@ -41,8 +42,6 @@ option(MLX_BUILD_BLAS_FROM_SOURCE "Build OpenBLAS from source code" OFF)
|
||||
option(MLX_METAL_JIT "Use JIT compilation for Metal kernels" OFF)
|
||||
option(BUILD_SHARED_LIBS "Build mlx as a shared library" OFF)
|
||||
|
||||
add_compile_definitions("MLX_VERSION=${MLX_VERSION}")
|
||||
|
||||
# --------------------- Processor tests -------------------------
|
||||
message(
|
||||
STATUS
|
||||
@@ -77,7 +76,6 @@ include(FetchContent)
|
||||
cmake_policy(SET CMP0135 NEW)
|
||||
|
||||
add_library(mlx)
|
||||
set_target_properties(mlx PROPERTIES COMPILE_WARNING_AS_ERROR ON)
|
||||
|
||||
if(MLX_BUILD_METAL)
|
||||
set(METAL_LIB "-framework Metal")
|
||||
@@ -214,23 +212,13 @@ else()
|
||||
set(MLX_BUILD_ACCELERATE OFF)
|
||||
endif()
|
||||
|
||||
find_package(MPI)
|
||||
if(MPI_FOUND)
|
||||
execute_process(
|
||||
COMMAND zsh "-c" "mpirun --version"
|
||||
OUTPUT_VARIABLE MPI_VERSION
|
||||
ERROR_QUIET)
|
||||
if(${MPI_VERSION} MATCHES ".*Open MPI.*")
|
||||
target_include_directories(mlx PRIVATE ${MPI_INCLUDE_PATH})
|
||||
elseif(MPI_VERSION STREQUAL "")
|
||||
set(MPI_FOUND FALSE)
|
||||
message(
|
||||
WARNING "MPI found but mpirun is not available. Building without MPI.")
|
||||
else()
|
||||
set(MPI_FOUND FALSE)
|
||||
message(WARNING "MPI which is not OpenMPI found. Building without MPI.")
|
||||
endif()
|
||||
endif()
|
||||
message(STATUS "Downloading json")
|
||||
FetchContent_Declare(
|
||||
json
|
||||
URL https://github.com/nlohmann/json/releases/download/v3.11.3/json.tar.xz)
|
||||
FetchContent_MakeAvailable(json)
|
||||
target_include_directories(
|
||||
mlx PRIVATE $<BUILD_INTERFACE:${json_SOURCE_DIR}/single_include/nlohmann>)
|
||||
|
||||
add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/mlx)
|
||||
|
||||
|
||||
@@ -5,26 +5,26 @@ possible.
|
||||
|
||||
## Pull Requests
|
||||
|
||||
1. Fork and submit pull requests to the repo.
|
||||
1. Fork and submit pull requests to the repo.
|
||||
2. If you've added code that should be tested, add tests.
|
||||
3. If a change is likely to impact efficiency, run some of the benchmarks before
|
||||
and after the change. Examples of benchmarks can be found in `benchmarks/python/`.
|
||||
4. If you've changed APIs, update the documentation.
|
||||
5. Every PR should have passing tests and at least one review.
|
||||
5. Every PR should have passing tests and at least one review.
|
||||
6. For code formatting install `pre-commit` using something like `pip install pre-commit` and run `pre-commit install`.
|
||||
This should install hooks for running `black` and `clang-format` to ensure
|
||||
consistent style for C++ and python code.
|
||||
|
||||
|
||||
You can also run the formatters manually as follows:
|
||||
|
||||
```
|
||||
clang-format -i file.cpp
|
||||
```
|
||||
|
||||
```
|
||||
black file.py
|
||||
```
|
||||
|
||||
|
||||
```shell
|
||||
clang-format -i file.cpp
|
||||
```
|
||||
|
||||
```shell
|
||||
black file.py
|
||||
```
|
||||
|
||||
or run `pre-commit run --all-files` to check all files in the repo.
|
||||
|
||||
## Issues
|
||||
|
||||
@@ -157,7 +157,7 @@ def bench_shape(B, M, N, K, np_dtype, transpose="nn"):
|
||||
|
||||
|
||||
def get_gflop_count(B, M, N, K):
|
||||
return float(2.0 * N_iter_bench * N_iter_func * B * M * N * K) / float(1024.0**3)
|
||||
return float(2.0 * N_iter_bench * N_iter_func * B * M * N * K) / float(1000.0**3)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
@@ -175,6 +175,8 @@ if __name__ == "__main__":
|
||||
(1, 4096, 4096, 4096),
|
||||
)
|
||||
|
||||
print(f" B, M, N, K, dtype, t, gflops_pt, gflops_mx, diff%")
|
||||
|
||||
for dtype in dtypes:
|
||||
for transpose in transposes:
|
||||
for B, M, N, K in shapes:
|
||||
@@ -187,7 +189,7 @@ if __name__ == "__main__":
|
||||
diff = gflops_mx / gflops_pt - 1.0
|
||||
|
||||
print(
|
||||
f"{B:3d}, {M:4d}, {N:4d}, {K:4d}, {dtype}, {transpose}, {gflops_pt:05.3f}, {gflops_mx:05.3f}, {100. * diff:+5.2f}%"
|
||||
f"{B:3d}, {M:4d}, {N:4d}, {K:5d}, {dtype}, {transpose}, {gflops_pt:8.2f}, {gflops_mx:8.2f}, {100. * diff:+5.2f}%"
|
||||
)
|
||||
if gflops_pt >= 2.0 * gflops_mx:
|
||||
print("ATTENTION ^^^^^^^")
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
import argparse
|
||||
from time import time
|
||||
|
||||
import mlx.core as mx
|
||||
import torch
|
||||
|
||||
74
benchmarks/python/gather_mm_bench.py
Normal file
74
benchmarks/python/gather_mm_bench.py
Normal file
@@ -0,0 +1,74 @@
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
import mlx.core as mx
|
||||
from time_utils import time_fn
|
||||
|
||||
N = 1024
|
||||
D = 1024
|
||||
M = 1024
|
||||
E = 32
|
||||
I = 4
|
||||
|
||||
|
||||
def gather_sort(x, indices):
|
||||
N, M = indices.shape
|
||||
indices = indices.flatten()
|
||||
order = mx.argsort(indices)
|
||||
inv_order = mx.argsort(order)
|
||||
return x.flatten(0, -3)[order // M], indices[order], inv_order
|
||||
|
||||
|
||||
def scatter_unsort(x, inv_order, shape=None):
|
||||
x = x[inv_order]
|
||||
if shape is not None:
|
||||
x = mx.unflatten(x, 0, shape)
|
||||
return x
|
||||
|
||||
|
||||
def gather_mm_simulate(x, w, indices):
|
||||
x, idx, inv_order = gather_sort(x, indices)
|
||||
for i in range(2):
|
||||
y = mx.concatenate([x[i] @ w[j].T for i, j in enumerate(idx.tolist())], axis=0)
|
||||
x = y[:, None]
|
||||
x = scatter_unsort(x, inv_order, indices.shape)
|
||||
return x
|
||||
|
||||
|
||||
def time_gather_mm():
|
||||
x = mx.random.normal((N, 1, 1, D)) / 1024**0.5
|
||||
w1 = mx.random.normal((E, M, D)) / 1024**0.5
|
||||
w2 = mx.random.normal((E, D, M)) / 1024**0.5
|
||||
indices = (mx.random.uniform(shape=(N, I)) * E).astype(mx.uint32)
|
||||
sorted_indices = mx.sort(indices.flatten()).reshape(N, I)
|
||||
mx.eval(x, w1, w2, indices, sorted_indices)
|
||||
|
||||
def gather_mm(x, w1, w2, indices, sort):
|
||||
idx = indices
|
||||
inv_order = None
|
||||
if sort:
|
||||
x, idx, inv_order = gather_sort(x, indices)
|
||||
x = mx.gather_mm(x, w1.swapaxes(-1, -2), rhs_indices=idx, sorted_indices=sort)
|
||||
x = mx.gather_mm(x, w2.swapaxes(-1, -2), rhs_indices=idx, sorted_indices=sort)
|
||||
if sort:
|
||||
x = scatter_unsort(x, inv_order, indices.shape)
|
||||
return x
|
||||
|
||||
time_fn(gather_mm, x, w1, w2, indices, False)
|
||||
time_fn(gather_mm, x, w1, w2, sorted_indices, False)
|
||||
time_fn(gather_mm, x, w1, w2, indices, True)
|
||||
|
||||
x = mx.random.normal((N * I, D)) / 1024**0.5
|
||||
w1 = mx.random.normal((M, D)) / 1024**0.5
|
||||
w2 = mx.random.normal((D, M)) / 1024**0.5
|
||||
mx.eval(x, w1, w2)
|
||||
|
||||
def equivalent_matmul(x, w1, w2):
|
||||
x = x @ w1.T
|
||||
x = x @ w2.T
|
||||
return x
|
||||
|
||||
time_fn(equivalent_matmul, x, w1, w2)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
time_gather_mm()
|
||||
@@ -28,11 +28,34 @@ def bench(f, *args):
|
||||
return (e - s) * 1e-9
|
||||
|
||||
|
||||
def mlx_sdpa_fused_inner(q, k, v, scale):
|
||||
return mx.fast.scaled_dot_product_attention(q, k, v, scale=scale, mask=None)
|
||||
def prepare_inputs(B, qL, kL, D, qH, kH, mask, transpose, dtype):
|
||||
np_dtype = getattr(np, dtype)
|
||||
|
||||
shape_q = (B, qL, qH, D) if transpose else (B, qH, qL, D)
|
||||
shape_kv = (B, kL, kH, D) if transpose else (B, kH, kL, D)
|
||||
|
||||
scale = 1.0 / math.sqrt(D)
|
||||
|
||||
q_np = np.random.normal(0.0, 1.0, shape_q).astype(np_dtype)
|
||||
k_np = np.random.normal(0.0, scale, shape_kv).astype(np_dtype)
|
||||
v_np = np.random.normal(0.0, scale, shape_kv).astype(np_dtype)
|
||||
|
||||
q_mx = mx.array(q_np)
|
||||
k_mx = mx.array(k_np)
|
||||
v_mx = mx.array(v_np)
|
||||
|
||||
if mask is not None:
|
||||
if mask == "additive":
|
||||
mask_np = np.random.normal(0.0, 1.0, (B, qH, qL, kL)).astype(np_dtype)
|
||||
mask = mx.array(mask_np)
|
||||
elif mask == "bool":
|
||||
mask_np = np.random.uniform(0.0, 1.0, (B, qH, qL, kL)) < 0.5
|
||||
mask = mx.array(mask_np)
|
||||
|
||||
return q_mx, k_mx, v_mx, scale, mask
|
||||
|
||||
|
||||
def mlx_sdpa_unfused_inner(q, k, v, scale, f32softmax=False):
|
||||
def mlx_ref_attn(q, k, v, scale=1.0, mask=None):
|
||||
q_dtype = q.dtype
|
||||
q = q * mx.array(scale, q_dtype)
|
||||
n_q_heads = q.shape[-3]
|
||||
@@ -41,6 +64,7 @@ def mlx_sdpa_unfused_inner(q, k, v, scale, f32softmax=False):
|
||||
|
||||
B = q.shape[0]
|
||||
L = q.shape[2]
|
||||
kL = k.shape[2]
|
||||
|
||||
if n_repeats > 1:
|
||||
q = mx.reshape(q, [B, n_kv_heads, n_repeats, L, -1])
|
||||
@@ -48,10 +72,27 @@ def mlx_sdpa_unfused_inner(q, k, v, scale, f32softmax=False):
|
||||
v = mx.expand_dims(v, 2)
|
||||
|
||||
scores = q @ mx.swapaxes(k, -1, -2)
|
||||
if f32softmax:
|
||||
scores = mx.softmax(scores.astype(mx.float32), axis=-1).astype(q_dtype)
|
||||
else:
|
||||
scores = mx.softmax(scores, axis=-1)
|
||||
|
||||
if mask is not None:
|
||||
|
||||
if mask == "causal":
|
||||
q_offset = max(0, kL - L)
|
||||
q_indices = mx.arange(q_offset, q_offset + L)
|
||||
k_indices = mx.arange(kL)
|
||||
mask = q_indices[:, None] >= k_indices[None]
|
||||
|
||||
if n_repeats > 1 and mask.ndim >= 3:
|
||||
if mask.shape[-3] == 1:
|
||||
mask = mx.expand_dims(mask, -3)
|
||||
else:
|
||||
mask = mx.unflatten(mask, -3, (n_kv_heads, n_repeats))
|
||||
|
||||
if mask.dtype == mx.bool_:
|
||||
scores = mx.where(mask, scores, -np.float32(np.inf))
|
||||
else:
|
||||
scores += mask
|
||||
|
||||
scores = mx.softmax(scores, axis=-1, precise=True)
|
||||
|
||||
out = scores @ v
|
||||
if n_repeats > 1:
|
||||
@@ -60,74 +101,55 @@ def mlx_sdpa_unfused_inner(q, k, v, scale, f32softmax=False):
|
||||
return out
|
||||
|
||||
|
||||
def mlx_spda_unfused(q, k, v, scale, transpose):
|
||||
q_out = q
|
||||
def mlx_fused_attn(q, k, v, scale, mask):
|
||||
return mx.fast.scaled_dot_product_attention(q, k, v, scale=scale, mask=mask)
|
||||
|
||||
|
||||
def do_attention(f, q, k, v, scale, mask=None, transpose=False):
|
||||
if transpose:
|
||||
k = mx.transpose(k, (0, 2, 1, 3))
|
||||
v = mx.transpose(v, (0, 2, 1, 3))
|
||||
q_t = mx.transpose(q, (0, 2, 1, 3))
|
||||
k_t = mx.transpose(k, (0, 2, 1, 3))
|
||||
v_t = mx.transpose(v, (0, 2, 1, 3))
|
||||
o_t = f(q_t, k_t, v_t, scale=scale, mask=mask)
|
||||
return mx.transpose(o_t, (0, 2, 1, 3))
|
||||
else:
|
||||
return f(q, k, v, scale=scale, mask=mask)
|
||||
|
||||
|
||||
def do_attention_bench(f, q, k, v, scale, mask=None, transpose=False):
|
||||
q_out = q
|
||||
|
||||
for i in range(N_iter_func):
|
||||
if transpose:
|
||||
q_out = mx.transpose(q_out, (0, 2, 1, 3))
|
||||
q_out = mlx_sdpa_unfused_inner(q_out, k, v, scale)
|
||||
if transpose:
|
||||
q_out = mx.transpose(q_out, (0, 2, 1, 3))
|
||||
q_out = do_attention(f, q_out, k, v, scale, mask=mask, transpose=transpose)
|
||||
|
||||
mx.eval(q_out)
|
||||
return q_out
|
||||
|
||||
|
||||
def mlx_spda_fused(q, k, v, scale, transpose):
|
||||
q_out = q
|
||||
if transpose:
|
||||
k = mx.transpose(k, (0, 2, 1, 3))
|
||||
v = mx.transpose(v, (0, 2, 1, 3))
|
||||
|
||||
for i in range(N_iter_func):
|
||||
if transpose:
|
||||
q_out = mx.transpose(q_out, (0, 2, 1, 3))
|
||||
q_out = mlx_sdpa_fused_inner(q_out, k, v, scale)
|
||||
if transpose:
|
||||
q_out = mx.transpose(q_out, (0, 2, 1, 3))
|
||||
|
||||
mx.eval(q_out)
|
||||
return q_out
|
||||
|
||||
|
||||
def bench_shape(B, qsl, ksl, head_dim, n_q_heads, n_kv_heads, np_dtype, transpose=True):
|
||||
shape_q = (
|
||||
(B, qsl, n_q_heads, head_dim) if transpose else (B, n_q_heads, qsl, head_dim)
|
||||
)
|
||||
shape_kv = (
|
||||
(B, ksl, n_kv_heads, head_dim) if transpose else (B, n_kv_heads, ksl, head_dim)
|
||||
def bench_shape(
|
||||
B, qsl, ksl, head_dim, n_q_heads, n_kv_heads, dtype, transpose=True, mask_in=None
|
||||
):
|
||||
q_mx, k_mx, v_mx, scale, mask = prepare_inputs(
|
||||
B, qsl, ksl, head_dim, n_q_heads, n_kv_heads, mask_in, transpose, dtype
|
||||
)
|
||||
|
||||
q_np = np.random.normal(0.0, 1.0 / math.sqrt(head_dim), shape_q).astype(np_dtype)
|
||||
k_np = np.random.normal(0.0, 1.0 / math.sqrt(head_dim), shape_kv).astype(np_dtype)
|
||||
v_np = np.random.normal(0.0, 1.0 / math.sqrt(head_dim), shape_kv).astype(np_dtype)
|
||||
time_mlx_unfused = bench(
|
||||
do_attention_bench, mlx_ref_attn, q_mx, k_mx, v_mx, scale, mask, transpose
|
||||
)
|
||||
time_mlx_fused = bench(
|
||||
do_attention_bench, mlx_fused_attn, q_mx, k_mx, v_mx, scale, mask, transpose
|
||||
)
|
||||
|
||||
scale = math.sqrt(1.0 / head_dim)
|
||||
o_mlx_fused = do_attention(mlx_ref_attn, q_mx, k_mx, v_mx, scale, mask, transpose)
|
||||
o_mlx_unfused = do_attention(
|
||||
mlx_fused_attn, q_mx, k_mx, v_mx, scale, mask, transpose
|
||||
)
|
||||
|
||||
q_mx = mx.array(q_np)
|
||||
k_mx = mx.array(k_np)
|
||||
v_mx = mx.array(v_np)
|
||||
atol = 1e-5 if dtype == "float32" else 2e-4
|
||||
|
||||
time_mlx_unfused = bench(mlx_spda_unfused, q_mx, k_mx, v_mx, scale, transpose)
|
||||
time_mlx_fused = bench(mlx_spda_fused, q_mx, k_mx, v_mx, scale, transpose)
|
||||
|
||||
if transpose:
|
||||
q_mx = mx.transpose(q_mx, (0, 2, 1, 3))
|
||||
k_mx = mx.transpose(k_mx, (0, 2, 1, 3))
|
||||
v_mx = mx.transpose(v_mx, (0, 2, 1, 3))
|
||||
|
||||
o_mlx_fused = mlx_sdpa_fused_inner(q_mx, k_mx, v_mx, scale)
|
||||
o_mlx_unfused = mlx_sdpa_unfused_inner(q_mx, k_mx, v_mx, scale, f32softmax=True)
|
||||
|
||||
atol = 1e-5 if np_dtype == np.float32 else 1e-4
|
||||
|
||||
if not mx.allclose(o_mlx_fused, o_mlx_unfused, atol=atol):
|
||||
if not mx.allclose(o_mlx_fused, o_mlx_unfused, atol=atol, rtol=atol):
|
||||
print(
|
||||
f"Failed at (B: {B}, qsl: {qsl}, ksl: {ksl}, head_dim: {head_dim}, n_qh: {n_q_heads}, n_kvh: {n_kv_heads}) [tpose = {transpose}] with max(|a - b|) = {mx.max(mx.abs(o_mlx_unfused - o_mlx_fused)):3.2e}"
|
||||
f"Failed at (B: {B}, qsl: {qsl}, ksl: {ksl}, head_dim: {head_dim}, n_qh: {n_q_heads}, n_kvh: {n_kv_heads}, mask: {mask_in}) [tpose = {transpose}] with max(|a - b|) = {mx.max(mx.abs(o_mlx_unfused - o_mlx_fused)):3.2e}"
|
||||
)
|
||||
|
||||
return time_mlx_fused, time_mlx_unfused
|
||||
@@ -151,39 +173,51 @@ if __name__ == "__main__":
|
||||
( 1, 128, 128, 64, 32, 32),
|
||||
( 1, 256, 256, 64, 32, 32),
|
||||
( 1, 512, 512, 64, 32, 32),
|
||||
( 1, 1024, 1024, 64, 32, 32),
|
||||
( 1, 2048, 2048, 64, 32, 32),
|
||||
( 1, 4096, 4096, 64, 32, 32),
|
||||
( 1, 1024, 1024, 64, 32, 8),
|
||||
( 1, 2048, 2048, 64, 32, 8),
|
||||
( 1, 4096, 4096, 64, 32, 8),
|
||||
)
|
||||
|
||||
shapes_80 = (
|
||||
# ( B, qsl, ksl, head_dim, n_qh, n_kvh)
|
||||
( 1, 1024, 1024, 80, 32, 32),
|
||||
( 1, 2048, 2048, 80, 32, 32),
|
||||
( 1, 4096, 4096, 80, 32, 32),
|
||||
( 1, 1024, 1024, 80, 32, 8),
|
||||
( 1, 2048, 2048, 80, 32, 8),
|
||||
( 1, 4096, 4096, 80, 32, 8),
|
||||
)
|
||||
|
||||
shapes_128 = (
|
||||
# ( B, qsl, ksl, head_dim, n_qh, n_kvh)
|
||||
( 1, 1024, 1024, 128, 32, 32),
|
||||
( 1, 2048, 2048, 128, 32, 32),
|
||||
( 1, 4096, 4096, 128, 32, 32),
|
||||
( 1, 1024, 1024, 128, 32, 8),
|
||||
( 1, 2048, 2048, 128, 32, 8),
|
||||
( 1, 4096, 4096, 128, 32, 8),
|
||||
)
|
||||
# fmt: on
|
||||
|
||||
shapes = shapes_64 + shapes_80 + shapes_128
|
||||
|
||||
print(" B, qsl, ksl, hdim, n_qh, n_kvh, tpose, dtype, t_unfs, t_fuse, diff%")
|
||||
masks = [None, "bool", "causal"]
|
||||
|
||||
print(
|
||||
" B, qsl, ksl, hdim, n_qh, n_kvh, t, dtype, mask, t_unfs, t_fuse, diff%"
|
||||
)
|
||||
|
||||
for dtype in dtypes:
|
||||
for transpose in transposes:
|
||||
for B, qsl, ksl, head_dim, n_q_heads, n_kv_heads in shapes:
|
||||
np_dtype = getattr(np, dtype)
|
||||
time_mlx_fused, time_mlx_unfused = bench_shape(
|
||||
B, qsl, ksl, head_dim, n_q_heads, n_kv_heads, np_dtype, transpose
|
||||
)
|
||||
diff = time_mlx_unfused / time_mlx_fused - 1.0
|
||||
t_str = 1 if transpose else 0
|
||||
print(
|
||||
f"{B:3d}, {qsl:5d}, {ksl:5d}, {head_dim:4d}, {n_q_heads:4d}, {n_kv_heads:5d}, {t_str:5d}, {dtype}, {time_mlx_unfused: 2.3f}, {time_mlx_fused: 2.3f}, {100. * diff:+5.2f}%"
|
||||
)
|
||||
for mask_in in masks:
|
||||
time_mlx_fused, time_mlx_unfused = bench_shape(
|
||||
B,
|
||||
qsl,
|
||||
ksl,
|
||||
head_dim,
|
||||
n_q_heads,
|
||||
n_kv_heads,
|
||||
dtype,
|
||||
transpose,
|
||||
mask_in,
|
||||
)
|
||||
diff = time_mlx_unfused / time_mlx_fused - 1.0
|
||||
t_str = 1 if transpose else 0
|
||||
print(
|
||||
f"{B:3d}, {qsl:5d}, {ksl:5d}, {head_dim:4d}, {n_q_heads:4d}, {n_kv_heads:5d}, {t_str:1d}, {dtype}, {str(mask_in):>8}, {time_mlx_unfused: 2.3f}, {time_mlx_fused: 2.3f}, {100. * diff:+5.2f}%"
|
||||
)
|
||||
|
||||
@@ -13,7 +13,7 @@ EXCLUDE_PATTERNS = */private/*
|
||||
CREATE_SUBDIRS = NO
|
||||
FULL_PATH_NAMES = YES
|
||||
RECURSIVE = YES
|
||||
GENERATE_HTML = YES
|
||||
GENERATE_HTML = NO
|
||||
GENERATE_LATEX = NO
|
||||
GENERATE_XML = YES
|
||||
XML_PROGRAMLISTING = YES
|
||||
|
||||
@@ -22,12 +22,12 @@ You can do that in MLX directly:
|
||||
This function performs that operation while leaving the implementation and
|
||||
function transformations to MLX.
|
||||
|
||||
However you may need to customize the underlying implementation, perhaps to
|
||||
make it faster or for custom differentiation. In this tutorial we will go
|
||||
through adding custom extensions. It will cover:
|
||||
However, you may want to customize the underlying implementation, perhaps to
|
||||
make it faster. In this tutorial we will go through adding custom extensions.
|
||||
It will cover:
|
||||
|
||||
* The structure of the MLX library.
|
||||
* Implementing a CPU operation that redirects to Accelerate_ when appropriate.
|
||||
* Implementing a CPU operation.
|
||||
* Implementing a GPU operation using metal.
|
||||
* Adding the ``vjp`` and ``jvp`` function transformation.
|
||||
* Building a custom extension and binding it to python.
|
||||
@@ -45,7 +45,7 @@ Operations
|
||||
Operations are the front-end functions that operate on arrays. They are defined
|
||||
in the C++ API (:ref:`cpp_ops`), and the Python API (:ref:`ops`) binds them.
|
||||
|
||||
We would like an operation, :meth:`axpby` that takes in two arrays ``x`` and
|
||||
We would like an operation :meth:`axpby` that takes in two arrays, ``x`` and
|
||||
``y``, and two scalars, ``alpha`` and ``beta``. This is how to define it in
|
||||
C++:
|
||||
|
||||
@@ -55,7 +55,7 @@ C++:
|
||||
* Scale and sum two vectors element-wise
|
||||
* z = alpha * x + beta * y
|
||||
*
|
||||
* Follow numpy style broadcasting between x and y
|
||||
* Use NumPy-style broadcasting between x and y
|
||||
* Inputs are upcasted to floats if needed
|
||||
**/
|
||||
array axpby(
|
||||
@@ -66,7 +66,7 @@ C++:
|
||||
StreamOrDevice s = {} // Stream on which to schedule the operation
|
||||
);
|
||||
|
||||
The simplest way to this operation is in terms of existing operations:
|
||||
The simplest way to implement this is with existing operations:
|
||||
|
||||
.. code-block:: C++
|
||||
|
||||
@@ -93,9 +93,9 @@ Primitives
|
||||
^^^^^^^^^^^
|
||||
|
||||
A :class:`Primitive` is part of the computation graph of an :class:`array`. It
|
||||
defines how to create outputs arrays given a input arrays. Further, a
|
||||
defines how to create output arrays given input arrays. Further, a
|
||||
:class:`Primitive` has methods to run on the CPU or GPU and for function
|
||||
transformations such as ``vjp`` and ``jvp``. Lets go back to our example to be
|
||||
transformations such as ``vjp`` and ``jvp``. Let's go back to our example to be
|
||||
more concrete:
|
||||
|
||||
.. code-block:: C++
|
||||
@@ -128,7 +128,7 @@ more concrete:
|
||||
/** The vector-Jacobian product. */
|
||||
std::vector<array> vjp(
|
||||
const std::vector<array>& primals,
|
||||
const array& cotan,
|
||||
const std::vector<array>& cotangents,
|
||||
const std::vector<int>& argnums,
|
||||
const std::vector<array>& outputs) override;
|
||||
|
||||
@@ -153,9 +153,6 @@ more concrete:
|
||||
private:
|
||||
float alpha_;
|
||||
float beta_;
|
||||
|
||||
/** Fall back implementation for evaluation on CPU */
|
||||
void eval(const std::vector<array>& inputs, array& out);
|
||||
};
|
||||
|
||||
The :class:`Axpby` class derives from the base :class:`Primitive` class. The
|
||||
@@ -188,7 +185,7 @@ Let's reimplement our operation now in terms of our :class:`Axpby` primitive.
|
||||
auto promoted_dtype = promote_types(x.dtype(), y.dtype());
|
||||
|
||||
// Upcast to float32 for non-floating point inputs x and y
|
||||
auto out_dtype = is_floating_point(promoted_dtype)
|
||||
auto out_dtype = issubdtype(promoted_dtype, float32)
|
||||
? promoted_dtype
|
||||
: promote_types(promoted_dtype, float32);
|
||||
|
||||
@@ -234,49 +231,57 @@ the execution of the computation graph, and calls :meth:`Axpby::eval_cpu` or
|
||||
Implementing the CPU Back-end
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
Let's start by implementing a naive and generic version of
|
||||
:meth:`Axpby::eval_cpu`. We declared this as a private member function of
|
||||
:class:`Axpby` earlier called :meth:`Axpby::eval`.
|
||||
Let's start by implementing :meth:`Axpby::eval_cpu`.
|
||||
|
||||
Our naive method will go over each element of the output array, find the
|
||||
The method will go over each element of the output array, find the
|
||||
corresponding input elements of ``x`` and ``y`` and perform the operation
|
||||
point-wise. This is captured in the templated function :meth:`axpby_impl`.
|
||||
|
||||
.. code-block:: C++
|
||||
|
||||
template <typename T>
|
||||
void axpby_impl(
|
||||
const array& x,
|
||||
const array& y,
|
||||
array& out,
|
||||
float alpha_,
|
||||
float beta_) {
|
||||
// We only allocate memory when we are ready to fill the output
|
||||
// malloc_or_wait synchronously allocates available memory
|
||||
// There may be a wait executed here if the allocation is requested
|
||||
// under memory-pressured conditions
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
template <typename T>
|
||||
void axpby_impl(
|
||||
const mx::array& x,
|
||||
const mx::array& y,
|
||||
mx::array& out,
|
||||
float alpha_,
|
||||
float beta_,
|
||||
mx::Stream stream) {
|
||||
out.set_data(mx::allocator::malloc(out.nbytes()));
|
||||
|
||||
// Collect input and output data pointers
|
||||
const T* x_ptr = x.data<T>();
|
||||
const T* y_ptr = y.data<T>();
|
||||
T* out_ptr = out.data<T>();
|
||||
// Get the CPU command encoder and register input and output arrays
|
||||
auto& encoder = mx::cpu::get_command_encoder(stream);
|
||||
encoder.set_input_array(x);
|
||||
encoder.set_input_array(y);
|
||||
encoder.set_output_array(out);
|
||||
|
||||
// Cast alpha and beta to the relevant types
|
||||
T alpha = static_cast<T>(alpha_);
|
||||
T beta = static_cast<T>(beta_);
|
||||
// Launch the CPU kernel
|
||||
encoder.dispatch([x_ptr = x.data<T>(),
|
||||
y_ptr = y.data<T>(),
|
||||
out_ptr = out.data<T>(),
|
||||
size = out.size(),
|
||||
shape = out.shape(),
|
||||
x_strides = x.strides(),
|
||||
y_strides = y.strides(),
|
||||
alpha_,
|
||||
beta_]() {
|
||||
|
||||
// Do the element-wise operation for each output
|
||||
for (size_t out_idx = 0; out_idx < out.size(); out_idx++) {
|
||||
// Map linear indices to offsets in x and y
|
||||
auto x_offset = elem_to_loc(out_idx, x.shape(), x.strides());
|
||||
auto y_offset = elem_to_loc(out_idx, y.shape(), y.strides());
|
||||
// Cast alpha and beta to the relevant types
|
||||
T alpha = static_cast<T>(alpha_);
|
||||
T beta = static_cast<T>(beta_);
|
||||
|
||||
// We allocate the output to be contiguous and regularly strided
|
||||
// (defaults to row major) and hence it doesn't need additional mapping
|
||||
out_ptr[out_idx] = alpha * x_ptr[x_offset] + beta * y_ptr[y_offset];
|
||||
}
|
||||
}
|
||||
// Do the element-wise operation for each output
|
||||
for (size_t out_idx = 0; out_idx < size; out_idx++) {
|
||||
// Map linear indices to offsets in x and y
|
||||
auto x_offset = mx::elem_to_loc(out_idx, shape, x_strides);
|
||||
auto y_offset = mx::elem_to_loc(out_idx, shape, y_strides);
|
||||
|
||||
// We allocate the output to be contiguous and regularly strided
|
||||
// (defaults to row major) and hence it doesn't need additional mapping
|
||||
out_ptr[out_idx] = alpha * x_ptr[x_offset] + beta * y_ptr[y_offset];
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
Our implementation should work for all incoming floating point arrays.
|
||||
Accordingly, we add dispatches for ``float32``, ``float16``, ``bfloat16`` and
|
||||
@@ -284,112 +289,32 @@ Accordingly, we add dispatches for ``float32``, ``float16``, ``bfloat16`` and
|
||||
|
||||
.. code-block:: C++
|
||||
|
||||
/** Fall back implementation for evaluation on CPU */
|
||||
void Axpby::eval(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<array>& outputs) {
|
||||
auto& x = inputs[0];
|
||||
auto& y = inputs[1];
|
||||
auto& out = outputs[0];
|
||||
|
||||
// Dispatch to the correct dtype
|
||||
if (out.dtype() == float32) {
|
||||
return axpby_impl<float>(x, y, out, alpha_, beta_);
|
||||
} else if (out.dtype() == float16) {
|
||||
return axpby_impl<float16_t>(x, y, out, alpha_, beta_);
|
||||
} else if (out.dtype() == bfloat16) {
|
||||
return axpby_impl<bfloat16_t>(x, y, out, alpha_, beta_);
|
||||
} else if (out.dtype() == complex64) {
|
||||
return axpby_impl<complex64_t>(x, y, out, alpha_, beta_);
|
||||
} else {
|
||||
throw std::runtime_error(
|
||||
"[Axpby] Only supports floating point types.");
|
||||
}
|
||||
}
|
||||
|
||||
This is good as a fallback implementation. We can use the ``axpby`` routine
|
||||
provided by the Accelerate_ framework for a faster implementation in certain
|
||||
cases:
|
||||
|
||||
#. Accelerate does not provide implementations of ``axpby`` for half precision
|
||||
floats. We can only use it for ``float32`` types.
|
||||
#. Accelerate assumes the inputs ``x`` and ``y`` are contiguous and all
|
||||
elements have fixed strides between them. We only direct to Accelerate
|
||||
if both ``x`` and ``y`` are row contiguous or column contiguous.
|
||||
#. Accelerate performs the routine ``Y = (alpha * X) + (beta * Y)`` in-place.
|
||||
MLX expects to write the output to a new array. We must copy the elements
|
||||
of ``y`` into the output and use that as an input to ``axpby``.
|
||||
|
||||
Let's write an implementation that uses Accelerate in the right conditions.
|
||||
It allocates data for the output, copies ``y`` into it, and then calls the
|
||||
:func:`catlas_saxpby` from accelerate.
|
||||
|
||||
.. code-block:: C++
|
||||
|
||||
template <typename T>
|
||||
void axpby_impl_accelerate(
|
||||
const array& x,
|
||||
const array& y,
|
||||
array& out,
|
||||
float alpha_,
|
||||
float beta_) {
|
||||
// Accelerate library provides catlas_saxpby which does
|
||||
// Y = (alpha * X) + (beta * Y) in place
|
||||
// To use it, we first copy the data in y over to the output array
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
|
||||
// We then copy over the elements using the contiguous vector specialization
|
||||
copy_inplace(y, out, CopyType::Vector);
|
||||
|
||||
// Get x and y pointers for catlas_saxpby
|
||||
const T* x_ptr = x.data<T>();
|
||||
T* y_ptr = out.data<T>();
|
||||
|
||||
T alpha = static_cast<T>(alpha_);
|
||||
T beta = static_cast<T>(beta_);
|
||||
|
||||
// Call the inplace accelerate operator
|
||||
catlas_saxpby(
|
||||
/* N = */ out.size(),
|
||||
/* ALPHA = */ alpha,
|
||||
/* X = */ x_ptr,
|
||||
/* INCX = */ 1,
|
||||
/* BETA = */ beta,
|
||||
/* Y = */ y_ptr,
|
||||
/* INCY = */ 1);
|
||||
}
|
||||
|
||||
For inputs that do not fit the criteria for accelerate, we fall back to
|
||||
:meth:`Axpby::eval`. With this in mind, let's finish our
|
||||
:meth:`Axpby::eval_cpu`.
|
||||
|
||||
.. code-block:: C++
|
||||
|
||||
/** Evaluate primitive on CPU using accelerate specializations */
|
||||
void Axpby::eval_cpu(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<array>& outputs) {
|
||||
assert(inputs.size() == 2);
|
||||
auto& x = inputs[0];
|
||||
auto& y = inputs[1];
|
||||
auto& out = outputs[0];
|
||||
const std::vector<mx::array>& inputs,
|
||||
std::vector<mx::array>& outputs) {
|
||||
auto& x = inputs[0];
|
||||
auto& y = inputs[1];
|
||||
auto& out = outputs[0];
|
||||
|
||||
// Accelerate specialization for contiguous single precision float arrays
|
||||
if (out.dtype() == float32 &&
|
||||
((x.flags().row_contiguous && y.flags().row_contiguous) ||
|
||||
(x.flags().col_contiguous && y.flags().col_contiguous))) {
|
||||
axpby_impl_accelerate<float>(x, y, out, alpha_, beta_);
|
||||
return;
|
||||
}
|
||||
|
||||
// Fall back to common back-end if specializations are not available
|
||||
eval(inputs, outputs);
|
||||
// Dispatch to the correct dtype
|
||||
if (out.dtype() == mx::float32) {
|
||||
return axpby_impl<float>(x, y, out, alpha_, beta_, stream());
|
||||
} else if (out.dtype() == mx::float16) {
|
||||
return axpby_impl<mx::float16_t>(x, y, out, alpha_, beta_, stream());
|
||||
} else if (out.dtype() == mx::bfloat16) {
|
||||
return axpby_impl<mx::bfloat16_t>(x, y, out, alpha_, beta_, stream());
|
||||
} else if (out.dtype() == mx::complex64) {
|
||||
return axpby_impl<mx::complex64_t>(x, y, out, alpha_, beta_, stream());
|
||||
} else {
|
||||
throw std::runtime_error(
|
||||
"Axpby is only supported for floating point types.");
|
||||
}
|
||||
}
|
||||
|
||||
Just this much is enough to run the operation :meth:`axpby` on a CPU stream! If
|
||||
you do not plan on running the operation on the GPU or using transforms on
|
||||
computation graphs that contain :class:`Axpby`, you can stop implementing the
|
||||
primitive here and enjoy the speed-ups you get from the Accelerate library.
|
||||
primitive here.
|
||||
|
||||
Implementing the GPU Back-end
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
@@ -466,7 +391,7 @@ below.
|
||||
auto& d = metal::device(s.device);
|
||||
|
||||
// Allocate output memory
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
|
||||
// Resolve name of kernel
|
||||
std::ostringstream kname;
|
||||
@@ -544,7 +469,7 @@ one we just defined:
|
||||
const std::vector<array>& tangents,
|
||||
const std::vector<int>& argnums) {
|
||||
// Forward mode diff that pushes along the tangents
|
||||
// The jvp transform on the primitive can built with ops
|
||||
// The jvp transform on the primitive can be built with ops
|
||||
// that are scheduled on the same stream as the primitive
|
||||
|
||||
// If argnums = {0}, we only push along x in which case the
|
||||
@@ -556,7 +481,7 @@ one we just defined:
|
||||
auto scale_arr = array(scale, tangents[0].dtype());
|
||||
return {multiply(scale_arr, tangents[0], stream())};
|
||||
}
|
||||
// If, argnums = {0, 1}, we take contributions from both
|
||||
// If argnums = {0, 1}, we take contributions from both
|
||||
// which gives us jvp = tangent_x * alpha + tangent_y * beta
|
||||
else {
|
||||
return {axpby(tangents[0], tangents[1], alpha_, beta_, stream())};
|
||||
@@ -810,7 +735,7 @@ Let's look at a simple script and its results:
|
||||
|
||||
print(f"c shape: {c.shape}")
|
||||
print(f"c dtype: {c.dtype}")
|
||||
print(f"c correct: {mx.all(c == 6.0).item()}")
|
||||
print(f"c is correct: {mx.all(c == 6.0).item()}")
|
||||
|
||||
Output:
|
||||
|
||||
@@ -818,13 +743,13 @@ Output:
|
||||
|
||||
c shape: [3, 4]
|
||||
c dtype: float32
|
||||
c correctness: True
|
||||
c is correct: True
|
||||
|
||||
Results
|
||||
^^^^^^^
|
||||
|
||||
Let's run a quick benchmark and see how our new ``axpby`` operation compares
|
||||
with the naive :meth:`simple_axpby` we first defined on the CPU.
|
||||
with the naive :meth:`simple_axpby` we first defined.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
@@ -832,13 +757,11 @@ with the naive :meth:`simple_axpby` we first defined on the CPU.
|
||||
from mlx_sample_extensions import axpby
|
||||
import time
|
||||
|
||||
mx.set_default_device(mx.cpu)
|
||||
|
||||
def simple_axpby(x: mx.array, y: mx.array, alpha: float, beta: float) -> mx.array:
|
||||
return alpha * x + beta * y
|
||||
|
||||
M = 256
|
||||
N = 512
|
||||
M = 4096
|
||||
N = 4096
|
||||
|
||||
x = mx.random.normal((M, N))
|
||||
y = mx.random.normal((M, N))
|
||||
@@ -849,24 +772,24 @@ with the naive :meth:`simple_axpby` we first defined on the CPU.
|
||||
|
||||
def bench(f):
|
||||
# Warm up
|
||||
for i in range(100):
|
||||
for i in range(5):
|
||||
z = f(x, y, alpha, beta)
|
||||
mx.eval(z)
|
||||
|
||||
# Timed run
|
||||
s = time.time()
|
||||
for i in range(5000):
|
||||
for i in range(100):
|
||||
z = f(x, y, alpha, beta)
|
||||
mx.eval(z)
|
||||
e = time.time()
|
||||
return e - s
|
||||
return 1000 * (e - s) / 100
|
||||
|
||||
simple_time = bench(simple_axpby)
|
||||
custom_time = bench(axpby)
|
||||
|
||||
print(f"Simple axpby: {simple_time:.3f} s | Custom axpby: {custom_time:.3f} s")
|
||||
print(f"Simple axpby: {simple_time:.3f} ms | Custom axpby: {custom_time:.3f} ms")
|
||||
|
||||
The results are ``Simple axpby: 0.114 s | Custom axpby: 0.109 s``. We see
|
||||
The results are ``Simple axpby: 1.559 ms | Custom axpby: 0.774 ms``. We see
|
||||
modest improvements right away!
|
||||
|
||||
This operation is now good to be used to build other operations, in
|
||||
|
||||
@@ -70,6 +70,7 @@ are the CPU and GPU.
|
||||
python/fft
|
||||
python/linalg
|
||||
python/metal
|
||||
python/memory_management
|
||||
python/nn
|
||||
python/optimizers
|
||||
python/distributed
|
||||
|
||||
@@ -38,6 +38,7 @@ Array
|
||||
array.log10
|
||||
array.log1p
|
||||
array.log2
|
||||
array.logcumsumexp
|
||||
array.logsumexp
|
||||
array.max
|
||||
array.mean
|
||||
|
||||
@@ -20,5 +20,6 @@ Linear Algebra
|
||||
eigh
|
||||
lu
|
||||
lu_factor
|
||||
pinv
|
||||
solve
|
||||
solve_triangular
|
||||
|
||||
16
docs/src/python/memory_management.rst
Normal file
16
docs/src/python/memory_management.rst
Normal file
@@ -0,0 +1,16 @@
|
||||
Memory Management
|
||||
=================
|
||||
|
||||
.. currentmodule:: mlx.core
|
||||
|
||||
.. autosummary::
|
||||
:toctree: _autosummary
|
||||
|
||||
get_active_memory
|
||||
get_peak_memory
|
||||
reset_peak_memory
|
||||
get_cache_memory
|
||||
set_memory_limit
|
||||
set_cache_limit
|
||||
set_wired_limit
|
||||
clear_cache
|
||||
@@ -8,13 +8,5 @@ Metal
|
||||
|
||||
is_available
|
||||
device_info
|
||||
get_active_memory
|
||||
get_peak_memory
|
||||
reset_peak_memory
|
||||
get_cache_memory
|
||||
set_memory_limit
|
||||
set_cache_limit
|
||||
set_wired_limit
|
||||
clear_cache
|
||||
start_capture
|
||||
stop_capture
|
||||
|
||||
@@ -36,10 +36,12 @@ Operations
|
||||
bitwise_or
|
||||
bitwise_xor
|
||||
block_masked_mm
|
||||
broadcast_arrays
|
||||
broadcast_to
|
||||
ceil
|
||||
clip
|
||||
concatenate
|
||||
contiguous
|
||||
conj
|
||||
conjugate
|
||||
convolve
|
||||
@@ -101,6 +103,7 @@ Operations
|
||||
log10
|
||||
log1p
|
||||
logaddexp
|
||||
logcumsumexp
|
||||
logical_not
|
||||
logical_and
|
||||
logical_or
|
||||
|
||||
@@ -18,3 +18,4 @@ Common Optimizers
|
||||
AdamW
|
||||
Adamax
|
||||
Lion
|
||||
MultiOptimizer
|
||||
|
||||
@@ -9,6 +9,7 @@ Transforms
|
||||
:toctree: _autosummary
|
||||
|
||||
eval
|
||||
async_eval
|
||||
compile
|
||||
custom_function
|
||||
disable_compile
|
||||
|
||||
@@ -10,7 +10,6 @@ set(CMAKE_POSITION_INDEPENDENT_CODE ON)
|
||||
option(BUILD_SHARED_LIBS "Build extensions as a shared library" ON)
|
||||
|
||||
# ----------------------------- Dependencies -----------------------------
|
||||
find_package(MLX CONFIG REQUIRED)
|
||||
find_package(
|
||||
Python 3.8
|
||||
COMPONENTS Interpreter Development.Module
|
||||
@@ -21,6 +20,12 @@ execute_process(
|
||||
OUTPUT_VARIABLE nanobind_ROOT)
|
||||
find_package(nanobind CONFIG REQUIRED)
|
||||
|
||||
execute_process(
|
||||
COMMAND "${Python_EXECUTABLE}" -m mlx --cmake-dir
|
||||
OUTPUT_STRIP_TRAILING_WHITESPACE
|
||||
OUTPUT_VARIABLE MLX_ROOT)
|
||||
find_package(MLX CONFIG REQUIRED)
|
||||
|
||||
# ----------------------------- Extensions -----------------------------
|
||||
|
||||
# Add library
|
||||
|
||||
@@ -1,20 +1,14 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
// Copyright © 2023-2025 Apple Inc.
|
||||
|
||||
#include <cassert>
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
|
||||
#include "mlx/backend/common/copy.h"
|
||||
#include "mlx/backend/common/utils.h"
|
||||
#include "mlx/backend/cpu/copy.h"
|
||||
#include "mlx/backend/cpu/encoder.h"
|
||||
#include "mlx/utils.h"
|
||||
|
||||
#include "axpby/axpby.h"
|
||||
|
||||
#ifdef ACCELERATE_NEW_LAPACK
|
||||
#include <vecLib/cblas_new.h>
|
||||
#endif
|
||||
|
||||
#ifdef _METAL_
|
||||
#include "mlx/backend/metal/device.h"
|
||||
#include "mlx/backend/metal/utils.h"
|
||||
@@ -76,136 +70,65 @@ void axpby_impl(
|
||||
const mx::array& y,
|
||||
mx::array& out,
|
||||
float alpha_,
|
||||
float beta_) {
|
||||
// We only allocate memory when we are ready to fill the output
|
||||
// malloc_or_wait synchronously allocates available memory
|
||||
// There may be a wait executed here if the allocation is requested
|
||||
// under memory-pressured conditions
|
||||
out.set_data(mx::allocator::malloc_or_wait(out.nbytes()));
|
||||
float beta_,
|
||||
mx::Stream stream) {
|
||||
out.set_data(mx::allocator::malloc(out.nbytes()));
|
||||
|
||||
// Collect input and output data pointers
|
||||
const T* x_ptr = x.data<T>();
|
||||
const T* y_ptr = y.data<T>();
|
||||
T* out_ptr = out.data<T>();
|
||||
// Get the CPU command encoder and register input and output arrays
|
||||
auto& encoder = mx::cpu::get_command_encoder(stream);
|
||||
encoder.set_input_array(x);
|
||||
encoder.set_input_array(y);
|
||||
encoder.set_output_array(out);
|
||||
|
||||
// Cast alpha and beta to the relevant types
|
||||
T alpha = static_cast<T>(alpha_);
|
||||
T beta = static_cast<T>(beta_);
|
||||
// Launch the CPU kernel
|
||||
encoder.dispatch([x_ptr = x.data<T>(),
|
||||
y_ptr = y.data<T>(),
|
||||
out_ptr = out.data<T>(),
|
||||
size = out.size(),
|
||||
shape = out.shape(),
|
||||
x_strides = x.strides(),
|
||||
y_strides = y.strides(),
|
||||
alpha_,
|
||||
beta_]() {
|
||||
// Cast alpha and beta to the relevant types
|
||||
T alpha = static_cast<T>(alpha_);
|
||||
T beta = static_cast<T>(beta_);
|
||||
|
||||
// Do the element-wise operation for each output
|
||||
for (size_t out_idx = 0; out_idx < out.size(); out_idx++) {
|
||||
// Map linear indices to offsets in x and y
|
||||
auto x_offset = mx::elem_to_loc(out_idx, x.shape(), x.strides());
|
||||
auto y_offset = mx::elem_to_loc(out_idx, y.shape(), y.strides());
|
||||
// Do the element-wise operation for each output
|
||||
for (size_t out_idx = 0; out_idx < size; out_idx++) {
|
||||
// Map linear indices to offsets in x and y
|
||||
auto x_offset = mx::elem_to_loc(out_idx, shape, x_strides);
|
||||
auto y_offset = mx::elem_to_loc(out_idx, shape, y_strides);
|
||||
|
||||
// We allocate the output to be contiguous and regularly strided
|
||||
// (defaults to row major) and hence it doesn't need additional mapping
|
||||
out_ptr[out_idx] = alpha * x_ptr[x_offset] + beta * y_ptr[y_offset];
|
||||
}
|
||||
// We allocate the output to be contiguous and regularly strided
|
||||
// (defaults to row major) and hence it doesn't need additional mapping
|
||||
out_ptr[out_idx] = alpha * x_ptr[x_offset] + beta * y_ptr[y_offset];
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
/** Fall back implementation for evaluation on CPU */
|
||||
void Axpby::eval(
|
||||
void Axpby::eval_cpu(
|
||||
const std::vector<mx::array>& inputs,
|
||||
std::vector<mx::array>& outputs) {
|
||||
// Check the inputs (registered in the op while constructing the out array)
|
||||
assert(inputs.size() == 2);
|
||||
auto& x = inputs[0];
|
||||
auto& y = inputs[1];
|
||||
auto& out = outputs[0];
|
||||
|
||||
// Dispatch to the correct dtype
|
||||
if (out.dtype() == mx::float32) {
|
||||
return axpby_impl<float>(x, y, out, alpha_, beta_);
|
||||
return axpby_impl<float>(x, y, out, alpha_, beta_, stream());
|
||||
} else if (out.dtype() == mx::float16) {
|
||||
return axpby_impl<mx::float16_t>(x, y, out, alpha_, beta_);
|
||||
return axpby_impl<mx::float16_t>(x, y, out, alpha_, beta_, stream());
|
||||
} else if (out.dtype() == mx::bfloat16) {
|
||||
return axpby_impl<mx::bfloat16_t>(x, y, out, alpha_, beta_);
|
||||
return axpby_impl<mx::bfloat16_t>(x, y, out, alpha_, beta_, stream());
|
||||
} else if (out.dtype() == mx::complex64) {
|
||||
return axpby_impl<mx::complex64_t>(x, y, out, alpha_, beta_);
|
||||
return axpby_impl<mx::complex64_t>(x, y, out, alpha_, beta_, stream());
|
||||
} else {
|
||||
throw std::runtime_error(
|
||||
"Axpby is only supported for floating point types.");
|
||||
}
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Primitive Accelerate Backend Implementation
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#ifdef ACCELERATE_NEW_LAPACK
|
||||
|
||||
template <typename T>
|
||||
void axpby_impl_accelerate(
|
||||
const mx::array& x,
|
||||
const mx::array& y,
|
||||
mx::array& out,
|
||||
float alpha_,
|
||||
float beta_) {
|
||||
// Accelerate library provides catlas_saxpby which does
|
||||
// Y = (alpha * X) + (beta * Y) in place
|
||||
// To use it, we first copy the data in y over to the output array
|
||||
|
||||
// This specialization requires both x and y be contiguous in the same mode
|
||||
// i.e: corresponding linear indices in both point to corresponding elements
|
||||
// The data in the output array is allocated to match the strides in y
|
||||
// such that x, y, and out are contiguous in the same mode and
|
||||
// no transposition is needed
|
||||
out.set_data(mx::allocator::malloc_or_wait(out.nbytes()));
|
||||
|
||||
// We then copy over the elements using the contiguous vector specialization
|
||||
copy_inplace(y, out, mx::CopyType::Vector);
|
||||
|
||||
// Get x and y pointers for catlas_saxpby
|
||||
const T* x_ptr = x.data<T>();
|
||||
T* y_ptr = out.data<T>();
|
||||
|
||||
T alpha = static_cast<T>(alpha_);
|
||||
T beta = static_cast<T>(beta_);
|
||||
|
||||
// Call the inplace accelerate operator
|
||||
catlas_saxpby(
|
||||
/* N = */ out.size(),
|
||||
/* ALPHA = */ alpha,
|
||||
/* X = */ x_ptr,
|
||||
/* INCX = */ 1,
|
||||
/* BETA = */ beta,
|
||||
/* Y = */ y_ptr,
|
||||
/* INCY = */ 1);
|
||||
}
|
||||
|
||||
/** Evaluate primitive on CPU using accelerate specializations */
|
||||
void Axpby::eval_cpu(
|
||||
const std::vector<mx::array>& inputs,
|
||||
std::vector<mx::array>& outputs) {
|
||||
assert(inputs.size() == 2);
|
||||
auto& x = inputs[0];
|
||||
auto& y = inputs[1];
|
||||
auto& out = outputs[0];
|
||||
|
||||
// Accelerate specialization for contiguous single precision float arrays
|
||||
if (out.dtype() == mx::float32 &&
|
||||
((x.flags().row_contiguous && y.flags().row_contiguous) ||
|
||||
(x.flags().col_contiguous && y.flags().col_contiguous))) {
|
||||
axpby_impl_accelerate<float>(x, y, out, alpha_, beta_);
|
||||
return;
|
||||
}
|
||||
|
||||
// Fall back to common backend if specializations are not available
|
||||
eval(inputs, outputs);
|
||||
}
|
||||
|
||||
#else // Accelerate not available
|
||||
|
||||
/** Evaluate primitive on CPU falling back to common backend */
|
||||
void Axpby::eval_cpu(
|
||||
const std::vector<mx::array>& inputs,
|
||||
std::vector<mx::array>& outputs) {
|
||||
eval(inputs, outputs);
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Primitive Metal Backend Implementation
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
@@ -217,7 +140,6 @@ void Axpby::eval_gpu(
|
||||
const std::vector<mx::array>& inputs,
|
||||
std::vector<mx::array>& outputs) {
|
||||
// Prepare inputs
|
||||
assert(inputs.size() == 2);
|
||||
auto& x = inputs[0];
|
||||
auto& y = inputs[1];
|
||||
auto& out = outputs[0];
|
||||
@@ -236,12 +158,12 @@ void Axpby::eval_gpu(
|
||||
// Allocate output memory with strides based on specialization
|
||||
if (contiguous_kernel) {
|
||||
out.set_data(
|
||||
mx::allocator::malloc_or_wait(x.data_size() * out.itemsize()),
|
||||
mx::allocator::malloc(x.data_size() * out.itemsize()),
|
||||
x.data_size(),
|
||||
x.strides(),
|
||||
x.flags());
|
||||
} else {
|
||||
out.set_data(mx::allocator::malloc_or_wait(out.nbytes()));
|
||||
out.set_data(mx::allocator::malloc(out.nbytes()));
|
||||
}
|
||||
|
||||
// Resolve name of kernel (corresponds to axpby.metal)
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
// Copyright © 2023-2025 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -85,11 +85,6 @@ class Axpby : public mx::Primitive {
|
||||
private:
|
||||
float alpha_;
|
||||
float beta_;
|
||||
|
||||
/** Fall back implementation for evaluation on CPU */
|
||||
void eval(
|
||||
const std::vector<mx::array>& inputs,
|
||||
std::vector<mx::array>& outputs);
|
||||
};
|
||||
|
||||
} // namespace my_ext
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
// Copyright © 2023-2025 Apple Inc.
|
||||
|
||||
#include <metal_stdlib>
|
||||
|
||||
|
||||
@@ -17,9 +17,13 @@ target_sources(
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/transforms.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/linalg.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/version.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/backend/metal/metal.h)
|
||||
|
||||
# Define MLX_VERSION only in the version.cpp file.
|
||||
add_library(mlx_version STATIC ${CMAKE_CURRENT_SOURCE_DIR}/version.cpp)
|
||||
target_compile_definitions(mlx_version PRIVATE MLX_VERSION="${MLX_VERSION}")
|
||||
target_link_libraries(mlx PRIVATE $<BUILD_INTERFACE:mlx_version>)
|
||||
|
||||
if(MSVC)
|
||||
# Disable some MSVC warnings to speed up compilation.
|
||||
target_compile_options(mlx PUBLIC /wd4068 /wd4244 /wd4267 /wd4804)
|
||||
|
||||
@@ -4,12 +4,11 @@
|
||||
#include <sstream>
|
||||
|
||||
#include "mlx/allocator.h"
|
||||
#include "mlx/scheduler.h"
|
||||
|
||||
namespace mlx::core::allocator {
|
||||
|
||||
Buffer malloc(size_t size) {
|
||||
auto buffer = allocator().malloc(size, /* allow_swap */ true);
|
||||
auto buffer = allocator().malloc(size);
|
||||
if (size && !buffer.ptr()) {
|
||||
std::ostringstream msg;
|
||||
msg << "[malloc] Unable to allocate " << size << " bytes.";
|
||||
@@ -22,45 +21,4 @@ void free(Buffer buffer) {
|
||||
allocator().free(buffer);
|
||||
}
|
||||
|
||||
Buffer CommonAllocator::malloc(size_t size, bool) {
|
||||
void* ptr = std::malloc(size + sizeof(size_t));
|
||||
if (ptr != nullptr) {
|
||||
*static_cast<size_t*>(ptr) = size;
|
||||
}
|
||||
return Buffer{ptr};
|
||||
}
|
||||
|
||||
void CommonAllocator::free(Buffer buffer) {
|
||||
std::free(buffer.ptr());
|
||||
}
|
||||
|
||||
size_t CommonAllocator::size(Buffer buffer) const {
|
||||
if (buffer.ptr() == nullptr) {
|
||||
return 0;
|
||||
}
|
||||
return *static_cast<size_t*>(buffer.ptr());
|
||||
}
|
||||
|
||||
Buffer malloc_or_wait(size_t size) {
|
||||
auto buffer = allocator().malloc(size);
|
||||
|
||||
while (size && !buffer.ptr() && scheduler::n_active_tasks() > 0) {
|
||||
scheduler::wait_for_one();
|
||||
buffer = allocator().malloc(size);
|
||||
}
|
||||
|
||||
// Try swapping if needed
|
||||
if (size && !buffer.ptr()) {
|
||||
buffer = allocator().malloc(size, /* allow_swap = */ true);
|
||||
}
|
||||
|
||||
if (size && !buffer.ptr()) {
|
||||
std::ostringstream msg;
|
||||
msg << "[malloc_or_wait] Unable to allocate " << size << " bytes.";
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
|
||||
return buffer;
|
||||
}
|
||||
|
||||
} // namespace mlx::core::allocator
|
||||
|
||||
@@ -32,14 +32,10 @@ Buffer malloc(size_t size);
|
||||
|
||||
void free(Buffer buffer);
|
||||
|
||||
// Wait for running tasks to finish and free up memory
|
||||
// if allocation fails
|
||||
Buffer malloc_or_wait(size_t size);
|
||||
|
||||
class Allocator {
|
||||
/** Abstract base class for a memory allocator. */
|
||||
public:
|
||||
virtual Buffer malloc(size_t size, bool allow_swap = false) = 0;
|
||||
virtual Buffer malloc(size_t size) = 0;
|
||||
virtual void free(Buffer buffer) = 0;
|
||||
virtual size_t size(Buffer buffer) const = 0;
|
||||
|
||||
@@ -53,16 +49,4 @@ class Allocator {
|
||||
|
||||
Allocator& allocator();
|
||||
|
||||
class CommonAllocator : public Allocator {
|
||||
/** A general CPU allocator. */
|
||||
public:
|
||||
virtual Buffer malloc(size_t size, bool allow_swap = false) override;
|
||||
virtual void free(Buffer buffer) override;
|
||||
virtual size_t size(Buffer buffer) const override;
|
||||
|
||||
private:
|
||||
CommonAllocator() = default;
|
||||
friend Allocator& allocator();
|
||||
};
|
||||
|
||||
} // namespace mlx::core::allocator
|
||||
|
||||
@@ -56,6 +56,18 @@ std::vector<array> array::make_arrays(
|
||||
return outputs;
|
||||
}
|
||||
|
||||
array array::unsafe_weak_copy(const array& other) {
|
||||
auto cpy = array(other.shape(), other.dtype(), nullptr, {});
|
||||
cpy.set_data(
|
||||
other.buffer(),
|
||||
other.data_size(),
|
||||
other.strides(),
|
||||
other.flags(),
|
||||
[](auto) {});
|
||||
cpy.array_desc_->data_ptr = other.array_desc_->data_ptr;
|
||||
return cpy;
|
||||
}
|
||||
|
||||
array::array(std::initializer_list<float> data)
|
||||
: array_desc_(std::make_shared<ArrayDesc>(
|
||||
Shape{static_cast<ShapeElem>(data.size())},
|
||||
@@ -76,35 +88,27 @@ array::array(allocator::Buffer data, Shape shape, Dtype dtype, Deleter deleter)
|
||||
set_data(data, deleter);
|
||||
}
|
||||
|
||||
array::array(
|
||||
allocator::Buffer data,
|
||||
Shape shape,
|
||||
Dtype dtype,
|
||||
Strides strides,
|
||||
size_t data_size,
|
||||
Flags flags,
|
||||
Deleter deleter)
|
||||
: array_desc_(std::make_shared<ArrayDesc>(std::move(shape), dtype)) {
|
||||
set_data(data, data_size, std::move(strides), flags, deleter);
|
||||
}
|
||||
|
||||
void array::detach() {
|
||||
array_desc_->primitive = nullptr;
|
||||
for (auto& s : array_desc_->siblings) {
|
||||
s.array_desc_->primitive = nullptr;
|
||||
}
|
||||
for (auto& s : array_desc_->siblings) {
|
||||
s.array_desc_->inputs.clear();
|
||||
s.array_desc_->siblings.clear();
|
||||
s.array_desc_->position = 0;
|
||||
s.array_desc_->primitive = nullptr;
|
||||
}
|
||||
array_desc_->inputs.clear();
|
||||
array_desc_->siblings.clear();
|
||||
array_desc_->position = 0;
|
||||
array_desc_->primitive = nullptr;
|
||||
}
|
||||
|
||||
bool array::is_available() const {
|
||||
if (status() == Status::available) {
|
||||
return true;
|
||||
} else if (status() == Status::evaluated && event().is_signaled()) {
|
||||
} else if (
|
||||
status() == Status::evaluated &&
|
||||
(!event().valid() || event().is_signaled())) {
|
||||
set_status(Status::available);
|
||||
return true;
|
||||
}
|
||||
@@ -113,7 +117,10 @@ bool array::is_available() const {
|
||||
|
||||
void array::wait() {
|
||||
if (!is_available()) {
|
||||
event().wait();
|
||||
if (event().valid()) {
|
||||
event().wait();
|
||||
detach_event();
|
||||
}
|
||||
set_status(Status::available);
|
||||
}
|
||||
}
|
||||
@@ -174,34 +181,13 @@ void array::copy_shared_buffer(const array& other) {
|
||||
copy_shared_buffer(other, other.strides(), other.flags(), other.data_size());
|
||||
}
|
||||
|
||||
void array::move_shared_buffer(
|
||||
array other,
|
||||
const Strides& strides,
|
||||
Flags flags,
|
||||
size_t data_size,
|
||||
size_t offset /* = 0 */) {
|
||||
array_desc_->data = std::move(other.array_desc_->data);
|
||||
array_desc_->strides = strides;
|
||||
array_desc_->flags = flags;
|
||||
array_desc_->data_size = data_size;
|
||||
auto char_offset = sizeof(char) * itemsize() * offset;
|
||||
auto data_ptr = other.array_desc_->data_ptr;
|
||||
other.array_desc_->data_ptr = nullptr;
|
||||
array_desc_->data_ptr =
|
||||
static_cast<void*>(static_cast<char*>(data_ptr) + char_offset);
|
||||
}
|
||||
|
||||
void array::move_shared_buffer(array other) {
|
||||
move_shared_buffer(other, other.strides(), other.flags(), other.data_size());
|
||||
}
|
||||
|
||||
array::~array() {
|
||||
if (array_desc_ == nullptr) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Ignore arrays that might be detached during eval
|
||||
if (status() == array::Status::scheduled) {
|
||||
// Detached/detaching
|
||||
if (array_desc_->primitive == nullptr) {
|
||||
return;
|
||||
}
|
||||
|
||||
|
||||
37
mlx/array.h
37
mlx/array.h
@@ -199,6 +199,13 @@ class array {
|
||||
const std::shared_ptr<Primitive>& primitive,
|
||||
const std::vector<array>& inputs);
|
||||
|
||||
/**
|
||||
* Get a new array that refers to the same data as the input but with a
|
||||
* non-owning pointer to it. Note the array is detached from the graph and has
|
||||
* no inputs, siblings or primitive.
|
||||
*/
|
||||
static array unsafe_weak_copy(const array& other);
|
||||
|
||||
/** A unique identifier for an array. */
|
||||
std::uintptr_t id() const {
|
||||
return reinterpret_cast<std::uintptr_t>(array_desc_.get());
|
||||
@@ -243,18 +250,6 @@ class array {
|
||||
bool col_contiguous : 1;
|
||||
};
|
||||
|
||||
/** Build an array from all the info held by the array description. Including
|
||||
* the buffer, strides, flags.
|
||||
*/
|
||||
explicit array(
|
||||
allocator::Buffer data,
|
||||
Shape shape,
|
||||
Dtype dtype,
|
||||
Strides strides,
|
||||
size_t data_size,
|
||||
Flags flags,
|
||||
Deleter deleter = allocator::free);
|
||||
|
||||
/** The array's primitive. */
|
||||
Primitive& primitive() const {
|
||||
return *(array_desc_->primitive);
|
||||
@@ -365,11 +360,6 @@ class array {
|
||||
// For example, the status of `x` in `auto x = a + b`.
|
||||
unscheduled,
|
||||
|
||||
// The ouptut of a computation which has been scheduled but `eval_*` has
|
||||
// not yet been called on the array's primitive. A possible
|
||||
// status of `x` in `auto x = a + b; eval(x);`
|
||||
scheduled,
|
||||
|
||||
// The array's `eval_*` function has been run, but the computation is not
|
||||
// necessarily complete. The array will have memory allocated and if it is
|
||||
// not a tracer then it will be detached from the graph.
|
||||
@@ -406,6 +396,10 @@ class array {
|
||||
array_desc_->event = std::move(e);
|
||||
}
|
||||
|
||||
void detach_event() const {
|
||||
array_desc_->event = Event{};
|
||||
}
|
||||
|
||||
// Mark the array as a tracer array (true) or not.
|
||||
void set_tracer(bool is_tracer) {
|
||||
array_desc_->is_tracer = is_tracer;
|
||||
@@ -431,15 +425,6 @@ class array {
|
||||
|
||||
void copy_shared_buffer(const array& other);
|
||||
|
||||
void move_shared_buffer(
|
||||
array other,
|
||||
const Strides& strides,
|
||||
Flags flags,
|
||||
size_t data_size,
|
||||
size_t offset = 0);
|
||||
|
||||
void move_shared_buffer(array other);
|
||||
|
||||
void overwrite_descriptor(const array& other) {
|
||||
array_desc_ = other.array_desc_;
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
target_sources(
|
||||
mlx
|
||||
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/compiled.cpp
|
||||
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/broadcasting.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/compiled.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/common.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/load.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp
|
||||
|
||||
@@ -38,25 +38,20 @@ inline void set_binary_op_output_data(
|
||||
const array& a,
|
||||
const array& b,
|
||||
array& out,
|
||||
BinaryOpType bopt,
|
||||
bool donate_with_move = false) {
|
||||
BinaryOpType bopt) {
|
||||
bool b_donatable = is_donatable(b, out);
|
||||
bool a_donatable = is_donatable(a, out);
|
||||
switch (bopt) {
|
||||
case BinaryOpType::ScalarScalar:
|
||||
out.set_data(
|
||||
allocator::malloc_or_wait(out.itemsize()), 1, a.strides(), a.flags());
|
||||
allocator::malloc(out.itemsize()), 1, a.strides(), a.flags());
|
||||
break;
|
||||
case BinaryOpType::ScalarVector:
|
||||
if (b_donatable) {
|
||||
if (donate_with_move) {
|
||||
out.move_shared_buffer(b);
|
||||
} else {
|
||||
out.copy_shared_buffer(b);
|
||||
}
|
||||
out.copy_shared_buffer(b);
|
||||
} else {
|
||||
out.set_data(
|
||||
allocator::malloc_or_wait(b.data_size() * out.itemsize()),
|
||||
allocator::malloc(b.data_size() * out.itemsize()),
|
||||
b.data_size(),
|
||||
b.strides(),
|
||||
b.flags());
|
||||
@@ -64,14 +59,10 @@ inline void set_binary_op_output_data(
|
||||
break;
|
||||
case BinaryOpType::VectorScalar:
|
||||
if (a_donatable) {
|
||||
if (donate_with_move) {
|
||||
out.move_shared_buffer(a);
|
||||
} else {
|
||||
out.copy_shared_buffer(a);
|
||||
}
|
||||
out.copy_shared_buffer(a);
|
||||
} else {
|
||||
out.set_data(
|
||||
allocator::malloc_or_wait(a.data_size() * out.itemsize()),
|
||||
allocator::malloc(a.data_size() * out.itemsize()),
|
||||
a.data_size(),
|
||||
a.strides(),
|
||||
a.flags());
|
||||
@@ -79,20 +70,12 @@ inline void set_binary_op_output_data(
|
||||
break;
|
||||
case BinaryOpType::VectorVector:
|
||||
if (a_donatable) {
|
||||
if (donate_with_move) {
|
||||
out.move_shared_buffer(a);
|
||||
} else {
|
||||
out.copy_shared_buffer(a);
|
||||
}
|
||||
out.copy_shared_buffer(a);
|
||||
} else if (b_donatable) {
|
||||
if (donate_with_move) {
|
||||
out.move_shared_buffer(b);
|
||||
} else {
|
||||
out.copy_shared_buffer(b);
|
||||
}
|
||||
out.copy_shared_buffer(b);
|
||||
} else {
|
||||
out.set_data(
|
||||
allocator::malloc_or_wait(a.data_size() * out.itemsize()),
|
||||
allocator::malloc(a.data_size() * out.itemsize()),
|
||||
a.data_size(),
|
||||
a.strides(),
|
||||
a.flags());
|
||||
@@ -100,20 +83,12 @@ inline void set_binary_op_output_data(
|
||||
break;
|
||||
case BinaryOpType::General:
|
||||
if (a_donatable && a.flags().row_contiguous && a.size() == out.size()) {
|
||||
if (donate_with_move) {
|
||||
out.move_shared_buffer(a);
|
||||
} else {
|
||||
out.copy_shared_buffer(a);
|
||||
}
|
||||
out.copy_shared_buffer(a);
|
||||
} else if (
|
||||
b_donatable && b.flags().row_contiguous && b.size() == out.size()) {
|
||||
if (donate_with_move) {
|
||||
out.move_shared_buffer(b);
|
||||
} else {
|
||||
out.copy_shared_buffer(b);
|
||||
}
|
||||
out.copy_shared_buffer(b);
|
||||
} else {
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
24
mlx/backend/common/broadcasting.cpp
Normal file
24
mlx/backend/common/broadcasting.cpp
Normal file
@@ -0,0 +1,24 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#include "mlx/backend/common/utils.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
void broadcast(const array& in, array& out) {
|
||||
if (out.size() == 0) {
|
||||
out.set_data(nullptr);
|
||||
return;
|
||||
}
|
||||
Strides strides(out.ndim(), 0);
|
||||
int diff = out.ndim() - in.ndim();
|
||||
for (int i = in.ndim() - 1; i >= 0; --i) {
|
||||
strides[i + diff] = (in.shape()[i] == 1) ? 0 : in.strides()[i];
|
||||
}
|
||||
auto flags = in.flags();
|
||||
if (out.size() > in.size()) {
|
||||
flags.row_contiguous = flags.col_contiguous = false;
|
||||
}
|
||||
out.copy_shared_buffer(in, strides, flags, in.data_size());
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
@@ -1,10 +1,11 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "mlx/array.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
void encode_wait(Event e);
|
||||
|
||||
void encode_signal(Event e);
|
||||
void broadcast(const array& in, array& out);
|
||||
|
||||
} // namespace mlx::core
|
||||
@@ -1,6 +1,7 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
#include <cassert>
|
||||
|
||||
#include "mlx/backend/common/broadcasting.h"
|
||||
#include "mlx/backend/common/utils.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
@@ -39,24 +40,7 @@ void AsStrided::eval(const std::vector<array>& inputs, array& out) {
|
||||
// rely on data_size anyway.
|
||||
size_t data_size = out.size();
|
||||
|
||||
return move_or_copy(in, out, strides_, flags, data_size, offset_);
|
||||
}
|
||||
|
||||
void broadcast(const array& in, array& out) {
|
||||
if (out.size() == 0) {
|
||||
out.set_data(nullptr);
|
||||
return;
|
||||
}
|
||||
Strides strides(out.ndim(), 0);
|
||||
int diff = out.ndim() - in.ndim();
|
||||
for (int i = in.ndim() - 1; i >= 0; --i) {
|
||||
strides[i + diff] = (in.shape()[i] == 1) ? 0 : in.strides()[i];
|
||||
}
|
||||
auto flags = in.flags();
|
||||
if (out.size() > in.size()) {
|
||||
flags.row_contiguous = flags.col_contiguous = false;
|
||||
}
|
||||
move_or_copy(in, out, strides, flags, in.data_size());
|
||||
return out.copy_shared_buffer(in, strides_, flags, data_size, offset_);
|
||||
}
|
||||
|
||||
void Broadcast::eval(const std::vector<array>& inputs, array& out) {
|
||||
@@ -69,7 +53,7 @@ void BroadcastAxes::eval(const std::vector<array>& inputs, array& out) {
|
||||
|
||||
void Copy::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
move_or_copy(inputs[0], out);
|
||||
out.copy_shared_buffer(inputs[0]);
|
||||
}
|
||||
|
||||
void CustomTransforms::eval(
|
||||
@@ -78,7 +62,7 @@ void CustomTransforms::eval(
|
||||
assert(inputs.size() > outputs.size());
|
||||
for (int i = 0, j = inputs.size() - outputs.size(); i < outputs.size();
|
||||
i++, j++) {
|
||||
move_or_copy(inputs[j], outputs[i]);
|
||||
outputs[i].copy_shared_buffer(inputs[j]);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -87,7 +71,7 @@ void Depends::eval(
|
||||
std::vector<array>& outputs) {
|
||||
assert(inputs.size() > outputs.size());
|
||||
for (int i = 0; i < outputs.size(); i++) {
|
||||
move_or_copy(inputs[i], outputs[i]);
|
||||
outputs[i].copy_shared_buffer(inputs[i]);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -98,12 +82,12 @@ void ExpandDims::eval(const std::vector<array>& inputs, array& out) {
|
||||
for (auto ax : axes_) {
|
||||
strides.insert(strides.begin() + ax, 1);
|
||||
}
|
||||
move_or_copy(in, out, strides, in.flags(), in.data_size());
|
||||
out.copy_shared_buffer(in, strides, in.flags(), in.data_size());
|
||||
}
|
||||
|
||||
void NumberOfElements::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
|
||||
double numel = 1;
|
||||
for (auto ax : axes_) {
|
||||
@@ -210,7 +194,7 @@ void shared_buffer_reshape(
|
||||
auto max_dim = std::max_element(out.shape().begin(), out.shape().end());
|
||||
flags.col_contiguous = out.size() <= 1 || out.size() == *max_dim;
|
||||
}
|
||||
move_or_copy(in, out, out_strides, flags, in.data_size());
|
||||
out.copy_shared_buffer(in, out_strides, flags, in.data_size());
|
||||
}
|
||||
|
||||
void Split::eval(
|
||||
@@ -276,12 +260,12 @@ void Squeeze::eval(const std::vector<array>& inputs, array& out) {
|
||||
strides.push_back(in.strides(i));
|
||||
}
|
||||
}
|
||||
move_or_copy(in, out, strides, in.flags(), in.data_size());
|
||||
out.copy_shared_buffer(in, strides, in.flags(), in.data_size());
|
||||
}
|
||||
|
||||
void StopGradient::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
move_or_copy(inputs[0], out);
|
||||
out.copy_shared_buffer(inputs[0]);
|
||||
}
|
||||
|
||||
void Transpose::eval(const std::vector<array>& inputs, array& out) {
|
||||
@@ -315,7 +299,7 @@ void Transpose::eval(const std::vector<array>& inputs, array& out) {
|
||||
b_stride *= out.shape(ri);
|
||||
}
|
||||
}
|
||||
move_or_copy(in, out, out_strides, flags, in.data_size());
|
||||
out.copy_shared_buffer(in, out_strides, flags, in.data_size());
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
||||
@@ -161,8 +161,7 @@ void compiled_allocate_outputs(
|
||||
std::vector<array>& outputs,
|
||||
const std::vector<array>& inputs_,
|
||||
const std::unordered_set<uintptr_t>& constant_ids_,
|
||||
bool contiguous,
|
||||
bool move_buffers /* = false */) {
|
||||
bool contiguous) {
|
||||
if (contiguous) {
|
||||
int o = 0;
|
||||
Strides strides;
|
||||
@@ -178,11 +177,7 @@ void compiled_allocate_outputs(
|
||||
if (in.itemsize() == outputs[o].itemsize() && !is_scalar(in) &&
|
||||
in.is_donatable() &&
|
||||
constant_ids_.find(inputs_[i].id()) == constant_ids_.end()) {
|
||||
if (move_buffers) {
|
||||
outputs[o++].move_shared_buffer(in);
|
||||
} else {
|
||||
outputs[o++].copy_shared_buffer(in);
|
||||
}
|
||||
outputs[o++].copy_shared_buffer(in);
|
||||
}
|
||||
// Get representative input flags to properly set non-donated outputs
|
||||
if (strides.empty() && in.size() == outputs[0].size()) {
|
||||
@@ -193,7 +188,7 @@ void compiled_allocate_outputs(
|
||||
}
|
||||
for (; o < outputs.size(); ++o) {
|
||||
outputs[o].set_data(
|
||||
allocator::malloc_or_wait(data_size * outputs[o].itemsize()),
|
||||
allocator::malloc(data_size * outputs[o].itemsize()),
|
||||
data_size,
|
||||
strides,
|
||||
flags);
|
||||
@@ -210,18 +205,13 @@ void compiled_allocate_outputs(
|
||||
if (in.flags().row_contiguous && in.size() == outputs[o].size() &&
|
||||
in.itemsize() == outputs[o].itemsize() && in.is_donatable() &&
|
||||
constant_ids_.find(inputs_[i].id()) == constant_ids_.end()) {
|
||||
if (move_buffers) {
|
||||
outputs[o].move_shared_buffer(
|
||||
in, outputs[o].strides(), in.flags(), in.data_size());
|
||||
} else {
|
||||
outputs[o].copy_shared_buffer(
|
||||
in, outputs[o].strides(), in.flags(), in.data_size());
|
||||
}
|
||||
outputs[o].copy_shared_buffer(
|
||||
in, outputs[o].strides(), in.flags(), in.data_size());
|
||||
o++;
|
||||
}
|
||||
}
|
||||
for (; o < outputs.size(); ++o) {
|
||||
outputs[o].set_data(allocator::malloc_or_wait(outputs[o].nbytes()));
|
||||
outputs[o].set_data(allocator::malloc(outputs[o].nbytes()));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -62,7 +62,6 @@ void compiled_allocate_outputs(
|
||||
std::vector<array>& outputs,
|
||||
const std::vector<array>& inputs_,
|
||||
const std::unordered_set<uintptr_t>& constant_ids_,
|
||||
bool contiguous,
|
||||
bool move_buffers = false);
|
||||
bool contiguous);
|
||||
|
||||
} // namespace mlx::core
|
||||
|
||||
@@ -22,4 +22,25 @@ enum class CopyType {
|
||||
GeneralGeneral
|
||||
};
|
||||
|
||||
inline bool set_copy_output_data(const array& in, array& out, CopyType ctype) {
|
||||
if (ctype == CopyType::Vector) {
|
||||
// If the input is donateable, we are doing a vector copy and the types
|
||||
// have the same size, then the input buffer can hold the output.
|
||||
if (in.is_donatable() && in.itemsize() == out.itemsize()) {
|
||||
out.copy_shared_buffer(in);
|
||||
return true;
|
||||
} else {
|
||||
out.set_data(
|
||||
allocator::malloc(in.data_size() * out.itemsize()),
|
||||
in.data_size(),
|
||||
in.strides(),
|
||||
in.flags());
|
||||
return false;
|
||||
}
|
||||
} else {
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
||||
@@ -3,7 +3,8 @@
|
||||
#include <algorithm>
|
||||
#include <utility>
|
||||
|
||||
#include "mlx/backend/common/load.h"
|
||||
#include "mlx/primitives.h"
|
||||
#include "mlx/scheduler.h"
|
||||
|
||||
namespace {
|
||||
|
||||
@@ -26,26 +27,31 @@ void swap_endianness(uint8_t* data_bytes, size_t N) {
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
void load(
|
||||
array& out,
|
||||
size_t offset,
|
||||
const std::shared_ptr<io::Reader>& reader,
|
||||
bool swap_endianness_) {
|
||||
reader->read(out.data<char>(), out.nbytes(), offset);
|
||||
|
||||
if (swap_endianness_) {
|
||||
switch (out.itemsize()) {
|
||||
case 2:
|
||||
swap_endianness<2>(out.data<uint8_t>(), out.data_size());
|
||||
break;
|
||||
case 4:
|
||||
swap_endianness<4>(out.data<uint8_t>(), out.data_size());
|
||||
break;
|
||||
case 8:
|
||||
swap_endianness<8>(out.data<uint8_t>(), out.data_size());
|
||||
break;
|
||||
void Load::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
auto read_task = [out_ptr = out.data<char>(),
|
||||
size = out.size(),
|
||||
itemsize = out.itemsize(),
|
||||
offset = offset_,
|
||||
reader = reader_,
|
||||
swap_endianness_ = swap_endianness_]() mutable {
|
||||
reader->read(out_ptr, size * itemsize, offset);
|
||||
if (swap_endianness_) {
|
||||
switch (itemsize) {
|
||||
case 2:
|
||||
swap_endianness<2>(reinterpret_cast<uint8_t*>(out_ptr), size);
|
||||
break;
|
||||
case 4:
|
||||
swap_endianness<4>(reinterpret_cast<uint8_t*>(out_ptr), size);
|
||||
break;
|
||||
case 8:
|
||||
swap_endianness<8>(reinterpret_cast<uint8_t*>(out_ptr), size);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
auto fut = io::thread_pool().enqueue(std::move(read_task)).share();
|
||||
scheduler::enqueue(stream(), [fut = std::move(fut)]() { fut.wait(); });
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
||||
@@ -1,14 +0,0 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#include "mlx/array.h"
|
||||
#include "mlx/io/load.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
void load(
|
||||
array& out,
|
||||
size_t offset,
|
||||
const std::shared_ptr<io::Reader>& reader,
|
||||
bool swap_endianess);
|
||||
|
||||
} // namespace mlx::core
|
||||
@@ -36,7 +36,7 @@ void shared_buffer_slice(
|
||||
flags.col_contiguous = is_col_contiguous;
|
||||
flags.contiguous = (no_bsx_size == data_size);
|
||||
|
||||
move_or_copy(in, out, out_strides, flags, data_size, data_offset);
|
||||
out.copy_shared_buffer(in, out_strides, flags, data_size, data_offset);
|
||||
}
|
||||
|
||||
void slice(
|
||||
|
||||
@@ -36,15 +36,10 @@ inline void set_ternary_op_output_data(
|
||||
const array& b,
|
||||
const array& c,
|
||||
array& out,
|
||||
TernaryOpType topt,
|
||||
bool donate_with_move = false) {
|
||||
auto maybe_donate = [&out, donate_with_move](const array& x) {
|
||||
TernaryOpType topt) {
|
||||
auto maybe_donate = [&out](const array& x) {
|
||||
if (is_donatable(x, out)) {
|
||||
if (donate_with_move) {
|
||||
out.move_shared_buffer(x);
|
||||
} else {
|
||||
out.copy_shared_buffer(x);
|
||||
}
|
||||
out.copy_shared_buffer(x);
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
@@ -53,12 +48,12 @@ inline void set_ternary_op_output_data(
|
||||
switch (topt) {
|
||||
case TernaryOpType::ScalarScalarScalar:
|
||||
out.set_data(
|
||||
allocator::malloc_or_wait(out.itemsize()), 1, b.strides(), b.flags());
|
||||
allocator::malloc(out.itemsize()), 1, b.strides(), b.flags());
|
||||
break;
|
||||
case TernaryOpType::VectorVectorVector:
|
||||
if (!(maybe_donate(a) || maybe_donate(b) || maybe_donate(c))) {
|
||||
out.set_data(
|
||||
allocator::malloc_or_wait(out.itemsize() * b.data_size()),
|
||||
allocator::malloc(out.itemsize() * b.data_size()),
|
||||
b.data_size(),
|
||||
b.strides(),
|
||||
b.flags());
|
||||
@@ -69,7 +64,7 @@ inline void set_ternary_op_output_data(
|
||||
if (!((a.flags().row_contiguous && maybe_donate(a)) ||
|
||||
(b.flags().row_contiguous && maybe_donate(b)) ||
|
||||
(c.flags().row_contiguous && maybe_donate(c)))) {
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
@@ -4,28 +4,6 @@
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
void move_or_copy(const array& in, array& out) {
|
||||
if (in.is_donatable()) {
|
||||
out.move_shared_buffer(in);
|
||||
} else {
|
||||
out.copy_shared_buffer(in);
|
||||
}
|
||||
}
|
||||
|
||||
void move_or_copy(
|
||||
const array& in,
|
||||
array& out,
|
||||
const Strides& strides,
|
||||
array::Flags flags,
|
||||
size_t data_size,
|
||||
size_t offset /* = 0 */) {
|
||||
if (in.is_donatable()) {
|
||||
out.move_shared_buffer(in, strides, flags, data_size, offset);
|
||||
} else {
|
||||
out.copy_shared_buffer(in, strides, flags, data_size, offset);
|
||||
}
|
||||
}
|
||||
|
||||
std::tuple<Shape, std::vector<Strides>> collapse_contiguous_dims(
|
||||
const Shape& shape,
|
||||
const std::vector<Strides>& strides,
|
||||
|
||||
@@ -159,15 +159,6 @@ inline bool is_donatable(const array& in, const array& out) {
|
||||
in.buffer_size() <= out.nbytes() + donation_extra;
|
||||
}
|
||||
|
||||
void move_or_copy(const array& in, array& out);
|
||||
void move_or_copy(
|
||||
const array& in,
|
||||
array& out,
|
||||
const Strides& strides,
|
||||
array::Flags flags,
|
||||
size_t data_size,
|
||||
size_t offset = 0);
|
||||
|
||||
std::pair<bool, Strides> prepare_reshape(const array& in, const array& out);
|
||||
|
||||
void shared_buffer_reshape(
|
||||
|
||||
@@ -44,7 +44,9 @@ target_sources(
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/binary.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/distributed.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/eigh.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/encoder.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/hadamard.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/matmul.cpp
|
||||
@@ -56,6 +58,7 @@ target_sources(
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/scan.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/select.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/softmax.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/logsumexp.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/sort.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/threefry.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp
|
||||
@@ -65,13 +68,14 @@ target_sources(
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/inverse.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/cholesky.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/unary.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/eval.cpp
|
||||
${CMAKE_CURRENT_BINARY_DIR}/compiled_preamble.cpp)
|
||||
|
||||
if(MLX_BUILD_ACCELERATE)
|
||||
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/gemms/bnns.cpp)
|
||||
else()
|
||||
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/gemms/no_fp16.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/gemms/no_bf16.cpp)
|
||||
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/gemms/simd_fp16.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/gemms/simd_bf16.cpp)
|
||||
endif()
|
||||
|
||||
if(IOS)
|
||||
|
||||
@@ -2,76 +2,27 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "mlx/allocator.h"
|
||||
#include "mlx/array.h"
|
||||
#include "mlx/backend/cpu/encoder.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace {
|
||||
|
||||
template <typename T>
|
||||
void arange(T start, T next, array& out, size_t size) {
|
||||
void arange(T start, T next, array& out, size_t size, Stream stream) {
|
||||
auto ptr = out.data<T>();
|
||||
auto step_size = next - start;
|
||||
for (int i = 0; i < size; ++i) {
|
||||
ptr[i] = start;
|
||||
start += step_size;
|
||||
}
|
||||
auto& encoder = cpu::get_command_encoder(stream);
|
||||
encoder.set_output_array(out);
|
||||
encoder.dispatch([ptr, start, step_size, size]() mutable {
|
||||
for (int i = 0; i < size; ++i) {
|
||||
ptr[i] = start;
|
||||
start += step_size;
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void arange(
|
||||
const std::vector<array>& inputs,
|
||||
array& out,
|
||||
double start,
|
||||
double step) {
|
||||
assert(inputs.size() == 0);
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
switch (out.dtype()) {
|
||||
case bool_:
|
||||
throw std::runtime_error("Bool type unsupported for arange.");
|
||||
break;
|
||||
case uint8:
|
||||
arange<uint8_t>(start, start + step, out, out.size());
|
||||
break;
|
||||
case uint16:
|
||||
arange<uint16_t>(start, start + step, out, out.size());
|
||||
break;
|
||||
case uint32:
|
||||
arange<uint32_t>(start, start + step, out, out.size());
|
||||
break;
|
||||
case uint64:
|
||||
arange<uint64_t>(start, start + step, out, out.size());
|
||||
break;
|
||||
case int8:
|
||||
arange<int8_t>(start, start + step, out, out.size());
|
||||
break;
|
||||
case int16:
|
||||
arange<int16_t>(start, start + step, out, out.size());
|
||||
break;
|
||||
case int32:
|
||||
arange<int32_t>(start, start + step, out, out.size());
|
||||
break;
|
||||
case int64:
|
||||
arange<int64_t>(start, start + step, out, out.size());
|
||||
break;
|
||||
case float16:
|
||||
arange<float16_t>(start, start + step, out, out.size());
|
||||
break;
|
||||
case float32:
|
||||
arange<float>(start, start + step, out, out.size());
|
||||
break;
|
||||
case float64:
|
||||
arange<double>(start, start + step, out, out.size());
|
||||
break;
|
||||
case bfloat16:
|
||||
arange<bfloat16_t>(start, start + step, out, out.size());
|
||||
break;
|
||||
case complex64:
|
||||
arange<complex64_t>(start, start + step, out, out.size());
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
#include <cassert>
|
||||
|
||||
#include "mlx/backend/common/utils.h"
|
||||
#include "mlx/backend/cpu/encoder.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
namespace mlx::core {
|
||||
@@ -17,15 +18,18 @@ void arg_reduce(const array& in, array& out, const OpT& op, int axis) {
|
||||
Shape shape = in.shape();
|
||||
strides.erase(strides.begin() + axis);
|
||||
shape.erase(shape.begin() + axis);
|
||||
auto in_ptr = in.data<InT>();
|
||||
auto out_ptr = out.data<uint32_t>();
|
||||
|
||||
for (uint32_t i = 0; i < out.size(); ++i) {
|
||||
auto loc = elem_to_loc(i, shape, strides);
|
||||
auto in_ptr = in.data<InT>() + loc;
|
||||
auto local_in_ptr = in_ptr + loc;
|
||||
uint32_t ind_v = 0;
|
||||
InT v = (*in_ptr);
|
||||
for (uint32_t j = 0; j < axis_size; ++j, in_ptr += axis_stride) {
|
||||
op(j, (*in_ptr), &ind_v, &v);
|
||||
InT v = (*local_in_ptr);
|
||||
for (uint32_t j = 0; j < axis_size; ++j, local_in_ptr += axis_stride) {
|
||||
op(j, (*local_in_ptr), &ind_v, &v);
|
||||
}
|
||||
out.data<uint32_t>()[i] = ind_v;
|
||||
out_ptr[i] = ind_v;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -64,52 +68,59 @@ void arg_reduce_dispatch(
|
||||
void ArgReduce::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
|
||||
switch (in.dtype()) {
|
||||
case bool_:
|
||||
arg_reduce_dispatch<bool>(in, out, reduce_type_, axis_);
|
||||
break;
|
||||
case uint8:
|
||||
arg_reduce_dispatch<uint8_t>(in, out, reduce_type_, axis_);
|
||||
break;
|
||||
case uint16:
|
||||
arg_reduce_dispatch<uint16_t>(in, out, reduce_type_, axis_);
|
||||
break;
|
||||
case uint32:
|
||||
arg_reduce_dispatch<uint32_t>(in, out, reduce_type_, axis_);
|
||||
break;
|
||||
case uint64:
|
||||
arg_reduce_dispatch<uint64_t>(in, out, reduce_type_, axis_);
|
||||
break;
|
||||
case int8:
|
||||
arg_reduce_dispatch<int8_t>(in, out, reduce_type_, axis_);
|
||||
break;
|
||||
case int16:
|
||||
arg_reduce_dispatch<int16_t>(in, out, reduce_type_, axis_);
|
||||
break;
|
||||
case int32:
|
||||
arg_reduce_dispatch<int32_t>(in, out, reduce_type_, axis_);
|
||||
break;
|
||||
case int64:
|
||||
arg_reduce_dispatch<int64_t>(in, out, reduce_type_, axis_);
|
||||
break;
|
||||
case float16:
|
||||
arg_reduce_dispatch<float16_t>(in, out, reduce_type_, axis_);
|
||||
break;
|
||||
case float32:
|
||||
arg_reduce_dispatch<float>(in, out, reduce_type_, axis_);
|
||||
break;
|
||||
case bfloat16:
|
||||
arg_reduce_dispatch<bfloat16_t>(in, out, reduce_type_, axis_);
|
||||
break;
|
||||
case float64:
|
||||
arg_reduce_dispatch<double>(in, out, reduce_type_, axis_);
|
||||
break;
|
||||
case complex64:
|
||||
arg_reduce_dispatch<complex64_t>(in, out, reduce_type_, axis_);
|
||||
break;
|
||||
}
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
auto& encoder = cpu::get_command_encoder(stream());
|
||||
encoder.set_input_array(in);
|
||||
encoder.set_output_array(out);
|
||||
encoder.dispatch([in = array::unsafe_weak_copy(in),
|
||||
out = array::unsafe_weak_copy(out),
|
||||
reduce_type_ = reduce_type_,
|
||||
axis_ = axis_]() mutable {
|
||||
switch (in.dtype()) {
|
||||
case bool_:
|
||||
arg_reduce_dispatch<bool>(in, out, reduce_type_, axis_);
|
||||
break;
|
||||
case uint8:
|
||||
arg_reduce_dispatch<uint8_t>(in, out, reduce_type_, axis_);
|
||||
break;
|
||||
case uint16:
|
||||
arg_reduce_dispatch<uint16_t>(in, out, reduce_type_, axis_);
|
||||
break;
|
||||
case uint32:
|
||||
arg_reduce_dispatch<uint32_t>(in, out, reduce_type_, axis_);
|
||||
break;
|
||||
case uint64:
|
||||
arg_reduce_dispatch<uint64_t>(in, out, reduce_type_, axis_);
|
||||
break;
|
||||
case int8:
|
||||
arg_reduce_dispatch<int8_t>(in, out, reduce_type_, axis_);
|
||||
break;
|
||||
case int16:
|
||||
arg_reduce_dispatch<int16_t>(in, out, reduce_type_, axis_);
|
||||
break;
|
||||
case int32:
|
||||
arg_reduce_dispatch<int32_t>(in, out, reduce_type_, axis_);
|
||||
break;
|
||||
case int64:
|
||||
arg_reduce_dispatch<int64_t>(in, out, reduce_type_, axis_);
|
||||
break;
|
||||
case float16:
|
||||
arg_reduce_dispatch<float16_t>(in, out, reduce_type_, axis_);
|
||||
break;
|
||||
case float32:
|
||||
arg_reduce_dispatch<float>(in, out, reduce_type_, axis_);
|
||||
break;
|
||||
case bfloat16:
|
||||
arg_reduce_dispatch<bfloat16_t>(in, out, reduce_type_, axis_);
|
||||
break;
|
||||
case float64:
|
||||
arg_reduce_dispatch<double>(in, out, reduce_type_, axis_);
|
||||
break;
|
||||
case complex64:
|
||||
arg_reduce_dispatch<complex64_t>(in, out, reduce_type_, axis_);
|
||||
break;
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
||||
@@ -8,6 +8,7 @@
|
||||
#include "mlx/backend/cpu/binary.h"
|
||||
#include "mlx/backend/cpu/binary_ops.h"
|
||||
#include "mlx/backend/cpu/binary_two.h"
|
||||
#include "mlx/backend/cpu/encoder.h"
|
||||
#include "mlx/primitives.h"
|
||||
#include "mlx/utils.h"
|
||||
|
||||
@@ -16,51 +17,218 @@ namespace mlx::core {
|
||||
namespace {
|
||||
|
||||
template <typename Op>
|
||||
void comparison_op(const array& a, const array& b, array& out, Op op) {
|
||||
switch (a.dtype()) {
|
||||
case bool_:
|
||||
binary_op<bool, bool>(a, b, out, op);
|
||||
break;
|
||||
case uint8:
|
||||
binary_op<uint8_t, bool>(a, b, out, op);
|
||||
break;
|
||||
case uint16:
|
||||
binary_op<uint16_t, bool>(a, b, out, op);
|
||||
break;
|
||||
case uint32:
|
||||
binary_op<uint32_t, bool>(a, b, out, op);
|
||||
break;
|
||||
case uint64:
|
||||
binary_op<uint64_t, bool>(a, b, out, op);
|
||||
break;
|
||||
case int8:
|
||||
binary_op<int8_t, bool>(a, b, out, op);
|
||||
break;
|
||||
case int16:
|
||||
binary_op<int16_t, bool>(a, b, out, op);
|
||||
break;
|
||||
case int32:
|
||||
binary_op<int32_t, bool>(a, b, out, op);
|
||||
break;
|
||||
case int64:
|
||||
binary_op<int64_t, bool>(a, b, out, op);
|
||||
break;
|
||||
case float16:
|
||||
binary_op<float16_t, bool>(a, b, out, op);
|
||||
break;
|
||||
case float32:
|
||||
binary_op<float, bool>(a, b, out, op);
|
||||
break;
|
||||
case float64:
|
||||
binary_op<double, bool>(a, b, out, op);
|
||||
break;
|
||||
case bfloat16:
|
||||
binary_op<bfloat16_t, bool>(a, b, out, op);
|
||||
break;
|
||||
case complex64:
|
||||
binary_op<complex64_t, bool>(a, b, out, op);
|
||||
break;
|
||||
}
|
||||
void binary(const array& a, const array& b, array& out, Op op, Stream stream) {
|
||||
auto bopt = get_binary_op_type(a, b);
|
||||
set_binary_op_output_data(a, b, out, bopt);
|
||||
|
||||
auto& encoder = cpu::get_command_encoder(stream);
|
||||
encoder.set_input_array(a);
|
||||
encoder.set_input_array(b);
|
||||
encoder.set_output_array(out);
|
||||
encoder.dispatch([a = array::unsafe_weak_copy(a),
|
||||
b = array::unsafe_weak_copy(b),
|
||||
out = array::unsafe_weak_copy(out),
|
||||
bopt]() mutable {
|
||||
switch (out.dtype()) {
|
||||
case bool_:
|
||||
binary_op<bool, Op>(a, b, out, bopt);
|
||||
break;
|
||||
case uint8:
|
||||
binary_op<uint8_t, Op>(a, b, out, bopt);
|
||||
break;
|
||||
case uint16:
|
||||
binary_op<uint16_t, Op>(a, b, out, bopt);
|
||||
break;
|
||||
case uint32:
|
||||
binary_op<uint32_t, Op>(a, b, out, bopt);
|
||||
break;
|
||||
case uint64:
|
||||
binary_op<uint64_t, Op>(a, b, out, bopt);
|
||||
break;
|
||||
case int8:
|
||||
binary_op<int8_t, Op>(a, b, out, bopt);
|
||||
break;
|
||||
case int16:
|
||||
binary_op<int16_t, Op>(a, b, out, bopt);
|
||||
break;
|
||||
case int32:
|
||||
binary_op<int32_t, Op>(a, b, out, bopt);
|
||||
break;
|
||||
case int64:
|
||||
binary_op<int64_t, Op>(a, b, out, bopt);
|
||||
break;
|
||||
case float16:
|
||||
binary_op<float16_t, Op>(a, b, out, bopt);
|
||||
break;
|
||||
case float32:
|
||||
binary_op<float, Op>(a, b, out, bopt);
|
||||
break;
|
||||
case float64:
|
||||
binary_op<double, Op>(a, b, out, bopt);
|
||||
break;
|
||||
case bfloat16:
|
||||
binary_op<bfloat16_t, Op>(a, b, out, bopt);
|
||||
break;
|
||||
case complex64:
|
||||
binary_op<complex64_t, Op>(a, b, out, bopt);
|
||||
break;
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
template <typename Op>
|
||||
void comparison_op(
|
||||
const array& a,
|
||||
const array& b,
|
||||
array& out,
|
||||
Op op,
|
||||
Stream stream) {
|
||||
auto bopt = get_binary_op_type(a, b);
|
||||
set_binary_op_output_data(a, b, out, bopt);
|
||||
|
||||
auto& encoder = cpu::get_command_encoder(stream);
|
||||
encoder.set_input_array(a);
|
||||
encoder.set_input_array(b);
|
||||
encoder.set_output_array(out);
|
||||
encoder.dispatch([a = array::unsafe_weak_copy(a),
|
||||
b = array::unsafe_weak_copy(b),
|
||||
out = array::unsafe_weak_copy(out),
|
||||
bopt]() mutable {
|
||||
switch (a.dtype()) {
|
||||
case bool_:
|
||||
binary_op<bool, bool, Op>(a, b, out, bopt);
|
||||
break;
|
||||
case uint8:
|
||||
binary_op<uint8_t, bool, Op>(a, b, out, bopt);
|
||||
break;
|
||||
case uint16:
|
||||
binary_op<uint16_t, bool, Op>(a, b, out, bopt);
|
||||
break;
|
||||
case uint32:
|
||||
binary_op<uint32_t, bool, Op>(a, b, out, bopt);
|
||||
break;
|
||||
case uint64:
|
||||
binary_op<uint64_t, bool, Op>(a, b, out, bopt);
|
||||
break;
|
||||
case int8:
|
||||
binary_op<int8_t, bool, Op>(a, b, out, bopt);
|
||||
break;
|
||||
case int16:
|
||||
binary_op<int16_t, bool, Op>(a, b, out, bopt);
|
||||
break;
|
||||
case int32:
|
||||
binary_op<int32_t, bool, Op>(a, b, out, bopt);
|
||||
break;
|
||||
case int64:
|
||||
binary_op<int64_t, bool, Op>(a, b, out, bopt);
|
||||
break;
|
||||
case float16:
|
||||
binary_op<float16_t, bool, Op>(a, b, out, bopt);
|
||||
break;
|
||||
case float32:
|
||||
binary_op<float, bool, Op>(a, b, out, bopt);
|
||||
break;
|
||||
case float64:
|
||||
binary_op<double, bool, Op>(a, b, out, bopt);
|
||||
break;
|
||||
case bfloat16:
|
||||
binary_op<bfloat16_t, bool, Op>(a, b, out, bopt);
|
||||
break;
|
||||
case complex64:
|
||||
binary_op<complex64_t, bool, Op>(a, b, out, bopt);
|
||||
break;
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
template <typename Op>
|
||||
void binary_float(
|
||||
const array& a,
|
||||
const array& b,
|
||||
array& out,
|
||||
Op op,
|
||||
Stream stream) {
|
||||
auto bopt = get_binary_op_type(a, b);
|
||||
set_binary_op_output_data(a, b, out, bopt);
|
||||
|
||||
auto& encoder = cpu::get_command_encoder(stream);
|
||||
encoder.set_input_array(a);
|
||||
encoder.set_input_array(b);
|
||||
encoder.set_output_array(out);
|
||||
encoder.dispatch([a = array::unsafe_weak_copy(a),
|
||||
b = array::unsafe_weak_copy(b),
|
||||
out = array::unsafe_weak_copy(out),
|
||||
bopt]() mutable {
|
||||
switch (out.dtype()) {
|
||||
case float16:
|
||||
binary_op<float16_t, Op>(a, b, out, bopt);
|
||||
break;
|
||||
case float32:
|
||||
binary_op<float, Op>(a, b, out, bopt);
|
||||
break;
|
||||
case float64:
|
||||
binary_op<double, Op>(a, b, out, bopt);
|
||||
break;
|
||||
case bfloat16:
|
||||
binary_op<bfloat16_t, Op>(a, b, out, bopt);
|
||||
break;
|
||||
default:
|
||||
throw std::runtime_error(
|
||||
"[binary_float] Only supports non-complex floating point types.");
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
template <typename Op>
|
||||
void binary_int(
|
||||
const array& a,
|
||||
const array& b,
|
||||
array& out,
|
||||
Op op,
|
||||
Stream stream) {
|
||||
auto bopt = get_binary_op_type(a, b);
|
||||
set_binary_op_output_data(a, b, out, bopt);
|
||||
|
||||
auto& encoder = cpu::get_command_encoder(stream);
|
||||
encoder.set_input_array(a);
|
||||
encoder.set_input_array(b);
|
||||
encoder.set_output_array(out);
|
||||
encoder.dispatch([a = array::unsafe_weak_copy(a),
|
||||
b = array::unsafe_weak_copy(b),
|
||||
out = array::unsafe_weak_copy(out),
|
||||
bopt]() mutable {
|
||||
switch (out.dtype()) {
|
||||
case bool_:
|
||||
binary_op<bool, Op>(a, b, out, bopt);
|
||||
case uint8:
|
||||
binary_op<uint8_t, Op>(a, b, out, bopt);
|
||||
break;
|
||||
case uint16:
|
||||
binary_op<uint16_t, Op>(a, b, out, bopt);
|
||||
break;
|
||||
case uint32:
|
||||
binary_op<uint32_t, Op>(a, b, out, bopt);
|
||||
break;
|
||||
case uint64:
|
||||
binary_op<uint64_t, Op>(a, b, out, bopt);
|
||||
break;
|
||||
case int8:
|
||||
binary_op<int8_t, Op>(a, b, out, bopt);
|
||||
break;
|
||||
case int16:
|
||||
binary_op<int16_t, Op>(a, b, out, bopt);
|
||||
break;
|
||||
case int32:
|
||||
binary_op<int32_t, Op>(a, b, out, bopt);
|
||||
break;
|
||||
case int64:
|
||||
binary_op<int64_t, Op>(a, b, out, bopt);
|
||||
break;
|
||||
default:
|
||||
throw std::runtime_error("[binary_int] Type not supported");
|
||||
break;
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace
|
||||
@@ -69,7 +237,7 @@ void Add::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
auto& a = inputs[0];
|
||||
auto& b = inputs[1];
|
||||
binary(a, b, out, detail::Add());
|
||||
binary(a, b, out, detail::Add(), stream());
|
||||
}
|
||||
|
||||
void DivMod::eval_cpu(
|
||||
@@ -78,70 +246,89 @@ void DivMod::eval_cpu(
|
||||
assert(inputs.size() == 2);
|
||||
auto& a = inputs[0];
|
||||
auto& b = inputs[1];
|
||||
auto integral_op = [](auto x, auto y) {
|
||||
return std::make_pair(x / y, x % y);
|
||||
};
|
||||
auto float_op = [](auto x, auto y) {
|
||||
return std::make_pair(std::trunc(x / y), std::fmod(x, y));
|
||||
};
|
||||
switch (outputs[0].dtype()) {
|
||||
case bool_:
|
||||
binary_op<bool>(a, b, outputs, integral_op);
|
||||
case uint8:
|
||||
binary_op<uint8_t>(a, b, outputs, integral_op);
|
||||
break;
|
||||
case uint16:
|
||||
binary_op<uint16_t>(a, b, outputs, integral_op);
|
||||
break;
|
||||
case uint32:
|
||||
binary_op<uint32_t>(a, b, outputs, integral_op);
|
||||
break;
|
||||
case uint64:
|
||||
binary_op<uint64_t>(a, b, outputs, integral_op);
|
||||
break;
|
||||
case int8:
|
||||
binary_op<int8_t>(a, b, outputs, integral_op);
|
||||
break;
|
||||
case int16:
|
||||
binary_op<int16_t>(a, b, outputs, integral_op);
|
||||
break;
|
||||
case int32:
|
||||
binary_op<int32_t>(a, b, outputs, integral_op);
|
||||
break;
|
||||
case int64:
|
||||
binary_op<int64_t>(a, b, outputs, integral_op);
|
||||
break;
|
||||
case float16:
|
||||
binary_op<float16_t>(a, b, outputs, float_op);
|
||||
break;
|
||||
case float32:
|
||||
binary_op<float>(a, b, outputs, float_op);
|
||||
break;
|
||||
case float64:
|
||||
binary_op<double>(a, b, outputs, float_op);
|
||||
break;
|
||||
case bfloat16:
|
||||
binary_op<bfloat16_t>(a, b, outputs, float_op);
|
||||
break;
|
||||
case complex64:
|
||||
// Should never get here
|
||||
throw std::runtime_error("[DivMod] Complex type not supported");
|
||||
break;
|
||||
}
|
||||
auto bopt = get_binary_op_type(a, b);
|
||||
auto& out_a = outputs[0];
|
||||
auto& out_b = outputs[1];
|
||||
set_binary_op_output_data(a, b, out_a, bopt);
|
||||
set_binary_op_output_data(a, b, out_b, bopt);
|
||||
|
||||
auto& encoder = cpu::get_command_encoder(stream());
|
||||
encoder.set_input_array(a);
|
||||
encoder.set_input_array(b);
|
||||
encoder.set_output_array(out_a);
|
||||
encoder.set_output_array(out_b);
|
||||
|
||||
encoder.dispatch([a = array::unsafe_weak_copy(a),
|
||||
b = array::unsafe_weak_copy(b),
|
||||
out_a = array::unsafe_weak_copy(out_a),
|
||||
out_b = array::unsafe_weak_copy(out_b),
|
||||
bopt]() mutable {
|
||||
auto integral_op = [](auto x, auto y) {
|
||||
return std::make_pair(x / y, x % y);
|
||||
};
|
||||
auto float_op = [](auto x, auto y) {
|
||||
return std::make_pair(std::trunc(x / y), std::fmod(x, y));
|
||||
};
|
||||
|
||||
switch (out_a.dtype()) {
|
||||
case bool_:
|
||||
binary_op<bool>(a, b, out_a, out_b, integral_op, bopt);
|
||||
case uint8:
|
||||
binary_op<uint8_t>(a, b, out_a, out_b, integral_op, bopt);
|
||||
break;
|
||||
case uint16:
|
||||
binary_op<uint16_t>(a, b, out_a, out_b, integral_op, bopt);
|
||||
break;
|
||||
case uint32:
|
||||
binary_op<uint32_t>(a, b, out_a, out_b, integral_op, bopt);
|
||||
break;
|
||||
case uint64:
|
||||
binary_op<uint64_t>(a, b, out_a, out_b, integral_op, bopt);
|
||||
break;
|
||||
case int8:
|
||||
binary_op<int8_t>(a, b, out_a, out_b, integral_op, bopt);
|
||||
break;
|
||||
case int16:
|
||||
binary_op<int16_t>(a, b, out_a, out_b, integral_op, bopt);
|
||||
break;
|
||||
case int32:
|
||||
binary_op<int32_t>(a, b, out_a, out_b, integral_op, bopt);
|
||||
break;
|
||||
case int64:
|
||||
binary_op<int64_t>(a, b, out_a, out_b, integral_op, bopt);
|
||||
break;
|
||||
case float16:
|
||||
binary_op<float16_t>(a, b, out_a, out_b, float_op, bopt);
|
||||
break;
|
||||
case float32:
|
||||
binary_op<float>(a, b, out_a, out_b, float_op, bopt);
|
||||
break;
|
||||
case float64:
|
||||
binary_op<double>(a, b, out_a, out_b, float_op, bopt);
|
||||
break;
|
||||
case bfloat16:
|
||||
binary_op<bfloat16_t>(a, b, out_a, out_b, float_op, bopt);
|
||||
break;
|
||||
case complex64:
|
||||
// Should never get here
|
||||
throw std::runtime_error("[DivMod] Complex type not supported");
|
||||
break;
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
void Divide::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
auto& a = inputs[0];
|
||||
auto& b = inputs[1];
|
||||
binary(a, b, out, detail::Divide());
|
||||
binary(a, b, out, detail::Divide(), stream());
|
||||
}
|
||||
|
||||
void Remainder::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
auto& a = inputs[0];
|
||||
auto& b = inputs[1];
|
||||
binary(a, b, out, detail::Remainder());
|
||||
binary(a, b, out, detail::Remainder(), stream());
|
||||
}
|
||||
|
||||
void Equal::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
@@ -149,181 +336,143 @@ void Equal::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
auto& a = inputs[0];
|
||||
auto& b = inputs[1];
|
||||
if (equal_nan_) {
|
||||
switch (a.dtype()) {
|
||||
case float16:
|
||||
binary_op<float16_t, bool>(a, b, out, detail::NaNEqual());
|
||||
break;
|
||||
case float32:
|
||||
binary_op<float, bool>(a, b, out, detail::NaNEqual());
|
||||
break;
|
||||
case float64:
|
||||
binary_op<double, bool>(a, b, out, detail::NaNEqual());
|
||||
break;
|
||||
case bfloat16:
|
||||
binary_op<bfloat16_t, bool>(a, b, out, detail::NaNEqual());
|
||||
break;
|
||||
case complex64:
|
||||
binary_op<complex64_t, bool>(a, b, out, detail::NaNEqual());
|
||||
break;
|
||||
default:
|
||||
throw std::runtime_error(
|
||||
"[NanEqual::eval_cpu] Only for floating point types.");
|
||||
}
|
||||
auto bopt = get_binary_op_type(a, b);
|
||||
set_binary_op_output_data(a, b, out, bopt);
|
||||
|
||||
auto& encoder = cpu::get_command_encoder(stream());
|
||||
encoder.set_input_array(a);
|
||||
encoder.set_input_array(b);
|
||||
encoder.set_output_array(out);
|
||||
encoder.dispatch([a = array::unsafe_weak_copy(a),
|
||||
b = array::unsafe_weak_copy(b),
|
||||
out = array::unsafe_weak_copy(out),
|
||||
bopt]() mutable {
|
||||
switch (a.dtype()) {
|
||||
case float16:
|
||||
binary_op<float16_t, bool, detail::NaNEqual>(a, b, out, bopt);
|
||||
break;
|
||||
case float32:
|
||||
binary_op<float, bool, detail::NaNEqual>(a, b, out, bopt);
|
||||
break;
|
||||
case float64:
|
||||
binary_op<double, bool, detail::NaNEqual>(a, b, out, bopt);
|
||||
break;
|
||||
case bfloat16:
|
||||
binary_op<bfloat16_t, bool, detail::NaNEqual>(a, b, out, bopt);
|
||||
break;
|
||||
case complex64:
|
||||
binary_op<complex64_t, bool, detail::NaNEqual>(a, b, out, bopt);
|
||||
break;
|
||||
default:
|
||||
throw std::runtime_error(
|
||||
"[NanEqual::eval_cpu] Only for floating point types.");
|
||||
}
|
||||
});
|
||||
} else {
|
||||
comparison_op(a, b, out, detail::Equal());
|
||||
comparison_op(a, b, out, detail::Equal(), stream());
|
||||
}
|
||||
}
|
||||
|
||||
void Greater::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
comparison_op(inputs[0], inputs[1], out, detail::Greater());
|
||||
comparison_op(inputs[0], inputs[1], out, detail::Greater(), stream());
|
||||
}
|
||||
|
||||
void GreaterEqual::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
comparison_op(inputs[0], inputs[1], out, detail::GreaterEqual());
|
||||
comparison_op(inputs[0], inputs[1], out, detail::GreaterEqual(), stream());
|
||||
}
|
||||
|
||||
void Less::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
comparison_op(inputs[0], inputs[1], out, detail::Less());
|
||||
comparison_op(inputs[0], inputs[1], out, detail::Less(), stream());
|
||||
}
|
||||
|
||||
void LessEqual::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
comparison_op(inputs[0], inputs[1], out, detail::LessEqual());
|
||||
comparison_op(inputs[0], inputs[1], out, detail::LessEqual(), stream());
|
||||
}
|
||||
|
||||
void LogAddExp::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
auto& a = inputs[0];
|
||||
auto& b = inputs[1];
|
||||
switch (out.dtype()) {
|
||||
case float16:
|
||||
binary_op<float16_t>(a, b, out, detail::LogAddExp());
|
||||
break;
|
||||
case float32:
|
||||
binary_op<float>(a, b, out, detail::LogAddExp());
|
||||
break;
|
||||
case float64:
|
||||
binary_op<double>(a, b, out, detail::LogAddExp());
|
||||
break;
|
||||
case bfloat16:
|
||||
binary_op<bfloat16_t>(a, b, out, detail::LogAddExp());
|
||||
break;
|
||||
default:
|
||||
throw std::runtime_error(
|
||||
"[LogAddExp::eval_cpu] Only supports non-complex floating point types.");
|
||||
}
|
||||
binary_float(a, b, out, detail::LogAddExp(), stream());
|
||||
}
|
||||
|
||||
void LogicalAnd::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2); // LogicalAnd requires two input arrays
|
||||
auto& in1 = inputs[0];
|
||||
auto& in2 = inputs[1];
|
||||
binary(in1, in2, out, detail::LogicalAnd());
|
||||
binary(in1, in2, out, detail::LogicalAnd(), stream());
|
||||
}
|
||||
|
||||
void LogicalOr::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2); // LogicalOr requires two input arrays
|
||||
auto& in1 = inputs[0];
|
||||
auto& in2 = inputs[1];
|
||||
binary(in1, in2, out, detail::LogicalOr());
|
||||
binary(in1, in2, out, detail::LogicalOr(), stream());
|
||||
}
|
||||
|
||||
void Maximum::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
auto& a = inputs[0];
|
||||
auto& b = inputs[1];
|
||||
binary(a, b, out, detail::Maximum());
|
||||
binary(a, b, out, detail::Maximum(), stream());
|
||||
}
|
||||
|
||||
void Minimum::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
auto& a = inputs[0];
|
||||
auto& b = inputs[1];
|
||||
binary(a, b, out, detail::Minimum());
|
||||
binary(a, b, out, detail::Minimum(), stream());
|
||||
}
|
||||
|
||||
void Multiply::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
auto& a = inputs[0];
|
||||
auto& b = inputs[1];
|
||||
binary(a, b, out, detail::Multiply());
|
||||
binary(a, b, out, detail::Multiply(), stream());
|
||||
}
|
||||
|
||||
void NotEqual::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
comparison_op(inputs[0], inputs[1], out, detail::NotEqual());
|
||||
comparison_op(inputs[0], inputs[1], out, detail::NotEqual(), stream());
|
||||
}
|
||||
|
||||
void Power::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
auto& a = inputs[0];
|
||||
auto& b = inputs[1];
|
||||
binary(a, b, out, detail::Power());
|
||||
binary(a, b, out, detail::Power(), stream());
|
||||
}
|
||||
|
||||
void Subtract::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
auto& a = inputs[0];
|
||||
auto& b = inputs[1];
|
||||
binary(a, b, out, detail::Subtract());
|
||||
binary(a, b, out, detail::Subtract(), stream());
|
||||
}
|
||||
|
||||
void BitwiseBinary::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
auto& a = inputs[0];
|
||||
auto& b = inputs[1];
|
||||
auto dispatch_type = [&a, &b, &out](auto op) {
|
||||
switch (out.dtype()) {
|
||||
case bool_:
|
||||
binary_op<bool>(a, b, out, op);
|
||||
case uint8:
|
||||
binary_op<uint8_t>(a, b, out, op);
|
||||
break;
|
||||
case uint16:
|
||||
binary_op<uint16_t>(a, b, out, op);
|
||||
break;
|
||||
case uint32:
|
||||
binary_op<uint32_t>(a, b, out, op);
|
||||
break;
|
||||
case uint64:
|
||||
binary_op<uint64_t>(a, b, out, op);
|
||||
break;
|
||||
case int8:
|
||||
binary_op<int8_t>(a, b, out, op);
|
||||
break;
|
||||
case int16:
|
||||
binary_op<int16_t>(a, b, out, op);
|
||||
break;
|
||||
case int32:
|
||||
binary_op<int32_t>(a, b, out, op);
|
||||
break;
|
||||
case int64:
|
||||
binary_op<int64_t>(a, b, out, op);
|
||||
break;
|
||||
default:
|
||||
throw std::runtime_error(
|
||||
"[BitwiseBinary::eval_cpu] Type not supported");
|
||||
break;
|
||||
}
|
||||
};
|
||||
switch (op_) {
|
||||
case BitwiseBinary::And:
|
||||
dispatch_type(detail::BitwiseAnd());
|
||||
binary_int(a, b, out, detail::BitwiseAnd(), stream());
|
||||
break;
|
||||
case BitwiseBinary::Or:
|
||||
dispatch_type(detail::BitwiseOr());
|
||||
binary_int(a, b, out, detail::BitwiseOr(), stream());
|
||||
break;
|
||||
case BitwiseBinary::Xor:
|
||||
dispatch_type(detail::BitwiseXor());
|
||||
binary_int(a, b, out, detail::BitwiseXor(), stream());
|
||||
break;
|
||||
case BitwiseBinary::LeftShift:
|
||||
dispatch_type(detail::LeftShift());
|
||||
binary_int(a, b, out, detail::LeftShift(), stream());
|
||||
break;
|
||||
case BitwiseBinary::RightShift:
|
||||
dispatch_type(detail::RightShift());
|
||||
binary_int(a, b, out, detail::RightShift(), stream());
|
||||
break;
|
||||
}
|
||||
}
|
||||
@@ -332,23 +481,7 @@ void ArcTan2::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
const auto& a = inputs[0];
|
||||
const auto& b = inputs[1];
|
||||
switch (out.dtype()) {
|
||||
case float16:
|
||||
binary_op<float16_t>(a, b, out, detail::ArcTan2());
|
||||
break;
|
||||
case float32:
|
||||
binary_op<float>(a, b, out, detail::ArcTan2());
|
||||
break;
|
||||
case float64:
|
||||
binary_op<double>(a, b, out, detail::ArcTan2());
|
||||
break;
|
||||
case bfloat16:
|
||||
binary_op<bfloat16_t>(a, b, out, detail::ArcTan2());
|
||||
break;
|
||||
default:
|
||||
throw std::runtime_error(
|
||||
"[ArcTan2::eval_cpu] Only supports non-complex floating point types.");
|
||||
}
|
||||
binary_float(a, b, out, detail::ArcTan2(), stream());
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
||||
@@ -3,7 +3,6 @@
|
||||
#pragma once
|
||||
#include <cassert>
|
||||
|
||||
#include "mlx/allocator.h"
|
||||
#include "mlx/array.h"
|
||||
#include "mlx/backend/common/binary.h"
|
||||
#include "mlx/backend/common/utils.h"
|
||||
@@ -14,22 +13,18 @@ namespace mlx::core {
|
||||
|
||||
template <typename Op>
|
||||
struct VectorScalar {
|
||||
Op op;
|
||||
|
||||
VectorScalar(Op op_) : op(op_) {}
|
||||
|
||||
template <typename T, typename U>
|
||||
void operator()(const T* a, const T* b, U* dst, int size) {
|
||||
T scalar = *b;
|
||||
constexpr int N = simd::max_size<T>;
|
||||
while (size >= N) {
|
||||
simd::store(dst, op(simd::load<T, N>(a), simd::Simd<T, N>(scalar)));
|
||||
simd::store(dst, Op{}(simd::load<T, N>(a), simd::Simd<T, N>(scalar)));
|
||||
dst += N;
|
||||
a += N;
|
||||
size -= N;
|
||||
}
|
||||
while (size-- > 0) {
|
||||
*dst = op(*a, scalar);
|
||||
*dst = Op{}(*a, scalar);
|
||||
dst++;
|
||||
a++;
|
||||
}
|
||||
@@ -38,22 +33,18 @@ struct VectorScalar {
|
||||
|
||||
template <typename Op>
|
||||
struct ScalarVector {
|
||||
Op op;
|
||||
|
||||
ScalarVector(Op op_) : op(op_) {}
|
||||
|
||||
template <typename T, typename U>
|
||||
void operator()(const T* a, const T* b, U* dst, int size) {
|
||||
T scalar = *a;
|
||||
constexpr int N = simd::max_size<T>;
|
||||
while (size >= N) {
|
||||
simd::store(dst, op(simd::Simd<T, N>(scalar), simd::load<T, N>(b)));
|
||||
simd::store(dst, Op{}(simd::Simd<T, N>(scalar), simd::load<T, N>(b)));
|
||||
dst += N;
|
||||
b += N;
|
||||
size -= N;
|
||||
}
|
||||
while (size-- > 0) {
|
||||
*dst = op(scalar, *b);
|
||||
*dst = Op{}(scalar, *b);
|
||||
dst++;
|
||||
b++;
|
||||
}
|
||||
@@ -62,22 +53,18 @@ struct ScalarVector {
|
||||
|
||||
template <typename Op>
|
||||
struct VectorVector {
|
||||
Op op;
|
||||
|
||||
VectorVector(Op op_) : op(op_) {}
|
||||
|
||||
template <typename T, typename U>
|
||||
void operator()(const T* a, const T* b, U* dst, int size) {
|
||||
constexpr int N = simd::max_size<T>;
|
||||
while (size >= N) {
|
||||
simd::store(dst, op(simd::load<T, N>(a), simd::load<T, N>(b)));
|
||||
simd::store(dst, Op{}(simd::load<T, N>(a), simd::load<T, N>(b)));
|
||||
dst += N;
|
||||
a += N;
|
||||
b += N;
|
||||
size -= N;
|
||||
}
|
||||
while (size-- > 0) {
|
||||
*dst = op(*a, *b);
|
||||
*dst = Op{}(*a, *b);
|
||||
dst++;
|
||||
a++;
|
||||
b++;
|
||||
@@ -90,7 +77,6 @@ void binary_op_dims(
|
||||
const T* a,
|
||||
const T* b,
|
||||
U* out,
|
||||
Op op,
|
||||
const Shape& shape,
|
||||
const Strides& a_strides,
|
||||
const Strides& b_strides,
|
||||
@@ -104,12 +90,12 @@ void binary_op_dims(
|
||||
for (int i = 0; i < N; i++) {
|
||||
if constexpr (D > 1) {
|
||||
binary_op_dims<T, U, Op, D - 1, Strided>(
|
||||
a, b, out, op, shape, a_strides, b_strides, out_strides, axis + 1);
|
||||
a, b, out, shape, a_strides, b_strides, out_strides, axis + 1);
|
||||
} else {
|
||||
if constexpr (Strided) {
|
||||
op(a, b, out, stride_out);
|
||||
Op{}(a, b, out, stride_out);
|
||||
} else {
|
||||
*out = op(*a, *b);
|
||||
*out = Op{}(*a, *b);
|
||||
}
|
||||
}
|
||||
out += stride_out;
|
||||
@@ -120,66 +106,38 @@ void binary_op_dims(
|
||||
|
||||
template <typename T, typename U, bool Strided, typename Op>
|
||||
void binary_op_dispatch_dims(
|
||||
const array& a,
|
||||
const array& b,
|
||||
array& out,
|
||||
Op op,
|
||||
const T* a,
|
||||
const T* b,
|
||||
U* out,
|
||||
int dim,
|
||||
int size,
|
||||
const Shape& shape,
|
||||
const Strides& a_strides,
|
||||
const Strides& b_strides,
|
||||
const Strides& out_strides) {
|
||||
const T* a_ptr = a.data<T>();
|
||||
const T* b_ptr = b.data<T>();
|
||||
U* out_ptr = out.data<U>();
|
||||
switch (dim) {
|
||||
case 1:
|
||||
binary_op_dims<T, U, Op, 1, Strided>(
|
||||
a_ptr,
|
||||
b_ptr,
|
||||
out_ptr,
|
||||
op,
|
||||
shape,
|
||||
a_strides,
|
||||
b_strides,
|
||||
out_strides,
|
||||
0);
|
||||
a, b, out, shape, a_strides, b_strides, out_strides, 0);
|
||||
return;
|
||||
case 2:
|
||||
binary_op_dims<T, U, Op, 2, Strided>(
|
||||
a_ptr,
|
||||
b_ptr,
|
||||
out_ptr,
|
||||
op,
|
||||
shape,
|
||||
a_strides,
|
||||
b_strides,
|
||||
out_strides,
|
||||
0);
|
||||
a, b, out, shape, a_strides, b_strides, out_strides, 0);
|
||||
return;
|
||||
case 3:
|
||||
binary_op_dims<T, U, Op, 3, Strided>(
|
||||
a_ptr,
|
||||
b_ptr,
|
||||
out_ptr,
|
||||
op,
|
||||
shape,
|
||||
a_strides,
|
||||
b_strides,
|
||||
out_strides,
|
||||
0);
|
||||
a, b, out, shape, a_strides, b_strides, out_strides, 0);
|
||||
return;
|
||||
}
|
||||
|
||||
ContiguousIterator a_it(shape, a_strides, dim - 3);
|
||||
ContiguousIterator b_it(shape, b_strides, dim - 3);
|
||||
auto stride = out_strides[dim - 4];
|
||||
for (int64_t elem = 0; elem < a.size(); elem += stride) {
|
||||
for (int64_t elem = 0; elem < size; elem += stride) {
|
||||
binary_op_dims<T, U, Op, 3, Strided>(
|
||||
a_ptr + a_it.loc,
|
||||
b_ptr + b_it.loc,
|
||||
out_ptr + elem,
|
||||
op,
|
||||
a + a_it.loc,
|
||||
b + b_it.loc,
|
||||
out + elem,
|
||||
shape,
|
||||
a_strides,
|
||||
b_strides,
|
||||
@@ -191,40 +149,41 @@ void binary_op_dispatch_dims(
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op>
|
||||
void binary_op(const array& a, const array& b, array& out, Op op) {
|
||||
auto bopt = get_binary_op_type(a, b);
|
||||
set_binary_op_output_data(a, b, out, bopt);
|
||||
|
||||
void binary_op(const array& a, const array& b, array& out, BinaryOpType bopt) {
|
||||
// The full computation is scalar scalar so call the base op once
|
||||
auto a_ptr = a.data<T>();
|
||||
auto b_ptr = b.data<T>();
|
||||
|
||||
auto out_ptr = out.data<U>();
|
||||
if (bopt == BinaryOpType::ScalarScalar) {
|
||||
*(out.data<U>()) = op(*a.data<T>(), *b.data<T>());
|
||||
*out_ptr = Op{}(*a_ptr, *b_ptr);
|
||||
return;
|
||||
}
|
||||
|
||||
// The full computation is scalar vector so delegate to the op
|
||||
if (bopt == BinaryOpType::ScalarVector) {
|
||||
ScalarVector{op}(a.data<T>(), b.data<T>(), out.data<U>(), b.data_size());
|
||||
ScalarVector<Op>{}(a_ptr, b_ptr, out_ptr, b.data_size());
|
||||
return;
|
||||
}
|
||||
|
||||
// The full computation is vector scalar so delegate to the op
|
||||
if (bopt == BinaryOpType::VectorScalar) {
|
||||
VectorScalar{op}(a.data<T>(), b.data<T>(), out.data<U>(), a.data_size());
|
||||
VectorScalar<Op>{}(a_ptr, b_ptr, out_ptr, a.data_size());
|
||||
return;
|
||||
}
|
||||
|
||||
// The full computation is vector vector so delegate to the op
|
||||
if (bopt == BinaryOpType::VectorVector) {
|
||||
VectorVector{op}(a.data<T>(), b.data<T>(), out.data<U>(), out.size());
|
||||
VectorVector<Op>{}(a_ptr, b_ptr, out_ptr, a.size());
|
||||
return;
|
||||
}
|
||||
|
||||
// General computation so let's try to optimize
|
||||
auto [new_shape, new_strides] = collapse_contiguous_dims(
|
||||
a.shape(), {a.strides(), b.strides(), out.strides()});
|
||||
const auto& a_strides = new_strides[0];
|
||||
const auto& b_strides = new_strides[1];
|
||||
const auto& strides = new_strides[2];
|
||||
auto& a_strides = new_strides[0];
|
||||
auto& b_strides = new_strides[1];
|
||||
auto& strides = new_strides[2];
|
||||
|
||||
// Get the left-most dim such that the array is row contiguous after
|
||||
auto leftmost_rc_dim = [&strides](const auto& arr_strides) {
|
||||
@@ -248,7 +207,8 @@ void binary_op(const array& a, const array& b, array& out, Op op) {
|
||||
|
||||
auto ndim = new_shape.size();
|
||||
|
||||
// Case 1: LxM and FxM where L and F are broadcastable and M is row contiguous
|
||||
// Case 1: LxM and FxM where L and F are broadcastable and M is row
|
||||
// contiguous
|
||||
int dim = ndim;
|
||||
if (int d = std::max(a_rc_dim, b_rc_dim); d < ndim) {
|
||||
bopt = BinaryOpType::VectorVector;
|
||||
@@ -275,99 +235,59 @@ void binary_op(const array& a, const array& b, array& out, Op op) {
|
||||
|
||||
switch (bopt) {
|
||||
case BinaryOpType::VectorVector:
|
||||
binary_op_dispatch_dims<T, U, true>(
|
||||
a,
|
||||
b,
|
||||
out,
|
||||
VectorVector{op},
|
||||
binary_op_dispatch_dims<T, U, true, VectorVector<Op>>(
|
||||
a_ptr,
|
||||
b_ptr,
|
||||
out_ptr,
|
||||
dim,
|
||||
a.size(),
|
||||
new_shape,
|
||||
a_strides,
|
||||
b_strides,
|
||||
strides);
|
||||
break;
|
||||
case BinaryOpType::VectorScalar:
|
||||
binary_op_dispatch_dims<T, U, true>(
|
||||
a,
|
||||
b,
|
||||
out,
|
||||
VectorScalar{op},
|
||||
binary_op_dispatch_dims<T, U, true, VectorScalar<Op>>(
|
||||
a_ptr,
|
||||
b_ptr,
|
||||
out_ptr,
|
||||
dim,
|
||||
a.size(),
|
||||
new_shape,
|
||||
a_strides,
|
||||
b_strides,
|
||||
strides);
|
||||
break;
|
||||
case BinaryOpType::ScalarVector:
|
||||
binary_op_dispatch_dims<T, U, true>(
|
||||
a,
|
||||
b,
|
||||
out,
|
||||
ScalarVector{op},
|
||||
binary_op_dispatch_dims<T, U, true, ScalarVector<Op>>(
|
||||
a_ptr,
|
||||
b_ptr,
|
||||
out_ptr,
|
||||
dim,
|
||||
a.size(),
|
||||
new_shape,
|
||||
a_strides,
|
||||
b_strides,
|
||||
strides);
|
||||
break;
|
||||
default:
|
||||
binary_op_dispatch_dims<T, U, false>(
|
||||
a, b, out, op, dim, new_shape, a_strides, b_strides, strides);
|
||||
binary_op_dispatch_dims<T, U, false, Op>(
|
||||
a_ptr,
|
||||
b_ptr,
|
||||
out_ptr,
|
||||
dim,
|
||||
a.size(),
|
||||
new_shape,
|
||||
a_strides,
|
||||
b_strides,
|
||||
strides);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename Op>
|
||||
void binary_op(const array& a, const array& b, array& out, Op op) {
|
||||
binary_op<T, T>(a, b, out, op);
|
||||
}
|
||||
|
||||
template <typename Op>
|
||||
void binary(const array& a, const array& b, array& out, Op op) {
|
||||
switch (out.dtype()) {
|
||||
case bool_:
|
||||
binary_op<bool>(a, b, out, op);
|
||||
break;
|
||||
case uint8:
|
||||
binary_op<uint8_t>(a, b, out, op);
|
||||
break;
|
||||
case uint16:
|
||||
binary_op<uint16_t>(a, b, out, op);
|
||||
break;
|
||||
case uint32:
|
||||
binary_op<uint32_t>(a, b, out, op);
|
||||
break;
|
||||
case uint64:
|
||||
binary_op<uint64_t>(a, b, out, op);
|
||||
break;
|
||||
case int8:
|
||||
binary_op<int8_t>(a, b, out, op);
|
||||
break;
|
||||
case int16:
|
||||
binary_op<int16_t>(a, b, out, op);
|
||||
break;
|
||||
case int32:
|
||||
binary_op<int32_t>(a, b, out, op);
|
||||
break;
|
||||
case int64:
|
||||
binary_op<int64_t>(a, b, out, op);
|
||||
break;
|
||||
case float16:
|
||||
binary_op<float16_t>(a, b, out, op);
|
||||
break;
|
||||
case float32:
|
||||
binary_op<float>(a, b, out, op);
|
||||
break;
|
||||
case float64:
|
||||
binary_op<double>(a, b, out, op);
|
||||
break;
|
||||
case bfloat16:
|
||||
binary_op<bfloat16_t>(a, b, out, op);
|
||||
break;
|
||||
case complex64:
|
||||
binary_op<complex64_t>(a, b, out, op);
|
||||
break;
|
||||
}
|
||||
void binary_op(const array& a, const array& b, array& out, BinaryOpType bopt) {
|
||||
binary_op<T, T, Op>(a, b, out, bopt);
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
||||
@@ -58,14 +58,14 @@ void binary_op_dispatch_dims(
|
||||
Op op) {
|
||||
auto [shape, strides] = collapse_contiguous_dims(
|
||||
a.shape(), {a.strides(), b.strides(), out_a.strides()});
|
||||
const auto& a_strides = strides[0];
|
||||
const auto& b_strides = strides[1];
|
||||
const auto& out_strides = strides[2];
|
||||
const T* a_ptr = a.data<T>();
|
||||
const T* b_ptr = b.data<T>();
|
||||
U* out_a_ptr = out_a.data<U>();
|
||||
U* out_b_ptr = out_b.data<U>();
|
||||
|
||||
const auto& a_strides = strides[0];
|
||||
const auto& b_strides = strides[1];
|
||||
const auto& out_strides = strides[2];
|
||||
int ndim = shape.size();
|
||||
switch (ndim) {
|
||||
case 1:
|
||||
@@ -120,14 +120,10 @@ template <typename T, typename U = T, typename Op>
|
||||
void binary_op(
|
||||
const array& a,
|
||||
const array& b,
|
||||
std::vector<array>& outputs,
|
||||
Op op) {
|
||||
auto bopt = get_binary_op_type(a, b);
|
||||
auto& out_a = outputs[0];
|
||||
auto& out_b = outputs[1];
|
||||
set_binary_op_output_data(a, b, out_a, bopt);
|
||||
set_binary_op_output_data(a, b, out_b, bopt);
|
||||
|
||||
array& out_a,
|
||||
array& out_b,
|
||||
Op op,
|
||||
BinaryOpType bopt) {
|
||||
// The full computation is scalar scalar so call the base op once
|
||||
if (bopt == BinaryOpType::General) {
|
||||
binary_op_dispatch_dims<T, U, Op>(a, b, out_a, out_b, op);
|
||||
@@ -141,14 +137,14 @@ void binary_op(
|
||||
if (bopt == BinaryOpType::ScalarScalar) {
|
||||
std::tie(*out_a_ptr, *out_b_ptr) = op(*a_ptr, *b_ptr);
|
||||
} else if (bopt == BinaryOpType::ScalarVector) {
|
||||
for (size_t i = 0; i < b.size(); ++i) {
|
||||
for (size_t i = 0; i < b.data_size(); ++i) {
|
||||
std::tie(*out_a_ptr, *out_b_ptr) = op(*a_ptr, *b_ptr);
|
||||
out_a_ptr++;
|
||||
out_b_ptr++;
|
||||
b_ptr++;
|
||||
}
|
||||
} else if (bopt == BinaryOpType::VectorScalar) {
|
||||
for (size_t i = 0; i < a.size(); ++i) {
|
||||
for (size_t i = 0; i < a.data_size(); ++i) {
|
||||
std::tie(*out_a_ptr, *out_b_ptr) = op(*a_ptr, *b_ptr);
|
||||
out_a_ptr++;
|
||||
out_b_ptr++;
|
||||
@@ -165,58 +161,6 @@ void binary_op(
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Op>
|
||||
void binary(
|
||||
const array& a,
|
||||
const array& b,
|
||||
std::vector<array>& outputs,
|
||||
Op op) {
|
||||
switch (outputs[0].dtype()) {
|
||||
case bool_:
|
||||
binary_op<bool>(a, b, outputs, op);
|
||||
break;
|
||||
case uint8:
|
||||
binary_op<uint8_t>(a, b, outputs, op);
|
||||
break;
|
||||
case uint16:
|
||||
binary_op<uint16_t>(a, b, outputs, op);
|
||||
break;
|
||||
case uint32:
|
||||
binary_op<uint32_t>(a, b, outputs, op);
|
||||
break;
|
||||
case uint64:
|
||||
binary_op<uint64_t>(a, b, outputs, op);
|
||||
break;
|
||||
case int8:
|
||||
binary_op<int8_t>(a, b, outputs, op);
|
||||
break;
|
||||
case int16:
|
||||
binary_op<int16_t>(a, b, outputs, op);
|
||||
break;
|
||||
case int32:
|
||||
binary_op<int32_t>(a, b, outputs, op);
|
||||
break;
|
||||
case int64:
|
||||
binary_op<int64_t>(a, b, outputs, op);
|
||||
break;
|
||||
case float16:
|
||||
binary_op<float16_t>(a, b, outputs, op);
|
||||
break;
|
||||
case float32:
|
||||
binary_op<float>(a, b, outputs, op);
|
||||
break;
|
||||
case float64:
|
||||
binary_op<double>(a, b, outputs, op);
|
||||
break;
|
||||
case bfloat16:
|
||||
binary_op<bfloat16_t>(a, b, outputs, op);
|
||||
break;
|
||||
case complex64:
|
||||
binary_op<complex64_t>(a, b, outputs, op);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
} // namespace mlx::core
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
|
||||
#include "mlx/allocator.h"
|
||||
#include "mlx/backend/cpu/copy.h"
|
||||
#include "mlx/backend/cpu/encoder.h"
|
||||
#include "mlx/backend/cpu/lapack.h"
|
||||
#include "mlx/linalg.h"
|
||||
#include "mlx/primitives.h"
|
||||
@@ -9,7 +10,7 @@
|
||||
namespace mlx::core {
|
||||
|
||||
template <typename T>
|
||||
void cholesky_impl(const array& a, array& factor, bool upper) {
|
||||
void cholesky_impl(const array& a, array& factor, bool upper, Stream stream) {
|
||||
// Lapack uses the column-major convention. We take advantage of the fact that
|
||||
// the matrix should be symmetric:
|
||||
// (A)ᵀ = A
|
||||
@@ -17,60 +18,63 @@ void cholesky_impl(const array& a, array& factor, bool upper) {
|
||||
// triangular matrix, so uplo is the opposite of what we would expect from
|
||||
// upper
|
||||
|
||||
char uplo = (upper) ? 'L' : 'U';
|
||||
|
||||
// The decomposition is computed in place, so just copy the input to the
|
||||
// output.
|
||||
copy(
|
||||
a,
|
||||
factor,
|
||||
a.flags().row_contiguous ? CopyType::Vector : CopyType::General);
|
||||
a.flags().row_contiguous ? CopyType::Vector : CopyType::General,
|
||||
stream);
|
||||
|
||||
const int N = a.shape(-1);
|
||||
const size_t num_matrices = a.size() / (N * N);
|
||||
auto& encoder = cpu::get_command_encoder(stream);
|
||||
encoder.set_output_array(factor);
|
||||
encoder.dispatch([matrix = factor.data<T>(),
|
||||
upper,
|
||||
N = a.shape(-1),
|
||||
size = a.size()]() mutable {
|
||||
char uplo = (upper) ? 'L' : 'U';
|
||||
size_t num_matrices = size / (N * N);
|
||||
for (int i = 0; i < num_matrices; i++) {
|
||||
// Compute Cholesky factorization.
|
||||
int info;
|
||||
potrf<T>(
|
||||
/* uplo = */ &uplo,
|
||||
/* n = */ &N,
|
||||
/* a = */ matrix,
|
||||
/* lda = */ &N,
|
||||
/* info = */ &info);
|
||||
|
||||
T* matrix = factor.data<T>();
|
||||
|
||||
for (int i = 0; i < num_matrices; i++) {
|
||||
// Compute Cholesky factorization.
|
||||
int info;
|
||||
potrf<T>(
|
||||
/* uplo = */ &uplo,
|
||||
/* n = */ &N,
|
||||
/* a = */ matrix,
|
||||
/* lda = */ &N,
|
||||
/* info = */ &info);
|
||||
|
||||
// TODO: We do nothing when the matrix is not positive semi-definite
|
||||
// because throwing an error would result in a crash. If we figure out how
|
||||
// to catch errors from the implementation we should throw.
|
||||
if (info < 0) {
|
||||
std::stringstream msg;
|
||||
msg << "[cholesky] Cholesky decomposition failed with error code "
|
||||
<< info;
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
|
||||
// Zero out the upper/lower triangle while advancing the pointer to the
|
||||
// next matrix at the same time.
|
||||
for (int row = 0; row < N; row++) {
|
||||
if (upper) {
|
||||
std::fill(matrix, matrix + row, 0);
|
||||
} else {
|
||||
std::fill(matrix + row + 1, matrix + N, 0);
|
||||
// TODO: We do nothing when the matrix is not positive semi-definite
|
||||
// because throwing an error would result in a crash. If we figure out how
|
||||
// to catch errors from the implementation we should throw.
|
||||
if (info < 0) {
|
||||
std::stringstream msg;
|
||||
msg << "[Cholesky::eval_cpu] Cholesky decomposition failed with error code "
|
||||
<< info;
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
|
||||
// Zero out the upper/lower triangle while advancing the pointer to the
|
||||
// next matrix at the same time.
|
||||
for (int row = 0; row < N; row++) {
|
||||
if (upper) {
|
||||
std::fill(matrix, matrix + row, 0);
|
||||
} else {
|
||||
std::fill(matrix + row + 1, matrix + N, 0);
|
||||
}
|
||||
matrix += N;
|
||||
}
|
||||
matrix += N;
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
void Cholesky::eval_cpu(const std::vector<array>& inputs, array& output) {
|
||||
switch (inputs[0].dtype()) {
|
||||
case float32:
|
||||
cholesky_impl<float>(inputs[0], output, upper_);
|
||||
cholesky_impl<float>(inputs[0], output, upper_, stream());
|
||||
break;
|
||||
case float64:
|
||||
cholesky_impl<double>(inputs[0], output, upper_);
|
||||
cholesky_impl<double>(inputs[0], output, upper_, stream());
|
||||
break;
|
||||
default:
|
||||
throw std::runtime_error(
|
||||
|
||||
@@ -11,6 +11,7 @@
|
||||
|
||||
#include "mlx/backend/common/compiled.h"
|
||||
#include "mlx/backend/cpu/compiled_preamble.h"
|
||||
#include "mlx/backend/cpu/encoder.h"
|
||||
#include "mlx/backend/cpu/jit_compiler.h"
|
||||
#include "mlx/device.h"
|
||||
#include "mlx/graph_utils.h"
|
||||
@@ -288,6 +289,7 @@ void Compiled::eval_cpu(
|
||||
// Figure out which kernel we are using
|
||||
auto& shape = outputs[0].shape();
|
||||
auto contiguous = compiled_check_contiguity(inputs, shape);
|
||||
auto& encoder = cpu::get_command_encoder(stream());
|
||||
|
||||
// Handle all broadcasting and collect function input arguments
|
||||
std::vector<void*> args;
|
||||
@@ -298,6 +300,7 @@ void Compiled::eval_cpu(
|
||||
continue;
|
||||
}
|
||||
auto& x = inputs[i];
|
||||
encoder.set_input_array(x);
|
||||
args.push_back((void*)x.data<void>());
|
||||
|
||||
if (contiguous || is_scalar(x)) {
|
||||
@@ -356,18 +359,25 @@ void Compiled::eval_cpu(
|
||||
});
|
||||
|
||||
compiled_allocate_outputs(
|
||||
inputs, outputs, inputs_, constant_ids_, contiguous, false);
|
||||
inputs, outputs, inputs_, constant_ids_, contiguous);
|
||||
|
||||
for (auto& x : outputs) {
|
||||
args.push_back(x.data<void>());
|
||||
encoder.set_output_array(x);
|
||||
}
|
||||
Shape out_shape;
|
||||
if (!contiguous) {
|
||||
args.push_back((void*)outputs[0].shape().data());
|
||||
out_shape = outputs[0].shape();
|
||||
args.push_back((void*)out_shape.data());
|
||||
} else {
|
||||
args.push_back((void*)outputs[0].data_size());
|
||||
}
|
||||
auto fun = (void (*)(void**))fn_ptr;
|
||||
fun(args.data());
|
||||
encoder.dispatch(
|
||||
[fun,
|
||||
args = std::move(args),
|
||||
strides = std::move(strides),
|
||||
out_shape = std::move(out_shape)]() mutable { fun(args.data()); });
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -5,6 +5,7 @@
|
||||
#include "mlx/allocator.h"
|
||||
#include "mlx/backend/common/utils.h"
|
||||
#include "mlx/backend/cpu/copy.h"
|
||||
#include "mlx/backend/cpu/encoder.h"
|
||||
#include "mlx/backend/cpu/simd/simd.h"
|
||||
|
||||
namespace mlx::core {
|
||||
@@ -13,19 +14,19 @@ namespace {
|
||||
|
||||
template <typename SrcT, typename DstT>
|
||||
void copy_single(const array& src, array& dst) {
|
||||
auto val = static_cast<DstT>(src.data<SrcT>()[0]);
|
||||
auto src_ptr = src.data<SrcT>();
|
||||
auto dst_ptr = dst.data<DstT>();
|
||||
for (int i = 0; i < dst.size(); ++i) {
|
||||
dst_ptr[i] = val;
|
||||
}
|
||||
auto size = dst.size();
|
||||
auto val = static_cast<DstT>(src_ptr[0]);
|
||||
std::fill_n(dst_ptr, size, val);
|
||||
}
|
||||
|
||||
template <typename SrcT, typename DstT>
|
||||
void copy_vector(const array& src, array& dst) {
|
||||
auto src_ptr = src.data<SrcT>();
|
||||
auto dst_ptr = dst.data<DstT>();
|
||||
size_t size = src.data_size();
|
||||
std::copy(src_ptr, src_ptr + src.data_size(), dst_ptr);
|
||||
auto size = src.data_size();
|
||||
std::copy(src_ptr, src_ptr + size, dst_ptr);
|
||||
}
|
||||
|
||||
template <typename SrcT, typename DstT, int D>
|
||||
@@ -60,36 +61,57 @@ void copy_general_general(
|
||||
const Strides& i_strides,
|
||||
const Strides& o_strides,
|
||||
int64_t i_offset,
|
||||
int64_t o_offset) {
|
||||
int64_t o_offset,
|
||||
const std::optional<array>& dynamic_i_offset,
|
||||
const std::optional<array>& dynamic_o_offset) {
|
||||
auto src_ptr = src.data<SrcT>() + i_offset;
|
||||
auto dst_ptr = dst.data<DstT>() + o_offset;
|
||||
auto i_offset_ptr =
|
||||
dynamic_i_offset ? dynamic_i_offset->data<int64_t>() : nullptr;
|
||||
auto o_offset_ptr =
|
||||
dynamic_o_offset ? dynamic_o_offset->data<int64_t>() : nullptr;
|
||||
auto size = src.size();
|
||||
if (data_shape.empty()) {
|
||||
auto val = static_cast<DstT>(*(src.data<SrcT>() + i_offset));
|
||||
auto dst_ptr = dst.data<DstT>() + o_offset;
|
||||
auto val = static_cast<DstT>(*src_ptr);
|
||||
*dst_ptr = val;
|
||||
return;
|
||||
}
|
||||
auto [shape, strides] =
|
||||
collapse_contiguous_dims(data_shape, {i_strides, o_strides});
|
||||
auto src_ptr = src.data<SrcT>() + i_offset;
|
||||
auto dst_ptr = dst.data<DstT>() + o_offset;
|
||||
|
||||
int ndim = shape.size();
|
||||
if (ndim == 1) {
|
||||
copy_dims<SrcT, DstT, 1>(
|
||||
src_ptr, dst_ptr, shape, strides[0], strides[1], 0);
|
||||
return;
|
||||
} else if (ndim == 2) {
|
||||
copy_dims<SrcT, DstT, 2>(
|
||||
src_ptr, dst_ptr, shape, strides[0], strides[1], 0);
|
||||
return;
|
||||
} else if (ndim == 3) {
|
||||
copy_dims<SrcT, DstT, 3>(
|
||||
src_ptr, dst_ptr, shape, strides[0], strides[1], 0);
|
||||
if (ndim < 3) {
|
||||
if (i_offset_ptr) {
|
||||
src_ptr += i_offset_ptr[0];
|
||||
}
|
||||
if (o_offset_ptr) {
|
||||
dst_ptr += o_offset_ptr[0];
|
||||
}
|
||||
|
||||
if (ndim == 1) {
|
||||
copy_dims<SrcT, DstT, 1>(
|
||||
src_ptr, dst_ptr, shape, strides[0], strides[1], 0);
|
||||
} else if (ndim == 2) {
|
||||
copy_dims<SrcT, DstT, 2>(
|
||||
src_ptr, dst_ptr, shape, strides[0], strides[1], 0);
|
||||
} else if (ndim == 3) {
|
||||
copy_dims<SrcT, DstT, 3>(
|
||||
src_ptr, dst_ptr, shape, strides[0], strides[1], 0);
|
||||
}
|
||||
return;
|
||||
}
|
||||
if (i_offset_ptr) {
|
||||
src_ptr += i_offset_ptr[0];
|
||||
}
|
||||
if (o_offset_ptr) {
|
||||
dst_ptr += o_offset_ptr[0];
|
||||
}
|
||||
|
||||
ContiguousIterator in(shape, strides[0], ndim - 3);
|
||||
ContiguousIterator out(shape, strides[1], ndim - 3);
|
||||
auto stride = std::accumulate(
|
||||
shape.end() - 3, shape.end(), 1, std::multiplies<int64_t>());
|
||||
for (int64_t elem = 0; elem < src.size(); elem += stride) {
|
||||
for (int64_t elem = 0; elem < size; elem += stride) {
|
||||
copy_dims<SrcT, DstT, 3>(
|
||||
src_ptr + in.loc,
|
||||
dst_ptr + out.loc,
|
||||
@@ -105,7 +127,15 @@ void copy_general_general(
|
||||
template <typename SrcT, typename DstT>
|
||||
inline void copy_general_general(const array& src, array& dst) {
|
||||
copy_general_general<SrcT, DstT>(
|
||||
src, dst, src.shape(), src.strides(), dst.strides(), 0, 0);
|
||||
src,
|
||||
dst,
|
||||
src.shape(),
|
||||
src.strides(),
|
||||
dst.strides(),
|
||||
0,
|
||||
0,
|
||||
std::nullopt,
|
||||
std::nullopt);
|
||||
}
|
||||
|
||||
template <typename SrcT, typename DstT>
|
||||
@@ -116,7 +146,9 @@ void copy_general(
|
||||
const Strides& i_strides,
|
||||
const Strides&,
|
||||
int64_t i_offset,
|
||||
int64_t o_offset) {
|
||||
int64_t o_offset,
|
||||
const std::optional<array>& dynamic_i_offset,
|
||||
const std::optional<array>& dynamic_o_offset) {
|
||||
copy_general_general<SrcT, DstT>(
|
||||
src,
|
||||
dst,
|
||||
@@ -124,7 +156,9 @@ void copy_general(
|
||||
i_strides,
|
||||
make_contiguous_strides(data_shape),
|
||||
i_offset,
|
||||
o_offset);
|
||||
o_offset,
|
||||
dynamic_i_offset,
|
||||
dynamic_o_offset);
|
||||
}
|
||||
|
||||
template <typename SrcT, typename DstT>
|
||||
@@ -136,7 +170,9 @@ inline void copy_general(const array& src, array& dst) {
|
||||
src.strides(),
|
||||
make_contiguous_strides(src.shape()),
|
||||
0,
|
||||
0);
|
||||
0,
|
||||
std::nullopt,
|
||||
std::nullopt);
|
||||
}
|
||||
|
||||
template <typename SrcT, typename DstT, typename... Args>
|
||||
@@ -259,35 +295,27 @@ inline void copy_inplace_dispatch(
|
||||
|
||||
} // namespace
|
||||
|
||||
void copy_inplace(const array& src, array& dst, CopyType ctype) {
|
||||
copy_inplace_dispatch(src, dst, ctype);
|
||||
void copy_inplace(const array& src, array& dst, CopyType ctype, Stream stream) {
|
||||
auto& encoder = cpu::get_command_encoder(stream);
|
||||
encoder.set_input_array(src);
|
||||
encoder.set_output_array(dst);
|
||||
encoder.dispatch(
|
||||
[src = array::unsafe_weak_copy(src),
|
||||
dst = array::unsafe_weak_copy(dst),
|
||||
ctype]() mutable { copy_inplace_dispatch(src, dst, ctype); });
|
||||
}
|
||||
|
||||
void copy(const array& src, array& dst, CopyType ctype) {
|
||||
// Allocate the output
|
||||
switch (ctype) {
|
||||
case CopyType::Vector:
|
||||
if (src.is_donatable() && src.itemsize() == dst.itemsize()) {
|
||||
dst.copy_shared_buffer(src);
|
||||
} else {
|
||||
auto size = src.data_size();
|
||||
dst.set_data(
|
||||
allocator::malloc_or_wait(size * dst.itemsize()),
|
||||
size,
|
||||
src.strides(),
|
||||
src.flags());
|
||||
}
|
||||
break;
|
||||
case CopyType::Scalar:
|
||||
case CopyType::General:
|
||||
case CopyType::GeneralGeneral:
|
||||
dst.set_data(allocator::malloc_or_wait(dst.nbytes()));
|
||||
break;
|
||||
void copy(const array& src, array& dst, CopyType ctype, Stream stream) {
|
||||
bool donated = set_copy_output_data(src, dst, ctype);
|
||||
if (donated && src.dtype() == dst.dtype()) {
|
||||
// If the output has the same type as the input then there is nothing to
|
||||
// copy, just use the buffer.
|
||||
return;
|
||||
}
|
||||
if (ctype == CopyType::GeneralGeneral) {
|
||||
ctype = CopyType::General;
|
||||
}
|
||||
copy_inplace(src, dst, ctype);
|
||||
copy_inplace(src, dst, ctype, stream);
|
||||
}
|
||||
|
||||
void copy_inplace(
|
||||
@@ -298,24 +326,51 @@ void copy_inplace(
|
||||
const Strides& o_strides,
|
||||
int64_t i_offset,
|
||||
int64_t o_offset,
|
||||
CopyType ctype) {
|
||||
switch (ctype) {
|
||||
case CopyType::General:
|
||||
case CopyType::GeneralGeneral:
|
||||
copy_inplace_dispatch(
|
||||
src,
|
||||
dst,
|
||||
ctype,
|
||||
data_shape,
|
||||
i_strides,
|
||||
o_strides,
|
||||
i_offset,
|
||||
o_offset);
|
||||
break;
|
||||
case CopyType::Scalar:
|
||||
case CopyType::Vector:
|
||||
copy_inplace_dispatch(src, dst, ctype);
|
||||
}
|
||||
CopyType ctype,
|
||||
Stream stream,
|
||||
const std::optional<array>& dynamic_i_offset, /* = std::nullopt */
|
||||
const std::optional<array>& dynamic_o_offset /* = std::nullopt */) {
|
||||
auto& encoder = cpu::get_command_encoder(stream);
|
||||
encoder.set_input_array(src);
|
||||
encoder.set_output_array(dst);
|
||||
auto weak_copy_if_set = [](auto x) -> std::optional<array> {
|
||||
if (x) {
|
||||
return array::unsafe_weak_copy(*x);
|
||||
} else {
|
||||
return std::nullopt;
|
||||
}
|
||||
};
|
||||
encoder.dispatch(
|
||||
[src = array::unsafe_weak_copy(src),
|
||||
dst = array::unsafe_weak_copy(dst),
|
||||
data_shape,
|
||||
i_strides,
|
||||
o_strides,
|
||||
i_offset,
|
||||
o_offset,
|
||||
ctype,
|
||||
dynamic_i_offset = weak_copy_if_set(dynamic_i_offset),
|
||||
dynamic_o_offset = weak_copy_if_set(dynamic_o_offset)]() mutable {
|
||||
switch (ctype) {
|
||||
case CopyType::General:
|
||||
case CopyType::GeneralGeneral:
|
||||
copy_inplace_dispatch(
|
||||
src,
|
||||
dst,
|
||||
ctype,
|
||||
data_shape,
|
||||
i_strides,
|
||||
o_strides,
|
||||
i_offset,
|
||||
o_offset,
|
||||
dynamic_i_offset,
|
||||
dynamic_o_offset);
|
||||
break;
|
||||
case CopyType::Scalar:
|
||||
case CopyType::Vector:
|
||||
copy_inplace_dispatch(src, dst, ctype);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
||||
@@ -2,14 +2,16 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <optional>
|
||||
|
||||
#include "mlx/array.h"
|
||||
#include "mlx/backend/common/copy.h"
|
||||
#include "mlx/backend/common/utils.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
void copy(const array& src, array& dst, CopyType ctype);
|
||||
void copy_inplace(const array& src, array& dst, CopyType ctype);
|
||||
void copy(const array& src, array& dst, CopyType ctype, Stream stream);
|
||||
void copy_inplace(const array& src, array& dst, CopyType ctype, Stream stream);
|
||||
|
||||
void copy_inplace(
|
||||
const array& src,
|
||||
@@ -19,6 +21,9 @@ void copy_inplace(
|
||||
const Strides& o_strides,
|
||||
int64_t i_offset,
|
||||
int64_t o_offset,
|
||||
CopyType ctype);
|
||||
CopyType ctype,
|
||||
Stream stream,
|
||||
const std::optional<array>& dynamic_i_offset = std::nullopt,
|
||||
const std::optional<array>& dynamic_o_offset = std::nullopt);
|
||||
|
||||
} // namespace mlx::core
|
||||
|
||||
101
mlx/backend/cpu/distributed.cpp
Normal file
101
mlx/backend/cpu/distributed.cpp
Normal file
@@ -0,0 +1,101 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#include <cassert>
|
||||
|
||||
#include "mlx/allocator.h"
|
||||
#include "mlx/backend/cpu/copy.h"
|
||||
#include "mlx/backend/cpu/encoder.h"
|
||||
#include "mlx/distributed/primitives.h"
|
||||
|
||||
namespace mlx::core::distributed {
|
||||
|
||||
std::pair<array, bool> ensure_row_contiguous(const array& arr, Stream stream) {
|
||||
if (arr.flags().row_contiguous) {
|
||||
return {arr, false};
|
||||
} else {
|
||||
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
|
||||
copy(arr, arr_copy, CopyType::General, stream);
|
||||
return {arr_copy, true};
|
||||
}
|
||||
};
|
||||
|
||||
void AllReduce::eval_cpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
assert(inputs.size() == 1);
|
||||
assert(outputs.size() == 1);
|
||||
|
||||
auto donate_or_copy = [s = stream()](const array& in, array& out) {
|
||||
if (in.flags().row_contiguous) {
|
||||
if (in.is_donatable()) {
|
||||
out.copy_shared_buffer(in);
|
||||
} else {
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
}
|
||||
return in;
|
||||
} else {
|
||||
array arr_copy(in.shape(), in.dtype(), nullptr, {});
|
||||
copy(in, arr_copy, CopyType::General, s);
|
||||
out.copy_shared_buffer(arr_copy);
|
||||
return arr_copy;
|
||||
}
|
||||
};
|
||||
|
||||
auto in = donate_or_copy(inputs[0], outputs[0]);
|
||||
switch (reduce_type_) {
|
||||
case Sum:
|
||||
distributed::detail::all_sum(group(), in, outputs[0], stream());
|
||||
break;
|
||||
case Max:
|
||||
distributed::detail::all_max(group(), in, outputs[0], stream());
|
||||
break;
|
||||
case Min:
|
||||
distributed::detail::all_min(group(), in, outputs[0], stream());
|
||||
break;
|
||||
default:
|
||||
throw std::runtime_error(
|
||||
"Only all reduce sum, min and max are supported for now");
|
||||
}
|
||||
}
|
||||
|
||||
void AllGather::eval_cpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
assert(inputs.size() == 1);
|
||||
assert(outputs.size() == 1);
|
||||
|
||||
auto [in, copied] = ensure_row_contiguous(inputs[0], stream());
|
||||
outputs[0].set_data(allocator::malloc(outputs[0].nbytes()));
|
||||
distributed::detail::all_gather(group(), in, outputs[0], stream());
|
||||
if (copied) {
|
||||
auto& enc = cpu::get_command_encoder(stream());
|
||||
enc.add_temporary(in);
|
||||
}
|
||||
}
|
||||
|
||||
void Send::eval_cpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
assert(inputs.size() == 1);
|
||||
assert(outputs.size() == 1);
|
||||
|
||||
auto [in, copied] = ensure_row_contiguous(inputs[0], stream());
|
||||
distributed::detail::send(group(), in, dst_, stream());
|
||||
outputs[0].copy_shared_buffer(inputs[0]);
|
||||
if (copied) {
|
||||
auto& enc = cpu::get_command_encoder(stream());
|
||||
enc.add_temporary(in);
|
||||
}
|
||||
}
|
||||
|
||||
void Recv::eval_cpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
assert(inputs.size() == 0);
|
||||
assert(outputs.size() == 1);
|
||||
|
||||
outputs[0].set_data(allocator::malloc(outputs[0].nbytes()));
|
||||
distributed::detail::recv(group(), outputs[0], src_, stream());
|
||||
}
|
||||
|
||||
} // namespace mlx::core::distributed
|
||||
@@ -3,6 +3,7 @@
|
||||
#include "mlx/allocator.h"
|
||||
#include "mlx/array.h"
|
||||
#include "mlx/backend/cpu/copy.h"
|
||||
#include "mlx/backend/cpu/encoder.h"
|
||||
#include "mlx/backend/cpu/lapack.h"
|
||||
#include "mlx/linalg.h"
|
||||
#include "mlx/primitives.h"
|
||||
@@ -16,59 +17,71 @@ void eigh_impl(
|
||||
array& vectors,
|
||||
array& values,
|
||||
const std::string& uplo,
|
||||
bool compute_eigenvectors) {
|
||||
bool compute_eigenvectors,
|
||||
Stream stream) {
|
||||
auto vec_ptr = vectors.data<T>();
|
||||
auto eig_ptr = values.data<T>();
|
||||
|
||||
char jobz = compute_eigenvectors ? 'V' : 'N';
|
||||
auto N = vectors.shape(-1);
|
||||
|
||||
// Work query
|
||||
int lwork = -1;
|
||||
int liwork = -1;
|
||||
int info;
|
||||
{
|
||||
T work;
|
||||
int iwork;
|
||||
syevd<T>(
|
||||
&jobz,
|
||||
uplo.c_str(),
|
||||
&N,
|
||||
nullptr,
|
||||
&N,
|
||||
nullptr,
|
||||
&work,
|
||||
&lwork,
|
||||
&iwork,
|
||||
&liwork,
|
||||
&info);
|
||||
lwork = static_cast<int>(work);
|
||||
liwork = iwork;
|
||||
}
|
||||
|
||||
auto work_buf = array::Data{allocator::malloc_or_wait(sizeof(T) * lwork)};
|
||||
auto iwork_buf = array::Data{allocator::malloc_or_wait(sizeof(int) * liwork)};
|
||||
for (size_t i = 0; i < vectors.size() / (N * N); ++i) {
|
||||
syevd<T>(
|
||||
&jobz,
|
||||
uplo.c_str(),
|
||||
&N,
|
||||
vec_ptr,
|
||||
&N,
|
||||
eig_ptr,
|
||||
static_cast<T*>(work_buf.buffer.raw_ptr()),
|
||||
&lwork,
|
||||
static_cast<int*>(iwork_buf.buffer.raw_ptr()),
|
||||
&liwork,
|
||||
&info);
|
||||
vec_ptr += N * N;
|
||||
eig_ptr += N;
|
||||
if (info != 0) {
|
||||
std::stringstream msg;
|
||||
msg << "[Eigh::eval_cpu] Eigenvalue decomposition failed with error code "
|
||||
<< info;
|
||||
throw std::runtime_error(msg.str());
|
||||
auto& encoder = cpu::get_command_encoder(stream);
|
||||
encoder.set_output_array(vectors);
|
||||
encoder.set_output_array(values);
|
||||
encoder.dispatch([vec_ptr,
|
||||
eig_ptr,
|
||||
jobz,
|
||||
uplo = uplo[0],
|
||||
N = vectors.shape(-1),
|
||||
size = vectors.size()]() mutable {
|
||||
// Work query
|
||||
int lwork = -1;
|
||||
int liwork = -1;
|
||||
int info;
|
||||
{
|
||||
T work;
|
||||
int iwork;
|
||||
syevd<T>(
|
||||
&jobz,
|
||||
&uplo,
|
||||
&N,
|
||||
nullptr,
|
||||
&N,
|
||||
nullptr,
|
||||
&work,
|
||||
&lwork,
|
||||
&iwork,
|
||||
&liwork,
|
||||
&info);
|
||||
lwork = static_cast<int>(work);
|
||||
liwork = iwork;
|
||||
}
|
||||
|
||||
auto work_buf = array::Data{allocator::malloc(sizeof(T) * lwork)};
|
||||
auto iwork_buf = array::Data{allocator::malloc(sizeof(int) * liwork)};
|
||||
for (size_t i = 0; i < size / (N * N); ++i) {
|
||||
syevd<T>(
|
||||
&jobz,
|
||||
&uplo,
|
||||
&N,
|
||||
vec_ptr,
|
||||
&N,
|
||||
eig_ptr,
|
||||
static_cast<T*>(work_buf.buffer.raw_ptr()),
|
||||
&lwork,
|
||||
static_cast<int*>(iwork_buf.buffer.raw_ptr()),
|
||||
&liwork,
|
||||
&info);
|
||||
vec_ptr += N * N;
|
||||
eig_ptr += N;
|
||||
if (info != 0) {
|
||||
std::stringstream msg;
|
||||
msg << "[Eigh::eval_cpu] Eigenvalue decomposition failed with error code "
|
||||
<< info;
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
}
|
||||
});
|
||||
if (!compute_eigenvectors) {
|
||||
encoder.add_temporary(vectors);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -84,12 +97,13 @@ void Eigh::eval_cpu(
|
||||
? outputs[1]
|
||||
: array(a.shape(), a.dtype(), nullptr, {});
|
||||
|
||||
values.set_data(allocator::malloc_or_wait(values.nbytes()));
|
||||
values.set_data(allocator::malloc(values.nbytes()));
|
||||
|
||||
copy(
|
||||
a,
|
||||
vectors,
|
||||
a.flags().row_contiguous ? CopyType::Vector : CopyType::General);
|
||||
a.flags().row_contiguous ? CopyType::Vector : CopyType::General,
|
||||
stream());
|
||||
|
||||
if (compute_eigenvectors_) {
|
||||
// Set the strides and flags so the eigenvectors
|
||||
@@ -107,14 +121,15 @@ void Eigh::eval_cpu(
|
||||
flags.col_contiguous = true;
|
||||
}
|
||||
}
|
||||
vectors.move_shared_buffer(vectors, strides, flags, vectors.data_size());
|
||||
vectors.copy_shared_buffer(vectors, strides, flags, vectors.data_size());
|
||||
}
|
||||
switch (a.dtype()) {
|
||||
case float32:
|
||||
eigh_impl<float>(vectors, values, uplo_, compute_eigenvectors_);
|
||||
eigh_impl<float>(vectors, values, uplo_, compute_eigenvectors_, stream());
|
||||
break;
|
||||
case float64:
|
||||
eigh_impl<double>(vectors, values, uplo_, compute_eigenvectors_);
|
||||
eigh_impl<double>(
|
||||
vectors, values, uplo_, compute_eigenvectors_, stream());
|
||||
break;
|
||||
default:
|
||||
throw std::runtime_error(
|
||||
|
||||
16
mlx/backend/cpu/encoder.cpp
Normal file
16
mlx/backend/cpu/encoder.cpp
Normal file
@@ -0,0 +1,16 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cpu/encoder.h"
|
||||
|
||||
namespace mlx::core::cpu {
|
||||
|
||||
CommandEncoder& get_command_encoder(Stream stream) {
|
||||
static std::unordered_map<int, CommandEncoder> encoder_map;
|
||||
auto it = encoder_map.find(stream.index);
|
||||
if (it == encoder_map.end()) {
|
||||
it = encoder_map.emplace(stream.index, stream).first;
|
||||
}
|
||||
return it->second;
|
||||
}
|
||||
|
||||
} // namespace mlx::core::cpu
|
||||
67
mlx/backend/cpu/encoder.h
Normal file
67
mlx/backend/cpu/encoder.h
Normal file
@@ -0,0 +1,67 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <unordered_map>
|
||||
|
||||
#include "mlx/array.h"
|
||||
#include "mlx/scheduler.h"
|
||||
|
||||
namespace mlx::core::cpu {
|
||||
|
||||
// Number of dispatches per scheduler task
|
||||
constexpr int DISPATCHES_PER_TASK = 10;
|
||||
|
||||
struct CommandEncoder {
|
||||
CommandEncoder(Stream stream) : stream_(stream) {}
|
||||
|
||||
CommandEncoder(const CommandEncoder&) = delete;
|
||||
CommandEncoder& operator=(const CommandEncoder&) = delete;
|
||||
CommandEncoder(CommandEncoder&&) = delete;
|
||||
CommandEncoder& operator=(CommandEncoder&&) = delete;
|
||||
|
||||
void set_input_array(const array& a) {}
|
||||
void set_output_array(array& a) {}
|
||||
|
||||
// Hold onto a temporary until any already scheduled tasks which use it as
|
||||
// an input are complete.
|
||||
void add_temporary(array arr) {
|
||||
temporaries_.push_back(std::move(arr));
|
||||
}
|
||||
|
||||
void add_temporaries(std::vector<array> arrays) {
|
||||
temporaries_.insert(
|
||||
temporaries_.end(),
|
||||
std::make_move_iterator(arrays.begin()),
|
||||
std::make_move_iterator(arrays.end()));
|
||||
}
|
||||
|
||||
std::vector<array>& temporaries() {
|
||||
return temporaries_;
|
||||
}
|
||||
|
||||
template <class F, class... Args>
|
||||
void dispatch(F&& f, Args&&... args) {
|
||||
num_ops_ = (num_ops_ + 1) % DISPATCHES_PER_TASK;
|
||||
auto task = std::bind(std::forward<F>(f), std::forward<Args>(args)...);
|
||||
if (num_ops_ == 0) {
|
||||
scheduler::notify_new_task(stream_);
|
||||
auto task_wrap = [s = stream_, task = std::move(task)]() mutable {
|
||||
task();
|
||||
scheduler::notify_task_completion(s);
|
||||
};
|
||||
scheduler::enqueue(stream_, std::move(task_wrap));
|
||||
} else {
|
||||
scheduler::enqueue(stream_, std::move(task));
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
Stream stream_;
|
||||
std::vector<array> temporaries_;
|
||||
int num_ops_{0};
|
||||
};
|
||||
|
||||
CommandEncoder& get_command_encoder(Stream stream);
|
||||
|
||||
} // namespace mlx::core::cpu
|
||||
40
mlx/backend/cpu/eval.cpp
Normal file
40
mlx/backend/cpu/eval.cpp
Normal file
@@ -0,0 +1,40 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
#include "mlx/backend/cpu/eval.h"
|
||||
#include "mlx/backend/cpu/encoder.h"
|
||||
#include "mlx/primitives.h"
|
||||
#include "mlx/scheduler.h"
|
||||
#include "mlx/utils.h"
|
||||
|
||||
namespace mlx::core::cpu {
|
||||
|
||||
void eval(array& arr) {
|
||||
auto s = arr.primitive().stream();
|
||||
|
||||
auto outputs = arr.outputs();
|
||||
{
|
||||
// If the array is a tracer hold a reference
|
||||
// to its inputs so they don't get donated
|
||||
std::vector<array> inputs;
|
||||
if (arr.is_tracer()) {
|
||||
inputs = arr.inputs();
|
||||
}
|
||||
arr.primitive().eval_cpu(arr.inputs(), outputs);
|
||||
}
|
||||
|
||||
std::unordered_set<std::shared_ptr<array::Data>> buffers;
|
||||
for (auto& in : arr.inputs()) {
|
||||
buffers.insert(in.data_shared_ptr());
|
||||
}
|
||||
for (auto& s : arr.siblings()) {
|
||||
buffers.insert(s.data_shared_ptr());
|
||||
}
|
||||
// Remove the output if it was donated to by an input
|
||||
if (auto it = buffers.find(arr.data_shared_ptr()); it != buffers.end()) {
|
||||
buffers.erase(it);
|
||||
}
|
||||
auto& encoder = cpu::get_command_encoder(s);
|
||||
encoder.dispatch([buffers = std::move(buffers),
|
||||
temps = std::move(encoder.temporaries())]() {});
|
||||
}
|
||||
|
||||
} // namespace mlx::core::cpu
|
||||
12
mlx/backend/cpu/eval.h
Normal file
12
mlx/backend/cpu/eval.h
Normal file
@@ -0,0 +1,12 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "mlx/array.h"
|
||||
#include "mlx/stream.h"
|
||||
|
||||
namespace mlx::core::cpu {
|
||||
|
||||
void eval(array& arr);
|
||||
|
||||
} // namespace mlx::core::cpu
|
||||
@@ -4,6 +4,7 @@
|
||||
|
||||
#include "mlx/3rdparty/pocketfft.h"
|
||||
#include "mlx/allocator.h"
|
||||
#include "mlx/backend/cpu/encoder.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
namespace mlx::core {
|
||||
@@ -21,7 +22,7 @@ void FFT::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
s *= out.itemsize();
|
||||
}
|
||||
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
|
||||
std::vector<size_t> shape;
|
||||
if (out.dtype() == float32) {
|
||||
@@ -38,46 +39,78 @@ void FFT::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
});
|
||||
scale /= nelem;
|
||||
}
|
||||
|
||||
auto& encoder = cpu::get_command_encoder(stream());
|
||||
encoder.set_input_array(in);
|
||||
encoder.set_output_array(out);
|
||||
|
||||
if (in.dtype() == complex64 && out.dtype() == complex64) {
|
||||
auto in_ptr =
|
||||
reinterpret_cast<const std::complex<float>*>(in.data<complex64_t>());
|
||||
auto out_ptr =
|
||||
reinterpret_cast<std::complex<float>*>(out.data<complex64_t>());
|
||||
pocketfft::c2c(
|
||||
shape,
|
||||
strides_in,
|
||||
strides_out,
|
||||
axes_,
|
||||
!inverse_,
|
||||
in_ptr,
|
||||
out_ptr,
|
||||
scale);
|
||||
encoder.dispatch([shape = std::move(shape),
|
||||
strides_in = std::move(strides_in),
|
||||
strides_out = std::move(strides_out),
|
||||
axes = axes_,
|
||||
inverse = inverse_,
|
||||
in_ptr,
|
||||
out_ptr,
|
||||
scale]() {
|
||||
pocketfft::c2c(
|
||||
shape,
|
||||
strides_in,
|
||||
strides_out,
|
||||
axes,
|
||||
!inverse,
|
||||
in_ptr,
|
||||
out_ptr,
|
||||
scale);
|
||||
});
|
||||
} else if (in.dtype() == float32 && out.dtype() == complex64) {
|
||||
auto in_ptr = in.data<float>();
|
||||
auto out_ptr =
|
||||
reinterpret_cast<std::complex<float>*>(out.data<complex64_t>());
|
||||
pocketfft::r2c(
|
||||
shape,
|
||||
strides_in,
|
||||
strides_out,
|
||||
axes_,
|
||||
!inverse_,
|
||||
in_ptr,
|
||||
out_ptr,
|
||||
scale);
|
||||
encoder.dispatch([shape = std::move(shape),
|
||||
strides_in = std::move(strides_in),
|
||||
strides_out = std::move(strides_out),
|
||||
axes = axes_,
|
||||
inverse = inverse_,
|
||||
in_ptr,
|
||||
out_ptr,
|
||||
scale]() {
|
||||
pocketfft::r2c(
|
||||
shape,
|
||||
strides_in,
|
||||
strides_out,
|
||||
axes,
|
||||
!inverse,
|
||||
in_ptr,
|
||||
out_ptr,
|
||||
scale);
|
||||
});
|
||||
} else if (in.dtype() == complex64 && out.dtype() == float32) {
|
||||
auto in_ptr =
|
||||
reinterpret_cast<const std::complex<float>*>(in.data<complex64_t>());
|
||||
auto out_ptr = out.data<float>();
|
||||
pocketfft::c2r(
|
||||
shape,
|
||||
strides_in,
|
||||
strides_out,
|
||||
axes_,
|
||||
!inverse_,
|
||||
in_ptr,
|
||||
out_ptr,
|
||||
scale);
|
||||
encoder.dispatch([shape = std::move(shape),
|
||||
strides_in = std::move(strides_in),
|
||||
strides_out = std::move(strides_out),
|
||||
axes = axes_,
|
||||
inverse = inverse_,
|
||||
in_ptr,
|
||||
out_ptr,
|
||||
scale]() {
|
||||
pocketfft::c2r(
|
||||
shape,
|
||||
strides_in,
|
||||
strides_out,
|
||||
axes,
|
||||
!inverse,
|
||||
in_ptr,
|
||||
out_ptr,
|
||||
scale);
|
||||
});
|
||||
} else {
|
||||
throw std::runtime_error(
|
||||
"[FFT] Received unexpected input and output type combination.");
|
||||
|
||||
@@ -7,14 +7,20 @@ namespace mlx::core {
|
||||
|
||||
template <typename T>
|
||||
void matmul(
|
||||
const array& a,
|
||||
const array& b,
|
||||
array& out,
|
||||
const T* a,
|
||||
const T* b,
|
||||
T* out,
|
||||
bool a_transposed,
|
||||
bool b_transposed,
|
||||
size_t lda,
|
||||
size_t ldb,
|
||||
size_t ldc,
|
||||
float alpha,
|
||||
float beta);
|
||||
float beta,
|
||||
size_t batch_size,
|
||||
const Shape& a_shape,
|
||||
const Strides& a_strides,
|
||||
const Shape& b_shape,
|
||||
const Strides& b_strides);
|
||||
|
||||
} // namespace mlx::core
|
||||
|
||||
@@ -9,39 +9,46 @@
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
BNNSDataType to_bnns_dtype(Dtype mlx_dtype) {
|
||||
uint32_t size_bits = size_of(mlx_dtype) * 8;
|
||||
switch (kindof(mlx_dtype)) {
|
||||
case Dtype::Kind::b:
|
||||
return BNNSDataTypeBoolean;
|
||||
case Dtype::Kind::u:
|
||||
return BNNSDataType(BNNSDataTypeUIntBit | size_bits);
|
||||
case Dtype::Kind::i:
|
||||
return BNNSDataType(BNNSDataTypeIntBit | size_bits);
|
||||
case Dtype::Kind::f:
|
||||
return BNNSDataType(BNNSDataTypeFloatBit | size_bits);
|
||||
case Dtype::Kind::V:
|
||||
return BNNSDataTypeBFloat16;
|
||||
case Dtype::Kind::c:
|
||||
throw std::invalid_argument("BNNS does not support complex types");
|
||||
}
|
||||
template <typename T>
|
||||
constexpr BNNSDataType to_bnns_dtype();
|
||||
|
||||
template <>
|
||||
constexpr BNNSDataType to_bnns_dtype<float>() {
|
||||
return BNNSDataType(BNNSDataTypeFloatBit | 32);
|
||||
}
|
||||
template <>
|
||||
constexpr BNNSDataType to_bnns_dtype<float16_t>() {
|
||||
return BNNSDataType(BNNSDataTypeFloatBit | 16);
|
||||
}
|
||||
|
||||
template <>
|
||||
constexpr BNNSDataType to_bnns_dtype<bfloat16_t>() {
|
||||
return BNNSDataTypeBFloat16;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void matmul_bnns(
|
||||
const array& a,
|
||||
const array& b,
|
||||
array& out,
|
||||
const T* a,
|
||||
const T* b,
|
||||
T* out,
|
||||
bool a_transposed,
|
||||
bool b_transposed,
|
||||
size_t lda,
|
||||
size_t ldb,
|
||||
size_t ldc,
|
||||
float alpha,
|
||||
float beta) {
|
||||
size_t M = a.shape(-2);
|
||||
size_t N = b.shape(-1);
|
||||
size_t K = a.shape(-1);
|
||||
float beta,
|
||||
size_t batch_size,
|
||||
const Shape& a_shape,
|
||||
const Strides& a_strides,
|
||||
const Shape& b_shape,
|
||||
const Strides& b_strides) {
|
||||
auto ndim = a_shape.size();
|
||||
size_t M = a_shape[ndim - 2];
|
||||
size_t N = b_shape[ndim - 1];
|
||||
size_t K = a_shape[ndim - 1];
|
||||
|
||||
BNNSDataType bnns_dtype = to_bnns_dtype(out.dtype());
|
||||
BNNSDataType bnns_dtype = to_bnns_dtype<T>();
|
||||
|
||||
#pragma GCC diagnostic push
|
||||
#pragma GCC diagnostic ignored "-Wdeprecated-declarations"
|
||||
@@ -115,14 +122,14 @@ void matmul_bnns(
|
||||
auto bnns_filter =
|
||||
BNNSFilterCreateLayerBroadcastMatMul(&gemm_params, nullptr);
|
||||
|
||||
for (int i = 0; i < (a.size() / (M * K)); ++i) {
|
||||
for (int i = 0; i < batch_size; ++i) {
|
||||
BNNSFilterApplyTwoInput(
|
||||
bnns_filter,
|
||||
a.data<uint8_t>() +
|
||||
elem_to_loc(M * K * i, a.shape(), a.strides()) * a.itemsize(),
|
||||
b.data<uint8_t>() +
|
||||
elem_to_loc(K * N * i, b.shape(), b.strides()) * b.itemsize(),
|
||||
out.data<uint8_t>() + M * N * i * out.itemsize());
|
||||
reinterpret_cast<const uint8_t*>(
|
||||
a + elem_to_loc(M * K * i, a_shape, a_strides)),
|
||||
reinterpret_cast<const uint8_t*>(
|
||||
b + elem_to_loc(K * N * i, b_shape, b_strides)),
|
||||
reinterpret_cast<uint8_t*>(out + M * N * i));
|
||||
}
|
||||
|
||||
BNNSFilterDestroy(bnns_filter);
|
||||
@@ -131,30 +138,72 @@ void matmul_bnns(
|
||||
|
||||
template <>
|
||||
void matmul<float16_t>(
|
||||
const array& a,
|
||||
const array& b,
|
||||
array& out,
|
||||
const float16_t* a,
|
||||
const float16_t* b,
|
||||
float16_t* out,
|
||||
bool a_transposed,
|
||||
bool b_transposed,
|
||||
size_t lda,
|
||||
size_t ldb,
|
||||
size_t ldc,
|
||||
float alpha,
|
||||
float beta) {
|
||||
matmul_bnns(a, b, out, a_transposed, b_transposed, lda, ldb, alpha, beta);
|
||||
float beta,
|
||||
size_t batch_size,
|
||||
const Shape& a_shape,
|
||||
const Strides& a_strides,
|
||||
const Shape& b_shape,
|
||||
const Strides& b_strides) {
|
||||
matmul_bnns(
|
||||
a,
|
||||
b,
|
||||
out,
|
||||
a_transposed,
|
||||
b_transposed,
|
||||
lda,
|
||||
ldb,
|
||||
ldc,
|
||||
alpha,
|
||||
beta,
|
||||
batch_size,
|
||||
a_shape,
|
||||
a_strides,
|
||||
b_shape,
|
||||
b_strides);
|
||||
}
|
||||
|
||||
template <>
|
||||
void matmul<bfloat16_t>(
|
||||
const array& a,
|
||||
const array& b,
|
||||
array& out,
|
||||
const bfloat16_t* a,
|
||||
const bfloat16_t* b,
|
||||
bfloat16_t* out,
|
||||
bool a_transposed,
|
||||
bool b_transposed,
|
||||
size_t lda,
|
||||
size_t ldb,
|
||||
size_t ldc,
|
||||
float alpha,
|
||||
float beta) {
|
||||
matmul_bnns(a, b, out, a_transposed, b_transposed, lda, ldb, alpha, beta);
|
||||
float beta,
|
||||
size_t batch_size,
|
||||
const Shape& a_shape,
|
||||
const Strides& a_strides,
|
||||
const Shape& b_shape,
|
||||
const Strides& b_strides) {
|
||||
matmul_bnns(
|
||||
a,
|
||||
b,
|
||||
out,
|
||||
a_transposed,
|
||||
b_transposed,
|
||||
lda,
|
||||
ldb,
|
||||
ldc,
|
||||
alpha,
|
||||
beta,
|
||||
batch_size,
|
||||
a_shape,
|
||||
a_strides,
|
||||
b_shape,
|
||||
b_strides);
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
||||
@@ -8,20 +8,27 @@ namespace mlx::core {
|
||||
|
||||
template <>
|
||||
void matmul<float>(
|
||||
const array& a,
|
||||
const array& b,
|
||||
array& out,
|
||||
const float* a,
|
||||
const float* b,
|
||||
float* out,
|
||||
bool a_transposed,
|
||||
bool b_transposed,
|
||||
size_t lda,
|
||||
size_t ldb,
|
||||
size_t ldc,
|
||||
float alpha,
|
||||
float beta) {
|
||||
size_t M = a.shape(-2);
|
||||
size_t N = b.shape(-1);
|
||||
size_t K = a.shape(-1);
|
||||
float beta,
|
||||
size_t batch_size,
|
||||
const Shape& a_shape,
|
||||
const Strides& a_strides,
|
||||
const Shape& b_shape,
|
||||
const Strides& b_strides) {
|
||||
auto ndim = a_shape.size();
|
||||
size_t M = a_shape[ndim - 2];
|
||||
size_t N = b_shape[ndim - 1];
|
||||
size_t K = a_shape[ndim - 1];
|
||||
|
||||
for (int i = 0; i < (a.size() / (M * K)); ++i) {
|
||||
for (int i = 0; i < batch_size; ++i) {
|
||||
cblas_sgemm(
|
||||
CblasRowMajor,
|
||||
a_transposed ? CblasTrans : CblasNoTrans, // transA
|
||||
@@ -29,34 +36,40 @@ void matmul<float>(
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
alpha, // alpha
|
||||
a.data<float>() + elem_to_loc(M * K * i, a.shape(), a.strides()),
|
||||
alpha,
|
||||
a + elem_to_loc(M * K * i, a_shape, a_strides),
|
||||
lda,
|
||||
b.data<float>() + elem_to_loc(K * N * i, b.shape(), b.strides()),
|
||||
b + elem_to_loc(K * N * i, b_shape, b_strides),
|
||||
ldb,
|
||||
beta, // beta
|
||||
out.data<float>() + M * N * i,
|
||||
out.shape(-1) // ldc
|
||||
);
|
||||
beta,
|
||||
out + M * N * i,
|
||||
ldc);
|
||||
}
|
||||
}
|
||||
|
||||
template <>
|
||||
void matmul<double>(
|
||||
const array& a,
|
||||
const array& b,
|
||||
array& out,
|
||||
const double* a,
|
||||
const double* b,
|
||||
double* out,
|
||||
bool a_transposed,
|
||||
bool b_transposed,
|
||||
size_t lda,
|
||||
size_t ldb,
|
||||
size_t ldc,
|
||||
float alpha,
|
||||
float beta) {
|
||||
size_t M = a.shape(-2);
|
||||
size_t N = b.shape(-1);
|
||||
size_t K = a.shape(-1);
|
||||
float beta,
|
||||
size_t batch_size,
|
||||
const Shape& a_shape,
|
||||
const Strides& a_strides,
|
||||
const Shape& b_shape,
|
||||
const Strides& b_strides) {
|
||||
auto ndim = a_shape.size();
|
||||
size_t M = a_shape[ndim - 2];
|
||||
size_t N = b_shape[ndim - 1];
|
||||
size_t K = a_shape[ndim - 1];
|
||||
|
||||
for (int i = 0; i < (a.size() / (M * K)); ++i) {
|
||||
for (int i = 0; i < batch_size; ++i) {
|
||||
cblas_dgemm(
|
||||
CblasRowMajor,
|
||||
a_transposed ? CblasTrans : CblasNoTrans, // transA
|
||||
@@ -64,15 +77,14 @@ void matmul<double>(
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
alpha, // alpha
|
||||
a.data<double>() + elem_to_loc(M * K * i, a.shape(), a.strides()),
|
||||
alpha,
|
||||
a + elem_to_loc(M * K * i, a_shape, a_strides),
|
||||
lda,
|
||||
b.data<double>() + elem_to_loc(K * N * i, b.shape(), b.strides()),
|
||||
b + elem_to_loc(K * N * i, b_shape, b_strides),
|
||||
ldb,
|
||||
beta, // beta
|
||||
out.data<double>() + M * N * i,
|
||||
out.shape(-1) // ldc
|
||||
);
|
||||
beta,
|
||||
out + M * N * i,
|
||||
ldc);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,21 +0,0 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cpu/gemm.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
template <>
|
||||
void matmul<bfloat16_t>(
|
||||
const array&,
|
||||
const array&,
|
||||
array&,
|
||||
bool,
|
||||
bool,
|
||||
size_t,
|
||||
size_t,
|
||||
float,
|
||||
float) {
|
||||
throw std::runtime_error("[Matmul::eval_cpu] bfloat16 not supported.");
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
@@ -1,21 +0,0 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cpu/gemm.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
template <>
|
||||
void matmul<float16_t>(
|
||||
const array&,
|
||||
const array&,
|
||||
array&,
|
||||
bool,
|
||||
bool,
|
||||
size_t,
|
||||
size_t,
|
||||
float,
|
||||
float) {
|
||||
throw std::runtime_error("[Matmul::eval_cpu] float16 not supported.");
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
45
mlx/backend/cpu/gemms/simd_bf16.cpp
Normal file
45
mlx/backend/cpu/gemms/simd_bf16.cpp
Normal file
@@ -0,0 +1,45 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/common/utils.h"
|
||||
#include "mlx/backend/cpu/gemm.h"
|
||||
#include "mlx/backend/cpu/gemms/simd_gemm.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
template <>
|
||||
void matmul<bfloat16_t>(
|
||||
const bfloat16_t* a,
|
||||
const bfloat16_t* b,
|
||||
bfloat16_t* out,
|
||||
bool a_transposed,
|
||||
bool b_transposed,
|
||||
size_t lda,
|
||||
size_t ldb,
|
||||
size_t ldc,
|
||||
float alpha,
|
||||
float beta,
|
||||
size_t batch_size,
|
||||
const Shape& a_shape,
|
||||
const Strides& a_strides,
|
||||
const Shape& b_shape,
|
||||
const Strides& b_strides) {
|
||||
auto ndim = a_shape.size();
|
||||
size_t M = a_shape[ndim - 2];
|
||||
size_t N = b_shape[ndim - 1];
|
||||
size_t K = a_shape[ndim - 1];
|
||||
for (int i = 0; i < batch_size; ++i) {
|
||||
simd_gemm<bfloat16_t, float>(
|
||||
a + elem_to_loc(M * K * i, a_shape, a_strides),
|
||||
b + elem_to_loc(K * N * i, b_shape, b_strides),
|
||||
out + M * N * i,
|
||||
a_transposed,
|
||||
b_transposed,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
alpha,
|
||||
beta);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
45
mlx/backend/cpu/gemms/simd_fp16.cpp
Normal file
45
mlx/backend/cpu/gemms/simd_fp16.cpp
Normal file
@@ -0,0 +1,45 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/common/utils.h"
|
||||
#include "mlx/backend/cpu/gemm.h"
|
||||
#include "mlx/backend/cpu/gemms/simd_gemm.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
template <>
|
||||
void matmul<float16_t>(
|
||||
const float16_t* a,
|
||||
const float16_t* b,
|
||||
float16_t* out,
|
||||
bool a_transposed,
|
||||
bool b_transposed,
|
||||
size_t lda,
|
||||
size_t ldb,
|
||||
size_t ldc,
|
||||
float alpha,
|
||||
float beta,
|
||||
size_t batch_size,
|
||||
const Shape& a_shape,
|
||||
const Strides& a_strides,
|
||||
const Shape& b_shape,
|
||||
const Strides& b_strides) {
|
||||
auto ndim = a_shape.size();
|
||||
size_t M = a_shape[ndim - 2];
|
||||
size_t N = b_shape[ndim - 1];
|
||||
size_t K = a_shape[ndim - 1];
|
||||
for (int i = 0; i < batch_size; ++i) {
|
||||
simd_gemm<float16_t, float>(
|
||||
a + elem_to_loc(M * K * i, a_shape, a_strides),
|
||||
b + elem_to_loc(K * N * i, b_shape, b_strides),
|
||||
out + M * N * i,
|
||||
a_transposed,
|
||||
b_transposed,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
alpha,
|
||||
beta);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
139
mlx/backend/cpu/gemms/simd_gemm.h
Normal file
139
mlx/backend/cpu/gemms/simd_gemm.h
Normal file
@@ -0,0 +1,139 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
#pragma once
|
||||
|
||||
#include "mlx/backend/cpu/simd/simd.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
inline int ceildiv(int a, int b) {
|
||||
return (a + b - 1) / b;
|
||||
}
|
||||
|
||||
template <int block_size, typename T, typename AccT>
|
||||
void load_block(
|
||||
const T* in,
|
||||
AccT* out,
|
||||
int M,
|
||||
int N,
|
||||
int i,
|
||||
int j,
|
||||
bool transpose) {
|
||||
if (transpose) {
|
||||
for (int ii = 0; ii < block_size && i * block_size + ii < M; ++ii) {
|
||||
for (int jj = 0; jj < block_size && j * block_size + jj < N; ++jj) {
|
||||
out[jj * block_size + ii] =
|
||||
in[(i * block_size + ii) * N + j * block_size + jj];
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for (int ii = 0; ii < block_size && i * block_size + ii < M; ++ii) {
|
||||
for (int jj = 0; jj < block_size && j * block_size + jj < N; ++jj) {
|
||||
out[ii * block_size + jj] =
|
||||
in[(i * block_size + ii) * N + j * block_size + jj];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename AccT>
|
||||
void simd_gemm(
|
||||
const T* a,
|
||||
const T* b,
|
||||
T* c,
|
||||
bool a_trans,
|
||||
bool b_trans,
|
||||
int M,
|
||||
int N,
|
||||
int K,
|
||||
float alpha,
|
||||
float beta) {
|
||||
constexpr int block_size = 16;
|
||||
constexpr int simd_size = simd::max_size<AccT>;
|
||||
static_assert(
|
||||
(block_size % simd_size) == 0,
|
||||
"Block size must be divisible by SIMD size");
|
||||
|
||||
int last_k_block_size = K - block_size * (K / block_size);
|
||||
int last_k_simd_block = (last_k_block_size / simd_size) * simd_size;
|
||||
for (int i = 0; i < ceildiv(M, block_size); i++) {
|
||||
for (int j = 0; j < ceildiv(N, block_size); j++) {
|
||||
AccT c_block[block_size * block_size] = {0.0};
|
||||
AccT a_block[block_size * block_size];
|
||||
AccT b_block[block_size * block_size];
|
||||
|
||||
int k = 0;
|
||||
for (; k < K / block_size; k++) {
|
||||
// Load a and b blocks
|
||||
if (a_trans) {
|
||||
load_block<block_size>(a, a_block, K, M, k, i, true);
|
||||
} else {
|
||||
load_block<block_size>(a, a_block, M, K, i, k, false);
|
||||
}
|
||||
if (b_trans) {
|
||||
load_block<block_size>(b, b_block, N, K, j, k, false);
|
||||
} else {
|
||||
load_block<block_size>(b, b_block, K, N, k, j, true);
|
||||
}
|
||||
|
||||
// Multiply and accumulate
|
||||
for (int ii = 0; ii < block_size && i * block_size + ii < M; ++ii) {
|
||||
for (int jj = 0; jj < block_size && j * block_size + jj < N; ++jj) {
|
||||
for (int kk = 0; kk < block_size; kk += simd_size) {
|
||||
auto av =
|
||||
simd::load<AccT, simd_size>(a_block + ii * block_size + kk);
|
||||
auto bv =
|
||||
simd::load<AccT, simd_size>(b_block + jj * block_size + kk);
|
||||
c_block[ii * block_size + jj] += simd::sum(av * bv);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if (last_k_block_size) {
|
||||
// Load a and b blocks
|
||||
if (a_trans) {
|
||||
load_block<block_size>(a, a_block, K, M, k, i, true);
|
||||
} else {
|
||||
load_block<block_size>(a, a_block, M, K, i, k, false);
|
||||
}
|
||||
if (b_trans) {
|
||||
load_block<block_size>(b, b_block, N, K, j, k, false);
|
||||
} else {
|
||||
load_block<block_size>(b, b_block, K, N, k, j, true);
|
||||
}
|
||||
|
||||
// Multiply and accumulate
|
||||
for (int ii = 0; ii < block_size && i * block_size + ii < M; ++ii) {
|
||||
for (int jj = 0; jj < block_size && j * block_size + jj < N; ++jj) {
|
||||
int kk = 0;
|
||||
for (; kk < last_k_simd_block; kk += simd_size) {
|
||||
auto av =
|
||||
simd::load<AccT, simd_size>(a_block + ii * block_size + kk);
|
||||
auto bv =
|
||||
simd::load<AccT, simd_size>(b_block + jj * block_size + kk);
|
||||
c_block[ii * block_size + jj] += simd::sum(av * bv);
|
||||
}
|
||||
for (; kk < last_k_block_size; ++kk) {
|
||||
c_block[ii * block_size + jj] +=
|
||||
a_block[ii * block_size + kk] * b_block[jj * block_size + kk];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Store
|
||||
for (int ii = 0; ii < block_size && i * block_size + ii < M; ++ii) {
|
||||
for (int jj = 0; jj < block_size && j * block_size + jj < N; ++jj) {
|
||||
auto c_idx = (i * block_size + ii) * N + j * block_size + jj;
|
||||
if (beta != 0) {
|
||||
c[c_idx] = static_cast<T>(
|
||||
alpha * c_block[ii * block_size + jj] + beta * c[c_idx]);
|
||||
} else {
|
||||
c[c_idx] = static_cast<T>(alpha * c_block[ii * block_size + jj]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
@@ -4,16 +4,17 @@
|
||||
|
||||
#include "mlx/backend/common/hadamard.h"
|
||||
#include "mlx/backend/cpu/copy.h"
|
||||
#include "mlx/backend/cpu/encoder.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
// n = 2^k component
|
||||
template <typename T>
|
||||
void hadamard_n(array& out, int n, int m, float scale) {
|
||||
for (int b = 0; b < out.size() / n; b++) {
|
||||
void hadamard_n(T* out, int n, int m, float scale, size_t size) {
|
||||
for (int b = 0; b < size / n; b++) {
|
||||
size_t loc = b * n;
|
||||
T* data_ptr = out.data<T>() + loc;
|
||||
T* data_ptr = out + loc;
|
||||
int h = 1;
|
||||
int n_over_2 = n / 2;
|
||||
while (h < n) {
|
||||
@@ -36,7 +37,7 @@ void hadamard_n(array& out, int n, int m, float scale) {
|
||||
|
||||
// m component
|
||||
template <typename T>
|
||||
void hadamard_m(array& out, int n, int m, float scale) {
|
||||
void hadamard_m(T* out, int n, int m, float scale, size_t size) {
|
||||
auto h_matrices = hadamard_matrices();
|
||||
auto& matrix = h_matrices[m];
|
||||
auto start = 1;
|
||||
@@ -51,9 +52,9 @@ void hadamard_m(array& out, int n, int m, float scale) {
|
||||
end = matrix.find('\n', start);
|
||||
}
|
||||
|
||||
for (int b = 0; b < out.size() / m / n; b++) {
|
||||
for (int b = 0; b < size / m / n; b++) {
|
||||
size_t loc = b * n * m;
|
||||
T* data_ptr = out.data<T>() + loc;
|
||||
T* data_ptr = out + loc;
|
||||
for (int i = 0; i < n; i++) {
|
||||
std::vector<float> out(m);
|
||||
for (int j = 0; j < m; j++) {
|
||||
@@ -74,12 +75,17 @@ void hadamard_m(array& out, int n, int m, float scale) {
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void hadamard(array& out, int n, int m, float scale) {
|
||||
float n_scale = m > 1 ? 1.0 : scale;
|
||||
hadamard_n<T>(out, n, m, n_scale);
|
||||
if (m > 1) {
|
||||
hadamard_m<T>(out, n, m, scale);
|
||||
}
|
||||
void hadamard(array& out, int n, int m, float scale, Stream stream) {
|
||||
auto& encoder = cpu::get_command_encoder(stream);
|
||||
encoder.set_output_array(out);
|
||||
auto out_ptr = out.data<T>();
|
||||
encoder.dispatch([out_ptr, size = out.size(), n, m, scale]() {
|
||||
float n_scale = m > 1 ? 1.0 : scale;
|
||||
hadamard_n<T>(out_ptr, n, m, n_scale, size);
|
||||
if (m > 1) {
|
||||
hadamard_m<T>(out_ptr, n, m, scale, size);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
void Hadamard::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
@@ -87,18 +93,26 @@ void Hadamard::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
auto& in = inputs[0];
|
||||
|
||||
// Copy input to output
|
||||
copy(in, out, CopyType::General);
|
||||
if (in.flags().row_contiguous && in.is_donatable()) {
|
||||
out.copy_shared_buffer(in);
|
||||
} else {
|
||||
copy(
|
||||
in,
|
||||
out,
|
||||
in.flags().row_contiguous ? CopyType::Vector : CopyType::General,
|
||||
stream());
|
||||
}
|
||||
|
||||
int axis = out.ndim() - 1;
|
||||
auto [n, m] = decompose_hadamard(out.shape(axis));
|
||||
|
||||
switch (in.dtype()) {
|
||||
case float32:
|
||||
return hadamard<float>(out, n, m, scale_);
|
||||
return hadamard<float>(out, n, m, scale_, stream());
|
||||
case float16:
|
||||
return hadamard<float16_t>(out, n, m, scale_);
|
||||
return hadamard<float16_t>(out, n, m, scale_, stream());
|
||||
case bfloat16:
|
||||
return hadamard<bfloat16_t>(out, n, m, scale_);
|
||||
return hadamard<bfloat16_t>(out, n, m, scale_, stream());
|
||||
default:
|
||||
throw std::invalid_argument("[hadamard] Unsupported type.");
|
||||
}
|
||||
|
||||
@@ -8,6 +8,7 @@
|
||||
|
||||
#include "mlx/backend/common/utils.h"
|
||||
#include "mlx/backend/cpu/copy.h"
|
||||
#include "mlx/backend/cpu/encoder.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
@@ -21,6 +22,40 @@ inline size_t offset_neg_idx(uint32_t idx, size_t) {
|
||||
return idx;
|
||||
}
|
||||
|
||||
struct None {
|
||||
template <typename T>
|
||||
void operator()(T x, T* y) {
|
||||
(*y) = x;
|
||||
}
|
||||
};
|
||||
struct Sum {
|
||||
template <typename T>
|
||||
void operator()(T x, T* y) {
|
||||
(*y) += x;
|
||||
}
|
||||
};
|
||||
|
||||
struct Prod {
|
||||
template <typename T>
|
||||
void operator()(T x, T* y) {
|
||||
(*y) *= x;
|
||||
}
|
||||
};
|
||||
|
||||
struct Max {
|
||||
template <typename T>
|
||||
void operator()(T x, T* y) {
|
||||
(*y) = (*y > x) ? *y : x;
|
||||
}
|
||||
};
|
||||
|
||||
struct Min {
|
||||
template <typename T>
|
||||
void operator()(T x, T* y) {
|
||||
(*y) = (*y < x) ? *y : x;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, typename IdxT>
|
||||
void gather(
|
||||
const array& src,
|
||||
@@ -73,13 +108,14 @@ void gather(
|
||||
size_t ind_size = slice_size == 0 ? 0 : out.size() / slice_size;
|
||||
const T* src_ptr = src.data<T>();
|
||||
T* dst_ptr = out.data<T>();
|
||||
size_t out_idx = 0;
|
||||
|
||||
std::vector<ContiguousIterator> its(inds.begin(), inds.end());
|
||||
ContiguousIterator src_it;
|
||||
if (!can_copy && src.ndim() > 0) {
|
||||
src_it = ContiguousIterator(slice_sizes, src.strides(), src.ndim());
|
||||
}
|
||||
|
||||
size_t out_idx = 0;
|
||||
for (int idx = 0; idx < ind_size; idx++) {
|
||||
size_t src_idx = 0;
|
||||
for (int ii = 0; ii < inds.size(); ++ii) {
|
||||
@@ -161,46 +197,59 @@ void dispatch_gather(
|
||||
}
|
||||
|
||||
void Gather::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
|
||||
auto& src = inputs[0];
|
||||
std::vector<array> inds(inputs.begin() + 1, inputs.end());
|
||||
|
||||
if (inds.empty()) {
|
||||
dispatch_gather<uint8_t>(src, inds, out, axes_, slice_sizes_);
|
||||
return;
|
||||
std::vector<array> inds;
|
||||
for (auto it = inputs.begin() + 1; it < inputs.end(); ++it) {
|
||||
inds.push_back(array::unsafe_weak_copy(*it));
|
||||
}
|
||||
|
||||
switch (inds[0].dtype()) {
|
||||
case uint8:
|
||||
auto& encoder = cpu::get_command_encoder(stream());
|
||||
for (auto& in : inputs) {
|
||||
encoder.set_input_array(in);
|
||||
}
|
||||
encoder.set_output_array(out);
|
||||
encoder.dispatch([axes_ = axes_,
|
||||
slice_sizes_ = slice_sizes_,
|
||||
src = array::unsafe_weak_copy(src),
|
||||
inds = std::move(inds),
|
||||
out = array::unsafe_weak_copy(out)]() mutable {
|
||||
if (inds.empty()) {
|
||||
dispatch_gather<uint8_t>(src, inds, out, axes_, slice_sizes_);
|
||||
break;
|
||||
case uint16:
|
||||
dispatch_gather<uint16_t>(src, inds, out, axes_, slice_sizes_);
|
||||
break;
|
||||
case uint32:
|
||||
dispatch_gather<uint32_t>(src, inds, out, axes_, slice_sizes_);
|
||||
break;
|
||||
case uint64:
|
||||
dispatch_gather<uint64_t>(src, inds, out, axes_, slice_sizes_);
|
||||
break;
|
||||
case int8:
|
||||
dispatch_gather<int8_t>(src, inds, out, axes_, slice_sizes_);
|
||||
break;
|
||||
case int16:
|
||||
dispatch_gather<int16_t>(src, inds, out, axes_, slice_sizes_);
|
||||
break;
|
||||
case int32:
|
||||
dispatch_gather<int32_t>(src, inds, out, axes_, slice_sizes_);
|
||||
break;
|
||||
case int64:
|
||||
dispatch_gather<int64_t>(src, inds, out, axes_, slice_sizes_);
|
||||
break;
|
||||
default:
|
||||
throw std::runtime_error(
|
||||
"[Gather::eval_cpu] Cannot gather with indices type.");
|
||||
break;
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
switch (inds[0].dtype()) {
|
||||
case uint8:
|
||||
dispatch_gather<uint8_t>(src, inds, out, axes_, slice_sizes_);
|
||||
break;
|
||||
case uint16:
|
||||
dispatch_gather<uint16_t>(src, inds, out, axes_, slice_sizes_);
|
||||
break;
|
||||
case uint32:
|
||||
dispatch_gather<uint32_t>(src, inds, out, axes_, slice_sizes_);
|
||||
break;
|
||||
case uint64:
|
||||
dispatch_gather<uint64_t>(src, inds, out, axes_, slice_sizes_);
|
||||
break;
|
||||
case int8:
|
||||
dispatch_gather<int8_t>(src, inds, out, axes_, slice_sizes_);
|
||||
break;
|
||||
case int16:
|
||||
dispatch_gather<int16_t>(src, inds, out, axes_, slice_sizes_);
|
||||
break;
|
||||
case int32:
|
||||
dispatch_gather<int32_t>(src, inds, out, axes_, slice_sizes_);
|
||||
break;
|
||||
case int64:
|
||||
dispatch_gather<int64_t>(src, inds, out, axes_, slice_sizes_);
|
||||
break;
|
||||
default:
|
||||
throw std::runtime_error(
|
||||
"[Gather::eval_cpu] Cannot gather with indices type.");
|
||||
break;
|
||||
}
|
||||
});
|
||||
}
|
||||
template <typename T, typename IdxT>
|
||||
void gather_axis(
|
||||
@@ -235,6 +284,7 @@ void gather_axis(
|
||||
for (int i = axis + 1; i < ind.ndim(); ++i) {
|
||||
size_post *= ind.shape(i);
|
||||
}
|
||||
|
||||
size_t stride_pre = size_post * ind_ax_size;
|
||||
for (size_t i = 0; i < size_pre; i++) {
|
||||
for (size_t k = 0; k < size_post; k++) {
|
||||
@@ -304,39 +354,49 @@ void dispatch_gather_axis(
|
||||
}
|
||||
|
||||
void GatherAxis::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
|
||||
auto& src = inputs[0];
|
||||
auto& inds = inputs[1];
|
||||
switch (inds.dtype()) {
|
||||
case uint8:
|
||||
dispatch_gather_axis<uint8_t>(src, inds, out, axis_);
|
||||
break;
|
||||
case uint16:
|
||||
dispatch_gather_axis<uint16_t>(src, inds, out, axis_);
|
||||
break;
|
||||
case uint32:
|
||||
dispatch_gather_axis<uint32_t>(src, inds, out, axis_);
|
||||
break;
|
||||
case uint64:
|
||||
dispatch_gather_axis<uint64_t>(src, inds, out, axis_);
|
||||
break;
|
||||
case int8:
|
||||
dispatch_gather_axis<int8_t>(src, inds, out, axis_);
|
||||
break;
|
||||
case int16:
|
||||
dispatch_gather_axis<int16_t>(src, inds, out, axis_);
|
||||
break;
|
||||
case int32:
|
||||
dispatch_gather_axis<int32_t>(src, inds, out, axis_);
|
||||
break;
|
||||
case int64:
|
||||
dispatch_gather_axis<int64_t>(src, inds, out, axis_);
|
||||
break;
|
||||
default:
|
||||
throw std::runtime_error(
|
||||
"[GatherAxis::eval_cpu] Cannot gather with indices type.");
|
||||
break;
|
||||
}
|
||||
auto& encoder = cpu::get_command_encoder(stream());
|
||||
encoder.set_input_array(src);
|
||||
encoder.set_input_array(inds);
|
||||
encoder.set_output_array(out);
|
||||
encoder.dispatch([axis_ = axis_,
|
||||
src = array::unsafe_weak_copy(src),
|
||||
inds = array::unsafe_weak_copy(inds),
|
||||
out = array::unsafe_weak_copy(out)]() mutable {
|
||||
switch (inds.dtype()) {
|
||||
case uint8:
|
||||
dispatch_gather_axis<uint8_t>(src, inds, out, axis_);
|
||||
break;
|
||||
case uint16:
|
||||
dispatch_gather_axis<uint16_t>(src, inds, out, axis_);
|
||||
break;
|
||||
case uint32:
|
||||
dispatch_gather_axis<uint32_t>(src, inds, out, axis_);
|
||||
break;
|
||||
case uint64:
|
||||
dispatch_gather_axis<uint64_t>(src, inds, out, axis_);
|
||||
break;
|
||||
case int8:
|
||||
dispatch_gather_axis<int8_t>(src, inds, out, axis_);
|
||||
break;
|
||||
case int16:
|
||||
dispatch_gather_axis<int16_t>(src, inds, out, axis_);
|
||||
break;
|
||||
case int32:
|
||||
dispatch_gather_axis<int32_t>(src, inds, out, axis_);
|
||||
break;
|
||||
case int64:
|
||||
dispatch_gather_axis<int64_t>(src, inds, out, axis_);
|
||||
break;
|
||||
default:
|
||||
throw std::runtime_error(
|
||||
"[GatherAxis::eval_cpu] Cannot gather with indices type.");
|
||||
break;
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
template <typename InT, typename IdxT, typename OpT>
|
||||
@@ -344,8 +404,7 @@ void scatter(
|
||||
const array& updates,
|
||||
array& out,
|
||||
const std::vector<array>& inds,
|
||||
const std::vector<int>& axes,
|
||||
const OpT& op) {
|
||||
const std::vector<int>& axes) {
|
||||
int nind = inds.size();
|
||||
auto inds_ndim = updates.ndim() - out.ndim();
|
||||
size_t n_updates = nind ? inds[0].size() : 1;
|
||||
@@ -361,9 +420,11 @@ void scatter(
|
||||
ContiguousIterator update_it(updates);
|
||||
ContiguousIterator out_it(update_shape, out.strides(), out.ndim());
|
||||
|
||||
auto out_ptr = out.data<InT>();
|
||||
auto upd_ptr = updates.data<InT>();
|
||||
for (int i = 0; i < n_updates; ++i) {
|
||||
size_t out_offset = 0;
|
||||
for (int j = 0; j < nind; ++j) {
|
||||
for (int j = 0; j < inds.size(); ++j) {
|
||||
auto ax = axes[j];
|
||||
auto idx_loc = its[j].loc;
|
||||
its[j].step();
|
||||
@@ -373,8 +434,7 @@ void scatter(
|
||||
}
|
||||
update_it.seek(i * update_size);
|
||||
for (int j = 0; j < update_size; ++j) {
|
||||
op(updates.data<InT>()[update_it.loc],
|
||||
out.data<InT>() + out_offset + out_it.loc);
|
||||
OpT{}(upd_ptr[update_it.loc], out_ptr + out_offset + out_it.loc);
|
||||
update_it.step();
|
||||
out_it.step();
|
||||
}
|
||||
@@ -392,26 +452,19 @@ void dispatch_scatter_inds(
|
||||
Scatter::ReduceType rtype) {
|
||||
switch (rtype) {
|
||||
case Scatter::None:
|
||||
scatter<InT, IdxT>(
|
||||
updates, out, indices, axes, [](auto x, auto* y) { (*y) = x; });
|
||||
scatter<InT, IdxT, None>(updates, out, indices, axes);
|
||||
break;
|
||||
case Scatter::Sum:
|
||||
scatter<InT, IdxT>(
|
||||
updates, out, indices, axes, [](auto x, auto* y) { (*y) += x; });
|
||||
scatter<InT, IdxT, Sum>(updates, out, indices, axes);
|
||||
break;
|
||||
case Scatter::Prod:
|
||||
scatter<InT, IdxT>(
|
||||
updates, out, indices, axes, [](auto x, auto* y) { (*y) *= x; });
|
||||
scatter<InT, IdxT, Prod>(updates, out, indices, axes);
|
||||
break;
|
||||
case Scatter::Max:
|
||||
scatter<InT, IdxT>(updates, out, indices, axes, [](auto x, auto* y) {
|
||||
(*y) = (*y > x) ? *y : x;
|
||||
});
|
||||
scatter<InT, IdxT, Max>(updates, out, indices, axes);
|
||||
break;
|
||||
case Scatter::Min:
|
||||
scatter<InT, IdxT>(updates, out, indices, axes, [](auto x, auto* y) {
|
||||
(*y) = (*y < x) ? *y : x;
|
||||
});
|
||||
scatter<InT, IdxT, Min>(updates, out, indices, axes);
|
||||
break;
|
||||
}
|
||||
}
|
||||
@@ -463,67 +516,75 @@ void Scatter::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() >= 2);
|
||||
|
||||
auto& src = inputs[0];
|
||||
std::vector<array> inds(inputs.begin() + 1, inputs.end() - 1);
|
||||
auto& updates = inputs.back();
|
||||
|
||||
// Copy src into out (copy allocates memory for out)
|
||||
auto ctype =
|
||||
src.flags().row_contiguous ? CopyType::Vector : CopyType::General;
|
||||
copy(src, out, ctype);
|
||||
copy(src, out, ctype, stream());
|
||||
|
||||
switch (src.dtype()) {
|
||||
case bool_:
|
||||
dispatch_scatter<bool>(out, inds, updates, axes_, reduce_type_);
|
||||
break;
|
||||
case uint8:
|
||||
dispatch_scatter<uint8_t>(out, inds, updates, axes_, reduce_type_);
|
||||
break;
|
||||
case uint16:
|
||||
dispatch_scatter<uint16_t>(out, inds, updates, axes_, reduce_type_);
|
||||
break;
|
||||
case uint32:
|
||||
dispatch_scatter<uint32_t>(out, inds, updates, axes_, reduce_type_);
|
||||
break;
|
||||
case uint64:
|
||||
dispatch_scatter<uint64_t>(out, inds, updates, axes_, reduce_type_);
|
||||
break;
|
||||
case int8:
|
||||
dispatch_scatter<int8_t>(out, inds, updates, axes_, reduce_type_);
|
||||
break;
|
||||
case int16:
|
||||
dispatch_scatter<int16_t>(out, inds, updates, axes_, reduce_type_);
|
||||
break;
|
||||
case int32:
|
||||
dispatch_scatter<int32_t>(out, inds, updates, axes_, reduce_type_);
|
||||
break;
|
||||
case int64:
|
||||
dispatch_scatter<int64_t>(out, inds, updates, axes_, reduce_type_);
|
||||
break;
|
||||
case float16:
|
||||
dispatch_scatter<float16_t>(out, inds, updates, axes_, reduce_type_);
|
||||
break;
|
||||
case float32:
|
||||
dispatch_scatter<float>(out, inds, updates, axes_, reduce_type_);
|
||||
break;
|
||||
case float64:
|
||||
dispatch_scatter<double>(out, inds, updates, axes_, reduce_type_);
|
||||
break;
|
||||
case bfloat16:
|
||||
dispatch_scatter<bfloat16_t>(out, inds, updates, axes_, reduce_type_);
|
||||
break;
|
||||
case complex64:
|
||||
dispatch_scatter<complex64_t>(out, inds, updates, axes_, reduce_type_);
|
||||
break;
|
||||
auto& encoder = cpu::get_command_encoder(stream());
|
||||
std::vector<array> inds;
|
||||
for (auto it = inputs.begin() + 1; it < inputs.end() - 1; ++it) {
|
||||
encoder.set_input_array(*it);
|
||||
inds.push_back(array::unsafe_weak_copy(*it));
|
||||
}
|
||||
encoder.set_input_array(updates);
|
||||
encoder.set_output_array(out);
|
||||
encoder.dispatch([axes_ = axes_,
|
||||
reduce_type_ = reduce_type_,
|
||||
updates = array::unsafe_weak_copy(updates),
|
||||
inds = std::move(inds),
|
||||
out = array::unsafe_weak_copy(out)]() mutable {
|
||||
switch (out.dtype()) {
|
||||
case bool_:
|
||||
dispatch_scatter<bool>(out, inds, updates, axes_, reduce_type_);
|
||||
break;
|
||||
case uint8:
|
||||
dispatch_scatter<uint8_t>(out, inds, updates, axes_, reduce_type_);
|
||||
break;
|
||||
case uint16:
|
||||
dispatch_scatter<uint16_t>(out, inds, updates, axes_, reduce_type_);
|
||||
break;
|
||||
case uint32:
|
||||
dispatch_scatter<uint32_t>(out, inds, updates, axes_, reduce_type_);
|
||||
break;
|
||||
case uint64:
|
||||
dispatch_scatter<uint64_t>(out, inds, updates, axes_, reduce_type_);
|
||||
break;
|
||||
case int8:
|
||||
dispatch_scatter<int8_t>(out, inds, updates, axes_, reduce_type_);
|
||||
break;
|
||||
case int16:
|
||||
dispatch_scatter<int16_t>(out, inds, updates, axes_, reduce_type_);
|
||||
break;
|
||||
case int32:
|
||||
dispatch_scatter<int32_t>(out, inds, updates, axes_, reduce_type_);
|
||||
break;
|
||||
case int64:
|
||||
dispatch_scatter<int64_t>(out, inds, updates, axes_, reduce_type_);
|
||||
break;
|
||||
case float16:
|
||||
dispatch_scatter<float16_t>(out, inds, updates, axes_, reduce_type_);
|
||||
break;
|
||||
case float32:
|
||||
dispatch_scatter<float>(out, inds, updates, axes_, reduce_type_);
|
||||
break;
|
||||
case float64:
|
||||
dispatch_scatter<double>(out, inds, updates, axes_, reduce_type_);
|
||||
break;
|
||||
case bfloat16:
|
||||
dispatch_scatter<bfloat16_t>(out, inds, updates, axes_, reduce_type_);
|
||||
break;
|
||||
case complex64:
|
||||
dispatch_scatter<complex64_t>(out, inds, updates, axes_, reduce_type_);
|
||||
break;
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
template <typename T, typename IdxT, typename OpT>
|
||||
void scatter_axis(
|
||||
array& out,
|
||||
const array idx,
|
||||
const array& upd,
|
||||
int axis,
|
||||
const OpT& op) {
|
||||
void scatter_axis(array& out, const array idx, const array& upd, int axis) {
|
||||
auto strides = idx.strides();
|
||||
strides.erase(strides.begin() + axis);
|
||||
auto shape = idx.shape();
|
||||
@@ -557,8 +618,9 @@ void scatter_axis(
|
||||
for (int j = 0; j < idx_ax_size; ++j) {
|
||||
auto ind_val = offset_neg_idx(
|
||||
idx_ptr[idx_it.loc + j * idx_ax_stride], dst_ax_size);
|
||||
op(upd_ptr[upd_it.loc + j * upd_ax_stride],
|
||||
dst_ptr + k + ind_val * dst_ax_stride);
|
||||
OpT{}(
|
||||
upd_ptr[upd_it.loc + j * upd_ax_stride],
|
||||
dst_ptr + k + ind_val * dst_ax_stride);
|
||||
}
|
||||
idx_it.step();
|
||||
upd_it.step();
|
||||
@@ -576,12 +638,10 @@ void dispatch_scatter_axis_op(
|
||||
ScatterAxis::ReduceType rtype) {
|
||||
switch (rtype) {
|
||||
case ScatterAxis::None:
|
||||
scatter_axis<InT, IdxT>(
|
||||
out, idx, updates, axis, [](auto x, auto* y) { (*y) = x; });
|
||||
scatter_axis<InT, IdxT, None>(out, idx, updates, axis);
|
||||
break;
|
||||
case ScatterAxis::Sum:
|
||||
scatter_axis<InT, IdxT>(
|
||||
out, idx, updates, axis, [](auto x, auto* y) { (*y) += x; });
|
||||
scatter_axis<InT, IdxT, Sum>(out, idx, updates, axis);
|
||||
break;
|
||||
}
|
||||
}
|
||||
@@ -634,53 +694,65 @@ void ScatterAxis::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
// Copy src into out (copy allocates memory for out)
|
||||
auto ctype =
|
||||
src.flags().row_contiguous ? CopyType::Vector : CopyType::General;
|
||||
copy(src, out, ctype);
|
||||
copy(src, out, ctype, stream());
|
||||
|
||||
switch (src.dtype()) {
|
||||
case bool_:
|
||||
dispatch_scatter_axis<bool>(out, idx, updates, axis_, reduce_type_);
|
||||
break;
|
||||
case uint8:
|
||||
dispatch_scatter_axis<uint8_t>(out, idx, updates, axis_, reduce_type_);
|
||||
break;
|
||||
case uint16:
|
||||
dispatch_scatter_axis<uint16_t>(out, idx, updates, axis_, reduce_type_);
|
||||
break;
|
||||
case uint32:
|
||||
dispatch_scatter_axis<uint32_t>(out, idx, updates, axis_, reduce_type_);
|
||||
break;
|
||||
case uint64:
|
||||
dispatch_scatter_axis<uint64_t>(out, idx, updates, axis_, reduce_type_);
|
||||
break;
|
||||
case int8:
|
||||
dispatch_scatter_axis<int8_t>(out, idx, updates, axis_, reduce_type_);
|
||||
break;
|
||||
case int16:
|
||||
dispatch_scatter_axis<int16_t>(out, idx, updates, axis_, reduce_type_);
|
||||
break;
|
||||
case int32:
|
||||
dispatch_scatter_axis<int32_t>(out, idx, updates, axis_, reduce_type_);
|
||||
break;
|
||||
case int64:
|
||||
dispatch_scatter_axis<int64_t>(out, idx, updates, axis_, reduce_type_);
|
||||
break;
|
||||
case float16:
|
||||
dispatch_scatter_axis<float16_t>(out, idx, updates, axis_, reduce_type_);
|
||||
break;
|
||||
case float32:
|
||||
dispatch_scatter_axis<float>(out, idx, updates, axis_, reduce_type_);
|
||||
break;
|
||||
case float64:
|
||||
dispatch_scatter_axis<double>(out, idx, updates, axis_, reduce_type_);
|
||||
break;
|
||||
case bfloat16:
|
||||
dispatch_scatter_axis<bfloat16_t>(out, idx, updates, axis_, reduce_type_);
|
||||
break;
|
||||
case complex64:
|
||||
dispatch_scatter_axis<complex64_t>(
|
||||
out, idx, updates, axis_, reduce_type_);
|
||||
break;
|
||||
}
|
||||
auto& encoder = cpu::get_command_encoder(stream());
|
||||
encoder.set_input_array(idx);
|
||||
encoder.set_input_array(updates);
|
||||
encoder.set_output_array(out);
|
||||
encoder.dispatch([axis_ = axis_,
|
||||
reduce_type_ = reduce_type_,
|
||||
idx = array::unsafe_weak_copy(idx),
|
||||
updates = array::unsafe_weak_copy(updates),
|
||||
out = array::unsafe_weak_copy(out)]() mutable {
|
||||
switch (out.dtype()) {
|
||||
case bool_:
|
||||
dispatch_scatter_axis<bool>(out, idx, updates, axis_, reduce_type_);
|
||||
break;
|
||||
case uint8:
|
||||
dispatch_scatter_axis<uint8_t>(out, idx, updates, axis_, reduce_type_);
|
||||
break;
|
||||
case uint16:
|
||||
dispatch_scatter_axis<uint16_t>(out, idx, updates, axis_, reduce_type_);
|
||||
break;
|
||||
case uint32:
|
||||
dispatch_scatter_axis<uint32_t>(out, idx, updates, axis_, reduce_type_);
|
||||
break;
|
||||
case uint64:
|
||||
dispatch_scatter_axis<uint64_t>(out, idx, updates, axis_, reduce_type_);
|
||||
break;
|
||||
case int8:
|
||||
dispatch_scatter_axis<int8_t>(out, idx, updates, axis_, reduce_type_);
|
||||
break;
|
||||
case int16:
|
||||
dispatch_scatter_axis<int16_t>(out, idx, updates, axis_, reduce_type_);
|
||||
break;
|
||||
case int32:
|
||||
dispatch_scatter_axis<int32_t>(out, idx, updates, axis_, reduce_type_);
|
||||
break;
|
||||
case int64:
|
||||
dispatch_scatter_axis<int64_t>(out, idx, updates, axis_, reduce_type_);
|
||||
break;
|
||||
case float16:
|
||||
dispatch_scatter_axis<float16_t>(
|
||||
out, idx, updates, axis_, reduce_type_);
|
||||
break;
|
||||
case float32:
|
||||
dispatch_scatter_axis<float>(out, idx, updates, axis_, reduce_type_);
|
||||
break;
|
||||
case float64:
|
||||
dispatch_scatter_axis<double>(out, idx, updates, axis_, reduce_type_);
|
||||
break;
|
||||
case bfloat16:
|
||||
dispatch_scatter_axis<bfloat16_t>(
|
||||
out, idx, updates, axis_, reduce_type_);
|
||||
break;
|
||||
case complex64:
|
||||
dispatch_scatter_axis<complex64_t>(
|
||||
out, idx, updates, axis_, reduce_type_);
|
||||
break;
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
||||
@@ -2,20 +2,21 @@
|
||||
|
||||
#include "mlx/allocator.h"
|
||||
#include "mlx/backend/cpu/copy.h"
|
||||
#include "mlx/backend/cpu/encoder.h"
|
||||
#include "mlx/backend/cpu/lapack.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
template <typename T>
|
||||
void general_inv(array& inv, int N, int i) {
|
||||
void general_inv(T* inv, int N) {
|
||||
int info;
|
||||
auto ipiv = array::Data{allocator::malloc_or_wait(sizeof(int) * N)};
|
||||
auto ipiv = array::Data{allocator::malloc(sizeof(int) * N)};
|
||||
// Compute LU factorization.
|
||||
getrf<T>(
|
||||
/* m = */ &N,
|
||||
/* n = */ &N,
|
||||
/* a = */ inv.data<T>() + N * N * i,
|
||||
/* a = */ inv,
|
||||
/* lda = */ &N,
|
||||
/* ipiv = */ static_cast<int*>(ipiv.buffer.raw_ptr()),
|
||||
/* info = */ &info);
|
||||
@@ -48,12 +49,12 @@ void general_inv(array& inv, int N, int i) {
|
||||
}
|
||||
|
||||
const int lwork = workspace_size;
|
||||
auto scratch = array::Data{allocator::malloc_or_wait(sizeof(T) * lwork)};
|
||||
auto scratch = array::Data{allocator::malloc(sizeof(T) * lwork)};
|
||||
|
||||
// Compute inverse.
|
||||
getri<T>(
|
||||
/* m = */ &N,
|
||||
/* a = */ inv.data<T>() + N * N * i,
|
||||
/* a = */ inv,
|
||||
/* lda = */ &N,
|
||||
/* ipiv = */ static_cast<int*>(ipiv.buffer.raw_ptr()),
|
||||
/* work = */ static_cast<T*>(scratch.buffer.raw_ptr()),
|
||||
@@ -68,29 +69,28 @@ void general_inv(array& inv, int N, int i) {
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void tri_inv(array& inv, int N, int i, bool upper) {
|
||||
void tri_inv(T* inv, int N, bool upper) {
|
||||
const char uplo = upper ? 'L' : 'U';
|
||||
const char diag = 'N';
|
||||
T* data = inv.data<T>() + N * N * i;
|
||||
int info;
|
||||
trtri<T>(
|
||||
/* uplo = */ &uplo,
|
||||
/* diag = */ &diag,
|
||||
/* N = */ &N,
|
||||
/* a = */ data,
|
||||
/* a = */ inv,
|
||||
/* lda = */ &N,
|
||||
/* info = */ &info);
|
||||
|
||||
// zero out the other triangle
|
||||
if (upper) {
|
||||
for (int i = 0; i < N; i++) {
|
||||
std::fill(data, data + i, 0.0f);
|
||||
data += N;
|
||||
std::fill(inv, inv + i, 0.0f);
|
||||
inv += N;
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < N; i++) {
|
||||
std::fill(data + i + 1, data + N, 0.0f);
|
||||
data += N;
|
||||
std::fill(inv + i + 1, inv + N, 0.0f);
|
||||
inv += N;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -103,34 +103,53 @@ void tri_inv(array& inv, int N, int i, bool upper) {
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void inverse_impl(const array& a, array& inv, bool tri, bool upper) {
|
||||
void inverse_impl(
|
||||
const array& a,
|
||||
array& inv,
|
||||
bool tri,
|
||||
bool upper,
|
||||
Stream stream) {
|
||||
// Lapack uses the column-major convention. We take advantage of the following
|
||||
// identity to avoid transposing (see
|
||||
// https://math.stackexchange.com/a/340234):
|
||||
// (A⁻¹)ᵀ = (Aᵀ)⁻¹
|
||||
|
||||
// The inverse is computed in place, so just copy the input to the output.
|
||||
copy(a, inv, a.flags().row_contiguous ? CopyType::Vector : CopyType::General);
|
||||
copy(
|
||||
a,
|
||||
inv,
|
||||
a.flags().row_contiguous ? CopyType::Vector : CopyType::General,
|
||||
stream);
|
||||
|
||||
const int N = a.shape(-1);
|
||||
const size_t num_matrices = a.size() / (N * N);
|
||||
|
||||
for (int i = 0; i < num_matrices; i++) {
|
||||
if (tri) {
|
||||
tri_inv<T>(inv, N, i, upper);
|
||||
} else {
|
||||
general_inv<T>(inv, N, i);
|
||||
}
|
||||
auto& encoder = cpu::get_command_encoder(stream);
|
||||
encoder.set_output_array(inv);
|
||||
|
||||
auto inv_ptr = inv.data<T>();
|
||||
if (tri) {
|
||||
encoder.dispatch([inv_ptr, N, num_matrices, upper]() {
|
||||
for (int i = 0; i < num_matrices; i++) {
|
||||
tri_inv<T>(inv_ptr + N * N * i, N, upper);
|
||||
}
|
||||
});
|
||||
} else {
|
||||
encoder.dispatch([inv_ptr, N, num_matrices]() {
|
||||
for (int i = 0; i < num_matrices; i++) {
|
||||
general_inv<T>(inv_ptr + N * N * i, N);
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
void Inverse::eval_cpu(const std::vector<array>& inputs, array& output) {
|
||||
switch (inputs[0].dtype()) {
|
||||
case float32:
|
||||
inverse_impl<float>(inputs[0], output, tri_, upper_);
|
||||
inverse_impl<float>(inputs[0], output, tri_, upper_, stream());
|
||||
break;
|
||||
case float64:
|
||||
inverse_impl<double>(inputs[0], output, tri_, upper_);
|
||||
inverse_impl<double>(inputs[0], output, tri_, upper_, stream());
|
||||
break;
|
||||
default:
|
||||
throw std::runtime_error(
|
||||
|
||||
140
mlx/backend/cpu/logsumexp.cpp
Normal file
140
mlx/backend/cpu/logsumexp.cpp
Normal file
@@ -0,0 +1,140 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include <cassert>
|
||||
#include <cmath>
|
||||
|
||||
#include "mlx/backend/cpu/copy.h"
|
||||
#include "mlx/backend/cpu/encoder.h"
|
||||
#include "mlx/backend/cpu/simd/simd.h"
|
||||
#include "mlx/primitives.h"
|
||||
#include "mlx/types/limits.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace {
|
||||
|
||||
using namespace mlx::core::simd;
|
||||
|
||||
template <typename T, typename AccT>
|
||||
void logsumexp(const array& in, array& out, Stream stream) {
|
||||
auto& encoder = cpu::get_command_encoder(stream);
|
||||
encoder.set_input_array(in);
|
||||
encoder.set_output_array(out);
|
||||
|
||||
const T* in_ptr = in.data<T>();
|
||||
T* out_ptr = out.data<T>();
|
||||
|
||||
int M = in.shape().back();
|
||||
int L = in.data_size() / M;
|
||||
|
||||
encoder.dispatch([in_ptr, out_ptr, M, L]() mutable {
|
||||
constexpr int N = std::min(max_size<AccT>, max_size<T>);
|
||||
|
||||
const T* current_in_ptr;
|
||||
|
||||
for (int i = 0; i < L; i++, in_ptr += M, out_ptr += 1) {
|
||||
// Find the maximum
|
||||
current_in_ptr = in_ptr;
|
||||
Simd<AccT, N> vmaximum(-numeric_limits<AccT>::infinity());
|
||||
size_t s = M;
|
||||
while (s >= N) {
|
||||
Simd<AccT, N> vals = load<T, N>(current_in_ptr);
|
||||
vmaximum = maximum(vals, vmaximum);
|
||||
current_in_ptr += N;
|
||||
s -= N;
|
||||
}
|
||||
|
||||
AccT maximum = max(vmaximum);
|
||||
while (s-- > 0) {
|
||||
maximum = std::max(maximum, static_cast<AccT>(*current_in_ptr));
|
||||
current_in_ptr++;
|
||||
}
|
||||
|
||||
// Compute the normalizer and the exponentials
|
||||
Simd<AccT, N> vnormalizer(0.0);
|
||||
current_in_ptr = in_ptr;
|
||||
s = M;
|
||||
while (s >= N) {
|
||||
Simd<AccT, N> vexp = load<T, N>(current_in_ptr);
|
||||
vexp = exp(vexp - maximum);
|
||||
vnormalizer = vnormalizer + vexp;
|
||||
current_in_ptr += N;
|
||||
s -= N;
|
||||
}
|
||||
AccT normalizer = sum(vnormalizer);
|
||||
while (s-- > 0) {
|
||||
AccT _exp = std::exp(*current_in_ptr - maximum);
|
||||
normalizer += _exp;
|
||||
current_in_ptr++;
|
||||
}
|
||||
// Normalize
|
||||
*out_ptr = std::isinf(maximum)
|
||||
? static_cast<T>(maximum)
|
||||
: static_cast<T>(std::log(normalizer) + maximum);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void LogSumExp::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
|
||||
// Make sure that the last dimension is contiguous
|
||||
auto s = stream();
|
||||
auto& encoder = cpu::get_command_encoder(s);
|
||||
auto ensure_contiguous = [&s, &encoder](const array& x) {
|
||||
if (x.flags().contiguous && x.strides()[x.ndim() - 1] == 1) {
|
||||
return x;
|
||||
} else {
|
||||
auto x_copy = array(x.shape(), x.dtype(), nullptr, {});
|
||||
copy(x, x_copy, CopyType::General, s);
|
||||
encoder.add_temporary(x_copy);
|
||||
return x_copy;
|
||||
}
|
||||
};
|
||||
|
||||
auto in = ensure_contiguous(inputs[0]);
|
||||
if (in.flags().row_contiguous) {
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
} else {
|
||||
auto n = in.shape(-1);
|
||||
auto flags = in.flags();
|
||||
auto strides = in.strides();
|
||||
for (auto& s : strides) {
|
||||
s /= n;
|
||||
}
|
||||
bool col_contig = strides[0] == 1;
|
||||
for (int i = 1; col_contig && i < strides.size(); ++i) {
|
||||
col_contig &=
|
||||
(out.shape(i) == 1 || strides[i - 1] == out.shape(i) * strides[i]);
|
||||
}
|
||||
flags.col_contiguous = col_contig;
|
||||
out.set_data(
|
||||
allocator::malloc(in.nbytes() / n),
|
||||
in.data_size() / n,
|
||||
std::move(strides),
|
||||
flags);
|
||||
}
|
||||
|
||||
switch (in.dtype()) {
|
||||
case float32:
|
||||
logsumexp<float, float>(in, out, stream());
|
||||
break;
|
||||
case float16:
|
||||
logsumexp<float16_t, float>(in, out, stream());
|
||||
break;
|
||||
case bfloat16:
|
||||
logsumexp<bfloat16_t, float>(in, out, stream());
|
||||
break;
|
||||
case float64:
|
||||
logsumexp<double, double>(in, out, stream());
|
||||
break;
|
||||
default:
|
||||
throw std::runtime_error(
|
||||
"[logsumexp] only supports floating point types");
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
@@ -4,15 +4,22 @@
|
||||
|
||||
#include "mlx/allocator.h"
|
||||
#include "mlx/backend/cpu/copy.h"
|
||||
#include "mlx/backend/cpu/encoder.h"
|
||||
#include "mlx/backend/cpu/lapack.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
template <typename T>
|
||||
void luf_impl(const array& a, array& lu, array& pivots, array& row_indices) {
|
||||
void luf_impl(
|
||||
const array& a,
|
||||
array& lu,
|
||||
array& pivots,
|
||||
array& row_indices,
|
||||
Stream stream) {
|
||||
int M = a.shape(-2);
|
||||
int N = a.shape(-1);
|
||||
int K = std::min(M, N);
|
||||
|
||||
// Copy a into lu and make it col contiguous
|
||||
auto ndim = lu.ndim();
|
||||
@@ -23,60 +30,74 @@ void luf_impl(const array& a, array& lu, array& pivots, array& row_indices) {
|
||||
auto strides = lu.strides();
|
||||
strides[ndim - 1] = M;
|
||||
strides[ndim - 2] = 1;
|
||||
lu.set_data(
|
||||
allocator::malloc_or_wait(lu.nbytes()), lu.nbytes(), strides, flags);
|
||||
lu.set_data(allocator::malloc(lu.nbytes()), lu.nbytes(), strides, flags);
|
||||
copy_inplace(
|
||||
a, lu, a.shape(), a.strides(), strides, 0, 0, CopyType::GeneralGeneral);
|
||||
a,
|
||||
lu,
|
||||
a.shape(),
|
||||
a.strides(),
|
||||
strides,
|
||||
0,
|
||||
0,
|
||||
CopyType::GeneralGeneral,
|
||||
stream);
|
||||
|
||||
auto a_ptr = lu.data<T>();
|
||||
|
||||
pivots.set_data(allocator::malloc_or_wait(pivots.nbytes()));
|
||||
row_indices.set_data(allocator::malloc_or_wait(row_indices.nbytes()));
|
||||
pivots.set_data(allocator::malloc(pivots.nbytes()));
|
||||
row_indices.set_data(allocator::malloc(row_indices.nbytes()));
|
||||
auto pivots_ptr = pivots.data<uint32_t>();
|
||||
auto row_indices_ptr = row_indices.data<uint32_t>();
|
||||
|
||||
int info;
|
||||
size_t num_matrices = a.size() / (M * N);
|
||||
for (size_t i = 0; i < num_matrices; ++i) {
|
||||
// Compute LU factorization of A
|
||||
getrf<T>(
|
||||
/* m */ &M,
|
||||
/* n */ &N,
|
||||
/* a */ a_ptr,
|
||||
/* lda */ &M,
|
||||
/* ipiv */ reinterpret_cast<int*>(pivots_ptr),
|
||||
/* info */ &info);
|
||||
auto& encoder = cpu::get_command_encoder(stream);
|
||||
encoder.set_input_array(a);
|
||||
encoder.set_output_array(lu);
|
||||
encoder.set_output_array(pivots);
|
||||
encoder.set_output_array(row_indices);
|
||||
|
||||
if (info != 0) {
|
||||
std::stringstream ss;
|
||||
ss << "[LUF::eval_cpu] sgetrf_ failed with code " << info
|
||||
<< ((info > 0) ? " because matrix is singular"
|
||||
: " because argument had an illegal value");
|
||||
throw std::runtime_error(ss.str());
|
||||
}
|
||||
encoder.dispatch(
|
||||
[a_ptr, pivots_ptr, row_indices_ptr, num_matrices, M, N, K]() mutable {
|
||||
int info;
|
||||
for (size_t i = 0; i < num_matrices; ++i) {
|
||||
// Compute LU factorization of A
|
||||
getrf<T>(
|
||||
/* m */ &M,
|
||||
/* n */ &N,
|
||||
/* a */ a_ptr,
|
||||
/* lda */ &M,
|
||||
/* ipiv */ reinterpret_cast<int*>(pivots_ptr),
|
||||
/* info */ &info);
|
||||
|
||||
// Subtract 1 to get 0-based index
|
||||
int j = 0;
|
||||
for (; j < pivots.shape(-1); ++j) {
|
||||
pivots_ptr[j]--;
|
||||
row_indices_ptr[j] = j;
|
||||
}
|
||||
for (; j < row_indices.shape(-1); ++j) {
|
||||
row_indices_ptr[j] = j;
|
||||
}
|
||||
for (int j = pivots.shape(-1) - 1; j >= 0; --j) {
|
||||
auto piv = pivots_ptr[j];
|
||||
auto t1 = row_indices_ptr[piv];
|
||||
auto t2 = row_indices_ptr[j];
|
||||
row_indices_ptr[j] = t1;
|
||||
row_indices_ptr[piv] = t2;
|
||||
}
|
||||
if (info != 0) {
|
||||
std::stringstream ss;
|
||||
ss << "[LUF::eval_cpu] sgetrf_ failed with code " << info
|
||||
<< ((info > 0) ? " because matrix is singular"
|
||||
: " because argument had an illegal value");
|
||||
throw std::runtime_error(ss.str());
|
||||
}
|
||||
|
||||
// Advance pointers to the next matrix
|
||||
a_ptr += M * N;
|
||||
pivots_ptr += pivots.shape(-1);
|
||||
row_indices_ptr += pivots.shape(-1);
|
||||
}
|
||||
// Subtract 1 to get 0-based index
|
||||
int j = 0;
|
||||
for (; j < K; ++j) {
|
||||
pivots_ptr[j]--;
|
||||
row_indices_ptr[j] = j;
|
||||
}
|
||||
for (; j < M; ++j) {
|
||||
row_indices_ptr[j] = j;
|
||||
}
|
||||
for (int j = K - 1; j >= 0; --j) {
|
||||
auto piv = pivots_ptr[j];
|
||||
auto t1 = row_indices_ptr[piv];
|
||||
auto t2 = row_indices_ptr[j];
|
||||
row_indices_ptr[j] = t1;
|
||||
row_indices_ptr[piv] = t2;
|
||||
}
|
||||
|
||||
// Advance pointers to the next matrix
|
||||
a_ptr += M * N;
|
||||
pivots_ptr += K;
|
||||
row_indices_ptr += M;
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
void LUF::eval_cpu(
|
||||
@@ -85,10 +106,10 @@ void LUF::eval_cpu(
|
||||
assert(inputs.size() == 1);
|
||||
switch (inputs[0].dtype()) {
|
||||
case float32:
|
||||
luf_impl<float>(inputs[0], outputs[0], outputs[1], outputs[2]);
|
||||
luf_impl<float>(inputs[0], outputs[0], outputs[1], outputs[2], stream());
|
||||
break;
|
||||
case float64:
|
||||
luf_impl<double>(inputs[0], outputs[0], outputs[1], outputs[2]);
|
||||
luf_impl<double>(inputs[0], outputs[0], outputs[1], outputs[2], stream());
|
||||
break;
|
||||
default:
|
||||
throw std::runtime_error(
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
#include "mlx/array.h"
|
||||
#include "mlx/backend/common/utils.h"
|
||||
#include "mlx/backend/cpu/copy.h"
|
||||
#include "mlx/backend/cpu/encoder.h"
|
||||
#include "mlx/backend/cpu/lapack.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
@@ -58,42 +59,42 @@ void BlockMaskedMM::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
throw std::runtime_error(
|
||||
"[BlockMaskedMM::eval] Currently only supports float32.");
|
||||
}
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
|
||||
auto& a_pre = inputs[0];
|
||||
auto& b_pre = inputs[1];
|
||||
|
||||
auto check_transpose =
|
||||
[](const array& arr, bool do_copy, bool expand_all = false) {
|
||||
[s = stream()](const array& arr, bool do_copy, bool expand_all = false) {
|
||||
auto stx = arr.strides()[arr.ndim() - 2];
|
||||
auto sty = arr.strides()[arr.ndim() - 1];
|
||||
if (!expand_all && stx == arr.shape(-1) && sty == 1) {
|
||||
if (do_copy) {
|
||||
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
|
||||
copy(arr, arr_copy, CopyType::Vector);
|
||||
return std::make_tuple(false, stx, arr_copy);
|
||||
copy(arr, arr_copy, CopyType::Vector, s);
|
||||
return std::make_tuple(false, stx, arr_copy, true);
|
||||
}
|
||||
return std::make_tuple(false, stx, arr);
|
||||
return std::make_tuple(false, stx, arr, false);
|
||||
} else if (!expand_all && stx == 1 && sty == arr.shape(-2)) {
|
||||
if (do_copy) {
|
||||
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
|
||||
copy(arr, arr_copy, CopyType::Vector);
|
||||
return std::make_tuple(true, sty, arr_copy);
|
||||
copy(arr, arr_copy, CopyType::Vector, s);
|
||||
return std::make_tuple(true, sty, arr_copy, true);
|
||||
}
|
||||
return std::make_tuple(true, sty, arr);
|
||||
return std::make_tuple(true, sty, arr, false);
|
||||
} else {
|
||||
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
|
||||
copy(arr, arr_copy, CopyType::General);
|
||||
copy(arr, arr_copy, CopyType::General, s);
|
||||
int64_t stx = arr.shape(-1);
|
||||
return std::make_tuple(false, stx, arr_copy);
|
||||
return std::make_tuple(false, stx, arr_copy, true);
|
||||
}
|
||||
};
|
||||
|
||||
bool has_op_mask = inputs.size() > 3;
|
||||
bool has_out_mask = inputs.size() == 3 || inputs.size() == 5;
|
||||
auto [a_transposed, lda, a] =
|
||||
auto [a_transposed, lda, a, a_copied] =
|
||||
check_transpose(a_pre, has_op_mask, inputs.back().dtype() != bool_);
|
||||
auto [b_transposed, ldb, b] =
|
||||
auto [b_transposed, ldb, b, b_copied] =
|
||||
check_transpose(b_pre, has_op_mask, inputs.back().dtype() != bool_);
|
||||
|
||||
size_t M = a.shape(-2);
|
||||
@@ -104,31 +105,39 @@ void BlockMaskedMM::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
return;
|
||||
}
|
||||
|
||||
auto& encoder = cpu::get_command_encoder(stream());
|
||||
if (K == 0) {
|
||||
std::memset(static_cast<void*>(out.data<float>()), 0, out.nbytes());
|
||||
encoder.set_output_array(out);
|
||||
encoder.dispatch([out_ptr = out.data<void>(), nbytes = out.nbytes()]() {
|
||||
std::memset(out_ptr, 0, nbytes);
|
||||
});
|
||||
return;
|
||||
}
|
||||
|
||||
auto mask_array = [](const array& mask,
|
||||
auto mask_array = [](const void* mask,
|
||||
float* data,
|
||||
int block_size,
|
||||
int batch_idx,
|
||||
int X,
|
||||
int Y,
|
||||
size_t X_data_str,
|
||||
size_t Y_data_str) {
|
||||
size_t Y_data_str,
|
||||
const Shape& mask_shape,
|
||||
const Strides& mask_strides,
|
||||
bool is_bool) {
|
||||
auto ndim = mask_shape.size();
|
||||
auto mask_offset = elem_to_loc(
|
||||
mask.shape(-1) * mask.shape(-2) * batch_idx,
|
||||
mask.shape(),
|
||||
mask.strides());
|
||||
mask_shape[ndim - 1] * mask_shape[ndim - 2] * batch_idx,
|
||||
mask_shape,
|
||||
mask_strides);
|
||||
|
||||
auto X_mask_str = mask.strides()[mask.ndim() - 2];
|
||||
auto Y_mask_str = mask.strides()[mask.ndim() - 1];
|
||||
auto X_mask_str = mask_strides[ndim - 2];
|
||||
auto Y_mask_str = mask_strides[ndim - 1];
|
||||
|
||||
if (mask.dtype() == bool_) {
|
||||
if (is_bool) {
|
||||
return mask_matrix(
|
||||
data,
|
||||
mask.data<bool>(),
|
||||
static_cast<const bool*>(mask),
|
||||
block_size,
|
||||
X,
|
||||
Y,
|
||||
@@ -140,7 +149,7 @@ void BlockMaskedMM::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
} else {
|
||||
return mask_matrix(
|
||||
data,
|
||||
mask.data<float>(),
|
||||
static_cast<const float*>(mask),
|
||||
block_size,
|
||||
X,
|
||||
Y,
|
||||
@@ -152,61 +161,155 @@ void BlockMaskedMM::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
}
|
||||
};
|
||||
|
||||
for (int i = 0; i < (out.size() / (M * size_t(N))); ++i) {
|
||||
// Adjust pointer
|
||||
float* ai =
|
||||
a.data<float>() + elem_to_loc(M * K * i, a.shape(), a.strides());
|
||||
float* bi =
|
||||
b.data<float>() + elem_to_loc(K * N * i, b.shape(), b.strides());
|
||||
float* ci = out.data<float>() + M * N * i;
|
||||
encoder.set_input_array(a);
|
||||
encoder.set_input_array(b);
|
||||
const void* a_mask_ptr;
|
||||
const void* b_mask_ptr;
|
||||
const void* out_mask_ptr;
|
||||
Shape a_mask_shape;
|
||||
Shape b_mask_shape;
|
||||
Shape out_mask_shape;
|
||||
Strides a_mask_strides;
|
||||
Strides b_mask_strides;
|
||||
Strides out_mask_strides;
|
||||
bool a_mask_bool;
|
||||
bool b_mask_bool;
|
||||
bool out_mask_bool;
|
||||
if (has_op_mask) {
|
||||
auto& a_mask = inputs[inputs.size() - 2];
|
||||
auto& b_mask = inputs[inputs.size() - 1];
|
||||
a_mask_ptr = a_mask.data<void>();
|
||||
b_mask_ptr = b_mask.data<void>();
|
||||
a_mask_shape = a_mask.shape();
|
||||
b_mask_shape = b_mask.shape();
|
||||
a_mask_strides = a_mask.strides();
|
||||
b_mask_strides = b_mask.strides();
|
||||
a_mask_bool = (a_mask.dtype() == bool_);
|
||||
b_mask_bool = (b_mask.dtype() == bool_);
|
||||
encoder.set_input_array(a_mask);
|
||||
encoder.set_input_array(b_mask);
|
||||
}
|
||||
if (has_out_mask) {
|
||||
auto& out_mask = inputs[2];
|
||||
out_mask_ptr = out_mask.data<void>();
|
||||
out_mask_bool = (out_mask.dtype() == bool_);
|
||||
encoder.set_input_array(out_mask);
|
||||
out_mask_shape = out_mask.shape();
|
||||
out_mask_strides = out_mask.strides();
|
||||
}
|
||||
encoder.set_output_array(out);
|
||||
auto a_ptr = a.data<float>();
|
||||
auto b_ptr = b.data<float>();
|
||||
auto out_ptr = out.data<float>();
|
||||
size_t num_matrices = out.size() / (M * size_t(N));
|
||||
auto ldc = out.shape(-1);
|
||||
|
||||
// Zero out blocks in a and b if needed
|
||||
if (has_op_mask) {
|
||||
auto& a_mask = inputs[inputs.size() - 2];
|
||||
mask_array(
|
||||
a_mask,
|
||||
ai,
|
||||
block_size_,
|
||||
i,
|
||||
encoder.dispatch([a_ptr,
|
||||
b_ptr,
|
||||
out_ptr,
|
||||
a_mask_ptr,
|
||||
b_mask_ptr,
|
||||
out_mask_ptr,
|
||||
has_op_mask,
|
||||
has_out_mask,
|
||||
block_size = block_size_,
|
||||
num_matrices,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
a_transposed = a_transposed,
|
||||
b_transposed = b_transposed,
|
||||
lda = lda,
|
||||
ldb = ldb,
|
||||
ldc,
|
||||
a_shape = a.shape(),
|
||||
a_strides = a.strides(),
|
||||
b_shape = b.shape(),
|
||||
b_strides = b.strides(),
|
||||
a_mask_shape = std::move(a_mask_shape),
|
||||
b_mask_shape = std::move(b_mask_shape),
|
||||
out_mask_shape = std::move(out_mask_shape),
|
||||
a_mask_strides = std::move(a_mask_strides),
|
||||
b_mask_strides = std::move(b_mask_strides),
|
||||
out_mask_strides = std::move(out_mask_strides),
|
||||
mask_array,
|
||||
a_mask_bool,
|
||||
b_mask_bool,
|
||||
out_mask_bool]() {
|
||||
for (int i = 0; i < num_matrices; ++i) {
|
||||
// Adjust pointer
|
||||
float* ai = a_ptr + elem_to_loc(M * K * i, a_shape, a_strides);
|
||||
float* bi = b_ptr + elem_to_loc(K * N * i, b_shape, b_strides);
|
||||
float* ci = out_ptr + M * N * i;
|
||||
|
||||
// Zero out blocks in a and b if needed
|
||||
if (has_op_mask) {
|
||||
mask_array(
|
||||
a_mask_ptr,
|
||||
ai,
|
||||
block_size,
|
||||
i,
|
||||
M,
|
||||
K,
|
||||
a_transposed ? 1 : lda,
|
||||
a_transposed ? lda : 1,
|
||||
a_mask_shape,
|
||||
a_mask_strides,
|
||||
a_mask_bool);
|
||||
|
||||
mask_array(
|
||||
b_mask_ptr,
|
||||
bi,
|
||||
block_size,
|
||||
i,
|
||||
K,
|
||||
N,
|
||||
b_transposed ? 1 : ldb,
|
||||
b_transposed ? ldb : 1,
|
||||
b_mask_shape,
|
||||
b_mask_strides,
|
||||
b_mask_bool);
|
||||
}
|
||||
|
||||
// Do matmul
|
||||
cblas_sgemm(
|
||||
CblasRowMajor,
|
||||
a_transposed ? CblasTrans : CblasNoTrans, // transA
|
||||
b_transposed ? CblasTrans : CblasNoTrans, // transB
|
||||
M,
|
||||
K,
|
||||
a_transposed ? 1 : lda,
|
||||
a_transposed ? lda : 1);
|
||||
|
||||
auto& b_mask = inputs[inputs.size() - 1];
|
||||
mask_array(
|
||||
b_mask,
|
||||
bi,
|
||||
block_size_,
|
||||
i,
|
||||
K,
|
||||
N,
|
||||
b_transposed ? 1 : ldb,
|
||||
b_transposed ? ldb : 1);
|
||||
}
|
||||
K,
|
||||
1.0, // alpha
|
||||
ai,
|
||||
lda,
|
||||
bi,
|
||||
ldb,
|
||||
0.0, // beta
|
||||
ci,
|
||||
ldc);
|
||||
|
||||
// Do matmul
|
||||
cblas_sgemm(
|
||||
CblasRowMajor,
|
||||
a_transposed ? CblasTrans : CblasNoTrans, // transA
|
||||
b_transposed ? CblasTrans : CblasNoTrans, // transB
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
1.0, // alpha
|
||||
ai,
|
||||
lda,
|
||||
bi,
|
||||
ldb,
|
||||
0.0, // beta
|
||||
ci,
|
||||
out.shape(-1) // ldc
|
||||
);
|
||||
|
||||
// Zero out blocks in out
|
||||
if (has_out_mask) {
|
||||
mask_array(inputs[2], ci, block_size_, i, M, N, N, 1);
|
||||
// Zero out blocks in out
|
||||
if (has_out_mask) {
|
||||
mask_array(
|
||||
out_mask_ptr,
|
||||
ci,
|
||||
block_size,
|
||||
i,
|
||||
M,
|
||||
N,
|
||||
N,
|
||||
1,
|
||||
out_mask_shape,
|
||||
out_mask_strides,
|
||||
out_mask_bool);
|
||||
}
|
||||
}
|
||||
});
|
||||
if (a_copied) {
|
||||
encoder.add_temporary(a);
|
||||
}
|
||||
if (b_copied) {
|
||||
encoder.add_temporary(b);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -215,12 +318,13 @@ void GatherMM::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
throw std::runtime_error(
|
||||
"[GatherMM::eval] Currently only supports float32.");
|
||||
}
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
|
||||
auto& a_pre = inputs[0];
|
||||
auto& b_pre = inputs[1];
|
||||
|
||||
auto check_transpose = [](const array& arr) {
|
||||
std::vector<array> temps;
|
||||
auto check_transpose = [s = stream(), &temps](const array& arr) {
|
||||
auto stx = arr.strides()[arr.ndim() - 2];
|
||||
auto sty = arr.strides()[arr.ndim() - 1];
|
||||
if (stx == arr.shape(-1) && sty == 1) {
|
||||
@@ -228,10 +332,10 @@ void GatherMM::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
} else if (stx == 1 && sty == arr.shape(-2)) {
|
||||
return std::make_tuple(true, sty, arr);
|
||||
} else {
|
||||
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
|
||||
copy(arr, arr_copy, CopyType::General);
|
||||
temps.push_back(array(arr.shape(), arr.dtype(), nullptr, {}));
|
||||
copy(arr, temps.back(), CopyType::General, s);
|
||||
int64_t stx = arr.shape(-1);
|
||||
return std::make_tuple(false, stx, arr_copy);
|
||||
return std::make_tuple(false, stx, temps.back());
|
||||
}
|
||||
};
|
||||
|
||||
@@ -246,8 +350,12 @@ void GatherMM::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
return;
|
||||
}
|
||||
|
||||
auto& encoder = cpu::get_command_encoder(stream());
|
||||
if (K == 0) {
|
||||
std::memset(static_cast<void*>(out.data<float>()), 0, out.nbytes());
|
||||
encoder.set_output_array(out);
|
||||
encoder.dispatch([out_ptr = out.data<float>(), size = out.size()]() {
|
||||
std::fill_n(out_ptr, size, 0);
|
||||
});
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -272,29 +380,61 @@ void GatherMM::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
|
||||
const uint32_t* lhs_indices_ptr = lhs_indices.data<uint32_t>();
|
||||
const uint32_t* rhs_indices_ptr = rhs_indices.data<uint32_t>();
|
||||
encoder.set_input_array(a);
|
||||
encoder.set_input_array(b);
|
||||
encoder.set_input_array(lhs_indices);
|
||||
encoder.set_input_array(rhs_indices);
|
||||
encoder.set_output_array(out);
|
||||
auto ldc = out.shape(-1);
|
||||
|
||||
for (int i = 0; i < batch_size_out; i++) {
|
||||
// Get index
|
||||
uint32_t indx_A = lhs_indices_ptr[elem_to_loc(i, lhs_indices)];
|
||||
uint32_t indx_B = rhs_indices_ptr[elem_to_loc(i, rhs_indices)];
|
||||
encoder.dispatch([a_ptr = a.data<float>(),
|
||||
b_ptr = b.data<float>(),
|
||||
out_ptr = out.data<float>(),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
lda = lda,
|
||||
ldb = ldb,
|
||||
a_transposed = a_transposed,
|
||||
b_transposed = b_transposed,
|
||||
ldc,
|
||||
lhs_indices_ptr,
|
||||
rhs_indices_ptr,
|
||||
lhs_indices_shape = lhs_indices.shape(),
|
||||
lhs_indices_strides = lhs_indices.strides(),
|
||||
rhs_indices_shape = rhs_indices.shape(),
|
||||
rhs_indices_strides = rhs_indices.strides(),
|
||||
batch_size_out,
|
||||
matrix_stride_out,
|
||||
batch_shape_A = std::move(batch_shape_A),
|
||||
batch_shape_B = std::move(batch_shape_B),
|
||||
batch_strides_A = std::move(batch_strides_A),
|
||||
batch_strides_B = std::move(batch_strides_B)]() {
|
||||
for (int i = 0; i < batch_size_out; i++) {
|
||||
// Get index
|
||||
uint32_t indx_A = lhs_indices_ptr[elem_to_loc(
|
||||
i, lhs_indices_shape, lhs_indices_strides)];
|
||||
uint32_t indx_B = rhs_indices_ptr[elem_to_loc(
|
||||
i, rhs_indices_shape, rhs_indices_strides)];
|
||||
|
||||
cblas_sgemm(
|
||||
CblasRowMajor,
|
||||
a_transposed ? CblasTrans : CblasNoTrans, // transA
|
||||
b_transposed ? CblasTrans : CblasNoTrans, // transB
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
1.0f, // alpha
|
||||
a.data<float>() + elem_to_loc(indx_A, batch_shape_A, batch_strides_A),
|
||||
lda,
|
||||
b.data<float>() + elem_to_loc(indx_B, batch_shape_B, batch_strides_B),
|
||||
ldb,
|
||||
0.0f, // beta
|
||||
out.data<float>() + matrix_stride_out * i,
|
||||
out.shape(-1) // ldc
|
||||
);
|
||||
}
|
||||
cblas_sgemm(
|
||||
CblasRowMajor,
|
||||
a_transposed ? CblasTrans : CblasNoTrans, // transA
|
||||
b_transposed ? CblasTrans : CblasNoTrans, // transB
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
1.0f, // alpha
|
||||
a_ptr + elem_to_loc(indx_A, batch_shape_A, batch_strides_A),
|
||||
lda,
|
||||
b_ptr + elem_to_loc(indx_B, batch_shape_B, batch_strides_B),
|
||||
ldb,
|
||||
0.0f, // beta
|
||||
out_ptr + matrix_stride_out * i,
|
||||
ldc);
|
||||
}
|
||||
});
|
||||
encoder.add_temporaries(std::move(temps));
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
||||
@@ -3,18 +3,76 @@
|
||||
#include <cstring>
|
||||
#include "mlx/array.h"
|
||||
#include "mlx/backend/cpu/copy.h"
|
||||
#include "mlx/backend/cpu/encoder.h"
|
||||
#include "mlx/backend/cpu/gemm.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
template <typename T>
|
||||
void matmul_dispatch(
|
||||
const array& a,
|
||||
const array& b,
|
||||
array& out,
|
||||
bool a_transposed,
|
||||
bool b_transposed,
|
||||
size_t lda,
|
||||
size_t ldb,
|
||||
float alpha,
|
||||
float beta,
|
||||
Stream stream) {
|
||||
const T* a_ptr = a.data<T>();
|
||||
const T* b_ptr = b.data<T>();
|
||||
T* out_ptr = out.data<T>();
|
||||
size_t ldc = out.shape(-1);
|
||||
size_t batch_size = a.size() / (a.shape(-2) * a.shape(-1));
|
||||
auto& encoder = cpu::get_command_encoder(stream);
|
||||
encoder.set_input_array(a);
|
||||
encoder.set_input_array(b);
|
||||
encoder.set_output_array(out);
|
||||
encoder.dispatch([a_ptr,
|
||||
b_ptr,
|
||||
out_ptr,
|
||||
a_transposed,
|
||||
b_transposed,
|
||||
lda,
|
||||
ldb,
|
||||
ldc,
|
||||
alpha,
|
||||
beta,
|
||||
batch_size,
|
||||
a_shape = a.shape(),
|
||||
a_strides = a.strides(),
|
||||
b_shape = b.shape(),
|
||||
b_strides = b.strides()]() {
|
||||
matmul<T>(
|
||||
a_ptr,
|
||||
b_ptr,
|
||||
out_ptr,
|
||||
a_transposed,
|
||||
b_transposed,
|
||||
lda,
|
||||
ldb,
|
||||
ldc,
|
||||
alpha,
|
||||
beta,
|
||||
batch_size,
|
||||
a_shape,
|
||||
a_strides,
|
||||
b_shape,
|
||||
b_strides);
|
||||
});
|
||||
}
|
||||
|
||||
void matmul_general(
|
||||
const array& a_pre,
|
||||
const array& b_pre,
|
||||
array& out,
|
||||
Stream stream,
|
||||
float alpha = 1.0f,
|
||||
float beta = 0.0f) {
|
||||
auto check_transpose = [](const array& arr) {
|
||||
std::vector<array> temps;
|
||||
auto check_transpose = [stream, &temps](const array& arr) {
|
||||
auto stx = arr.strides()[arr.ndim() - 2];
|
||||
auto sty = arr.strides()[arr.ndim() - 1];
|
||||
if (stx == arr.shape(-1) && sty == 1) {
|
||||
@@ -22,10 +80,10 @@ void matmul_general(
|
||||
} else if (stx == 1 && sty == arr.shape(-2)) {
|
||||
return std::make_tuple(true, sty, arr);
|
||||
} else {
|
||||
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
|
||||
copy(arr, arr_copy, CopyType::General);
|
||||
temps.push_back(array(arr.shape(), arr.dtype(), nullptr, {}));
|
||||
copy(arr, temps.back(), CopyType::General, stream);
|
||||
stx = arr.shape(-1);
|
||||
return std::make_tuple(false, stx, arr_copy);
|
||||
return std::make_tuple(false, stx, temps.back());
|
||||
}
|
||||
};
|
||||
|
||||
@@ -39,28 +97,34 @@ void matmul_general(
|
||||
}
|
||||
|
||||
if (out.dtype() == float32) {
|
||||
matmul<float>(a, b, out, a_transposed, b_transposed, lda, ldb, alpha, beta);
|
||||
matmul_dispatch<float>(
|
||||
a, b, out, a_transposed, b_transposed, lda, ldb, alpha, beta, stream);
|
||||
} else if (out.dtype() == float16) {
|
||||
matmul<float16_t>(
|
||||
a, b, out, a_transposed, b_transposed, lda, ldb, alpha, beta);
|
||||
matmul_dispatch<float16_t>(
|
||||
a, b, out, a_transposed, b_transposed, lda, ldb, alpha, beta, stream);
|
||||
} else if (out.dtype() == bfloat16) {
|
||||
matmul<bfloat16_t>(
|
||||
a, b, out, a_transposed, b_transposed, lda, ldb, alpha, beta);
|
||||
matmul_dispatch<bfloat16_t>(
|
||||
a, b, out, a_transposed, b_transposed, lda, ldb, alpha, beta, stream);
|
||||
} else if (out.dtype() == float64) {
|
||||
matmul<double>(
|
||||
a, b, out, a_transposed, b_transposed, lda, ldb, alpha, beta);
|
||||
matmul_dispatch<double>(
|
||||
a, b, out, a_transposed, b_transposed, lda, ldb, alpha, beta, stream);
|
||||
} else {
|
||||
throw std::runtime_error("[Matmul::eval_cpu] Invalid type.");
|
||||
}
|
||||
cpu::get_command_encoder(stream).add_temporaries(std::move(temps));
|
||||
}
|
||||
|
||||
void Matmul::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
if (inputs[0].shape(-1) == 0) {
|
||||
std::memset(out.data<void>(), 0, out.nbytes());
|
||||
auto& encoder = cpu::get_command_encoder(stream());
|
||||
encoder.set_output_array(out);
|
||||
encoder.dispatch([out_ptr = out.data<void>(), nbytes = out.nbytes()]() {
|
||||
std::memset(out_ptr, 0, nbytes);
|
||||
});
|
||||
return;
|
||||
}
|
||||
return matmul_general(inputs[0], inputs[1], out);
|
||||
matmul_general(inputs[0], inputs[1], out, stream());
|
||||
}
|
||||
|
||||
void AddMM::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
@@ -74,9 +138,9 @@ void AddMM::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
CopyType ctype = c.data_size() == 1
|
||||
? CopyType::Scalar
|
||||
: (c.flags().row_contiguous ? CopyType::Vector : CopyType::General);
|
||||
copy(c, out, ctype);
|
||||
copy(c, out, ctype, stream());
|
||||
|
||||
return matmul_general(inputs[0], inputs[1], out, alpha_, beta_);
|
||||
matmul_general(inputs[0], inputs[1], out, stream(), alpha_, beta_);
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
||||
@@ -7,11 +7,11 @@
|
||||
#include <sstream>
|
||||
|
||||
#include "mlx/allocator.h"
|
||||
#include "mlx/backend/common/load.h"
|
||||
#include "mlx/backend/common/slicing.h"
|
||||
#include "mlx/backend/common/utils.h"
|
||||
#include "mlx/backend/cpu/arange.h"
|
||||
#include "mlx/backend/cpu/copy.h"
|
||||
#include "mlx/backend/cpu/encoder.h"
|
||||
#include "mlx/backend/cpu/threefry.h"
|
||||
#include "mlx/primitives.h"
|
||||
#include "mlx/utils.h"
|
||||
@@ -21,40 +21,59 @@ namespace mlx::core {
|
||||
void reshape(const array& in, array& out) {
|
||||
auto [copy_necessary, out_strides] = prepare_reshape(in, out);
|
||||
if (copy_necessary) {
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
copy_inplace(in, out, CopyType::General);
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
copy_inplace(in, out, CopyType::General, out.primitive().stream());
|
||||
} else {
|
||||
shared_buffer_reshape(in, out_strides, out);
|
||||
}
|
||||
}
|
||||
|
||||
int64_t compute_dynamic_offset(
|
||||
static std::pair<array, bool> compute_dynamic_offset(
|
||||
const array& indices,
|
||||
const Strides& strides,
|
||||
const std::vector<int>& axes) {
|
||||
auto compute_offset = [&strides, &axes](const auto* indices) {
|
||||
int64_t offset = 0;
|
||||
for (int i = 0; i < axes.size(); ++i) {
|
||||
offset += indices[i] * strides[axes[i]];
|
||||
}
|
||||
return offset;
|
||||
};
|
||||
const std::vector<int>& axes,
|
||||
Stream stream) {
|
||||
array offset({1}, int64, nullptr, {});
|
||||
bool donate = indices.is_donatable() &&
|
||||
(indices.data_size() * indices.itemsize()) >= offset.itemsize();
|
||||
if (donate) {
|
||||
offset.copy_shared_buffer(indices);
|
||||
} else {
|
||||
offset.set_data(allocator::malloc(offset.itemsize()));
|
||||
}
|
||||
|
||||
auto& encoder = cpu::get_command_encoder(stream);
|
||||
encoder.set_input_array(indices);
|
||||
encoder.set_output_array(offset);
|
||||
auto compute_offset =
|
||||
[strides, axes, offset = offset.data<int64_t>()](const auto* indices) {
|
||||
int64_t offset_ = 0;
|
||||
for (int i = 0; i < axes.size(); ++i) {
|
||||
offset_ += indices[i] * strides[axes[i]];
|
||||
}
|
||||
offset[0] = offset_;
|
||||
};
|
||||
switch (indices.dtype()) {
|
||||
case int8:
|
||||
case uint8:
|
||||
return compute_offset(indices.data<uint8_t>());
|
||||
encoder.dispatch(compute_offset, indices.data<uint8_t>());
|
||||
break;
|
||||
case int16:
|
||||
case uint16:
|
||||
return compute_offset(indices.data<uint16_t>());
|
||||
encoder.dispatch(compute_offset, indices.data<uint16_t>());
|
||||
break;
|
||||
case int32:
|
||||
case uint32:
|
||||
return compute_offset(indices.data<uint32_t>());
|
||||
encoder.dispatch(compute_offset, indices.data<uint32_t>());
|
||||
break;
|
||||
case int64:
|
||||
case uint64:
|
||||
return compute_offset(indices.data<uint64_t>());
|
||||
encoder.dispatch(compute_offset, indices.data<uint64_t>());
|
||||
break;
|
||||
default:
|
||||
throw std::runtime_error("Invalid indices type.");
|
||||
}
|
||||
return {offset, donate};
|
||||
}
|
||||
|
||||
void AsStrided::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
@@ -104,14 +123,59 @@ void Transpose::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
}
|
||||
|
||||
void Arange::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
arange(inputs, out, start_, step_);
|
||||
assert(inputs.size() == 0);
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
switch (out.dtype()) {
|
||||
case bool_:
|
||||
throw std::runtime_error("Bool type unsupported for arange.");
|
||||
break;
|
||||
case uint8:
|
||||
arange<uint8_t>(start_, start_ + step_, out, out.size(), stream());
|
||||
break;
|
||||
case uint16:
|
||||
arange<uint16_t>(start_, start_ + step_, out, out.size(), stream());
|
||||
break;
|
||||
case uint32:
|
||||
arange<uint32_t>(start_, start_ + step_, out, out.size(), stream());
|
||||
break;
|
||||
case uint64:
|
||||
arange<uint64_t>(start_, start_ + step_, out, out.size(), stream());
|
||||
break;
|
||||
case int8:
|
||||
arange<int8_t>(start_, start_ + step_, out, out.size(), stream());
|
||||
break;
|
||||
case int16:
|
||||
arange<int16_t>(start_, start_ + step_, out, out.size(), stream());
|
||||
break;
|
||||
case int32:
|
||||
arange<int32_t>(start_, start_ + step_, out, out.size(), stream());
|
||||
break;
|
||||
case int64:
|
||||
arange<int64_t>(start_, start_ + step_, out, out.size(), stream());
|
||||
break;
|
||||
case float16:
|
||||
arange<float16_t>(start_, start_ + step_, out, out.size(), stream());
|
||||
break;
|
||||
case float32:
|
||||
arange<float>(start_, start_ + step_, out, out.size(), stream());
|
||||
break;
|
||||
case float64:
|
||||
arange<double>(start_, start_ + step_, out, out.size(), stream());
|
||||
break;
|
||||
case bfloat16:
|
||||
arange<bfloat16_t>(start_, start_ + step_, out, out.size(), stream());
|
||||
break;
|
||||
case complex64:
|
||||
arange<complex64_t>(start_, start_ + step_, out, out.size(), stream());
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
void AsType::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
CopyType ctype = in.flags().contiguous ? CopyType::Vector : CopyType::General;
|
||||
copy(in, out, ctype);
|
||||
copy(in, out, ctype, stream());
|
||||
}
|
||||
|
||||
void Concatenate::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
@@ -122,7 +186,7 @@ void Concatenate::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
}
|
||||
std::partial_sum(sizes.cbegin(), sizes.cend(), sizes.begin());
|
||||
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
|
||||
auto strides = out.strides();
|
||||
auto flags = out.flags();
|
||||
@@ -134,18 +198,20 @@ void Concatenate::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
size_t data_offset = strides[axis_] * sizes[i];
|
||||
out_slice.copy_shared_buffer(
|
||||
out, strides, flags, out_slice.size(), data_offset);
|
||||
copy_inplace(inputs[i], out_slice, CopyType::GeneralGeneral);
|
||||
copy_inplace(inputs[i], out_slice, CopyType::GeneralGeneral, stream());
|
||||
}
|
||||
}
|
||||
|
||||
void Contiguous::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
if (in.flags().row_contiguous ||
|
||||
(allow_col_major_ && in.flags().col_contiguous)) {
|
||||
constexpr size_t extra_bytes = 16384;
|
||||
if (in.buffer_size() <= out.nbytes() + extra_bytes &&
|
||||
(in.flags().row_contiguous ||
|
||||
(allow_col_major_ && in.flags().col_contiguous))) {
|
||||
out.copy_shared_buffer(in);
|
||||
} else {
|
||||
copy(in, out, CopyType::General);
|
||||
copy(in, out, CopyType::General, stream());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -169,14 +235,7 @@ void Full::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
} else {
|
||||
ctype = CopyType::General;
|
||||
}
|
||||
copy(in, out, ctype);
|
||||
}
|
||||
|
||||
void Load::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 0);
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
|
||||
load(out, offset_, reader_, swap_endianness_);
|
||||
copy(in, out, ctype, stream());
|
||||
}
|
||||
|
||||
void Pad::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
@@ -192,7 +251,7 @@ void Pad::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(val.dtype() == in.dtype() && in.dtype() == out.dtype());
|
||||
|
||||
// Fill output with val
|
||||
copy(val, out, CopyType::Scalar);
|
||||
copy(val, out, CopyType::Scalar, stream());
|
||||
|
||||
// Find offset for start of input values
|
||||
size_t data_offset = 0;
|
||||
@@ -207,7 +266,7 @@ void Pad::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
out, out.strides(), out.flags(), out_slice.size(), data_offset);
|
||||
|
||||
// Copy input values into the slice
|
||||
copy_inplace(in, out_slice, CopyType::GeneralGeneral);
|
||||
copy_inplace(in, out_slice, CopyType::GeneralGeneral, stream());
|
||||
}
|
||||
|
||||
void RandomBits::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
@@ -219,43 +278,53 @@ void RandomBits::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
|
||||
size_t elems_per_key = out.size() / num_keys;
|
||||
size_t bytes_per_key = out.itemsize() * elems_per_key;
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
|
||||
auto kptr = inputs[0].data<uint32_t>();
|
||||
auto cptr = out.data<char>();
|
||||
size_t out_skip = (bytes_per_key + 4 - 1) / 4;
|
||||
auto half_size = out_skip / 2;
|
||||
bool even = out_skip % 2 == 0;
|
||||
for (int i = 0; i < num_keys; ++i, cptr += bytes_per_key) {
|
||||
auto ptr = reinterpret_cast<uint32_t*>(cptr);
|
||||
// Get ith key
|
||||
auto kidx = 2 * i;
|
||||
auto k1_elem = elem_to_loc(kidx, keys.shape(), keys.strides());
|
||||
auto k2_elem = elem_to_loc(kidx + 1, keys.shape(), keys.strides());
|
||||
auto key = std::make_pair(kptr[k1_elem], kptr[k2_elem]);
|
||||
auto& encoder = cpu::get_command_encoder(stream());
|
||||
encoder.set_input_array(inputs[0]);
|
||||
encoder.set_output_array(out);
|
||||
encoder.dispatch([kptr,
|
||||
cptr,
|
||||
bytes_per_key,
|
||||
num_keys,
|
||||
kshape = keys.shape(),
|
||||
kstrides = keys.strides()]() mutable {
|
||||
size_t out_skip = (bytes_per_key + 4 - 1) / 4;
|
||||
auto half_size = out_skip / 2;
|
||||
bool even = out_skip % 2 == 0;
|
||||
for (int i = 0; i < num_keys; ++i, cptr += bytes_per_key) {
|
||||
auto ptr = reinterpret_cast<uint32_t*>(cptr);
|
||||
// Get ith key
|
||||
auto kidx = 2 * i;
|
||||
auto k1_elem = elem_to_loc(kidx, kshape, kstrides);
|
||||
auto k2_elem = elem_to_loc(kidx + 1, kshape, kstrides);
|
||||
auto key = std::make_pair(kptr[k1_elem], kptr[k2_elem]);
|
||||
|
||||
std::pair<uintptr_t, uintptr_t> count{0, half_size + !even};
|
||||
for (; count.first + 1 < half_size; count.first++, count.second++) {
|
||||
std::tie(ptr[count.first], ptr[count.second]) =
|
||||
random::threefry2x32_hash(key, count);
|
||||
}
|
||||
if (count.first < half_size) {
|
||||
auto rb = random::threefry2x32_hash(key, count);
|
||||
ptr[count.first++] = rb.first;
|
||||
if (bytes_per_key % 4 > 0) {
|
||||
std::copy(
|
||||
reinterpret_cast<char*>(&rb.second),
|
||||
reinterpret_cast<char*>(&rb.second) + bytes_per_key % 4,
|
||||
cptr + 4 * count.second);
|
||||
} else {
|
||||
ptr[count.second] = rb.second;
|
||||
std::pair<uintptr_t, uintptr_t> count{0, half_size + !even};
|
||||
for (; count.first + 1 < half_size; count.first++, count.second++) {
|
||||
std::tie(ptr[count.first], ptr[count.second]) =
|
||||
random::threefry2x32_hash(key, count);
|
||||
}
|
||||
if (count.first < half_size) {
|
||||
auto rb = random::threefry2x32_hash(key, count);
|
||||
ptr[count.first++] = rb.first;
|
||||
if (bytes_per_key % 4 > 0) {
|
||||
std::copy(
|
||||
reinterpret_cast<char*>(&rb.second),
|
||||
reinterpret_cast<char*>(&rb.second) + bytes_per_key % 4,
|
||||
cptr + 4 * count.second);
|
||||
} else {
|
||||
ptr[count.second] = rb.second;
|
||||
}
|
||||
}
|
||||
if (!even) {
|
||||
count.second = 0;
|
||||
ptr[half_size] = random::threefry2x32_hash(key, count).first;
|
||||
}
|
||||
}
|
||||
if (!even) {
|
||||
count.second = 0;
|
||||
ptr[half_size] = random::threefry2x32_hash(key, count).first;
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
void Reshape::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
@@ -268,17 +337,24 @@ void DynamicSlice::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
return;
|
||||
}
|
||||
auto& in = inputs[0];
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
auto i_offset = compute_dynamic_offset(inputs[1], in.strides(), axes_);
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
auto [in_offset, donated] =
|
||||
compute_dynamic_offset(inputs[1], in.strides(), axes_, stream());
|
||||
copy_inplace(
|
||||
/* const array& src = */ in,
|
||||
/* array& dst = */ out,
|
||||
/* const Shape& data_shape = */ out.shape(),
|
||||
/* const Strides& i_strides = */ in.strides(),
|
||||
/* const Strides& o_strides = */ out.strides(),
|
||||
/* int64_t i_offset = */ i_offset,
|
||||
/* int64_t i_offset = */ 0,
|
||||
/* int64_t o_offset = */ 0,
|
||||
/* CopyType ctype = */ CopyType::GeneralGeneral);
|
||||
/* CopyType ctype = */ CopyType::GeneralGeneral,
|
||||
stream(),
|
||||
/* const std::optional<array>& dynamic_i_offset = */ in_offset,
|
||||
/* const std::optional<array>& dynamic_o_offset = */ std::nullopt);
|
||||
if (!donated) {
|
||||
cpu::get_command_encoder(stream()).add_temporary(std::move(in_offset));
|
||||
}
|
||||
}
|
||||
|
||||
void DynamicSliceUpdate::eval_cpu(
|
||||
@@ -296,9 +372,10 @@ void DynamicSliceUpdate::eval_cpu(
|
||||
auto ctype = in.flags().contiguous && in.size() == in.data_size()
|
||||
? CopyType::Vector
|
||||
: CopyType::General;
|
||||
copy(in, out, in.data_size() == 1 ? CopyType::Scalar : ctype);
|
||||
copy(in, out, in.data_size() == 1 ? CopyType::Scalar : ctype, stream());
|
||||
|
||||
auto o_offset = compute_dynamic_offset(inputs[2], out.strides(), axes_);
|
||||
auto [out_offset, donated] =
|
||||
compute_dynamic_offset(inputs[2], out.strides(), axes_, stream());
|
||||
copy_inplace(
|
||||
/* const array& src = */ upd,
|
||||
/* array& dst = */ out,
|
||||
@@ -306,8 +383,14 @@ void DynamicSliceUpdate::eval_cpu(
|
||||
/* const std::vector<stride_t>& i_strides = */ upd.strides(),
|
||||
/* const std::vector<stride_t>& o_strides = */ out.strides(),
|
||||
/* int64_t i_offset = */ 0,
|
||||
/* int64_t o_offset = */ o_offset,
|
||||
/* CopyType ctype = */ CopyType::GeneralGeneral);
|
||||
/* int64_t o_offset = */ 0,
|
||||
/* CopyType ctype = */ CopyType::GeneralGeneral,
|
||||
stream(),
|
||||
/* const std::optional<array>& dynamic_i_offset = */ std::nullopt,
|
||||
/* const std::optional<array>& dynamic_o_offset = */ out_offset);
|
||||
if (!donated) {
|
||||
cpu::get_command_encoder(stream()).add_temporary(std::move(out_offset));
|
||||
}
|
||||
}
|
||||
|
||||
void SliceUpdate::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
@@ -329,7 +412,7 @@ void SliceUpdate::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
auto ctype = in.flags().contiguous && in.size() == in.data_size()
|
||||
? CopyType::Vector
|
||||
: CopyType::General;
|
||||
copy(in, out, in.data_size() == 1 ? CopyType::Scalar : ctype);
|
||||
copy(in, out, in.data_size() == 1 ? CopyType::Scalar : ctype, stream());
|
||||
|
||||
// Calculate out strides, initial offset and if copy needs to be made
|
||||
auto [data_offset, out_strides] =
|
||||
@@ -344,7 +427,8 @@ void SliceUpdate::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
/* const std::vector<stride_t>& o_strides = */ out_strides,
|
||||
/* int64_t i_offset = */ 0,
|
||||
/* int64_t o_offset = */ data_offset,
|
||||
/* CopyType ctype = */ CopyType::GeneralGeneral);
|
||||
/* CopyType ctype = */ CopyType::GeneralGeneral,
|
||||
stream());
|
||||
}
|
||||
|
||||
void View::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
@@ -368,13 +452,13 @@ void View::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
} else {
|
||||
auto tmp = array(
|
||||
in.shape(), in.dtype() == bool_ ? uint8 : in.dtype(), nullptr, {});
|
||||
tmp.set_data(allocator::malloc_or_wait(tmp.nbytes()));
|
||||
tmp.set_data(allocator::malloc(tmp.nbytes()));
|
||||
if (in.dtype() == bool_) {
|
||||
auto in_tmp = array(in.shape(), uint8, nullptr, {});
|
||||
in_tmp.copy_shared_buffer(in);
|
||||
copy_inplace(in_tmp, tmp, CopyType::General);
|
||||
copy_inplace(in_tmp, tmp, CopyType::General, stream());
|
||||
} else {
|
||||
copy_inplace(in, tmp, CopyType::General);
|
||||
copy_inplace(in, tmp, CopyType::General, stream());
|
||||
}
|
||||
|
||||
auto flags = out.flags();
|
||||
@@ -382,7 +466,7 @@ void View::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
flags.row_contiguous = true;
|
||||
auto max_dim = std::max_element(out.shape().begin(), out.shape().end());
|
||||
flags.col_contiguous = out.size() <= 1 || out.size() == *max_dim;
|
||||
out.move_shared_buffer(tmp, out.strides(), flags, out.size());
|
||||
out.copy_shared_buffer(tmp, out.strides(), flags, out.size());
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -2,20 +2,18 @@
|
||||
|
||||
#include "mlx/allocator.h"
|
||||
#include "mlx/backend/cpu/copy.h"
|
||||
#include "mlx/backend/cpu/encoder.h"
|
||||
#include "mlx/backend/cpu/lapack.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
template <typename T>
|
||||
void qrf_impl(const array& a, array& q, array& r) {
|
||||
void qrf_impl(const array& a, array& q, array& r, Stream stream) {
|
||||
const int M = a.shape(-2);
|
||||
const int N = a.shape(-1);
|
||||
const int lda = M;
|
||||
size_t num_matrices = a.size() / (M * N);
|
||||
int num_reflectors = std::min(M, N);
|
||||
auto tau =
|
||||
allocator::malloc_or_wait(sizeof(T) * num_matrices * num_reflectors);
|
||||
|
||||
// Copy A to inplace input and make it col-contiguous
|
||||
array in(a.shape(), a.dtype(), nullptr, {});
|
||||
@@ -27,95 +25,107 @@ void qrf_impl(const array& a, array& q, array& r) {
|
||||
auto strides = in.strides();
|
||||
strides[in.ndim() - 2] = 1;
|
||||
strides[in.ndim() - 1] = M;
|
||||
in.set_data(
|
||||
allocator::malloc_or_wait(in.nbytes()), in.nbytes(), strides, flags);
|
||||
copy_inplace(a, in, CopyType::GeneralGeneral);
|
||||
in.set_data(allocator::malloc(in.nbytes()), in.nbytes(), strides, flags);
|
||||
copy_inplace(a, in, CopyType::GeneralGeneral, stream);
|
||||
auto& encoder = cpu::get_command_encoder(stream);
|
||||
q.set_data(allocator::malloc(q.nbytes()));
|
||||
r.set_data(allocator::malloc(r.nbytes()));
|
||||
|
||||
T optimal_work;
|
||||
int lwork = -1;
|
||||
int info;
|
||||
auto in_ptr = in.data<T>();
|
||||
auto r_ptr = r.data<T>();
|
||||
auto q_ptr = q.data<T>();
|
||||
|
||||
// Compute workspace size
|
||||
geqrf<T>(&M, &N, nullptr, &lda, nullptr, &optimal_work, &lwork, &info);
|
||||
encoder.set_input_array(in);
|
||||
encoder.set_output_array(q);
|
||||
encoder.set_output_array(r);
|
||||
encoder.dispatch([in_ptr, q_ptr, r_ptr, M, N, lda, num_matrices]() {
|
||||
int num_reflectors = std::min(M, N);
|
||||
auto tau = allocator::malloc(sizeof(T) * num_matrices * num_reflectors);
|
||||
|
||||
// Update workspace size
|
||||
lwork = optimal_work;
|
||||
auto work = allocator::malloc_or_wait(sizeof(T) * lwork);
|
||||
T optimal_work;
|
||||
int lwork = -1;
|
||||
int info;
|
||||
|
||||
// Loop over matrices
|
||||
for (int i = 0; i < num_matrices; ++i) {
|
||||
// Solve
|
||||
geqrf<T>(
|
||||
&M,
|
||||
&N,
|
||||
in.data<T>() + M * N * i,
|
||||
&lda,
|
||||
static_cast<T*>(tau.raw_ptr()) + num_reflectors * i,
|
||||
static_cast<T*>(work.raw_ptr()),
|
||||
&lwork,
|
||||
&info);
|
||||
}
|
||||
allocator::free(work);
|
||||
// Compute workspace size
|
||||
geqrf<T>(&M, &N, nullptr, &lda, nullptr, &optimal_work, &lwork, &info);
|
||||
|
||||
r.set_data(allocator::malloc_or_wait(r.nbytes()));
|
||||
// Update workspace size
|
||||
lwork = optimal_work;
|
||||
auto work = allocator::malloc(sizeof(T) * lwork);
|
||||
|
||||
for (int i = 0; i < num_matrices; ++i) {
|
||||
/// num_reflectors x N
|
||||
for (int j = 0; j < r.shape(-2); ++j) {
|
||||
for (int k = 0; k < j; ++k) {
|
||||
r.data<T>()[i * N * num_reflectors + j * N + k] = 0;
|
||||
}
|
||||
for (int k = j; k < r.shape(-1); ++k) {
|
||||
r.data<T>()[i * N * num_reflectors + j * N + k] =
|
||||
in.data<T>()[i * N * M + j + k * M];
|
||||
// Loop over matrices
|
||||
for (int i = 0; i < num_matrices; ++i) {
|
||||
// Solve
|
||||
geqrf<T>(
|
||||
&M,
|
||||
&N,
|
||||
in_ptr + M * N * i,
|
||||
&lda,
|
||||
static_cast<T*>(tau.raw_ptr()) + num_reflectors * i,
|
||||
static_cast<T*>(work.raw_ptr()),
|
||||
&lwork,
|
||||
&info);
|
||||
}
|
||||
allocator::free(work);
|
||||
|
||||
for (int i = 0; i < num_matrices; ++i) {
|
||||
/// num_reflectors x N
|
||||
for (int j = 0; j < num_reflectors; ++j) {
|
||||
for (int k = 0; k < j; ++k) {
|
||||
r_ptr[i * N * num_reflectors + j * N + k] = 0;
|
||||
}
|
||||
for (int k = j; k < N; ++k) {
|
||||
r_ptr[i * N * num_reflectors + j * N + k] =
|
||||
in_ptr[i * N * M + j + k * M];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Get work size
|
||||
lwork = -1;
|
||||
orgqr<T>(
|
||||
&M,
|
||||
&num_reflectors,
|
||||
&num_reflectors,
|
||||
nullptr,
|
||||
&lda,
|
||||
nullptr,
|
||||
&optimal_work,
|
||||
&lwork,
|
||||
&info);
|
||||
lwork = optimal_work;
|
||||
work = allocator::malloc_or_wait(sizeof(T) * lwork);
|
||||
|
||||
// Loop over matrices
|
||||
for (int i = 0; i < num_matrices; ++i) {
|
||||
// Compute Q
|
||||
// Get work size
|
||||
lwork = -1;
|
||||
orgqr<T>(
|
||||
&M,
|
||||
&num_reflectors,
|
||||
&num_reflectors,
|
||||
in.data<T>() + M * N * i,
|
||||
nullptr,
|
||||
&lda,
|
||||
static_cast<T*>(tau.raw_ptr()) + num_reflectors * i,
|
||||
static_cast<T*>(work.raw_ptr()),
|
||||
nullptr,
|
||||
&optimal_work,
|
||||
&lwork,
|
||||
&info);
|
||||
}
|
||||
lwork = optimal_work;
|
||||
work = allocator::malloc(sizeof(T) * lwork);
|
||||
|
||||
q.set_data(allocator::malloc_or_wait(q.nbytes()));
|
||||
for (int i = 0; i < num_matrices; ++i) {
|
||||
// M x num_reflectors
|
||||
for (int j = 0; j < q.shape(-2); ++j) {
|
||||
for (int k = 0; k < q.shape(-1); ++k) {
|
||||
q.data<T>()[i * M * num_reflectors + j * num_reflectors + k] =
|
||||
in.data<T>()[i * N * M + j + k * M];
|
||||
// Loop over matrices
|
||||
for (int i = 0; i < num_matrices; ++i) {
|
||||
// Compute Q
|
||||
orgqr<T>(
|
||||
&M,
|
||||
&num_reflectors,
|
||||
&num_reflectors,
|
||||
in_ptr + M * N * i,
|
||||
&lda,
|
||||
static_cast<T*>(tau.raw_ptr()) + num_reflectors * i,
|
||||
static_cast<T*>(work.raw_ptr()),
|
||||
&lwork,
|
||||
&info);
|
||||
}
|
||||
|
||||
for (int i = 0; i < num_matrices; ++i) {
|
||||
// M x num_reflectors
|
||||
for (int j = 0; j < M; ++j) {
|
||||
for (int k = 0; k < num_reflectors; ++k) {
|
||||
q_ptr[i * M * num_reflectors + j * num_reflectors + k] =
|
||||
in_ptr[i * N * M + j + k * M];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Cleanup
|
||||
allocator::free(work);
|
||||
allocator::free(tau);
|
||||
// Cleanup
|
||||
allocator::free(work);
|
||||
allocator::free(tau);
|
||||
});
|
||||
encoder.add_temporary(in);
|
||||
}
|
||||
|
||||
void QRF::eval_cpu(
|
||||
@@ -123,10 +133,10 @@ void QRF::eval_cpu(
|
||||
std::vector<array>& outputs) {
|
||||
switch (inputs[0].dtype()) {
|
||||
case float32:
|
||||
qrf_impl<float>(inputs[0], outputs[0], outputs[1]);
|
||||
qrf_impl<float>(inputs[0], outputs[0], outputs[1], stream());
|
||||
break;
|
||||
case float64:
|
||||
qrf_impl<double>(inputs[0], outputs[0], outputs[1]);
|
||||
qrf_impl<double>(inputs[0], outputs[0], outputs[1], stream());
|
||||
break;
|
||||
default:
|
||||
throw std::runtime_error(
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
#include <cassert>
|
||||
|
||||
#include "mlx/backend/cpu/copy.h"
|
||||
#include "mlx/backend/cpu/encoder.h"
|
||||
#include "mlx/backend/cpu/simd/simd.h"
|
||||
#include "mlx/fast_primitives.h"
|
||||
#include "mlx/primitives.h"
|
||||
@@ -316,7 +317,8 @@ void _qmm_dispatch_typed(
|
||||
}
|
||||
}
|
||||
|
||||
void _qmm_dispatch(
|
||||
template <typename T>
|
||||
void _qmm_dispatch_typed(
|
||||
array& out,
|
||||
const array& x,
|
||||
const array& w,
|
||||
@@ -328,63 +330,61 @@ void _qmm_dispatch(
|
||||
int K = x.shape(-1);
|
||||
int M = x.ndim() > 1 ? x.shape(-2) : 1;
|
||||
int N = out.shape(-1);
|
||||
|
||||
int w_els = w.ndim() > 2 ? w.shape(-1) * w.shape(-2) : 0;
|
||||
int g_els = w.ndim() > 2 ? scales.shape(-1) * scales.shape(-2) : 0;
|
||||
|
||||
int batch_size = x.size() / (K * M);
|
||||
|
||||
auto out_ptr = out.data<T>();
|
||||
auto x_ptr = x.data<T>();
|
||||
auto w_ptr = w.data<uint32_t>();
|
||||
auto scales_ptr = scales.data<T>();
|
||||
auto biases_ptr = biases.data<T>();
|
||||
for (int i = 0; i < batch_size; i++) {
|
||||
switch (x.dtype()) {
|
||||
case float32:
|
||||
_qmm_dispatch_typed<float>(
|
||||
out.data<float>() + i * M * N,
|
||||
x.data<float>() + elem_to_loc(i * M * K, x),
|
||||
w.data<uint32_t>() + elem_to_loc(i * w_els, w),
|
||||
scales.data<float>() + elem_to_loc(i * g_els, scales),
|
||||
biases.data<float>() + elem_to_loc(i * g_els, biases),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
bits,
|
||||
group_size,
|
||||
transposed_w);
|
||||
break;
|
||||
case float16:
|
||||
_qmm_dispatch_typed<float16_t>(
|
||||
out.data<float16_t>() + i * M * N,
|
||||
x.data<float16_t>() + elem_to_loc(i * M * K, x),
|
||||
w.data<uint32_t>() + elem_to_loc(i * w_els, w),
|
||||
scales.data<float16_t>() + elem_to_loc(i * g_els, scales),
|
||||
biases.data<float16_t>() + elem_to_loc(i * g_els, biases),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
bits,
|
||||
group_size,
|
||||
transposed_w);
|
||||
break;
|
||||
case bfloat16:
|
||||
_qmm_dispatch_typed<bfloat16_t>(
|
||||
out.data<bfloat16_t>() + i * M * N,
|
||||
x.data<bfloat16_t>() + elem_to_loc(i * M * K, x),
|
||||
w.data<uint32_t>() + elem_to_loc(i * w_els, w),
|
||||
scales.data<bfloat16_t>() + elem_to_loc(i * g_els, scales),
|
||||
biases.data<bfloat16_t>() + elem_to_loc(i * g_els, biases),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
bits,
|
||||
group_size,
|
||||
transposed_w);
|
||||
break;
|
||||
default:
|
||||
throw std::invalid_argument(
|
||||
"[quantized_matmul] only floating types are supported");
|
||||
}
|
||||
_qmm_dispatch_typed<T>(
|
||||
out_ptr + i * M * N,
|
||||
x_ptr + elem_to_loc(i * M * K, x.shape(), x.strides()),
|
||||
w_ptr + elem_to_loc(i * w_els, w.shape(), w.strides()),
|
||||
scales_ptr + elem_to_loc(i * g_els, scales.shape(), scales.strides()),
|
||||
biases_ptr + elem_to_loc(i * g_els, biases.shape(), biases.strides()),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
bits,
|
||||
group_size,
|
||||
transposed_w);
|
||||
}
|
||||
}
|
||||
|
||||
void _bs_qmm_dispatch(
|
||||
void _qmm_dispatch(
|
||||
array& out,
|
||||
const array& x,
|
||||
const array& w,
|
||||
const array& scales,
|
||||
const array& biases,
|
||||
int bits,
|
||||
int group_size,
|
||||
bool transposed_w) {
|
||||
switch (x.dtype()) {
|
||||
case float32:
|
||||
_qmm_dispatch_typed<float>(
|
||||
out, x, w, scales, biases, bits, group_size, transposed_w);
|
||||
break;
|
||||
case float16:
|
||||
_qmm_dispatch_typed<float16_t>(
|
||||
out, x, w, scales, biases, bits, group_size, transposed_w);
|
||||
break;
|
||||
case bfloat16:
|
||||
_qmm_dispatch_typed<bfloat16_t>(
|
||||
out, x, w, scales, biases, bits, group_size, transposed_w);
|
||||
break;
|
||||
default:
|
||||
throw std::invalid_argument(
|
||||
"[quantized_matmul] only floating types are supported");
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void _bs_qmm_dispatch_typed(
|
||||
array& out,
|
||||
const array& x,
|
||||
const array& w,
|
||||
@@ -402,60 +402,90 @@ void _bs_qmm_dispatch(
|
||||
int w_els = w.shape(-1) * w.shape(-2);
|
||||
int g_els = scales.shape(-1) * scales.shape(-2);
|
||||
|
||||
const uint32_t* lhs_indices_data = lhs_indices.data<uint32_t>();
|
||||
const uint32_t* rhs_indices_data = rhs_indices.data<uint32_t>();
|
||||
auto out_ptr = out.data<T>();
|
||||
auto x_ptr = x.data<T>();
|
||||
auto w_ptr = w.data<uint32_t>();
|
||||
auto scales_ptr = scales.data<T>();
|
||||
auto biases_ptr = biases.data<T>();
|
||||
auto lhs_indices_ptr = lhs_indices.data<uint32_t>();
|
||||
auto rhs_indices_ptr = rhs_indices.data<uint32_t>();
|
||||
|
||||
for (int i = 0; i < lhs_indices.size(); i++) {
|
||||
int x_idx = lhs_indices_data[elem_to_loc(i, lhs_indices)];
|
||||
int w_idx = rhs_indices_data[elem_to_loc(i, rhs_indices)];
|
||||
int x_idx = lhs_indices_ptr[elem_to_loc(
|
||||
i, lhs_indices.shape(), lhs_indices.strides())];
|
||||
int w_idx = rhs_indices_ptr[elem_to_loc(
|
||||
i, rhs_indices.shape(), rhs_indices.strides())];
|
||||
_qmm_dispatch_typed<T>(
|
||||
out_ptr + i * M * N,
|
||||
x_ptr + elem_to_loc(x_idx * M * K, x.shape(), x.strides()),
|
||||
w_ptr + elem_to_loc(w_idx * w_els, w.shape(), w.strides()),
|
||||
scales_ptr +
|
||||
elem_to_loc(w_idx * g_els, scales.shape(), scales.strides()),
|
||||
biases_ptr +
|
||||
elem_to_loc(w_idx * g_els, biases.shape(), biases.strides()),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
bits,
|
||||
group_size,
|
||||
transposed_w);
|
||||
}
|
||||
}
|
||||
|
||||
switch (x.dtype()) {
|
||||
case float32:
|
||||
_qmm_dispatch_typed<float>(
|
||||
out.data<float>() + i * M * N,
|
||||
x.data<float>() + elem_to_loc(x_idx * M * K, x),
|
||||
w.data<uint32_t>() + elem_to_loc(w_idx * w_els, w),
|
||||
scales.data<float>() + elem_to_loc(w_idx * g_els, scales),
|
||||
biases.data<float>() + elem_to_loc(w_idx * g_els, biases),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
bits,
|
||||
group_size,
|
||||
transposed_w);
|
||||
break;
|
||||
case float16:
|
||||
_qmm_dispatch_typed<float16_t>(
|
||||
out.data<float16_t>() + i * M * N,
|
||||
x.data<float16_t>() + elem_to_loc(x_idx * M * K, x),
|
||||
w.data<uint32_t>() + elem_to_loc(w_idx * w_els, w),
|
||||
scales.data<float16_t>() + elem_to_loc(w_idx * g_els, scales),
|
||||
biases.data<float16_t>() + elem_to_loc(w_idx * g_els, biases),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
bits,
|
||||
group_size,
|
||||
transposed_w);
|
||||
break;
|
||||
case bfloat16:
|
||||
_qmm_dispatch_typed<bfloat16_t>(
|
||||
out.data<bfloat16_t>() + i * M * N,
|
||||
x.data<bfloat16_t>() + elem_to_loc(x_idx * M * K, x),
|
||||
w.data<uint32_t>() + elem_to_loc(w_idx * w_els, w),
|
||||
scales.data<bfloat16_t>() + elem_to_loc(w_idx * g_els, scales),
|
||||
biases.data<bfloat16_t>() + elem_to_loc(w_idx * g_els, biases),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
bits,
|
||||
group_size,
|
||||
transposed_w);
|
||||
break;
|
||||
default:
|
||||
throw std::invalid_argument(
|
||||
"[quantized_matmul] only floating types are supported");
|
||||
}
|
||||
void _bs_qmm_dispatch(
|
||||
array& out,
|
||||
const array& x,
|
||||
const array& w,
|
||||
const array& scales,
|
||||
const array& biases,
|
||||
const array& lhs_indices,
|
||||
const array& rhs_indices,
|
||||
int bits,
|
||||
int group_size,
|
||||
bool transposed_w) {
|
||||
switch (x.dtype()) {
|
||||
case float32:
|
||||
_bs_qmm_dispatch_typed<float>(
|
||||
out,
|
||||
x,
|
||||
w,
|
||||
scales,
|
||||
biases,
|
||||
lhs_indices,
|
||||
rhs_indices,
|
||||
bits,
|
||||
group_size,
|
||||
transposed_w);
|
||||
break;
|
||||
case float16:
|
||||
_bs_qmm_dispatch_typed<float16_t>(
|
||||
out,
|
||||
x,
|
||||
w,
|
||||
scales,
|
||||
biases,
|
||||
lhs_indices,
|
||||
rhs_indices,
|
||||
bits,
|
||||
group_size,
|
||||
transposed_w);
|
||||
break;
|
||||
case bfloat16:
|
||||
_bs_qmm_dispatch_typed<bfloat16_t>(
|
||||
out,
|
||||
x,
|
||||
w,
|
||||
scales,
|
||||
biases,
|
||||
lhs_indices,
|
||||
rhs_indices,
|
||||
bits,
|
||||
group_size,
|
||||
transposed_w);
|
||||
break;
|
||||
default:
|
||||
throw std::invalid_argument(
|
||||
"[quantized_matmul] only floating types are supported");
|
||||
}
|
||||
}
|
||||
|
||||
@@ -469,13 +499,14 @@ void QuantizedMatmul::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
auto& scales_pre = inputs[2];
|
||||
auto& biases_pre = inputs[3];
|
||||
|
||||
auto ensure_row_contiguous = [](const array& arr) {
|
||||
std::vector<array> temps;
|
||||
auto ensure_row_contiguous = [s = stream(), &temps](const array& arr) {
|
||||
if (arr.flags().row_contiguous) {
|
||||
return arr;
|
||||
} else {
|
||||
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
|
||||
copy(arr, arr_copy, CopyType::General);
|
||||
return arr_copy;
|
||||
temps.push_back(array(arr.shape(), arr.dtype(), nullptr, {}));
|
||||
copy(arr, temps.back(), CopyType::General, s);
|
||||
return temps.back();
|
||||
}
|
||||
};
|
||||
|
||||
@@ -484,8 +515,25 @@ void QuantizedMatmul::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
auto scales = ensure_row_contiguous(scales_pre);
|
||||
auto biases = ensure_row_contiguous(biases_pre);
|
||||
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
_qmm_dispatch(out, x, w, scales, biases, group_size_, bits_, transpose_);
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
|
||||
auto& encoder = cpu::get_command_encoder(stream());
|
||||
encoder.add_temporaries(std::move(temps));
|
||||
encoder.set_input_array(x);
|
||||
encoder.set_input_array(w);
|
||||
encoder.set_input_array(scales);
|
||||
encoder.set_input_array(biases);
|
||||
encoder.set_output_array(out);
|
||||
encoder.dispatch([out = array::unsafe_weak_copy(out),
|
||||
x = array::unsafe_weak_copy(x),
|
||||
w = array::unsafe_weak_copy(w),
|
||||
scales = array::unsafe_weak_copy(scales),
|
||||
biases = array::unsafe_weak_copy(biases),
|
||||
group_size_ = group_size_,
|
||||
bits_ = bits_,
|
||||
transpose_ = transpose_]() mutable {
|
||||
_qmm_dispatch(out, x, w, scales, biases, group_size_, bits_, transpose_);
|
||||
});
|
||||
}
|
||||
|
||||
void GatherQMM::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
@@ -498,15 +546,17 @@ void GatherQMM::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
auto& lhs_indices = inputs[4];
|
||||
auto& rhs_indices = inputs[5];
|
||||
|
||||
auto ensure_row_contiguous_last_dims = [](const array& arr) {
|
||||
std::vector<array> temps;
|
||||
auto ensure_row_contiguous_last_dims = [s = stream(),
|
||||
&temps](const array& arr) {
|
||||
auto stride_0 = arr.strides()[arr.ndim() - 2];
|
||||
auto stride_1 = arr.strides()[arr.ndim() - 1];
|
||||
if (stride_0 == arr.shape(-1) && stride_1 == 1) {
|
||||
return arr;
|
||||
} else {
|
||||
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
|
||||
copy(arr, arr_copy, CopyType::General);
|
||||
return arr_copy;
|
||||
temps.push_back(array(arr.shape(), arr.dtype(), nullptr, {}));
|
||||
copy(arr, temps.back(), CopyType::General, s);
|
||||
return temps.back();
|
||||
}
|
||||
};
|
||||
|
||||
@@ -515,42 +565,59 @@ void GatherQMM::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
auto scales = ensure_row_contiguous_last_dims(scales_pre);
|
||||
auto biases = ensure_row_contiguous_last_dims(biases_pre);
|
||||
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
_bs_qmm_dispatch(
|
||||
out,
|
||||
x,
|
||||
w,
|
||||
scales,
|
||||
biases,
|
||||
lhs_indices,
|
||||
rhs_indices,
|
||||
group_size_,
|
||||
bits_,
|
||||
transpose_);
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
|
||||
auto& encoder = cpu::get_command_encoder(stream());
|
||||
encoder.add_temporaries(std::move(temps));
|
||||
encoder.set_input_array(x);
|
||||
encoder.set_input_array(w);
|
||||
encoder.set_input_array(scales);
|
||||
encoder.set_input_array(biases);
|
||||
encoder.set_input_array(lhs_indices);
|
||||
encoder.set_input_array(rhs_indices);
|
||||
encoder.set_output_array(out);
|
||||
encoder.dispatch([out = array::unsafe_weak_copy(out),
|
||||
x = array::unsafe_weak_copy(x),
|
||||
w = array::unsafe_weak_copy(w),
|
||||
scales = array::unsafe_weak_copy(scales),
|
||||
biases = array::unsafe_weak_copy(biases),
|
||||
lhs_indices = array::unsafe_weak_copy(lhs_indices),
|
||||
rhs_indices = array::unsafe_weak_copy(rhs_indices),
|
||||
group_size_ = group_size_,
|
||||
bits_ = bits_,
|
||||
transpose_ = transpose_]() mutable {
|
||||
_bs_qmm_dispatch(
|
||||
out,
|
||||
x,
|
||||
w,
|
||||
scales,
|
||||
biases,
|
||||
lhs_indices,
|
||||
rhs_indices,
|
||||
group_size_,
|
||||
bits_,
|
||||
transpose_);
|
||||
});
|
||||
}
|
||||
|
||||
template <typename T, typename U>
|
||||
void quantize(
|
||||
const array& w_,
|
||||
array& out_,
|
||||
array& scales_,
|
||||
array& biases_,
|
||||
const T* w,
|
||||
U* out,
|
||||
T* scales,
|
||||
T* biases,
|
||||
int bits,
|
||||
int group_size) {
|
||||
const T* w = w_.data<T>();
|
||||
|
||||
auto out = out_.data<U>();
|
||||
T* scales = scales_.data<T>();
|
||||
T* biases = biases_.data<T>();
|
||||
|
||||
int group_size,
|
||||
size_t w_size) {
|
||||
float n_bins = (1 << bits) - 1;
|
||||
float eps = 1e-7;
|
||||
|
||||
bool power_of_2_bits = is_power_of_2(bits);
|
||||
int el_per_int = bits == 3 ? 8 : bits == 6 ? 4 : 32 / bits;
|
||||
// For 3/6 bits we read 3 uint8s at a time instead of 1 uint32
|
||||
int bytes_per_pack = power_of_2_bits ? 1 : 3;
|
||||
int int_per_group = group_size * bytes_per_pack / el_per_int;
|
||||
size_t n_groups = w_.size() / group_size;
|
||||
size_t n_groups = w_size / group_size;
|
||||
|
||||
for (size_t i = 0; i < n_groups; ++i) {
|
||||
size_t w_idx = i * group_size;
|
||||
@@ -593,50 +660,86 @@ void quantize(
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename U>
|
||||
void dispatch_quantize(
|
||||
const array& w,
|
||||
array& out,
|
||||
array& scales,
|
||||
array& biases,
|
||||
int bits,
|
||||
int group_size) {
|
||||
auto w_ptr = w.data<T>();
|
||||
auto out_ptr = out.data<U>();
|
||||
auto scales_ptr = scales.data<T>();
|
||||
auto biases_ptr = biases.data<T>();
|
||||
quantize<T, U>(
|
||||
w_ptr, out_ptr, scales_ptr, biases_ptr, bits, group_size, w.size());
|
||||
}
|
||||
|
||||
void fast::AffineQuantize::eval_cpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
auto ensure_row_contiguous = [](const array& arr) {
|
||||
auto ensure_row_contiguous = [s = stream()](const array& arr) {
|
||||
if (arr.flags().row_contiguous) {
|
||||
return arr;
|
||||
return std::make_pair(arr, false);
|
||||
} else {
|
||||
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
|
||||
copy(arr, arr_copy, CopyType::General);
|
||||
return arr_copy;
|
||||
copy(arr, arr_copy, CopyType::General, s);
|
||||
return std::make_pair(arr_copy, true);
|
||||
}
|
||||
};
|
||||
auto w = ensure_row_contiguous(inputs[0]);
|
||||
|
||||
auto [w, copied] = ensure_row_contiguous(inputs[0]);
|
||||
auto& out = outputs[0];
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
|
||||
auto& scales = outputs[1];
|
||||
auto& biases = outputs[2];
|
||||
scales.set_data(allocator::malloc_or_wait(scales.nbytes()));
|
||||
biases.set_data(allocator::malloc_or_wait(biases.nbytes()));
|
||||
if (w.dtype() == float16) {
|
||||
if (is_power_of_2(bits_)) {
|
||||
quantize<float16_t, uint32_t>(w, out, scales, biases, bits_, group_size_);
|
||||
} else {
|
||||
quantize<float16_t, uint8_t>(w, out, scales, biases, bits_, group_size_);
|
||||
}
|
||||
} else if (w.dtype() == bfloat16) {
|
||||
if (is_power_of_2(bits_)) {
|
||||
quantize<bfloat16_t, uint32_t>(
|
||||
w, out, scales, biases, bits_, group_size_);
|
||||
} else {
|
||||
quantize<bfloat16_t, uint8_t>(w, out, scales, biases, bits_, group_size_);
|
||||
}
|
||||
} else if (w.dtype() == float32) {
|
||||
if (is_power_of_2(bits_)) {
|
||||
quantize<float, uint32_t>(w, out, scales, biases, bits_, group_size_);
|
||||
} else {
|
||||
quantize<float, uint8_t>(w, out, scales, biases, bits_, group_size_);
|
||||
}
|
||||
} else {
|
||||
throw std::runtime_error(
|
||||
"[fast::AffineQuantize::eval_cpu] Only supports floating point inputs");
|
||||
scales.set_data(allocator::malloc(scales.nbytes()));
|
||||
biases.set_data(allocator::malloc(biases.nbytes()));
|
||||
auto& encoder = cpu::get_command_encoder(stream());
|
||||
if (copied) {
|
||||
encoder.add_temporary(w);
|
||||
}
|
||||
encoder.set_input_array(w);
|
||||
encoder.set_input_array(scales);
|
||||
encoder.set_input_array(biases);
|
||||
encoder.set_output_array(out);
|
||||
encoder.dispatch([w = array::unsafe_weak_copy(w),
|
||||
out = array::unsafe_weak_copy(out),
|
||||
scales = array::unsafe_weak_copy(scales),
|
||||
biases = array::unsafe_weak_copy(biases),
|
||||
group_size_ = group_size_,
|
||||
bits_ = bits_]() mutable {
|
||||
if (w.dtype() == float16) {
|
||||
if (is_power_of_2(bits_)) {
|
||||
dispatch_quantize<float16_t, uint32_t>(
|
||||
w, out, scales, biases, bits_, group_size_);
|
||||
} else {
|
||||
dispatch_quantize<float16_t, uint8_t>(
|
||||
w, out, scales, biases, bits_, group_size_);
|
||||
}
|
||||
} else if (w.dtype() == bfloat16) {
|
||||
if (is_power_of_2(bits_)) {
|
||||
dispatch_quantize<bfloat16_t, uint32_t>(
|
||||
w, out, scales, biases, bits_, group_size_);
|
||||
} else {
|
||||
dispatch_quantize<bfloat16_t, uint8_t>(
|
||||
w, out, scales, biases, bits_, group_size_);
|
||||
}
|
||||
} else if (w.dtype() == float32) {
|
||||
if (is_power_of_2(bits_)) {
|
||||
dispatch_quantize<float, uint32_t>(
|
||||
w, out, scales, biases, bits_, group_size_);
|
||||
} else {
|
||||
dispatch_quantize<float, uint8_t>(
|
||||
w, out, scales, biases, bits_, group_size_);
|
||||
}
|
||||
} else {
|
||||
throw std::runtime_error(
|
||||
"[fast::AffineQuantize::eval_cpu] Only supports floating point inputs");
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
#include <limits>
|
||||
|
||||
#include "mlx/backend/common/reduce.h"
|
||||
#include "mlx/backend/cpu/encoder.h"
|
||||
#include "mlx/backend/cpu/simd/simd.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
@@ -139,25 +140,22 @@ void reduction_op(
|
||||
const array& x,
|
||||
array& out,
|
||||
const std::vector<int>& axes,
|
||||
U init,
|
||||
Op op) {
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
U init) {
|
||||
ReductionPlan plan = get_reduction_plan(x, axes);
|
||||
|
||||
auto in_ptr = x.data<T>();
|
||||
auto out_ptr = out.data<U>();
|
||||
if (plan.type == ContiguousAllReduce) {
|
||||
U* out_ptr = out.data<U>();
|
||||
*out_ptr = init;
|
||||
contiguous_reduce(x.data<T>(), out_ptr, x.size(), op, init);
|
||||
contiguous_reduce(in_ptr, out_ptr, x.size(), Op{}, init);
|
||||
return;
|
||||
}
|
||||
|
||||
if (plan.type == ContiguousReduce && plan.shape.size() == 1) {
|
||||
int reduction_size = plan.shape[0];
|
||||
const T* x_ptr = x.data<T>();
|
||||
U* out_ptr = out.data<U>();
|
||||
for (int i = 0; i < out.size(); i++, out_ptr++, x_ptr += reduction_size) {
|
||||
for (int i = 0; i < out.size(); i++, out_ptr++, in_ptr += reduction_size) {
|
||||
*out_ptr = init;
|
||||
contiguous_reduce(x_ptr, out_ptr, reduction_size, op, init);
|
||||
contiguous_reduce(in_ptr, out_ptr, reduction_size, Op{}, init);
|
||||
}
|
||||
return;
|
||||
}
|
||||
@@ -166,8 +164,6 @@ void reduction_op(
|
||||
int reduction_size = plan.shape.back();
|
||||
plan.shape.pop_back();
|
||||
plan.strides.pop_back();
|
||||
const T* x_ptr = x.data<T>();
|
||||
U* out_ptr = out.data<U>();
|
||||
// Unrolling the following loop (and implementing it in order for
|
||||
// ContiguousReduce) should hold extra performance boost.
|
||||
auto [shape, strides] = shapes_without_reduction_axes(x, axes);
|
||||
@@ -175,7 +171,7 @@ void reduction_op(
|
||||
for (int i = 0; i < out.size(); i++, out_ptr++) {
|
||||
int offset = elem_to_loc(i, shape, strides);
|
||||
*out_ptr = init;
|
||||
contiguous_reduce(x_ptr + offset, out_ptr, reduction_size, op, init);
|
||||
contiguous_reduce(in_ptr + offset, out_ptr, reduction_size, Op{}, init);
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < out.size(); i++, out_ptr++) {
|
||||
@@ -184,10 +180,10 @@ void reduction_op(
|
||||
nd_loop(
|
||||
[&](int extra_offset) {
|
||||
contiguous_reduce(
|
||||
x_ptr + offset + extra_offset,
|
||||
in_ptr + offset + extra_offset,
|
||||
out_ptr,
|
||||
reduction_size,
|
||||
op,
|
||||
Op{},
|
||||
init);
|
||||
},
|
||||
plan.shape,
|
||||
@@ -202,12 +198,10 @@ void reduction_op(
|
||||
size_t reduction_stride = plan.strides.back();
|
||||
plan.shape.pop_back();
|
||||
plan.strides.pop_back();
|
||||
const T* x_ptr = x.data<T>();
|
||||
U* out_ptr = out.data<U>();
|
||||
for (int i = 0; i < out.size(); i += reduction_stride) {
|
||||
std::fill_n(out_ptr, reduction_stride, init);
|
||||
strided_reduce(x_ptr, out_ptr, reduction_size, reduction_stride, op);
|
||||
x_ptr += reduction_stride * reduction_size;
|
||||
strided_reduce(in_ptr, out_ptr, reduction_size, reduction_stride, Op{});
|
||||
in_ptr += reduction_stride * reduction_size;
|
||||
out_ptr += reduction_stride;
|
||||
}
|
||||
return;
|
||||
@@ -219,15 +213,14 @@ void reduction_op(
|
||||
size_t reduction_stride = plan.strides.back();
|
||||
plan.shape.pop_back();
|
||||
plan.strides.pop_back();
|
||||
const T* x_ptr = x.data<T>();
|
||||
U* out_ptr = out.data<U>();
|
||||
auto [shape, strides] = shapes_without_reduction_axes(x, axes);
|
||||
|
||||
if (plan.shape.size() == 0) {
|
||||
for (int i = 0; i < out.size(); i += reduction_stride) {
|
||||
int offset = elem_to_loc(i, shape, strides);
|
||||
std::fill_n(out_ptr, reduction_stride, init);
|
||||
strided_reduce(
|
||||
x_ptr + offset, out_ptr, reduction_size, reduction_stride, op);
|
||||
in_ptr + offset, out_ptr, reduction_size, reduction_stride, Op{});
|
||||
out_ptr += reduction_stride;
|
||||
}
|
||||
} else {
|
||||
@@ -237,11 +230,11 @@ void reduction_op(
|
||||
nd_loop(
|
||||
[&](int extra_offset) {
|
||||
strided_reduce(
|
||||
x_ptr + offset + extra_offset,
|
||||
in_ptr + offset + extra_offset,
|
||||
out_ptr,
|
||||
reduction_size,
|
||||
reduction_stride,
|
||||
op);
|
||||
Op{});
|
||||
},
|
||||
plan.shape,
|
||||
plan.strides);
|
||||
@@ -252,15 +245,14 @@ void reduction_op(
|
||||
}
|
||||
|
||||
if (plan.type == GeneralReduce) {
|
||||
const T* x_ptr = x.data<T>();
|
||||
U* out_ptr = out.data<U>();
|
||||
auto [shape, strides] = shapes_without_reduction_axes(x, axes);
|
||||
|
||||
for (int i = 0; i < out.size(); i++, out_ptr++) {
|
||||
int offset = elem_to_loc(i, shape, strides);
|
||||
U val = init;
|
||||
nd_loop(
|
||||
[&](int extra_offset) {
|
||||
val = op(val, *(x_ptr + offset + extra_offset));
|
||||
val = Op{}(val, *(in_ptr + offset + extra_offset));
|
||||
},
|
||||
plan.shape,
|
||||
plan.strides);
|
||||
@@ -396,9 +388,9 @@ void reduce_dispatch_and_or(
|
||||
Reduce::ReduceType rtype,
|
||||
const std::vector<int>& axes) {
|
||||
if (rtype == Reduce::And) {
|
||||
reduction_op<InT, bool>(in, out, axes, true, AndReduce());
|
||||
reduction_op<InT, bool, AndReduce>(in, out, axes, true);
|
||||
} else {
|
||||
reduction_op<InT, bool>(in, out, axes, false, OrReduce());
|
||||
reduction_op<InT, bool, OrReduce>(in, out, axes, false);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -410,15 +402,15 @@ void reduce_dispatch_sum_prod(
|
||||
const std::vector<int>& axes) {
|
||||
if (rtype == Reduce::Sum) {
|
||||
if constexpr (std::is_integral_v<InT> && sizeof(InT) <= 4) {
|
||||
reduction_op<InT, int32_t>(in, out, axes, 0, SumReduce());
|
||||
reduction_op<InT, int32_t, SumReduce>(in, out, axes, 0);
|
||||
} else {
|
||||
reduction_op<InT, InT>(in, out, axes, 0, SumReduce());
|
||||
reduction_op<InT, InT, SumReduce>(in, out, axes, 0);
|
||||
}
|
||||
} else {
|
||||
if constexpr (std::is_integral_v<InT> && sizeof(InT) <= 4) {
|
||||
reduction_op<InT, int32_t>(in, out, axes, 1, ProdReduce());
|
||||
reduction_op<InT, int32_t, ProdReduce>(in, out, axes, 1);
|
||||
} else {
|
||||
reduction_op<InT, InT>(in, out, axes, 1, ProdReduce());
|
||||
reduction_op<InT, InT, ProdReduce>(in, out, axes, 1);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -431,132 +423,141 @@ void reduce_dispatch_min_max(
|
||||
const std::vector<int>& axes) {
|
||||
if (rtype == Reduce::Max) {
|
||||
auto init = Limits<InT>::min;
|
||||
reduction_op<InT, InT>(in, out, axes, init, MaxReduce());
|
||||
reduction_op<InT, InT, MaxReduce>(in, out, axes, init);
|
||||
} else {
|
||||
auto init = Limits<InT>::max;
|
||||
reduction_op<InT, InT>(in, out, axes, init, MinReduce());
|
||||
reduction_op<InT, InT, MinReduce>(in, out, axes, init);
|
||||
}
|
||||
}
|
||||
|
||||
void Reduce::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
switch (reduce_type_) {
|
||||
case Reduce::And:
|
||||
case Reduce::Or: {
|
||||
switch (in.dtype()) {
|
||||
case bool_:
|
||||
case uint8:
|
||||
case int8:
|
||||
reduce_dispatch_and_or<int8_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case int16:
|
||||
case uint16:
|
||||
case float16:
|
||||
case bfloat16:
|
||||
reduce_dispatch_and_or<int16_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case uint32:
|
||||
case int32:
|
||||
case float32:
|
||||
reduce_dispatch_and_or<int32_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case uint64:
|
||||
case int64:
|
||||
case float64:
|
||||
case complex64:
|
||||
reduce_dispatch_and_or<int64_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
auto& encoder = cpu::get_command_encoder(stream());
|
||||
encoder.set_input_array(in);
|
||||
encoder.set_output_array(out);
|
||||
encoder.dispatch([in = array::unsafe_weak_copy(in),
|
||||
out = array::unsafe_weak_copy(out),
|
||||
reduce_type_ = reduce_type_,
|
||||
axes_ = axes_]() mutable {
|
||||
switch (reduce_type_) {
|
||||
case Reduce::And:
|
||||
case Reduce::Or: {
|
||||
switch (in.dtype()) {
|
||||
case bool_:
|
||||
case uint8:
|
||||
case int8:
|
||||
reduce_dispatch_and_or<int8_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case int16:
|
||||
case uint16:
|
||||
case float16:
|
||||
case bfloat16:
|
||||
reduce_dispatch_and_or<int16_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case uint32:
|
||||
case int32:
|
||||
case float32:
|
||||
reduce_dispatch_and_or<int32_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case uint64:
|
||||
case int64:
|
||||
case float64:
|
||||
case complex64:
|
||||
reduce_dispatch_and_or<int64_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
}
|
||||
break;
|
||||
}
|
||||
break;
|
||||
}
|
||||
case Reduce::Sum:
|
||||
case Reduce::Prod: {
|
||||
switch (in.dtype()) {
|
||||
case bool_:
|
||||
case uint8:
|
||||
case int8:
|
||||
reduce_dispatch_sum_prod<int8_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case int16:
|
||||
case uint16:
|
||||
reduce_dispatch_sum_prod<int16_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case int32:
|
||||
case uint32:
|
||||
reduce_dispatch_sum_prod<int32_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case int64:
|
||||
case uint64:
|
||||
reduce_dispatch_sum_prod<int64_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case float16:
|
||||
reduce_dispatch_sum_prod<float16_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case bfloat16:
|
||||
reduce_dispatch_sum_prod<bfloat16_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case float32:
|
||||
reduce_dispatch_sum_prod<float>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case float64:
|
||||
reduce_dispatch_sum_prod<double>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case complex64:
|
||||
reduce_dispatch_sum_prod<complex64_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case Reduce::Sum:
|
||||
case Reduce::Prod: {
|
||||
switch (in.dtype()) {
|
||||
case bool_:
|
||||
case uint8:
|
||||
case int8:
|
||||
reduce_dispatch_sum_prod<int8_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case int16:
|
||||
case uint16:
|
||||
reduce_dispatch_sum_prod<int16_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case int32:
|
||||
case uint32:
|
||||
reduce_dispatch_sum_prod<int32_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case int64:
|
||||
case uint64:
|
||||
reduce_dispatch_sum_prod<int64_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case float16:
|
||||
reduce_dispatch_sum_prod<float16_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case bfloat16:
|
||||
reduce_dispatch_sum_prod<bfloat16_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case float32:
|
||||
reduce_dispatch_sum_prod<float>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case float64:
|
||||
reduce_dispatch_sum_prod<double>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case complex64:
|
||||
reduce_dispatch_sum_prod<complex64_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
}
|
||||
break;
|
||||
}
|
||||
break;
|
||||
}
|
||||
case Reduce::Max:
|
||||
case Reduce::Min: {
|
||||
switch (in.dtype()) {
|
||||
case bool_:
|
||||
reduce_dispatch_min_max<bool>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case uint8:
|
||||
reduce_dispatch_min_max<uint8_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case uint16:
|
||||
reduce_dispatch_min_max<uint16_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case uint32:
|
||||
reduce_dispatch_min_max<uint32_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case uint64:
|
||||
reduce_dispatch_min_max<uint64_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case int8:
|
||||
reduce_dispatch_min_max<uint8_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case int16:
|
||||
reduce_dispatch_min_max<uint16_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case int32:
|
||||
reduce_dispatch_min_max<int32_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case int64:
|
||||
reduce_dispatch_min_max<int64_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case float16:
|
||||
reduce_dispatch_min_max<float16_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case float32:
|
||||
reduce_dispatch_min_max<float>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case float64:
|
||||
reduce_dispatch_min_max<double>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case bfloat16:
|
||||
reduce_dispatch_min_max<bfloat16_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case complex64:
|
||||
reduce_dispatch_min_max<complex64_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case Reduce::Max:
|
||||
case Reduce::Min: {
|
||||
switch (in.dtype()) {
|
||||
case bool_:
|
||||
reduce_dispatch_min_max<bool>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case uint8:
|
||||
reduce_dispatch_min_max<uint8_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case uint16:
|
||||
reduce_dispatch_min_max<uint16_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case uint32:
|
||||
reduce_dispatch_min_max<uint32_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case uint64:
|
||||
reduce_dispatch_min_max<uint64_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case int8:
|
||||
reduce_dispatch_min_max<uint8_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case int16:
|
||||
reduce_dispatch_min_max<uint16_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case int32:
|
||||
reduce_dispatch_min_max<int32_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case int64:
|
||||
reduce_dispatch_min_max<int64_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case float16:
|
||||
reduce_dispatch_min_max<float16_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case float32:
|
||||
reduce_dispatch_min_max<float>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case float64:
|
||||
reduce_dispatch_min_max<double>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case bfloat16:
|
||||
reduce_dispatch_min_max<bfloat16_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case complex64:
|
||||
reduce_dispatch_min_max<complex64_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
}
|
||||
break;
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
||||
@@ -3,7 +3,9 @@
|
||||
#include <cassert>
|
||||
|
||||
#include "mlx/backend/common/utils.h"
|
||||
#include "mlx/backend/cpu/binary_ops.h"
|
||||
#include "mlx/backend/cpu/copy.h"
|
||||
#include "mlx/backend/cpu/encoder.h"
|
||||
#include "mlx/backend/cpu/simd/simd.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
@@ -153,33 +155,31 @@ void strided_scan(
|
||||
|
||||
template <typename T, typename U, typename Op>
|
||||
void scan_op(
|
||||
const array& input,
|
||||
array& output,
|
||||
const array& in,
|
||||
array& out,
|
||||
int axis,
|
||||
bool reverse,
|
||||
bool inclusive,
|
||||
const Op& op,
|
||||
U init) {
|
||||
output.set_data(allocator::malloc_or_wait(output.nbytes()));
|
||||
|
||||
if (input.flags().row_contiguous) {
|
||||
if (input.strides()[axis] == 1) {
|
||||
if (in.flags().row_contiguous) {
|
||||
if (in.strides()[axis] == 1) {
|
||||
contiguous_scan(
|
||||
input.data<T>(),
|
||||
output.data<U>(),
|
||||
input.size() / input.shape(axis),
|
||||
input.shape(axis),
|
||||
in.data<T>(),
|
||||
out.data<U>(),
|
||||
in.size() / in.shape(axis),
|
||||
in.shape(axis),
|
||||
reverse,
|
||||
inclusive,
|
||||
op,
|
||||
init);
|
||||
} else {
|
||||
strided_scan(
|
||||
input.data<T>(),
|
||||
output.data<U>(),
|
||||
input.size() / input.shape(axis) / input.strides()[axis],
|
||||
input.shape(axis),
|
||||
input.strides()[axis],
|
||||
in.data<T>(),
|
||||
out.data<U>(),
|
||||
in.size() / in.shape(axis) / in.strides()[axis],
|
||||
in.shape(axis),
|
||||
in.strides()[axis],
|
||||
reverse,
|
||||
inclusive,
|
||||
op,
|
||||
@@ -193,8 +193,8 @@ void scan_op(
|
||||
template <typename T, typename U>
|
||||
void scan_dispatch(
|
||||
Scan::ReduceType rtype,
|
||||
const array& input,
|
||||
array& output,
|
||||
const array& in,
|
||||
array& out,
|
||||
int axis,
|
||||
bool reverse,
|
||||
bool inclusive) {
|
||||
@@ -202,29 +202,39 @@ void scan_dispatch(
|
||||
case Scan::Sum: {
|
||||
auto op = [](U y, T x) { return y + x; };
|
||||
auto init = static_cast<U>(0);
|
||||
scan_op<T, U>(input, output, axis, reverse, inclusive, op, init);
|
||||
scan_op<T, U>(in, out, axis, reverse, inclusive, op, init);
|
||||
break;
|
||||
}
|
||||
case Scan::Prod: {
|
||||
auto op = [](U y, T x) { return y * x; };
|
||||
auto init = static_cast<U>(1);
|
||||
scan_op<T, U>(input, output, axis, reverse, inclusive, op, init);
|
||||
scan_op<T, U>(in, out, axis, reverse, inclusive, op, init);
|
||||
break;
|
||||
}
|
||||
case Scan::Min: {
|
||||
auto op = [](U y, T x) { return x < y ? x : y; };
|
||||
auto init = (issubdtype(input.dtype(), floating))
|
||||
auto init = (issubdtype(in.dtype(), floating))
|
||||
? static_cast<U>(std::numeric_limits<float>::infinity())
|
||||
: std::numeric_limits<U>::max();
|
||||
scan_op<T, U>(input, output, axis, reverse, inclusive, op, init);
|
||||
scan_op<T, U>(in, out, axis, reverse, inclusive, op, init);
|
||||
break;
|
||||
}
|
||||
case Scan::Max: {
|
||||
auto op = [](U y, T x) { return x < y ? y : x; };
|
||||
auto init = (issubdtype(input.dtype(), floating))
|
||||
auto init = (issubdtype(in.dtype(), floating))
|
||||
? static_cast<U>(-std::numeric_limits<float>::infinity())
|
||||
: std::numeric_limits<U>::min();
|
||||
scan_op<T, U>(input, output, axis, reverse, inclusive, op, init);
|
||||
scan_op<T, U>(in, out, axis, reverse, inclusive, op, init);
|
||||
break;
|
||||
}
|
||||
case Scan::LogAddExp: {
|
||||
auto op = [](U a, T b) {
|
||||
return detail::LogAddExp{}(a, static_cast<U>(b));
|
||||
};
|
||||
auto init = (issubdtype(in.dtype(), floating))
|
||||
? static_cast<U>(-std::numeric_limits<float>::infinity())
|
||||
: std::numeric_limits<U>::min();
|
||||
scan_op<T, U>(in, out, axis, reverse, inclusive, op, init);
|
||||
break;
|
||||
}
|
||||
}
|
||||
@@ -235,82 +245,95 @@ void scan_dispatch(
|
||||
void Scan::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
|
||||
auto& encoder = cpu::get_command_encoder(stream());
|
||||
|
||||
// Ensure contiguity
|
||||
auto in = inputs[0];
|
||||
if (!in.flags().row_contiguous) {
|
||||
array arr_copy(in.shape(), in.dtype(), nullptr, {});
|
||||
copy(in, arr_copy, CopyType::General);
|
||||
copy(in, arr_copy, CopyType::General, stream());
|
||||
in = arr_copy;
|
||||
encoder.add_temporary(arr_copy);
|
||||
}
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
|
||||
switch (in.dtype()) {
|
||||
case bool_: {
|
||||
// We could do a full dtype x dtype switch but this is the only case
|
||||
// where we accumulate in a different type, for now.
|
||||
//
|
||||
// TODO: If we add the option to accumulate floats in higher precision
|
||||
// floats perhaps we should add the full all-to-all dispatch.
|
||||
if (reduce_type_ == Scan::Sum && out.dtype() == int32) {
|
||||
scan_dispatch<bool, int32_t>(
|
||||
reduce_type_, in, out, axis_, reverse_, inclusive_);
|
||||
} else {
|
||||
scan_dispatch<bool, bool>(
|
||||
reduce_type_, in, out, axis_, reverse_, inclusive_);
|
||||
encoder.set_input_array(in);
|
||||
encoder.set_output_array(out);
|
||||
encoder.dispatch([in = array::unsafe_weak_copy(in),
|
||||
out = array::unsafe_weak_copy(out),
|
||||
axis_ = axis_,
|
||||
reduce_type_ = reduce_type_,
|
||||
reverse_ = reverse_,
|
||||
inclusive_ = inclusive_]() mutable {
|
||||
switch (in.dtype()) {
|
||||
case bool_: {
|
||||
// We could do a full dtype x dtype switch but this is the only case
|
||||
// where we accumulate in a different type, for now.
|
||||
//
|
||||
// TODO: If we add the option to accumulate floats in higher precision
|
||||
// floats perhaps we should add the full all-to-all dispatch.
|
||||
if (reduce_type_ == Scan::Sum && out.dtype() == int32) {
|
||||
scan_dispatch<bool, int32_t>(
|
||||
reduce_type_, in, out, axis_, reverse_, inclusive_);
|
||||
} else {
|
||||
scan_dispatch<bool, bool>(
|
||||
reduce_type_, in, out, axis_, reverse_, inclusive_);
|
||||
}
|
||||
break;
|
||||
}
|
||||
break;
|
||||
case uint8:
|
||||
scan_dispatch<uint8_t, uint8_t>(
|
||||
reduce_type_, in, out, axis_, reverse_, inclusive_);
|
||||
break;
|
||||
case uint16:
|
||||
scan_dispatch<uint16_t, uint16_t>(
|
||||
reduce_type_, in, out, axis_, reverse_, inclusive_);
|
||||
break;
|
||||
case uint32:
|
||||
scan_dispatch<uint32_t, uint32_t>(
|
||||
reduce_type_, in, out, axis_, reverse_, inclusive_);
|
||||
break;
|
||||
case uint64:
|
||||
scan_dispatch<uint64_t, uint64_t>(
|
||||
reduce_type_, in, out, axis_, reverse_, inclusive_);
|
||||
break;
|
||||
case int8:
|
||||
scan_dispatch<int8_t, int8_t>(
|
||||
reduce_type_, in, out, axis_, reverse_, inclusive_);
|
||||
break;
|
||||
case int16:
|
||||
scan_dispatch<int16_t, int16_t>(
|
||||
reduce_type_, in, out, axis_, reverse_, inclusive_);
|
||||
break;
|
||||
case int32:
|
||||
scan_dispatch<int32_t, int32_t>(
|
||||
reduce_type_, in, out, axis_, reverse_, inclusive_);
|
||||
break;
|
||||
case int64:
|
||||
scan_dispatch<int64_t, int64_t>(
|
||||
reduce_type_, in, out, axis_, reverse_, inclusive_);
|
||||
break;
|
||||
case float16:
|
||||
scan_dispatch<float16_t, float16_t>(
|
||||
reduce_type_, in, out, axis_, reverse_, inclusive_);
|
||||
break;
|
||||
case float32:
|
||||
scan_dispatch<float, float>(
|
||||
reduce_type_, in, out, axis_, reverse_, inclusive_);
|
||||
break;
|
||||
case float64:
|
||||
scan_dispatch<double, double>(
|
||||
reduce_type_, in, out, axis_, reverse_, inclusive_);
|
||||
break;
|
||||
case bfloat16:
|
||||
scan_dispatch<bfloat16_t, bfloat16_t>(
|
||||
reduce_type_, in, out, axis_, reverse_, inclusive_);
|
||||
break;
|
||||
case complex64:
|
||||
throw std::runtime_error("Scan ops do not support complex types yet");
|
||||
break;
|
||||
}
|
||||
case uint8:
|
||||
scan_dispatch<uint8_t, uint8_t>(
|
||||
reduce_type_, in, out, axis_, reverse_, inclusive_);
|
||||
break;
|
||||
case uint16:
|
||||
scan_dispatch<uint16_t, uint16_t>(
|
||||
reduce_type_, in, out, axis_, reverse_, inclusive_);
|
||||
break;
|
||||
case uint32:
|
||||
scan_dispatch<uint32_t, uint32_t>(
|
||||
reduce_type_, in, out, axis_, reverse_, inclusive_);
|
||||
break;
|
||||
case uint64:
|
||||
scan_dispatch<uint64_t, uint64_t>(
|
||||
reduce_type_, in, out, axis_, reverse_, inclusive_);
|
||||
break;
|
||||
case int8:
|
||||
scan_dispatch<int8_t, int8_t>(
|
||||
reduce_type_, in, out, axis_, reverse_, inclusive_);
|
||||
break;
|
||||
case int16:
|
||||
scan_dispatch<int16_t, int16_t>(
|
||||
reduce_type_, in, out, axis_, reverse_, inclusive_);
|
||||
break;
|
||||
case int32:
|
||||
scan_dispatch<int32_t, int32_t>(
|
||||
reduce_type_, in, out, axis_, reverse_, inclusive_);
|
||||
break;
|
||||
case int64:
|
||||
scan_dispatch<int64_t, int64_t>(
|
||||
reduce_type_, in, out, axis_, reverse_, inclusive_);
|
||||
break;
|
||||
case float16:
|
||||
scan_dispatch<float16_t, float16_t>(
|
||||
reduce_type_, in, out, axis_, reverse_, inclusive_);
|
||||
break;
|
||||
case float32:
|
||||
scan_dispatch<float, float>(
|
||||
reduce_type_, in, out, axis_, reverse_, inclusive_);
|
||||
break;
|
||||
case float64:
|
||||
scan_dispatch<double, double>(
|
||||
reduce_type_, in, out, axis_, reverse_, inclusive_);
|
||||
break;
|
||||
case bfloat16:
|
||||
scan_dispatch<bfloat16_t, bfloat16_t>(
|
||||
reduce_type_, in, out, axis_, reverse_, inclusive_);
|
||||
break;
|
||||
case complex64:
|
||||
throw std::runtime_error("Scan ops do not support complex types yet");
|
||||
break;
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
||||
@@ -16,51 +16,70 @@ void select_op(
|
||||
const array& b,
|
||||
const array& c,
|
||||
array& out,
|
||||
Op op) {
|
||||
switch (out.dtype()) {
|
||||
case bool_:
|
||||
ternary_op<bool, bool, bool, bool>(a, b, c, out, op);
|
||||
break;
|
||||
case uint8:
|
||||
ternary_op<bool, uint8_t, uint8_t, uint8_t>(a, b, c, out, op);
|
||||
break;
|
||||
case uint16:
|
||||
ternary_op<bool, uint16_t, uint16_t, uint16_t>(a, b, c, out, op);
|
||||
break;
|
||||
case uint32:
|
||||
ternary_op<bool, uint32_t, uint32_t, uint32_t>(a, b, c, out, op);
|
||||
break;
|
||||
case uint64:
|
||||
ternary_op<bool, uint64_t, uint64_t, uint64_t>(a, b, c, out, op);
|
||||
break;
|
||||
case int8:
|
||||
ternary_op<bool, int8_t, int8_t, int8_t>(a, b, c, out, op);
|
||||
break;
|
||||
case int16:
|
||||
ternary_op<bool, int16_t, int16_t, int16_t>(a, b, c, out, op);
|
||||
break;
|
||||
case int32:
|
||||
ternary_op<bool, int32_t, int32_t, int32_t>(a, b, c, out, op);
|
||||
break;
|
||||
case int64:
|
||||
ternary_op<bool, int64_t, int64_t, int64_t>(a, b, c, out, op);
|
||||
break;
|
||||
case float16:
|
||||
ternary_op<bool, float16_t, float16_t, float16_t>(a, b, c, out, op);
|
||||
break;
|
||||
case float32:
|
||||
ternary_op<bool, float, float, float>(a, b, c, out, op);
|
||||
break;
|
||||
case float64:
|
||||
ternary_op<bool, double, double, double>(a, b, c, out, op);
|
||||
break;
|
||||
case bfloat16:
|
||||
ternary_op<bool, bfloat16_t, bfloat16_t, bfloat16_t>(a, b, c, out, op);
|
||||
break;
|
||||
case complex64:
|
||||
ternary_op<bool, complex64_t, complex64_t, complex64_t>(a, b, c, out, op);
|
||||
break;
|
||||
}
|
||||
Op op,
|
||||
Stream stream) {
|
||||
TernaryOpType topt = get_ternary_op_type(a, b, c);
|
||||
set_ternary_op_output_data(a, b, c, out, topt);
|
||||
|
||||
auto& encoder = cpu::get_command_encoder(stream);
|
||||
encoder.set_input_array(a);
|
||||
encoder.set_input_array(b);
|
||||
encoder.set_input_array(c);
|
||||
encoder.set_output_array(out);
|
||||
encoder.dispatch([a = array::unsafe_weak_copy(a),
|
||||
b = array::unsafe_weak_copy(b),
|
||||
c = array::unsafe_weak_copy(c),
|
||||
out = array::unsafe_weak_copy(out),
|
||||
op,
|
||||
topt]() mutable {
|
||||
switch (out.dtype()) {
|
||||
case bool_:
|
||||
ternary_op<bool, bool, bool, bool>(a, b, c, out, op, topt);
|
||||
break;
|
||||
case uint8:
|
||||
ternary_op<bool, uint8_t, uint8_t, uint8_t>(a, b, c, out, op, topt);
|
||||
break;
|
||||
case uint16:
|
||||
ternary_op<bool, uint16_t, uint16_t, uint16_t>(a, b, c, out, op, topt);
|
||||
break;
|
||||
case uint32:
|
||||
ternary_op<bool, uint32_t, uint32_t, uint32_t>(a, b, c, out, op, topt);
|
||||
break;
|
||||
case uint64:
|
||||
ternary_op<bool, uint64_t, uint64_t, uint64_t>(a, b, c, out, op, topt);
|
||||
break;
|
||||
case int8:
|
||||
ternary_op<bool, int8_t, int8_t, int8_t>(a, b, c, out, op, topt);
|
||||
break;
|
||||
case int16:
|
||||
ternary_op<bool, int16_t, int16_t, int16_t>(a, b, c, out, op, topt);
|
||||
break;
|
||||
case int32:
|
||||
ternary_op<bool, int32_t, int32_t, int32_t>(a, b, c, out, op, topt);
|
||||
break;
|
||||
case int64:
|
||||
ternary_op<bool, int64_t, int64_t, int64_t>(a, b, c, out, op, topt);
|
||||
break;
|
||||
case float16:
|
||||
ternary_op<bool, float16_t, float16_t, float16_t>(
|
||||
a, b, c, out, op, topt);
|
||||
break;
|
||||
case float32:
|
||||
ternary_op<bool, float, float, float>(a, b, c, out, op, topt);
|
||||
break;
|
||||
case float64:
|
||||
ternary_op<bool, double, double, double>(a, b, c, out, op, topt);
|
||||
break;
|
||||
case bfloat16:
|
||||
ternary_op<bool, bfloat16_t, bfloat16_t, bfloat16_t>(
|
||||
a, b, c, out, op, topt);
|
||||
break;
|
||||
case complex64:
|
||||
ternary_op<bool, complex64_t, complex64_t, complex64_t>(
|
||||
a, b, c, out, op, topt);
|
||||
break;
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace
|
||||
@@ -70,7 +89,7 @@ void Select::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
const auto& condition = inputs[0];
|
||||
const auto& a = inputs[1];
|
||||
const auto& b = inputs[2];
|
||||
select_op(condition, a, b, out, detail::Select());
|
||||
select_op(condition, a, b, out, detail::Select(), stream());
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
||||
@@ -17,7 +17,7 @@ struct ScalarT<float16_t, N> {
|
||||
#endif
|
||||
|
||||
template <>
|
||||
static constexpr int max_size<float16_t> = N;
|
||||
inline constexpr int max_size<float16_t> = N;
|
||||
|
||||
#define SIMD_FP16_DEFAULT_UNARY(op) \
|
||||
template <> \
|
||||
|
||||
@@ -83,25 +83,25 @@ struct Simd {
|
||||
// Values chosen based on benchmarks on M3 Max
|
||||
// TODO: consider choosing these more optimally
|
||||
template <>
|
||||
static constexpr int max_size<int8_t> = 16;
|
||||
inline constexpr int max_size<int8_t> = 16;
|
||||
template <>
|
||||
static constexpr int max_size<int16_t> = 16;
|
||||
inline constexpr int max_size<int16_t> = 16;
|
||||
template <>
|
||||
static constexpr int max_size<int> = 8;
|
||||
inline constexpr int max_size<int> = 8;
|
||||
template <>
|
||||
static constexpr int max_size<int64_t> = 4;
|
||||
inline constexpr int max_size<int64_t> = 4;
|
||||
template <>
|
||||
static constexpr int max_size<uint8_t> = 16;
|
||||
inline constexpr int max_size<uint8_t> = 16;
|
||||
template <>
|
||||
static constexpr int max_size<uint16_t> = 16;
|
||||
inline constexpr int max_size<uint16_t> = 16;
|
||||
template <>
|
||||
static constexpr int max_size<uint32_t> = 8;
|
||||
inline constexpr int max_size<uint32_t> = 8;
|
||||
template <>
|
||||
static constexpr int max_size<uint64_t> = 4;
|
||||
inline constexpr int max_size<uint64_t> = 4;
|
||||
template <>
|
||||
static constexpr int max_size<float> = 8;
|
||||
inline constexpr int max_size<float> = 8;
|
||||
template <>
|
||||
static constexpr int max_size<double> = 4;
|
||||
inline constexpr int max_size<double> = 4;
|
||||
|
||||
#define SIMD_DEFAULT_UNARY(name, op) \
|
||||
template <typename T, int N> \
|
||||
|
||||
@@ -87,7 +87,6 @@ DEFAULT_UNARY(cosh, std::cosh)
|
||||
DEFAULT_UNARY(expm1, std::expm1)
|
||||
DEFAULT_UNARY(floor, std::floor)
|
||||
DEFAULT_UNARY(log, std::log)
|
||||
DEFAULT_UNARY(log2, std::log2)
|
||||
DEFAULT_UNARY(log10, std::log10)
|
||||
DEFAULT_UNARY(log1p, std::log1p)
|
||||
DEFAULT_UNARY(sinh, std::sinh)
|
||||
@@ -95,6 +94,17 @@ DEFAULT_UNARY(sqrt, std::sqrt)
|
||||
DEFAULT_UNARY(tan, std::tan)
|
||||
DEFAULT_UNARY(tanh, std::tanh)
|
||||
|
||||
template <typename T>
|
||||
Simd<T, 1> log2(Simd<T, 1> in) {
|
||||
if constexpr (is_complex<T>) {
|
||||
auto out = std::log(in.value);
|
||||
auto scale = decltype(out.real())(M_LN2);
|
||||
return Simd<T, 1>{T{out.real() / scale, out.imag() / scale}};
|
||||
} else {
|
||||
return Simd<T, 1>{std::log2(in.value)};
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
Simd<T, 1> operator~(Simd<T, 1> in) {
|
||||
return ~in.value;
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
#include <cmath>
|
||||
|
||||
#include "mlx/backend/cpu/copy.h"
|
||||
#include "mlx/backend/cpu/encoder.h"
|
||||
#include "mlx/backend/cpu/simd/simd.h"
|
||||
#include "mlx/primitives.h"
|
||||
#include "mlx/types/limits.h"
|
||||
@@ -15,92 +16,100 @@ namespace {
|
||||
using namespace mlx::core::simd;
|
||||
|
||||
template <typename T, typename AccT>
|
||||
void softmax(const array& in, array& out) {
|
||||
constexpr bool same_t = std::is_same_v<T, AccT>;
|
||||
constexpr int N = std::min(max_size<AccT>, max_size<T>);
|
||||
void softmax(const array& in, array& out, Stream stream) {
|
||||
auto& encoder = cpu::get_command_encoder(stream);
|
||||
encoder.set_input_array(in);
|
||||
encoder.set_output_array(out);
|
||||
|
||||
const T* in_ptr = in.data<T>();
|
||||
T* out_ptr = out.data<T>();
|
||||
|
||||
int M = in.shape().back();
|
||||
int L = in.data_size() / M;
|
||||
const T* current_in_ptr;
|
||||
T* current_out_ptr;
|
||||
|
||||
for (int i = 0; i < L; i++, in_ptr += M, out_ptr += M) {
|
||||
// Find the maximum
|
||||
current_in_ptr = in_ptr;
|
||||
Simd<AccT, N> vmaximum(-numeric_limits<AccT>::infinity());
|
||||
size_t s = M;
|
||||
while (s >= N) {
|
||||
Simd<AccT, N> vals = load<T, N>(current_in_ptr);
|
||||
vmaximum = maximum(vals, vmaximum);
|
||||
current_in_ptr += N;
|
||||
s -= N;
|
||||
}
|
||||
encoder.dispatch([in_ptr, out_ptr, M, L]() mutable {
|
||||
constexpr bool same_t = std::is_same_v<T, AccT>;
|
||||
constexpr int N = std::min(max_size<AccT>, max_size<T>);
|
||||
|
||||
AccT maximum = max(vmaximum);
|
||||
while (s-- > 0) {
|
||||
maximum = std::max(maximum, static_cast<AccT>(*current_in_ptr));
|
||||
current_in_ptr++;
|
||||
}
|
||||
const T* current_in_ptr;
|
||||
T* current_out_ptr;
|
||||
|
||||
// Compute the normalizer and the exponentials
|
||||
Simd<AccT, N> vnormalizer(0.0);
|
||||
current_out_ptr = out_ptr;
|
||||
current_in_ptr = in_ptr;
|
||||
s = M;
|
||||
while (s >= N) {
|
||||
Simd<AccT, N> vexp = load<T, N>(current_in_ptr);
|
||||
vexp = exp(vexp - maximum);
|
||||
if constexpr (same_t) {
|
||||
store(current_out_ptr, vexp);
|
||||
}
|
||||
vnormalizer = vnormalizer + vexp;
|
||||
current_in_ptr += N;
|
||||
current_out_ptr += N;
|
||||
s -= N;
|
||||
}
|
||||
AccT normalizer = sum(vnormalizer);
|
||||
while (s-- > 0) {
|
||||
AccT _exp = std::exp(*current_in_ptr - maximum);
|
||||
if constexpr (same_t) {
|
||||
*current_out_ptr = _exp;
|
||||
}
|
||||
normalizer += _exp;
|
||||
current_in_ptr++;
|
||||
current_out_ptr++;
|
||||
}
|
||||
normalizer = 1 / normalizer;
|
||||
|
||||
// Normalize
|
||||
current_out_ptr = out_ptr;
|
||||
current_in_ptr = in_ptr;
|
||||
s = M;
|
||||
while (s >= N) {
|
||||
if constexpr (same_t) {
|
||||
store(
|
||||
current_out_ptr,
|
||||
Simd<T, N>(load<T, N>(current_out_ptr) * normalizer));
|
||||
} else {
|
||||
Simd<AccT, N> vexp = load<T, N>(current_in_ptr);
|
||||
vexp = exp(vexp - maximum) * normalizer;
|
||||
store(current_out_ptr, Simd<T, N>(vexp));
|
||||
for (int i = 0; i < L; i++, in_ptr += M, out_ptr += M) {
|
||||
// Find the maximum
|
||||
current_in_ptr = in_ptr;
|
||||
Simd<AccT, N> vmaximum(-numeric_limits<AccT>::infinity());
|
||||
size_t s = M;
|
||||
while (s >= N) {
|
||||
Simd<AccT, N> vals = load<T, N>(current_in_ptr);
|
||||
vmaximum = maximum(vals, vmaximum);
|
||||
current_in_ptr += N;
|
||||
s -= N;
|
||||
}
|
||||
current_out_ptr += N;
|
||||
s -= N;
|
||||
}
|
||||
while (s-- > 0) {
|
||||
if constexpr (same_t) {
|
||||
*current_out_ptr *= normalizer;
|
||||
} else {
|
||||
AccT _exp = std::exp(*current_in_ptr - maximum);
|
||||
*current_out_ptr = static_cast<T>(_exp * normalizer);
|
||||
|
||||
AccT maximum = max(vmaximum);
|
||||
while (s-- > 0) {
|
||||
maximum = std::max(maximum, static_cast<AccT>(*current_in_ptr));
|
||||
current_in_ptr++;
|
||||
}
|
||||
current_out_ptr++;
|
||||
|
||||
// Compute the normalizer and the exponentials
|
||||
Simd<AccT, N> vnormalizer(0.0);
|
||||
current_out_ptr = out_ptr;
|
||||
current_in_ptr = in_ptr;
|
||||
s = M;
|
||||
while (s >= N) {
|
||||
Simd<AccT, N> vexp = load<T, N>(current_in_ptr);
|
||||
vexp = exp(vexp - maximum);
|
||||
if constexpr (same_t) {
|
||||
store(current_out_ptr, vexp);
|
||||
}
|
||||
vnormalizer = vnormalizer + vexp;
|
||||
current_in_ptr += N;
|
||||
current_out_ptr += N;
|
||||
s -= N;
|
||||
}
|
||||
AccT normalizer = sum(vnormalizer);
|
||||
while (s-- > 0) {
|
||||
AccT _exp = std::exp(*current_in_ptr - maximum);
|
||||
if constexpr (same_t) {
|
||||
*current_out_ptr = _exp;
|
||||
}
|
||||
normalizer += _exp;
|
||||
current_in_ptr++;
|
||||
current_out_ptr++;
|
||||
}
|
||||
normalizer = 1 / normalizer;
|
||||
|
||||
// Normalize
|
||||
current_out_ptr = out_ptr;
|
||||
current_in_ptr = in_ptr;
|
||||
s = M;
|
||||
while (s >= N) {
|
||||
if constexpr (same_t) {
|
||||
store(
|
||||
current_out_ptr,
|
||||
Simd<T, N>(load<T, N>(current_out_ptr) * normalizer));
|
||||
} else {
|
||||
Simd<AccT, N> vexp = load<T, N>(current_in_ptr);
|
||||
vexp = exp(vexp - maximum) * normalizer;
|
||||
store(current_out_ptr, Simd<T, N>(vexp));
|
||||
current_in_ptr += N;
|
||||
}
|
||||
current_out_ptr += N;
|
||||
s -= N;
|
||||
}
|
||||
while (s-- > 0) {
|
||||
if constexpr (same_t) {
|
||||
*current_out_ptr *= normalizer;
|
||||
} else {
|
||||
AccT _exp = std::exp(*current_in_ptr - maximum);
|
||||
*current_out_ptr = static_cast<T>(_exp * normalizer);
|
||||
current_in_ptr++;
|
||||
}
|
||||
current_out_ptr++;
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace
|
||||
@@ -109,67 +118,52 @@ void Softmax::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
|
||||
// Make sure that the last dimension is contiguous
|
||||
auto check_input = [](array x) {
|
||||
bool no_copy = x.strides()[x.ndim() - 1] == 1;
|
||||
if (x.ndim() > 1) {
|
||||
auto s = x.strides()[x.ndim() - 2];
|
||||
no_copy &= (s == 0 || s == x.shape().back());
|
||||
}
|
||||
if (no_copy) {
|
||||
auto set_output = [s = stream(), &out](const array& x) {
|
||||
if (x.flags().contiguous && x.strides()[x.ndim() - 1] == 1) {
|
||||
if (x.is_donatable()) {
|
||||
out.copy_shared_buffer(x);
|
||||
} else {
|
||||
out.set_data(
|
||||
allocator::malloc(x.data_size() * x.itemsize()),
|
||||
x.data_size(),
|
||||
x.strides(),
|
||||
x.flags());
|
||||
}
|
||||
return x;
|
||||
} else {
|
||||
array x_copy(x.shape(), x.dtype(), nullptr, {});
|
||||
copy(x, x_copy, CopyType::General);
|
||||
copy(x, x_copy, CopyType::General, s);
|
||||
out.copy_shared_buffer(x_copy);
|
||||
return x_copy;
|
||||
}
|
||||
};
|
||||
array in = check_input(std::move(inputs[0]));
|
||||
if (in.is_donatable()) {
|
||||
out.copy_shared_buffer(in);
|
||||
} else {
|
||||
out.set_data(
|
||||
allocator::malloc_or_wait(in.data_size() * in.itemsize()),
|
||||
in.data_size(),
|
||||
in.strides(),
|
||||
in.flags());
|
||||
}
|
||||
|
||||
auto in = set_output(inputs[0]);
|
||||
|
||||
switch (in.dtype()) {
|
||||
case bool_:
|
||||
case uint8:
|
||||
case uint16:
|
||||
case uint32:
|
||||
case uint64:
|
||||
case int8:
|
||||
case int16:
|
||||
case int32:
|
||||
case int64:
|
||||
throw std::runtime_error(
|
||||
"Softmax is defined only for floating point types");
|
||||
break;
|
||||
case float32:
|
||||
softmax<float, float>(in, out);
|
||||
softmax<float, float>(in, out, stream());
|
||||
break;
|
||||
case float16:
|
||||
if (precise_) {
|
||||
softmax<float16_t, float>(in, out);
|
||||
softmax<float16_t, float>(in, out, stream());
|
||||
} else {
|
||||
softmax<float16_t, float16_t>(in, out);
|
||||
softmax<float16_t, float16_t>(in, out, stream());
|
||||
}
|
||||
break;
|
||||
case bfloat16:
|
||||
if (precise_) {
|
||||
softmax<bfloat16_t, float>(in, out);
|
||||
softmax<bfloat16_t, float>(in, out, stream());
|
||||
} else {
|
||||
softmax<bfloat16_t, bfloat16_t>(in, out);
|
||||
softmax<bfloat16_t, bfloat16_t>(in, out, stream());
|
||||
}
|
||||
break;
|
||||
case float64:
|
||||
softmax<double, double>(in, out);
|
||||
softmax<double, double>(in, out, stream());
|
||||
break;
|
||||
case complex64:
|
||||
throw std::invalid_argument(
|
||||
"[Softmax] Not yet implemented for complex64");
|
||||
default:
|
||||
throw std::runtime_error(
|
||||
"[softmax] Only defined for floating point types.");
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -7,6 +7,7 @@
|
||||
|
||||
#include "mlx/backend/common/utils.h"
|
||||
#include "mlx/backend/cpu/copy.h"
|
||||
#include "mlx/backend/cpu/encoder.h"
|
||||
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
@@ -103,16 +104,12 @@ struct StridedIterator {
|
||||
T* ptr_;
|
||||
};
|
||||
|
||||
template <typename T, typename IdxT = uint32_t>
|
||||
void sort(const array& in, array& out, int axis) {
|
||||
// Copy input to output
|
||||
CopyType ctype = in.flags().contiguous ? CopyType::Vector : CopyType::General;
|
||||
copy(in, out, ctype);
|
||||
|
||||
template <typename T>
|
||||
void sort(array& out, int axis) {
|
||||
// Get axis, shape and stride info
|
||||
axis = axis < 0 ? axis + in.ndim() : axis;
|
||||
size_t in_size = in.flags().contiguous ? in.data_size() : in.size();
|
||||
size_t n_rows = in_size / in.shape(axis);
|
||||
axis = axis < 0 ? axis + out.ndim() : axis;
|
||||
size_t in_size = out.size();
|
||||
size_t n_rows = in_size / out.shape(axis);
|
||||
|
||||
auto remaining_shape = out.shape();
|
||||
remaining_shape.erase(remaining_shape.begin() + axis);
|
||||
@@ -126,8 +123,9 @@ void sort(const array& in, array& out, int axis) {
|
||||
// Perform sorting in place
|
||||
ContiguousIterator src_it(
|
||||
remaining_shape, remaining_strides, remaining_shape.size());
|
||||
auto out_ptr = out.data<T>();
|
||||
for (int i = 0; i < n_rows; i++) {
|
||||
T* data_ptr = out.data<T>() + src_it.loc;
|
||||
T* data_ptr = out_ptr + src_it.loc;
|
||||
|
||||
StridedIterator st(data_ptr, axis_stride, 0);
|
||||
StridedIterator ed(data_ptr, axis_stride, axis_size);
|
||||
@@ -139,9 +137,6 @@ void sort(const array& in, array& out, int axis) {
|
||||
|
||||
template <typename T, typename IdxT = uint32_t>
|
||||
void argsort(const array& in, array& out, int axis) {
|
||||
// Allocate output
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
|
||||
// Get axis, shape and stride info
|
||||
axis = axis < 0 ? axis + in.ndim() : axis;
|
||||
size_t n_rows = in.size() / in.shape(axis);
|
||||
@@ -167,9 +162,12 @@ void argsort(const array& in, array& out, int axis) {
|
||||
in_remaining_shape, in_remaining_strides, in_remaining_shape.size());
|
||||
ContiguousIterator out_it(
|
||||
out_remaining_shape, out_remaining_strides, out_remaining_shape.size());
|
||||
auto in_ptr = in.data<T>();
|
||||
auto out_ptr = out.data<IdxT>();
|
||||
for (int i = 0; i < n_rows; i++) {
|
||||
const T* data_ptr = in.data<T>() + in_it.loc;
|
||||
IdxT* idx_ptr = out.data<IdxT>() + out_it.loc;
|
||||
const T* data_ptr = in_ptr + in_it.loc;
|
||||
IdxT* idx_ptr = out_ptr + out_it.loc;
|
||||
|
||||
in_it.step();
|
||||
out_it.step();
|
||||
|
||||
@@ -191,33 +189,30 @@ void argsort(const array& in, array& out, int axis) {
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename IdxT = uint32_t>
|
||||
void partition(const array& in, array& out, int axis, int kth) {
|
||||
// Copy input to output
|
||||
CopyType ctype = in.flags().contiguous ? CopyType::Vector : CopyType::General;
|
||||
copy(in, out, ctype);
|
||||
|
||||
template <typename T>
|
||||
void partition(array& out, int axis, int kth) {
|
||||
// Get axis, shape and stride info
|
||||
axis = axis < 0 ? axis + in.ndim() : axis;
|
||||
size_t in_size = in.flags().contiguous ? in.data_size() : in.size();
|
||||
size_t n_rows = in_size / in.shape(axis);
|
||||
axis = axis < 0 ? axis + out.ndim() : axis;
|
||||
size_t in_size = out.size();
|
||||
size_t n_rows = in_size / out.shape(axis);
|
||||
|
||||
auto remaining_shape = in.shape();
|
||||
auto remaining_shape = out.shape();
|
||||
remaining_shape.erase(remaining_shape.begin() + axis);
|
||||
|
||||
auto remaining_strides = in.strides();
|
||||
auto remaining_strides = out.strides();
|
||||
remaining_strides.erase(remaining_strides.begin() + axis);
|
||||
|
||||
auto axis_stride = in.strides()[axis];
|
||||
int axis_size = in.shape(axis);
|
||||
auto axis_stride = out.strides()[axis];
|
||||
int axis_size = out.shape(axis);
|
||||
|
||||
kth = kth < 0 ? kth + axis_size : kth;
|
||||
|
||||
// Perform partition in place
|
||||
ContiguousIterator src_it(
|
||||
remaining_shape, remaining_strides, remaining_shape.size());
|
||||
auto out_ptr = out.data<T>();
|
||||
for (int i = 0; i < n_rows; i++) {
|
||||
T* data_ptr = out.data<T>() + src_it.loc;
|
||||
T* data_ptr = out_ptr + src_it.loc;
|
||||
src_it.step();
|
||||
|
||||
StridedIterator st(data_ptr, axis_stride, 0);
|
||||
@@ -230,9 +225,6 @@ void partition(const array& in, array& out, int axis, int kth) {
|
||||
|
||||
template <typename T, typename IdxT = uint32_t>
|
||||
void argpartition(const array& in, array& out, int axis, int kth) {
|
||||
// Allocate output
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
|
||||
// Get axis, shape and stride info
|
||||
axis = axis < 0 ? axis + in.ndim() : axis;
|
||||
size_t n_rows = in.size() / in.shape(axis);
|
||||
@@ -260,9 +252,13 @@ void argpartition(const array& in, array& out, int axis, int kth) {
|
||||
in_remaining_shape, in_remaining_strides, in_remaining_shape.size());
|
||||
ContiguousIterator out_it(
|
||||
out_remaining_shape, out_remaining_strides, out_remaining_shape.size());
|
||||
|
||||
auto in_ptr = in.data<T>();
|
||||
auto out_ptr = out.data<IdxT>();
|
||||
|
||||
for (int i = 0; i < n_rows; i++) {
|
||||
const T* data_ptr = in.data<T>() + in_it.loc;
|
||||
IdxT* idx_ptr = out.data<IdxT>() + out_it.loc;
|
||||
const T* data_ptr = in_ptr + in_it.loc;
|
||||
IdxT* idx_ptr = out_ptr + out_it.loc;
|
||||
in_it.step();
|
||||
out_it.step();
|
||||
|
||||
@@ -291,144 +287,184 @@ void ArgSort::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
|
||||
switch (in.dtype()) {
|
||||
case bool_:
|
||||
return argsort<bool>(in, out, axis_);
|
||||
case uint8:
|
||||
return argsort<uint8_t>(in, out, axis_);
|
||||
case uint16:
|
||||
return argsort<uint16_t>(in, out, axis_);
|
||||
case uint32:
|
||||
return argsort<uint32_t>(in, out, axis_);
|
||||
case uint64:
|
||||
return argsort<uint64_t>(in, out, axis_);
|
||||
case int8:
|
||||
return argsort<int8_t>(in, out, axis_);
|
||||
case int16:
|
||||
return argsort<int16_t>(in, out, axis_);
|
||||
case int32:
|
||||
return argsort<int32_t>(in, out, axis_);
|
||||
case int64:
|
||||
return argsort<int64_t>(in, out, axis_);
|
||||
case float32:
|
||||
return argsort<float>(in, out, axis_);
|
||||
case float64:
|
||||
return argsort<double>(in, out, axis_);
|
||||
case float16:
|
||||
return argsort<float16_t>(in, out, axis_);
|
||||
case bfloat16:
|
||||
return argsort<bfloat16_t>(in, out, axis_);
|
||||
case complex64:
|
||||
return argsort<complex64_t>(in, out, axis_);
|
||||
}
|
||||
// Allocate output
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
|
||||
auto& encoder = cpu::get_command_encoder(stream());
|
||||
encoder.set_input_array(in);
|
||||
encoder.set_input_array(out);
|
||||
encoder.dispatch([in = array::unsafe_weak_copy(in),
|
||||
out = array::unsafe_weak_copy(out),
|
||||
axis_ = axis_]() mutable {
|
||||
switch (in.dtype()) {
|
||||
case bool_:
|
||||
return argsort<bool>(in, out, axis_);
|
||||
case uint8:
|
||||
return argsort<uint8_t>(in, out, axis_);
|
||||
case uint16:
|
||||
return argsort<uint16_t>(in, out, axis_);
|
||||
case uint32:
|
||||
return argsort<uint32_t>(in, out, axis_);
|
||||
case uint64:
|
||||
return argsort<uint64_t>(in, out, axis_);
|
||||
case int8:
|
||||
return argsort<int8_t>(in, out, axis_);
|
||||
case int16:
|
||||
return argsort<int16_t>(in, out, axis_);
|
||||
case int32:
|
||||
return argsort<int32_t>(in, out, axis_);
|
||||
case int64:
|
||||
return argsort<int64_t>(in, out, axis_);
|
||||
case float32:
|
||||
return argsort<float>(in, out, axis_);
|
||||
case float64:
|
||||
return argsort<double>(in, out, axis_);
|
||||
case float16:
|
||||
return argsort<float16_t>(in, out, axis_);
|
||||
case bfloat16:
|
||||
return argsort<bfloat16_t>(in, out, axis_);
|
||||
case complex64:
|
||||
return argsort<complex64_t>(in, out, axis_);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
void Sort::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
|
||||
switch (in.dtype()) {
|
||||
case bool_:
|
||||
return sort<bool>(in, out, axis_);
|
||||
case uint8:
|
||||
return sort<uint8_t>(in, out, axis_);
|
||||
case uint16:
|
||||
return sort<uint16_t>(in, out, axis_);
|
||||
case uint32:
|
||||
return sort<uint32_t>(in, out, axis_);
|
||||
case uint64:
|
||||
return sort<uint64_t>(in, out, axis_);
|
||||
case int8:
|
||||
return sort<int8_t>(in, out, axis_);
|
||||
case int16:
|
||||
return sort<int16_t>(in, out, axis_);
|
||||
case int32:
|
||||
return sort<int32_t>(in, out, axis_);
|
||||
case int64:
|
||||
return sort<int64_t>(in, out, axis_);
|
||||
case float32:
|
||||
return sort<float>(in, out, axis_);
|
||||
case float64:
|
||||
return sort<double>(in, out, axis_);
|
||||
case float16:
|
||||
return sort<float16_t>(in, out, axis_);
|
||||
case bfloat16:
|
||||
return sort<bfloat16_t>(in, out, axis_);
|
||||
case complex64:
|
||||
return sort<complex64_t>(in, out, axis_);
|
||||
}
|
||||
// Copy input to output
|
||||
CopyType ctype = in.flags().contiguous ? CopyType::Vector : CopyType::General;
|
||||
copy(in, out, ctype, stream());
|
||||
|
||||
auto& encoder = cpu::get_command_encoder(stream());
|
||||
encoder.set_output_array(out);
|
||||
encoder.dispatch(
|
||||
[out = array::unsafe_weak_copy(out), axis_ = axis_]() mutable {
|
||||
switch (out.dtype()) {
|
||||
case bool_:
|
||||
return sort<bool>(out, axis_);
|
||||
case uint8:
|
||||
return sort<uint8_t>(out, axis_);
|
||||
case uint16:
|
||||
return sort<uint16_t>(out, axis_);
|
||||
case uint32:
|
||||
return sort<uint32_t>(out, axis_);
|
||||
case uint64:
|
||||
return sort<uint64_t>(out, axis_);
|
||||
case int8:
|
||||
return sort<int8_t>(out, axis_);
|
||||
case int16:
|
||||
return sort<int16_t>(out, axis_);
|
||||
case int32:
|
||||
return sort<int32_t>(out, axis_);
|
||||
case int64:
|
||||
return sort<int64_t>(out, axis_);
|
||||
case float32:
|
||||
return sort<float>(out, axis_);
|
||||
case float64:
|
||||
return sort<double>(out, axis_);
|
||||
case float16:
|
||||
return sort<float16_t>(out, axis_);
|
||||
case bfloat16:
|
||||
return sort<bfloat16_t>(out, axis_);
|
||||
case complex64:
|
||||
return sort<complex64_t>(out, axis_);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
void ArgPartition::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
|
||||
switch (in.dtype()) {
|
||||
case bool_:
|
||||
return argpartition<bool>(in, out, axis_, kth_);
|
||||
case uint8:
|
||||
return argpartition<uint8_t>(in, out, axis_, kth_);
|
||||
case uint16:
|
||||
return argpartition<uint16_t>(in, out, axis_, kth_);
|
||||
case uint32:
|
||||
return argpartition<uint32_t>(in, out, axis_, kth_);
|
||||
case uint64:
|
||||
return argpartition<uint64_t>(in, out, axis_, kth_);
|
||||
case int8:
|
||||
return argpartition<int8_t>(in, out, axis_, kth_);
|
||||
case int16:
|
||||
return argpartition<int16_t>(in, out, axis_, kth_);
|
||||
case int32:
|
||||
return argpartition<int32_t>(in, out, axis_, kth_);
|
||||
case int64:
|
||||
return argpartition<int64_t>(in, out, axis_, kth_);
|
||||
case float32:
|
||||
return argpartition<float>(in, out, axis_, kth_);
|
||||
case float64:
|
||||
return argpartition<double>(in, out, axis_, kth_);
|
||||
case float16:
|
||||
return argpartition<float16_t>(in, out, axis_, kth_);
|
||||
case bfloat16:
|
||||
return argpartition<bfloat16_t>(in, out, axis_, kth_);
|
||||
case complex64:
|
||||
return argpartition<complex64_t>(in, out, axis_, kth_);
|
||||
}
|
||||
// Allocate output
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
|
||||
auto& encoder = cpu::get_command_encoder(stream());
|
||||
encoder.set_input_array(in);
|
||||
encoder.set_input_array(out);
|
||||
encoder.dispatch([in = array::unsafe_weak_copy(in),
|
||||
out = array::unsafe_weak_copy(out),
|
||||
axis_ = axis_,
|
||||
kth_ = kth_]() mutable {
|
||||
switch (in.dtype()) {
|
||||
case bool_:
|
||||
return argpartition<bool>(in, out, axis_, kth_);
|
||||
case uint8:
|
||||
return argpartition<uint8_t>(in, out, axis_, kth_);
|
||||
case uint16:
|
||||
return argpartition<uint16_t>(in, out, axis_, kth_);
|
||||
case uint32:
|
||||
return argpartition<uint32_t>(in, out, axis_, kth_);
|
||||
case uint64:
|
||||
return argpartition<uint64_t>(in, out, axis_, kth_);
|
||||
case int8:
|
||||
return argpartition<int8_t>(in, out, axis_, kth_);
|
||||
case int16:
|
||||
return argpartition<int16_t>(in, out, axis_, kth_);
|
||||
case int32:
|
||||
return argpartition<int32_t>(in, out, axis_, kth_);
|
||||
case int64:
|
||||
return argpartition<int64_t>(in, out, axis_, kth_);
|
||||
case float32:
|
||||
return argpartition<float>(in, out, axis_, kth_);
|
||||
case float64:
|
||||
return argpartition<double>(in, out, axis_, kth_);
|
||||
case float16:
|
||||
return argpartition<float16_t>(in, out, axis_, kth_);
|
||||
case bfloat16:
|
||||
return argpartition<bfloat16_t>(in, out, axis_, kth_);
|
||||
case complex64:
|
||||
return argpartition<complex64_t>(in, out, axis_, kth_);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
void Partition::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
|
||||
switch (in.dtype()) {
|
||||
case bool_:
|
||||
return partition<bool>(in, out, axis_, kth_);
|
||||
case uint8:
|
||||
return partition<uint8_t>(in, out, axis_, kth_);
|
||||
case uint16:
|
||||
return partition<uint16_t>(in, out, axis_, kth_);
|
||||
case uint32:
|
||||
return partition<uint32_t>(in, out, axis_, kth_);
|
||||
case uint64:
|
||||
return partition<uint64_t>(in, out, axis_, kth_);
|
||||
case int8:
|
||||
return partition<int8_t>(in, out, axis_, kth_);
|
||||
case int16:
|
||||
return partition<int16_t>(in, out, axis_, kth_);
|
||||
case int32:
|
||||
return partition<int32_t>(in, out, axis_, kth_);
|
||||
case int64:
|
||||
return partition<int64_t>(in, out, axis_, kth_);
|
||||
case float32:
|
||||
return partition<float>(in, out, axis_, kth_);
|
||||
case float64:
|
||||
return partition<double>(in, out, axis_, kth_);
|
||||
case float16:
|
||||
return partition<float16_t>(in, out, axis_, kth_);
|
||||
case bfloat16:
|
||||
return partition<bfloat16_t>(in, out, axis_, kth_);
|
||||
case complex64:
|
||||
return partition<complex64_t>(in, out, axis_, kth_);
|
||||
}
|
||||
// Copy input to output
|
||||
CopyType ctype = in.flags().contiguous ? CopyType::Vector : CopyType::General;
|
||||
copy(in, out, ctype, stream());
|
||||
|
||||
auto& encoder = cpu::get_command_encoder(stream());
|
||||
encoder.set_output_array(out);
|
||||
encoder.dispatch([out = array::unsafe_weak_copy(out),
|
||||
axis_ = axis_,
|
||||
kth_ = kth_]() mutable {
|
||||
switch (out.dtype()) {
|
||||
case bool_:
|
||||
return partition<bool>(out, axis_, kth_);
|
||||
case uint8:
|
||||
return partition<uint8_t>(out, axis_, kth_);
|
||||
case uint16:
|
||||
return partition<uint16_t>(out, axis_, kth_);
|
||||
case uint32:
|
||||
return partition<uint32_t>(out, axis_, kth_);
|
||||
case uint64:
|
||||
return partition<uint64_t>(out, axis_, kth_);
|
||||
case int8:
|
||||
return partition<int8_t>(out, axis_, kth_);
|
||||
case int16:
|
||||
return partition<int16_t>(out, axis_, kth_);
|
||||
case int32:
|
||||
return partition<int32_t>(out, axis_, kth_);
|
||||
case int64:
|
||||
return partition<int64_t>(out, axis_, kth_);
|
||||
case float32:
|
||||
return partition<float>(out, axis_, kth_);
|
||||
case float64:
|
||||
return partition<double>(out, axis_, kth_);
|
||||
case float16:
|
||||
return partition<float16_t>(out, axis_, kth_);
|
||||
case bfloat16:
|
||||
return partition<bfloat16_t>(out, axis_, kth_);
|
||||
case complex64:
|
||||
return partition<complex64_t>(out, axis_, kth_);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
||||
@@ -2,13 +2,18 @@
|
||||
|
||||
#include "mlx/allocator.h"
|
||||
#include "mlx/backend/cpu/copy.h"
|
||||
#include "mlx/backend/cpu/encoder.h"
|
||||
#include "mlx/backend/cpu/lapack.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
template <typename T>
|
||||
void svd_impl(const array& a, T* u_data, T* s_data, T* vt_data) {
|
||||
void svd_impl(
|
||||
const array& a,
|
||||
std::vector<array>& outputs,
|
||||
bool compute_uv,
|
||||
Stream stream) {
|
||||
// Lapack uses the column-major convention. To avoid having to transpose
|
||||
// the input and then transpose the outputs, we swap the indices/sizes of the
|
||||
// matrices and take advantage of the following identity (see
|
||||
@@ -22,75 +27,80 @@ void svd_impl(const array& a, T* u_data, T* s_data, T* vt_data) {
|
||||
const int N = a.shape(-1);
|
||||
const int K = std::min(M, N);
|
||||
|
||||
// A of shape M x N. The leading dimension is N since lapack receives Aᵀ.
|
||||
const int lda = N;
|
||||
// U of shape M x M. (N x N in lapack).
|
||||
const int ldu = N;
|
||||
// Vᵀ of shape N x N. (M x M in lapack).
|
||||
const int ldvt = M;
|
||||
|
||||
size_t num_matrices = a.size() / (M * N);
|
||||
|
||||
// lapack clobbers the input, so we have to make a copy.
|
||||
array in(a.shape(), a.dtype(), nullptr, {});
|
||||
copy(a, in, a.flags().row_contiguous ? CopyType::Vector : CopyType::General);
|
||||
copy(
|
||||
a,
|
||||
in,
|
||||
a.flags().row_contiguous ? CopyType::Vector : CopyType::General,
|
||||
stream);
|
||||
|
||||
auto job_u = (u_data && vt_data) ? "V" : "N";
|
||||
auto job_vt = (u_data && vt_data) ? "V" : "N";
|
||||
static constexpr auto range = "A";
|
||||
// Allocate outputs.
|
||||
auto& encoder = cpu::get_command_encoder(stream);
|
||||
encoder.set_input_array(a);
|
||||
auto in_ptr = in.data<T>();
|
||||
T* u_ptr;
|
||||
T* s_ptr;
|
||||
T* vt_ptr;
|
||||
|
||||
// Will contain the number of singular values after the call has returned.
|
||||
int ns = 0;
|
||||
T workspace_dimension = 0;
|
||||
if (compute_uv) {
|
||||
array& u = outputs[0];
|
||||
array& s = outputs[1];
|
||||
array& vt = outputs[2];
|
||||
|
||||
// Will contain the indices of eigenvectors that failed to converge (not used
|
||||
// here but required by lapack).
|
||||
auto iwork = array::Data{allocator::malloc_or_wait(sizeof(int) * 12 * K)};
|
||||
u.set_data(allocator::malloc(u.nbytes()));
|
||||
s.set_data(allocator::malloc(s.nbytes()));
|
||||
vt.set_data(allocator::malloc(vt.nbytes()));
|
||||
|
||||
static const int lwork_query = -1;
|
||||
encoder.set_output_array(u);
|
||||
encoder.set_output_array(s);
|
||||
encoder.set_output_array(vt);
|
||||
|
||||
static const int ignored_int = 0;
|
||||
static const T ignored_float = 0;
|
||||
static T ignored_output = 0;
|
||||
s_ptr = s.data<T>();
|
||||
u_ptr = u.data<T>();
|
||||
vt_ptr = vt.data<T>();
|
||||
} else {
|
||||
array& s = outputs[0];
|
||||
|
||||
int info;
|
||||
s.set_data(allocator::malloc(s.nbytes()));
|
||||
|
||||
// Compute workspace size.
|
||||
gesvdx<T>(
|
||||
/* jobu = */ job_u,
|
||||
/* jobvt = */ job_vt,
|
||||
/* range = */ range,
|
||||
// M and N are swapped since lapack expects column-major.
|
||||
/* m = */ &N,
|
||||
/* n = */ &M,
|
||||
/* a = */ nullptr,
|
||||
/* lda = */ &lda,
|
||||
/* vl = */ &ignored_float,
|
||||
/* vu = */ &ignored_float,
|
||||
/* il = */ &ignored_int,
|
||||
/* iu = */ &ignored_int,
|
||||
/* ns = */ &ns,
|
||||
/* s = */ nullptr,
|
||||
/* u = */ nullptr,
|
||||
/* ldu = */ &ldu,
|
||||
/* vt = */ nullptr,
|
||||
/* ldvt = */ &ldvt,
|
||||
/* work = */ &workspace_dimension,
|
||||
/* lwork = */ &lwork_query,
|
||||
/* iwork = */ static_cast<int*>(iwork.buffer.raw_ptr()),
|
||||
/* info = */ &info);
|
||||
encoder.set_output_array(s);
|
||||
|
||||
if (info != 0) {
|
||||
std::stringstream ss;
|
||||
ss << "[SVD::eval_cpu] workspace calculation failed with code " << info;
|
||||
throw std::runtime_error(ss.str());
|
||||
s_ptr = s.data<T>();
|
||||
u_ptr = nullptr;
|
||||
vt_ptr = nullptr;
|
||||
}
|
||||
|
||||
const int lwork = workspace_dimension;
|
||||
auto scratch = array::Data{allocator::malloc_or_wait(sizeof(T) * lwork)};
|
||||
encoder.dispatch([in_ptr, u_ptr, s_ptr, vt_ptr, M, N, K, num_matrices]() {
|
||||
// A of shape M x N. The leading dimension is N since lapack receives Aᵀ.
|
||||
const int lda = N;
|
||||
// U of shape M x M. (N x N in lapack).
|
||||
const int ldu = N;
|
||||
// Vᵀ of shape N x N. (M x M in lapack).
|
||||
const int ldvt = M;
|
||||
|
||||
// Loop over matrices.
|
||||
for (int i = 0; i < num_matrices; i++) {
|
||||
auto job_u = (u_ptr) ? "V" : "N";
|
||||
auto job_vt = (u_ptr) ? "V" : "N";
|
||||
static constexpr auto range = "A";
|
||||
|
||||
// Will contain the number of singular values after the call has returned.
|
||||
int ns = 0;
|
||||
T workspace_dimension = 0;
|
||||
|
||||
// Will contain the indices of eigenvectors that failed to converge (not
|
||||
// used here but required by lapack).
|
||||
auto iwork = array::Data{allocator::malloc(sizeof(int) * 12 * K)};
|
||||
|
||||
static const int lwork_query = -1;
|
||||
|
||||
static const int ignored_int = 0;
|
||||
static const T ignored_float = 0;
|
||||
|
||||
int info;
|
||||
|
||||
// Compute workspace size.
|
||||
gesvdx<T>(
|
||||
/* jobu = */ job_u,
|
||||
/* jobvt = */ job_vt,
|
||||
@@ -98,70 +108,93 @@ void svd_impl(const array& a, T* u_data, T* s_data, T* vt_data) {
|
||||
// M and N are swapped since lapack expects column-major.
|
||||
/* m = */ &N,
|
||||
/* n = */ &M,
|
||||
/* a = */ in.data<T>() + M * N * i,
|
||||
/* a = */ nullptr,
|
||||
/* lda = */ &lda,
|
||||
/* vl = */ &ignored_float,
|
||||
/* vu = */ &ignored_float,
|
||||
/* il = */ &ignored_int,
|
||||
/* iu = */ &ignored_int,
|
||||
/* ns = */ &ns,
|
||||
/* s = */ s_data + K * i,
|
||||
// According to the identity above, lapack will write Vᵀᵀ as U.
|
||||
/* u = */ vt_data ? vt_data + N * N * i : &ignored_output,
|
||||
/* s = */ nullptr,
|
||||
/* u = */ nullptr,
|
||||
/* ldu = */ &ldu,
|
||||
// According to the identity above, lapack will write Uᵀ as Vᵀ.
|
||||
/* vt = */ u_data ? u_data + M * M * i : &ignored_output,
|
||||
/* vt = */ nullptr,
|
||||
/* ldvt = */ &ldvt,
|
||||
/* work = */ static_cast<T*>(scratch.buffer.raw_ptr()),
|
||||
/* lwork = */ &lwork,
|
||||
/* work = */ &workspace_dimension,
|
||||
/* lwork = */ &lwork_query,
|
||||
/* iwork = */ static_cast<int*>(iwork.buffer.raw_ptr()),
|
||||
/* info = */ &info);
|
||||
|
||||
if (info != 0) {
|
||||
std::stringstream ss;
|
||||
ss << "[SVD::eval_cpu] failed with code " << info;
|
||||
ss << "[SVD::eval_cpu] workspace calculation failed with code " << info;
|
||||
throw std::runtime_error(ss.str());
|
||||
}
|
||||
|
||||
if (ns != K) {
|
||||
std::stringstream ss;
|
||||
ss << "[SVD::eval_cpu] expected " << K << " singular values, but " << ns
|
||||
<< " were computed.";
|
||||
throw std::runtime_error(ss.str());
|
||||
const int lwork = workspace_dimension;
|
||||
auto scratch = array::Data{allocator::malloc(sizeof(T) * lwork)};
|
||||
|
||||
// Loop over matrices.
|
||||
for (int i = 0; i < num_matrices; i++) {
|
||||
gesvdx<T>(
|
||||
/* jobu = */ job_u,
|
||||
/* jobvt = */ job_vt,
|
||||
/* range = */ range,
|
||||
// M and N are swapped since lapack expects column-major.
|
||||
/* m = */ &N,
|
||||
/* n = */ &M,
|
||||
/* a = */ in_ptr + M * N * i,
|
||||
/* lda = */ &lda,
|
||||
/* vl = */ &ignored_float,
|
||||
/* vu = */ &ignored_float,
|
||||
/* il = */ &ignored_int,
|
||||
/* iu = */ &ignored_int,
|
||||
/* ns = */ &ns,
|
||||
/* s = */ s_ptr + K * i,
|
||||
// According to the identity above, lapack will write Vᵀᵀ as U.
|
||||
/* u = */ vt_ptr ? vt_ptr + N * N * i : nullptr,
|
||||
/* ldu = */ &ldu,
|
||||
// According to the identity above, lapack will write Uᵀ as Vᵀ.
|
||||
/* vt = */ u_ptr ? u_ptr + M * M * i : nullptr,
|
||||
/* ldvt = */ &ldvt,
|
||||
/* work = */ static_cast<T*>(scratch.buffer.raw_ptr()),
|
||||
/* lwork = */ &lwork,
|
||||
/* iwork = */ static_cast<int*>(iwork.buffer.raw_ptr()),
|
||||
/* info = */ &info);
|
||||
|
||||
if (info != 0) {
|
||||
std::stringstream ss;
|
||||
ss << "svd_impl: sgesvdx_ failed with code " << info;
|
||||
throw std::runtime_error(ss.str());
|
||||
}
|
||||
|
||||
if (ns != K) {
|
||||
std::stringstream ss;
|
||||
ss << "svd_impl: expected " << K << " singular values, but " << ns
|
||||
<< " were computed.";
|
||||
throw std::runtime_error(ss.str());
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
encoder.add_temporary(in);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void compute_svd(const array& a, bool compute_uv, std::vector<array>& outputs) {
|
||||
if (compute_uv) {
|
||||
array& u = outputs[0];
|
||||
array& s = outputs[1];
|
||||
array& vt = outputs[2];
|
||||
|
||||
u.set_data(allocator::malloc_or_wait(u.nbytes()));
|
||||
s.set_data(allocator::malloc_or_wait(s.nbytes()));
|
||||
vt.set_data(allocator::malloc_or_wait(vt.nbytes()));
|
||||
|
||||
svd_impl<T>(a, u.data<T>(), s.data<T>(), vt.data<T>());
|
||||
} else {
|
||||
array& s = outputs[0];
|
||||
|
||||
s.set_data(allocator::malloc_or_wait(s.nbytes()));
|
||||
|
||||
svd_impl<T>(a, nullptr, s.data<T>(), nullptr);
|
||||
}
|
||||
}
|
||||
void compute_svd(
|
||||
const array& a,
|
||||
bool compute_uv,
|
||||
std::vector<array>& outputs,
|
||||
Stream stream) {}
|
||||
|
||||
void SVD::eval_cpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
switch (inputs[0].dtype()) {
|
||||
case float32:
|
||||
compute_svd<float>(inputs[0], compute_uv_, outputs);
|
||||
svd_impl<float>(inputs[0], outputs, compute_uv_, stream());
|
||||
break;
|
||||
case float64:
|
||||
compute_svd<double>(inputs[0], compute_uv_, outputs);
|
||||
svd_impl<double>(inputs[0], outputs, compute_uv_, stream());
|
||||
break;
|
||||
default:
|
||||
throw std::runtime_error(
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
#include "mlx/allocator.h"
|
||||
#include "mlx/array.h"
|
||||
#include "mlx/backend/common/ternary.h"
|
||||
#include "mlx/backend/common/utils.h"
|
||||
#include "mlx/backend/cpu/encoder.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
@@ -53,22 +53,18 @@ void ternary_op_dims(
|
||||
|
||||
template <typename T1, typename T2, typename T3, typename U, typename Op>
|
||||
void ternary_op_dispatch_dims(
|
||||
const array& a,
|
||||
const array& b,
|
||||
const array& c,
|
||||
array& out,
|
||||
Op op) {
|
||||
auto [shape, strides] = collapse_contiguous_dims(
|
||||
a.shape(), {a.strides(), b.strides(), c.strides(), out.strides()});
|
||||
const T1* a_ptr,
|
||||
const T2* b_ptr,
|
||||
const T3* c_ptr,
|
||||
U* out_ptr,
|
||||
Op op,
|
||||
size_t size,
|
||||
Shape& shape,
|
||||
std::vector<Strides>& strides) {
|
||||
const auto& a_strides = strides[0];
|
||||
const auto& b_strides = strides[1];
|
||||
const auto& c_strides = strides[2];
|
||||
const auto& out_strides = strides[3];
|
||||
|
||||
const T1* a_ptr = a.data<T1>();
|
||||
const T2* b_ptr = b.data<T2>();
|
||||
const T3* c_ptr = c.data<T3>();
|
||||
U* out_ptr = out.data<T3>();
|
||||
int ndim = shape.size();
|
||||
switch (ndim) {
|
||||
case 1:
|
||||
@@ -105,7 +101,7 @@ void ternary_op_dispatch_dims(
|
||||
ContiguousIterator b_it(shape, b_strides, ndim - 2);
|
||||
ContiguousIterator c_it(shape, c_strides, ndim - 2);
|
||||
auto stride = out_strides[ndim - 3];
|
||||
for (size_t elem = 0; elem < a.size(); elem += stride) {
|
||||
for (size_t elem = 0; elem < size; elem += stride) {
|
||||
ternary_op_dims<T1, T2, T3, U, Op, 2>(
|
||||
a_ptr + a_it.loc,
|
||||
b_ptr + b_it.loc,
|
||||
@@ -130,18 +126,16 @@ void ternary_op(
|
||||
const array& b,
|
||||
const array& c,
|
||||
array& out,
|
||||
Op op) {
|
||||
TernaryOpType topt = get_ternary_op_type(a, b, c);
|
||||
set_ternary_op_output_data(a, b, c, out, topt);
|
||||
Op op,
|
||||
TernaryOpType topt) {
|
||||
const T1* a_ptr = a.data<T1>();
|
||||
const T2* b_ptr = b.data<T2>();
|
||||
const T3* c_ptr = c.data<T3>();
|
||||
U* out_ptr = out.data<U>();
|
||||
|
||||
// The full computation is scalar-scalar-scalar so we call the base op once.
|
||||
if (topt == TernaryOpType::ScalarScalarScalar) {
|
||||
*(out.data<U>()) = op(*a.data<T1>(), *b.data<T2>(), *c.data<T3>());
|
||||
*out_ptr = op(*a_ptr, *b_ptr, *c_ptr);
|
||||
} else if (topt == TernaryOpType::VectorVectorVector) {
|
||||
const T1* a_ptr = a.data<T1>();
|
||||
const T2* b_ptr = b.data<T2>();
|
||||
const T3* c_ptr = c.data<T3>();
|
||||
U* out_ptr = out.data<U>();
|
||||
for (size_t i = 0; i < out.size(); ++i) {
|
||||
*out_ptr = op(*a_ptr, *b_ptr, *c_ptr);
|
||||
a_ptr++;
|
||||
@@ -150,7 +144,10 @@ void ternary_op(
|
||||
out_ptr++;
|
||||
}
|
||||
} else {
|
||||
ternary_op_dispatch_dims<T1, T2, T3, U>(a, b, c, out, op);
|
||||
auto [shape, strides] = collapse_contiguous_dims(
|
||||
a.shape(), {a.strides(), b.strides(), c.strides(), out.strides()});
|
||||
ternary_op_dispatch_dims<T1, T2, T3, U>(
|
||||
a_ptr, b_ptr, c_ptr, out_ptr, op, out.size(), shape, strides);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,5 +1,8 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
// Required for using M_LN2 in MSVC.
|
||||
#define _USE_MATH_DEFINES
|
||||
|
||||
#include <cassert>
|
||||
|
||||
#include "mlx/backend/cpu/unary.h"
|
||||
@@ -14,88 +17,57 @@ void Abs::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
// No-op for unsigned types
|
||||
out.copy_shared_buffer(in);
|
||||
} else {
|
||||
auto op = detail::Abs{};
|
||||
switch (out.dtype()) {
|
||||
case int8:
|
||||
unary_op<int8_t>(in, out, op);
|
||||
break;
|
||||
case int16:
|
||||
unary_op<int16_t>(in, out, op);
|
||||
break;
|
||||
case int32:
|
||||
unary_op<int32_t>(in, out, op);
|
||||
break;
|
||||
case int64:
|
||||
unary_op<int64_t>(in, out, op);
|
||||
break;
|
||||
case float16:
|
||||
unary_op<float16_t>(in, out, op);
|
||||
break;
|
||||
case float32:
|
||||
unary_op<float>(in, out, op);
|
||||
break;
|
||||
case float64:
|
||||
unary_op<double>(in, out, op);
|
||||
break;
|
||||
case bfloat16:
|
||||
unary_op<bfloat16_t>(in, out, op);
|
||||
break;
|
||||
case complex64:
|
||||
unary_op<complex64_t>(in, out, op);
|
||||
break;
|
||||
default:
|
||||
throw std::runtime_error("[Abs] Called on unsigned type");
|
||||
}
|
||||
unary_signed(in, out, detail::Abs(), stream());
|
||||
}
|
||||
}
|
||||
|
||||
void ArcCos::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
unary_fp(in, out, detail::ArcCos());
|
||||
unary_fp(in, out, detail::ArcCos(), stream());
|
||||
}
|
||||
|
||||
void ArcCosh::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
unary_fp(in, out, detail::ArcCosh());
|
||||
unary_fp(in, out, detail::ArcCosh(), stream());
|
||||
}
|
||||
|
||||
void ArcSin::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
unary_fp(in, out, detail::ArcSin());
|
||||
unary_fp(in, out, detail::ArcSin(), stream());
|
||||
}
|
||||
|
||||
void ArcSinh::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
unary_fp(in, out, detail::ArcSinh());
|
||||
unary_fp(in, out, detail::ArcSinh(), stream());
|
||||
}
|
||||
|
||||
void ArcTan::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
unary_fp(in, out, detail::ArcTan());
|
||||
unary_fp(in, out, detail::ArcTan(), stream());
|
||||
}
|
||||
|
||||
void ArcTanh::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
unary_fp(in, out, detail::ArcTanh());
|
||||
unary_fp(in, out, detail::ArcTanh(), stream());
|
||||
}
|
||||
|
||||
void BitwiseInvert::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
unary_int(in, out, detail::BitwiseInvert());
|
||||
unary_int(in, out, detail::BitwiseInvert(), stream());
|
||||
}
|
||||
|
||||
void Ceil::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
if (issubdtype(in.dtype(), inexact)) {
|
||||
unary_fp(in, out, detail::Ceil());
|
||||
unary_fp(in, out, detail::Ceil(), stream());
|
||||
} else {
|
||||
// No-op integer types
|
||||
out.copy_shared_buffer(in);
|
||||
@@ -104,84 +76,50 @@ void Ceil::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
|
||||
void Conjugate::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
unary_op<complex64_t>(inputs[0], out, detail::Conjugate());
|
||||
unary_complex(inputs[0], out, detail::Conjugate(), stream());
|
||||
}
|
||||
|
||||
void Cos::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
unary_fp(in, out, detail::Cos());
|
||||
unary_fp(in, out, detail::Cos(), stream());
|
||||
}
|
||||
|
||||
void Cosh::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
unary_fp(in, out, detail::Cosh());
|
||||
unary_fp(in, out, detail::Cosh(), stream());
|
||||
}
|
||||
|
||||
void Erf::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
switch (out.dtype()) {
|
||||
case float32:
|
||||
unary_op<float>(in, out, detail::Erf());
|
||||
break;
|
||||
case float16:
|
||||
unary_op<float16_t>(in, out, detail::Erf());
|
||||
break;
|
||||
case float64:
|
||||
unary_op<double>(in, out, detail::Erf());
|
||||
break;
|
||||
case bfloat16:
|
||||
unary_op<bfloat16_t>(in, out, detail::Erf());
|
||||
break;
|
||||
default:
|
||||
throw std::invalid_argument(
|
||||
"[erf] Error function only defined for arrays"
|
||||
" with real floating point type.");
|
||||
}
|
||||
unary_real_fp(in, out, detail::Erf(), stream());
|
||||
}
|
||||
|
||||
void ErfInv::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
switch (out.dtype()) {
|
||||
case float32:
|
||||
unary_op<float>(in, out, detail::ErfInv());
|
||||
break;
|
||||
case float16:
|
||||
unary_op<float16_t>(in, out, detail::ErfInv());
|
||||
break;
|
||||
case float64:
|
||||
unary_op<double>(in, out, detail::ErfInv());
|
||||
break;
|
||||
case bfloat16:
|
||||
unary_op<bfloat16_t>(in, out, detail::ErfInv());
|
||||
break;
|
||||
default:
|
||||
throw std::invalid_argument(
|
||||
"[erf_inv] Inverse error function only defined for arrays"
|
||||
" with real floating point type.");
|
||||
}
|
||||
unary_real_fp(in, out, detail::ErfInv(), stream());
|
||||
}
|
||||
|
||||
void Exp::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
unary_fp(in, out, detail::Exp());
|
||||
unary_fp(in, out, detail::Exp(), stream());
|
||||
}
|
||||
|
||||
void Expm1::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
unary_fp(in, out, detail::Expm1());
|
||||
unary_fp(in, out, detail::Expm1(), stream());
|
||||
}
|
||||
|
||||
void Floor::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
if (issubdtype(in.dtype(), inexact)) {
|
||||
unary_fp(in, out, detail::Floor());
|
||||
unary_fp(in, out, detail::Floor(), stream());
|
||||
} else {
|
||||
// No-op integer types
|
||||
out.copy_shared_buffer(in);
|
||||
@@ -189,7 +127,7 @@ void Floor::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
}
|
||||
|
||||
void Imag::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
unary_op<complex64_t, float>(inputs[0], out, detail::Imag());
|
||||
unary_complex_to_float(inputs[0], out, detail::Imag(), stream());
|
||||
}
|
||||
|
||||
void Log::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
@@ -197,13 +135,13 @@ void Log::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
const auto& in = inputs[0];
|
||||
switch (base_) {
|
||||
case Base::e:
|
||||
unary_fp(in, out, detail::Log());
|
||||
unary_fp(in, out, detail::Log(), stream());
|
||||
break;
|
||||
case Base::two:
|
||||
unary_fp(in, out, detail::Log2());
|
||||
unary_fp(in, out, detail::Log2(), stream());
|
||||
break;
|
||||
case Base::ten:
|
||||
unary_fp(in, out, detail::Log10());
|
||||
unary_fp(in, out, detail::Log10(), stream());
|
||||
break;
|
||||
}
|
||||
}
|
||||
@@ -211,30 +149,30 @@ void Log::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
void Log1p::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
unary_fp(in, out, detail::Log1p());
|
||||
unary_fp(in, out, detail::Log1p(), stream());
|
||||
}
|
||||
|
||||
void LogicalNot::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
unary(in, out, detail::LogicalNot());
|
||||
unary(in, out, detail::LogicalNot(), stream());
|
||||
}
|
||||
|
||||
void Negative::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
unary(in, out, detail::Negative());
|
||||
unary(in, out, detail::Negative(), stream());
|
||||
}
|
||||
|
||||
void Real::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
unary_op<complex64_t, float>(inputs[0], out, detail::Real());
|
||||
unary_complex_to_float(inputs[0], out, detail::Real(), stream());
|
||||
}
|
||||
|
||||
void Round::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
if (issubdtype(in.dtype(), inexact)) {
|
||||
unary_fp(in, out, detail::Round());
|
||||
unary_fp(in, out, detail::Round(), stream());
|
||||
} else {
|
||||
// No-op integer types
|
||||
out.copy_shared_buffer(in);
|
||||
@@ -244,7 +182,7 @@ void Round::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
void Sigmoid::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
unary_fp(in, out, detail::Sigmoid());
|
||||
unary_fp(in, out, detail::Sigmoid(), stream());
|
||||
}
|
||||
|
||||
void Sign::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
@@ -253,48 +191,48 @@ void Sign::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
if (in.dtype() == bool_) {
|
||||
out.copy_shared_buffer(in);
|
||||
} else {
|
||||
unary(in, out, detail::Sign());
|
||||
unary(in, out, detail::Sign(), stream());
|
||||
}
|
||||
}
|
||||
|
||||
void Sin::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
unary_fp(in, out, detail::Sin());
|
||||
unary_fp(in, out, detail::Sin(), stream());
|
||||
}
|
||||
|
||||
void Sinh::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
unary_fp(in, out, detail::Sinh());
|
||||
unary_fp(in, out, detail::Sinh(), stream());
|
||||
}
|
||||
|
||||
void Square::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
unary(in, out, detail::Square());
|
||||
unary(in, out, detail::Square(), stream());
|
||||
}
|
||||
|
||||
void Sqrt::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
if (recip_) {
|
||||
unary_fp(in, out, detail::Rsqrt());
|
||||
unary_fp(in, out, detail::Rsqrt(), stream());
|
||||
} else {
|
||||
unary_fp(in, out, detail::Sqrt());
|
||||
unary_fp(in, out, detail::Sqrt(), stream());
|
||||
}
|
||||
}
|
||||
|
||||
void Tan::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
unary_fp(in, out, detail::Tan());
|
||||
unary_fp(in, out, detail::Tan(), stream());
|
||||
}
|
||||
|
||||
void Tanh::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
unary_fp(in, out, detail::Tanh());
|
||||
unary_fp(in, out, detail::Tanh(), stream());
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
||||
@@ -5,174 +5,296 @@
|
||||
#include "mlx/allocator.h"
|
||||
#include "mlx/array.h"
|
||||
#include "mlx/backend/common/utils.h"
|
||||
#include "mlx/backend/cpu/encoder.h"
|
||||
#include "mlx/backend/cpu/simd/simd.h"
|
||||
#include "mlx/utils.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
void set_unary_output_data(const array& in, array& out) {
|
||||
if (is_donatable(in, out)) {
|
||||
out.copy_shared_buffer(in);
|
||||
if (in.flags().contiguous) {
|
||||
if (is_donatable(in, out)) {
|
||||
out.copy_shared_buffer(in);
|
||||
} else {
|
||||
auto size = in.data_size();
|
||||
out.set_data(
|
||||
allocator::malloc(size * out.itemsize()),
|
||||
size,
|
||||
in.strides(),
|
||||
in.flags());
|
||||
}
|
||||
} else {
|
||||
auto size = in.data_size();
|
||||
out.set_data(
|
||||
allocator::malloc_or_wait(size * out.itemsize()),
|
||||
size,
|
||||
in.strides(),
|
||||
in.flags());
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename U = T, typename Op>
|
||||
void unary_op(const T* a, U* out, Op op, size_t shape, size_t stride) {
|
||||
void unary_op(const T* a, U* out, size_t shape, size_t stride) {
|
||||
for (size_t i = 0; i < shape; i += 1) {
|
||||
out[i] = op(*a);
|
||||
out[i] = Op{}(*a);
|
||||
a += stride;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename U = T, typename Op>
|
||||
void unary_op(const array& a, array& out, Op op) {
|
||||
const T* a_ptr = a.data<T>();
|
||||
void unary_op(const array& a, array& out, Op) {
|
||||
const T* src = a.data<T>();
|
||||
U* dst = out.data<U>();
|
||||
auto ndim = a.ndim();
|
||||
if (a.flags().contiguous) {
|
||||
set_unary_output_data(a, out);
|
||||
U* dst = out.data<U>();
|
||||
auto size = a.data_size();
|
||||
constexpr int N = simd::max_size<T>;
|
||||
size_t size = a.data_size();
|
||||
while (size >= N) {
|
||||
simd::store(dst, op(simd::load<T, N>(a_ptr)));
|
||||
simd::store(dst, Op{}(simd::load<T, N>(src)));
|
||||
size -= N;
|
||||
a_ptr += N;
|
||||
src += N;
|
||||
dst += N;
|
||||
}
|
||||
while (size > 0) {
|
||||
*dst = op(*a_ptr);
|
||||
*dst = Op{}(*src);
|
||||
size--;
|
||||
dst++;
|
||||
a_ptr++;
|
||||
src++;
|
||||
}
|
||||
} else {
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
U* dst = out.data<U>();
|
||||
size_t shape = a.ndim() > 0 ? a.shape(-1) : 1;
|
||||
size_t stride = a.ndim() > 0 ? a.strides(-1) : 1;
|
||||
if (a.ndim() <= 1) {
|
||||
unary_op(a_ptr, dst, op, shape, stride);
|
||||
size_t shape = ndim > 0 ? a.shape().back() : 1;
|
||||
size_t stride = ndim > 0 ? a.strides().back() : 1;
|
||||
if (ndim <= 1) {
|
||||
unary_op<T, U, Op>(src, dst, shape, stride);
|
||||
return;
|
||||
}
|
||||
ContiguousIterator it(a.shape(), a.strides(), a.ndim() - 1);
|
||||
auto it = ContiguousIterator(a.shape(), a.strides(), ndim - 1);
|
||||
for (size_t elem = 0; elem < a.size(); elem += shape) {
|
||||
unary_op(a_ptr + it.loc, dst + elem, op, shape, stride);
|
||||
unary_op<T, U, Op>(src + it.loc, dst + elem, shape, stride);
|
||||
it.step();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Op>
|
||||
void unary(const array& a, array& out, Op op) {
|
||||
switch (out.dtype()) {
|
||||
case bool_:
|
||||
unary_op<bool>(a, out, op);
|
||||
break;
|
||||
case uint8:
|
||||
unary_op<uint8_t>(a, out, op);
|
||||
break;
|
||||
case uint16:
|
||||
unary_op<uint16_t>(a, out, op);
|
||||
break;
|
||||
case uint32:
|
||||
unary_op<uint32_t>(a, out, op);
|
||||
break;
|
||||
case uint64:
|
||||
unary_op<uint64_t>(a, out, op);
|
||||
break;
|
||||
case int8:
|
||||
unary_op<int8_t>(a, out, op);
|
||||
break;
|
||||
case int16:
|
||||
unary_op<int16_t>(a, out, op);
|
||||
break;
|
||||
case int32:
|
||||
unary_op<int32_t>(a, out, op);
|
||||
break;
|
||||
case int64:
|
||||
unary_op<int64_t>(a, out, op);
|
||||
break;
|
||||
case float16:
|
||||
unary_op<float16_t>(a, out, op);
|
||||
break;
|
||||
case float32:
|
||||
unary_op<float>(a, out, op);
|
||||
break;
|
||||
case float64:
|
||||
unary_op<double>(a, out, op);
|
||||
break;
|
||||
case bfloat16:
|
||||
unary_op<bfloat16_t>(a, out, op);
|
||||
break;
|
||||
case complex64:
|
||||
unary_op<complex64_t>(a, out, op);
|
||||
break;
|
||||
}
|
||||
void unary(const array& a, array& out, Op op, Stream stream) {
|
||||
set_unary_output_data(a, out);
|
||||
auto& encoder = cpu::get_command_encoder(stream);
|
||||
encoder.set_input_array(a);
|
||||
encoder.set_output_array(out);
|
||||
encoder.dispatch([a = array::unsafe_weak_copy(a),
|
||||
out = array::unsafe_weak_copy(out),
|
||||
op = op]() mutable {
|
||||
switch (out.dtype()) {
|
||||
case bool_:
|
||||
unary_op<bool>(a, out, op);
|
||||
break;
|
||||
case uint8:
|
||||
unary_op<uint8_t>(a, out, op);
|
||||
break;
|
||||
case uint16:
|
||||
unary_op<uint16_t>(a, out, op);
|
||||
break;
|
||||
case uint32:
|
||||
unary_op<uint32_t>(a, out, op);
|
||||
break;
|
||||
case uint64:
|
||||
unary_op<uint64_t>(a, out, op);
|
||||
break;
|
||||
case int8:
|
||||
unary_op<int8_t>(a, out, op);
|
||||
break;
|
||||
case int16:
|
||||
unary_op<int16_t>(a, out, op);
|
||||
break;
|
||||
case int32:
|
||||
unary_op<int32_t>(a, out, op);
|
||||
break;
|
||||
case int64:
|
||||
unary_op<int64_t>(a, out, op);
|
||||
break;
|
||||
case float16:
|
||||
unary_op<float16_t>(a, out, op);
|
||||
break;
|
||||
case float32:
|
||||
unary_op<float>(a, out, op);
|
||||
break;
|
||||
case float64:
|
||||
unary_op<double>(a, out, op);
|
||||
break;
|
||||
case bfloat16:
|
||||
unary_op<bfloat16_t>(a, out, op);
|
||||
break;
|
||||
case complex64:
|
||||
unary_op<complex64_t>(a, out, op);
|
||||
break;
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
template <typename Op>
|
||||
void unary_fp(const array& a, array& out, Op op) {
|
||||
switch (out.dtype()) {
|
||||
case bfloat16:
|
||||
unary_op<bfloat16_t>(a, out, op);
|
||||
break;
|
||||
case float16:
|
||||
unary_op<float16_t>(a, out, op);
|
||||
break;
|
||||
case float32:
|
||||
unary_op<float>(a, out, op);
|
||||
break;
|
||||
case float64:
|
||||
unary_op<double>(a, out, op);
|
||||
break;
|
||||
case complex64:
|
||||
unary_op<complex64_t>(a, out, op);
|
||||
break;
|
||||
default:
|
||||
std::ostringstream err;
|
||||
err << "[unary_fp] Does not support " << out.dtype();
|
||||
throw std::runtime_error(err.str());
|
||||
}
|
||||
void unary_real_fp(const array& a, array& out, Op op, Stream stream) {
|
||||
set_unary_output_data(a, out);
|
||||
auto& encoder = cpu::get_command_encoder(stream);
|
||||
encoder.set_input_array(a);
|
||||
encoder.set_output_array(out);
|
||||
encoder.dispatch([a = array::unsafe_weak_copy(a),
|
||||
out = array::unsafe_weak_copy(out),
|
||||
op = op]() mutable {
|
||||
switch (out.dtype()) {
|
||||
case bfloat16:
|
||||
unary_op<bfloat16_t>(a, out, op);
|
||||
break;
|
||||
case float16:
|
||||
unary_op<float16_t>(a, out, op);
|
||||
break;
|
||||
case float32:
|
||||
unary_op<float>(a, out, op);
|
||||
break;
|
||||
case float64:
|
||||
unary_op<double>(a, out, op);
|
||||
break;
|
||||
default:
|
||||
std::ostringstream err;
|
||||
err << "[unary_real] Does not support " << out.dtype();
|
||||
throw std::runtime_error(err.str());
|
||||
}
|
||||
});
|
||||
}
|
||||
template <typename Op>
|
||||
void unary_fp(const array& a, array& out, Op op, Stream stream) {
|
||||
set_unary_output_data(a, out);
|
||||
auto& encoder = cpu::get_command_encoder(stream);
|
||||
encoder.set_input_array(a);
|
||||
encoder.set_output_array(out);
|
||||
encoder.dispatch([a = array::unsafe_weak_copy(a),
|
||||
out = array::unsafe_weak_copy(out),
|
||||
op = op]() mutable {
|
||||
switch (out.dtype()) {
|
||||
case bfloat16:
|
||||
unary_op<bfloat16_t>(a, out, op);
|
||||
break;
|
||||
case float16:
|
||||
unary_op<float16_t>(a, out, op);
|
||||
break;
|
||||
case float32:
|
||||
unary_op<float>(a, out, op);
|
||||
break;
|
||||
case float64:
|
||||
unary_op<double>(a, out, op);
|
||||
break;
|
||||
case complex64:
|
||||
unary_op<complex64_t>(a, out, op);
|
||||
break;
|
||||
default:
|
||||
std::ostringstream err;
|
||||
err << "[unary_fp] Does not support " << out.dtype();
|
||||
throw std::runtime_error(err.str());
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
template <typename Op>
|
||||
void unary_int(const array& a, array& out, Op op) {
|
||||
switch (out.dtype()) {
|
||||
case uint8:
|
||||
unary_op<uint8_t>(a, out, op);
|
||||
break;
|
||||
case uint16:
|
||||
unary_op<uint16_t>(a, out, op);
|
||||
break;
|
||||
case uint32:
|
||||
unary_op<uint32_t>(a, out, op);
|
||||
break;
|
||||
case uint64:
|
||||
unary_op<uint64_t>(a, out, op);
|
||||
break;
|
||||
case int8:
|
||||
unary_op<int8_t>(a, out, op);
|
||||
break;
|
||||
case int16:
|
||||
unary_op<int16_t>(a, out, op);
|
||||
break;
|
||||
case int32:
|
||||
unary_op<int32_t>(a, out, op);
|
||||
break;
|
||||
case int64:
|
||||
unary_op<int64_t>(a, out, op);
|
||||
break;
|
||||
default:
|
||||
std::ostringstream err;
|
||||
err << "[unary_int] Does not support " << out.dtype();
|
||||
throw std::runtime_error(err.str());
|
||||
}
|
||||
void unary_signed(const array& a, array& out, Op op, Stream stream) {
|
||||
set_unary_output_data(a, out);
|
||||
auto& encoder = cpu::get_command_encoder(stream);
|
||||
encoder.set_input_array(a);
|
||||
encoder.set_output_array(out);
|
||||
encoder.dispatch([a = array::unsafe_weak_copy(a),
|
||||
out = array::unsafe_weak_copy(out),
|
||||
op = op]() mutable {
|
||||
switch (out.dtype()) {
|
||||
case int8:
|
||||
unary_op<int8_t>(a, out, op);
|
||||
break;
|
||||
case int16:
|
||||
unary_op<int16_t>(a, out, op);
|
||||
break;
|
||||
case int32:
|
||||
unary_op<int32_t>(a, out, op);
|
||||
break;
|
||||
case int64:
|
||||
unary_op<int64_t>(a, out, op);
|
||||
break;
|
||||
case float16:
|
||||
unary_op<float16_t>(a, out, op);
|
||||
break;
|
||||
case float32:
|
||||
unary_op<float>(a, out, op);
|
||||
break;
|
||||
case float64:
|
||||
unary_op<double>(a, out, op);
|
||||
break;
|
||||
case bfloat16:
|
||||
unary_op<bfloat16_t>(a, out, op);
|
||||
break;
|
||||
case complex64:
|
||||
unary_op<complex64_t>(a, out, op);
|
||||
break;
|
||||
default:
|
||||
throw std::runtime_error("[Abs] Called on unsigned type");
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
template <typename Op>
|
||||
void unary_complex(const array& a, array& out, Op op, Stream stream) {
|
||||
set_unary_output_data(a, out);
|
||||
auto& encoder = cpu::get_command_encoder(stream);
|
||||
encoder.set_input_array(a);
|
||||
encoder.set_output_array(out);
|
||||
encoder.dispatch([a = array::unsafe_weak_copy(a),
|
||||
out = array::unsafe_weak_copy(out),
|
||||
op = op]() mutable { unary_op<complex64_t>(a, out, op); });
|
||||
}
|
||||
|
||||
template <typename Op>
|
||||
void unary_complex_to_float(const array& a, array& out, Op op, Stream stream) {
|
||||
set_unary_output_data(a, out);
|
||||
auto& encoder = cpu::get_command_encoder(stream);
|
||||
encoder.set_input_array(a);
|
||||
encoder.set_output_array(out);
|
||||
encoder.dispatch(
|
||||
[a = array::unsafe_weak_copy(a),
|
||||
out = array::unsafe_weak_copy(out),
|
||||
op = op]() mutable { unary_op<complex64_t, float>(a, out, op); });
|
||||
}
|
||||
|
||||
template <typename Op>
|
||||
void unary_int(const array& a, array& out, Op op, Stream stream) {
|
||||
set_unary_output_data(a, out);
|
||||
auto& encoder = cpu::get_command_encoder(stream);
|
||||
encoder.set_input_array(a);
|
||||
encoder.set_output_array(out);
|
||||
encoder.dispatch([a = array::unsafe_weak_copy(a),
|
||||
out = array::unsafe_weak_copy(out),
|
||||
op = op]() mutable {
|
||||
switch (out.dtype()) {
|
||||
case uint8:
|
||||
unary_op<uint8_t>(a, out, op);
|
||||
break;
|
||||
case uint16:
|
||||
unary_op<uint16_t>(a, out, op);
|
||||
break;
|
||||
case uint32:
|
||||
unary_op<uint32_t>(a, out, op);
|
||||
break;
|
||||
case uint64:
|
||||
unary_op<uint64_t>(a, out, op);
|
||||
break;
|
||||
case int8:
|
||||
unary_op<int8_t>(a, out, op);
|
||||
break;
|
||||
case int16:
|
||||
unary_op<int16_t>(a, out, op);
|
||||
break;
|
||||
case int32:
|
||||
unary_op<int32_t>(a, out, op);
|
||||
break;
|
||||
case int64:
|
||||
unary_op<int64_t>(a, out, op);
|
||||
break;
|
||||
default:
|
||||
std::ostringstream err;
|
||||
err << "[unary_int] Does not support " << out.dtype();
|
||||
throw std::runtime_error(err.str());
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
||||
@@ -86,13 +86,14 @@ struct Sign {
|
||||
template <int N, typename T>
|
||||
Simd<T, N> operator()(Simd<T, N> x) {
|
||||
auto z = Simd<T, N>{0};
|
||||
auto o = Simd<T, N>{1};
|
||||
auto m = Simd<T, N>{-1};
|
||||
if constexpr (std::is_unsigned_v<T>) {
|
||||
return x != z;
|
||||
return simd::select(x == z, z, o);
|
||||
} else if constexpr (std::is_same_v<T, complex64_t>) {
|
||||
return simd::select(x == z, x, Simd<T, N>(x / simd::abs(x)));
|
||||
} else {
|
||||
return simd::select(
|
||||
x < z, Simd<T, N>{-1}, simd::select(x > z, Simd<T, N>{1}, z));
|
||||
return simd::select(x < z, m, simd::select(x > z, o, z));
|
||||
}
|
||||
}
|
||||
SINGLE()
|
||||
|
||||
@@ -47,6 +47,7 @@ if(MLX_METAL_JIT)
|
||||
make_jit_source(binary)
|
||||
make_jit_source(binary_two)
|
||||
make_jit_source(fft kernels/fft/radix.h kernels/fft/readwrite.h)
|
||||
make_jit_source(logsumexp)
|
||||
make_jit_source(ternary)
|
||||
make_jit_source(softmax)
|
||||
make_jit_source(scan)
|
||||
@@ -60,6 +61,7 @@ if(MLX_METAL_JIT)
|
||||
kernels/steel/gemm/transforms.h)
|
||||
make_jit_source(steel/gemm/kernels/steel_gemm_fused)
|
||||
make_jit_source(steel/gemm/kernels/steel_gemm_masked kernels/steel/defines.h)
|
||||
make_jit_source(steel/gemm/kernels/steel_gemm_gather)
|
||||
make_jit_source(steel/gemm/kernels/steel_gemm_splitk)
|
||||
make_jit_source(
|
||||
steel/conv/conv
|
||||
@@ -95,6 +97,7 @@ target_sources(
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/hadamard.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/logsumexp.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/matmul.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/scaled_dot_product_attention.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/metal.cpp
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
#include "mlx/backend/metal/metal.h"
|
||||
#include "mlx/backend/metal/metal_impl.h"
|
||||
#include "mlx/backend/metal/resident.h"
|
||||
#include "mlx/memory.h"
|
||||
|
||||
#include <mach/vm_page_size.h>
|
||||
#include <unistd.h>
|
||||
@@ -20,6 +21,9 @@ Allocator& allocator() {
|
||||
}
|
||||
|
||||
void* Buffer::raw_ptr() {
|
||||
if (!ptr_) {
|
||||
return nullptr;
|
||||
}
|
||||
return static_cast<MTL::Buffer*>(ptr_)->contents();
|
||||
}
|
||||
|
||||
@@ -29,8 +33,11 @@ namespace metal {
|
||||
|
||||
namespace {
|
||||
|
||||
BufferCache::BufferCache(MTL::Device* device)
|
||||
: device_(device), head_(nullptr), tail_(nullptr), pool_size_(0) {}
|
||||
BufferCache::BufferCache(ResidencySet& residency_set)
|
||||
: head_(nullptr),
|
||||
tail_(nullptr),
|
||||
pool_size_(0),
|
||||
residency_set_(residency_set) {}
|
||||
|
||||
BufferCache::~BufferCache() {
|
||||
auto pool = metal::new_scoped_memory_pool();
|
||||
@@ -41,6 +48,9 @@ int BufferCache::clear() {
|
||||
int n_release = 0;
|
||||
for (auto& [size, holder] : buffer_pool_) {
|
||||
if (holder->buf) {
|
||||
if (!holder->buf->heap()) {
|
||||
residency_set_.erase(holder->buf);
|
||||
}
|
||||
holder->buf->release();
|
||||
n_release++;
|
||||
}
|
||||
@@ -98,6 +108,9 @@ int BufferCache::release_cached_buffers(size_t min_bytes_to_free) {
|
||||
while (tail_ && (total_bytes_freed < min_bytes_to_free)) {
|
||||
if (tail_->buf) {
|
||||
total_bytes_freed += tail_->buf->length();
|
||||
if (!tail_->buf->heap()) {
|
||||
residency_set_.erase(tail_->buf);
|
||||
}
|
||||
tail_->buf->release();
|
||||
tail_->buf = nullptr;
|
||||
n_release++;
|
||||
@@ -152,7 +165,7 @@ void BufferCache::remove_from_list(BufferCache::BufferHolder* to_remove) {
|
||||
MetalAllocator::MetalAllocator()
|
||||
: device_(device(mlx::core::Device::gpu).mtl_device()),
|
||||
residency_set_(device_),
|
||||
buffer_cache_(device_) {
|
||||
buffer_cache_(residency_set_) {
|
||||
auto pool = metal::new_scoped_memory_pool();
|
||||
auto memsize = std::get<size_t>(device_info().at("memory_size"));
|
||||
auto max_rec_size =
|
||||
@@ -189,16 +202,19 @@ size_t MetalAllocator::set_cache_limit(size_t limit) {
|
||||
return limit;
|
||||
};
|
||||
|
||||
size_t MetalAllocator::set_memory_limit(size_t limit, bool relaxed) {
|
||||
size_t MetalAllocator::set_memory_limit(size_t limit) {
|
||||
std::unique_lock lk(mutex_);
|
||||
std::swap(limit, block_limit_);
|
||||
relaxed_ = relaxed;
|
||||
gc_limit_ = std::min(
|
||||
block_limit_,
|
||||
static_cast<size_t>(0.95 * device_->recommendedMaxWorkingSetSize()));
|
||||
return limit;
|
||||
};
|
||||
|
||||
size_t MetalAllocator::get_memory_limit() {
|
||||
return block_limit_;
|
||||
}
|
||||
|
||||
size_t MetalAllocator::set_wired_limit(size_t limit) {
|
||||
std::unique_lock lk(mutex_);
|
||||
std::swap(limit, wired_limit_);
|
||||
@@ -206,7 +222,7 @@ size_t MetalAllocator::set_wired_limit(size_t limit) {
|
||||
return limit;
|
||||
};
|
||||
|
||||
Buffer MetalAllocator::malloc(size_t size, bool allow_swap /* = false */) {
|
||||
Buffer MetalAllocator::malloc(size_t size) {
|
||||
// Metal doesn't like empty buffers
|
||||
if (size == 0) {
|
||||
return Buffer{nullptr};
|
||||
@@ -233,11 +249,6 @@ Buffer MetalAllocator::malloc(size_t size, bool allow_swap /* = false */) {
|
||||
if (!buf) {
|
||||
size_t mem_required = get_active_memory() + get_cache_memory() + size;
|
||||
|
||||
// If there is too much memory pressure, fail (likely causes a wait).
|
||||
if (!(allow_swap && relaxed_) && mem_required >= block_limit_) {
|
||||
return Buffer{nullptr};
|
||||
}
|
||||
|
||||
auto pool = metal::new_scoped_memory_pool();
|
||||
|
||||
// If we have a lot of memory pressure or are over the maximum cache size,
|
||||
@@ -261,9 +272,13 @@ Buffer MetalAllocator::malloc(size_t size, bool allow_swap /* = false */) {
|
||||
if (!buf) {
|
||||
buf = device_->newBuffer(size, resource_options);
|
||||
}
|
||||
if (!buf) {
|
||||
return Buffer{nullptr};
|
||||
}
|
||||
lk.lock();
|
||||
if (buf) {
|
||||
num_resources_++;
|
||||
num_resources_++;
|
||||
if (!buf->heap()) {
|
||||
residency_set_.insert(buf);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -277,10 +292,6 @@ Buffer MetalAllocator::malloc(size_t size, bool allow_swap /* = false */) {
|
||||
get_cache_memory() - max_pool_size_);
|
||||
}
|
||||
|
||||
if (!buf->heap()) {
|
||||
residency_set_.insert(buf);
|
||||
}
|
||||
|
||||
return Buffer{static_cast<void*>(buf)};
|
||||
}
|
||||
|
||||
@@ -296,14 +307,14 @@ void MetalAllocator::free(Buffer buffer) {
|
||||
return;
|
||||
}
|
||||
std::unique_lock lk(mutex_);
|
||||
if (!buf->heap()) {
|
||||
residency_set_.erase(buf);
|
||||
}
|
||||
active_memory_ -= buf->length();
|
||||
if (get_cache_memory() < max_pool_size_) {
|
||||
buffer_cache_.recycle_to_cache(buf);
|
||||
} else {
|
||||
num_resources_--;
|
||||
if (!buf->heap()) {
|
||||
residency_set_.erase(buf);
|
||||
}
|
||||
lk.unlock();
|
||||
auto pool = metal::new_scoped_memory_pool();
|
||||
buf->release();
|
||||
@@ -322,37 +333,40 @@ MetalAllocator& allocator() {
|
||||
return *allocator_;
|
||||
}
|
||||
|
||||
} // namespace metal
|
||||
|
||||
size_t set_cache_limit(size_t limit) {
|
||||
return allocator().set_cache_limit(limit);
|
||||
return metal::allocator().set_cache_limit(limit);
|
||||
}
|
||||
size_t set_memory_limit(size_t limit, bool relaxed /* = true */) {
|
||||
return allocator().set_memory_limit(limit, relaxed);
|
||||
size_t set_memory_limit(size_t limit) {
|
||||
return metal::allocator().set_memory_limit(limit);
|
||||
}
|
||||
size_t get_memory_limit() {
|
||||
return metal::allocator().get_memory_limit();
|
||||
}
|
||||
size_t set_wired_limit(size_t limit) {
|
||||
if (limit >
|
||||
std::get<size_t>(device_info().at("max_recommended_working_set_size"))) {
|
||||
if (limit > std::get<size_t>(metal::device_info().at(
|
||||
"max_recommended_working_set_size"))) {
|
||||
throw std::invalid_argument(
|
||||
"[metal::set_wired_limit] Setting a wired limit larger than "
|
||||
"the maximum working set size is not allowed.");
|
||||
}
|
||||
return allocator().set_wired_limit(limit);
|
||||
return metal::allocator().set_wired_limit(limit);
|
||||
}
|
||||
size_t get_active_memory() {
|
||||
return allocator().get_active_memory();
|
||||
return metal::allocator().get_active_memory();
|
||||
}
|
||||
size_t get_peak_memory() {
|
||||
return allocator().get_peak_memory();
|
||||
return metal::allocator().get_peak_memory();
|
||||
}
|
||||
void reset_peak_memory() {
|
||||
allocator().reset_peak_memory();
|
||||
metal::allocator().reset_peak_memory();
|
||||
}
|
||||
size_t get_cache_memory() {
|
||||
return allocator().get_cache_memory();
|
||||
return metal::allocator().get_cache_memory();
|
||||
}
|
||||
void clear_cache() {
|
||||
return allocator().clear_cache();
|
||||
return metal::allocator().clear_cache();
|
||||
}
|
||||
|
||||
} // namespace metal
|
||||
|
||||
} // namespace mlx::core
|
||||
|
||||
@@ -18,7 +18,7 @@ namespace {
|
||||
|
||||
class BufferCache {
|
||||
public:
|
||||
BufferCache(MTL::Device* device);
|
||||
BufferCache(ResidencySet& residency_set);
|
||||
~BufferCache();
|
||||
|
||||
MTL::Buffer* reuse_from_cache(size_t size);
|
||||
@@ -42,13 +42,11 @@ class BufferCache {
|
||||
void add_at_head(BufferHolder* to_add);
|
||||
void remove_from_list(BufferHolder* to_remove);
|
||||
|
||||
MTL::Device* device_;
|
||||
MTL::Heap* heap_{nullptr};
|
||||
|
||||
std::multimap<size_t, BufferHolder*> buffer_pool_;
|
||||
BufferHolder* head_;
|
||||
BufferHolder* tail_;
|
||||
size_t pool_size_;
|
||||
ResidencySet& residency_set_;
|
||||
};
|
||||
|
||||
} // namespace
|
||||
@@ -56,7 +54,7 @@ class BufferCache {
|
||||
class MetalAllocator : public allocator::Allocator {
|
||||
/** Allocator for Metal GPUs. */
|
||||
public:
|
||||
virtual Buffer malloc(size_t size, bool allow_swap = false) override;
|
||||
virtual Buffer malloc(size_t size) override;
|
||||
virtual void free(Buffer buffer) override;
|
||||
virtual size_t size(Buffer buffer) const override;
|
||||
size_t get_active_memory() {
|
||||
@@ -73,7 +71,8 @@ class MetalAllocator : public allocator::Allocator {
|
||||
return buffer_cache_.cache_size();
|
||||
};
|
||||
size_t set_cache_limit(size_t limit);
|
||||
size_t set_memory_limit(size_t limit, bool relaxed);
|
||||
size_t set_memory_limit(size_t limit);
|
||||
size_t get_memory_limit();
|
||||
size_t set_wired_limit(size_t limit);
|
||||
void clear_cache();
|
||||
|
||||
|
||||
@@ -102,16 +102,9 @@ void binary_op_gpu_inplace(
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
// - If a is donated it goes to the first output
|
||||
// - If b is donated it goes to the first output if a was not donated
|
||||
// otherwise it goes to the second output.
|
||||
// - If there is only one output only one of a and b will be donated.
|
||||
bool donate_a = a.data_shared_ptr() == nullptr;
|
||||
bool donate_b = b.data_shared_ptr() == nullptr;
|
||||
int arg_idx = 0;
|
||||
compute_encoder.set_input_array(donate_a ? outputs[0] : a, arg_idx++);
|
||||
compute_encoder.set_input_array(
|
||||
donate_b ? (donate_a ? outputs[1] : outputs[0]) : b, arg_idx++);
|
||||
compute_encoder.set_input_array(a, arg_idx++);
|
||||
compute_encoder.set_input_array(b, arg_idx++);
|
||||
compute_encoder.set_output_array(outputs[0], arg_idx++);
|
||||
if (outputs.size() == 2) {
|
||||
compute_encoder.set_output_array(outputs[1], arg_idx++);
|
||||
@@ -164,8 +157,8 @@ void binary_op_gpu(
|
||||
auto& a = inputs[0];
|
||||
auto& b = inputs[1];
|
||||
auto bopt = get_binary_op_type(a, b);
|
||||
set_binary_op_output_data(a, b, outputs[0], bopt, true);
|
||||
set_binary_op_output_data(a, b, outputs[1], bopt, true);
|
||||
set_binary_op_output_data(a, b, outputs[0], bopt);
|
||||
set_binary_op_output_data(a, b, outputs[1], bopt);
|
||||
binary_op_gpu_inplace(inputs, outputs, op, s);
|
||||
}
|
||||
|
||||
@@ -195,7 +188,7 @@ void binary_op_gpu(
|
||||
auto& a = inputs[0];
|
||||
auto& b = inputs[1];
|
||||
auto bopt = get_binary_op_type(a, b);
|
||||
set_binary_op_output_data(a, b, out, bopt, true);
|
||||
set_binary_op_output_data(a, b, out, bopt);
|
||||
binary_op_gpu_inplace(inputs, out, op, s);
|
||||
}
|
||||
|
||||
|
||||
@@ -457,7 +457,7 @@ void Compiled::eval_gpu(
|
||||
}
|
||||
|
||||
compiled_allocate_outputs(
|
||||
inputs, outputs, inputs_, constant_ids_, contiguous, true);
|
||||
inputs, outputs, inputs_, constant_ids_, contiguous);
|
||||
|
||||
// Put the outputs in
|
||||
for (auto& x : outputs) {
|
||||
|
||||
@@ -37,7 +37,7 @@ void explicit_gemm_conv_ND_gpu(
|
||||
Shape unfolded_shape{implicit_M, implicit_K};
|
||||
array in_unfolded(unfolded_shape, in.dtype(), nullptr, {});
|
||||
|
||||
in_unfolded.set_data(allocator::malloc_or_wait(in_unfolded.nbytes()));
|
||||
in_unfolded.set_data(allocator::malloc(in_unfolded.nbytes()));
|
||||
|
||||
// Prepare unfolding kernel
|
||||
std::ostringstream kname;
|
||||
@@ -115,7 +115,7 @@ void explicit_gemm_conv_group_ND_gpu(
|
||||
// Prepare unfolding array
|
||||
Shape unfolded_shape{implicit_M, implicit_K * groups};
|
||||
array in_unfolded(unfolded_shape, in.dtype(), nullptr, {});
|
||||
in_unfolded.set_data(allocator::malloc_or_wait(in_unfolded.nbytes()));
|
||||
in_unfolded.set_data(allocator::malloc(in_unfolded.nbytes()));
|
||||
|
||||
// Prepare unfolding kernel
|
||||
std::ostringstream kname;
|
||||
@@ -613,7 +613,7 @@ void winograd_conv_2D_gpu(
|
||||
// Do filter transform
|
||||
Shape filt_wg_shape = {8 * 8, conv_params.C, conv_params.O};
|
||||
array filt_wg(std::move(filt_wg_shape), wt.dtype(), nullptr, {});
|
||||
filt_wg.set_data(allocator::malloc_or_wait(filt_wg.nbytes()));
|
||||
filt_wg.set_data(allocator::malloc(filt_wg.nbytes()));
|
||||
copies_w.push_back(filt_wg);
|
||||
{
|
||||
int bc = 32;
|
||||
@@ -640,7 +640,7 @@ void winograd_conv_2D_gpu(
|
||||
// Do input transform
|
||||
Shape inp_wg_shape = {8 * 8, N_tiles, conv_params.C};
|
||||
array inp_wg(std::move(inp_wg_shape), in.dtype(), nullptr, {});
|
||||
inp_wg.set_data(allocator::malloc_or_wait(inp_wg.nbytes()));
|
||||
inp_wg.set_data(allocator::malloc(inp_wg.nbytes()));
|
||||
copies_w.push_back(inp_wg);
|
||||
{
|
||||
int bc = 32;
|
||||
@@ -667,7 +667,7 @@ void winograd_conv_2D_gpu(
|
||||
// Do batched gemm
|
||||
Shape out_wg_shape = {8 * 8, N_tiles, conv_params.O};
|
||||
array out_wg(std::move(out_wg_shape), in.dtype(), nullptr, {});
|
||||
out_wg.set_data(allocator::malloc_or_wait(out_wg.nbytes()));
|
||||
out_wg.set_data(allocator::malloc(out_wg.nbytes()));
|
||||
copies_w.push_back(out_wg);
|
||||
{
|
||||
std::vector<array> empty_copies;
|
||||
@@ -712,6 +712,65 @@ void winograd_conv_2D_gpu(
|
||||
}
|
||||
}
|
||||
|
||||
void depthwise_conv_2D_gpu(
|
||||
const Stream& s,
|
||||
metal::Device& d,
|
||||
const array& in,
|
||||
const array& wt,
|
||||
array out,
|
||||
const MLXConvParams<2>& conv_params) {
|
||||
std::ostringstream kname;
|
||||
kname << "depthwise_conv_2d_" << type_to_name(out);
|
||||
std::string base_name = kname.str();
|
||||
|
||||
const int N = conv_params.N;
|
||||
const int ker_h = conv_params.wS[0];
|
||||
const int ker_w = conv_params.wS[1];
|
||||
const int str_h = conv_params.str[0];
|
||||
const int str_w = conv_params.str[1];
|
||||
const int tc = 8;
|
||||
const int tw = 8;
|
||||
const int th = 4;
|
||||
const bool do_flip = conv_params.flip;
|
||||
|
||||
metal::MTLFCList func_consts = {
|
||||
{&ker_h, MTL::DataType::DataTypeInt, 00},
|
||||
{&ker_w, MTL::DataType::DataTypeInt, 01},
|
||||
{&str_h, MTL::DataType::DataTypeInt, 10},
|
||||
{&str_w, MTL::DataType::DataTypeInt, 11},
|
||||
{&th, MTL::DataType::DataTypeInt, 100},
|
||||
{&tw, MTL::DataType::DataTypeInt, 101},
|
||||
{&do_flip, MTL::DataType::DataTypeBool, 200},
|
||||
};
|
||||
|
||||
// clang-format off
|
||||
kname << "_ker_h_" << ker_h
|
||||
<< "_ker_w_" << ker_w
|
||||
<< "_str_h_" << str_h
|
||||
<< "_str_w_" << str_w
|
||||
<< "_tgp_h_" << th
|
||||
<< "_tgp_w_" << tw
|
||||
<< "_do_flip_" << (do_flip ? 't' : 'n'); // clang-format on
|
||||
|
||||
std::string hash_name = kname.str();
|
||||
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(base_name, "mlx", hash_name, func_consts);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
compute_encoder.set_input_array(in, 0);
|
||||
compute_encoder.set_input_array(wt, 1);
|
||||
compute_encoder.set_output_array(out, 2);
|
||||
|
||||
compute_encoder.set_bytes(conv_params, 3);
|
||||
|
||||
MTL::Size group_dims = MTL::Size(tc, tw, th);
|
||||
MTL::Size grid_dims = MTL::Size(
|
||||
conv_params.C / tc, conv_params.oS[1] / tw, (conv_params.oS[0] / th) * N);
|
||||
|
||||
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
void conv_2D_gpu(
|
||||
const Stream& s,
|
||||
metal::Device& d,
|
||||
@@ -754,11 +813,20 @@ void conv_2D_gpu(
|
||||
bool is_kdil_one = conv_params.kdil[0] == 1 && conv_params.kdil[1] == 1;
|
||||
bool is_idil_one = conv_params.idil[0] == 1 && conv_params.idil[1] == 1;
|
||||
|
||||
if (groups > 1) {
|
||||
if (is_idil_one && groups > 1) {
|
||||
const int C_per_group = conv_params.C / groups;
|
||||
const int O_per_group = conv_params.O / groups;
|
||||
|
||||
if (is_idil_one && (C_per_group <= 4 || C_per_group % 16 == 0) &&
|
||||
if (C_per_group == 1 && O_per_group == 1 && is_kdil_one &&
|
||||
conv_params.wS[0] <= 7 && conv_params.wS[1] <= 7 &&
|
||||
conv_params.str[0] <= 2 && conv_params.str[1] <= 2 &&
|
||||
conv_params.oS[0] % 8 == 0 && conv_params.oS[1] % 8 == 0 &&
|
||||
conv_params.wt_strides[1] == conv_params.wS[1] &&
|
||||
conv_params.C % 16 == 0 && conv_params.C == conv_params.O) {
|
||||
return depthwise_conv_2D_gpu(s, d, in, wt, out, conv_params);
|
||||
}
|
||||
|
||||
if ((C_per_group <= 4 || C_per_group % 16 == 0) &&
|
||||
(O_per_group <= 16 || O_per_group % 16 == 0)) {
|
||||
return implicit_gemm_conv_2D_gpu(s, d, in, wt, out, conv_params);
|
||||
} else {
|
||||
@@ -855,7 +923,7 @@ void conv_3D_gpu(
|
||||
} // namespace
|
||||
|
||||
void Convolution::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
auto& s = stream();
|
||||
auto& d = metal::device(s.device);
|
||||
|
||||
|
||||
@@ -14,25 +14,11 @@ namespace mlx::core {
|
||||
constexpr int MAX_COPY_SPECIALIZED_DIMS = 3;
|
||||
|
||||
void copy_gpu(const array& in, array& out, CopyType ctype, const Stream& s) {
|
||||
if (ctype == CopyType::Vector) {
|
||||
// If the input is donateable, we are doing a vector copy and the types
|
||||
// have the same size, then the input buffer can hold the output.
|
||||
if (in.is_donatable() && in.itemsize() == out.itemsize()) {
|
||||
out.move_shared_buffer(in);
|
||||
// If the output has the same type as the input then there is nothing to
|
||||
// copy, just use the buffer.
|
||||
if (in.dtype() == out.dtype()) {
|
||||
return;
|
||||
}
|
||||
} else {
|
||||
out.set_data(
|
||||
allocator::malloc_or_wait(in.data_size() * out.itemsize()),
|
||||
in.data_size(),
|
||||
in.strides(),
|
||||
in.flags());
|
||||
}
|
||||
} else {
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
bool donated = set_copy_output_data(in, out, ctype);
|
||||
if (donated && in.dtype() == out.dtype()) {
|
||||
// If the output has the same type as the input then there is nothing to
|
||||
// copy, just use the buffer.
|
||||
return;
|
||||
}
|
||||
if (ctype == CopyType::GeneralGeneral) {
|
||||
ctype = CopyType::General;
|
||||
@@ -216,7 +202,7 @@ void fill_gpu(const array& val, array& out, const Stream& s) {
|
||||
if (out.size() == 0) {
|
||||
return;
|
||||
}
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
bool large = out.data_size() > UINT32_MAX;
|
||||
auto& d = metal::device(s.device);
|
||||
std::string kernel_name = std::string(large ? "s2" : "s") + "_copy" +
|
||||
|
||||
@@ -19,7 +19,7 @@ void CustomKernel::eval_gpu(
|
||||
copies.emplace_back(init_value_.value(), out.dtype());
|
||||
fill_gpu(copies.back(), out, s);
|
||||
} else {
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -19,9 +19,6 @@ namespace mlx::core::metal {
|
||||
|
||||
namespace {
|
||||
|
||||
// TODO nicer way to set this or possibly expose as an environment variable
|
||||
constexpr int MAX_BUFFERS_PER_QUEUE = 12;
|
||||
|
||||
constexpr const char* default_mtllib_path = METAL_PATH;
|
||||
|
||||
auto get_metal_version() {
|
||||
@@ -58,7 +55,10 @@ std::pair<MTL::Library*, NS::Error*> load_library_from_path(
|
||||
}
|
||||
|
||||
#ifdef SWIFTPM_BUNDLE
|
||||
MTL::Library* try_load_bundle(MTL::Device* device, NS::URL* url) {
|
||||
MTL::Library* try_load_bundle(
|
||||
MTL::Device* device,
|
||||
NS::URL* url,
|
||||
const std::string& lib_name) {
|
||||
std::string bundle_path = std::string(url->fileSystemRepresentation()) + "/" +
|
||||
SWIFTPM_BUNDLE + ".bundle";
|
||||
auto bundle = NS::Bundle::alloc()->init(
|
||||
@@ -66,8 +66,8 @@ MTL::Library* try_load_bundle(MTL::Device* device, NS::URL* url) {
|
||||
if (bundle != nullptr) {
|
||||
std::string resource_path =
|
||||
std::string(bundle->resourceURL()->fileSystemRepresentation()) + "/" +
|
||||
"default.metallib";
|
||||
auto [lib, error] = load_library_from_path(device, resource_path.c_str());
|
||||
lib_name + ".metallib" auto [lib, error] =
|
||||
load_library_from_path(device, resource_path.c_str());
|
||||
if (lib) {
|
||||
return lib;
|
||||
}
|
||||
@@ -76,51 +76,124 @@ MTL::Library* try_load_bundle(MTL::Device* device, NS::URL* url) {
|
||||
}
|
||||
#endif
|
||||
|
||||
// Firstly, search for the metallib in the same path as this binary
|
||||
std::pair<MTL::Library*, NS::Error*> load_colocated_library(
|
||||
MTL::Device* device,
|
||||
const std::string& lib_name) {
|
||||
std::string lib_path = get_colocated_mtllib_path(lib_name);
|
||||
if (lib_path.size() != 0) {
|
||||
return load_library_from_path(device, lib_path.c_str());
|
||||
}
|
||||
return {nullptr, nullptr};
|
||||
}
|
||||
|
||||
std::pair<MTL::Library*, NS::Error*> load_swiftpm_library(
|
||||
MTL::Device* device,
|
||||
const std::string& lib_name) {
|
||||
#ifdef SWIFTPM_BUNDLE
|
||||
MTL::Library* library =
|
||||
try_load_bundle(device, NS::Bundle::mainBundle()->bundleURL(), lib_name);
|
||||
if (library != nullptr) {
|
||||
return {library, nullptr};
|
||||
}
|
||||
auto bundles = NS::Bundle::allBundles();
|
||||
for (int i = 0, c = (int)bundles->count(); i < c; i++) {
|
||||
auto bundle = reinterpret_cast<NS::Bundle*>(bundles->object(i));
|
||||
library = try_load_bundle(device, bundle->resourceURL());
|
||||
if (library != nullptr) {
|
||||
return {library, nullptr};
|
||||
}
|
||||
}
|
||||
#endif
|
||||
return {nullptr, nullptr};
|
||||
}
|
||||
|
||||
MTL::Library* load_default_library(MTL::Device* device) {
|
||||
NS::Error *error1, *error2, *error3;
|
||||
MTL::Library* lib;
|
||||
// First try the colocated mlx.metallib
|
||||
std::tie(lib, error1) = load_colocated_library(device, "mlx");
|
||||
if (lib) {
|
||||
return lib;
|
||||
}
|
||||
|
||||
// Then try default.metallib in a SwiftPM bundle if we have one
|
||||
std::tie(lib, error2) = load_swiftpm_library(device, "default");
|
||||
if (lib) {
|
||||
return lib;
|
||||
}
|
||||
|
||||
// Finally try default_mtllib_path
|
||||
std::tie(lib, error3) = load_library_from_path(device, default_mtllib_path);
|
||||
if (!lib) {
|
||||
std::ostringstream msg;
|
||||
msg << "Failed to load the default metallib. ";
|
||||
if (error1 != nullptr) {
|
||||
msg << error1->localizedDescription()->utf8String() << " ";
|
||||
}
|
||||
if (error2 != nullptr) {
|
||||
msg << error2->localizedDescription()->utf8String() << " ";
|
||||
}
|
||||
if (error3 != nullptr) {
|
||||
msg << error3->localizedDescription()->utf8String() << " ";
|
||||
}
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
return lib;
|
||||
}
|
||||
|
||||
MTL::Library* load_library(
|
||||
MTL::Device* device,
|
||||
const std::string& lib_name = "mlx",
|
||||
const char* lib_path = default_mtllib_path) {
|
||||
// Firstly, search for the metallib in the same path as this binary
|
||||
std::string first_path = get_colocated_mtllib_path(lib_name);
|
||||
if (first_path.size() != 0) {
|
||||
auto [lib, error] = load_library_from_path(device, first_path.c_str());
|
||||
const std::string& lib_name,
|
||||
const std::string& lib_path) {
|
||||
// We have been given a path that ends in metallib so try to load it
|
||||
if (lib_path.size() > 9 &&
|
||||
std::equal(lib_path.end() - 9, lib_path.end(), ".metallib")) {
|
||||
auto [lib, error] = load_library_from_path(device, lib_path.c_str());
|
||||
if (!lib) {
|
||||
std::ostringstream msg;
|
||||
msg << "Failed to load the metallib from <" << lib_path << "> with error "
|
||||
<< error->localizedDescription()->utf8String();
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
}
|
||||
|
||||
// We have been given a path so try to load from lib_path / lib_name.metallib
|
||||
if (lib_path.size() > 0) {
|
||||
std::string full_path = lib_path + "/" + lib_name + ".metallib";
|
||||
auto [lib, error] = load_library_from_path(device, full_path.c_str());
|
||||
if (!lib) {
|
||||
std::ostringstream msg;
|
||||
msg << "Failed to load the metallib from <" << full_path
|
||||
<< "> with error " << error->localizedDescription()->utf8String();
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
}
|
||||
|
||||
// Try to load the colocated library
|
||||
{
|
||||
auto [lib, error] = load_colocated_library(device, lib_name);
|
||||
if (lib) {
|
||||
return lib;
|
||||
}
|
||||
}
|
||||
|
||||
#ifdef SWIFTPM_BUNDLE
|
||||
// try to load from a swiftpm resource bundle -- scan the available bundles to
|
||||
// find one that contains the named bundle
|
||||
// Try to load the library from swiftpm
|
||||
{
|
||||
MTL::Library* library =
|
||||
try_load_bundle(device, NS::Bundle::mainBundle()->bundleURL());
|
||||
if (library != nullptr) {
|
||||
return library;
|
||||
}
|
||||
auto bundles = NS::Bundle::allBundles();
|
||||
for (int i = 0, c = (int)bundles->count(); i < c; i++) {
|
||||
auto bundle = reinterpret_cast<NS::Bundle*>(bundles->object(i));
|
||||
library = try_load_bundle(device, bundle->resourceURL());
|
||||
if (library != nullptr) {
|
||||
return library;
|
||||
}
|
||||
auto [lib, error] = load_swiftpm_library(device, lib_name);
|
||||
if (lib) {
|
||||
return lib;
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
// Couldn't find it so let's load it from default_mtllib_path
|
||||
{
|
||||
auto [lib, error] = load_library_from_path(device, lib_path);
|
||||
if (!lib) {
|
||||
std::ostringstream msg;
|
||||
msg << error->localizedDescription()->utf8String() << "\n"
|
||||
<< "Failed to load device library from <" << lib_path << ">"
|
||||
<< " or <" << first_path << ">.";
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
return lib;
|
||||
}
|
||||
std::ostringstream msg;
|
||||
msg << "Failed to load the metallib " << lib_name << ".metallib. "
|
||||
<< "We attempted to load it from <" << get_colocated_mtllib_path(lib_name)
|
||||
<< ">";
|
||||
#ifdef SWIFTPM_BUNDLE
|
||||
msg << " and from the Swift PM bundle.";
|
||||
#endif
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
|
||||
} // namespace
|
||||
@@ -168,9 +241,10 @@ void CommandEncoder::set_output_array(
|
||||
register_output_array(a);
|
||||
}
|
||||
|
||||
void CommandEncoder::register_output_array(array& a) {
|
||||
void CommandEncoder::register_output_array(const array& a) {
|
||||
all_outputs_.insert(a.buffer().ptr());
|
||||
auto buf = static_cast<MTL::Resource*>(a.buffer().ptr());
|
||||
|
||||
auto buf = static_cast<MTL::Resource*>(const_cast<void*>(a.buffer().ptr()));
|
||||
if (concurrent_) {
|
||||
concurrent_outputs_.insert(buf);
|
||||
} else {
|
||||
@@ -212,7 +286,7 @@ void CommandEncoder::barrier() {
|
||||
Device::Device() {
|
||||
auto pool = new_scoped_memory_pool();
|
||||
device_ = load_device();
|
||||
library_map_ = {{"mlx", load_library(device_)}};
|
||||
library_map_ = {{"mlx", load_default_library(device_)}};
|
||||
arch_ = std::string(device_->architecture()->name()->utf8String());
|
||||
auto arch = arch_.back();
|
||||
switch (arch) {
|
||||
@@ -255,7 +329,7 @@ Device::~Device() {
|
||||
|
||||
void Device::new_queue(int index) {
|
||||
auto thread_pool = metal::new_scoped_memory_pool();
|
||||
auto q = device_->newCommandQueue(MAX_BUFFERS_PER_QUEUE);
|
||||
auto q = device_->newCommandQueue();
|
||||
debug_set_stream_queue_label(q, index);
|
||||
if (!q) {
|
||||
throw std::runtime_error(
|
||||
|
||||
@@ -62,7 +62,7 @@ struct CommandEncoder {
|
||||
|
||||
void set_input_array(const array& a, int idx, int64_t offset = 0);
|
||||
void set_output_array(array& a, int idx, int64_t offset = 0);
|
||||
void register_output_array(array& a);
|
||||
void register_output_array(const array& a);
|
||||
void dispatch_threadgroups(MTL::Size grid_dims, MTL::Size group_dims);
|
||||
void dispatch_threads(MTL::Size grid_dims, MTL::Size group_dims);
|
||||
void maybeInsertBarrier();
|
||||
@@ -189,15 +189,7 @@ class Device {
|
||||
|
||||
void register_library(
|
||||
const std::string& lib_name,
|
||||
const std::string& lib_path);
|
||||
|
||||
// Note, this should remain in the header so that it is not dynamically
|
||||
// linked
|
||||
void register_library(const std::string& lib_name) {
|
||||
if (auto it = library_map_.find(lib_name); it == library_map_.end()) {
|
||||
register_library(lib_name, get_colocated_mtllib_path(lib_name));
|
||||
}
|
||||
}
|
||||
const std::string& lib_path = "");
|
||||
|
||||
MTL::Library* get_library(
|
||||
const std::string& name,
|
||||
|
||||
@@ -4,149 +4,30 @@
|
||||
|
||||
#include "mlx/allocator.h"
|
||||
#include "mlx/backend/common/utils.h"
|
||||
#include "mlx/backend/metal/copy.h"
|
||||
#include "mlx/backend/metal/device.h"
|
||||
#include "mlx/backend/metal/event.h"
|
||||
#include "mlx/backend/metal/fence.h"
|
||||
#include "mlx/backend/metal/utils.h"
|
||||
#include "mlx/distributed/ops.h"
|
||||
#include "mlx/distributed/primitives.h"
|
||||
#include "mlx/fence.h"
|
||||
#include "mlx/scheduler.h"
|
||||
|
||||
namespace mlx::core::distributed {
|
||||
|
||||
void signal_and_wait(const Event& e_signal, const Event& e_wait) {
|
||||
if (e_signal.valid()) {
|
||||
encode_signal(e_signal);
|
||||
}
|
||||
encode_wait(e_wait);
|
||||
void AllReduce::eval_gpu(const std::vector<array>&, std::vector<array>&) {
|
||||
throw std::runtime_error("[AllReduce::eval_gpu] has no GPU implementation.");
|
||||
}
|
||||
|
||||
void AllReduce::eval_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
assert(inputs.size() == 1);
|
||||
assert(outputs.size() == 1);
|
||||
|
||||
auto& in = inputs[0];
|
||||
|
||||
Fence f{stream()};
|
||||
|
||||
if (in.event().valid()) {
|
||||
f.update_gpu(in);
|
||||
}
|
||||
|
||||
auto& out = outputs[0];
|
||||
if (in.is_donatable()) {
|
||||
out.move_shared_buffer(in);
|
||||
} else {
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
}
|
||||
f.wait_gpu(out);
|
||||
|
||||
auto task = [in = in,
|
||||
out = unsafe_weak_copy(out),
|
||||
f = std::move(f),
|
||||
reduce_type = reduce_type_,
|
||||
group = group()]() mutable {
|
||||
if (in.event().valid()) {
|
||||
f.wait();
|
||||
}
|
||||
switch (reduce_type) {
|
||||
case Sum:
|
||||
distributed::detail::all_sum(
|
||||
group, in.data_shared_ptr() == nullptr ? out : in, out);
|
||||
break;
|
||||
default:
|
||||
throw std::runtime_error("Only all reduce sum is supported for now");
|
||||
}
|
||||
f.update();
|
||||
};
|
||||
scheduler::enqueue(detail::communication_stream(), std::move(task));
|
||||
void AllGather::eval_gpu(const std::vector<array>&, std::vector<array>&) {
|
||||
throw std::runtime_error("[AllGather::eval_gpu] has no GPU implementation.");
|
||||
}
|
||||
|
||||
void AllGather::eval_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
assert(inputs.size() == 1);
|
||||
assert(outputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
auto& out = outputs[0];
|
||||
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
|
||||
Fence f{stream()};
|
||||
|
||||
if (in.event().valid()) {
|
||||
f.update_gpu(in);
|
||||
}
|
||||
f.wait_gpu(out);
|
||||
|
||||
auto task = [in = in,
|
||||
out = unsafe_weak_copy(out),
|
||||
f = std::move(f),
|
||||
group = group()]() mutable {
|
||||
if (in.event().valid()) {
|
||||
f.wait();
|
||||
}
|
||||
distributed::detail::all_gather(group, in, out);
|
||||
f.update();
|
||||
};
|
||||
scheduler::enqueue(detail::communication_stream(), std::move(task));
|
||||
void Send::eval_gpu(const std::vector<array>&, std::vector<array>&) {
|
||||
throw std::runtime_error("[Send::eval_gpu] has no GPU implementation.");
|
||||
}
|
||||
|
||||
void Send::eval_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
assert(inputs.size() == 1);
|
||||
assert(outputs.size() == 1);
|
||||
|
||||
auto& in = inputs[0];
|
||||
|
||||
// Encode a signal event for the input
|
||||
Fence f{stream()};
|
||||
if (in.event().valid()) {
|
||||
f.update_gpu(in);
|
||||
}
|
||||
|
||||
auto& out = outputs[0];
|
||||
move_or_copy(in, out);
|
||||
|
||||
// Schedule an async send on the comm stream
|
||||
auto task = [in = in,
|
||||
out = unsafe_weak_copy(out),
|
||||
f = std::move(f),
|
||||
group = group(),
|
||||
dst = dst_]() mutable {
|
||||
if (in.event().valid()) {
|
||||
f.wait();
|
||||
}
|
||||
distributed::detail::send(group, out, dst);
|
||||
};
|
||||
scheduler::enqueue(detail::communication_stream(), std::move(task));
|
||||
}
|
||||
|
||||
void Recv::eval_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
assert(inputs.size() == 0);
|
||||
assert(outputs.size() == 1);
|
||||
|
||||
auto& out = outputs[0];
|
||||
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
|
||||
Fence f{stream()};
|
||||
f.wait_gpu(out);
|
||||
|
||||
// Schedule an async recv on the comm stream
|
||||
auto task = [out = unsafe_weak_copy(out),
|
||||
f = std::move(f),
|
||||
group = group(),
|
||||
src = src_]() mutable {
|
||||
distributed::detail::recv(group, out, src);
|
||||
f.update();
|
||||
};
|
||||
scheduler::enqueue(detail::communication_stream(), std::move(task));
|
||||
void Recv::eval_gpu(const std::vector<array>&, std::vector<array>&) {
|
||||
throw std::runtime_error("[Recv::eval_gpu] has no GPU implementation.");
|
||||
}
|
||||
|
||||
} // namespace mlx::core::distributed
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user