mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Compare commits
2 Commits
4fda5fbdf9
...
interrupt_
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
688e421184 | ||
|
|
9ffe88841c |
@@ -24,8 +24,8 @@ jobs:
|
|||||||
type: boolean
|
type: boolean
|
||||||
default: false
|
default: false
|
||||||
macos:
|
macos:
|
||||||
xcode: "16.2.0"
|
xcode: "15.2.0"
|
||||||
resource_class: m2pro.medium
|
resource_class: macos.m1.medium.gen1
|
||||||
steps:
|
steps:
|
||||||
- checkout
|
- checkout
|
||||||
- run:
|
- run:
|
||||||
@@ -89,14 +89,15 @@ jobs:
|
|||||||
pip install numpy
|
pip install numpy
|
||||||
sudo apt-get update
|
sudo apt-get update
|
||||||
sudo apt-get install libblas-dev liblapack-dev liblapacke-dev
|
sudo apt-get install libblas-dev liblapack-dev liblapacke-dev
|
||||||
sudo apt-get install openmpi-bin openmpi-common libopenmpi-dev
|
|
||||||
- run:
|
- run:
|
||||||
name: Install Python package
|
name: Install Python package
|
||||||
command: |
|
command: |
|
||||||
CMAKE_ARGS="-DMLX_BUILD_METAL=OFF" \
|
CMAKE_ARGS="-DMLX_BUILD_METAL=OFF
|
||||||
|
CMAKE_COMPILE_WARNING_AS_ERROR=ON" \
|
||||||
CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
|
CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
|
||||||
python3 setup.py build_ext --inplace
|
python3 setup.py build_ext --inplace
|
||||||
CMAKE_ARGS="-DMLX_BUILD_METAL=OFF" \
|
CMAKE_ARGS="-DMLX_BUILD_METAL=OFF \
|
||||||
|
CMAKE_COMPILE_WARNING_AS_ERROR=ON" \
|
||||||
CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
|
CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
|
||||||
python3 setup.py develop
|
python3 setup.py develop
|
||||||
- run:
|
- run:
|
||||||
@@ -109,8 +110,6 @@ jobs:
|
|||||||
name: Run Python tests
|
name: Run Python tests
|
||||||
command: |
|
command: |
|
||||||
python3 -m unittest discover python/tests -v
|
python3 -m unittest discover python/tests -v
|
||||||
mpirun --bind-to none -host localhost:8 -np 8 python python/tests/mpi_test_distributed.py
|
|
||||||
mlx.launch --verbose -n 8 python/tests/ring_test_distributed.py
|
|
||||||
- run:
|
- run:
|
||||||
name: Build CPP only
|
name: Build CPP only
|
||||||
command: |
|
command: |
|
||||||
@@ -125,15 +124,10 @@ jobs:
|
|||||||
parameters:
|
parameters:
|
||||||
xcode_version:
|
xcode_version:
|
||||||
type: string
|
type: string
|
||||||
default: "16.2.0"
|
default: "15.2.0"
|
||||||
macosx_deployment_target:
|
|
||||||
type: string
|
|
||||||
default: ""
|
|
||||||
macos:
|
macos:
|
||||||
xcode: << parameters.xcode_version >>
|
xcode: << parameters.xcode_version >>
|
||||||
environment:
|
resource_class: macos.m1.medium.gen1
|
||||||
MACOSX_DEPLOYMENT_TARGET: << parameters.macosx_deployment_target >>
|
|
||||||
resource_class: m2pro.medium
|
|
||||||
steps:
|
steps:
|
||||||
- checkout
|
- checkout
|
||||||
- run:
|
- run:
|
||||||
@@ -155,7 +149,7 @@ jobs:
|
|||||||
command: |
|
command: |
|
||||||
source env/bin/activate
|
source env/bin/activate
|
||||||
DEBUG=1 CMAKE_BUILD_PARALLEL_LEVEL=`sysctl -n hw.ncpu` \
|
DEBUG=1 CMAKE_BUILD_PARALLEL_LEVEL=`sysctl -n hw.ncpu` \
|
||||||
CMAKE_ARGS="-DCMAKE_COMPILE_WARNING_AS_ERROR=ON" \
|
CMAKE_ARGS="CMAKE_COMPILE_WARNING_AS_ERROR=ON" \
|
||||||
pip install -e . -v
|
pip install -e . -v
|
||||||
- run:
|
- run:
|
||||||
name: Generate package stubs
|
name: Generate package stubs
|
||||||
@@ -212,30 +206,6 @@ jobs:
|
|||||||
METAL_DEBUG_ERROR_MODE=0 \
|
METAL_DEBUG_ERROR_MODE=0 \
|
||||||
python -m xmlrunner discover -v python/tests -o test-results/gpu_jit
|
python -m xmlrunner discover -v python/tests -o test-results/gpu_jit
|
||||||
|
|
||||||
cuda_build_and_test:
|
|
||||||
machine:
|
|
||||||
image: linux-cuda-12:default
|
|
||||||
resource_class: gpu.nvidia.small.gen2
|
|
||||||
steps:
|
|
||||||
- checkout
|
|
||||||
- run:
|
|
||||||
name: Install Python package
|
|
||||||
command: |
|
|
||||||
sudo apt-get update
|
|
||||||
sudo apt-get install libblas-dev liblapack-dev liblapacke-dev
|
|
||||||
sudo apt-get install openmpi-bin openmpi-common libopenmpi-dev
|
|
||||||
python -m venv env
|
|
||||||
source env/bin/activate
|
|
||||||
CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
|
|
||||||
CMAKE_ARGS="-DMLX_BUILD_CUDA=ON -DCMAKE_CUDA_COMPILER=`which nvcc`" \
|
|
||||||
pip install -e ".[dev]"
|
|
||||||
- run:
|
|
||||||
name: Run Python tests
|
|
||||||
command: |
|
|
||||||
source env/bin/activate
|
|
||||||
LOW_MEMORY=1 DEVICE=cpu python -m unittest discover python/tests -v
|
|
||||||
LOW_MEMORY=1 DEVICE=gpu python -m tests discover python/tests -v
|
|
||||||
|
|
||||||
build_release:
|
build_release:
|
||||||
parameters:
|
parameters:
|
||||||
python_version:
|
python_version:
|
||||||
@@ -243,18 +213,13 @@ jobs:
|
|||||||
default: "3.9"
|
default: "3.9"
|
||||||
xcode_version:
|
xcode_version:
|
||||||
type: string
|
type: string
|
||||||
default: "16.2.0"
|
default: "15.2.0"
|
||||||
build_env:
|
build_env:
|
||||||
type: string
|
type: string
|
||||||
default: ""
|
default: ""
|
||||||
macosx_deployment_target:
|
|
||||||
type: string
|
|
||||||
default: ""
|
|
||||||
macos:
|
macos:
|
||||||
xcode: << parameters.xcode_version >>
|
xcode: << parameters.xcode_version >>
|
||||||
resource_class: m2pro.medium
|
resource_class: macos.m1.medium.gen1
|
||||||
environment:
|
|
||||||
MACOSX_DEPLOYMENT_TARGET: << parameters.macosx_deployment_target >>
|
|
||||||
steps:
|
steps:
|
||||||
- checkout
|
- checkout
|
||||||
- run:
|
- run:
|
||||||
@@ -275,7 +240,7 @@ jobs:
|
|||||||
name: Install Python package
|
name: Install Python package
|
||||||
command: |
|
command: |
|
||||||
source env/bin/activate
|
source env/bin/activate
|
||||||
env -u MACOSX_DEPLOYMENT_TARGET DEV_RELEASE=1 \
|
DEV_RELEASE=1 \
|
||||||
CMAKE_BUILD_PARALLEL_LEVEL=`sysctl -n hw.ncpu` \
|
CMAKE_BUILD_PARALLEL_LEVEL=`sysctl -n hw.ncpu` \
|
||||||
pip install . -v
|
pip install . -v
|
||||||
- run:
|
- run:
|
||||||
@@ -370,9 +335,8 @@ workflows:
|
|||||||
- mac_build_and_test:
|
- mac_build_and_test:
|
||||||
matrix:
|
matrix:
|
||||||
parameters:
|
parameters:
|
||||||
macosx_deployment_target: ["13.5", "14.0"]
|
xcode_version: ["15.0.0", "15.2.0", "16.0.0"]
|
||||||
- linux_build_and_test
|
- linux_build_and_test
|
||||||
- cuda_build_and_test
|
|
||||||
- build_documentation
|
- build_documentation
|
||||||
|
|
||||||
build_pypi_release:
|
build_pypi_release:
|
||||||
@@ -391,70 +355,8 @@ workflows:
|
|||||||
matrix:
|
matrix:
|
||||||
parameters:
|
parameters:
|
||||||
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
|
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
|
||||||
macosx_deployment_target: ["13.5", "14.0", "15.0"]
|
xcode_version: ["15.0.0", "15.2.0"]
|
||||||
build_env: ["PYPI_RELEASE=1"]
|
build_env: ["PYPI_RELEASE=1"]
|
||||||
xcode_version: ["16.2.0", "15.0.0"]
|
|
||||||
exclude:
|
|
||||||
- macosx_deployment_target: "13.5"
|
|
||||||
xcode_version: "16.2.0"
|
|
||||||
python_version: "3.9"
|
|
||||||
build_env: "PYPI_RELEASE=1"
|
|
||||||
- macosx_deployment_target: "13.5"
|
|
||||||
xcode_version: "16.2.0"
|
|
||||||
python_version: "3.10"
|
|
||||||
build_env: "PYPI_RELEASE=1"
|
|
||||||
- macosx_deployment_target: "13.5"
|
|
||||||
xcode_version: "16.2.0"
|
|
||||||
python_version: "3.11"
|
|
||||||
build_env: "PYPI_RELEASE=1"
|
|
||||||
- macosx_deployment_target: "13.5"
|
|
||||||
xcode_version: "16.2.0"
|
|
||||||
python_version: "3.12"
|
|
||||||
build_env: "PYPI_RELEASE=1"
|
|
||||||
- macosx_deployment_target: "13.5"
|
|
||||||
xcode_version: "16.2.0"
|
|
||||||
python_version: "3.13"
|
|
||||||
build_env: "PYPI_RELEASE=1"
|
|
||||||
- macosx_deployment_target: "14.0"
|
|
||||||
xcode_version: "15.0.0"
|
|
||||||
python_version: "3.9"
|
|
||||||
build_env: "PYPI_RELEASE=1"
|
|
||||||
- macosx_deployment_target: "14.0"
|
|
||||||
xcode_version: "15.0.0"
|
|
||||||
python_version: "3.10"
|
|
||||||
build_env: "PYPI_RELEASE=1"
|
|
||||||
- macosx_deployment_target: "14.0"
|
|
||||||
xcode_version: "15.0.0"
|
|
||||||
python_version: "3.11"
|
|
||||||
build_env: "PYPI_RELEASE=1"
|
|
||||||
- macosx_deployment_target: "14.0"
|
|
||||||
xcode_version: "15.0.0"
|
|
||||||
python_version: "3.12"
|
|
||||||
build_env: "PYPI_RELEASE=1"
|
|
||||||
- macosx_deployment_target: "14.0"
|
|
||||||
xcode_version: "15.0.0"
|
|
||||||
python_version: "3.13"
|
|
||||||
build_env: "PYPI_RELEASE=1"
|
|
||||||
- macosx_deployment_target: "15.0"
|
|
||||||
xcode_version: "15.0.0"
|
|
||||||
python_version: "3.9"
|
|
||||||
build_env: "PYPI_RELEASE=1"
|
|
||||||
- macosx_deployment_target: "15.0"
|
|
||||||
xcode_version: "15.0.0"
|
|
||||||
python_version: "3.10"
|
|
||||||
build_env: "PYPI_RELEASE=1"
|
|
||||||
- macosx_deployment_target: "15.0"
|
|
||||||
xcode_version: "15.0.0"
|
|
||||||
python_version: "3.11"
|
|
||||||
build_env: "PYPI_RELEASE=1"
|
|
||||||
- macosx_deployment_target: "15.0"
|
|
||||||
xcode_version: "15.0.0"
|
|
||||||
python_version: "3.12"
|
|
||||||
build_env: "PYPI_RELEASE=1"
|
|
||||||
- macosx_deployment_target: "15.0"
|
|
||||||
xcode_version: "15.0.0"
|
|
||||||
python_version: "3.13"
|
|
||||||
build_env: "PYPI_RELEASE=1"
|
|
||||||
- build_documentation:
|
- build_documentation:
|
||||||
filters:
|
filters:
|
||||||
tags:
|
tags:
|
||||||
@@ -477,11 +379,9 @@ workflows:
|
|||||||
requires: [ hold ]
|
requires: [ hold ]
|
||||||
matrix:
|
matrix:
|
||||||
parameters:
|
parameters:
|
||||||
macosx_deployment_target: ["13.5", "14.0"]
|
xcode_version: ["15.0.0", "15.2.0", "16.0.0"]
|
||||||
- linux_build_and_test:
|
- linux_build_and_test:
|
||||||
requires: [ hold ]
|
requires: [ hold ]
|
||||||
- cuda_build_and_test:
|
|
||||||
requires: [ hold ]
|
|
||||||
nightly_build:
|
nightly_build:
|
||||||
when:
|
when:
|
||||||
and:
|
and:
|
||||||
@@ -492,54 +392,7 @@ workflows:
|
|||||||
matrix:
|
matrix:
|
||||||
parameters:
|
parameters:
|
||||||
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
|
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
|
||||||
macosx_deployment_target: ["13.5", "14.0", "15.0"]
|
xcode_version: ["15.0.0", "15.2.0"]
|
||||||
xcode_version: ["16.2.0", "15.0.0"]
|
|
||||||
exclude:
|
|
||||||
- macosx_deployment_target: "13.5"
|
|
||||||
xcode_version: "16.2.0"
|
|
||||||
python_version: "3.9"
|
|
||||||
- macosx_deployment_target: "13.5"
|
|
||||||
xcode_version: "16.2.0"
|
|
||||||
python_version: "3.10"
|
|
||||||
- macosx_deployment_target: "13.5"
|
|
||||||
xcode_version: "16.2.0"
|
|
||||||
python_version: "3.11"
|
|
||||||
- macosx_deployment_target: "13.5"
|
|
||||||
xcode_version: "16.2.0"
|
|
||||||
python_version: "3.12"
|
|
||||||
- macosx_deployment_target: "13.5"
|
|
||||||
xcode_version: "16.2.0"
|
|
||||||
python_version: "3.13"
|
|
||||||
- macosx_deployment_target: "14.0"
|
|
||||||
xcode_version: "15.0.0"
|
|
||||||
python_version: "3.9"
|
|
||||||
- macosx_deployment_target: "14.0"
|
|
||||||
xcode_version: "15.0.0"
|
|
||||||
python_version: "3.10"
|
|
||||||
- macosx_deployment_target: "14.0"
|
|
||||||
xcode_version: "15.0.0"
|
|
||||||
python_version: "3.11"
|
|
||||||
- macosx_deployment_target: "14.0"
|
|
||||||
xcode_version: "15.0.0"
|
|
||||||
python_version: "3.12"
|
|
||||||
- macosx_deployment_target: "14.0"
|
|
||||||
xcode_version: "15.0.0"
|
|
||||||
python_version: "3.13"
|
|
||||||
- macosx_deployment_target: "15.0"
|
|
||||||
xcode_version: "15.0.0"
|
|
||||||
python_version: "3.9"
|
|
||||||
- macosx_deployment_target: "15.0"
|
|
||||||
xcode_version: "15.0.0"
|
|
||||||
python_version: "3.10"
|
|
||||||
- macosx_deployment_target: "15.0"
|
|
||||||
xcode_version: "15.0.0"
|
|
||||||
python_version: "3.11"
|
|
||||||
- macosx_deployment_target: "15.0"
|
|
||||||
xcode_version: "15.0.0"
|
|
||||||
python_version: "3.12"
|
|
||||||
- macosx_deployment_target: "15.0"
|
|
||||||
xcode_version: "15.0.0"
|
|
||||||
python_version: "3.13"
|
|
||||||
weekly_build:
|
weekly_build:
|
||||||
when:
|
when:
|
||||||
and:
|
and:
|
||||||
@@ -550,70 +403,8 @@ workflows:
|
|||||||
matrix:
|
matrix:
|
||||||
parameters:
|
parameters:
|
||||||
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
|
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
|
||||||
macosx_deployment_target: ["13.5", "14.0", "15.0"]
|
xcode_version: ["15.0.0", "15.2.0", "16.0.0"]
|
||||||
build_env: ["DEV_RELEASE=1"]
|
build_env: ["DEV_RELEASE=1"]
|
||||||
xcode_version: ["16.2.0", "15.0.0"]
|
|
||||||
exclude:
|
|
||||||
- macosx_deployment_target: "13.5"
|
|
||||||
xcode_version: "16.2.0"
|
|
||||||
python_version: "3.9"
|
|
||||||
build_env: "DEV_RELEASE=1"
|
|
||||||
- macosx_deployment_target: "13.5"
|
|
||||||
xcode_version: "16.2.0"
|
|
||||||
python_version: "3.10"
|
|
||||||
build_env: "DEV_RELEASE=1"
|
|
||||||
- macosx_deployment_target: "13.5"
|
|
||||||
xcode_version: "16.2.0"
|
|
||||||
python_version: "3.11"
|
|
||||||
build_env: "DEV_RELEASE=1"
|
|
||||||
- macosx_deployment_target: "13.5"
|
|
||||||
xcode_version: "16.2.0"
|
|
||||||
python_version: "3.12"
|
|
||||||
build_env: "DEV_RELEASE=1"
|
|
||||||
- macosx_deployment_target: "13.5"
|
|
||||||
xcode_version: "16.2.0"
|
|
||||||
python_version: "3.13"
|
|
||||||
build_env: "DEV_RELEASE=1"
|
|
||||||
- macosx_deployment_target: "14.0"
|
|
||||||
xcode_version: "15.0.0"
|
|
||||||
python_version: "3.9"
|
|
||||||
build_env: "DEV_RELEASE=1"
|
|
||||||
- macosx_deployment_target: "14.0"
|
|
||||||
xcode_version: "15.0.0"
|
|
||||||
python_version: "3.10"
|
|
||||||
build_env: "DEV_RELEASE=1"
|
|
||||||
- macosx_deployment_target: "14.0"
|
|
||||||
xcode_version: "15.0.0"
|
|
||||||
python_version: "3.11"
|
|
||||||
build_env: "DEV_RELEASE=1"
|
|
||||||
- macosx_deployment_target: "14.0"
|
|
||||||
xcode_version: "15.0.0"
|
|
||||||
python_version: "3.12"
|
|
||||||
build_env: "DEV_RELEASE=1"
|
|
||||||
- macosx_deployment_target: "14.0"
|
|
||||||
xcode_version: "15.0.0"
|
|
||||||
python_version: "3.13"
|
|
||||||
build_env: "DEV_RELEASE=1"
|
|
||||||
- macosx_deployment_target: "15.0"
|
|
||||||
xcode_version: "15.0.0"
|
|
||||||
python_version: "3.9"
|
|
||||||
build_env: "DEV_RELEASE=1"
|
|
||||||
- macosx_deployment_target: "15.0"
|
|
||||||
xcode_version: "15.0.0"
|
|
||||||
python_version: "3.10"
|
|
||||||
build_env: "DEV_RELEASE=1"
|
|
||||||
- macosx_deployment_target: "15.0"
|
|
||||||
xcode_version: "15.0.0"
|
|
||||||
python_version: "3.11"
|
|
||||||
build_env: "DEV_RELEASE=1"
|
|
||||||
- macosx_deployment_target: "15.0"
|
|
||||||
xcode_version: "15.0.0"
|
|
||||||
python_version: "3.12"
|
|
||||||
build_env: "DEV_RELEASE=1"
|
|
||||||
- macosx_deployment_target: "15.0"
|
|
||||||
xcode_version: "15.0.0"
|
|
||||||
python_version: "3.13"
|
|
||||||
build_env: "DEV_RELEASE=1"
|
|
||||||
linux_test_release:
|
linux_test_release:
|
||||||
when:
|
when:
|
||||||
and:
|
and:
|
||||||
|
|||||||
1
.gitignore
vendored
1
.gitignore
vendored
@@ -36,7 +36,6 @@ share/python-wheels/
|
|||||||
.installed.cfg
|
.installed.cfg
|
||||||
*.egg
|
*.egg
|
||||||
MANIFEST
|
MANIFEST
|
||||||
uv.lock
|
|
||||||
|
|
||||||
# vim
|
# vim
|
||||||
*.swp
|
*.swp
|
||||||
|
|||||||
@@ -34,7 +34,6 @@ option(MLX_BUILD_BENCHMARKS "Build benchmarks for mlx" OFF)
|
|||||||
option(MLX_BUILD_PYTHON_BINDINGS "Build python bindings for mlx" OFF)
|
option(MLX_BUILD_PYTHON_BINDINGS "Build python bindings for mlx" OFF)
|
||||||
option(MLX_BUILD_METAL "Build metal backend" ON)
|
option(MLX_BUILD_METAL "Build metal backend" ON)
|
||||||
option(MLX_BUILD_CPU "Build cpu backend" ON)
|
option(MLX_BUILD_CPU "Build cpu backend" ON)
|
||||||
option(MLX_BUILD_CUDA "Build cuda backend" OFF)
|
|
||||||
option(MLX_METAL_DEBUG "Enhance metal debug workflow" OFF)
|
option(MLX_METAL_DEBUG "Enhance metal debug workflow" OFF)
|
||||||
option(MLX_ENABLE_X64_MAC "Enable building for x64 macOS" OFF)
|
option(MLX_ENABLE_X64_MAC "Enable building for x64 macOS" OFF)
|
||||||
option(MLX_BUILD_GGUF "Include support for GGUF format" ON)
|
option(MLX_BUILD_GGUF "Include support for GGUF format" ON)
|
||||||
@@ -84,10 +83,6 @@ if(MLX_BUILD_METAL)
|
|||||||
set(QUARTZ_LIB "-framework QuartzCore")
|
set(QUARTZ_LIB "-framework QuartzCore")
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
if(MLX_BUILD_CUDA)
|
|
||||||
enable_language(CUDA)
|
|
||||||
endif()
|
|
||||||
|
|
||||||
if(MLX_BUILD_METAL AND NOT METAL_LIB)
|
if(MLX_BUILD_METAL AND NOT METAL_LIB)
|
||||||
message(STATUS "Metal not found. Unable to build GPU")
|
message(STATUS "Metal not found. Unable to build GPU")
|
||||||
set(MLX_BUILD_METAL OFF)
|
set(MLX_BUILD_METAL OFF)
|
||||||
@@ -217,6 +212,24 @@ else()
|
|||||||
set(MLX_BUILD_ACCELERATE OFF)
|
set(MLX_BUILD_ACCELERATE OFF)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
|
find_package(MPI)
|
||||||
|
if(MPI_FOUND)
|
||||||
|
execute_process(
|
||||||
|
COMMAND zsh "-c" "mpirun --version"
|
||||||
|
OUTPUT_VARIABLE MPI_VERSION
|
||||||
|
ERROR_QUIET)
|
||||||
|
if(${MPI_VERSION} MATCHES ".*Open MPI.*")
|
||||||
|
target_include_directories(mlx PRIVATE ${MPI_INCLUDE_PATH})
|
||||||
|
elseif(MPI_VERSION STREQUAL "")
|
||||||
|
set(MPI_FOUND FALSE)
|
||||||
|
message(
|
||||||
|
WARNING "MPI found but mpirun is not available. Building without MPI.")
|
||||||
|
else()
|
||||||
|
set(MPI_FOUND FALSE)
|
||||||
|
message(WARNING "MPI which is not OpenMPI found. Building without MPI.")
|
||||||
|
endif()
|
||||||
|
endif()
|
||||||
|
|
||||||
message(STATUS "Downloading json")
|
message(STATUS "Downloading json")
|
||||||
FetchContent_Declare(
|
FetchContent_Declare(
|
||||||
json
|
json
|
||||||
@@ -231,9 +244,6 @@ target_include_directories(
|
|||||||
mlx PUBLIC $<BUILD_INTERFACE:${CMAKE_CURRENT_LIST_DIR}>
|
mlx PUBLIC $<BUILD_INTERFACE:${CMAKE_CURRENT_LIST_DIR}>
|
||||||
$<INSTALL_INTERFACE:include>)
|
$<INSTALL_INTERFACE:include>)
|
||||||
|
|
||||||
# Do not add mlx_EXPORTS define for shared library.
|
|
||||||
set_target_properties(mlx PROPERTIES DEFINE_SYMBOL "")
|
|
||||||
|
|
||||||
FetchContent_Declare(
|
FetchContent_Declare(
|
||||||
fmt
|
fmt
|
||||||
GIT_REPOSITORY https://github.com/fmtlib/fmt.git
|
GIT_REPOSITORY https://github.com/fmtlib/fmt.git
|
||||||
|
|||||||
@@ -5,26 +5,26 @@ possible.
|
|||||||
|
|
||||||
## Pull Requests
|
## Pull Requests
|
||||||
|
|
||||||
1. Fork and submit pull requests to the repo.
|
1. Fork and submit pull requests to the repo.
|
||||||
2. If you've added code that should be tested, add tests.
|
2. If you've added code that should be tested, add tests.
|
||||||
3. If a change is likely to impact efficiency, run some of the benchmarks before
|
3. If a change is likely to impact efficiency, run some of the benchmarks before
|
||||||
and after the change. Examples of benchmarks can be found in `benchmarks/python/`.
|
and after the change. Examples of benchmarks can be found in `benchmarks/python/`.
|
||||||
4. If you've changed APIs, update the documentation.
|
4. If you've changed APIs, update the documentation.
|
||||||
5. Every PR should have passing tests and at least one review.
|
5. Every PR should have passing tests and at least one review.
|
||||||
6. For code formatting install `pre-commit` using something like `pip install pre-commit` and run `pre-commit install`.
|
6. For code formatting install `pre-commit` using something like `pip install pre-commit` and run `pre-commit install`.
|
||||||
This should install hooks for running `black` and `clang-format` to ensure
|
This should install hooks for running `black` and `clang-format` to ensure
|
||||||
consistent style for C++ and python code.
|
consistent style for C++ and python code.
|
||||||
|
|
||||||
You can also run the formatters manually as follows:
|
You can also run the formatters manually as follows:
|
||||||
|
|
||||||
```shell
|
```
|
||||||
clang-format -i file.cpp
|
clang-format -i file.cpp
|
||||||
```
|
```
|
||||||
|
|
||||||
```shell
|
```
|
||||||
black file.py
|
black file.py
|
||||||
```
|
```
|
||||||
|
|
||||||
or run `pre-commit run --all-files` to check all files in the repo.
|
or run `pre-commit run --all-files` to check all files in the repo.
|
||||||
|
|
||||||
## Issues
|
## Issues
|
||||||
|
|||||||
@@ -1,6 +1,4 @@
|
|||||||
include CMakeLists.txt
|
include CMakeLists.txt
|
||||||
include mlx.pc.in
|
|
||||||
recursive-include mlx/ *
|
recursive-include mlx/ *
|
||||||
include cmake/*
|
|
||||||
include python/src/*
|
include python/src/*
|
||||||
include python/mlx/py.typed # support type hinting as in PEP-561
|
include python/mlx/py.typed # support type hinting as in PEP-561
|
||||||
|
|||||||
@@ -1,6 +1,5 @@
|
|||||||
// Copyright © 2023 Apple Inc.
|
// Copyright © 2023 Apple Inc.
|
||||||
|
|
||||||
#include <cstring>
|
|
||||||
#include <iostream>
|
#include <iostream>
|
||||||
#include <sstream>
|
#include <sstream>
|
||||||
|
|
||||||
|
|||||||
@@ -1,107 +0,0 @@
|
|||||||
import math
|
|
||||||
import time
|
|
||||||
|
|
||||||
import mlx.core as mx
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
|
|
||||||
N_warmup = 10
|
|
||||||
N_iter_bench = 100
|
|
||||||
N_iter_func = 5
|
|
||||||
|
|
||||||
|
|
||||||
def bench(f, a, b):
|
|
||||||
for i in range(N_warmup):
|
|
||||||
f(a, b)
|
|
||||||
torch.mps.synchronize()
|
|
||||||
|
|
||||||
s = time.perf_counter_ns()
|
|
||||||
for i in range(N_iter_bench):
|
|
||||||
f(a, b)
|
|
||||||
e = time.perf_counter_ns()
|
|
||||||
return (e - s) * 1e-9
|
|
||||||
|
|
||||||
|
|
||||||
def make_mx_conv_2D(strides=(1, 1), padding=(0, 0), groups=1):
|
|
||||||
def mx_conv_2D(a, b):
|
|
||||||
ys = []
|
|
||||||
for i in range(N_iter_func):
|
|
||||||
y = mx.conv2d(a, b, stride=strides, padding=padding, groups=groups)
|
|
||||||
ys.append(y)
|
|
||||||
mx.eval(ys)
|
|
||||||
return ys
|
|
||||||
|
|
||||||
return mx_conv_2D
|
|
||||||
|
|
||||||
|
|
||||||
def make_pt_conv_2D(strides=(1, 1), padding=(0, 0), groups=1):
|
|
||||||
@torch.no_grad()
|
|
||||||
def pt_conv_2D(a, b):
|
|
||||||
ys = []
|
|
||||||
for i in range(N_iter_func):
|
|
||||||
y = torch.conv2d(a, b, stride=strides, padding=padding, groups=groups)
|
|
||||||
ys.append(y)
|
|
||||||
torch.mps.synchronize()
|
|
||||||
return ys
|
|
||||||
|
|
||||||
return pt_conv_2D
|
|
||||||
|
|
||||||
|
|
||||||
def bench_shape(N, H, W, C, kH, kW, O, strides, padding, groups, np_dtype):
|
|
||||||
scale = 1.0 / math.sqrt(kH * kH * C)
|
|
||||||
a_np = np.random.uniform(0, 0.5, (N, H, W, C)).astype(np_dtype)
|
|
||||||
b_np = np.random.uniform(-scale, scale, (O, kH, kW, int(C / groups))).astype(
|
|
||||||
np_dtype
|
|
||||||
)
|
|
||||||
|
|
||||||
a_mx = mx.array(a_np)
|
|
||||||
b_mx = mx.array(b_np)
|
|
||||||
|
|
||||||
a_pt = torch.from_numpy(a_np.transpose((0, 3, 1, 2))).to("mps")
|
|
||||||
b_pt = torch.from_numpy(b_np.transpose((0, 3, 1, 2))).to("mps")
|
|
||||||
|
|
||||||
torch.mps.synchronize()
|
|
||||||
|
|
||||||
f_mx = make_mx_conv_2D(strides, padding, groups)
|
|
||||||
f_pt = make_pt_conv_2D(strides, padding, groups)
|
|
||||||
|
|
||||||
time_torch = bench(f_pt, a_pt, b_pt)
|
|
||||||
time_mlx = bench(f_mx, a_mx, b_mx)
|
|
||||||
|
|
||||||
out_mx = mx.conv2d(a_mx, b_mx, stride=strides, padding=padding, groups=groups)
|
|
||||||
out_pt = torch.conv2d(
|
|
||||||
a_pt.to("cpu"), b_pt.to("cpu"), stride=strides, padding=padding, groups=groups
|
|
||||||
)
|
|
||||||
out_pt = torch.permute(out_pt, (0, 2, 3, 1))
|
|
||||||
out_pt = out_pt.numpy(force=True)
|
|
||||||
|
|
||||||
atol = 2e-5 if np_dtype == np.float32 else 1e-4
|
|
||||||
|
|
||||||
if not np.allclose(out_pt, out_mx, atol=atol):
|
|
||||||
print(
|
|
||||||
f"Failed at {(N, H, W, C)}, {(O, kH, kW, C)} [strides = {strides}, padding = {padding}, groups = {groups}] with max(|a - b|) = {np.max(np.abs(out_pt - out_mx))}"
|
|
||||||
)
|
|
||||||
|
|
||||||
return time_mlx, time_torch
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
dtype = "float32"
|
|
||||||
shapes = (
|
|
||||||
(4, 32, 32, 21, 3, 3, 128),
|
|
||||||
(4, 32, 32, 21, 3, 3, 37),
|
|
||||||
(4, 32, 32, 370, 3, 3, 370),
|
|
||||||
(4, 32, 32, 370, 7, 7, 128),
|
|
||||||
(2, 320, 640, 21, 7, 7, 21),
|
|
||||||
)
|
|
||||||
for N, H, W, C, kh, kw, O in shapes:
|
|
||||||
time_mlx, time_torch = bench_shape(
|
|
||||||
N, H, W, C, kh, kw, O, (1, 1), (0, 0), 1, dtype
|
|
||||||
)
|
|
||||||
diff = time_torch / time_mlx - 1.0
|
|
||||||
|
|
||||||
print(
|
|
||||||
f"({N}, {H:3d}, {W:3d}, {C:3d}), ({O:3d}, {kh:2d}, {kw:2d}, {C:3d}), {dtype}, {100. * diff:+5.2f}%"
|
|
||||||
)
|
|
||||||
if time_mlx >= 2.0 * time_torch:
|
|
||||||
print("ATTENTION ^^^^^^^")
|
|
||||||
@@ -1,74 +0,0 @@
|
|||||||
# Copyright © 2025 Apple Inc.
|
|
||||||
|
|
||||||
import mlx.core as mx
|
|
||||||
from time_utils import time_fn
|
|
||||||
|
|
||||||
N = 1024
|
|
||||||
D = 1024
|
|
||||||
M = 1024
|
|
||||||
E = 32
|
|
||||||
I = 4
|
|
||||||
|
|
||||||
|
|
||||||
def gather_sort(x, indices):
|
|
||||||
N, M = indices.shape
|
|
||||||
indices = indices.flatten()
|
|
||||||
order = mx.argsort(indices)
|
|
||||||
inv_order = mx.argsort(order)
|
|
||||||
return x.flatten(0, -3)[order // M], indices[order], inv_order
|
|
||||||
|
|
||||||
|
|
||||||
def scatter_unsort(x, inv_order, shape=None):
|
|
||||||
x = x[inv_order]
|
|
||||||
if shape is not None:
|
|
||||||
x = mx.unflatten(x, 0, shape)
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
def gather_mm_simulate(x, w, indices):
|
|
||||||
x, idx, inv_order = gather_sort(x, indices)
|
|
||||||
for i in range(2):
|
|
||||||
y = mx.concatenate([x[i] @ w[j].T for i, j in enumerate(idx.tolist())], axis=0)
|
|
||||||
x = y[:, None]
|
|
||||||
x = scatter_unsort(x, inv_order, indices.shape)
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
def time_gather_mm():
|
|
||||||
x = mx.random.normal((N, 1, 1, D)) / 1024**0.5
|
|
||||||
w1 = mx.random.normal((E, M, D)) / 1024**0.5
|
|
||||||
w2 = mx.random.normal((E, D, M)) / 1024**0.5
|
|
||||||
indices = (mx.random.uniform(shape=(N, I)) * E).astype(mx.uint32)
|
|
||||||
sorted_indices = mx.sort(indices.flatten()).reshape(N, I)
|
|
||||||
mx.eval(x, w1, w2, indices, sorted_indices)
|
|
||||||
|
|
||||||
def gather_mm(x, w1, w2, indices, sort):
|
|
||||||
idx = indices
|
|
||||||
inv_order = None
|
|
||||||
if sort:
|
|
||||||
x, idx, inv_order = gather_sort(x, indices)
|
|
||||||
x = mx.gather_mm(x, w1.swapaxes(-1, -2), rhs_indices=idx, sorted_indices=sort)
|
|
||||||
x = mx.gather_mm(x, w2.swapaxes(-1, -2), rhs_indices=idx, sorted_indices=sort)
|
|
||||||
if sort:
|
|
||||||
x = scatter_unsort(x, inv_order, indices.shape)
|
|
||||||
return x
|
|
||||||
|
|
||||||
time_fn(gather_mm, x, w1, w2, indices, False)
|
|
||||||
time_fn(gather_mm, x, w1, w2, sorted_indices, False)
|
|
||||||
time_fn(gather_mm, x, w1, w2, indices, True)
|
|
||||||
|
|
||||||
x = mx.random.normal((N * I, D)) / 1024**0.5
|
|
||||||
w1 = mx.random.normal((M, D)) / 1024**0.5
|
|
||||||
w2 = mx.random.normal((D, M)) / 1024**0.5
|
|
||||||
mx.eval(x, w1, w2)
|
|
||||||
|
|
||||||
def equivalent_matmul(x, w1, w2):
|
|
||||||
x = x @ w1.T
|
|
||||||
x = x @ w2.T
|
|
||||||
return x
|
|
||||||
|
|
||||||
time_fn(equivalent_matmul, x, w1, w2)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
time_gather_mm()
|
|
||||||
@@ -1,84 +0,0 @@
|
|||||||
# Copyright © 2025 Apple Inc.
|
|
||||||
|
|
||||||
import mlx.core as mx
|
|
||||||
from time_utils import time_fn
|
|
||||||
|
|
||||||
N = 1024
|
|
||||||
D = 1024
|
|
||||||
M = 1024
|
|
||||||
E = 32
|
|
||||||
I = 4
|
|
||||||
|
|
||||||
|
|
||||||
def gather_sort(x, indices):
|
|
||||||
N, M = indices.shape
|
|
||||||
indices = indices.flatten()
|
|
||||||
order = mx.argsort(indices)
|
|
||||||
inv_order = mx.argsort(order)
|
|
||||||
return x.flatten(0, -3)[order // M], indices[order], inv_order
|
|
||||||
|
|
||||||
|
|
||||||
def scatter_unsort(x, inv_order, shape=None):
|
|
||||||
x = x[inv_order]
|
|
||||||
if shape is not None:
|
|
||||||
x = mx.unflatten(x, 0, shape)
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
def gather_mm_simulate(x, w, indices):
|
|
||||||
x, idx, inv_order = gather_sort(x, indices)
|
|
||||||
for i in range(2):
|
|
||||||
y = mx.concatenate(
|
|
||||||
[
|
|
||||||
mx.quantized_matmul(x[i], w[0][j], w[1][j], w[2][j], transpose=True)
|
|
||||||
for i, j in enumerate(idx.tolist())
|
|
||||||
],
|
|
||||||
axis=0,
|
|
||||||
)
|
|
||||||
x = y[:, None]
|
|
||||||
x = scatter_unsort(x, inv_order, indices.shape)
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
def time_gather_qmm():
|
|
||||||
x = mx.random.normal((N, 1, 1, D)) / 1024**0.5
|
|
||||||
w1 = mx.random.normal((E, M, D)) / 1024**0.5
|
|
||||||
w2 = mx.random.normal((E, D, M)) / 1024**0.5
|
|
||||||
w1 = mx.quantize(w1)
|
|
||||||
w2 = mx.quantize(w2)
|
|
||||||
indices = (mx.random.uniform(shape=(N, I)) * E).astype(mx.uint32)
|
|
||||||
sorted_indices = mx.sort(indices.flatten()).reshape(N, I)
|
|
||||||
mx.eval(x, w1, w2, indices, sorted_indices)
|
|
||||||
|
|
||||||
def gather_mm(x, w1, w2, indices, sort):
|
|
||||||
idx = indices
|
|
||||||
inv_order = None
|
|
||||||
if sort:
|
|
||||||
x, idx, inv_order = gather_sort(x, indices)
|
|
||||||
x = mx.gather_qmm(x, *w1, transpose=True, rhs_indices=idx, sorted_indices=sort)
|
|
||||||
x = mx.gather_qmm(x, *w2, transpose=True, rhs_indices=idx, sorted_indices=sort)
|
|
||||||
if sort:
|
|
||||||
x = scatter_unsort(x, inv_order, indices.shape)
|
|
||||||
return x
|
|
||||||
|
|
||||||
time_fn(gather_mm, x, w1, w2, indices, False)
|
|
||||||
time_fn(gather_mm, x, w1, w2, sorted_indices, False)
|
|
||||||
time_fn(gather_mm, x, w1, w2, indices, True)
|
|
||||||
|
|
||||||
x = mx.random.normal((N * I, D)) / 1024**0.5
|
|
||||||
w1 = mx.random.normal((M, D)) / 1024**0.5
|
|
||||||
w2 = mx.random.normal((D, M)) / 1024**0.5
|
|
||||||
w1 = mx.quantize(w1)
|
|
||||||
w2 = mx.quantize(w2)
|
|
||||||
mx.eval(x, w1, w2)
|
|
||||||
|
|
||||||
def equivalent_matmul(x, w1, w2):
|
|
||||||
x = mx.quantized_matmul(x, *w1, transpose=True)
|
|
||||||
x = mx.quantized_matmul(x, *w2, transpose=True)
|
|
||||||
return x
|
|
||||||
|
|
||||||
time_fn(equivalent_matmul, x, w1, w2)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
time_gather_qmm()
|
|
||||||
@@ -1,7 +1,5 @@
|
|||||||
# Copyright © 2023-2024 Apple Inc.
|
# Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
from functools import partial
|
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
from time_utils import time_fn
|
from time_utils import time_fn
|
||||||
@@ -20,63 +18,51 @@ def layer_norm(x, w, b, eps):
|
|||||||
return y
|
return y
|
||||||
|
|
||||||
|
|
||||||
def time_layer_norm(N, dt):
|
def time_layer_norm():
|
||||||
L = 1024
|
|
||||||
f1 = lambda x, w, b, y: (layer_norm(x, w, b, 1e-5) * y).sum()
|
f1 = lambda x, w, b, y: (layer_norm(x, w, b, 1e-5) * y).sum()
|
||||||
f2 = lambda x, w, b, y: (mx.fast.layer_norm(x, w, b, 1e-5) * y).sum()
|
f2 = lambda x, w, b, y: (mx.fast.layer_norm(x, w, b, 1e-5) * y).sum()
|
||||||
g1 = mx.grad(f1, argnums=(0, 1, 2))
|
g1 = mx.grad(f1, argnums=(0, 1, 2))
|
||||||
g2 = mx.grad(f2, argnums=(0, 1, 2))
|
g2 = mx.grad(f2, argnums=(0, 1, 2))
|
||||||
|
|
||||||
x = mx.random.uniform(shape=(8, L, N)).astype(dt)
|
x = mx.random.uniform(shape=(8, 1024, 4096)).astype(mx.float16)
|
||||||
w = mx.random.uniform(shape=(N,)).astype(dt)
|
w = mx.random.uniform(shape=(4096,)).astype(mx.float16)
|
||||||
b = mx.random.uniform(shape=(N,)).astype(dt)
|
b = mx.random.uniform(shape=(4096,)).astype(mx.float16)
|
||||||
y = mx.random.uniform(shape=(8, L, N)).astype(dt)
|
y = mx.random.uniform(shape=(8, 1024, 4096)).astype(mx.float16)
|
||||||
mx.eval(x, w, b, y)
|
mx.eval(x, w, b, y)
|
||||||
|
|
||||||
def layer_norm_loop(f, x, w, b):
|
def layer_norm_loop(g, x, w, b):
|
||||||
for _ in range(32):
|
|
||||||
x = f(x, w, b)
|
|
||||||
return x
|
|
||||||
|
|
||||||
time_fn(layer_norm_loop, partial(layer_norm, eps=1e-5), x, w, b)
|
|
||||||
time_fn(layer_norm_loop, partial(mx.fast.layer_norm, eps=1e-5), x, w, b)
|
|
||||||
|
|
||||||
def layer_norm_grad_loop(g, x, w, b):
|
|
||||||
gx, gw, gb = x, w, b
|
gx, gw, gb = x, w, b
|
||||||
for _ in range(32):
|
for _ in range(32):
|
||||||
gx, gw, gb = g(gx, gw, gb, y)
|
gx, gw, gb = g(gx, gw, gb, y)
|
||||||
return gx, gw, gb
|
return gx, gw, gb
|
||||||
|
|
||||||
time_fn(layer_norm_grad_loop, g1, x, w, b)
|
time_fn(layer_norm_loop, g1, x, w, b)
|
||||||
time_fn(layer_norm_grad_loop, g2, x, w, b)
|
time_fn(layer_norm_loop, g2, x, w, b)
|
||||||
time_fn(layer_norm_grad_loop, mx.compile(g1), x, w, b)
|
time_fn(layer_norm_loop, mx.compile(g1), x, w, b)
|
||||||
time_fn(layer_norm_grad_loop, mx.compile(g2), x, w, b)
|
time_fn(layer_norm_loop, mx.compile(g2), x, w, b)
|
||||||
|
|
||||||
f1 = lambda x, y: (layer_norm(x, None, None, 1e-5) * y).sum()
|
f1 = lambda x, y: (layer_norm(x, None, None, 1e-5) * y).sum()
|
||||||
f2 = lambda x, y: (mx.fast.layer_norm(x, None, None, 1e-5) * y).sum()
|
f2 = lambda x, y: (mx.fast.layer_norm(x, None, None, 1e-5) * y).sum()
|
||||||
g1 = mx.grad(f1, argnums=(0,))
|
g1 = mx.grad(f1, argnums=(0,))
|
||||||
g2 = mx.grad(f2, argnums=(0,))
|
g2 = mx.grad(f2, argnums=(0,))
|
||||||
|
|
||||||
x = mx.random.uniform(shape=(8, L, N)).astype(dt)
|
x = mx.random.uniform(shape=(8, 1024, 4096)).astype(mx.float16)
|
||||||
w = mx.random.uniform(shape=(N,)).astype(dt)
|
w = mx.random.uniform(shape=(4096,)).astype(mx.float16)
|
||||||
b = mx.random.uniform(shape=(N,)).astype(dt)
|
b = mx.random.uniform(shape=(4096,)).astype(mx.float16)
|
||||||
y = mx.random.uniform(shape=(8, L, N)).astype(dt)
|
y = mx.random.uniform(shape=(8, 1024, 4096)).astype(mx.float16)
|
||||||
mx.eval(x, w, b, y)
|
mx.eval(x, w, b, y)
|
||||||
|
|
||||||
def layer_norm_grad_x_loop(g, x):
|
def layer_norm_loop(g, x):
|
||||||
gx = x
|
gx = x
|
||||||
for _ in range(32):
|
for _ in range(32):
|
||||||
gx = g(gx, y)
|
gx = g(gx, y)
|
||||||
return gx
|
return gx
|
||||||
|
|
||||||
time_fn(layer_norm_grad_x_loop, g1, x)
|
time_fn(layer_norm_loop, g1, x)
|
||||||
time_fn(layer_norm_grad_x_loop, g2, x)
|
time_fn(layer_norm_loop, g2, x)
|
||||||
time_fn(layer_norm_grad_x_loop, mx.compile(g1), x)
|
time_fn(layer_norm_loop, mx.compile(g1), x)
|
||||||
time_fn(layer_norm_grad_x_loop, mx.compile(g2), x)
|
time_fn(layer_norm_loop, mx.compile(g2), x)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
for dt in [mx.float32, mx.float16, mx.bfloat16]:
|
time_layer_norm()
|
||||||
for n in [1024, 2048, 4096, 8192, 8192 + 1024]:
|
|
||||||
print(dt, n)
|
|
||||||
time_layer_norm(n, dt)
|
|
||||||
|
|||||||
@@ -28,34 +28,11 @@ def bench(f, *args):
|
|||||||
return (e - s) * 1e-9
|
return (e - s) * 1e-9
|
||||||
|
|
||||||
|
|
||||||
def prepare_inputs(B, qL, kL, D, qH, kH, mask, transpose, dtype):
|
def mlx_sdpa_fused_inner(q, k, v, scale):
|
||||||
np_dtype = getattr(np, dtype)
|
return mx.fast.scaled_dot_product_attention(q, k, v, scale=scale, mask=None)
|
||||||
|
|
||||||
shape_q = (B, qL, qH, D) if transpose else (B, qH, qL, D)
|
|
||||||
shape_kv = (B, kL, kH, D) if transpose else (B, kH, kL, D)
|
|
||||||
|
|
||||||
scale = 1.0 / math.sqrt(D)
|
|
||||||
|
|
||||||
q_np = np.random.normal(0.0, 1.0, shape_q).astype(np_dtype)
|
|
||||||
k_np = np.random.normal(0.0, scale, shape_kv).astype(np_dtype)
|
|
||||||
v_np = np.random.normal(0.0, scale, shape_kv).astype(np_dtype)
|
|
||||||
|
|
||||||
q_mx = mx.array(q_np)
|
|
||||||
k_mx = mx.array(k_np)
|
|
||||||
v_mx = mx.array(v_np)
|
|
||||||
|
|
||||||
if mask is not None:
|
|
||||||
if mask == "additive":
|
|
||||||
mask_np = np.random.normal(0.0, 1.0, (B, qH, qL, kL)).astype(np_dtype)
|
|
||||||
mask = mx.array(mask_np)
|
|
||||||
elif mask == "bool":
|
|
||||||
mask_np = np.random.uniform(0.0, 1.0, (B, qH, qL, kL)) < 0.5
|
|
||||||
mask = mx.array(mask_np)
|
|
||||||
|
|
||||||
return q_mx, k_mx, v_mx, scale, mask
|
|
||||||
|
|
||||||
|
|
||||||
def mlx_ref_attn(q, k, v, scale=1.0, mask=None):
|
def mlx_sdpa_unfused_inner(q, k, v, scale, f32softmax=False):
|
||||||
q_dtype = q.dtype
|
q_dtype = q.dtype
|
||||||
q = q * mx.array(scale, q_dtype)
|
q = q * mx.array(scale, q_dtype)
|
||||||
n_q_heads = q.shape[-3]
|
n_q_heads = q.shape[-3]
|
||||||
@@ -64,7 +41,6 @@ def mlx_ref_attn(q, k, v, scale=1.0, mask=None):
|
|||||||
|
|
||||||
B = q.shape[0]
|
B = q.shape[0]
|
||||||
L = q.shape[2]
|
L = q.shape[2]
|
||||||
kL = k.shape[2]
|
|
||||||
|
|
||||||
if n_repeats > 1:
|
if n_repeats > 1:
|
||||||
q = mx.reshape(q, [B, n_kv_heads, n_repeats, L, -1])
|
q = mx.reshape(q, [B, n_kv_heads, n_repeats, L, -1])
|
||||||
@@ -72,27 +48,10 @@ def mlx_ref_attn(q, k, v, scale=1.0, mask=None):
|
|||||||
v = mx.expand_dims(v, 2)
|
v = mx.expand_dims(v, 2)
|
||||||
|
|
||||||
scores = q @ mx.swapaxes(k, -1, -2)
|
scores = q @ mx.swapaxes(k, -1, -2)
|
||||||
|
if f32softmax:
|
||||||
if mask is not None:
|
scores = mx.softmax(scores.astype(mx.float32), axis=-1).astype(q_dtype)
|
||||||
|
else:
|
||||||
if mask == "causal":
|
scores = mx.softmax(scores, axis=-1)
|
||||||
q_offset = max(0, kL - L)
|
|
||||||
q_indices = mx.arange(q_offset, q_offset + L)
|
|
||||||
k_indices = mx.arange(kL)
|
|
||||||
mask = q_indices[:, None] >= k_indices[None]
|
|
||||||
|
|
||||||
if n_repeats > 1 and mask.ndim >= 3:
|
|
||||||
if mask.shape[-3] == 1:
|
|
||||||
mask = mx.expand_dims(mask, -3)
|
|
||||||
else:
|
|
||||||
mask = mx.unflatten(mask, -3, (n_kv_heads, n_repeats))
|
|
||||||
|
|
||||||
if mask.dtype == mx.bool_:
|
|
||||||
scores = mx.where(mask, scores, -np.float32(np.inf))
|
|
||||||
else:
|
|
||||||
scores += mask
|
|
||||||
|
|
||||||
scores = mx.softmax(scores, axis=-1, precise=True)
|
|
||||||
|
|
||||||
out = scores @ v
|
out = scores @ v
|
||||||
if n_repeats > 1:
|
if n_repeats > 1:
|
||||||
@@ -101,55 +60,74 @@ def mlx_ref_attn(q, k, v, scale=1.0, mask=None):
|
|||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
def mlx_fused_attn(q, k, v, scale, mask):
|
def mlx_spda_unfused(q, k, v, scale, transpose):
|
||||||
return mx.fast.scaled_dot_product_attention(q, k, v, scale=scale, mask=mask)
|
|
||||||
|
|
||||||
|
|
||||||
def do_attention(f, q, k, v, scale, mask=None, transpose=False):
|
|
||||||
if transpose:
|
|
||||||
q_t = mx.transpose(q, (0, 2, 1, 3))
|
|
||||||
k_t = mx.transpose(k, (0, 2, 1, 3))
|
|
||||||
v_t = mx.transpose(v, (0, 2, 1, 3))
|
|
||||||
o_t = f(q_t, k_t, v_t, scale=scale, mask=mask)
|
|
||||||
return mx.transpose(o_t, (0, 2, 1, 3))
|
|
||||||
else:
|
|
||||||
return f(q, k, v, scale=scale, mask=mask)
|
|
||||||
|
|
||||||
|
|
||||||
def do_attention_bench(f, q, k, v, scale, mask=None, transpose=False):
|
|
||||||
q_out = q
|
q_out = q
|
||||||
|
if transpose:
|
||||||
|
k = mx.transpose(k, (0, 2, 1, 3))
|
||||||
|
v = mx.transpose(v, (0, 2, 1, 3))
|
||||||
|
|
||||||
for i in range(N_iter_func):
|
for i in range(N_iter_func):
|
||||||
q_out = do_attention(f, q_out, k, v, scale, mask=mask, transpose=transpose)
|
if transpose:
|
||||||
|
q_out = mx.transpose(q_out, (0, 2, 1, 3))
|
||||||
|
q_out = mlx_sdpa_unfused_inner(q_out, k, v, scale)
|
||||||
|
if transpose:
|
||||||
|
q_out = mx.transpose(q_out, (0, 2, 1, 3))
|
||||||
|
|
||||||
mx.eval(q_out)
|
mx.eval(q_out)
|
||||||
return q_out
|
return q_out
|
||||||
|
|
||||||
|
|
||||||
def bench_shape(
|
def mlx_spda_fused(q, k, v, scale, transpose):
|
||||||
B, qsl, ksl, head_dim, n_q_heads, n_kv_heads, dtype, transpose=True, mask_in=None
|
q_out = q
|
||||||
):
|
if transpose:
|
||||||
q_mx, k_mx, v_mx, scale, mask = prepare_inputs(
|
k = mx.transpose(k, (0, 2, 1, 3))
|
||||||
B, qsl, ksl, head_dim, n_q_heads, n_kv_heads, mask_in, transpose, dtype
|
v = mx.transpose(v, (0, 2, 1, 3))
|
||||||
|
|
||||||
|
for i in range(N_iter_func):
|
||||||
|
if transpose:
|
||||||
|
q_out = mx.transpose(q_out, (0, 2, 1, 3))
|
||||||
|
q_out = mlx_sdpa_fused_inner(q_out, k, v, scale)
|
||||||
|
if transpose:
|
||||||
|
q_out = mx.transpose(q_out, (0, 2, 1, 3))
|
||||||
|
|
||||||
|
mx.eval(q_out)
|
||||||
|
return q_out
|
||||||
|
|
||||||
|
|
||||||
|
def bench_shape(B, qsl, ksl, head_dim, n_q_heads, n_kv_heads, np_dtype, transpose=True):
|
||||||
|
shape_q = (
|
||||||
|
(B, qsl, n_q_heads, head_dim) if transpose else (B, n_q_heads, qsl, head_dim)
|
||||||
|
)
|
||||||
|
shape_kv = (
|
||||||
|
(B, ksl, n_kv_heads, head_dim) if transpose else (B, n_kv_heads, ksl, head_dim)
|
||||||
)
|
)
|
||||||
|
|
||||||
time_mlx_unfused = bench(
|
q_np = np.random.normal(0.0, 1.0 / math.sqrt(head_dim), shape_q).astype(np_dtype)
|
||||||
do_attention_bench, mlx_ref_attn, q_mx, k_mx, v_mx, scale, mask, transpose
|
k_np = np.random.normal(0.0, 1.0 / math.sqrt(head_dim), shape_kv).astype(np_dtype)
|
||||||
)
|
v_np = np.random.normal(0.0, 1.0 / math.sqrt(head_dim), shape_kv).astype(np_dtype)
|
||||||
time_mlx_fused = bench(
|
|
||||||
do_attention_bench, mlx_fused_attn, q_mx, k_mx, v_mx, scale, mask, transpose
|
|
||||||
)
|
|
||||||
|
|
||||||
o_mlx_fused = do_attention(mlx_ref_attn, q_mx, k_mx, v_mx, scale, mask, transpose)
|
scale = math.sqrt(1.0 / head_dim)
|
||||||
o_mlx_unfused = do_attention(
|
|
||||||
mlx_fused_attn, q_mx, k_mx, v_mx, scale, mask, transpose
|
|
||||||
)
|
|
||||||
|
|
||||||
atol = 1e-5 if dtype == "float32" else 2e-4
|
q_mx = mx.array(q_np)
|
||||||
|
k_mx = mx.array(k_np)
|
||||||
|
v_mx = mx.array(v_np)
|
||||||
|
|
||||||
if not mx.allclose(o_mlx_fused, o_mlx_unfused, atol=atol, rtol=atol):
|
time_mlx_unfused = bench(mlx_spda_unfused, q_mx, k_mx, v_mx, scale, transpose)
|
||||||
|
time_mlx_fused = bench(mlx_spda_fused, q_mx, k_mx, v_mx, scale, transpose)
|
||||||
|
|
||||||
|
if transpose:
|
||||||
|
q_mx = mx.transpose(q_mx, (0, 2, 1, 3))
|
||||||
|
k_mx = mx.transpose(k_mx, (0, 2, 1, 3))
|
||||||
|
v_mx = mx.transpose(v_mx, (0, 2, 1, 3))
|
||||||
|
|
||||||
|
o_mlx_fused = mlx_sdpa_fused_inner(q_mx, k_mx, v_mx, scale)
|
||||||
|
o_mlx_unfused = mlx_sdpa_unfused_inner(q_mx, k_mx, v_mx, scale, f32softmax=True)
|
||||||
|
|
||||||
|
atol = 1e-5 if np_dtype == np.float32 else 1e-4
|
||||||
|
|
||||||
|
if not mx.allclose(o_mlx_fused, o_mlx_unfused, atol=atol):
|
||||||
print(
|
print(
|
||||||
f"Failed at (B: {B}, qsl: {qsl}, ksl: {ksl}, head_dim: {head_dim}, n_qh: {n_q_heads}, n_kvh: {n_kv_heads}, mask: {mask_in}) [tpose = {transpose}] with max(|a - b|) = {mx.max(mx.abs(o_mlx_unfused - o_mlx_fused)):3.2e}"
|
f"Failed at (B: {B}, qsl: {qsl}, ksl: {ksl}, head_dim: {head_dim}, n_qh: {n_q_heads}, n_kvh: {n_kv_heads}) [tpose = {transpose}] with max(|a - b|) = {mx.max(mx.abs(o_mlx_unfused - o_mlx_fused)):3.2e}"
|
||||||
)
|
)
|
||||||
|
|
||||||
return time_mlx_fused, time_mlx_unfused
|
return time_mlx_fused, time_mlx_unfused
|
||||||
@@ -173,51 +151,39 @@ if __name__ == "__main__":
|
|||||||
( 1, 128, 128, 64, 32, 32),
|
( 1, 128, 128, 64, 32, 32),
|
||||||
( 1, 256, 256, 64, 32, 32),
|
( 1, 256, 256, 64, 32, 32),
|
||||||
( 1, 512, 512, 64, 32, 32),
|
( 1, 512, 512, 64, 32, 32),
|
||||||
( 1, 1024, 1024, 64, 32, 8),
|
( 1, 1024, 1024, 64, 32, 32),
|
||||||
( 1, 2048, 2048, 64, 32, 8),
|
( 1, 2048, 2048, 64, 32, 32),
|
||||||
( 1, 4096, 4096, 64, 32, 8),
|
( 1, 4096, 4096, 64, 32, 32),
|
||||||
)
|
)
|
||||||
|
|
||||||
shapes_80 = (
|
shapes_80 = (
|
||||||
# ( B, qsl, ksl, head_dim, n_qh, n_kvh)
|
# ( B, qsl, ksl, head_dim, n_qh, n_kvh)
|
||||||
( 1, 1024, 1024, 80, 32, 8),
|
( 1, 1024, 1024, 80, 32, 32),
|
||||||
( 1, 2048, 2048, 80, 32, 8),
|
( 1, 2048, 2048, 80, 32, 32),
|
||||||
( 1, 4096, 4096, 80, 32, 8),
|
( 1, 4096, 4096, 80, 32, 32),
|
||||||
)
|
)
|
||||||
|
|
||||||
shapes_128 = (
|
shapes_128 = (
|
||||||
# ( B, qsl, ksl, head_dim, n_qh, n_kvh)
|
# ( B, qsl, ksl, head_dim, n_qh, n_kvh)
|
||||||
( 1, 1024, 1024, 128, 32, 8),
|
( 1, 1024, 1024, 128, 32, 32),
|
||||||
( 1, 2048, 2048, 128, 32, 8),
|
( 1, 2048, 2048, 128, 32, 32),
|
||||||
( 1, 4096, 4096, 128, 32, 8),
|
( 1, 4096, 4096, 128, 32, 32),
|
||||||
)
|
)
|
||||||
# fmt: on
|
# fmt: on
|
||||||
|
|
||||||
shapes = shapes_64 + shapes_80 + shapes_128
|
shapes = shapes_64 + shapes_80 + shapes_128
|
||||||
|
|
||||||
masks = [None, "bool", "causal"]
|
print(" B, qsl, ksl, hdim, n_qh, n_kvh, tpose, dtype, t_unfs, t_fuse, diff%")
|
||||||
|
|
||||||
print(
|
|
||||||
" B, qsl, ksl, hdim, n_qh, n_kvh, t, dtype, mask, t_unfs, t_fuse, diff%"
|
|
||||||
)
|
|
||||||
|
|
||||||
for dtype in dtypes:
|
for dtype in dtypes:
|
||||||
for transpose in transposes:
|
for transpose in transposes:
|
||||||
for B, qsl, ksl, head_dim, n_q_heads, n_kv_heads in shapes:
|
for B, qsl, ksl, head_dim, n_q_heads, n_kv_heads in shapes:
|
||||||
for mask_in in masks:
|
np_dtype = getattr(np, dtype)
|
||||||
time_mlx_fused, time_mlx_unfused = bench_shape(
|
time_mlx_fused, time_mlx_unfused = bench_shape(
|
||||||
B,
|
B, qsl, ksl, head_dim, n_q_heads, n_kv_heads, np_dtype, transpose
|
||||||
qsl,
|
)
|
||||||
ksl,
|
diff = time_mlx_unfused / time_mlx_fused - 1.0
|
||||||
head_dim,
|
t_str = 1 if transpose else 0
|
||||||
n_q_heads,
|
print(
|
||||||
n_kv_heads,
|
f"{B:3d}, {qsl:5d}, {ksl:5d}, {head_dim:4d}, {n_q_heads:4d}, {n_kv_heads:5d}, {t_str:5d}, {dtype}, {time_mlx_unfused: 2.3f}, {time_mlx_fused: 2.3f}, {100. * diff:+5.2f}%"
|
||||||
dtype,
|
)
|
||||||
transpose,
|
|
||||||
mask_in,
|
|
||||||
)
|
|
||||||
diff = time_mlx_unfused / time_mlx_fused - 1.0
|
|
||||||
t_str = 1 if transpose else 0
|
|
||||||
print(
|
|
||||||
f"{B:3d}, {qsl:5d}, {ksl:5d}, {head_dim:4d}, {n_q_heads:4d}, {n_kv_heads:5d}, {t_str:1d}, {dtype}, {str(mask_in):>8}, {time_mlx_unfused: 2.3f}, {time_mlx_fused: 2.3f}, {100. * diff:+5.2f}%"
|
|
||||||
)
|
|
||||||
|
|||||||
@@ -11,14 +11,13 @@ include(CMakeParseArguments)
|
|||||||
# Args: TARGET: Custom target to be added for the metal library TITLE: Name of
|
# Args: TARGET: Custom target to be added for the metal library TITLE: Name of
|
||||||
# the .metallib OUTPUT_DIRECTORY: Where to place ${TITLE}.metallib SOURCES: List
|
# the .metallib OUTPUT_DIRECTORY: Where to place ${TITLE}.metallib SOURCES: List
|
||||||
# of source files INCLUDE_DIRS: List of include dirs DEPS: List of dependency
|
# of source files INCLUDE_DIRS: List of include dirs DEPS: List of dependency
|
||||||
# files (like headers) DEBUG: Boolean, if true, enables debug compile options
|
# files (like headers)
|
||||||
# for this specific library. If not provided, uses global MLX_METAL_DEBUG.
|
|
||||||
#
|
#
|
||||||
# clang format on
|
# clang format on
|
||||||
|
|
||||||
macro(mlx_build_metallib)
|
macro(mlx_build_metallib)
|
||||||
# Parse args
|
# Parse args
|
||||||
set(oneValueArgs TARGET TITLE OUTPUT_DIRECTORY DEBUG)
|
set(oneValueArgs TARGET TITLE OUTPUT_DIRECTORY)
|
||||||
set(multiValueArgs SOURCES INCLUDE_DIRS DEPS)
|
set(multiValueArgs SOURCES INCLUDE_DIRS DEPS)
|
||||||
cmake_parse_arguments(MTLLIB "" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
|
cmake_parse_arguments(MTLLIB "" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
|
||||||
|
|
||||||
@@ -27,10 +26,6 @@ macro(mlx_build_metallib)
|
|||||||
|
|
||||||
# Collect compile options
|
# Collect compile options
|
||||||
set(MTLLIB_COMPILE_OPTIONS -Wall -Wextra -fno-fast-math -Wno-c++17-extensions)
|
set(MTLLIB_COMPILE_OPTIONS -Wall -Wextra -fno-fast-math -Wno-c++17-extensions)
|
||||||
if(MLX_METAL_DEBUG OR MTLLIB_DEBUG)
|
|
||||||
set(MTLLIB_COMPILE_OPTIONS ${MTLLIB_COMPILE_OPTIONS} -gline-tables-only
|
|
||||||
-frecord-sources)
|
|
||||||
endif()
|
|
||||||
|
|
||||||
# Prepare metallib build command
|
# Prepare metallib build command
|
||||||
add_custom_command(
|
add_custom_command(
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ EXCLUDE_PATTERNS = */private/*
|
|||||||
CREATE_SUBDIRS = NO
|
CREATE_SUBDIRS = NO
|
||||||
FULL_PATH_NAMES = YES
|
FULL_PATH_NAMES = YES
|
||||||
RECURSIVE = YES
|
RECURSIVE = YES
|
||||||
GENERATE_HTML = NO
|
GENERATE_HTML = YES
|
||||||
GENERATE_LATEX = NO
|
GENERATE_LATEX = NO
|
||||||
GENERATE_XML = YES
|
GENERATE_XML = YES
|
||||||
XML_PROGRAMLISTING = YES
|
XML_PROGRAMLISTING = YES
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ import mlx.core as mx
|
|||||||
# -- Project information -----------------------------------------------------
|
# -- Project information -----------------------------------------------------
|
||||||
|
|
||||||
project = "MLX"
|
project = "MLX"
|
||||||
copyright = "2023, Apple"
|
copyright = "2023, MLX Contributors"
|
||||||
author = "MLX Contributors"
|
author = "MLX Contributors"
|
||||||
version = ".".join(mx.__version__.split(".")[:3])
|
version = ".".join(mx.__version__.split(".")[:3])
|
||||||
release = version
|
release = version
|
||||||
|
|||||||
@@ -8,26 +8,23 @@ MLX supports writing custom Metal kernels through the Python and C++ APIs.
|
|||||||
Simple Example
|
Simple Example
|
||||||
--------------
|
--------------
|
||||||
|
|
||||||
.. currentmodule:: mlx.core
|
|
||||||
|
|
||||||
Let's write a custom kernel that computes ``exp`` elementwise:
|
Let's write a custom kernel that computes ``exp`` elementwise:
|
||||||
|
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
|
||||||
source = """
|
|
||||||
uint elem = thread_position_in_grid.x;
|
|
||||||
T tmp = inp[elem];
|
|
||||||
out[elem] = metal::exp(tmp);
|
|
||||||
"""
|
|
||||||
|
|
||||||
kernel = mx.fast.metal_kernel(
|
|
||||||
name="myexp",
|
|
||||||
input_names=["inp"],
|
|
||||||
output_names=["out"],
|
|
||||||
source=source,
|
|
||||||
)
|
|
||||||
|
|
||||||
def exp_elementwise(a: mx.array):
|
def exp_elementwise(a: mx.array):
|
||||||
|
source = """
|
||||||
|
uint elem = thread_position_in_grid.x;
|
||||||
|
T tmp = inp[elem];
|
||||||
|
out[elem] = metal::exp(tmp);
|
||||||
|
"""
|
||||||
|
|
||||||
|
kernel = mx.fast.metal_kernel(
|
||||||
|
name="myexp",
|
||||||
|
input_names=["inp"],
|
||||||
|
output_names=["out"],
|
||||||
|
source=source,
|
||||||
|
)
|
||||||
outputs = kernel(
|
outputs = kernel(
|
||||||
inputs=[a],
|
inputs=[a],
|
||||||
template=[("T", mx.float32)],
|
template=[("T", mx.float32)],
|
||||||
@@ -42,13 +39,8 @@ Let's write a custom kernel that computes ``exp`` elementwise:
|
|||||||
b = exp_elementwise(a)
|
b = exp_elementwise(a)
|
||||||
assert mx.allclose(b, mx.exp(a))
|
assert mx.allclose(b, mx.exp(a))
|
||||||
|
|
||||||
Every time you make a kernel, a new Metal library is created and possibly
|
|
||||||
JIT compiled. To reduce the overhead from that, build the kernel once with
|
|
||||||
:func:`fast.metal_kernel` and then use it many times.
|
|
||||||
|
|
||||||
.. note::
|
.. note::
|
||||||
Only pass the body of the Metal kernel in ``source``. The function
|
We are only required to pass the body of the Metal kernel in ``source``.
|
||||||
signature is generated automatically.
|
|
||||||
|
|
||||||
The full function signature will be generated using:
|
The full function signature will be generated using:
|
||||||
|
|
||||||
@@ -86,51 +78,44 @@ Putting this all together, the generated function signature for ``myexp`` is as
|
|||||||
|
|
||||||
template [[host_name("custom_kernel_myexp_float")]] [[kernel]] decltype(custom_kernel_myexp_float<float>) custom_kernel_myexp_float<float>;
|
template [[host_name("custom_kernel_myexp_float")]] [[kernel]] decltype(custom_kernel_myexp_float<float>) custom_kernel_myexp_float<float>;
|
||||||
|
|
||||||
Note: ``grid`` and ``threadgroup`` are parameters to the Metal `dispatchThreads
|
Note: ``grid`` and ``threadgroup`` are parameters to the Metal `dispatchThreads <https://developer.apple.com/documentation/metal/mtlcomputecommandencoder/2866532-dispatchthreads>`_ function.
|
||||||
<https://developer.apple.com/documentation/metal/mtlcomputecommandencoder/2866532-dispatchthreads>`_
|
This means we will launch ``mx.prod(grid)`` threads, subdivided into ``threadgroup`` size threadgroups.
|
||||||
function. This means we will launch ``mx.prod(grid)`` threads, subdivided into
|
For optimal performance, each thread group dimension should be less than or equal to the corresponding grid dimension.
|
||||||
``threadgroup`` size threadgroups. For optimal performance, each thread group
|
|
||||||
dimension should be less than or equal to the corresponding grid dimension.
|
|
||||||
|
|
||||||
Passing ``verbose=True`` to :func:`ast.metal_kernel.__call__` will print the
|
Passing ``verbose=True`` to ``mx.fast.metal_kernel.__call__`` will print the generated code for debugging purposes.
|
||||||
generated code for debugging purposes.
|
|
||||||
|
|
||||||
Using Shape/Strides
|
Using Shape/Strides
|
||||||
-------------------
|
-------------------
|
||||||
|
|
||||||
:func:`fast.metal_kernel` supports an argument ``ensure_row_contiguous`` which
|
``mx.fast.metal_kernel`` supports an argument ``ensure_row_contiguous`` which is ``True`` by default.
|
||||||
is ``True`` by default. This will copy the array inputs if needed
|
This will copy the ``mx.array`` inputs if needed before the kernel is launched to ensure that the memory layout is row contiguous.
|
||||||
before the kernel is launched to ensure that the memory layout is row
|
Generally this makes writing the kernel easier, since we don't have to worry about gaps or the ordering of the dims
|
||||||
contiguous. Generally this makes writing the kernel easier, since we don't
|
when indexing.
|
||||||
have to worry about gaps or the ordering of the dims when indexing.
|
|
||||||
|
|
||||||
If we want to avoid this copy, :func:`fast.metal_kernel` automatically passes
|
If we want to avoid this copy, ``metal_kernel`` automatically passes ``a_shape``, ``a_strides`` and ``a_ndim`` for each
|
||||||
``a_shape``, ``a_strides`` and ``a_ndim`` for each input array ``a`` if any are
|
input array ``a`` if any are present in ``source``.
|
||||||
present in ``source``. We can then use MLX's built in indexing utils to fetch
|
We can then use MLX's built in indexing utils to fetch the right elements for each thread.
|
||||||
the right elements for each thread.
|
|
||||||
|
|
||||||
Let's convert ``myexp`` above to support arbitrarily strided arrays without
|
Let's convert ``myexp`` above to support arbitrarily strided arrays without relying on a copy from ``ensure_row_contiguous``:
|
||||||
relying on a copy from ``ensure_row_contiguous``:
|
|
||||||
|
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
|
||||||
source = """
|
|
||||||
uint elem = thread_position_in_grid.x;
|
|
||||||
// Utils from `mlx/backend/metal/kernels/utils.h` are automatically included
|
|
||||||
uint loc = elem_to_loc(elem, inp_shape, inp_strides, inp_ndim);
|
|
||||||
T tmp = inp[loc];
|
|
||||||
// Output arrays are always row contiguous
|
|
||||||
out[elem] = metal::exp(tmp);
|
|
||||||
"""
|
|
||||||
|
|
||||||
kernel = mx.fast.metal_kernel(
|
|
||||||
name="myexp_strided",
|
|
||||||
input_names=["inp"],
|
|
||||||
output_names=["out"],
|
|
||||||
source=source
|
|
||||||
)
|
|
||||||
|
|
||||||
def exp_elementwise(a: mx.array):
|
def exp_elementwise(a: mx.array):
|
||||||
|
source = """
|
||||||
|
uint elem = thread_position_in_grid.x;
|
||||||
|
// Utils from `mlx/backend/metal/kernels/utils.h` are automatically included
|
||||||
|
uint loc = elem_to_loc(elem, inp_shape, inp_strides, inp_ndim);
|
||||||
|
T tmp = inp[loc];
|
||||||
|
// Output arrays are always row contiguous
|
||||||
|
out[elem] = metal::exp(tmp);
|
||||||
|
"""
|
||||||
|
|
||||||
|
kernel = mx.fast.metal_kernel(
|
||||||
|
name="myexp_strided",
|
||||||
|
input_names=["inp"],
|
||||||
|
output_names=["out"],
|
||||||
|
source=source
|
||||||
|
)
|
||||||
outputs = kernel(
|
outputs = kernel(
|
||||||
inputs=[a],
|
inputs=[a],
|
||||||
template=[("T", mx.float32)],
|
template=[("T", mx.float32)],
|
||||||
@@ -157,139 +142,137 @@ We'll start with the following MLX implementation using standard ops:
|
|||||||
|
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
|
||||||
def grid_sample_ref(x, grid):
|
def grid_sample_ref(x, grid):
|
||||||
N, H_in, W_in, _ = x.shape
|
N, H_in, W_in, _ = x.shape
|
||||||
ix = ((grid[..., 0] + 1) * W_in - 1) / 2
|
ix = ((grid[..., 0] + 1) * W_in - 1) / 2
|
||||||
iy = ((grid[..., 1] + 1) * H_in - 1) / 2
|
iy = ((grid[..., 1] + 1) * H_in - 1) / 2
|
||||||
|
|
||||||
ix_nw = mx.floor(ix).astype(mx.int32)
|
ix_nw = mx.floor(ix).astype(mx.int32)
|
||||||
iy_nw = mx.floor(iy).astype(mx.int32)
|
iy_nw = mx.floor(iy).astype(mx.int32)
|
||||||
|
|
||||||
ix_ne = ix_nw + 1
|
ix_ne = ix_nw + 1
|
||||||
iy_ne = iy_nw
|
iy_ne = iy_nw
|
||||||
|
|
||||||
ix_sw = ix_nw
|
ix_sw = ix_nw
|
||||||
iy_sw = iy_nw + 1
|
iy_sw = iy_nw + 1
|
||||||
|
|
||||||
ix_se = ix_nw + 1
|
ix_se = ix_nw + 1
|
||||||
iy_se = iy_nw + 1
|
iy_se = iy_nw + 1
|
||||||
|
|
||||||
nw = (ix_se - ix) * (iy_se - iy)
|
nw = (ix_se - ix) * (iy_se - iy)
|
||||||
ne = (ix - ix_sw) * (iy_sw - iy)
|
ne = (ix - ix_sw) * (iy_sw - iy)
|
||||||
sw = (ix_ne - ix) * (iy - iy_ne)
|
sw = (ix_ne - ix) * (iy - iy_ne)
|
||||||
se = (ix - ix_nw) * (iy - iy_nw)
|
se = (ix - ix_nw) * (iy - iy_nw)
|
||||||
|
|
||||||
I_nw = x[mx.arange(N)[:, None, None], iy_nw, ix_nw, :]
|
I_nw = x[mx.arange(N)[:, None, None], iy_nw, ix_nw, :]
|
||||||
I_ne = x[mx.arange(N)[:, None, None], iy_ne, ix_ne, :]
|
I_ne = x[mx.arange(N)[:, None, None], iy_ne, ix_ne, :]
|
||||||
I_sw = x[mx.arange(N)[:, None, None], iy_sw, ix_sw, :]
|
I_sw = x[mx.arange(N)[:, None, None], iy_sw, ix_sw, :]
|
||||||
I_se = x[mx.arange(N)[:, None, None], iy_se, ix_se, :]
|
I_se = x[mx.arange(N)[:, None, None], iy_se, ix_se, :]
|
||||||
|
|
||||||
mask_nw = (iy_nw >= 0) & (iy_nw <= H_in - 1) & (ix_nw >= 0) & (ix_nw <= W_in - 1)
|
mask_nw = (iy_nw >= 0) & (iy_nw <= H_in - 1) & (ix_nw >= 0) & (ix_nw <= W_in - 1)
|
||||||
mask_ne = (iy_ne >= 0) & (iy_ne <= H_in - 1) & (ix_ne >= 0) & (ix_ne <= W_in - 1)
|
mask_ne = (iy_ne >= 0) & (iy_ne <= H_in - 1) & (ix_ne >= 0) & (ix_ne <= W_in - 1)
|
||||||
mask_sw = (iy_sw >= 0) & (iy_sw <= H_in - 1) & (ix_sw >= 0) & (ix_sw <= W_in - 1)
|
mask_sw = (iy_sw >= 0) & (iy_sw <= H_in - 1) & (ix_sw >= 0) & (ix_sw <= W_in - 1)
|
||||||
mask_se = (iy_se >= 0) & (iy_se <= H_in - 1) & (ix_se >= 0) & (ix_se <= W_in - 1)
|
mask_se = (iy_se >= 0) & (iy_se <= H_in - 1) & (ix_se >= 0) & (ix_se <= W_in - 1)
|
||||||
|
|
||||||
I_nw *= mask_nw[..., None]
|
I_nw *= mask_nw[..., None]
|
||||||
I_ne *= mask_ne[..., None]
|
I_ne *= mask_ne[..., None]
|
||||||
I_sw *= mask_sw[..., None]
|
I_sw *= mask_sw[..., None]
|
||||||
I_se *= mask_se[..., None]
|
I_se *= mask_se[..., None]
|
||||||
|
|
||||||
output = nw[..., None] * I_nw + ne[..., None] * I_ne + sw[..., None] * I_sw + se[..., None] * I_se
|
output = nw[..., None] * I_nw + ne[..., None] * I_ne + sw[..., None] * I_sw + se[..., None] * I_se
|
||||||
|
|
||||||
return output
|
return output
|
||||||
|
|
||||||
Now let's use :func:`custom_function` together with :func:`fast.metal_kernel`
|
Now let's use ``mx.custom_function`` together with ``mx.fast.metal_kernel``
|
||||||
to write a fast GPU kernel for both the forward and backward passes.
|
to write a fast GPU kernel for both the forward and backward passes.
|
||||||
|
|
||||||
First we'll implement the forward pass as a fused kernel:
|
First we'll implement the forward pass as a fused kernel:
|
||||||
|
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
|
||||||
source = """
|
@mx.custom_function
|
||||||
uint elem = thread_position_in_grid.x;
|
def grid_sample(x, grid):
|
||||||
int H = x_shape[1];
|
|
||||||
int W = x_shape[2];
|
|
||||||
int C = x_shape[3];
|
|
||||||
int gH = grid_shape[1];
|
|
||||||
int gW = grid_shape[2];
|
|
||||||
|
|
||||||
int w_stride = C;
|
assert x.ndim == 4, "`x` must be 4D."
|
||||||
int h_stride = W * w_stride;
|
assert grid.ndim == 4, "`grid` must be 4D."
|
||||||
int b_stride = H * h_stride;
|
|
||||||
|
|
||||||
uint grid_idx = elem / C * 2;
|
B, _, _, C = x.shape
|
||||||
float ix = ((grid[grid_idx] + 1) * W - 1) / 2;
|
_, gN, gM, D = grid.shape
|
||||||
float iy = ((grid[grid_idx + 1] + 1) * H - 1) / 2;
|
out_shape = (B, gN, gM, C)
|
||||||
|
|
||||||
int ix_nw = floor(ix);
|
assert D == 2, "Last dim of `grid` must be size 2."
|
||||||
int iy_nw = floor(iy);
|
|
||||||
|
|
||||||
int ix_ne = ix_nw + 1;
|
source = """
|
||||||
int iy_ne = iy_nw;
|
uint elem = thread_position_in_grid.x;
|
||||||
|
int H = x_shape[1];
|
||||||
|
int W = x_shape[2];
|
||||||
|
int C = x_shape[3];
|
||||||
|
int gH = grid_shape[1];
|
||||||
|
int gW = grid_shape[2];
|
||||||
|
|
||||||
int ix_sw = ix_nw;
|
int w_stride = C;
|
||||||
int iy_sw = iy_nw + 1;
|
int h_stride = W * w_stride;
|
||||||
|
int b_stride = H * h_stride;
|
||||||
|
|
||||||
int ix_se = ix_nw + 1;
|
uint grid_idx = elem / C * 2;
|
||||||
int iy_se = iy_nw + 1;
|
float ix = ((grid[grid_idx] + 1) * W - 1) / 2;
|
||||||
|
float iy = ((grid[grid_idx + 1] + 1) * H - 1) / 2;
|
||||||
|
|
||||||
T nw = (ix_se - ix) * (iy_se - iy);
|
int ix_nw = floor(ix);
|
||||||
T ne = (ix - ix_sw) * (iy_sw - iy);
|
int iy_nw = floor(iy);
|
||||||
T sw = (ix_ne - ix) * (iy - iy_ne);
|
|
||||||
T se = (ix - ix_nw) * (iy - iy_nw);
|
|
||||||
|
|
||||||
int batch_idx = elem / C / gH / gW * b_stride;
|
int ix_ne = ix_nw + 1;
|
||||||
int channel_idx = elem % C;
|
int iy_ne = iy_nw;
|
||||||
int base_idx = batch_idx + channel_idx;
|
|
||||||
|
|
||||||
T I_nw = x[base_idx + iy_nw * h_stride + ix_nw * w_stride];
|
int ix_sw = ix_nw;
|
||||||
T I_ne = x[base_idx + iy_ne * h_stride + ix_ne * w_stride];
|
int iy_sw = iy_nw + 1;
|
||||||
T I_sw = x[base_idx + iy_sw * h_stride + ix_sw * w_stride];
|
|
||||||
T I_se = x[base_idx + iy_se * h_stride + ix_se * w_stride];
|
|
||||||
|
|
||||||
I_nw = iy_nw >= 0 && iy_nw <= H - 1 && ix_nw >= 0 && ix_nw <= W - 1 ? I_nw : 0;
|
int ix_se = ix_nw + 1;
|
||||||
I_ne = iy_ne >= 0 && iy_ne <= H - 1 && ix_ne >= 0 && ix_ne <= W - 1 ? I_ne : 0;
|
int iy_se = iy_nw + 1;
|
||||||
I_sw = iy_sw >= 0 && iy_sw <= H - 1 && ix_sw >= 0 && ix_sw <= W - 1 ? I_sw : 0;
|
|
||||||
I_se = iy_se >= 0 && iy_se <= H - 1 && ix_se >= 0 && ix_se <= W - 1 ? I_se : 0;
|
|
||||||
|
|
||||||
out[elem] = nw * I_nw + ne * I_ne + sw * I_sw + se * I_se;
|
T nw = (ix_se - ix) * (iy_se - iy);
|
||||||
"""
|
T ne = (ix - ix_sw) * (iy_sw - iy);
|
||||||
|
T sw = (ix_ne - ix) * (iy - iy_ne);
|
||||||
|
T se = (ix - ix_nw) * (iy - iy_nw);
|
||||||
|
|
||||||
kernel = mx.fast.metal_kernel(
|
int batch_idx = elem / C / gH / gW * b_stride;
|
||||||
name="grid_sample",
|
int channel_idx = elem % C;
|
||||||
input_names=["x", "grid"],
|
int base_idx = batch_idx + channel_idx;
|
||||||
output_names=["out"],
|
|
||||||
source=source,
|
|
||||||
)
|
|
||||||
|
|
||||||
@mx.custom_function
|
T I_nw = x[base_idx + iy_nw * h_stride + ix_nw * w_stride];
|
||||||
def grid_sample(x, grid):
|
T I_ne = x[base_idx + iy_ne * h_stride + ix_ne * w_stride];
|
||||||
|
T I_sw = x[base_idx + iy_sw * h_stride + ix_sw * w_stride];
|
||||||
|
T I_se = x[base_idx + iy_se * h_stride + ix_se * w_stride];
|
||||||
|
|
||||||
assert x.ndim == 4, "`x` must be 4D."
|
I_nw = iy_nw >= 0 && iy_nw <= H - 1 && ix_nw >= 0 && ix_nw <= W - 1 ? I_nw : 0;
|
||||||
assert grid.ndim == 4, "`grid` must be 4D."
|
I_ne = iy_ne >= 0 && iy_ne <= H - 1 && ix_ne >= 0 && ix_ne <= W - 1 ? I_ne : 0;
|
||||||
|
I_sw = iy_sw >= 0 && iy_sw <= H - 1 && ix_sw >= 0 && ix_sw <= W - 1 ? I_sw : 0;
|
||||||
|
I_se = iy_se >= 0 && iy_se <= H - 1 && ix_se >= 0 && ix_se <= W - 1 ? I_se : 0;
|
||||||
|
|
||||||
B, _, _, C = x.shape
|
out[elem] = nw * I_nw + ne * I_ne + sw * I_sw + se * I_se;
|
||||||
_, gN, gM, D = grid.shape
|
"""
|
||||||
out_shape = (B, gN, gM, C)
|
kernel = mx.fast.metal_kernel(
|
||||||
|
name="grid_sample",
|
||||||
assert D == 2, "Last dim of `grid` must be size 2."
|
input_names=["x", "grid"],
|
||||||
|
output_names=["out"],
|
||||||
outputs = kernel(
|
source=source,
|
||||||
inputs=[x, grid],
|
)
|
||||||
template=[("T", x.dtype)],
|
outputs = kernel(
|
||||||
output_shapes=[out_shape],
|
inputs=[x, grid],
|
||||||
output_dtypes=[x.dtype],
|
template=[("T", x.dtype)],
|
||||||
grid=(np.prod(out_shape), 1, 1),
|
output_shapes=[out_shape],
|
||||||
threadgroup=(256, 1, 1),
|
output_dtypes=[x.dtype],
|
||||||
)
|
grid=(np.prod(out_shape), 1, 1),
|
||||||
return outputs[0]
|
threadgroup=(256, 1, 1),
|
||||||
|
)
|
||||||
|
return outputs[0]
|
||||||
|
|
||||||
For a reasonably sized input such as:
|
For a reasonably sized input such as:
|
||||||
|
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
|
||||||
x.shape = (8, 1024, 1024, 64)
|
x.shape = (8, 1024, 1024, 64)
|
||||||
grid.shape = (8, 256, 256, 2)
|
grid.shape = (8, 256, 256, 2)
|
||||||
|
|
||||||
On an M1 Max, we see a big performance improvement:
|
On an M1 Max, we see a big performance improvement:
|
||||||
|
|
||||||
@@ -298,11 +281,11 @@ On an M1 Max, we see a big performance improvement:
|
|||||||
Grid Sample VJP
|
Grid Sample VJP
|
||||||
---------------
|
---------------
|
||||||
|
|
||||||
Since we decorated ``grid_sample`` with :func:`custom_function`, we can now
|
Since we decorated ``grid_sample`` with ``mx.custom_function``, we can now define
|
||||||
define its custom vjp transform so MLX can differentiate it.
|
its custom vjp transform so MLX can differentiate it.
|
||||||
|
|
||||||
The backwards pass requires atomically updating ``x_grad``/``grid_grad`` and so
|
The backwards pass requires atomically updating ``x_grad``/``grid_grad`` and so
|
||||||
requires a few extra :func:`fast.metal_kernel` features:
|
requires a few extra ``mx.fast.metal_kernel`` features:
|
||||||
|
|
||||||
* ``init_value=0``
|
* ``init_value=0``
|
||||||
Initialize all of the kernel's outputs to this value before it runs. This allows us to update only part of the output arrays with the kernel.
|
Initialize all of the kernel's outputs to this value before it runs. This allows us to update only part of the output arrays with the kernel.
|
||||||
@@ -316,129 +299,128 @@ We can then implement the backwards pass as follows:
|
|||||||
|
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
|
||||||
source = """
|
@grid_sample.vjp
|
||||||
uint elem = thread_position_in_grid.x;
|
def grid_sample_vjp(primals, cotangent, _):
|
||||||
int H = x_shape[1];
|
x, grid = primals
|
||||||
int W = x_shape[2];
|
B, _, _, C = x.shape
|
||||||
int C = x_shape[3];
|
_, gN, gM, D = grid.shape
|
||||||
// Pad C to the nearest larger simdgroup size multiple
|
|
||||||
int C_padded = ceildiv(C, threads_per_simdgroup) * threads_per_simdgroup;
|
|
||||||
|
|
||||||
int gH = grid_shape[1];
|
assert D == 2, "Last dim of `grid` must be size 2."
|
||||||
int gW = grid_shape[2];
|
|
||||||
|
|
||||||
int w_stride = C;
|
source = """
|
||||||
int h_stride = W * w_stride;
|
uint elem = thread_position_in_grid.x;
|
||||||
int b_stride = H * h_stride;
|
int H = x_shape[1];
|
||||||
|
int W = x_shape[2];
|
||||||
|
int C = x_shape[3];
|
||||||
|
// Pad C to the nearest larger simdgroup size multiple
|
||||||
|
int C_padded = ceildiv(C, threads_per_simdgroup) * threads_per_simdgroup;
|
||||||
|
|
||||||
uint grid_idx = elem / C_padded * 2;
|
int gH = grid_shape[1];
|
||||||
float ix = ((grid[grid_idx] + 1) * W - 1) / 2;
|
int gW = grid_shape[2];
|
||||||
float iy = ((grid[grid_idx + 1] + 1) * H - 1) / 2;
|
|
||||||
|
|
||||||
int ix_nw = floor(ix);
|
int w_stride = C;
|
||||||
int iy_nw = floor(iy);
|
int h_stride = W * w_stride;
|
||||||
|
int b_stride = H * h_stride;
|
||||||
|
|
||||||
int ix_ne = ix_nw + 1;
|
uint grid_idx = elem / C_padded * 2;
|
||||||
int iy_ne = iy_nw;
|
float ix = ((grid[grid_idx] + 1) * W - 1) / 2;
|
||||||
|
float iy = ((grid[grid_idx + 1] + 1) * H - 1) / 2;
|
||||||
|
|
||||||
int ix_sw = ix_nw;
|
int ix_nw = floor(ix);
|
||||||
int iy_sw = iy_nw + 1;
|
int iy_nw = floor(iy);
|
||||||
|
|
||||||
int ix_se = ix_nw + 1;
|
int ix_ne = ix_nw + 1;
|
||||||
int iy_se = iy_nw + 1;
|
int iy_ne = iy_nw;
|
||||||
|
|
||||||
T nw = (ix_se - ix) * (iy_se - iy);
|
int ix_sw = ix_nw;
|
||||||
T ne = (ix - ix_sw) * (iy_sw - iy);
|
int iy_sw = iy_nw + 1;
|
||||||
T sw = (ix_ne - ix) * (iy - iy_ne);
|
|
||||||
T se = (ix - ix_nw) * (iy - iy_nw);
|
|
||||||
|
|
||||||
int batch_idx = elem / C_padded / gH / gW * b_stride;
|
int ix_se = ix_nw + 1;
|
||||||
int channel_idx = elem % C_padded;
|
int iy_se = iy_nw + 1;
|
||||||
int base_idx = batch_idx + channel_idx;
|
|
||||||
|
|
||||||
T gix = T(0);
|
T nw = (ix_se - ix) * (iy_se - iy);
|
||||||
T giy = T(0);
|
T ne = (ix - ix_sw) * (iy_sw - iy);
|
||||||
if (channel_idx < C) {
|
T sw = (ix_ne - ix) * (iy - iy_ne);
|
||||||
int cot_index = elem / C_padded * C + channel_idx;
|
T se = (ix - ix_nw) * (iy - iy_nw);
|
||||||
T cot = cotangent[cot_index];
|
|
||||||
if (iy_nw >= 0 && iy_nw <= H - 1 && ix_nw >= 0 && ix_nw <= W - 1) {
|
|
||||||
int offset = base_idx + iy_nw * h_stride + ix_nw * w_stride;
|
|
||||||
atomic_fetch_add_explicit(&x_grad[offset], nw * cot, memory_order_relaxed);
|
|
||||||
|
|
||||||
T I_nw = x[offset];
|
int batch_idx = elem / C_padded / gH / gW * b_stride;
|
||||||
gix -= I_nw * (iy_se - iy) * cot;
|
int channel_idx = elem % C_padded;
|
||||||
giy -= I_nw * (ix_se - ix) * cot;
|
int base_idx = batch_idx + channel_idx;
|
||||||
}
|
|
||||||
if (iy_ne >= 0 && iy_ne <= H - 1 && ix_ne >= 0 && ix_ne <= W - 1) {
|
|
||||||
int offset = base_idx + iy_ne * h_stride + ix_ne * w_stride;
|
|
||||||
atomic_fetch_add_explicit(&x_grad[offset], ne * cot, memory_order_relaxed);
|
|
||||||
|
|
||||||
T I_ne = x[offset];
|
T gix = T(0);
|
||||||
gix += I_ne * (iy_sw - iy) * cot;
|
T giy = T(0);
|
||||||
giy -= I_ne * (ix - ix_sw) * cot;
|
if (channel_idx < C) {
|
||||||
}
|
int cot_index = elem / C_padded * C + channel_idx;
|
||||||
if (iy_sw >= 0 && iy_sw <= H - 1 && ix_sw >= 0 && ix_sw <= W - 1) {
|
T cot = cotangent[cot_index];
|
||||||
int offset = base_idx + iy_sw * h_stride + ix_sw * w_stride;
|
if (iy_nw >= 0 && iy_nw <= H - 1 && ix_nw >= 0 && ix_nw <= W - 1) {
|
||||||
atomic_fetch_add_explicit(&x_grad[offset], sw * cot, memory_order_relaxed);
|
int offset = base_idx + iy_nw * h_stride + ix_nw * w_stride;
|
||||||
|
atomic_fetch_add_explicit(&x_grad[offset], nw * cot, memory_order_relaxed);
|
||||||
|
|
||||||
T I_sw = x[offset];
|
T I_nw = x[offset];
|
||||||
gix -= I_sw * (iy - iy_ne) * cot;
|
gix -= I_nw * (iy_se - iy) * cot;
|
||||||
giy += I_sw * (ix_ne - ix) * cot;
|
giy -= I_nw * (ix_se - ix) * cot;
|
||||||
}
|
}
|
||||||
if (iy_se >= 0 && iy_se <= H - 1 && ix_se >= 0 && ix_se <= W - 1) {
|
if (iy_ne >= 0 && iy_ne <= H - 1 && ix_ne >= 0 && ix_ne <= W - 1) {
|
||||||
int offset = base_idx + iy_se * h_stride + ix_se * w_stride;
|
int offset = base_idx + iy_ne * h_stride + ix_ne * w_stride;
|
||||||
atomic_fetch_add_explicit(&x_grad[offset], se * cot, memory_order_relaxed);
|
atomic_fetch_add_explicit(&x_grad[offset], ne * cot, memory_order_relaxed);
|
||||||
|
|
||||||
T I_se = x[offset];
|
T I_ne = x[offset];
|
||||||
gix += I_se * (iy - iy_nw) * cot;
|
gix += I_ne * (iy_sw - iy) * cot;
|
||||||
giy += I_se * (ix - ix_nw) * cot;
|
giy -= I_ne * (ix - ix_sw) * cot;
|
||||||
}
|
}
|
||||||
}
|
if (iy_sw >= 0 && iy_sw <= H - 1 && ix_sw >= 0 && ix_sw <= W - 1) {
|
||||||
|
int offset = base_idx + iy_sw * h_stride + ix_sw * w_stride;
|
||||||
|
atomic_fetch_add_explicit(&x_grad[offset], sw * cot, memory_order_relaxed);
|
||||||
|
|
||||||
T gix_mult = W / 2;
|
T I_sw = x[offset];
|
||||||
T giy_mult = H / 2;
|
gix -= I_sw * (iy - iy_ne) * cot;
|
||||||
|
giy += I_sw * (ix_ne - ix) * cot;
|
||||||
|
}
|
||||||
|
if (iy_se >= 0 && iy_se <= H - 1 && ix_se >= 0 && ix_se <= W - 1) {
|
||||||
|
int offset = base_idx + iy_se * h_stride + ix_se * w_stride;
|
||||||
|
atomic_fetch_add_explicit(&x_grad[offset], se * cot, memory_order_relaxed);
|
||||||
|
|
||||||
// Reduce across each simdgroup first.
|
T I_se = x[offset];
|
||||||
// This is much faster than relying purely on atomics.
|
gix += I_se * (iy - iy_nw) * cot;
|
||||||
gix = simd_sum(gix);
|
giy += I_se * (ix - ix_nw) * cot;
|
||||||
giy = simd_sum(giy);
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if (thread_index_in_simdgroup == 0) {
|
T gix_mult = W / 2;
|
||||||
atomic_fetch_add_explicit(&grid_grad[grid_idx], gix * gix_mult, memory_order_relaxed);
|
T giy_mult = H / 2;
|
||||||
atomic_fetch_add_explicit(&grid_grad[grid_idx + 1], giy * giy_mult, memory_order_relaxed);
|
|
||||||
}
|
|
||||||
"""
|
|
||||||
kernel = mx.fast.metal_kernel(
|
|
||||||
name="grid_sample_grad",
|
|
||||||
input_names=["x", "grid", "cotangent"],
|
|
||||||
output_names=["x_grad", "grid_grad"],
|
|
||||||
source=source,
|
|
||||||
atomic_outputs=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
@grid_sample.vjp
|
// Reduce across each simdgroup first.
|
||||||
def grid_sample_vjp(primals, cotangent, _):
|
// This is much faster than relying purely on atomics.
|
||||||
x, grid = primals
|
gix = simd_sum(gix);
|
||||||
B, _, _, C = x.shape
|
giy = simd_sum(giy);
|
||||||
_, gN, gM, D = grid.shape
|
|
||||||
|
|
||||||
assert D == 2, "Last dim of `grid` must be size 2."
|
if (thread_index_in_simdgroup == 0) {
|
||||||
|
atomic_fetch_add_explicit(&grid_grad[grid_idx], gix * gix_mult, memory_order_relaxed);
|
||||||
# pad the output channels to simd group size
|
atomic_fetch_add_explicit(&grid_grad[grid_idx + 1], giy * giy_mult, memory_order_relaxed);
|
||||||
# so that our `simd_sum`s don't overlap.
|
}
|
||||||
simdgroup_size = 32
|
"""
|
||||||
C_padded = (C + simdgroup_size - 1) // simdgroup_size * simdgroup_size
|
kernel = mx.fast.metal_kernel(
|
||||||
grid_size = B * gN * gM * C_padded
|
name="grid_sample_grad",
|
||||||
outputs = kernel(
|
input_names=["x", "grid", "cotangent"],
|
||||||
inputs=[x, grid, cotangent],
|
output_names=["x_grad", "grid_grad"],
|
||||||
template=[("T", x.dtype)],
|
source=source,
|
||||||
output_shapes=[x.shape, grid.shape],
|
atomic_outputs=True,
|
||||||
output_dtypes=[x.dtype, x.dtype],
|
)
|
||||||
grid=(grid_size, 1, 1),
|
# pad the output channels to simd group size
|
||||||
threadgroup=(256, 1, 1),
|
# so that our `simd_sum`s don't overlap.
|
||||||
init_value=0,
|
simdgroup_size = 32
|
||||||
)
|
C_padded = (C + simdgroup_size - 1) // simdgroup_size * simdgroup_size
|
||||||
return outputs[0], outputs[1]
|
grid_size = B * gN * gM * C_padded
|
||||||
|
outputs = kernel(
|
||||||
|
inputs=[x, grid, cotangent],
|
||||||
|
template=[("T", x.dtype)],
|
||||||
|
output_shapes=[x.shape, grid.shape],
|
||||||
|
output_dtypes=[x.dtype, x.dtype],
|
||||||
|
grid=(grid_size, 1, 1),
|
||||||
|
threadgroup=(256, 1, 1),
|
||||||
|
init_value=0,
|
||||||
|
)
|
||||||
|
return outputs[0], outputs[1]
|
||||||
|
|
||||||
There's an even larger speed up for the vjp:
|
There's an even larger speed up for the vjp:
|
||||||
|
|
||||||
|
|||||||
@@ -93,9 +93,9 @@ Primitives
|
|||||||
^^^^^^^^^^^
|
^^^^^^^^^^^
|
||||||
|
|
||||||
A :class:`Primitive` is part of the computation graph of an :class:`array`. It
|
A :class:`Primitive` is part of the computation graph of an :class:`array`. It
|
||||||
defines how to create output arrays given input arrays. Further, a
|
defines how to create outputs arrays given a input arrays. Further, a
|
||||||
:class:`Primitive` has methods to run on the CPU or GPU and for function
|
:class:`Primitive` has methods to run on the CPU or GPU and for function
|
||||||
transformations such as ``vjp`` and ``jvp``. Let's go back to our example to be
|
transformations such as ``vjp`` and ``jvp``. Lets go back to our example to be
|
||||||
more concrete:
|
more concrete:
|
||||||
|
|
||||||
.. code-block:: C++
|
.. code-block:: C++
|
||||||
@@ -128,7 +128,7 @@ more concrete:
|
|||||||
/** The vector-Jacobian product. */
|
/** The vector-Jacobian product. */
|
||||||
std::vector<array> vjp(
|
std::vector<array> vjp(
|
||||||
const std::vector<array>& primals,
|
const std::vector<array>& primals,
|
||||||
const std::vector<array>& cotangents,
|
const array& cotan,
|
||||||
const std::vector<int>& argnums,
|
const std::vector<int>& argnums,
|
||||||
const std::vector<array>& outputs) override;
|
const std::vector<array>& outputs) override;
|
||||||
|
|
||||||
@@ -247,7 +247,9 @@ point-wise. This is captured in the templated function :meth:`axpby_impl`.
|
|||||||
float alpha_,
|
float alpha_,
|
||||||
float beta_,
|
float beta_,
|
||||||
mx::Stream stream) {
|
mx::Stream stream) {
|
||||||
out.set_data(mx::allocator::malloc(out.nbytes()));
|
// Allocate the output with `malloc_or_wait` which synchronously allocates
|
||||||
|
// memory, potentially waiting if the system is under memory pressure
|
||||||
|
out.set_data(mx::allocator::malloc_or_wait(out.nbytes()));
|
||||||
|
|
||||||
// Get the CPU command encoder and register input and output arrays
|
// Get the CPU command encoder and register input and output arrays
|
||||||
auto& encoder = mx::cpu::get_command_encoder(stream);
|
auto& encoder = mx::cpu::get_command_encoder(stream);
|
||||||
@@ -391,17 +393,17 @@ below.
|
|||||||
auto& d = metal::device(s.device);
|
auto& d = metal::device(s.device);
|
||||||
|
|
||||||
// Allocate output memory
|
// Allocate output memory
|
||||||
out.set_data(allocator::malloc(out.nbytes()));
|
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||||
|
|
||||||
// Resolve name of kernel
|
// Resolve name of kernel
|
||||||
std::ostringstream kname;
|
std::ostringstream kname;
|
||||||
kname << "axpby_" << "general_" << type_to_name(out);
|
kname << "axpby_" << "general_" << type_to_name(out);
|
||||||
|
|
||||||
// Load the metal library
|
// Make sure the metal library is available
|
||||||
auto lib = d.get_library("mlx_ext");
|
d.register_library("mlx_ext");
|
||||||
|
|
||||||
// Make a kernel from this metal library
|
// Make a kernel from this metal library
|
||||||
auto kernel = d.get_kernel(kname.str(), lib);
|
auto kernel = d.get_kernel(kname.str(), "mlx_ext");
|
||||||
|
|
||||||
// Prepare to encode kernel
|
// Prepare to encode kernel
|
||||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||||
@@ -469,7 +471,7 @@ one we just defined:
|
|||||||
const std::vector<array>& tangents,
|
const std::vector<array>& tangents,
|
||||||
const std::vector<int>& argnums) {
|
const std::vector<int>& argnums) {
|
||||||
// Forward mode diff that pushes along the tangents
|
// Forward mode diff that pushes along the tangents
|
||||||
// The jvp transform on the primitive can be built with ops
|
// The jvp transform on the primitive can built with ops
|
||||||
// that are scheduled on the same stream as the primitive
|
// that are scheduled on the same stream as the primitive
|
||||||
|
|
||||||
// If argnums = {0}, we only push along x in which case the
|
// If argnums = {0}, we only push along x in which case the
|
||||||
@@ -481,7 +483,7 @@ one we just defined:
|
|||||||
auto scale_arr = array(scale, tangents[0].dtype());
|
auto scale_arr = array(scale, tangents[0].dtype());
|
||||||
return {multiply(scale_arr, tangents[0], stream())};
|
return {multiply(scale_arr, tangents[0], stream())};
|
||||||
}
|
}
|
||||||
// If argnums = {0, 1}, we take contributions from both
|
// If, argnums = {0, 1}, we take contributions from both
|
||||||
// which gives us jvp = tangent_x * alpha + tangent_y * beta
|
// which gives us jvp = tangent_x * alpha + tangent_y * beta
|
||||||
else {
|
else {
|
||||||
return {axpby(tangents[0], tangents[1], alpha_, beta_, stream())};
|
return {axpby(tangents[0], tangents[1], alpha_, beta_, stream())};
|
||||||
@@ -735,7 +737,7 @@ Let's look at a simple script and its results:
|
|||||||
|
|
||||||
print(f"c shape: {c.shape}")
|
print(f"c shape: {c.shape}")
|
||||||
print(f"c dtype: {c.dtype}")
|
print(f"c dtype: {c.dtype}")
|
||||||
print(f"c is correct: {mx.all(c == 6.0).item()}")
|
print(f"c correct: {mx.all(c == 6.0).item()}")
|
||||||
|
|
||||||
Output:
|
Output:
|
||||||
|
|
||||||
@@ -743,7 +745,7 @@ Output:
|
|||||||
|
|
||||||
c shape: [3, 4]
|
c shape: [3, 4]
|
||||||
c dtype: float32
|
c dtype: float32
|
||||||
c is correct: True
|
c correctness: True
|
||||||
|
|
||||||
Results
|
Results
|
||||||
^^^^^^^
|
^^^^^^^
|
||||||
|
|||||||
@@ -70,7 +70,6 @@ are the CPU and GPU.
|
|||||||
python/fft
|
python/fft
|
||||||
python/linalg
|
python/linalg
|
||||||
python/metal
|
python/metal
|
||||||
python/memory_management
|
|
||||||
python/nn
|
python/nn
|
||||||
python/optimizers
|
python/optimizers
|
||||||
python/distributed
|
python/distributed
|
||||||
|
|||||||
@@ -19,8 +19,6 @@ Array
|
|||||||
array.ndim
|
array.ndim
|
||||||
array.shape
|
array.shape
|
||||||
array.size
|
array.size
|
||||||
array.real
|
|
||||||
array.imag
|
|
||||||
array.abs
|
array.abs
|
||||||
array.all
|
array.all
|
||||||
array.any
|
array.any
|
||||||
@@ -40,7 +38,6 @@ Array
|
|||||||
array.log10
|
array.log10
|
||||||
array.log1p
|
array.log1p
|
||||||
array.log2
|
array.log2
|
||||||
array.logcumsumexp
|
|
||||||
array.logsumexp
|
array.logsumexp
|
||||||
array.max
|
array.max
|
||||||
array.mean
|
array.mean
|
||||||
|
|||||||
@@ -20,5 +20,3 @@ FFT
|
|||||||
irfft2
|
irfft2
|
||||||
rfftn
|
rfftn
|
||||||
irfftn
|
irfftn
|
||||||
fftshift
|
|
||||||
ifftshift
|
|
||||||
|
|||||||
@@ -16,12 +16,9 @@ Linear Algebra
|
|||||||
cross
|
cross
|
||||||
qr
|
qr
|
||||||
svd
|
svd
|
||||||
eigvals
|
|
||||||
eig
|
|
||||||
eigvalsh
|
eigvalsh
|
||||||
eigh
|
eigh
|
||||||
lu
|
lu
|
||||||
lu_factor
|
lu_factor
|
||||||
pinv
|
|
||||||
solve
|
solve
|
||||||
solve_triangular
|
solve_triangular
|
||||||
|
|||||||
@@ -1,16 +0,0 @@
|
|||||||
Memory Management
|
|
||||||
=================
|
|
||||||
|
|
||||||
.. currentmodule:: mlx.core
|
|
||||||
|
|
||||||
.. autosummary::
|
|
||||||
:toctree: _autosummary
|
|
||||||
|
|
||||||
get_active_memory
|
|
||||||
get_peak_memory
|
|
||||||
reset_peak_memory
|
|
||||||
get_cache_memory
|
|
||||||
set_memory_limit
|
|
||||||
set_cache_limit
|
|
||||||
set_wired_limit
|
|
||||||
clear_cache
|
|
||||||
@@ -8,5 +8,13 @@ Metal
|
|||||||
|
|
||||||
is_available
|
is_available
|
||||||
device_info
|
device_info
|
||||||
|
get_active_memory
|
||||||
|
get_peak_memory
|
||||||
|
reset_peak_memory
|
||||||
|
get_cache_memory
|
||||||
|
set_memory_limit
|
||||||
|
set_cache_limit
|
||||||
|
set_wired_limit
|
||||||
|
clear_cache
|
||||||
start_capture
|
start_capture
|
||||||
stop_capture
|
stop_capture
|
||||||
|
|||||||
@@ -36,12 +36,10 @@ Operations
|
|||||||
bitwise_or
|
bitwise_or
|
||||||
bitwise_xor
|
bitwise_xor
|
||||||
block_masked_mm
|
block_masked_mm
|
||||||
broadcast_arrays
|
|
||||||
broadcast_to
|
broadcast_to
|
||||||
ceil
|
ceil
|
||||||
clip
|
clip
|
||||||
concatenate
|
concatenate
|
||||||
contiguous
|
|
||||||
conj
|
conj
|
||||||
conjugate
|
conjugate
|
||||||
convolve
|
convolve
|
||||||
@@ -103,7 +101,6 @@ Operations
|
|||||||
log10
|
log10
|
||||||
log1p
|
log1p
|
||||||
logaddexp
|
logaddexp
|
||||||
logcumsumexp
|
|
||||||
logical_not
|
logical_not
|
||||||
logical_and
|
logical_and
|
||||||
logical_or
|
logical_or
|
||||||
|
|||||||
@@ -18,4 +18,3 @@ Common Optimizers
|
|||||||
AdamW
|
AdamW
|
||||||
Adamax
|
Adamax
|
||||||
Lion
|
Lion
|
||||||
MultiOptimizer
|
|
||||||
|
|||||||
@@ -9,7 +9,6 @@ Transforms
|
|||||||
:toctree: _autosummary
|
:toctree: _autosummary
|
||||||
|
|
||||||
eval
|
eval
|
||||||
async_eval
|
|
||||||
compile
|
compile
|
||||||
custom_function
|
custom_function
|
||||||
disable_compile
|
disable_compile
|
||||||
|
|||||||
@@ -72,7 +72,9 @@ void axpby_impl(
|
|||||||
float alpha_,
|
float alpha_,
|
||||||
float beta_,
|
float beta_,
|
||||||
mx::Stream stream) {
|
mx::Stream stream) {
|
||||||
out.set_data(mx::allocator::malloc(out.nbytes()));
|
// Allocate the output with `malloc_or_wait` which synchronously allocates
|
||||||
|
// memory, potentially waiting if the system is under memory pressure
|
||||||
|
out.set_data(mx::allocator::malloc_or_wait(out.nbytes()));
|
||||||
|
|
||||||
// Get the CPU command encoder and register input and output arrays
|
// Get the CPU command encoder and register input and output arrays
|
||||||
auto& encoder = mx::cpu::get_command_encoder(stream);
|
auto& encoder = mx::cpu::get_command_encoder(stream);
|
||||||
@@ -158,12 +160,12 @@ void Axpby::eval_gpu(
|
|||||||
// Allocate output memory with strides based on specialization
|
// Allocate output memory with strides based on specialization
|
||||||
if (contiguous_kernel) {
|
if (contiguous_kernel) {
|
||||||
out.set_data(
|
out.set_data(
|
||||||
mx::allocator::malloc(x.data_size() * out.itemsize()),
|
mx::allocator::malloc_or_wait(x.data_size() * out.itemsize()),
|
||||||
x.data_size(),
|
x.data_size(),
|
||||||
x.strides(),
|
x.strides(),
|
||||||
x.flags());
|
x.flags());
|
||||||
} else {
|
} else {
|
||||||
out.set_data(mx::allocator::malloc(out.nbytes()));
|
out.set_data(mx::allocator::malloc_or_wait(out.nbytes()));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Resolve name of kernel (corresponds to axpby.metal)
|
// Resolve name of kernel (corresponds to axpby.metal)
|
||||||
@@ -172,11 +174,11 @@ void Axpby::eval_gpu(
|
|||||||
kname << (contiguous_kernel ? "contiguous_" : "general_");
|
kname << (contiguous_kernel ? "contiguous_" : "general_");
|
||||||
kname << type_to_name(out);
|
kname << type_to_name(out);
|
||||||
|
|
||||||
// Load the metal library
|
// Make sure the metal library is available
|
||||||
auto lib = d.get_library("mlx_ext");
|
d.register_library("mlx_ext");
|
||||||
|
|
||||||
// Make a kernel from this metal library
|
// Make a kernel from this metal library
|
||||||
auto kernel = d.get_kernel(kname.str(), lib);
|
auto kernel = d.get_kernel(kname.str(), "mlx_ext");
|
||||||
|
|
||||||
// Prepare to encode kernel
|
// Prepare to encode kernel
|
||||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||||
|
|||||||
@@ -5,7 +5,6 @@ target_sources(
|
|||||||
${CMAKE_CURRENT_SOURCE_DIR}/compile.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/compile.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/dtype.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/dtype.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/dtype_utils.cpp
|
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/export.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/export.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/einsum.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/einsum.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/fast.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/fast.cpp
|
||||||
@@ -21,7 +20,7 @@ target_sources(
|
|||||||
${CMAKE_CURRENT_SOURCE_DIR}/backend/metal/metal.h)
|
${CMAKE_CURRENT_SOURCE_DIR}/backend/metal/metal.h)
|
||||||
|
|
||||||
# Define MLX_VERSION only in the version.cpp file.
|
# Define MLX_VERSION only in the version.cpp file.
|
||||||
add_library(mlx_version OBJECT ${CMAKE_CURRENT_SOURCE_DIR}/version.cpp)
|
add_library(mlx_version STATIC ${CMAKE_CURRENT_SOURCE_DIR}/version.cpp)
|
||||||
target_compile_definitions(mlx_version PRIVATE MLX_VERSION="${MLX_VERSION}")
|
target_compile_definitions(mlx_version PRIVATE MLX_VERSION="${MLX_VERSION}")
|
||||||
target_link_libraries(mlx PRIVATE $<BUILD_INTERFACE:mlx_version>)
|
target_link_libraries(mlx PRIVATE $<BUILD_INTERFACE:mlx_version>)
|
||||||
|
|
||||||
@@ -49,19 +48,5 @@ add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/io)
|
|||||||
if(MLX_BUILD_METAL)
|
if(MLX_BUILD_METAL)
|
||||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/metal)
|
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/metal)
|
||||||
else()
|
else()
|
||||||
target_sources(mlx
|
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/no_metal)
|
||||||
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/backend/metal/no_metal.cpp)
|
|
||||||
endif()
|
|
||||||
|
|
||||||
if(MLX_BUILD_CUDA)
|
|
||||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/cuda)
|
|
||||||
else()
|
|
||||||
target_sources(mlx
|
|
||||||
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/backend/cuda/no_cuda.cpp)
|
|
||||||
endif()
|
|
||||||
|
|
||||||
if(MLX_BUILD_METAL OR MLX_BUILD_CUDA)
|
|
||||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/gpu)
|
|
||||||
else()
|
|
||||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/no_gpu)
|
|
||||||
endif()
|
endif()
|
||||||
|
|||||||
@@ -4,11 +4,12 @@
|
|||||||
#include <sstream>
|
#include <sstream>
|
||||||
|
|
||||||
#include "mlx/allocator.h"
|
#include "mlx/allocator.h"
|
||||||
|
#include "mlx/scheduler.h"
|
||||||
|
|
||||||
namespace mlx::core::allocator {
|
namespace mlx::core::allocator {
|
||||||
|
|
||||||
Buffer malloc(size_t size) {
|
Buffer malloc(size_t size) {
|
||||||
auto buffer = allocator().malloc(size);
|
auto buffer = allocator().malloc(size, /* allow_swap */ true);
|
||||||
if (size && !buffer.ptr()) {
|
if (size && !buffer.ptr()) {
|
||||||
std::ostringstream msg;
|
std::ostringstream msg;
|
||||||
msg << "[malloc] Unable to allocate " << size << " bytes.";
|
msg << "[malloc] Unable to allocate " << size << " bytes.";
|
||||||
@@ -21,4 +22,45 @@ void free(Buffer buffer) {
|
|||||||
allocator().free(buffer);
|
allocator().free(buffer);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Buffer CommonAllocator::malloc(size_t size, bool) {
|
||||||
|
void* ptr = std::malloc(size + sizeof(size_t));
|
||||||
|
if (ptr != nullptr) {
|
||||||
|
*static_cast<size_t*>(ptr) = size;
|
||||||
|
}
|
||||||
|
return Buffer{ptr};
|
||||||
|
}
|
||||||
|
|
||||||
|
void CommonAllocator::free(Buffer buffer) {
|
||||||
|
std::free(buffer.ptr());
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t CommonAllocator::size(Buffer buffer) const {
|
||||||
|
if (buffer.ptr() == nullptr) {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
return *static_cast<size_t*>(buffer.ptr());
|
||||||
|
}
|
||||||
|
|
||||||
|
Buffer malloc_or_wait(size_t size) {
|
||||||
|
auto buffer = allocator().malloc(size);
|
||||||
|
|
||||||
|
while (size && !buffer.ptr() && scheduler::n_active_tasks() > 0) {
|
||||||
|
scheduler::wait_for_one();
|
||||||
|
buffer = allocator().malloc(size);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try swapping if needed
|
||||||
|
if (size && !buffer.ptr()) {
|
||||||
|
buffer = allocator().malloc(size, /* allow_swap = */ true);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (size && !buffer.ptr()) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << "[malloc_or_wait] Unable to allocate " << size << " bytes.";
|
||||||
|
throw std::runtime_error(msg.str());
|
||||||
|
}
|
||||||
|
|
||||||
|
return buffer;
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace mlx::core::allocator
|
} // namespace mlx::core::allocator
|
||||||
|
|||||||
@@ -32,10 +32,14 @@ Buffer malloc(size_t size);
|
|||||||
|
|
||||||
void free(Buffer buffer);
|
void free(Buffer buffer);
|
||||||
|
|
||||||
|
// Wait for running tasks to finish and free up memory
|
||||||
|
// if allocation fails
|
||||||
|
Buffer malloc_or_wait(size_t size);
|
||||||
|
|
||||||
class Allocator {
|
class Allocator {
|
||||||
/** Abstract base class for a memory allocator. */
|
/** Abstract base class for a memory allocator. */
|
||||||
public:
|
public:
|
||||||
virtual Buffer malloc(size_t size) = 0;
|
virtual Buffer malloc(size_t size, bool allow_swap = false) = 0;
|
||||||
virtual void free(Buffer buffer) = 0;
|
virtual void free(Buffer buffer) = 0;
|
||||||
virtual size_t size(Buffer buffer) const = 0;
|
virtual size_t size(Buffer buffer) const = 0;
|
||||||
|
|
||||||
@@ -49,4 +53,16 @@ class Allocator {
|
|||||||
|
|
||||||
Allocator& allocator();
|
Allocator& allocator();
|
||||||
|
|
||||||
|
class CommonAllocator : public Allocator {
|
||||||
|
/** A general CPU allocator. */
|
||||||
|
public:
|
||||||
|
virtual Buffer malloc(size_t size, bool allow_swap = false) override;
|
||||||
|
virtual void free(Buffer buffer) override;
|
||||||
|
virtual size_t size(Buffer buffer) const override;
|
||||||
|
|
||||||
|
private:
|
||||||
|
CommonAllocator() = default;
|
||||||
|
friend Allocator& allocator();
|
||||||
|
};
|
||||||
|
|
||||||
} // namespace mlx::core::allocator
|
} // namespace mlx::core::allocator
|
||||||
|
|||||||
12
mlx/array.h
12
mlx/array.h
@@ -224,10 +224,6 @@ class array {
|
|||||||
// Not copyable
|
// Not copyable
|
||||||
Data(const Data& d) = delete;
|
Data(const Data& d) = delete;
|
||||||
Data& operator=(const Data& d) = delete;
|
Data& operator=(const Data& d) = delete;
|
||||||
Data(Data&& o) : buffer(o.buffer), d(o.d) {
|
|
||||||
o.buffer = allocator::Buffer(nullptr);
|
|
||||||
o.d = [](allocator::Buffer) {};
|
|
||||||
}
|
|
||||||
~Data() {
|
~Data() {
|
||||||
d(buffer);
|
d(buffer);
|
||||||
}
|
}
|
||||||
@@ -343,11 +339,11 @@ class array {
|
|||||||
return allocator::allocator().size(buffer());
|
return allocator::allocator().size(buffer());
|
||||||
}
|
}
|
||||||
|
|
||||||
// Return the shared pointer to the array::Data struct
|
// Return a copy of the shared pointer
|
||||||
const std::shared_ptr<Data>& data_shared_ptr() const {
|
// to the array::Data struct
|
||||||
|
std::shared_ptr<Data> data_shared_ptr() const {
|
||||||
return array_desc_->data;
|
return array_desc_->data;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Return a raw pointer to the arrays data
|
// Return a raw pointer to the arrays data
|
||||||
template <typename T>
|
template <typename T>
|
||||||
T* data() {
|
T* data() {
|
||||||
@@ -360,7 +356,7 @@ class array {
|
|||||||
}
|
}
|
||||||
|
|
||||||
enum Status {
|
enum Status {
|
||||||
// The output of a computation which has not been scheduled.
|
// The ouptut of a computation which has not been scheduled.
|
||||||
// For example, the status of `x` in `auto x = a + b`.
|
// For example, the status of `x` in `auto x = a + b`.
|
||||||
unscheduled,
|
unscheduled,
|
||||||
|
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
target_sources(
|
target_sources(
|
||||||
mlx
|
mlx
|
||||||
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/broadcasting.cpp
|
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/compiled.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/compiled.cpp
|
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/common.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/common.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/load.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/load.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp
|
||||||
|
|||||||
@@ -44,14 +44,14 @@ inline void set_binary_op_output_data(
|
|||||||
switch (bopt) {
|
switch (bopt) {
|
||||||
case BinaryOpType::ScalarScalar:
|
case BinaryOpType::ScalarScalar:
|
||||||
out.set_data(
|
out.set_data(
|
||||||
allocator::malloc(out.itemsize()), 1, a.strides(), a.flags());
|
allocator::malloc_or_wait(out.itemsize()), 1, a.strides(), a.flags());
|
||||||
break;
|
break;
|
||||||
case BinaryOpType::ScalarVector:
|
case BinaryOpType::ScalarVector:
|
||||||
if (b_donatable) {
|
if (b_donatable) {
|
||||||
out.copy_shared_buffer(b);
|
out.copy_shared_buffer(b);
|
||||||
} else {
|
} else {
|
||||||
out.set_data(
|
out.set_data(
|
||||||
allocator::malloc(b.data_size() * out.itemsize()),
|
allocator::malloc_or_wait(b.data_size() * out.itemsize()),
|
||||||
b.data_size(),
|
b.data_size(),
|
||||||
b.strides(),
|
b.strides(),
|
||||||
b.flags());
|
b.flags());
|
||||||
@@ -62,7 +62,7 @@ inline void set_binary_op_output_data(
|
|||||||
out.copy_shared_buffer(a);
|
out.copy_shared_buffer(a);
|
||||||
} else {
|
} else {
|
||||||
out.set_data(
|
out.set_data(
|
||||||
allocator::malloc(a.data_size() * out.itemsize()),
|
allocator::malloc_or_wait(a.data_size() * out.itemsize()),
|
||||||
a.data_size(),
|
a.data_size(),
|
||||||
a.strides(),
|
a.strides(),
|
||||||
a.flags());
|
a.flags());
|
||||||
@@ -75,7 +75,7 @@ inline void set_binary_op_output_data(
|
|||||||
out.copy_shared_buffer(b);
|
out.copy_shared_buffer(b);
|
||||||
} else {
|
} else {
|
||||||
out.set_data(
|
out.set_data(
|
||||||
allocator::malloc(a.data_size() * out.itemsize()),
|
allocator::malloc_or_wait(a.data_size() * out.itemsize()),
|
||||||
a.data_size(),
|
a.data_size(),
|
||||||
a.strides(),
|
a.strides(),
|
||||||
a.flags());
|
a.flags());
|
||||||
@@ -88,7 +88,7 @@ inline void set_binary_op_output_data(
|
|||||||
b_donatable && b.flags().row_contiguous && b.size() == out.size()) {
|
b_donatable && b.flags().row_contiguous && b.size() == out.size()) {
|
||||||
out.copy_shared_buffer(b);
|
out.copy_shared_buffer(b);
|
||||||
} else {
|
} else {
|
||||||
out.set_data(allocator::malloc(out.nbytes()));
|
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,24 +0,0 @@
|
|||||||
// Copyright © 2024 Apple Inc.
|
|
||||||
|
|
||||||
#include "mlx/backend/common/utils.h"
|
|
||||||
|
|
||||||
namespace mlx::core {
|
|
||||||
|
|
||||||
void broadcast(const array& in, array& out) {
|
|
||||||
if (out.size() == 0) {
|
|
||||||
out.set_data(nullptr);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
Strides strides(out.ndim(), 0);
|
|
||||||
int diff = out.ndim() - in.ndim();
|
|
||||||
for (int i = in.ndim() - 1; i >= 0; --i) {
|
|
||||||
strides[i + diff] = (in.shape()[i] == 1) ? 0 : in.strides()[i];
|
|
||||||
}
|
|
||||||
auto flags = in.flags();
|
|
||||||
if (out.size() > in.size()) {
|
|
||||||
flags.row_contiguous = flags.col_contiguous = false;
|
|
||||||
}
|
|
||||||
out.copy_shared_buffer(in, strides, flags, in.data_size());
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace mlx::core
|
|
||||||
@@ -1,11 +0,0 @@
|
|||||||
// Copyright © 2024 Apple Inc.
|
|
||||||
|
|
||||||
#pragma once
|
|
||||||
|
|
||||||
#include "mlx/array.h"
|
|
||||||
|
|
||||||
namespace mlx::core {
|
|
||||||
|
|
||||||
void broadcast(const array& in, array& out);
|
|
||||||
|
|
||||||
} // namespace mlx::core
|
|
||||||
@@ -1,157 +0,0 @@
|
|||||||
// Copyright © 2025 Apple Inc.
|
|
||||||
|
|
||||||
#pragma once
|
|
||||||
|
|
||||||
#include <cassert>
|
|
||||||
#include <functional>
|
|
||||||
#include <map>
|
|
||||||
|
|
||||||
namespace mlx::core {
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
class BufferCache {
|
|
||||||
public:
|
|
||||||
BufferCache(
|
|
||||||
size_t page_size,
|
|
||||||
std::function<size_t(T*)> get_size,
|
|
||||||
std::function<void(T*)> free)
|
|
||||||
: page_size_(page_size),
|
|
||||||
get_size_(std::move(get_size)),
|
|
||||||
free_(std::move(free)) {}
|
|
||||||
|
|
||||||
~BufferCache() {
|
|
||||||
clear();
|
|
||||||
}
|
|
||||||
|
|
||||||
BufferCache(const BufferCache&) = delete;
|
|
||||||
BufferCache& operator=(const BufferCache&) = delete;
|
|
||||||
|
|
||||||
T* reuse_from_cache(size_t size) {
|
|
||||||
// Find the closest buffer in pool.
|
|
||||||
auto it = buffer_pool_.lower_bound(size);
|
|
||||||
if (it == buffer_pool_.end() ||
|
|
||||||
it->first >= std::min(2 * size, size + 2 * page_size_)) {
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Collect from the cache.
|
|
||||||
T* buf = it->second->buf;
|
|
||||||
pool_size_ -= it->first;
|
|
||||||
|
|
||||||
// Remove from record.
|
|
||||||
remove_from_list(it->second);
|
|
||||||
buffer_pool_.erase(it);
|
|
||||||
return buf;
|
|
||||||
}
|
|
||||||
|
|
||||||
void recycle_to_cache(T* buf) {
|
|
||||||
assert(buf);
|
|
||||||
// Add to cache.
|
|
||||||
BufferHolder* bh = new BufferHolder(buf);
|
|
||||||
add_at_head(bh);
|
|
||||||
size_t size = get_size_(buf);
|
|
||||||
pool_size_ += size;
|
|
||||||
buffer_pool_.emplace(size, bh);
|
|
||||||
}
|
|
||||||
|
|
||||||
int release_cached_buffers(size_t min_bytes_to_free) {
|
|
||||||
if (min_bytes_to_free >= 0.9 * pool_size_) {
|
|
||||||
return clear();
|
|
||||||
} else {
|
|
||||||
int n_release = 0;
|
|
||||||
size_t total_bytes_freed = 0;
|
|
||||||
|
|
||||||
while (tail_ && (total_bytes_freed < min_bytes_to_free)) {
|
|
||||||
// Release buffer.
|
|
||||||
size_t size = get_size_(tail_->buf);
|
|
||||||
total_bytes_freed += size;
|
|
||||||
free_(tail_->buf);
|
|
||||||
n_release++;
|
|
||||||
|
|
||||||
// Remove from record.
|
|
||||||
auto its = buffer_pool_.equal_range(size);
|
|
||||||
auto it = std::find_if(its.first, its.second, [this](const auto& el) {
|
|
||||||
return el.second == tail_;
|
|
||||||
});
|
|
||||||
assert(it != buffer_pool_.end());
|
|
||||||
buffer_pool_.erase(it);
|
|
||||||
remove_from_list(tail_);
|
|
||||||
}
|
|
||||||
|
|
||||||
pool_size_ -= total_bytes_freed;
|
|
||||||
return n_release;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
int clear() {
|
|
||||||
int n_release = 0;
|
|
||||||
for (auto& [size, holder] : buffer_pool_) {
|
|
||||||
free_(holder->buf);
|
|
||||||
n_release++;
|
|
||||||
delete holder;
|
|
||||||
}
|
|
||||||
buffer_pool_.clear();
|
|
||||||
pool_size_ = 0;
|
|
||||||
head_ = nullptr;
|
|
||||||
tail_ = nullptr;
|
|
||||||
return n_release;
|
|
||||||
}
|
|
||||||
|
|
||||||
size_t cache_size() const {
|
|
||||||
return pool_size_;
|
|
||||||
}
|
|
||||||
|
|
||||||
size_t page_size() const {
|
|
||||||
return page_size_;
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
|
||||||
struct BufferHolder {
|
|
||||||
public:
|
|
||||||
explicit BufferHolder(T* buf_) : buf(buf_) {}
|
|
||||||
|
|
||||||
BufferHolder* prev{nullptr};
|
|
||||||
BufferHolder* next{nullptr};
|
|
||||||
T* buf;
|
|
||||||
};
|
|
||||||
|
|
||||||
void add_at_head(BufferHolder* to_add) {
|
|
||||||
if (!head_) {
|
|
||||||
head_ = to_add;
|
|
||||||
tail_ = to_add;
|
|
||||||
} else {
|
|
||||||
head_->prev = to_add;
|
|
||||||
to_add->next = head_;
|
|
||||||
head_ = to_add;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void remove_from_list(BufferHolder* to_remove) {
|
|
||||||
if (to_remove->prev && to_remove->next) { // if middle
|
|
||||||
to_remove->prev->next = to_remove->next;
|
|
||||||
to_remove->next->prev = to_remove->prev;
|
|
||||||
} else if (to_remove->prev && to_remove == tail_) { // if tail
|
|
||||||
tail_ = to_remove->prev;
|
|
||||||
tail_->next = nullptr;
|
|
||||||
} else if (to_remove == head_ && to_remove->next) { // if head
|
|
||||||
head_ = to_remove->next;
|
|
||||||
head_->prev = nullptr;
|
|
||||||
} else if (to_remove == head_ && to_remove == tail_) { // if only element
|
|
||||||
head_ = nullptr;
|
|
||||||
tail_ = nullptr;
|
|
||||||
}
|
|
||||||
|
|
||||||
delete to_remove;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::multimap<size_t, BufferHolder*> buffer_pool_;
|
|
||||||
BufferHolder* head_{nullptr};
|
|
||||||
BufferHolder* tail_{nullptr};
|
|
||||||
size_t pool_size_{0};
|
|
||||||
|
|
||||||
const size_t page_size_;
|
|
||||||
std::function<size_t(T*)> get_size_;
|
|
||||||
std::function<void(T*)> free_;
|
|
||||||
};
|
|
||||||
|
|
||||||
} // namespace mlx::core
|
|
||||||
@@ -1,7 +1,6 @@
|
|||||||
// Copyright © 2024 Apple Inc.
|
// Copyright © 2024 Apple Inc.
|
||||||
#include <cassert>
|
#include <cassert>
|
||||||
|
|
||||||
#include "mlx/backend/common/broadcasting.h"
|
|
||||||
#include "mlx/backend/common/utils.h"
|
#include "mlx/backend/common/utils.h"
|
||||||
#include "mlx/primitives.h"
|
#include "mlx/primitives.h"
|
||||||
|
|
||||||
@@ -43,6 +42,23 @@ void AsStrided::eval(const std::vector<array>& inputs, array& out) {
|
|||||||
return out.copy_shared_buffer(in, strides_, flags, data_size, offset_);
|
return out.copy_shared_buffer(in, strides_, flags, data_size, offset_);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void broadcast(const array& in, array& out) {
|
||||||
|
if (out.size() == 0) {
|
||||||
|
out.set_data(nullptr);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
Strides strides(out.ndim(), 0);
|
||||||
|
int diff = out.ndim() - in.ndim();
|
||||||
|
for (int i = in.ndim() - 1; i >= 0; --i) {
|
||||||
|
strides[i + diff] = (in.shape()[i] == 1) ? 0 : in.strides()[i];
|
||||||
|
}
|
||||||
|
auto flags = in.flags();
|
||||||
|
if (out.size() > in.size()) {
|
||||||
|
flags.row_contiguous = flags.col_contiguous = false;
|
||||||
|
}
|
||||||
|
out.copy_shared_buffer(in, strides, flags, in.data_size());
|
||||||
|
}
|
||||||
|
|
||||||
void Broadcast::eval(const std::vector<array>& inputs, array& out) {
|
void Broadcast::eval(const std::vector<array>& inputs, array& out) {
|
||||||
broadcast(inputs[0], out);
|
broadcast(inputs[0], out);
|
||||||
}
|
}
|
||||||
@@ -87,7 +103,7 @@ void ExpandDims::eval(const std::vector<array>& inputs, array& out) {
|
|||||||
|
|
||||||
void NumberOfElements::eval(const std::vector<array>& inputs, array& out) {
|
void NumberOfElements::eval(const std::vector<array>& inputs, array& out) {
|
||||||
assert(inputs.size() == 1);
|
assert(inputs.size() == 1);
|
||||||
out.set_data(allocator::malloc(out.nbytes()));
|
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||||
|
|
||||||
double numel = 1;
|
double numel = 1;
|
||||||
for (auto ax : axes_) {
|
for (auto ax : axes_) {
|
||||||
|
|||||||
@@ -1,7 +1,8 @@
|
|||||||
// Copyright © 2023-2024 Apple Inc.
|
// Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
#include "mlx/backend/common/compiled.h"
|
#include "mlx/backend/common/compiled.h"
|
||||||
#include "mlx/backend/common/utils.h"
|
#include "mlx/graph_utils.h"
|
||||||
|
#include "mlx/primitives.h"
|
||||||
#include "mlx/utils.h"
|
#include "mlx/utils.h"
|
||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
@@ -78,6 +79,55 @@ std::string get_type_string(Dtype d) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::string build_lib_name(
|
||||||
|
const std::vector<array>& inputs,
|
||||||
|
const std::vector<array>& outputs,
|
||||||
|
const std::vector<array>& tape,
|
||||||
|
const std::unordered_set<uintptr_t>& constant_ids) {
|
||||||
|
NodeNamer namer;
|
||||||
|
std::ostringstream os;
|
||||||
|
std::ostringstream constant_hasher;
|
||||||
|
|
||||||
|
// Fill the input names. This is not really necessary, I just like having A,
|
||||||
|
// B, C, ... as the inputs.
|
||||||
|
for (auto& x : inputs) {
|
||||||
|
namer.get_name(x);
|
||||||
|
}
|
||||||
|
|
||||||
|
// The primitives describing the tape. For unary and binary primitives this
|
||||||
|
// must be enough to describe the full computation.
|
||||||
|
for (auto& a : tape) {
|
||||||
|
// name and type of output
|
||||||
|
os << namer.get_name(a) << kindof(a.dtype()) << a.itemsize();
|
||||||
|
// computation performed
|
||||||
|
a.primitive().print(os);
|
||||||
|
// name of inputs to the function
|
||||||
|
for (auto& inp : a.inputs()) {
|
||||||
|
os << namer.get_name(inp);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
os << "_";
|
||||||
|
|
||||||
|
for (auto& x : inputs) {
|
||||||
|
if (constant_ids.find(x.id()) != constant_ids.end()) {
|
||||||
|
os << "C";
|
||||||
|
print_constant(constant_hasher, x);
|
||||||
|
} else {
|
||||||
|
os << (is_scalar(x) ? "S" : "V");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
os << "_";
|
||||||
|
for (auto& x : inputs) {
|
||||||
|
if (constant_ids.find(x.id()) != constant_ids.end()) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
os << kindof(x.dtype()) << x.itemsize();
|
||||||
|
}
|
||||||
|
os << "_" << std::hash<std::string>{}(constant_hasher.str());
|
||||||
|
|
||||||
|
return os.str();
|
||||||
|
}
|
||||||
|
|
||||||
bool compiled_check_contiguity(
|
bool compiled_check_contiguity(
|
||||||
const std::vector<array>& inputs,
|
const std::vector<array>& inputs,
|
||||||
const Shape& shape) {
|
const Shape& shape) {
|
||||||
@@ -109,7 +159,8 @@ bool compiled_check_contiguity(
|
|||||||
void compiled_allocate_outputs(
|
void compiled_allocate_outputs(
|
||||||
const std::vector<array>& inputs,
|
const std::vector<array>& inputs,
|
||||||
std::vector<array>& outputs,
|
std::vector<array>& outputs,
|
||||||
const std::function<bool(size_t)>& is_constant,
|
const std::vector<array>& inputs_,
|
||||||
|
const std::unordered_set<uintptr_t>& constant_ids_,
|
||||||
bool contiguous) {
|
bool contiguous) {
|
||||||
if (contiguous) {
|
if (contiguous) {
|
||||||
int o = 0;
|
int o = 0;
|
||||||
@@ -124,7 +175,8 @@ void compiled_allocate_outputs(
|
|||||||
// - Donatable
|
// - Donatable
|
||||||
// - Not a constant
|
// - Not a constant
|
||||||
if (in.itemsize() == outputs[o].itemsize() && !is_scalar(in) &&
|
if (in.itemsize() == outputs[o].itemsize() && !is_scalar(in) &&
|
||||||
in.is_donatable() && is_constant(i)) {
|
in.is_donatable() &&
|
||||||
|
constant_ids_.find(inputs_[i].id()) == constant_ids_.end()) {
|
||||||
outputs[o++].copy_shared_buffer(in);
|
outputs[o++].copy_shared_buffer(in);
|
||||||
}
|
}
|
||||||
// Get representative input flags to properly set non-donated outputs
|
// Get representative input flags to properly set non-donated outputs
|
||||||
@@ -136,7 +188,7 @@ void compiled_allocate_outputs(
|
|||||||
}
|
}
|
||||||
for (; o < outputs.size(); ++o) {
|
for (; o < outputs.size(); ++o) {
|
||||||
outputs[o].set_data(
|
outputs[o].set_data(
|
||||||
allocator::malloc(data_size * outputs[o].itemsize()),
|
allocator::malloc_or_wait(data_size * outputs[o].itemsize()),
|
||||||
data_size,
|
data_size,
|
||||||
strides,
|
strides,
|
||||||
flags);
|
flags);
|
||||||
@@ -152,86 +204,16 @@ void compiled_allocate_outputs(
|
|||||||
// - Not a constant
|
// - Not a constant
|
||||||
if (in.flags().row_contiguous && in.size() == outputs[o].size() &&
|
if (in.flags().row_contiguous && in.size() == outputs[o].size() &&
|
||||||
in.itemsize() == outputs[o].itemsize() && in.is_donatable() &&
|
in.itemsize() == outputs[o].itemsize() && in.is_donatable() &&
|
||||||
is_constant(i)) {
|
constant_ids_.find(inputs_[i].id()) == constant_ids_.end()) {
|
||||||
outputs[o].copy_shared_buffer(
|
outputs[o].copy_shared_buffer(
|
||||||
in, outputs[o].strides(), in.flags(), in.data_size());
|
in, outputs[o].strides(), in.flags(), in.data_size());
|
||||||
o++;
|
o++;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
for (; o < outputs.size(); ++o) {
|
for (; o < outputs.size(); ++o) {
|
||||||
outputs[o].set_data(allocator::malloc(outputs[o].nbytes()));
|
outputs[o].set_data(allocator::malloc_or_wait(outputs[o].nbytes()));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
std::tuple<bool, Shape, std::vector<Strides>> compiled_collapse_contiguous_dims(
|
|
||||||
const std::vector<array>& inputs,
|
|
||||||
const array& out,
|
|
||||||
const std::function<bool(size_t)>& is_constant) {
|
|
||||||
const Shape& shape = out.shape();
|
|
||||||
bool contiguous = compiled_check_contiguity(inputs, shape);
|
|
||||||
if (contiguous) {
|
|
||||||
return {true, shape, {}};
|
|
||||||
}
|
|
||||||
|
|
||||||
std::vector<Strides> strides_vec{out.strides()};
|
|
||||||
for (size_t i = 0; i < inputs.size(); ++i) {
|
|
||||||
// Skip constants.
|
|
||||||
if (is_constant(i)) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Skip scalar inputs.
|
|
||||||
const auto& x = inputs[i];
|
|
||||||
if (is_scalar(x)) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Broadcast the inputs to the output shape.
|
|
||||||
Strides xstrides;
|
|
||||||
size_t j = 0;
|
|
||||||
for (; j < shape.size() - x.ndim(); ++j) {
|
|
||||||
if (shape[j] == 1) {
|
|
||||||
xstrides.push_back(out.strides()[j]);
|
|
||||||
} else {
|
|
||||||
xstrides.push_back(0);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
for (size_t i = 0; i < x.ndim(); ++i, ++j) {
|
|
||||||
if (x.shape(i) == 1) {
|
|
||||||
if (shape[j] == 1) {
|
|
||||||
xstrides.push_back(out.strides()[j]);
|
|
||||||
} else {
|
|
||||||
xstrides.push_back(0);
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
xstrides.push_back(x.strides()[i]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
strides_vec.push_back(std::move(xstrides));
|
|
||||||
}
|
|
||||||
|
|
||||||
auto tup = collapse_contiguous_dims(shape, strides_vec, INT32_MAX);
|
|
||||||
return {false, std::move(std::get<0>(tup)), std::move(std::get<1>(tup))};
|
|
||||||
}
|
|
||||||
|
|
||||||
bool compiled_use_large_index(
|
|
||||||
const std::vector<array>& inputs,
|
|
||||||
const std::vector<array>& outputs,
|
|
||||||
bool contiguous) {
|
|
||||||
if (contiguous) {
|
|
||||||
size_t max_size = 0;
|
|
||||||
for (const auto& in : inputs) {
|
|
||||||
max_size = std::max(max_size, in.data_size());
|
|
||||||
}
|
|
||||||
return max_size > UINT32_MAX;
|
|
||||||
} else {
|
|
||||||
size_t max_size = 0;
|
|
||||||
for (const auto& o : outputs) {
|
|
||||||
max_size = std::max(max_size, o.size());
|
|
||||||
}
|
|
||||||
return max_size > UINT32_MAX;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
|||||||
@@ -1,8 +1,9 @@
|
|||||||
// Copyright © 2023-2024 Apple Inc.
|
// Copyright © 2023-2024 Apple Inc.
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include <functional>
|
|
||||||
#include <iomanip>
|
#include <iomanip>
|
||||||
|
#include <sstream>
|
||||||
|
#include <unordered_set>
|
||||||
|
|
||||||
#include "mlx/array.h"
|
#include "mlx/array.h"
|
||||||
#include "mlx/primitives.h"
|
#include "mlx/primitives.h"
|
||||||
@@ -13,6 +14,12 @@ inline bool is_static_cast(const Primitive& p) {
|
|||||||
return (typeid(p) == typeid(Broadcast) || typeid(p) == typeid(AsType));
|
return (typeid(p) == typeid(Broadcast) || typeid(p) == typeid(AsType));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::string build_lib_name(
|
||||||
|
const std::vector<array>& inputs,
|
||||||
|
const std::vector<array>& outputs,
|
||||||
|
const std::vector<array>& tape,
|
||||||
|
const std::unordered_set<uintptr_t>& constant_ids);
|
||||||
|
|
||||||
std::string get_type_string(Dtype d);
|
std::string get_type_string(Dtype d);
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
@@ -53,19 +60,8 @@ bool compiled_check_contiguity(
|
|||||||
void compiled_allocate_outputs(
|
void compiled_allocate_outputs(
|
||||||
const std::vector<array>& inputs,
|
const std::vector<array>& inputs,
|
||||||
std::vector<array>& outputs,
|
std::vector<array>& outputs,
|
||||||
const std::function<bool(size_t)>& is_constant,
|
const std::vector<array>& inputs_,
|
||||||
bool contiguous);
|
const std::unordered_set<uintptr_t>& constant_ids_,
|
||||||
|
|
||||||
// Collapse contiguous dims ignoring scalars and constants.
|
|
||||||
std::tuple<bool, Shape, std::vector<Strides>> compiled_collapse_contiguous_dims(
|
|
||||||
const std::vector<array>& inputs,
|
|
||||||
const array& out,
|
|
||||||
const std::function<bool(size_t)>& is_constant);
|
|
||||||
|
|
||||||
// Return whether the kernel should use large index.
|
|
||||||
bool compiled_use_large_index(
|
|
||||||
const std::vector<array>& inputs,
|
|
||||||
const std::vector<array>& outputs,
|
|
||||||
bool contiguous);
|
bool contiguous);
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
|||||||
@@ -2,7 +2,7 @@
|
|||||||
|
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include "mlx/backend/common/utils.h"
|
#include "mlx/array.h"
|
||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
|
|
||||||
@@ -26,19 +26,19 @@ inline bool set_copy_output_data(const array& in, array& out, CopyType ctype) {
|
|||||||
if (ctype == CopyType::Vector) {
|
if (ctype == CopyType::Vector) {
|
||||||
// If the input is donateable, we are doing a vector copy and the types
|
// If the input is donateable, we are doing a vector copy and the types
|
||||||
// have the same size, then the input buffer can hold the output.
|
// have the same size, then the input buffer can hold the output.
|
||||||
if (is_donatable(in, out)) {
|
if (in.is_donatable() && in.itemsize() == out.itemsize()) {
|
||||||
out.copy_shared_buffer(in);
|
out.copy_shared_buffer(in);
|
||||||
return true;
|
return true;
|
||||||
} else {
|
} else {
|
||||||
out.set_data(
|
out.set_data(
|
||||||
allocator::malloc(in.data_size() * out.itemsize()),
|
allocator::malloc_or_wait(in.data_size() * out.itemsize()),
|
||||||
in.data_size(),
|
in.data_size(),
|
||||||
in.strides(),
|
in.strides(),
|
||||||
in.flags());
|
in.flags());
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
out.set_data(allocator::malloc(out.nbytes()));
|
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -99,11 +99,7 @@ inline std::pair<int, int> decompose_hadamard(int n) {
|
|||||||
"[hadamard] Only supports n = m*2^k where m in (1, 12, 20, 28).");
|
"[hadamard] Only supports n = m*2^k where m in (1, 12, 20, 28).");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (n > (1 << 26)) {
|
|
||||||
throw std::invalid_argument(
|
|
||||||
"[hadamard] Only supports n = m*2^k where k <= 26");
|
|
||||||
}
|
|
||||||
return {n, m};
|
return {n, m};
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
@@ -28,7 +28,7 @@ void swap_endianness(uint8_t* data_bytes, size_t N) {
|
|||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
|
|
||||||
void Load::eval_cpu(const std::vector<array>& inputs, array& out) {
|
void Load::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||||
out.set_data(allocator::malloc(out.nbytes()));
|
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||||
auto read_task = [out_ptr = out.data<char>(),
|
auto read_task = [out_ptr = out.data<char>(),
|
||||||
size = out.size(),
|
size = out.size(),
|
||||||
itemsize = out.itemsize(),
|
itemsize = out.itemsize(),
|
||||||
|
|||||||
@@ -1,78 +0,0 @@
|
|||||||
// Copyright © 2025 Apple Inc.
|
|
||||||
|
|
||||||
#pragma once
|
|
||||||
|
|
||||||
#include "mlx/backend/common/utils.h"
|
|
||||||
#include "mlx/utils.h"
|
|
||||||
|
|
||||||
#include <sstream>
|
|
||||||
|
|
||||||
namespace mlx::core {
|
|
||||||
|
|
||||||
inline std::tuple<Shape, Strides, Strides> collapse_batches(
|
|
||||||
const array& a,
|
|
||||||
const array& b) {
|
|
||||||
// Get and check the shape for the batched dims
|
|
||||||
Shape A_bshape{a.shape().begin(), a.shape().end() - 2};
|
|
||||||
Shape B_bshape{b.shape().begin(), b.shape().end() - 2};
|
|
||||||
if (A_bshape != B_bshape) {
|
|
||||||
std::ostringstream msg;
|
|
||||||
msg << "[matmul] Got matrices with incorrectly broadcasted shapes: " << "A "
|
|
||||||
<< a.shape() << ", B " << b.shape() << ".";
|
|
||||||
throw std::runtime_error(msg.str());
|
|
||||||
}
|
|
||||||
|
|
||||||
Strides A_bstride{a.strides().begin(), a.strides().end() - 2};
|
|
||||||
Strides B_bstride{b.strides().begin(), b.strides().end() - 2};
|
|
||||||
|
|
||||||
auto [batch_shape, batch_strides] =
|
|
||||||
collapse_contiguous_dims(A_bshape, std::vector{A_bstride, B_bstride});
|
|
||||||
|
|
||||||
auto a_batch_strides = batch_strides[0];
|
|
||||||
auto b_batch_strides = batch_strides[1];
|
|
||||||
|
|
||||||
if (batch_shape.empty()) {
|
|
||||||
batch_shape.push_back(1);
|
|
||||||
a_batch_strides.push_back(0);
|
|
||||||
b_batch_strides.push_back(0);
|
|
||||||
}
|
|
||||||
|
|
||||||
return std::make_tuple(batch_shape, a_batch_strides, b_batch_strides);
|
|
||||||
}
|
|
||||||
|
|
||||||
inline std::tuple<Shape, Strides, Strides, Strides>
|
|
||||||
collapse_batches(const array& a, const array& b, const array& c) {
|
|
||||||
// Get and check the shape for the batched dims
|
|
||||||
Shape A_bshape{a.shape().begin(), a.shape().end() - 2};
|
|
||||||
Shape B_bshape{b.shape().begin(), b.shape().end() - 2};
|
|
||||||
Shape C_bshape{c.shape().begin(), c.shape().end() - 2};
|
|
||||||
if (A_bshape != B_bshape || A_bshape != C_bshape) {
|
|
||||||
std::ostringstream msg;
|
|
||||||
msg << "[addmm] Got matrices with incorrectly broadcasted shapes: " << "A "
|
|
||||||
<< a.shape() << ", B " << b.shape() << ", B " << c.shape() << ".";
|
|
||||||
throw std::runtime_error(msg.str());
|
|
||||||
}
|
|
||||||
|
|
||||||
Strides A_bstride{a.strides().begin(), a.strides().end() - 2};
|
|
||||||
Strides B_bstride{b.strides().begin(), b.strides().end() - 2};
|
|
||||||
Strides C_bstride{c.strides().begin(), c.strides().end() - 2};
|
|
||||||
|
|
||||||
auto [batch_shape, batch_strides] = collapse_contiguous_dims(
|
|
||||||
A_bshape, std::vector{A_bstride, B_bstride, C_bstride});
|
|
||||||
|
|
||||||
auto A_batch_stride = batch_strides[0];
|
|
||||||
auto B_batch_stride = batch_strides[1];
|
|
||||||
auto C_batch_stride = batch_strides[2];
|
|
||||||
|
|
||||||
if (batch_shape.empty()) {
|
|
||||||
batch_shape.push_back(1);
|
|
||||||
A_batch_stride.push_back(0);
|
|
||||||
B_batch_stride.push_back(0);
|
|
||||||
C_batch_stride.push_back(0);
|
|
||||||
}
|
|
||||||
|
|
||||||
return std::make_tuple(
|
|
||||||
batch_shape, A_batch_stride, B_batch_stride, C_batch_stride);
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace mlx::core
|
|
||||||
@@ -48,12 +48,12 @@ inline void set_ternary_op_output_data(
|
|||||||
switch (topt) {
|
switch (topt) {
|
||||||
case TernaryOpType::ScalarScalarScalar:
|
case TernaryOpType::ScalarScalarScalar:
|
||||||
out.set_data(
|
out.set_data(
|
||||||
allocator::malloc(out.itemsize()), 1, b.strides(), b.flags());
|
allocator::malloc_or_wait(out.itemsize()), 1, b.strides(), b.flags());
|
||||||
break;
|
break;
|
||||||
case TernaryOpType::VectorVectorVector:
|
case TernaryOpType::VectorVectorVector:
|
||||||
if (!(maybe_donate(a) || maybe_donate(b) || maybe_donate(c))) {
|
if (!(maybe_donate(a) || maybe_donate(b) || maybe_donate(c))) {
|
||||||
out.set_data(
|
out.set_data(
|
||||||
allocator::malloc(out.itemsize() * b.data_size()),
|
allocator::malloc_or_wait(out.itemsize() * b.data_size()),
|
||||||
b.data_size(),
|
b.data_size(),
|
||||||
b.strides(),
|
b.strides(),
|
||||||
b.flags());
|
b.flags());
|
||||||
@@ -64,7 +64,7 @@ inline void set_ternary_op_output_data(
|
|||||||
if (!((a.flags().row_contiguous && maybe_donate(a)) ||
|
if (!((a.flags().row_contiguous && maybe_donate(a)) ||
|
||||||
(b.flags().row_contiguous && maybe_donate(b)) ||
|
(b.flags().row_contiguous && maybe_donate(b)) ||
|
||||||
(c.flags().row_contiguous && maybe_donate(c)))) {
|
(c.flags().row_contiguous && maybe_donate(c)))) {
|
||||||
out.set_data(allocator::malloc(out.nbytes()));
|
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,26 +0,0 @@
|
|||||||
// Copyright © 2025 Apple Inc.
|
|
||||||
|
|
||||||
#pragma once
|
|
||||||
|
|
||||||
#include "mlx/allocator.h"
|
|
||||||
#include "mlx/backend/common/utils.h"
|
|
||||||
|
|
||||||
namespace mlx::core {
|
|
||||||
|
|
||||||
inline void set_unary_output_data(const array& in, array& out) {
|
|
||||||
if (in.flags().contiguous) {
|
|
||||||
if (is_donatable(in, out)) {
|
|
||||||
out.copy_shared_buffer(in);
|
|
||||||
} else {
|
|
||||||
out.set_data(
|
|
||||||
allocator::malloc(in.data_size() * out.itemsize()),
|
|
||||||
in.data_size(),
|
|
||||||
in.strides(),
|
|
||||||
in.flags());
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
out.set_data(allocator::malloc(out.nbytes()));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace mlx::core
|
|
||||||
@@ -1,16 +1,9 @@
|
|||||||
// Copyright © 2023-2024 Apple Inc.
|
// Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
#include "mlx/backend/common/utils.h"
|
#include "mlx/backend/common/utils.h"
|
||||||
#include "mlx/primitives.h"
|
|
||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
|
|
||||||
std::string get_primitive_string(Primitive* primitive) {
|
|
||||||
std::ostringstream op_t;
|
|
||||||
primitive->print(op_t);
|
|
||||||
return op_t.str();
|
|
||||||
}
|
|
||||||
|
|
||||||
std::tuple<Shape, std::vector<Strides>> collapse_contiguous_dims(
|
std::tuple<Shape, std::vector<Strides>> collapse_contiguous_dims(
|
||||||
const Shape& shape,
|
const Shape& shape,
|
||||||
const std::vector<Strides>& strides,
|
const std::vector<Strides>& strides,
|
||||||
@@ -108,115 +101,4 @@ std::pair<Shape, Strides> collapse_contiguous_dims(
|
|||||||
return collapse_contiguous_dims(a.shape(), a.strides(), size_cap);
|
return collapse_contiguous_dims(a.shape(), a.strides(), size_cap);
|
||||||
}
|
}
|
||||||
|
|
||||||
Dims get_block_dims_common(int dim0, int dim1, int dim2, int pow2 /* = 10 */) {
|
|
||||||
int pows[3] = {0, 0, 0};
|
|
||||||
int sum = 0;
|
|
||||||
while (true) {
|
|
||||||
int presum = sum;
|
|
||||||
// Check all the pows
|
|
||||||
if (dim0 >= (1 << (pows[0] + 1))) {
|
|
||||||
pows[0]++;
|
|
||||||
sum++;
|
|
||||||
}
|
|
||||||
if (sum == 10) {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
if (dim1 >= (1 << (pows[1] + 1))) {
|
|
||||||
pows[1]++;
|
|
||||||
sum++;
|
|
||||||
}
|
|
||||||
if (sum == 10) {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
if (dim2 >= (1 << (pows[2] + 1))) {
|
|
||||||
pows[2]++;
|
|
||||||
sum++;
|
|
||||||
}
|
|
||||||
if (sum == presum || sum == pow2) {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return std::make_tuple(1ul << pows[0], 1ul << pows[1], 1ul << pows[2]);
|
|
||||||
}
|
|
||||||
|
|
||||||
Dims get_2d_grid_dims_common(const Shape& shape, const Strides& strides) {
|
|
||||||
// Dims with strides of 0 are ignored as they
|
|
||||||
// correspond to broadcasted dimensions
|
|
||||||
size_t grid_x = 1;
|
|
||||||
size_t grid_y = 1;
|
|
||||||
for (int i = 0; i < shape.size(); ++i) {
|
|
||||||
if (strides[i] == 0) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
if (grid_x * shape[i] < UINT32_MAX) {
|
|
||||||
grid_x *= shape[i];
|
|
||||||
} else {
|
|
||||||
grid_y *= shape[i];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (grid_y > UINT32_MAX || grid_x > UINT32_MAX) {
|
|
||||||
throw std::runtime_error("Unable to safely factor shape.");
|
|
||||||
}
|
|
||||||
if (grid_y > grid_x) {
|
|
||||||
std::swap(grid_x, grid_y);
|
|
||||||
}
|
|
||||||
return std::make_tuple(
|
|
||||||
static_cast<uint32_t>(grid_x), static_cast<uint32_t>(grid_y), 1);
|
|
||||||
}
|
|
||||||
|
|
||||||
Dims get_2d_grid_dims_common(
|
|
||||||
const Shape& shape,
|
|
||||||
const Strides& strides,
|
|
||||||
size_t divisor) {
|
|
||||||
// Compute the 2d grid dimensions such that the total size of the grid is
|
|
||||||
// divided by divisor.
|
|
||||||
size_t grid_x = 1;
|
|
||||||
size_t grid_y = 1;
|
|
||||||
for (int i = 0; i < shape.size(); ++i) {
|
|
||||||
if (strides[i] == 0) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
// No need to add this shape we can just remove it from the divisor.
|
|
||||||
if (divisor % shape[i] == 0) {
|
|
||||||
divisor /= shape[i];
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (grid_x * shape[i] < UINT32_MAX) {
|
|
||||||
grid_x *= shape[i];
|
|
||||||
} else {
|
|
||||||
grid_y *= shape[i];
|
|
||||||
}
|
|
||||||
|
|
||||||
if (divisor > 1) {
|
|
||||||
if (grid_x % divisor == 0) {
|
|
||||||
grid_x /= divisor;
|
|
||||||
divisor = 1;
|
|
||||||
} else if (grid_y % divisor == 0) {
|
|
||||||
grid_y /= divisor;
|
|
||||||
divisor = 1;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (grid_y > UINT32_MAX || grid_x > UINT32_MAX || divisor > 1) {
|
|
||||||
throw std::runtime_error("Unable to safely factor shape.");
|
|
||||||
}
|
|
||||||
if (grid_y > grid_x) {
|
|
||||||
std::swap(grid_x, grid_y);
|
|
||||||
}
|
|
||||||
return std::make_tuple(
|
|
||||||
static_cast<uint32_t>(grid_x), static_cast<uint32_t>(grid_y), 1);
|
|
||||||
}
|
|
||||||
|
|
||||||
std::pair<Dims, Dims> get_grid_and_block_common(int dim0, int dim1, int dim2) {
|
|
||||||
auto [bx, by, bz] = get_block_dims_common(dim0, dim1, dim2);
|
|
||||||
auto gx = (dim0 + bx - 1) / bx;
|
|
||||||
auto gy = (dim1 + by - 1) / by;
|
|
||||||
auto gz = (dim2 + bz - 1) / bz;
|
|
||||||
|
|
||||||
return std::make_pair(
|
|
||||||
std::make_tuple(gx, gy, gz), std::make_tuple(bx, by, bz));
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
|||||||
@@ -2,15 +2,12 @@
|
|||||||
|
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include <tuple>
|
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "mlx/array.h"
|
#include "mlx/array.h"
|
||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
|
|
||||||
std::string get_primitive_string(Primitive* primitive);
|
|
||||||
|
|
||||||
inline int64_t
|
inline int64_t
|
||||||
elem_to_loc(int elem, const Shape& shape, const Strides& strides) {
|
elem_to_loc(int elem, const Shape& shape, const Strides& strides) {
|
||||||
int64_t loc = 0;
|
int64_t loc = 0;
|
||||||
@@ -73,31 +70,6 @@ std::pair<Shape, Strides> collapse_contiguous_dims(
|
|||||||
const array& a,
|
const array& a,
|
||||||
int64_t size_cap = std::numeric_limits<int32_t>::max());
|
int64_t size_cap = std::numeric_limits<int32_t>::max());
|
||||||
|
|
||||||
// Compute the thread block dimensions which fit the given
|
|
||||||
// input dimensions.
|
|
||||||
// - The thread block dimensions will be powers of two
|
|
||||||
// - The thread block size will be less than 2^pow2
|
|
||||||
using Dims = std::tuple<uint32_t, uint32_t, uint32_t>;
|
|
||||||
Dims get_block_dims_common(int dim0, int dim1, int dim2, int pow2 = 10);
|
|
||||||
|
|
||||||
// Computes a 2D grid where each element is < UINT_MAX
|
|
||||||
// Assumes:
|
|
||||||
// - overall size (product of non-broadcasted dimensions) is < UINT_MAX^2
|
|
||||||
// - shape and strides correspond to a contiguous (no holes) but
|
|
||||||
// possibly broadcasted array
|
|
||||||
Dims get_2d_grid_dims_common(const Shape& shape, const Strides& strides);
|
|
||||||
|
|
||||||
// Same as above but we do an implicit division with divisor.
|
|
||||||
// Basically, equivalent to factorizing
|
|
||||||
// Prod(s \forall s in shape if strides[s] > 0) / divisor.
|
|
||||||
Dims get_2d_grid_dims_common(
|
|
||||||
const Shape& shape,
|
|
||||||
const Strides& strides,
|
|
||||||
size_t divisor);
|
|
||||||
|
|
||||||
// Get both the block and a grid of blocks that covers dim0, dim1 and dim2.
|
|
||||||
std::pair<Dims, Dims> get_grid_and_block_common(int dim0, int dim1, int dim2);
|
|
||||||
|
|
||||||
struct ContiguousIterator {
|
struct ContiguousIterator {
|
||||||
inline void step() {
|
inline void step() {
|
||||||
int dims = shape_.size();
|
int dims = shape_.size();
|
||||||
@@ -193,11 +165,4 @@ void shared_buffer_reshape(
|
|||||||
const array& in,
|
const array& in,
|
||||||
const Strides& out_strides,
|
const Strides& out_strides,
|
||||||
array& out);
|
array& out);
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
inline std::vector<T> remove_index(std::vector<T> vec, size_t index) {
|
|
||||||
vec.erase(std::next(vec.begin(), index));
|
|
||||||
return vec;
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
|||||||
@@ -40,13 +40,11 @@ add_dependencies(mlx cpu_compiled_preamble)
|
|||||||
|
|
||||||
target_sources(
|
target_sources(
|
||||||
mlx
|
mlx
|
||||||
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/available.cpp
|
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/arg_reduce.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/arg_reduce.cpp
|
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/binary.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/binary.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/distributed.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/distributed.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/eig.cpp
|
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/eigh.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/eigh.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/encoder.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/encoder.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp
|
||||||
@@ -60,7 +58,6 @@ target_sources(
|
|||||||
${CMAKE_CURRENT_SOURCE_DIR}/scan.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/scan.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/select.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/select.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/softmax.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/softmax.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/logsumexp.cpp
|
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/sort.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/sort.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/threefry.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/threefry.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp
|
||||||
@@ -76,8 +73,8 @@ target_sources(
|
|||||||
if(MLX_BUILD_ACCELERATE)
|
if(MLX_BUILD_ACCELERATE)
|
||||||
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/gemms/bnns.cpp)
|
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/gemms/bnns.cpp)
|
||||||
else()
|
else()
|
||||||
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/gemms/simd_fp16.cpp
|
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/gemms/no_fp16.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/gemms/simd_bf16.cpp)
|
${CMAKE_CURRENT_SOURCE_DIR}/gemms/no_bf16.cpp)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
if(IOS)
|
if(IOS)
|
||||||
|
|||||||
@@ -14,8 +14,10 @@ template <typename InT, typename OpT>
|
|||||||
void arg_reduce(const array& in, array& out, const OpT& op, int axis) {
|
void arg_reduce(const array& in, array& out, const OpT& op, int axis) {
|
||||||
auto axis_size = in.shape()[axis];
|
auto axis_size = in.shape()[axis];
|
||||||
auto axis_stride = in.strides()[axis];
|
auto axis_stride = in.strides()[axis];
|
||||||
Strides strides = remove_index(in.strides(), axis);
|
Strides strides = in.strides();
|
||||||
Shape shape = remove_index(in.shape(), axis);
|
Shape shape = in.shape();
|
||||||
|
strides.erase(strides.begin() + axis);
|
||||||
|
shape.erase(shape.begin() + axis);
|
||||||
auto in_ptr = in.data<InT>();
|
auto in_ptr = in.data<InT>();
|
||||||
auto out_ptr = out.data<uint32_t>();
|
auto out_ptr = out.data<uint32_t>();
|
||||||
|
|
||||||
@@ -66,7 +68,7 @@ void arg_reduce_dispatch(
|
|||||||
void ArgReduce::eval_cpu(const std::vector<array>& inputs, array& out) {
|
void ArgReduce::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||||
assert(inputs.size() == 1);
|
assert(inputs.size() == 1);
|
||||||
auto& in = inputs[0];
|
auto& in = inputs[0];
|
||||||
out.set_data(allocator::malloc(out.nbytes()));
|
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||||
auto& encoder = cpu::get_command_encoder(stream());
|
auto& encoder = cpu::get_command_encoder(stream());
|
||||||
encoder.set_input_array(in);
|
encoder.set_input_array(in);
|
||||||
encoder.set_output_array(out);
|
encoder.set_output_array(out);
|
||||||
|
|||||||
@@ -1,11 +0,0 @@
|
|||||||
// Copyright © 2025 Apple Inc.
|
|
||||||
|
|
||||||
#include "mlx/backend/cpu/available.h"
|
|
||||||
|
|
||||||
namespace mlx::core::cpu {
|
|
||||||
|
|
||||||
bool is_available() {
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace mlx::core::cpu
|
|
||||||
@@ -1,9 +0,0 @@
|
|||||||
// Copyright © 2025 Apple Inc.
|
|
||||||
|
|
||||||
#pragma once
|
|
||||||
|
|
||||||
namespace mlx::core::cpu {
|
|
||||||
|
|
||||||
bool is_available();
|
|
||||||
|
|
||||||
} // namespace mlx::core::cpu
|
|
||||||
@@ -172,12 +172,9 @@ void binary_float(
|
|||||||
case bfloat16:
|
case bfloat16:
|
||||||
binary_op<bfloat16_t, Op>(a, b, out, bopt);
|
binary_op<bfloat16_t, Op>(a, b, out, bopt);
|
||||||
break;
|
break;
|
||||||
case complex64:
|
|
||||||
binary_op<complex64_t, Op>(a, b, out, bopt);
|
|
||||||
break;
|
|
||||||
default:
|
default:
|
||||||
throw std::runtime_error(
|
throw std::runtime_error(
|
||||||
"[binary_float] Only supports floating point types.");
|
"[binary_float] Only supports non-complex floating point types.");
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -40,10 +40,7 @@ struct CompilerCache {
|
|||||||
std::shared_mutex mtx;
|
std::shared_mutex mtx;
|
||||||
};
|
};
|
||||||
|
|
||||||
static CompilerCache& cache() {
|
static CompilerCache cache{};
|
||||||
static CompilerCache cache_;
|
|
||||||
return cache_;
|
|
||||||
};
|
|
||||||
|
|
||||||
// GPU compile is always available if the GPU is available and since we are in
|
// GPU compile is always available if the GPU is available and since we are in
|
||||||
// this file CPU compile is also available.
|
// this file CPU compile is also available.
|
||||||
@@ -59,16 +56,14 @@ void* compile(
|
|||||||
const std::string& kernel_name,
|
const std::string& kernel_name,
|
||||||
const std::function<std::string(void)>& source_builder) {
|
const std::function<std::string(void)>& source_builder) {
|
||||||
{
|
{
|
||||||
std::shared_lock lock(cache().mtx);
|
std::shared_lock lock(cache.mtx);
|
||||||
if (auto it = cache().kernels.find(kernel_name);
|
if (auto it = cache.kernels.find(kernel_name); it != cache.kernels.end()) {
|
||||||
it != cache().kernels.end()) {
|
|
||||||
return it->second;
|
return it->second;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
std::unique_lock lock(cache().mtx);
|
std::unique_lock lock(cache.mtx);
|
||||||
if (auto it = cache().kernels.find(kernel_name);
|
if (auto it = cache.kernels.find(kernel_name); it != cache.kernels.end()) {
|
||||||
it != cache().kernels.end()) {
|
|
||||||
return it->second;
|
return it->second;
|
||||||
}
|
}
|
||||||
std::string source_code = source_builder();
|
std::string source_code = source_builder();
|
||||||
@@ -125,10 +120,10 @@ void* compile(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// load library
|
// load library
|
||||||
cache().libs.emplace_back(shared_lib_path);
|
cache.libs.emplace_back(shared_lib_path);
|
||||||
|
|
||||||
// Load function
|
// Load function
|
||||||
void* fun = dlsym(cache().libs.back().lib, kernel_name.c_str());
|
void* fun = dlsym(cache.libs.back().lib, kernel_name.c_str());
|
||||||
if (!fun) {
|
if (!fun) {
|
||||||
std::ostringstream msg;
|
std::ostringstream msg;
|
||||||
msg << "[Compile::eval_cpu] Failed to load compiled function "
|
msg << "[Compile::eval_cpu] Failed to load compiled function "
|
||||||
@@ -136,7 +131,7 @@ void* compile(
|
|||||||
<< dlerror();
|
<< dlerror();
|
||||||
throw std::runtime_error(msg.str());
|
throw std::runtime_error(msg.str());
|
||||||
}
|
}
|
||||||
cache().kernels.insert({kernel_name, fun});
|
cache.kernels.insert({kernel_name, fun});
|
||||||
return fun;
|
return fun;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -146,9 +141,18 @@ inline void build_kernel(
|
|||||||
const std::vector<array>& inputs,
|
const std::vector<array>& inputs,
|
||||||
const std::vector<array>& outputs,
|
const std::vector<array>& outputs,
|
||||||
const std::vector<array>& tape,
|
const std::vector<array>& tape,
|
||||||
const std::function<bool(size_t)>& is_constant,
|
const std::unordered_set<uintptr_t>& constant_ids,
|
||||||
bool contiguous,
|
bool contiguous,
|
||||||
int ndim) {
|
int ndim) {
|
||||||
|
// All outputs should have the exact same shape and will be row contiguous
|
||||||
|
auto output_shape = outputs[0].shape();
|
||||||
|
auto output_strides = outputs[0].strides();
|
||||||
|
|
||||||
|
// Constants are scalars that are captured by value and cannot change
|
||||||
|
auto is_constant = [&constant_ids](const array& x) {
|
||||||
|
return constant_ids.find(x.id()) != constant_ids.end();
|
||||||
|
};
|
||||||
|
|
||||||
NodeNamer namer;
|
NodeNamer namer;
|
||||||
|
|
||||||
#ifdef _MSC_VER
|
#ifdef _MSC_VER
|
||||||
@@ -161,15 +165,14 @@ inline void build_kernel(
|
|||||||
|
|
||||||
// Add the input arguments
|
// Add the input arguments
|
||||||
int cnt = 0;
|
int cnt = 0;
|
||||||
for (size_t i = 0; i < inputs.size(); ++i) {
|
for (auto& x : inputs) {
|
||||||
|
auto& xname = namer.get_name(x);
|
||||||
|
|
||||||
// Skip constants from the input list
|
// Skip constants from the input list
|
||||||
if (is_constant(i)) {
|
if (is_constant(x)) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
const auto& x = inputs[i];
|
|
||||||
auto& xname = namer.get_name(x);
|
|
||||||
|
|
||||||
auto tstr = get_type_string(x.dtype());
|
auto tstr = get_type_string(x.dtype());
|
||||||
os << " " << tstr << "* " << xname << " = (" << tstr << "*)args[" << cnt++
|
os << " " << tstr << "* " << xname << " = (" << tstr << "*)args[" << cnt++
|
||||||
<< "];" << std::endl;
|
<< "];" << std::endl;
|
||||||
@@ -203,11 +206,10 @@ inline void build_kernel(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Read the inputs in tmps
|
// Read the inputs in tmps
|
||||||
for (size_t i = 0; i < inputs.size(); ++i) {
|
for (auto& x : inputs) {
|
||||||
const auto& x = inputs[i];
|
|
||||||
auto& xname = namer.get_name(x);
|
auto& xname = namer.get_name(x);
|
||||||
|
|
||||||
if (is_constant(i)) {
|
if (is_constant(x)) {
|
||||||
os << " " << get_type_string(x.dtype()) << " tmp_" << xname << " = ";
|
os << " " << get_type_string(x.dtype()) << " tmp_" << xname << " = ";
|
||||||
print_constant(os, x);
|
print_constant(os, x);
|
||||||
os << ";" << std::endl;
|
os << ";" << std::endl;
|
||||||
@@ -257,9 +259,8 @@ inline void build_kernel(
|
|||||||
} else {
|
} else {
|
||||||
for (int d = ndim - 1; d >= 0; --d) {
|
for (int d = ndim - 1; d >= 0; --d) {
|
||||||
// Update pointers
|
// Update pointers
|
||||||
for (size_t i = 0; i < inputs.size(); ++i) {
|
for (auto& x : inputs) {
|
||||||
const auto& x = inputs[i];
|
if (is_constant(x) || is_scalar(x)) {
|
||||||
if (is_constant(i) || is_scalar(x)) {
|
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
auto& xname = namer.get_name(x);
|
auto& xname = namer.get_name(x);
|
||||||
@@ -281,37 +282,65 @@ inline void build_kernel(
|
|||||||
void Compiled::eval_cpu(
|
void Compiled::eval_cpu(
|
||||||
const std::vector<array>& inputs,
|
const std::vector<array>& inputs,
|
||||||
std::vector<array>& outputs) {
|
std::vector<array>& outputs) {
|
||||||
|
if (kernel_lib_.empty()) {
|
||||||
|
kernel_lib_ = build_lib_name(inputs_, outputs_, tape_, constant_ids_);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Figure out which kernel we are using
|
||||||
|
auto& shape = outputs[0].shape();
|
||||||
|
auto contiguous = compiled_check_contiguity(inputs, shape);
|
||||||
auto& encoder = cpu::get_command_encoder(stream());
|
auto& encoder = cpu::get_command_encoder(stream());
|
||||||
|
|
||||||
// Collapse contiguous dims to route to a faster kernel if possible. Also
|
// Handle all broadcasting and collect function input arguments
|
||||||
// handle all broadcasting.
|
|
||||||
auto [contiguous, shape, strides] =
|
|
||||||
compiled_collapse_contiguous_dims(inputs, outputs[0], is_constant_);
|
|
||||||
|
|
||||||
// Collect function input arguments.
|
|
||||||
std::vector<void*> args;
|
std::vector<void*> args;
|
||||||
int strides_index = 1;
|
std::vector<std::vector<size_t>> strides;
|
||||||
for (size_t i = 0; i < inputs.size(); ++i) {
|
for (int i = 0; i < inputs.size(); i++) {
|
||||||
if (is_constant_(i)) {
|
// Skip constants.
|
||||||
|
if (constant_ids_.find(inputs_[i].id()) != constant_ids_.end()) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
const auto& x = inputs[i];
|
auto& x = inputs[i];
|
||||||
encoder.set_input_array(x);
|
encoder.set_input_array(x);
|
||||||
args.push_back((void*)x.data<void>());
|
args.push_back((void*)x.data<void>());
|
||||||
if (!contiguous && !is_scalar(x)) {
|
|
||||||
args.push_back(strides[strides_index++].data());
|
if (contiguous || is_scalar(x)) {
|
||||||
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Broadcast the input to the output shape.
|
||||||
|
std::vector<size_t> xstrides;
|
||||||
|
int j = 0;
|
||||||
|
for (; j < shape.size() - x.ndim(); j++) {
|
||||||
|
if (shape[j] == 1) {
|
||||||
|
xstrides.push_back(outputs[0].strides()[j]);
|
||||||
|
} else {
|
||||||
|
xstrides.push_back(0);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for (int i = 0; i < x.ndim(); i++, j++) {
|
||||||
|
if (x.shape(i) == 1) {
|
||||||
|
if (shape[j] == 1) {
|
||||||
|
xstrides.push_back(outputs[0].strides()[j]);
|
||||||
|
} else {
|
||||||
|
xstrides.push_back(0);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
xstrides.push_back(x.strides()[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
strides.push_back(std::move(xstrides));
|
||||||
|
args.push_back(strides.back().data());
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get the kernel name from the lib
|
// Get the kernel name from the lib
|
||||||
int ndim = shape.size();
|
int ndim = shape.size();
|
||||||
auto kernel_name = kernel_lib_ + (contiguous ? "_contiguous" : "_strided_");
|
auto kernel_name = kernel_lib_ + (contiguous ? "_contiguous" : "_strided_");
|
||||||
if (!contiguous) {
|
if (!contiguous) {
|
||||||
kernel_name += std::to_string(ndim);
|
kernel_name += std::to_string(shape.size());
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get the function
|
// Get the function
|
||||||
auto fn_ptr = compile(kernel_name, [&, contiguous = contiguous]() {
|
auto fn_ptr = compile(kernel_name, [&]() {
|
||||||
std::ostringstream kernel;
|
std::ostringstream kernel;
|
||||||
kernel << get_kernel_preamble() << std::endl;
|
kernel << get_kernel_preamble() << std::endl;
|
||||||
kernel << "extern \"C\" {" << std::endl;
|
kernel << "extern \"C\" {" << std::endl;
|
||||||
@@ -321,7 +350,7 @@ void Compiled::eval_cpu(
|
|||||||
inputs_,
|
inputs_,
|
||||||
outputs_,
|
outputs_,
|
||||||
tape_,
|
tape_,
|
||||||
is_constant_,
|
constant_ids_,
|
||||||
contiguous,
|
contiguous,
|
||||||
ndim);
|
ndim);
|
||||||
// Close extern "C"
|
// Close extern "C"
|
||||||
@@ -329,22 +358,26 @@ void Compiled::eval_cpu(
|
|||||||
return kernel.str();
|
return kernel.str();
|
||||||
});
|
});
|
||||||
|
|
||||||
compiled_allocate_outputs(inputs, outputs, is_constant_, contiguous);
|
compiled_allocate_outputs(
|
||||||
|
inputs, outputs, inputs_, constant_ids_, contiguous);
|
||||||
|
|
||||||
for (auto& x : outputs) {
|
for (auto& x : outputs) {
|
||||||
args.push_back(x.data<void>());
|
args.push_back(x.data<void>());
|
||||||
encoder.set_output_array(x);
|
encoder.set_output_array(x);
|
||||||
}
|
}
|
||||||
|
Shape out_shape;
|
||||||
if (!contiguous) {
|
if (!contiguous) {
|
||||||
args.push_back((void*)shape.data());
|
out_shape = outputs[0].shape();
|
||||||
|
args.push_back((void*)out_shape.data());
|
||||||
} else {
|
} else {
|
||||||
args.push_back((void*)outputs[0].data_size());
|
args.push_back((void*)outputs[0].data_size());
|
||||||
}
|
}
|
||||||
auto fun = (void (*)(void**))fn_ptr;
|
auto fun = (void (*)(void**))fn_ptr;
|
||||||
encoder.dispatch([fun,
|
encoder.dispatch(
|
||||||
args = std::move(args),
|
[fun,
|
||||||
strides = std::move(strides),
|
args = std::move(args),
|
||||||
shape = std::move(shape)]() mutable { fun(args.data()); });
|
strides = std::move(strides),
|
||||||
|
out_shape = std::move(out_shape)]() mutable { fun(args.data()); });
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
|||||||
@@ -22,8 +22,7 @@ void slow_conv_1D(
|
|||||||
const array& in,
|
const array& in,
|
||||||
const array& wt,
|
const array& wt,
|
||||||
array out,
|
array out,
|
||||||
const std::vector<int>& padding_lo,
|
const std::vector<int>& padding,
|
||||||
const std::vector<int>& padding_hi,
|
|
||||||
const std::vector<int>& wt_strides,
|
const std::vector<int>& wt_strides,
|
||||||
const std::vector<int>& wt_dilation,
|
const std::vector<int>& wt_dilation,
|
||||||
const std::vector<int>& in_dilation,
|
const std::vector<int>& in_dilation,
|
||||||
@@ -61,8 +60,7 @@ void slow_conv_1D(
|
|||||||
out_stride_O = out.strides()[2],
|
out_stride_O = out.strides()[2],
|
||||||
|
|
||||||
flip,
|
flip,
|
||||||
padding_lo = padding_lo[0],
|
padding = padding[0],
|
||||||
padding_hi = padding_hi[0],
|
|
||||||
wt_stride = wt_strides[0],
|
wt_stride = wt_strides[0],
|
||||||
wt_dilation = wt_dilation[0],
|
wt_dilation = wt_dilation[0],
|
||||||
in_dilation = in_dilation[0]]() mutable {
|
in_dilation = in_dilation[0]]() mutable {
|
||||||
@@ -79,7 +77,7 @@ void slow_conv_1D(
|
|||||||
const T* wt_ptr = filter_wt_ptr + wh * wt_stride_H;
|
const T* wt_ptr = filter_wt_ptr + wh * wt_stride_H;
|
||||||
|
|
||||||
int wh_flip = flip ? (wH - wh - 1) : wh;
|
int wh_flip = flip ? (wH - wh - 1) : wh;
|
||||||
int ih = oh * wt_stride - padding_lo + wh_flip * wt_dilation;
|
int ih = oh * wt_stride - padding + wh_flip * wt_dilation;
|
||||||
|
|
||||||
auto ih_div = std::div(ih, in_dilation);
|
auto ih_div = std::div(ih, in_dilation);
|
||||||
|
|
||||||
@@ -111,8 +109,7 @@ void slow_conv_2D(
|
|||||||
const array& in,
|
const array& in,
|
||||||
const array& wt,
|
const array& wt,
|
||||||
array out,
|
array out,
|
||||||
const std::vector<int>& padding_lo,
|
const std::vector<int>& padding,
|
||||||
const std::vector<int>& padding_hi,
|
|
||||||
const std::vector<int>& wt_strides,
|
const std::vector<int>& wt_strides,
|
||||||
const std::vector<int>& wt_dilation,
|
const std::vector<int>& wt_dilation,
|
||||||
const std::vector<int>& in_dilation,
|
const std::vector<int>& in_dilation,
|
||||||
@@ -123,235 +120,230 @@ void slow_conv_2D(
|
|||||||
encoder.set_input_array(wt);
|
encoder.set_input_array(wt);
|
||||||
encoder.set_output_array(out);
|
encoder.set_output_array(out);
|
||||||
|
|
||||||
encoder.dispatch(
|
encoder.dispatch([st_wt_ptr = wt.data<T>(),
|
||||||
[st_wt_ptr = wt.data<T>(),
|
st_in_ptr = in.data<T>(),
|
||||||
st_in_ptr = in.data<T>(),
|
st_out_ptr = out.data<T>(),
|
||||||
st_out_ptr = out.data<T>(),
|
|
||||||
|
|
||||||
N = in.shape(0), // Batch size, should be the same as out.shape(0)
|
N = in.shape(
|
||||||
iH = 1 + in_dilation[0] * (in.shape(1) - 1), // Input spatial dim
|
0), // Batch size, should be the same as out.shape(0)
|
||||||
iW = 1 + in_dilation[1] * (in.shape(2) - 1), // Input spatial dim
|
iH = 1 +
|
||||||
C = in.shape(3), // In channels
|
in_dilation[0] * (in.shape(1) - 1), // Input spatial dim
|
||||||
oH = out.shape(1), // Output spatial dim
|
iW = 1 +
|
||||||
oW = out.shape(2), // Output spatial dim
|
in_dilation[1] * (in.shape(2) - 1), // Input spatial dim
|
||||||
O = wt.shape(0), // Out channels
|
C = in.shape(3), // In channels
|
||||||
wH = wt.shape(1), // Weight spatial dim
|
oH = out.shape(1), // Output spatial dim
|
||||||
wW = wt.shape(2), // Weight spatial dim
|
oW = out.shape(2), // Output spatial dim
|
||||||
|
O = wt.shape(0), // Out channels
|
||||||
|
wH = wt.shape(1), // Weight spatial dim
|
||||||
|
wW = wt.shape(2), // Weight spatial dim
|
||||||
|
|
||||||
groups = in.shape(3) / wt.shape(3),
|
groups = in.shape(3) / wt.shape(3),
|
||||||
C_per_group = wt.shape(3),
|
C_per_group = wt.shape(3),
|
||||||
|
|
||||||
in_stride_N = in.strides()[0],
|
in_stride_N = in.strides()[0],
|
||||||
in_stride_H = in.strides()[1],
|
in_stride_H = in.strides()[1],
|
||||||
in_stride_W = in.strides()[2],
|
in_stride_W = in.strides()[2],
|
||||||
in_stride_C = in.strides()[3],
|
in_stride_C = in.strides()[3],
|
||||||
|
|
||||||
wt_stride_O = wt.strides()[0],
|
wt_stride_O = wt.strides()[0],
|
||||||
wt_stride_H = wt.strides()[1],
|
wt_stride_H = wt.strides()[1],
|
||||||
wt_stride_W = wt.strides()[2],
|
wt_stride_W = wt.strides()[2],
|
||||||
wt_stride_C = wt.strides()[3],
|
wt_stride_C = wt.strides()[3],
|
||||||
|
|
||||||
out_stride_N = out.strides()[0],
|
out_stride_N = out.strides()[0],
|
||||||
out_stride_H = out.strides()[1],
|
out_stride_H = out.strides()[1],
|
||||||
out_stride_W = out.strides()[2],
|
out_stride_W = out.strides()[2],
|
||||||
out_stride_O = out.strides()[3],
|
out_stride_O = out.strides()[3],
|
||||||
|
|
||||||
padding_lo,
|
padding,
|
||||||
padding_hi,
|
wt_strides,
|
||||||
wt_strides,
|
wt_dilation,
|
||||||
wt_dilation,
|
in_dilation,
|
||||||
in_dilation,
|
flip]() mutable {
|
||||||
flip]() mutable {
|
bool is_idil_one = in_dilation[0] == 1 && in_dilation[1] == 1;
|
||||||
bool is_idil_one = in_dilation[0] == 1 && in_dilation[1] == 1;
|
|
||||||
|
|
||||||
const int O_per_group = O / groups;
|
const int O_per_group = O / groups;
|
||||||
auto pt_conv_no_checks =
|
auto pt_conv_no_checks = [&](const T* in_ptr,
|
||||||
[&](const T* in_ptr, const T* wt_ptr, T* out_ptr, int oh, int ow) {
|
const T* wt_ptr,
|
||||||
out_ptr += oh * out_stride_H + ow * out_stride_W;
|
T* out_ptr,
|
||||||
int ih_base = oh * wt_strides[0] - padding_lo[0];
|
int oh,
|
||||||
int iw_base = ow * wt_strides[1] - padding_lo[1];
|
int ow) {
|
||||||
|
out_ptr += oh * out_stride_H + ow * out_stride_W;
|
||||||
|
int ih_base = oh * wt_strides[0] - padding[0];
|
||||||
|
int iw_base = ow * wt_strides[1] - padding[1];
|
||||||
|
|
||||||
for (int g = 0; g < groups; ++g) {
|
for (int g = 0; g < groups; ++g) {
|
||||||
for (int o = g * O_per_group; o < (g + 1) * O_per_group; ++o) {
|
for (int o = g * O_per_group; o < (g + 1) * O_per_group; ++o) {
|
||||||
float r = 0.;
|
float r = 0.;
|
||||||
|
|
||||||
for (int wh = 0; wh < wH; ++wh) {
|
for (int wh = 0; wh < wH; ++wh) {
|
||||||
for (int ww = 0; ww < wW; ++ww) {
|
for (int ww = 0; ww < wW; ++ww) {
|
||||||
int wh_flip = flip ? wH - wh - 1 : wh;
|
int wh_flip = flip ? wH - wh - 1 : wh;
|
||||||
int ww_flip = flip ? wW - ww - 1 : ww;
|
int ww_flip = flip ? wW - ww - 1 : ww;
|
||||||
int ih = ih_base + wh_flip * wt_dilation[0];
|
int ih = ih_base + wh_flip * wt_dilation[0];
|
||||||
int iw = iw_base + ww_flip * wt_dilation[1];
|
int iw = iw_base + ww_flip * wt_dilation[1];
|
||||||
|
|
||||||
const T* wt_ptr_pt =
|
const T* wt_ptr_pt = wt_ptr + wh * wt_stride_H + ww * wt_stride_W;
|
||||||
wt_ptr + wh * wt_stride_H + ww * wt_stride_W;
|
const T* in_ptr_pt = in_ptr + ih * in_stride_H + iw * in_stride_W;
|
||||||
const T* in_ptr_pt =
|
|
||||||
in_ptr + ih * in_stride_H + iw * in_stride_W;
|
|
||||||
|
|
||||||
for (int c = g * C_per_group; c < (g + 1) * C_per_group;
|
for (int c = g * C_per_group; c < (g + 1) * C_per_group; ++c) {
|
||||||
++c) {
|
r += static_cast<float>(in_ptr_pt[c * in_stride_C]) *
|
||||||
r += static_cast<float>(in_ptr_pt[c * in_stride_C]) *
|
static_cast<float>(
|
||||||
static_cast<float>(
|
wt_ptr_pt[(c % C_per_group) * wt_stride_C]);
|
||||||
wt_ptr_pt[(c % C_per_group) * wt_stride_C]);
|
} // c
|
||||||
} // c
|
} // ww
|
||||||
} // ww
|
} // wh
|
||||||
} // wh
|
|
||||||
|
|
||||||
out_ptr[0] = static_cast<T>(r);
|
out_ptr[0] = static_cast<T>(r);
|
||||||
out_ptr += out_stride_O;
|
out_ptr += out_stride_O;
|
||||||
wt_ptr += wt_stride_O;
|
wt_ptr += wt_stride_O;
|
||||||
} // o
|
} // o
|
||||||
} // g
|
} // g
|
||||||
};
|
};
|
||||||
|
|
||||||
int jump_h = flip ? -wt_dilation[0] : wt_dilation[0];
|
int jump_h = flip ? -wt_dilation[0] : wt_dilation[0];
|
||||||
int jump_w = flip ? -wt_dilation[1] : wt_dilation[1];
|
int jump_w = flip ? -wt_dilation[1] : wt_dilation[1];
|
||||||
|
|
||||||
int init_h = (flip ? (wH - 1) * wt_dilation[0] : 0);
|
int init_h = (flip ? (wH - 1) * wt_dilation[0] : 0);
|
||||||
int init_w = (flip ? (wW - 1) * wt_dilation[1] : 0);
|
int init_w = (flip ? (wW - 1) * wt_dilation[1] : 0);
|
||||||
|
|
||||||
int f_wgt_jump_h =
|
int f_wgt_jump_h =
|
||||||
std::lcm(in_dilation[0], wt_dilation[0]) / wt_dilation[0];
|
std::lcm(in_dilation[0], wt_dilation[0]) / wt_dilation[0];
|
||||||
int f_wgt_jump_w =
|
int f_wgt_jump_w =
|
||||||
std::lcm(in_dilation[1], wt_dilation[1]) / wt_dilation[1];
|
std::lcm(in_dilation[1], wt_dilation[1]) / wt_dilation[1];
|
||||||
|
|
||||||
int f_out_jump_h =
|
int f_out_jump_h = std::lcm(in_dilation[0], wt_strides[0]) / wt_strides[0];
|
||||||
std::lcm(in_dilation[0], wt_strides[0]) / wt_strides[0];
|
int f_out_jump_w = std::lcm(in_dilation[1], wt_strides[1]) / wt_strides[1];
|
||||||
int f_out_jump_w =
|
|
||||||
std::lcm(in_dilation[1], wt_strides[1]) / wt_strides[1];
|
|
||||||
|
|
||||||
std::vector<int> base_h(f_out_jump_h);
|
std::vector<int> base_h(f_out_jump_h);
|
||||||
std::vector<int> base_w(f_out_jump_w);
|
std::vector<int> base_w(f_out_jump_w);
|
||||||
|
|
||||||
for (int i = 0; i < f_out_jump_h; ++i) {
|
for (int i = 0; i < f_out_jump_h; ++i) {
|
||||||
int ih_loop = i * wt_strides[0] - padding_lo[0] + init_h;
|
int ih_loop = i * wt_strides[0] - padding[0] + init_h;
|
||||||
|
|
||||||
int wh_base = 0;
|
int wh_base = 0;
|
||||||
while (wh_base < wH && ih_loop % in_dilation[0] != 0) {
|
while (wh_base < wH && ih_loop % in_dilation[0] != 0) {
|
||||||
wh_base++;
|
wh_base++;
|
||||||
ih_loop += jump_h;
|
ih_loop += jump_h;
|
||||||
}
|
}
|
||||||
|
|
||||||
base_h[i] = wh_base;
|
base_h[i] = wh_base;
|
||||||
}
|
}
|
||||||
|
|
||||||
for (int j = 0; j < f_out_jump_w; ++j) {
|
for (int j = 0; j < f_out_jump_w; ++j) {
|
||||||
int iw_loop = j * wt_strides[1] - padding_lo[1] + init_w;
|
int iw_loop = j * wt_strides[1] - padding[1] + init_w;
|
||||||
|
|
||||||
int ww_base = 0;
|
int ww_base = 0;
|
||||||
while (ww_base < wW && iw_loop % in_dilation[1] != 0) {
|
while (ww_base < wW && iw_loop % in_dilation[1] != 0) {
|
||||||
ww_base++;
|
ww_base++;
|
||||||
iw_loop += jump_w;
|
iw_loop += jump_w;
|
||||||
}
|
}
|
||||||
|
|
||||||
base_w[j] = ww_base;
|
base_w[j] = ww_base;
|
||||||
}
|
}
|
||||||
|
|
||||||
auto pt_conv_all_checks =
|
auto pt_conv_all_checks =
|
||||||
[&](const T* in_ptr, const T* wt_ptr, T* out_ptr, int oh, int ow) {
|
[&](const T* in_ptr, const T* wt_ptr, T* out_ptr, int oh, int ow) {
|
||||||
out_ptr += oh * out_stride_H + ow * out_stride_W;
|
out_ptr += oh * out_stride_H + ow * out_stride_W;
|
||||||
|
|
||||||
int ih_base = oh * wt_strides[0] - padding_lo[0];
|
int ih_base = oh * wt_strides[0] - padding[0];
|
||||||
int iw_base = ow * wt_strides[1] - padding_lo[1];
|
int iw_base = ow * wt_strides[1] - padding[1];
|
||||||
|
|
||||||
int wh_base = base_h[oh % f_out_jump_h];
|
int wh_base = base_h[oh % f_out_jump_h];
|
||||||
int ww_base = base_w[ow % f_out_jump_w];
|
int ww_base = base_w[ow % f_out_jump_w];
|
||||||
|
|
||||||
for (int g = 0; g < groups; ++g) {
|
for (int g = 0; g < groups; ++g) {
|
||||||
for (int o = g * O_per_group; o < (g + 1) * O_per_group; ++o) {
|
for (int o = g * O_per_group; o < (g + 1) * O_per_group; ++o) {
|
||||||
float r = 0.;
|
float r = 0.;
|
||||||
|
|
||||||
for (int wh = wh_base; wh < wH; wh += f_wgt_jump_h) {
|
for (int wh = wh_base; wh < wH; wh += f_wgt_jump_h) {
|
||||||
for (int ww = ww_base; ww < wW; ww += f_wgt_jump_w) {
|
for (int ww = ww_base; ww < wW; ww += f_wgt_jump_w) {
|
||||||
int wh_flip = flip ? wH - wh - 1 : wh;
|
int wh_flip = flip ? wH - wh - 1 : wh;
|
||||||
int ww_flip = flip ? wW - ww - 1 : ww;
|
int ww_flip = flip ? wW - ww - 1 : ww;
|
||||||
int ih = ih_base + wh_flip * wt_dilation[0];
|
int ih = ih_base + wh_flip * wt_dilation[0];
|
||||||
int iw = iw_base + ww_flip * wt_dilation[1];
|
int iw = iw_base + ww_flip * wt_dilation[1];
|
||||||
|
|
||||||
if (ih >= 0 && ih < iH && iw >= 0 && iw < iW) {
|
if (ih >= 0 && ih < iH && iw >= 0 && iw < iW) {
|
||||||
const T* wt_ptr_pt =
|
const T* wt_ptr_pt =
|
||||||
wt_ptr + wh * wt_stride_H + ww * wt_stride_W;
|
wt_ptr + wh * wt_stride_H + ww * wt_stride_W;
|
||||||
|
|
||||||
int ih_dil = !is_idil_one ? (ih / in_dilation[0]) : ih;
|
int ih_dil = !is_idil_one ? (ih / in_dilation[0]) : ih;
|
||||||
int iw_dil = !is_idil_one ? (iw / in_dilation[1]) : iw;
|
int iw_dil = !is_idil_one ? (iw / in_dilation[1]) : iw;
|
||||||
|
|
||||||
const T* in_ptr_pt = in_ptr + ih_dil * in_stride_H +
|
const T* in_ptr_pt =
|
||||||
iw_dil * in_stride_W;
|
in_ptr + ih_dil * in_stride_H + iw_dil * in_stride_W;
|
||||||
|
|
||||||
for (int c = g * C_per_group; c < (g + 1) * C_per_group;
|
for (int c = g * C_per_group; c < (g + 1) * C_per_group;
|
||||||
++c) {
|
++c) {
|
||||||
r += static_cast<float>(in_ptr_pt[c * in_stride_C]) *
|
r += static_cast<float>(in_ptr_pt[c * in_stride_C]) *
|
||||||
static_cast<float>(
|
static_cast<float>(
|
||||||
wt_ptr_pt[(c % C_per_group) * wt_stride_C]);
|
wt_ptr_pt[(c % C_per_group) * wt_stride_C]);
|
||||||
} // c
|
} // c
|
||||||
|
|
||||||
} // ih, iw check
|
} // ih, iw check
|
||||||
} // ww
|
} // ww
|
||||||
} // wh
|
} // wh
|
||||||
|
|
||||||
out_ptr[0] = static_cast<T>(r);
|
out_ptr[0] = static_cast<T>(r);
|
||||||
out_ptr += out_stride_O;
|
out_ptr += out_stride_O;
|
||||||
wt_ptr += wt_stride_O;
|
wt_ptr += wt_stride_O;
|
||||||
} // o
|
} // o
|
||||||
} // g
|
} // g
|
||||||
};
|
};
|
||||||
|
|
||||||
int oH_border_0 = 0;
|
int oH_border_0 = 0;
|
||||||
int oH_border_1 = is_idil_one
|
int oH_border_1 =
|
||||||
? ((padding_lo[0] + wt_strides[0] - 1) / wt_strides[0])
|
is_idil_one ? ((padding[0] + wt_strides[0] - 1) / wt_strides[0]) : oH;
|
||||||
: oH;
|
int oH_border_2 = std::max(
|
||||||
int oH_border_2 = std::max(
|
oH_border_1, (iH + padding[0] - wH * wt_dilation[0]) / wt_strides[0]);
|
||||||
oH_border_1,
|
int oH_border_3 = oH;
|
||||||
(iH + padding_lo[0] - wH * wt_dilation[0]) / wt_strides[0]);
|
|
||||||
int oH_border_3 = oH;
|
|
||||||
|
|
||||||
int oW_border_0 = 0;
|
int oW_border_0 = 0;
|
||||||
int oW_border_1 = is_idil_one
|
int oW_border_1 =
|
||||||
? ((padding_lo[1] + wt_strides[1] - 1) / wt_strides[1])
|
is_idil_one ? ((padding[1] + wt_strides[1] - 1) / wt_strides[1]) : oW;
|
||||||
: oW;
|
int oW_border_2 = std::max(
|
||||||
int oW_border_2 = std::max(
|
oW_border_1, (iW + padding[1] - wW * wt_dilation[1]) / wt_strides[1]);
|
||||||
oW_border_1,
|
int oW_border_3 = oW;
|
||||||
(iW + padding_lo[1] - wW * wt_dilation[1]) / wt_strides[1]);
|
|
||||||
int oW_border_3 = oW;
|
|
||||||
|
|
||||||
for (int n = 0; n < N; ++n) {
|
for (int n = 0; n < N; ++n) {
|
||||||
// Case 1: oh might put us out of bounds
|
// Case 1: oh might put us out of bounds
|
||||||
for (int oh = oH_border_0; oh < oH_border_1; ++oh) {
|
for (int oh = oH_border_0; oh < oH_border_1; ++oh) {
|
||||||
for (int ow = 0; ow < oW; ++ow) {
|
for (int ow = 0; ow < oW; ++ow) {
|
||||||
pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, oh, ow);
|
pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, oh, ow);
|
||||||
} // ow
|
} // ow
|
||||||
} // oh
|
} // oh
|
||||||
|
|
||||||
// Case 2: oh in bounds
|
// Case 2: oh in bounds
|
||||||
for (int oh = oH_border_1; oh < oH_border_2; ++oh) {
|
for (int oh = oH_border_1; oh < oH_border_2; ++oh) {
|
||||||
// Case a: ow might put us out of bounds
|
// Case a: ow might put us out of bounds
|
||||||
for (int ow = oW_border_0; ow < oW_border_1; ++ow) {
|
for (int ow = oW_border_0; ow < oW_border_1; ++ow) {
|
||||||
pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, oh, ow);
|
pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, oh, ow);
|
||||||
} // ow
|
} // ow
|
||||||
|
|
||||||
// Case b: ow in bounds
|
// Case b: ow in bounds
|
||||||
for (int ow = oW_border_1; ow < oW_border_2; ++ow) {
|
for (int ow = oW_border_1; ow < oW_border_2; ++ow) {
|
||||||
pt_conv_no_checks(st_in_ptr, st_wt_ptr, st_out_ptr, oh, ow);
|
pt_conv_no_checks(st_in_ptr, st_wt_ptr, st_out_ptr, oh, ow);
|
||||||
} // ow
|
} // ow
|
||||||
|
|
||||||
// Case c: ow might put us out of bounds
|
// Case c: ow might put us out of bounds
|
||||||
for (int ow = oW_border_2; ow < oW_border_3; ++ow) {
|
for (int ow = oW_border_2; ow < oW_border_3; ++ow) {
|
||||||
pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, oh, ow);
|
pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, oh, ow);
|
||||||
} // ow
|
} // ow
|
||||||
|
|
||||||
} // oh
|
} // oh
|
||||||
|
|
||||||
// Case 3: oh might put us out of bounds
|
// Case 3: oh might put us out of bounds
|
||||||
for (int oh = oH_border_2; oh < oH_border_3; ++oh) {
|
for (int oh = oH_border_2; oh < oH_border_3; ++oh) {
|
||||||
for (int ow = 0; ow < oW; ++ow) {
|
for (int ow = 0; ow < oW; ++ow) {
|
||||||
pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, oh, ow);
|
pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, oh, ow);
|
||||||
} // ow
|
} // ow
|
||||||
} // oh
|
} // oh
|
||||||
|
|
||||||
st_in_ptr += in_stride_N;
|
st_in_ptr += in_stride_N;
|
||||||
st_out_ptr += out_stride_N;
|
st_out_ptr += out_stride_N;
|
||||||
|
|
||||||
} // n
|
} // n
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
@@ -359,8 +351,7 @@ void slow_conv_3D(
|
|||||||
const array& in,
|
const array& in,
|
||||||
const array& wt,
|
const array& wt,
|
||||||
array out,
|
array out,
|
||||||
const std::vector<int>& padding_lo,
|
const std::vector<int>& padding,
|
||||||
const std::vector<int>& padding_hi,
|
|
||||||
const std::vector<int>& wt_strides,
|
const std::vector<int>& wt_strides,
|
||||||
const std::vector<int>& wt_dilation,
|
const std::vector<int>& wt_dilation,
|
||||||
const std::vector<int>& in_dilation,
|
const std::vector<int>& in_dilation,
|
||||||
@@ -409,8 +400,7 @@ void slow_conv_3D(
|
|||||||
out_stride_H = out.strides()[2],
|
out_stride_H = out.strides()[2],
|
||||||
out_stride_W = out.strides()[3],
|
out_stride_W = out.strides()[3],
|
||||||
out_stride_O = out.strides()[4],
|
out_stride_O = out.strides()[4],
|
||||||
padding_lo,
|
padding,
|
||||||
padding_hi,
|
|
||||||
wt_strides,
|
wt_strides,
|
||||||
wt_dilation,
|
wt_dilation,
|
||||||
in_dilation,
|
in_dilation,
|
||||||
@@ -425,9 +415,9 @@ void slow_conv_3D(
|
|||||||
int oh,
|
int oh,
|
||||||
int ow) {
|
int ow) {
|
||||||
out_ptr += od * out_stride_D + oh * out_stride_H + ow * out_stride_W;
|
out_ptr += od * out_stride_D + oh * out_stride_H + ow * out_stride_W;
|
||||||
int id_base = od * wt_strides[0] - padding_lo[0];
|
int id_base = od * wt_strides[0] - padding[0];
|
||||||
int ih_base = oh * wt_strides[1] - padding_lo[1];
|
int ih_base = oh * wt_strides[1] - padding[1];
|
||||||
int iw_base = ow * wt_strides[2] - padding_lo[2];
|
int iw_base = ow * wt_strides[2] - padding[2];
|
||||||
|
|
||||||
for (int o = 0; o < O; ++o) {
|
for (int o = 0; o < O; ++o) {
|
||||||
float r = 0.;
|
float r = 0.;
|
||||||
@@ -488,7 +478,7 @@ void slow_conv_3D(
|
|||||||
std::vector<int> base_w(f_out_jump_w);
|
std::vector<int> base_w(f_out_jump_w);
|
||||||
|
|
||||||
for (int i = 0; i < f_out_jump_d; ++i) {
|
for (int i = 0; i < f_out_jump_d; ++i) {
|
||||||
int id_loop = i * wt_strides[0] - padding_lo[0] + init_d;
|
int id_loop = i * wt_strides[0] - padding[0] + init_d;
|
||||||
|
|
||||||
int wd_base = 0;
|
int wd_base = 0;
|
||||||
while (wd_base < wD && id_loop % in_dilation[0] != 0) {
|
while (wd_base < wD && id_loop % in_dilation[0] != 0) {
|
||||||
@@ -500,7 +490,7 @@ void slow_conv_3D(
|
|||||||
}
|
}
|
||||||
|
|
||||||
for (int i = 0; i < f_out_jump_h; ++i) {
|
for (int i = 0; i < f_out_jump_h; ++i) {
|
||||||
int ih_loop = i * wt_strides[1] - padding_lo[1] + init_h;
|
int ih_loop = i * wt_strides[1] - padding[1] + init_h;
|
||||||
|
|
||||||
int wh_base = 0;
|
int wh_base = 0;
|
||||||
while (wh_base < wH && ih_loop % in_dilation[1] != 0) {
|
while (wh_base < wH && ih_loop % in_dilation[1] != 0) {
|
||||||
@@ -512,7 +502,7 @@ void slow_conv_3D(
|
|||||||
}
|
}
|
||||||
|
|
||||||
for (int j = 0; j < f_out_jump_w; ++j) {
|
for (int j = 0; j < f_out_jump_w; ++j) {
|
||||||
int iw_loop = j * wt_strides[2] - padding_lo[2] + init_w;
|
int iw_loop = j * wt_strides[2] - padding[2] + init_w;
|
||||||
|
|
||||||
int ww_base = 0;
|
int ww_base = 0;
|
||||||
while (ww_base < wW && iw_loop % in_dilation[2] != 0) {
|
while (ww_base < wW && iw_loop % in_dilation[2] != 0) {
|
||||||
@@ -531,9 +521,9 @@ void slow_conv_3D(
|
|||||||
int ow) {
|
int ow) {
|
||||||
out_ptr += od * out_stride_D + oh * out_stride_H + ow * out_stride_W;
|
out_ptr += od * out_stride_D + oh * out_stride_H + ow * out_stride_W;
|
||||||
|
|
||||||
int id_base = od * wt_strides[0] - padding_lo[0];
|
int id_base = od * wt_strides[0] - padding[0];
|
||||||
int ih_base = oh * wt_strides[1] - padding_lo[1];
|
int ih_base = oh * wt_strides[1] - padding[1];
|
||||||
int iw_base = ow * wt_strides[2] - padding_lo[2];
|
int iw_base = ow * wt_strides[2] - padding[2];
|
||||||
|
|
||||||
int wd_base = base_d[od % f_out_jump_d];
|
int wd_base = base_d[od % f_out_jump_d];
|
||||||
int wh_base = base_h[oh % f_out_jump_h];
|
int wh_base = base_h[oh % f_out_jump_h];
|
||||||
@@ -583,30 +573,24 @@ void slow_conv_3D(
|
|||||||
};
|
};
|
||||||
|
|
||||||
int oD_border_0 = 0;
|
int oD_border_0 = 0;
|
||||||
int oD_border_1 = is_idil_one
|
int oD_border_1 =
|
||||||
? ((padding_lo[0] + wt_strides[0] - 1) / wt_strides[0])
|
is_idil_one ? ((padding[0] + wt_strides[0] - 1) / wt_strides[0]) : oD;
|
||||||
: oD;
|
|
||||||
int oD_border_2 = std::max(
|
int oD_border_2 = std::max(
|
||||||
oD_border_1,
|
oD_border_1, (iD + padding[0] - wD * wt_dilation[0]) / wt_strides[0]);
|
||||||
(iD + padding_lo[0] - wD * wt_dilation[0]) / wt_strides[0]);
|
|
||||||
int oD_border_3 = oD;
|
int oD_border_3 = oD;
|
||||||
|
|
||||||
int oH_border_0 = 0;
|
int oH_border_0 = 0;
|
||||||
int oH_border_1 = is_idil_one
|
int oH_border_1 =
|
||||||
? ((padding_lo[1] + wt_strides[1] - 1) / wt_strides[1])
|
is_idil_one ? ((padding[1] + wt_strides[1] - 1) / wt_strides[1]) : oH;
|
||||||
: oH;
|
|
||||||
int oH_border_2 = std::max(
|
int oH_border_2 = std::max(
|
||||||
oH_border_1,
|
oH_border_1, (iH + padding[1] - wH * wt_dilation[1]) / wt_strides[1]);
|
||||||
(iH + padding_lo[1] - wH * wt_dilation[1]) / wt_strides[1]);
|
|
||||||
int oH_border_3 = oH;
|
int oH_border_3 = oH;
|
||||||
|
|
||||||
int oW_border_0 = 0;
|
int oW_border_0 = 0;
|
||||||
int oW_border_1 = is_idil_one
|
int oW_border_1 =
|
||||||
? ((padding_lo[2] + wt_strides[2] - 1) / wt_strides[2])
|
is_idil_one ? ((padding[2] + wt_strides[2] - 1) / wt_strides[2]) : oW;
|
||||||
: oW;
|
|
||||||
int oW_border_2 = std::max(
|
int oW_border_2 = std::max(
|
||||||
oW_border_1,
|
oW_border_1, (iW + padding[2] - wW * wt_dilation[2]) / wt_strides[2]);
|
||||||
(iW + padding_lo[2] - wW * wt_dilation[2]) / wt_strides[2]);
|
|
||||||
int oW_border_3 = oW;
|
int oW_border_3 = oW;
|
||||||
|
|
||||||
for (int n = 0; n < N; ++n) {
|
for (int n = 0; n < N; ++n) {
|
||||||
@@ -674,8 +658,7 @@ void dispatch_slow_conv_1D(
|
|||||||
const array& in,
|
const array& in,
|
||||||
const array& wt,
|
const array& wt,
|
||||||
array out,
|
array out,
|
||||||
const std::vector<int>& padding_lo,
|
const std::vector<int>& padding,
|
||||||
const std::vector<int>& padding_hi,
|
|
||||||
const std::vector<int>& wt_strides,
|
const std::vector<int>& wt_strides,
|
||||||
const std::vector<int>& wt_dilation,
|
const std::vector<int>& wt_dilation,
|
||||||
const std::vector<int>& in_dilation,
|
const std::vector<int>& in_dilation,
|
||||||
@@ -686,8 +669,7 @@ void dispatch_slow_conv_1D(
|
|||||||
in,
|
in,
|
||||||
wt,
|
wt,
|
||||||
out,
|
out,
|
||||||
padding_lo,
|
padding,
|
||||||
padding_hi,
|
|
||||||
wt_strides,
|
wt_strides,
|
||||||
wt_dilation,
|
wt_dilation,
|
||||||
in_dilation,
|
in_dilation,
|
||||||
@@ -698,8 +680,7 @@ void dispatch_slow_conv_1D(
|
|||||||
in,
|
in,
|
||||||
wt,
|
wt,
|
||||||
out,
|
out,
|
||||||
padding_lo,
|
padding,
|
||||||
padding_hi,
|
|
||||||
wt_strides,
|
wt_strides,
|
||||||
wt_dilation,
|
wt_dilation,
|
||||||
in_dilation,
|
in_dilation,
|
||||||
@@ -710,8 +691,7 @@ void dispatch_slow_conv_1D(
|
|||||||
in,
|
in,
|
||||||
wt,
|
wt,
|
||||||
out,
|
out,
|
||||||
padding_lo,
|
padding,
|
||||||
padding_hi,
|
|
||||||
wt_strides,
|
wt_strides,
|
||||||
wt_dilation,
|
wt_dilation,
|
||||||
in_dilation,
|
in_dilation,
|
||||||
@@ -727,8 +707,7 @@ void dispatch_slow_conv_2D(
|
|||||||
const array& in,
|
const array& in,
|
||||||
const array& wt,
|
const array& wt,
|
||||||
array out,
|
array out,
|
||||||
const std::vector<int>& padding_lo,
|
const std::vector<int>& padding,
|
||||||
const std::vector<int>& padding_hi,
|
|
||||||
const std::vector<int>& wt_strides,
|
const std::vector<int>& wt_strides,
|
||||||
const std::vector<int>& wt_dilation,
|
const std::vector<int>& wt_dilation,
|
||||||
const std::vector<int>& in_dilation,
|
const std::vector<int>& in_dilation,
|
||||||
@@ -739,8 +718,7 @@ void dispatch_slow_conv_2D(
|
|||||||
in,
|
in,
|
||||||
wt,
|
wt,
|
||||||
out,
|
out,
|
||||||
padding_lo,
|
padding,
|
||||||
padding_hi,
|
|
||||||
wt_strides,
|
wt_strides,
|
||||||
wt_dilation,
|
wt_dilation,
|
||||||
in_dilation,
|
in_dilation,
|
||||||
@@ -751,8 +729,7 @@ void dispatch_slow_conv_2D(
|
|||||||
in,
|
in,
|
||||||
wt,
|
wt,
|
||||||
out,
|
out,
|
||||||
padding_lo,
|
padding,
|
||||||
padding_hi,
|
|
||||||
wt_strides,
|
wt_strides,
|
||||||
wt_dilation,
|
wt_dilation,
|
||||||
in_dilation,
|
in_dilation,
|
||||||
@@ -763,8 +740,7 @@ void dispatch_slow_conv_2D(
|
|||||||
in,
|
in,
|
||||||
wt,
|
wt,
|
||||||
out,
|
out,
|
||||||
padding_lo,
|
padding,
|
||||||
padding_hi,
|
|
||||||
wt_strides,
|
wt_strides,
|
||||||
wt_dilation,
|
wt_dilation,
|
||||||
in_dilation,
|
in_dilation,
|
||||||
@@ -780,8 +756,7 @@ void dispatch_slow_conv_3D(
|
|||||||
const array& in,
|
const array& in,
|
||||||
const array& wt,
|
const array& wt,
|
||||||
array out,
|
array out,
|
||||||
const std::vector<int>& padding_lo,
|
const std::vector<int>& padding,
|
||||||
const std::vector<int>& padding_hi,
|
|
||||||
const std::vector<int>& wt_strides,
|
const std::vector<int>& wt_strides,
|
||||||
const std::vector<int>& wt_dilation,
|
const std::vector<int>& wt_dilation,
|
||||||
const std::vector<int>& in_dilation,
|
const std::vector<int>& in_dilation,
|
||||||
@@ -792,8 +767,7 @@ void dispatch_slow_conv_3D(
|
|||||||
in,
|
in,
|
||||||
wt,
|
wt,
|
||||||
out,
|
out,
|
||||||
padding_lo,
|
padding,
|
||||||
padding_hi,
|
|
||||||
wt_strides,
|
wt_strides,
|
||||||
wt_dilation,
|
wt_dilation,
|
||||||
in_dilation,
|
in_dilation,
|
||||||
@@ -804,8 +778,7 @@ void dispatch_slow_conv_3D(
|
|||||||
in,
|
in,
|
||||||
wt,
|
wt,
|
||||||
out,
|
out,
|
||||||
padding_lo,
|
padding,
|
||||||
padding_hi,
|
|
||||||
wt_strides,
|
wt_strides,
|
||||||
wt_dilation,
|
wt_dilation,
|
||||||
in_dilation,
|
in_dilation,
|
||||||
@@ -816,8 +789,7 @@ void dispatch_slow_conv_3D(
|
|||||||
in,
|
in,
|
||||||
wt,
|
wt,
|
||||||
out,
|
out,
|
||||||
padding_lo,
|
padding,
|
||||||
padding_hi,
|
|
||||||
wt_strides,
|
wt_strides,
|
||||||
wt_dilation,
|
wt_dilation,
|
||||||
in_dilation,
|
in_dilation,
|
||||||
@@ -857,8 +829,7 @@ void explicit_gemm_conv_1D_cpu(
|
|||||||
const array& in,
|
const array& in,
|
||||||
const array& wt,
|
const array& wt,
|
||||||
array out,
|
array out,
|
||||||
const std::vector<int>& padding_lo,
|
const std::vector<int>& padding,
|
||||||
const std::vector<int>& padding_hi,
|
|
||||||
const std::vector<int>& wt_strides,
|
const std::vector<int>& wt_strides,
|
||||||
const std::vector<int>& wt_dilation,
|
const std::vector<int>& wt_dilation,
|
||||||
Stream stream) {
|
Stream stream) {
|
||||||
@@ -877,7 +848,7 @@ void explicit_gemm_conv_1D_cpu(
|
|||||||
auto& encoder = cpu::get_command_encoder(stream);
|
auto& encoder = cpu::get_command_encoder(stream);
|
||||||
|
|
||||||
// Pad input
|
// Pad input
|
||||||
Shape padded_shape = {N, iH + padding_lo[0] + padding_hi[0], C};
|
Shape padded_shape = {N, iH + 2 * padding[0], C};
|
||||||
array in_padded(padded_shape, conv_dtype, nullptr, {});
|
array in_padded(padded_shape, conv_dtype, nullptr, {});
|
||||||
|
|
||||||
// Fill with zeros
|
// Fill with zeros
|
||||||
@@ -886,7 +857,7 @@ void explicit_gemm_conv_1D_cpu(
|
|||||||
copy(temps.back(), in_padded, CopyType::Scalar, stream);
|
copy(temps.back(), in_padded, CopyType::Scalar, stream);
|
||||||
|
|
||||||
// Pick input slice from padded
|
// Pick input slice from padded
|
||||||
size_t data_offset = padding_lo[0] * in_padded.strides()[1];
|
size_t data_offset = padding[0] * in_padded.strides()[1];
|
||||||
array in_padded_slice(in.shape(), in_padded.dtype(), nullptr, {});
|
array in_padded_slice(in.shape(), in_padded.dtype(), nullptr, {});
|
||||||
in_padded_slice.copy_shared_buffer(
|
in_padded_slice.copy_shared_buffer(
|
||||||
in_padded,
|
in_padded,
|
||||||
@@ -950,7 +921,7 @@ void explicit_gemm_conv_1D_cpu(
|
|||||||
|
|
||||||
if (out.dtype() != float32) {
|
if (out.dtype() != float32) {
|
||||||
gemm_out = array(out.shape(), float32, nullptr, {});
|
gemm_out = array(out.shape(), float32, nullptr, {});
|
||||||
gemm_out.set_data(allocator::malloc(gemm_out.nbytes()));
|
gemm_out.set_data(allocator::malloc_or_wait(gemm_out.nbytes()));
|
||||||
temps.push_back(gemm_out);
|
temps.push_back(gemm_out);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1000,8 +971,7 @@ void explicit_gemm_conv_2D_cpu(
|
|||||||
const array& in,
|
const array& in,
|
||||||
const array& wt,
|
const array& wt,
|
||||||
array out,
|
array out,
|
||||||
const std::vector<int>& padding_lo,
|
const std::vector<int>& padding,
|
||||||
const std::vector<int>& padding_hi,
|
|
||||||
const std::vector<int>& wt_strides,
|
const std::vector<int>& wt_strides,
|
||||||
const std::vector<int>& wt_dilation,
|
const std::vector<int>& wt_dilation,
|
||||||
Stream stream) {
|
Stream stream) {
|
||||||
@@ -1019,11 +989,7 @@ void explicit_gemm_conv_2D_cpu(
|
|||||||
auto& encoder = cpu::get_command_encoder(stream);
|
auto& encoder = cpu::get_command_encoder(stream);
|
||||||
|
|
||||||
// Pad input
|
// Pad input
|
||||||
Shape padded_shape = {
|
Shape padded_shape = {N, iH + 2 * padding[0], iW + 2 * padding[1], C};
|
||||||
N,
|
|
||||||
iH + padding_lo[0] + padding_hi[0],
|
|
||||||
iW + padding_lo[1] + padding_hi[1],
|
|
||||||
C};
|
|
||||||
array in_padded(padded_shape, conv_dtype, nullptr, {});
|
array in_padded(padded_shape, conv_dtype, nullptr, {});
|
||||||
|
|
||||||
// Fill with zeros
|
// Fill with zeros
|
||||||
@@ -1032,8 +998,8 @@ void explicit_gemm_conv_2D_cpu(
|
|||||||
copy(temps.back(), in_padded, CopyType::Scalar, stream);
|
copy(temps.back(), in_padded, CopyType::Scalar, stream);
|
||||||
|
|
||||||
// Pick input slice from padded
|
// Pick input slice from padded
|
||||||
size_t data_offset = padding_lo[0] * in_padded.strides()[1] +
|
size_t data_offset =
|
||||||
padding_lo[1] * in_padded.strides()[2];
|
padding[0] * in_padded.strides()[1] + padding[1] * in_padded.strides()[2];
|
||||||
array in_padded_slice(in.shape(), in_padded.dtype(), nullptr, {});
|
array in_padded_slice(in.shape(), in_padded.dtype(), nullptr, {});
|
||||||
in_padded_slice.copy_shared_buffer(
|
in_padded_slice.copy_shared_buffer(
|
||||||
in_padded,
|
in_padded,
|
||||||
@@ -1082,7 +1048,7 @@ void explicit_gemm_conv_2D_cpu(
|
|||||||
|
|
||||||
if (out.dtype() != float32) {
|
if (out.dtype() != float32) {
|
||||||
gemm_out = array(out.shape(), float32, nullptr, {});
|
gemm_out = array(out.shape(), float32, nullptr, {});
|
||||||
gemm_out.set_data(allocator::malloc(gemm_out.nbytes()));
|
gemm_out.set_data(allocator::malloc_or_wait(gemm_out.nbytes()));
|
||||||
temps.push_back(gemm_out);
|
temps.push_back(gemm_out);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1125,8 +1091,7 @@ void explicit_gemm_conv_ND_cpu(
|
|||||||
const array& in,
|
const array& in,
|
||||||
const array& wt,
|
const array& wt,
|
||||||
array out,
|
array out,
|
||||||
const std::vector<int>& padding_lo,
|
const std::vector<int>& padding,
|
||||||
const std::vector<int>& padding_hi,
|
|
||||||
const std::vector<int>& wt_strides,
|
const std::vector<int>& wt_strides,
|
||||||
const std::vector<int>& wt_dilation,
|
const std::vector<int>& wt_dilation,
|
||||||
const bool flip,
|
const bool flip,
|
||||||
@@ -1149,7 +1114,7 @@ void explicit_gemm_conv_ND_cpu(
|
|||||||
Shape padded_shape(in.shape().size());
|
Shape padded_shape(in.shape().size());
|
||||||
padded_shape.front() = N;
|
padded_shape.front() = N;
|
||||||
for (size_t i = 0; i < iDim.size(); i++) {
|
for (size_t i = 0; i < iDim.size(); i++) {
|
||||||
padded_shape[i + 1] = iDim[i] + padding_lo[i] + padding_hi[i];
|
padded_shape[i + 1] = iDim[i] + 2 * padding[i];
|
||||||
}
|
}
|
||||||
padded_shape.back() = C;
|
padded_shape.back() = C;
|
||||||
array in_padded(padded_shape, conv_dtype, nullptr, {});
|
array in_padded(padded_shape, conv_dtype, nullptr, {});
|
||||||
@@ -1160,10 +1125,9 @@ void explicit_gemm_conv_ND_cpu(
|
|||||||
|
|
||||||
// Pick input slice from padded
|
// Pick input slice from padded
|
||||||
size_t data_offset = 0;
|
size_t data_offset = 0;
|
||||||
for (size_t i = 0; i < padding_lo.size(); i++) {
|
for (size_t i = 0; i < padding.size(); i++) {
|
||||||
data_offset += padding_lo[i] * in_padded.strides()[i + 1];
|
data_offset += padding[i] * in_padded.strides()[i + 1];
|
||||||
}
|
}
|
||||||
|
|
||||||
array in_padded_slice(in.shape(), in_padded.dtype(), nullptr, {});
|
array in_padded_slice(in.shape(), in_padded.dtype(), nullptr, {});
|
||||||
in_padded_slice.copy_shared_buffer(
|
in_padded_slice.copy_shared_buffer(
|
||||||
in_padded,
|
in_padded,
|
||||||
@@ -1250,7 +1214,7 @@ void explicit_gemm_conv_ND_cpu(
|
|||||||
|
|
||||||
if (out.dtype() != float32) {
|
if (out.dtype() != float32) {
|
||||||
gemm_out = array(out.shape(), float32, nullptr, {});
|
gemm_out = array(out.shape(), float32, nullptr, {});
|
||||||
gemm_out.set_data(allocator::malloc(gemm_out.nbytes()));
|
gemm_out.set_data(allocator::malloc_or_wait(gemm_out.nbytes()));
|
||||||
temps.push_back(gemm_out);
|
temps.push_back(gemm_out);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1297,8 +1261,7 @@ void conv_1D_cpu(
|
|||||||
const array& in,
|
const array& in,
|
||||||
const array& wt,
|
const array& wt,
|
||||||
array out,
|
array out,
|
||||||
const std::vector<int>& padding_lo,
|
const std::vector<int>& padding,
|
||||||
const std::vector<int>& padding_hi,
|
|
||||||
const std::vector<int>& wt_strides,
|
const std::vector<int>& wt_strides,
|
||||||
const std::vector<int>& wt_dilation,
|
const std::vector<int>& wt_dilation,
|
||||||
const std::vector<int>& in_dilation,
|
const std::vector<int>& in_dilation,
|
||||||
@@ -1307,40 +1270,22 @@ void conv_1D_cpu(
|
|||||||
const int groups = in.shape().back() / wt.shape().back();
|
const int groups = in.shape().back() / wt.shape().back();
|
||||||
if (wt_dilation[0] == 1 && in_dilation[0] == 1 && !flip) {
|
if (wt_dilation[0] == 1 && in_dilation[0] == 1 && !flip) {
|
||||||
return explicit_gemm_conv_1D_cpu(
|
return explicit_gemm_conv_1D_cpu(
|
||||||
in, wt, out, padding_lo, padding_hi, wt_strides, wt_dilation, stream);
|
in, wt, out, padding, wt_strides, wt_dilation, stream);
|
||||||
}
|
}
|
||||||
if (wt_dilation[0] == 1 && in_dilation[0] == 1 && groups == 1) {
|
if (wt_dilation[0] == 1 && in_dilation[0] == 1 && groups == 1) {
|
||||||
return explicit_gemm_conv_ND_cpu(
|
return explicit_gemm_conv_ND_cpu(
|
||||||
in,
|
in, wt, out, padding, wt_strides, wt_dilation, flip, stream);
|
||||||
wt,
|
|
||||||
out,
|
|
||||||
padding_lo,
|
|
||||||
padding_hi,
|
|
||||||
wt_strides,
|
|
||||||
wt_dilation,
|
|
||||||
flip,
|
|
||||||
stream);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return dispatch_slow_conv_1D(
|
return dispatch_slow_conv_1D(
|
||||||
in,
|
in, wt, out, padding, wt_strides, wt_dilation, in_dilation, flip, stream);
|
||||||
wt,
|
|
||||||
out,
|
|
||||||
padding_lo,
|
|
||||||
padding_hi,
|
|
||||||
wt_strides,
|
|
||||||
wt_dilation,
|
|
||||||
in_dilation,
|
|
||||||
flip,
|
|
||||||
stream);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void conv_2D_cpu(
|
void conv_2D_cpu(
|
||||||
const array& in,
|
const array& in,
|
||||||
const array& wt,
|
const array& wt,
|
||||||
array out,
|
array out,
|
||||||
const std::vector<int>& padding_lo,
|
const std::vector<int>& padding,
|
||||||
const std::vector<int>& padding_hi,
|
|
||||||
const std::vector<int>& wt_strides,
|
const std::vector<int>& wt_strides,
|
||||||
const std::vector<int>& wt_dilation,
|
const std::vector<int>& wt_dilation,
|
||||||
const std::vector<int>& in_dilation,
|
const std::vector<int>& in_dilation,
|
||||||
@@ -1350,35 +1295,18 @@ void conv_2D_cpu(
|
|||||||
if (wt_dilation[0] == 1 && wt_dilation[1] == 1 && in_dilation[0] == 1 &&
|
if (wt_dilation[0] == 1 && wt_dilation[1] == 1 && in_dilation[0] == 1 &&
|
||||||
in_dilation[1] == 1 && groups == 1) {
|
in_dilation[1] == 1 && groups == 1) {
|
||||||
return explicit_gemm_conv_ND_cpu(
|
return explicit_gemm_conv_ND_cpu(
|
||||||
in,
|
in, wt, out, padding, wt_strides, wt_dilation, flip, stream);
|
||||||
wt,
|
|
||||||
out,
|
|
||||||
padding_lo,
|
|
||||||
padding_hi,
|
|
||||||
wt_strides,
|
|
||||||
wt_dilation,
|
|
||||||
flip,
|
|
||||||
stream);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return dispatch_slow_conv_2D(
|
return dispatch_slow_conv_2D(
|
||||||
in,
|
in, wt, out, padding, wt_strides, wt_dilation, in_dilation, flip, stream);
|
||||||
wt,
|
|
||||||
out,
|
|
||||||
padding_lo,
|
|
||||||
padding_hi,
|
|
||||||
wt_strides,
|
|
||||||
wt_dilation,
|
|
||||||
in_dilation,
|
|
||||||
flip,
|
|
||||||
stream);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void conv_3D_cpu(
|
void conv_3D_cpu(
|
||||||
const array& in,
|
const array& in,
|
||||||
const array& wt,
|
const array& wt,
|
||||||
array out,
|
array out,
|
||||||
const std::vector<int>& padding_lo,
|
const std::vector<int>& padding,
|
||||||
const std::vector<int>& padding_hi,
|
|
||||||
const std::vector<int>& wt_strides,
|
const std::vector<int>& wt_strides,
|
||||||
const std::vector<int>& wt_dilation,
|
const std::vector<int>& wt_dilation,
|
||||||
const std::vector<int>& in_dilation,
|
const std::vector<int>& in_dilation,
|
||||||
@@ -1389,34 +1317,17 @@ void conv_3D_cpu(
|
|||||||
in_dilation[0] == 1 && in_dilation[1] == 1 && in_dilation[2] == 1 &&
|
in_dilation[0] == 1 && in_dilation[1] == 1 && in_dilation[2] == 1 &&
|
||||||
groups == 1) {
|
groups == 1) {
|
||||||
return explicit_gemm_conv_ND_cpu(
|
return explicit_gemm_conv_ND_cpu(
|
||||||
in,
|
in, wt, out, padding, wt_strides, wt_dilation, flip, stream);
|
||||||
wt,
|
|
||||||
out,
|
|
||||||
padding_lo,
|
|
||||||
padding_hi,
|
|
||||||
wt_strides,
|
|
||||||
wt_dilation,
|
|
||||||
flip,
|
|
||||||
stream);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return dispatch_slow_conv_3D(
|
return dispatch_slow_conv_3D(
|
||||||
in,
|
in, wt, out, padding, wt_strides, wt_dilation, in_dilation, flip, stream);
|
||||||
wt,
|
|
||||||
out,
|
|
||||||
padding_lo,
|
|
||||||
padding_hi,
|
|
||||||
wt_strides,
|
|
||||||
wt_dilation,
|
|
||||||
in_dilation,
|
|
||||||
flip,
|
|
||||||
stream);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
void Convolution::eval_cpu(const std::vector<array>& inputs, array& out) {
|
void Convolution::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||||
out.set_data(allocator::malloc(out.nbytes()));
|
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||||
|
|
||||||
auto& in = inputs[0];
|
auto& in = inputs[0];
|
||||||
auto& wt = inputs[1];
|
auto& wt = inputs[1];
|
||||||
@@ -1427,8 +1338,7 @@ void Convolution::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
in,
|
in,
|
||||||
wt,
|
wt,
|
||||||
out,
|
out,
|
||||||
padding_lo_,
|
padding_,
|
||||||
padding_hi_,
|
|
||||||
kernel_strides_,
|
kernel_strides_,
|
||||||
kernel_dilation_,
|
kernel_dilation_,
|
||||||
input_dilation_,
|
input_dilation_,
|
||||||
@@ -1441,8 +1351,7 @@ void Convolution::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
in,
|
in,
|
||||||
wt,
|
wt,
|
||||||
out,
|
out,
|
||||||
padding_lo_,
|
padding_,
|
||||||
padding_hi_,
|
|
||||||
kernel_strides_,
|
kernel_strides_,
|
||||||
kernel_dilation_,
|
kernel_dilation_,
|
||||||
input_dilation_,
|
input_dilation_,
|
||||||
@@ -1455,8 +1364,7 @@ void Convolution::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
in,
|
in,
|
||||||
wt,
|
wt,
|
||||||
out,
|
out,
|
||||||
padding_lo_,
|
padding_,
|
||||||
padding_hi_,
|
|
||||||
kernel_strides_,
|
kernel_strides_,
|
||||||
kernel_dilation_,
|
kernel_dilation_,
|
||||||
input_dilation_,
|
input_dilation_,
|
||||||
|
|||||||
@@ -30,7 +30,7 @@ void AllReduce::eval_cpu(
|
|||||||
if (in.is_donatable()) {
|
if (in.is_donatable()) {
|
||||||
out.copy_shared_buffer(in);
|
out.copy_shared_buffer(in);
|
||||||
} else {
|
} else {
|
||||||
out.set_data(allocator::malloc(out.nbytes()));
|
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||||
}
|
}
|
||||||
return in;
|
return in;
|
||||||
} else {
|
} else {
|
||||||
@@ -46,15 +46,8 @@ void AllReduce::eval_cpu(
|
|||||||
case Sum:
|
case Sum:
|
||||||
distributed::detail::all_sum(group(), in, outputs[0], stream());
|
distributed::detail::all_sum(group(), in, outputs[0], stream());
|
||||||
break;
|
break;
|
||||||
case Max:
|
|
||||||
distributed::detail::all_max(group(), in, outputs[0], stream());
|
|
||||||
break;
|
|
||||||
case Min:
|
|
||||||
distributed::detail::all_min(group(), in, outputs[0], stream());
|
|
||||||
break;
|
|
||||||
default:
|
default:
|
||||||
throw std::runtime_error(
|
throw std::runtime_error("Only all reduce sum is supported for now");
|
||||||
"Only all reduce sum, min and max are supported for now");
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -65,7 +58,7 @@ void AllGather::eval_cpu(
|
|||||||
assert(outputs.size() == 1);
|
assert(outputs.size() == 1);
|
||||||
|
|
||||||
auto [in, copied] = ensure_row_contiguous(inputs[0], stream());
|
auto [in, copied] = ensure_row_contiguous(inputs[0], stream());
|
||||||
outputs[0].set_data(allocator::malloc(outputs[0].nbytes()));
|
outputs[0].set_data(allocator::malloc_or_wait(outputs[0].nbytes()));
|
||||||
distributed::detail::all_gather(group(), in, outputs[0], stream());
|
distributed::detail::all_gather(group(), in, outputs[0], stream());
|
||||||
if (copied) {
|
if (copied) {
|
||||||
auto& enc = cpu::get_command_encoder(stream());
|
auto& enc = cpu::get_command_encoder(stream());
|
||||||
@@ -94,7 +87,7 @@ void Recv::eval_cpu(
|
|||||||
assert(inputs.size() == 0);
|
assert(inputs.size() == 0);
|
||||||
assert(outputs.size() == 1);
|
assert(outputs.size() == 1);
|
||||||
|
|
||||||
outputs[0].set_data(allocator::malloc(outputs[0].nbytes()));
|
outputs[0].set_data(allocator::malloc_or_wait(outputs[0].nbytes()));
|
||||||
distributed::detail::recv(group(), outputs[0], src_, stream());
|
distributed::detail::recv(group(), outputs[0], src_, stream());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,174 +0,0 @@
|
|||||||
// Copyright © 2025 Apple Inc.
|
|
||||||
|
|
||||||
#include "mlx/allocator.h"
|
|
||||||
#include "mlx/array.h"
|
|
||||||
#include "mlx/backend/cpu/copy.h"
|
|
||||||
#include "mlx/backend/cpu/encoder.h"
|
|
||||||
#include "mlx/backend/cpu/lapack.h"
|
|
||||||
#include "mlx/linalg.h"
|
|
||||||
#include "mlx/primitives.h"
|
|
||||||
|
|
||||||
namespace mlx::core {
|
|
||||||
|
|
||||||
namespace {
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
void eig_impl(
|
|
||||||
array& a,
|
|
||||||
array& vectors,
|
|
||||||
array& values,
|
|
||||||
bool compute_eigenvectors,
|
|
||||||
Stream stream) {
|
|
||||||
using OT = std::complex<T>;
|
|
||||||
auto a_ptr = a.data<T>();
|
|
||||||
auto eig_ptr = values.data<OT>();
|
|
||||||
|
|
||||||
auto& encoder = cpu::get_command_encoder(stream);
|
|
||||||
encoder.set_input_array(a);
|
|
||||||
encoder.set_output_array(values);
|
|
||||||
OT* vec_ptr = nullptr;
|
|
||||||
if (compute_eigenvectors) {
|
|
||||||
encoder.set_output_array(vectors);
|
|
||||||
vec_ptr = vectors.data<OT>();
|
|
||||||
}
|
|
||||||
encoder.dispatch([a_ptr,
|
|
||||||
vec_ptr,
|
|
||||||
eig_ptr,
|
|
||||||
compute_eigenvectors,
|
|
||||||
N = vectors.shape(-1),
|
|
||||||
size = vectors.size()]() mutable {
|
|
||||||
// Work query
|
|
||||||
char jobr = 'N';
|
|
||||||
char jobl = compute_eigenvectors ? 'V' : 'N';
|
|
||||||
int n_vecs_r = 1;
|
|
||||||
int n_vecs_l = compute_eigenvectors ? N : 1;
|
|
||||||
int lwork = -1;
|
|
||||||
int info;
|
|
||||||
{
|
|
||||||
T work;
|
|
||||||
int iwork;
|
|
||||||
geev<T>(
|
|
||||||
&jobl,
|
|
||||||
&jobr,
|
|
||||||
&N,
|
|
||||||
nullptr,
|
|
||||||
&N,
|
|
||||||
nullptr,
|
|
||||||
nullptr,
|
|
||||||
nullptr,
|
|
||||||
&n_vecs_l,
|
|
||||||
nullptr,
|
|
||||||
&n_vecs_r,
|
|
||||||
&work,
|
|
||||||
&lwork,
|
|
||||||
&info);
|
|
||||||
lwork = static_cast<int>(work);
|
|
||||||
}
|
|
||||||
|
|
||||||
auto eig_tmp_data = array::Data{allocator::malloc(sizeof(T) * N * 2)};
|
|
||||||
auto vec_tmp_data =
|
|
||||||
array::Data{allocator::malloc(vec_ptr ? sizeof(T) * N * N * 2 : 0)};
|
|
||||||
auto eig_tmp = static_cast<T*>(eig_tmp_data.buffer.raw_ptr());
|
|
||||||
auto vec_tmp = static_cast<T*>(vec_tmp_data.buffer.raw_ptr());
|
|
||||||
auto work_buf = array::Data{allocator::malloc(sizeof(T) * lwork)};
|
|
||||||
for (size_t i = 0; i < size / (N * N); ++i) {
|
|
||||||
geev<T>(
|
|
||||||
&jobl,
|
|
||||||
&jobr,
|
|
||||||
&N,
|
|
||||||
a_ptr,
|
|
||||||
&N,
|
|
||||||
eig_tmp,
|
|
||||||
eig_tmp + N,
|
|
||||||
vec_tmp,
|
|
||||||
&n_vecs_l,
|
|
||||||
nullptr,
|
|
||||||
&n_vecs_r,
|
|
||||||
static_cast<T*>(work_buf.buffer.raw_ptr()),
|
|
||||||
&lwork,
|
|
||||||
&info);
|
|
||||||
for (int i = 0; i < N; ++i) {
|
|
||||||
eig_ptr[i] = {eig_tmp[i], eig_tmp[N + i]};
|
|
||||||
}
|
|
||||||
if (vec_ptr) {
|
|
||||||
for (int i = 0; i < N; ++i) {
|
|
||||||
if (eig_ptr[i].imag() != 0) {
|
|
||||||
// This vector and the next are a pair
|
|
||||||
for (int j = 0; j < N; ++j) {
|
|
||||||
vec_ptr[i * N + j] = {
|
|
||||||
vec_tmp[i * N + j], -vec_tmp[(i + 1) * N + j]};
|
|
||||||
vec_ptr[(i + 1) * N + j] = {
|
|
||||||
vec_tmp[i * N + j], vec_tmp[(i + 1) * N + j]};
|
|
||||||
}
|
|
||||||
i += 1;
|
|
||||||
} else {
|
|
||||||
for (int j = 0; j < N; ++j) {
|
|
||||||
vec_ptr[i * N + j] = {vec_tmp[i * N + j], 0};
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
vec_ptr += N * N;
|
|
||||||
}
|
|
||||||
a_ptr += N * N;
|
|
||||||
eig_ptr += N;
|
|
||||||
if (info != 0) {
|
|
||||||
std::stringstream msg;
|
|
||||||
msg << "[Eig::eval_cpu] Eigenvalue decomposition failed with error code "
|
|
||||||
<< info;
|
|
||||||
throw std::runtime_error(msg.str());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
});
|
|
||||||
encoder.add_temporary(a);
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace
|
|
||||||
|
|
||||||
void Eig::eval_cpu(
|
|
||||||
const std::vector<array>& inputs,
|
|
||||||
std::vector<array>& outputs) {
|
|
||||||
const auto& a = inputs[0];
|
|
||||||
auto& values = outputs[0];
|
|
||||||
|
|
||||||
auto vectors = compute_eigenvectors_
|
|
||||||
? outputs[1]
|
|
||||||
: array(a.shape(), complex64, nullptr, {});
|
|
||||||
|
|
||||||
auto a_copy = array(a.shape(), a.dtype(), nullptr, {});
|
|
||||||
copy(
|
|
||||||
a,
|
|
||||||
a_copy,
|
|
||||||
a.flags().row_contiguous ? CopyType::Vector : CopyType::General,
|
|
||||||
stream());
|
|
||||||
|
|
||||||
values.set_data(allocator::malloc(values.nbytes()));
|
|
||||||
|
|
||||||
if (compute_eigenvectors_) {
|
|
||||||
// Set the strides and flags so the eigenvectors
|
|
||||||
// are in the columns of the output
|
|
||||||
auto flags = vectors.flags();
|
|
||||||
auto strides = vectors.strides();
|
|
||||||
auto ndim = a.ndim();
|
|
||||||
std::swap(strides[ndim - 1], strides[ndim - 2]);
|
|
||||||
|
|
||||||
if (a.size() > 1) {
|
|
||||||
flags.row_contiguous = false;
|
|
||||||
if (ndim > 2) {
|
|
||||||
flags.col_contiguous = false;
|
|
||||||
} else {
|
|
||||||
flags.col_contiguous = true;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
vectors.set_data(
|
|
||||||
allocator::malloc(vectors.nbytes()), vectors.size(), strides, flags);
|
|
||||||
}
|
|
||||||
switch (a.dtype()) {
|
|
||||||
case float32:
|
|
||||||
eig_impl<float>(a_copy, vectors, values, compute_eigenvectors_, stream());
|
|
||||||
break;
|
|
||||||
default:
|
|
||||||
throw std::runtime_error("[Eig::eval_cpu] only supports float32.");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace mlx::core
|
|
||||||
@@ -12,133 +12,6 @@ namespace mlx::core {
|
|||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
template <typename T, class Enable = void>
|
|
||||||
struct EighWork {};
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
struct EighWork<
|
|
||||||
T,
|
|
||||||
typename std::enable_if<std::is_floating_point<T>::value>::type> {
|
|
||||||
using R = T;
|
|
||||||
|
|
||||||
char jobz;
|
|
||||||
char uplo;
|
|
||||||
int N;
|
|
||||||
int lwork;
|
|
||||||
int liwork;
|
|
||||||
int info;
|
|
||||||
std::vector<array::Data> buffers;
|
|
||||||
|
|
||||||
EighWork(char jobz_, char uplo_, int N_)
|
|
||||||
: jobz(jobz_), uplo(uplo_), N(N_), lwork(-1), liwork(-1) {
|
|
||||||
T work;
|
|
||||||
int iwork;
|
|
||||||
syevd<T>(
|
|
||||||
&jobz,
|
|
||||||
&uplo,
|
|
||||||
&N,
|
|
||||||
nullptr,
|
|
||||||
&N,
|
|
||||||
nullptr,
|
|
||||||
&work,
|
|
||||||
&lwork,
|
|
||||||
&iwork,
|
|
||||||
&liwork,
|
|
||||||
&info);
|
|
||||||
lwork = static_cast<int>(work);
|
|
||||||
liwork = iwork;
|
|
||||||
buffers.emplace_back(allocator::malloc(sizeof(T) * lwork));
|
|
||||||
buffers.emplace_back(allocator::malloc(sizeof(int) * liwork));
|
|
||||||
}
|
|
||||||
|
|
||||||
void run(T* vectors, T* values) {
|
|
||||||
syevd<T>(
|
|
||||||
&jobz,
|
|
||||||
&uplo,
|
|
||||||
&N,
|
|
||||||
vectors,
|
|
||||||
&N,
|
|
||||||
values,
|
|
||||||
static_cast<T*>(buffers[0].buffer.raw_ptr()),
|
|
||||||
&lwork,
|
|
||||||
static_cast<int*>(buffers[1].buffer.raw_ptr()),
|
|
||||||
&liwork,
|
|
||||||
&info);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
template <>
|
|
||||||
struct EighWork<std::complex<float>> {
|
|
||||||
using T = std::complex<float>;
|
|
||||||
using R = float;
|
|
||||||
|
|
||||||
char jobz;
|
|
||||||
char uplo;
|
|
||||||
int N;
|
|
||||||
int lwork;
|
|
||||||
int lrwork;
|
|
||||||
int liwork;
|
|
||||||
int info;
|
|
||||||
std::vector<array::Data> buffers;
|
|
||||||
|
|
||||||
EighWork(char jobz_, char uplo_, int N_)
|
|
||||||
: jobz(jobz_), uplo(uplo_), N(N_), lwork(-1), lrwork(-1), liwork(-1) {
|
|
||||||
T work;
|
|
||||||
R rwork;
|
|
||||||
int iwork;
|
|
||||||
heevd<T>(
|
|
||||||
&jobz,
|
|
||||||
&uplo,
|
|
||||||
&N,
|
|
||||||
nullptr,
|
|
||||||
&N,
|
|
||||||
nullptr,
|
|
||||||
&work,
|
|
||||||
&lwork,
|
|
||||||
&rwork,
|
|
||||||
&lrwork,
|
|
||||||
&iwork,
|
|
||||||
&liwork,
|
|
||||||
&info);
|
|
||||||
lwork = static_cast<int>(work.real());
|
|
||||||
lrwork = static_cast<int>(rwork);
|
|
||||||
liwork = iwork;
|
|
||||||
buffers.emplace_back(allocator::malloc(sizeof(T) * lwork));
|
|
||||||
buffers.emplace_back(allocator::malloc(sizeof(R) * lrwork));
|
|
||||||
buffers.emplace_back(allocator::malloc(sizeof(int) * liwork));
|
|
||||||
}
|
|
||||||
|
|
||||||
void run(T* vectors, R* values) {
|
|
||||||
heevd<T>(
|
|
||||||
&jobz,
|
|
||||||
&uplo,
|
|
||||||
&N,
|
|
||||||
vectors,
|
|
||||||
&N,
|
|
||||||
values,
|
|
||||||
static_cast<T*>(buffers[0].buffer.raw_ptr()),
|
|
||||||
&lwork,
|
|
||||||
static_cast<R*>(buffers[1].buffer.raw_ptr()),
|
|
||||||
&lrwork,
|
|
||||||
static_cast<int*>(buffers[2].buffer.raw_ptr()),
|
|
||||||
&liwork,
|
|
||||||
&info);
|
|
||||||
if (jobz == 'V') {
|
|
||||||
// We have pre-transposed the vectors but we also must conjugate them
|
|
||||||
// when they are complex.
|
|
||||||
//
|
|
||||||
// We could vectorize this but it is so fast in comparison to heevd that
|
|
||||||
// it doesn't really matter.
|
|
||||||
for (int i = 0; i < N; i++) {
|
|
||||||
for (int j = 0; j < N; j++) {
|
|
||||||
*vectors = std::conj(*vectors);
|
|
||||||
vectors++;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
void eigh_impl(
|
void eigh_impl(
|
||||||
array& vectors,
|
array& vectors,
|
||||||
@@ -146,10 +19,8 @@ void eigh_impl(
|
|||||||
const std::string& uplo,
|
const std::string& uplo,
|
||||||
bool compute_eigenvectors,
|
bool compute_eigenvectors,
|
||||||
Stream stream) {
|
Stream stream) {
|
||||||
using R = typename EighWork<T>::R;
|
|
||||||
|
|
||||||
auto vec_ptr = vectors.data<T>();
|
auto vec_ptr = vectors.data<T>();
|
||||||
auto eig_ptr = values.data<R>();
|
auto eig_ptr = values.data<T>();
|
||||||
char jobz = compute_eigenvectors ? 'V' : 'N';
|
char jobz = compute_eigenvectors ? 'V' : 'N';
|
||||||
|
|
||||||
auto& encoder = cpu::get_command_encoder(stream);
|
auto& encoder = cpu::get_command_encoder(stream);
|
||||||
@@ -162,17 +33,50 @@ void eigh_impl(
|
|||||||
N = vectors.shape(-1),
|
N = vectors.shape(-1),
|
||||||
size = vectors.size()]() mutable {
|
size = vectors.size()]() mutable {
|
||||||
// Work query
|
// Work query
|
||||||
EighWork<T> work(jobz, uplo, N);
|
int lwork = -1;
|
||||||
|
int liwork = -1;
|
||||||
|
int info;
|
||||||
|
{
|
||||||
|
T work;
|
||||||
|
int iwork;
|
||||||
|
syevd<T>(
|
||||||
|
&jobz,
|
||||||
|
&uplo,
|
||||||
|
&N,
|
||||||
|
nullptr,
|
||||||
|
&N,
|
||||||
|
nullptr,
|
||||||
|
&work,
|
||||||
|
&lwork,
|
||||||
|
&iwork,
|
||||||
|
&liwork,
|
||||||
|
&info);
|
||||||
|
lwork = static_cast<int>(work);
|
||||||
|
liwork = iwork;
|
||||||
|
}
|
||||||
|
|
||||||
// Work loop
|
auto work_buf = array::Data{allocator::malloc_or_wait(sizeof(T) * lwork)};
|
||||||
|
auto iwork_buf =
|
||||||
|
array::Data{allocator::malloc_or_wait(sizeof(int) * liwork)};
|
||||||
for (size_t i = 0; i < size / (N * N); ++i) {
|
for (size_t i = 0; i < size / (N * N); ++i) {
|
||||||
work.run(vec_ptr, eig_ptr);
|
syevd<T>(
|
||||||
|
&jobz,
|
||||||
|
&uplo,
|
||||||
|
&N,
|
||||||
|
vec_ptr,
|
||||||
|
&N,
|
||||||
|
eig_ptr,
|
||||||
|
static_cast<T*>(work_buf.buffer.raw_ptr()),
|
||||||
|
&lwork,
|
||||||
|
static_cast<int*>(iwork_buf.buffer.raw_ptr()),
|
||||||
|
&liwork,
|
||||||
|
&info);
|
||||||
vec_ptr += N * N;
|
vec_ptr += N * N;
|
||||||
eig_ptr += N;
|
eig_ptr += N;
|
||||||
if (work.info != 0) {
|
if (info != 0) {
|
||||||
std::stringstream msg;
|
std::stringstream msg;
|
||||||
msg << "[Eigh::eval_cpu] Eigenvalue decomposition failed with error code "
|
msg << "[Eigh::eval_cpu] Eigenvalue decomposition failed with error code "
|
||||||
<< work.info;
|
<< info;
|
||||||
throw std::runtime_error(msg.str());
|
throw std::runtime_error(msg.str());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -194,7 +98,7 @@ void Eigh::eval_cpu(
|
|||||||
? outputs[1]
|
? outputs[1]
|
||||||
: array(a.shape(), a.dtype(), nullptr, {});
|
: array(a.shape(), a.dtype(), nullptr, {});
|
||||||
|
|
||||||
values.set_data(allocator::malloc(values.nbytes()));
|
values.set_data(allocator::malloc_or_wait(values.nbytes()));
|
||||||
|
|
||||||
copy(
|
copy(
|
||||||
a,
|
a,
|
||||||
@@ -228,10 +132,6 @@ void Eigh::eval_cpu(
|
|||||||
eigh_impl<double>(
|
eigh_impl<double>(
|
||||||
vectors, values, uplo_, compute_eigenvectors_, stream());
|
vectors, values, uplo_, compute_eigenvectors_, stream());
|
||||||
break;
|
break;
|
||||||
case complex64:
|
|
||||||
eigh_impl<std::complex<float>>(
|
|
||||||
vectors, values, uplo_, compute_eigenvectors_, stream());
|
|
||||||
break;
|
|
||||||
default:
|
default:
|
||||||
throw std::runtime_error(
|
throw std::runtime_error(
|
||||||
"[Eigh::eval_cpu] only supports float32 or float64.");
|
"[Eigh::eval_cpu] only supports float32 or float64.");
|
||||||
|
|||||||
@@ -9,9 +9,6 @@
|
|||||||
|
|
||||||
namespace mlx::core::cpu {
|
namespace mlx::core::cpu {
|
||||||
|
|
||||||
// Number of dispatches per scheduler task
|
|
||||||
constexpr int DISPATCHES_PER_TASK = 10;
|
|
||||||
|
|
||||||
struct CommandEncoder {
|
struct CommandEncoder {
|
||||||
CommandEncoder(Stream stream) : stream_(stream) {}
|
CommandEncoder(Stream stream) : stream_(stream) {}
|
||||||
|
|
||||||
@@ -42,24 +39,13 @@ struct CommandEncoder {
|
|||||||
|
|
||||||
template <class F, class... Args>
|
template <class F, class... Args>
|
||||||
void dispatch(F&& f, Args&&... args) {
|
void dispatch(F&& f, Args&&... args) {
|
||||||
num_ops_ = (num_ops_ + 1) % DISPATCHES_PER_TASK;
|
|
||||||
auto task = std::bind(std::forward<F>(f), std::forward<Args>(args)...);
|
auto task = std::bind(std::forward<F>(f), std::forward<Args>(args)...);
|
||||||
if (num_ops_ == 0) {
|
scheduler::enqueue(stream_, std::move(task));
|
||||||
scheduler::notify_new_task(stream_);
|
|
||||||
auto task_wrap = [s = stream_, task = std::move(task)]() mutable {
|
|
||||||
task();
|
|
||||||
scheduler::notify_task_completion(s);
|
|
||||||
};
|
|
||||||
scheduler::enqueue(stream_, std::move(task_wrap));
|
|
||||||
} else {
|
|
||||||
scheduler::enqueue(stream_, std::move(task));
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
Stream stream_;
|
Stream stream_;
|
||||||
std::vector<array> temporaries_;
|
std::vector<array> temporaries_;
|
||||||
int num_ops_{0};
|
|
||||||
};
|
};
|
||||||
|
|
||||||
CommandEncoder& get_command_encoder(Stream stream);
|
CommandEncoder& get_command_encoder(Stream stream);
|
||||||
|
|||||||
@@ -33,8 +33,12 @@ void eval(array& arr) {
|
|||||||
buffers.erase(it);
|
buffers.erase(it);
|
||||||
}
|
}
|
||||||
auto& encoder = cpu::get_command_encoder(s);
|
auto& encoder = cpu::get_command_encoder(s);
|
||||||
encoder.dispatch([buffers = std::move(buffers),
|
scheduler::notify_new_task(s);
|
||||||
temps = std::move(encoder.temporaries())]() {});
|
encoder.dispatch([s,
|
||||||
|
buffers = std::move(buffers),
|
||||||
|
temps = std::move(encoder.temporaries())]() {
|
||||||
|
scheduler::notify_task_completion(s);
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace mlx::core::cpu
|
} // namespace mlx::core::cpu
|
||||||
|
|||||||
@@ -22,7 +22,7 @@ void FFT::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
s *= out.itemsize();
|
s *= out.itemsize();
|
||||||
}
|
}
|
||||||
|
|
||||||
out.set_data(allocator::malloc(out.nbytes()));
|
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||||
|
|
||||||
std::vector<size_t> shape;
|
std::vector<size_t> shape;
|
||||||
if (out.dtype() == float32) {
|
if (out.dtype() == float32) {
|
||||||
|
|||||||
27
mlx/backend/cpu/gemms/no_bf16.cpp
Normal file
27
mlx/backend/cpu/gemms/no_bf16.cpp
Normal file
@@ -0,0 +1,27 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#include "mlx/backend/cpu/gemm.h"
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
|
||||||
|
template <>
|
||||||
|
void matmul<bfloat16_t>(
|
||||||
|
const bfloat16_t*,
|
||||||
|
const bfloat16_t*,
|
||||||
|
bfloat16_t*,
|
||||||
|
bool,
|
||||||
|
bool,
|
||||||
|
size_t,
|
||||||
|
size_t,
|
||||||
|
size_t,
|
||||||
|
float,
|
||||||
|
float,
|
||||||
|
size_t,
|
||||||
|
const Shape&,
|
||||||
|
const Strides&,
|
||||||
|
const Shape&,
|
||||||
|
const Strides&) {
|
||||||
|
throw std::runtime_error("[Matmul::eval_cpu] bfloat16 not supported.");
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core
|
||||||
27
mlx/backend/cpu/gemms/no_fp16.cpp
Normal file
27
mlx/backend/cpu/gemms/no_fp16.cpp
Normal file
@@ -0,0 +1,27 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#include "mlx/backend/cpu/gemm.h"
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
|
||||||
|
template <>
|
||||||
|
void matmul<float16_t>(
|
||||||
|
const float16_t*,
|
||||||
|
const float16_t*,
|
||||||
|
float16_t*,
|
||||||
|
bool,
|
||||||
|
bool,
|
||||||
|
size_t,
|
||||||
|
size_t,
|
||||||
|
size_t,
|
||||||
|
float,
|
||||||
|
float,
|
||||||
|
size_t,
|
||||||
|
const Shape&,
|
||||||
|
const Strides&,
|
||||||
|
const Shape&,
|
||||||
|
const Strides&) {
|
||||||
|
throw std::runtime_error("[Matmul::eval_cpu] float16 not supported.");
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core
|
||||||
@@ -1,45 +0,0 @@
|
|||||||
// Copyright © 2025 Apple Inc.
|
|
||||||
|
|
||||||
#include "mlx/backend/common/utils.h"
|
|
||||||
#include "mlx/backend/cpu/gemm.h"
|
|
||||||
#include "mlx/backend/cpu/gemms/simd_gemm.h"
|
|
||||||
|
|
||||||
namespace mlx::core {
|
|
||||||
|
|
||||||
template <>
|
|
||||||
void matmul<bfloat16_t>(
|
|
||||||
const bfloat16_t* a,
|
|
||||||
const bfloat16_t* b,
|
|
||||||
bfloat16_t* out,
|
|
||||||
bool a_transposed,
|
|
||||||
bool b_transposed,
|
|
||||||
size_t lda,
|
|
||||||
size_t ldb,
|
|
||||||
size_t ldc,
|
|
||||||
float alpha,
|
|
||||||
float beta,
|
|
||||||
size_t batch_size,
|
|
||||||
const Shape& a_shape,
|
|
||||||
const Strides& a_strides,
|
|
||||||
const Shape& b_shape,
|
|
||||||
const Strides& b_strides) {
|
|
||||||
auto ndim = a_shape.size();
|
|
||||||
size_t M = a_shape[ndim - 2];
|
|
||||||
size_t N = b_shape[ndim - 1];
|
|
||||||
size_t K = a_shape[ndim - 1];
|
|
||||||
for (int i = 0; i < batch_size; ++i) {
|
|
||||||
simd_gemm<bfloat16_t, float>(
|
|
||||||
a + elem_to_loc(M * K * i, a_shape, a_strides),
|
|
||||||
b + elem_to_loc(K * N * i, b_shape, b_strides),
|
|
||||||
out + M * N * i,
|
|
||||||
a_transposed,
|
|
||||||
b_transposed,
|
|
||||||
M,
|
|
||||||
N,
|
|
||||||
K,
|
|
||||||
alpha,
|
|
||||||
beta);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace mlx::core
|
|
||||||
@@ -1,45 +0,0 @@
|
|||||||
// Copyright © 2025 Apple Inc.
|
|
||||||
|
|
||||||
#include "mlx/backend/common/utils.h"
|
|
||||||
#include "mlx/backend/cpu/gemm.h"
|
|
||||||
#include "mlx/backend/cpu/gemms/simd_gemm.h"
|
|
||||||
|
|
||||||
namespace mlx::core {
|
|
||||||
|
|
||||||
template <>
|
|
||||||
void matmul<float16_t>(
|
|
||||||
const float16_t* a,
|
|
||||||
const float16_t* b,
|
|
||||||
float16_t* out,
|
|
||||||
bool a_transposed,
|
|
||||||
bool b_transposed,
|
|
||||||
size_t lda,
|
|
||||||
size_t ldb,
|
|
||||||
size_t ldc,
|
|
||||||
float alpha,
|
|
||||||
float beta,
|
|
||||||
size_t batch_size,
|
|
||||||
const Shape& a_shape,
|
|
||||||
const Strides& a_strides,
|
|
||||||
const Shape& b_shape,
|
|
||||||
const Strides& b_strides) {
|
|
||||||
auto ndim = a_shape.size();
|
|
||||||
size_t M = a_shape[ndim - 2];
|
|
||||||
size_t N = b_shape[ndim - 1];
|
|
||||||
size_t K = a_shape[ndim - 1];
|
|
||||||
for (int i = 0; i < batch_size; ++i) {
|
|
||||||
simd_gemm<float16_t, float>(
|
|
||||||
a + elem_to_loc(M * K * i, a_shape, a_strides),
|
|
||||||
b + elem_to_loc(K * N * i, b_shape, b_strides),
|
|
||||||
out + M * N * i,
|
|
||||||
a_transposed,
|
|
||||||
b_transposed,
|
|
||||||
M,
|
|
||||||
N,
|
|
||||||
K,
|
|
||||||
alpha,
|
|
||||||
beta);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace mlx::core
|
|
||||||
@@ -1,139 +0,0 @@
|
|||||||
// Copyright © 2025 Apple Inc.
|
|
||||||
#pragma once
|
|
||||||
|
|
||||||
#include "mlx/backend/cpu/simd/simd.h"
|
|
||||||
|
|
||||||
namespace mlx::core {
|
|
||||||
|
|
||||||
inline int ceildiv(int a, int b) {
|
|
||||||
return (a + b - 1) / b;
|
|
||||||
}
|
|
||||||
|
|
||||||
template <int block_size, typename T, typename AccT>
|
|
||||||
void load_block(
|
|
||||||
const T* in,
|
|
||||||
AccT* out,
|
|
||||||
int M,
|
|
||||||
int N,
|
|
||||||
int i,
|
|
||||||
int j,
|
|
||||||
bool transpose) {
|
|
||||||
if (transpose) {
|
|
||||||
for (int ii = 0; ii < block_size && i * block_size + ii < M; ++ii) {
|
|
||||||
for (int jj = 0; jj < block_size && j * block_size + jj < N; ++jj) {
|
|
||||||
out[jj * block_size + ii] =
|
|
||||||
in[(i * block_size + ii) * N + j * block_size + jj];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
for (int ii = 0; ii < block_size && i * block_size + ii < M; ++ii) {
|
|
||||||
for (int jj = 0; jj < block_size && j * block_size + jj < N; ++jj) {
|
|
||||||
out[ii * block_size + jj] =
|
|
||||||
in[(i * block_size + ii) * N + j * block_size + jj];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T, typename AccT>
|
|
||||||
void simd_gemm(
|
|
||||||
const T* a,
|
|
||||||
const T* b,
|
|
||||||
T* c,
|
|
||||||
bool a_trans,
|
|
||||||
bool b_trans,
|
|
||||||
int M,
|
|
||||||
int N,
|
|
||||||
int K,
|
|
||||||
float alpha,
|
|
||||||
float beta) {
|
|
||||||
constexpr int block_size = 16;
|
|
||||||
constexpr int simd_size = simd::max_size<AccT>;
|
|
||||||
static_assert(
|
|
||||||
(block_size % simd_size) == 0,
|
|
||||||
"Block size must be divisible by SIMD size");
|
|
||||||
|
|
||||||
int last_k_block_size = K - block_size * (K / block_size);
|
|
||||||
int last_k_simd_block = (last_k_block_size / simd_size) * simd_size;
|
|
||||||
for (int i = 0; i < ceildiv(M, block_size); i++) {
|
|
||||||
for (int j = 0; j < ceildiv(N, block_size); j++) {
|
|
||||||
AccT c_block[block_size * block_size] = {0.0};
|
|
||||||
AccT a_block[block_size * block_size];
|
|
||||||
AccT b_block[block_size * block_size];
|
|
||||||
|
|
||||||
int k = 0;
|
|
||||||
for (; k < K / block_size; k++) {
|
|
||||||
// Load a and b blocks
|
|
||||||
if (a_trans) {
|
|
||||||
load_block<block_size>(a, a_block, K, M, k, i, true);
|
|
||||||
} else {
|
|
||||||
load_block<block_size>(a, a_block, M, K, i, k, false);
|
|
||||||
}
|
|
||||||
if (b_trans) {
|
|
||||||
load_block<block_size>(b, b_block, N, K, j, k, false);
|
|
||||||
} else {
|
|
||||||
load_block<block_size>(b, b_block, K, N, k, j, true);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Multiply and accumulate
|
|
||||||
for (int ii = 0; ii < block_size && i * block_size + ii < M; ++ii) {
|
|
||||||
for (int jj = 0; jj < block_size && j * block_size + jj < N; ++jj) {
|
|
||||||
for (int kk = 0; kk < block_size; kk += simd_size) {
|
|
||||||
auto av =
|
|
||||||
simd::load<AccT, simd_size>(a_block + ii * block_size + kk);
|
|
||||||
auto bv =
|
|
||||||
simd::load<AccT, simd_size>(b_block + jj * block_size + kk);
|
|
||||||
c_block[ii * block_size + jj] += simd::sum(av * bv);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (last_k_block_size) {
|
|
||||||
// Load a and b blocks
|
|
||||||
if (a_trans) {
|
|
||||||
load_block<block_size>(a, a_block, K, M, k, i, true);
|
|
||||||
} else {
|
|
||||||
load_block<block_size>(a, a_block, M, K, i, k, false);
|
|
||||||
}
|
|
||||||
if (b_trans) {
|
|
||||||
load_block<block_size>(b, b_block, N, K, j, k, false);
|
|
||||||
} else {
|
|
||||||
load_block<block_size>(b, b_block, K, N, k, j, true);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Multiply and accumulate
|
|
||||||
for (int ii = 0; ii < block_size && i * block_size + ii < M; ++ii) {
|
|
||||||
for (int jj = 0; jj < block_size && j * block_size + jj < N; ++jj) {
|
|
||||||
int kk = 0;
|
|
||||||
for (; kk < last_k_simd_block; kk += simd_size) {
|
|
||||||
auto av =
|
|
||||||
simd::load<AccT, simd_size>(a_block + ii * block_size + kk);
|
|
||||||
auto bv =
|
|
||||||
simd::load<AccT, simd_size>(b_block + jj * block_size + kk);
|
|
||||||
c_block[ii * block_size + jj] += simd::sum(av * bv);
|
|
||||||
}
|
|
||||||
for (; kk < last_k_block_size; ++kk) {
|
|
||||||
c_block[ii * block_size + jj] +=
|
|
||||||
a_block[ii * block_size + kk] * b_block[jj * block_size + kk];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Store
|
|
||||||
for (int ii = 0; ii < block_size && i * block_size + ii < M; ++ii) {
|
|
||||||
for (int jj = 0; jj < block_size && j * block_size + jj < N; ++jj) {
|
|
||||||
auto c_idx = (i * block_size + ii) * N + j * block_size + jj;
|
|
||||||
if (beta != 0) {
|
|
||||||
c[c_idx] = static_cast<T>(
|
|
||||||
alpha * c_block[ii * block_size + jj] + beta * c[c_idx]);
|
|
||||||
} else {
|
|
||||||
c[c_idx] = static_cast<T>(alpha * c_block[ii * block_size + jj]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace mlx::core
|
|
||||||
@@ -197,7 +197,7 @@ void dispatch_gather(
|
|||||||
}
|
}
|
||||||
|
|
||||||
void Gather::eval_cpu(const std::vector<array>& inputs, array& out) {
|
void Gather::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||||
out.set_data(allocator::malloc(out.nbytes()));
|
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||||
|
|
||||||
auto& src = inputs[0];
|
auto& src = inputs[0];
|
||||||
std::vector<array> inds;
|
std::vector<array> inds;
|
||||||
@@ -257,11 +257,15 @@ void gather_axis(
|
|||||||
const array& ind,
|
const array& ind,
|
||||||
array& out,
|
array& out,
|
||||||
const int axis) {
|
const int axis) {
|
||||||
auto shape = remove_index(ind.shape(), axis);
|
auto strides = ind.strides();
|
||||||
ContiguousIterator ind_it(
|
strides.erase(strides.begin() + axis);
|
||||||
shape, remove_index(ind.strides(), axis), src.ndim() - 1);
|
auto shape = ind.shape();
|
||||||
ContiguousIterator src_it(
|
shape.erase(shape.begin() + axis);
|
||||||
shape, remove_index(src.strides(), axis), src.ndim() - 1);
|
ContiguousIterator ind_it(shape, strides, src.ndim() - 1);
|
||||||
|
|
||||||
|
strides = src.strides();
|
||||||
|
strides.erase(strides.begin() + axis);
|
||||||
|
ContiguousIterator src_it(shape, strides, src.ndim() - 1);
|
||||||
|
|
||||||
auto ind_ptr = ind.data<IdxT>();
|
auto ind_ptr = ind.data<IdxT>();
|
||||||
auto src_ptr = src.data<T>();
|
auto src_ptr = src.data<T>();
|
||||||
@@ -350,7 +354,7 @@ void dispatch_gather_axis(
|
|||||||
}
|
}
|
||||||
|
|
||||||
void GatherAxis::eval_cpu(const std::vector<array>& inputs, array& out) {
|
void GatherAxis::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||||
out.set_data(allocator::malloc(out.nbytes()));
|
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||||
|
|
||||||
auto& src = inputs[0];
|
auto& src = inputs[0];
|
||||||
auto& inds = inputs[1];
|
auto& inds = inputs[1];
|
||||||
@@ -581,11 +585,15 @@ void Scatter::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
|
|
||||||
template <typename T, typename IdxT, typename OpT>
|
template <typename T, typename IdxT, typename OpT>
|
||||||
void scatter_axis(array& out, const array idx, const array& upd, int axis) {
|
void scatter_axis(array& out, const array idx, const array& upd, int axis) {
|
||||||
auto shape = remove_index(idx.shape(), axis);
|
auto strides = idx.strides();
|
||||||
ContiguousIterator idx_it(
|
strides.erase(strides.begin() + axis);
|
||||||
shape, remove_index(idx.strides(), axis), upd.ndim() - 1);
|
auto shape = idx.shape();
|
||||||
ContiguousIterator upd_it(
|
shape.erase(shape.begin() + axis);
|
||||||
shape, remove_index(upd.strides(), axis), upd.ndim() - 1);
|
ContiguousIterator idx_it(shape, strides, upd.ndim() - 1);
|
||||||
|
|
||||||
|
strides = upd.strides();
|
||||||
|
strides.erase(strides.begin() + axis);
|
||||||
|
ContiguousIterator upd_it(shape, strides, upd.ndim() - 1);
|
||||||
|
|
||||||
auto idx_ptr = idx.data<IdxT>();
|
auto idx_ptr = idx.data<IdxT>();
|
||||||
auto upd_ptr = upd.data<T>();
|
auto upd_ptr = upd.data<T>();
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ namespace mlx::core {
|
|||||||
template <typename T>
|
template <typename T>
|
||||||
void general_inv(T* inv, int N) {
|
void general_inv(T* inv, int N) {
|
||||||
int info;
|
int info;
|
||||||
auto ipiv = array::Data{allocator::malloc(sizeof(int) * N)};
|
auto ipiv = array::Data{allocator::malloc_or_wait(sizeof(int) * N)};
|
||||||
// Compute LU factorization.
|
// Compute LU factorization.
|
||||||
getrf<T>(
|
getrf<T>(
|
||||||
/* m = */ &N,
|
/* m = */ &N,
|
||||||
@@ -49,7 +49,7 @@ void general_inv(T* inv, int N) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
const int lwork = workspace_size;
|
const int lwork = workspace_size;
|
||||||
auto scratch = array::Data{allocator::malloc(sizeof(T) * lwork)};
|
auto scratch = array::Data{allocator::malloc_or_wait(sizeof(T) * lwork)};
|
||||||
|
|
||||||
// Compute inverse.
|
// Compute inverse.
|
||||||
getri<T>(
|
getri<T>(
|
||||||
|
|||||||
@@ -2,14 +2,14 @@
|
|||||||
|
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
|
// Required for Visual Studio.
|
||||||
|
// https://github.com/OpenMathLib/OpenBLAS/blob/develop/docs/install.md
|
||||||
|
#ifdef _MSC_VER
|
||||||
#include <complex>
|
#include <complex>
|
||||||
#define LAPACK_COMPLEX_CUSTOM
|
#define LAPACK_COMPLEX_CUSTOM
|
||||||
#define lapack_complex_float std::complex<float>
|
#define lapack_complex_float std::complex<float>
|
||||||
#define lapack_complex_double std::complex<double>
|
#define lapack_complex_double std::complex<double>
|
||||||
#define lapack_complex_float_real(z) ((z).real())
|
#endif
|
||||||
#define lapack_complex_float_imag(z) ((z).imag())
|
|
||||||
#define lapack_complex_double_real(z) ((z).real())
|
|
||||||
#define lapack_complex_double_imag(z) ((z).imag())
|
|
||||||
|
|
||||||
#ifdef MLX_USE_ACCELERATE
|
#ifdef MLX_USE_ACCELERATE
|
||||||
#include <Accelerate/Accelerate.h>
|
#include <Accelerate/Accelerate.h>
|
||||||
@@ -32,7 +32,7 @@
|
|||||||
|
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#define INSTANTIATE_LAPACK_REAL(FUNC) \
|
#define INSTANTIATE_LAPACK_TYPES(FUNC) \
|
||||||
template <typename T, typename... Args> \
|
template <typename T, typename... Args> \
|
||||||
void FUNC(Args... args) { \
|
void FUNC(Args... args) { \
|
||||||
if constexpr (std::is_same_v<T, float>) { \
|
if constexpr (std::is_same_v<T, float>) { \
|
||||||
@@ -42,24 +42,11 @@
|
|||||||
} \
|
} \
|
||||||
}
|
}
|
||||||
|
|
||||||
INSTANTIATE_LAPACK_REAL(geqrf)
|
INSTANTIATE_LAPACK_TYPES(geqrf)
|
||||||
INSTANTIATE_LAPACK_REAL(orgqr)
|
INSTANTIATE_LAPACK_TYPES(orgqr)
|
||||||
INSTANTIATE_LAPACK_REAL(syevd)
|
INSTANTIATE_LAPACK_TYPES(syevd)
|
||||||
INSTANTIATE_LAPACK_REAL(geev)
|
INSTANTIATE_LAPACK_TYPES(potrf)
|
||||||
INSTANTIATE_LAPACK_REAL(potrf)
|
INSTANTIATE_LAPACK_TYPES(gesvdx)
|
||||||
INSTANTIATE_LAPACK_REAL(gesvdx)
|
INSTANTIATE_LAPACK_TYPES(getrf)
|
||||||
INSTANTIATE_LAPACK_REAL(getrf)
|
INSTANTIATE_LAPACK_TYPES(getri)
|
||||||
INSTANTIATE_LAPACK_REAL(getri)
|
INSTANTIATE_LAPACK_TYPES(trtri)
|
||||||
INSTANTIATE_LAPACK_REAL(trtri)
|
|
||||||
|
|
||||||
#define INSTANTIATE_LAPACK_COMPLEX(FUNC) \
|
|
||||||
template <typename T, typename... Args> \
|
|
||||||
void FUNC(Args... args) { \
|
|
||||||
if constexpr (std::is_same_v<T, std::complex<float>>) { \
|
|
||||||
MLX_LAPACK_FUNC(c##FUNC)(std::forward<Args>(args)...); \
|
|
||||||
} else if constexpr (std::is_same_v<T, std::complex<double>>) { \
|
|
||||||
MLX_LAPACK_FUNC(z##FUNC)(std::forward<Args>(args)...); \
|
|
||||||
} \
|
|
||||||
}
|
|
||||||
|
|
||||||
INSTANTIATE_LAPACK_COMPLEX(heevd)
|
|
||||||
|
|||||||
@@ -1,140 +0,0 @@
|
|||||||
// Copyright © 2023-2024 Apple Inc.
|
|
||||||
|
|
||||||
#include <cassert>
|
|
||||||
#include <cmath>
|
|
||||||
|
|
||||||
#include "mlx/backend/cpu/copy.h"
|
|
||||||
#include "mlx/backend/cpu/encoder.h"
|
|
||||||
#include "mlx/backend/cpu/simd/simd.h"
|
|
||||||
#include "mlx/primitives.h"
|
|
||||||
#include "mlx/types/limits.h"
|
|
||||||
|
|
||||||
namespace mlx::core {
|
|
||||||
|
|
||||||
namespace {
|
|
||||||
|
|
||||||
using namespace mlx::core::simd;
|
|
||||||
|
|
||||||
template <typename T, typename AccT>
|
|
||||||
void logsumexp(const array& in, array& out, Stream stream) {
|
|
||||||
auto& encoder = cpu::get_command_encoder(stream);
|
|
||||||
encoder.set_input_array(in);
|
|
||||||
encoder.set_output_array(out);
|
|
||||||
|
|
||||||
const T* in_ptr = in.data<T>();
|
|
||||||
T* out_ptr = out.data<T>();
|
|
||||||
|
|
||||||
int M = in.shape().back();
|
|
||||||
int L = in.data_size() / M;
|
|
||||||
|
|
||||||
encoder.dispatch([in_ptr, out_ptr, M, L]() mutable {
|
|
||||||
constexpr int N = std::min(max_size<AccT>, max_size<T>);
|
|
||||||
|
|
||||||
const T* current_in_ptr;
|
|
||||||
|
|
||||||
for (int i = 0; i < L; i++, in_ptr += M, out_ptr += 1) {
|
|
||||||
// Find the maximum
|
|
||||||
current_in_ptr = in_ptr;
|
|
||||||
Simd<AccT, N> vmaximum(-numeric_limits<AccT>::infinity());
|
|
||||||
size_t s = M;
|
|
||||||
while (s >= N) {
|
|
||||||
Simd<AccT, N> vals = load<T, N>(current_in_ptr);
|
|
||||||
vmaximum = maximum(vals, vmaximum);
|
|
||||||
current_in_ptr += N;
|
|
||||||
s -= N;
|
|
||||||
}
|
|
||||||
|
|
||||||
AccT maximum = max(vmaximum);
|
|
||||||
while (s-- > 0) {
|
|
||||||
maximum = std::max(maximum, static_cast<AccT>(*current_in_ptr));
|
|
||||||
current_in_ptr++;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Compute the normalizer and the exponentials
|
|
||||||
Simd<AccT, N> vnormalizer(0.0);
|
|
||||||
current_in_ptr = in_ptr;
|
|
||||||
s = M;
|
|
||||||
while (s >= N) {
|
|
||||||
Simd<AccT, N> vexp = load<T, N>(current_in_ptr);
|
|
||||||
vexp = exp(vexp - maximum);
|
|
||||||
vnormalizer = vnormalizer + vexp;
|
|
||||||
current_in_ptr += N;
|
|
||||||
s -= N;
|
|
||||||
}
|
|
||||||
AccT normalizer = sum(vnormalizer);
|
|
||||||
while (s-- > 0) {
|
|
||||||
AccT _exp = std::exp(*current_in_ptr - maximum);
|
|
||||||
normalizer += _exp;
|
|
||||||
current_in_ptr++;
|
|
||||||
}
|
|
||||||
// Normalize
|
|
||||||
*out_ptr = std::isinf(maximum)
|
|
||||||
? static_cast<T>(maximum)
|
|
||||||
: static_cast<T>(std::log(normalizer) + maximum);
|
|
||||||
}
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace
|
|
||||||
|
|
||||||
void LogSumExp::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|
||||||
assert(inputs.size() == 1);
|
|
||||||
|
|
||||||
// Make sure that the last dimension is contiguous
|
|
||||||
auto s = stream();
|
|
||||||
auto& encoder = cpu::get_command_encoder(s);
|
|
||||||
auto ensure_contiguous = [&s, &encoder](const array& x) {
|
|
||||||
if (x.flags().contiguous && x.strides()[x.ndim() - 1] == 1) {
|
|
||||||
return x;
|
|
||||||
} else {
|
|
||||||
auto x_copy = array(x.shape(), x.dtype(), nullptr, {});
|
|
||||||
copy(x, x_copy, CopyType::General, s);
|
|
||||||
encoder.add_temporary(x_copy);
|
|
||||||
return x_copy;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
auto in = ensure_contiguous(inputs[0]);
|
|
||||||
if (in.flags().row_contiguous) {
|
|
||||||
out.set_data(allocator::malloc(out.nbytes()));
|
|
||||||
} else {
|
|
||||||
auto n = in.shape(-1);
|
|
||||||
auto flags = in.flags();
|
|
||||||
auto strides = in.strides();
|
|
||||||
for (auto& s : strides) {
|
|
||||||
s /= n;
|
|
||||||
}
|
|
||||||
bool col_contig = strides[0] == 1;
|
|
||||||
for (int i = 1; col_contig && i < strides.size(); ++i) {
|
|
||||||
col_contig &=
|
|
||||||
(out.shape(i) == 1 || strides[i - 1] == out.shape(i) * strides[i]);
|
|
||||||
}
|
|
||||||
flags.col_contiguous = col_contig;
|
|
||||||
out.set_data(
|
|
||||||
allocator::malloc(in.nbytes() / n),
|
|
||||||
in.data_size() / n,
|
|
||||||
std::move(strides),
|
|
||||||
flags);
|
|
||||||
}
|
|
||||||
|
|
||||||
switch (in.dtype()) {
|
|
||||||
case float32:
|
|
||||||
logsumexp<float, float>(in, out, stream());
|
|
||||||
break;
|
|
||||||
case float16:
|
|
||||||
logsumexp<float16_t, float>(in, out, stream());
|
|
||||||
break;
|
|
||||||
case bfloat16:
|
|
||||||
logsumexp<bfloat16_t, float>(in, out, stream());
|
|
||||||
break;
|
|
||||||
case float64:
|
|
||||||
logsumexp<double, double>(in, out, stream());
|
|
||||||
break;
|
|
||||||
default:
|
|
||||||
throw std::runtime_error(
|
|
||||||
"[logsumexp] only supports floating point types");
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace mlx::core
|
|
||||||
@@ -30,7 +30,8 @@ void luf_impl(
|
|||||||
auto strides = lu.strides();
|
auto strides = lu.strides();
|
||||||
strides[ndim - 1] = M;
|
strides[ndim - 1] = M;
|
||||||
strides[ndim - 2] = 1;
|
strides[ndim - 2] = 1;
|
||||||
lu.set_data(allocator::malloc(lu.nbytes()), lu.nbytes(), strides, flags);
|
lu.set_data(
|
||||||
|
allocator::malloc_or_wait(lu.nbytes()), lu.nbytes(), strides, flags);
|
||||||
copy_inplace(
|
copy_inplace(
|
||||||
a,
|
a,
|
||||||
lu,
|
lu,
|
||||||
@@ -43,8 +44,8 @@ void luf_impl(
|
|||||||
stream);
|
stream);
|
||||||
|
|
||||||
auto a_ptr = lu.data<T>();
|
auto a_ptr = lu.data<T>();
|
||||||
pivots.set_data(allocator::malloc(pivots.nbytes()));
|
pivots.set_data(allocator::malloc_or_wait(pivots.nbytes()));
|
||||||
row_indices.set_data(allocator::malloc(row_indices.nbytes()));
|
row_indices.set_data(allocator::malloc_or_wait(row_indices.nbytes()));
|
||||||
auto pivots_ptr = pivots.data<uint32_t>();
|
auto pivots_ptr = pivots.data<uint32_t>();
|
||||||
auto row_indices_ptr = row_indices.data<uint32_t>();
|
auto row_indices_ptr = row_indices.data<uint32_t>();
|
||||||
size_t num_matrices = a.size() / (M * N);
|
size_t num_matrices = a.size() / (M * N);
|
||||||
|
|||||||
@@ -59,7 +59,7 @@ void BlockMaskedMM::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
throw std::runtime_error(
|
throw std::runtime_error(
|
||||||
"[BlockMaskedMM::eval] Currently only supports float32.");
|
"[BlockMaskedMM::eval] Currently only supports float32.");
|
||||||
}
|
}
|
||||||
out.set_data(allocator::malloc(out.nbytes()));
|
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||||
|
|
||||||
auto& a_pre = inputs[0];
|
auto& a_pre = inputs[0];
|
||||||
auto& b_pre = inputs[1];
|
auto& b_pre = inputs[1];
|
||||||
@@ -318,7 +318,7 @@ void GatherMM::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
throw std::runtime_error(
|
throw std::runtime_error(
|
||||||
"[GatherMM::eval] Currently only supports float32.");
|
"[GatherMM::eval] Currently only supports float32.");
|
||||||
}
|
}
|
||||||
out.set_data(allocator::malloc(out.nbytes()));
|
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||||
|
|
||||||
auto& a_pre = inputs[0];
|
auto& a_pre = inputs[0];
|
||||||
auto& b_pre = inputs[1];
|
auto& b_pre = inputs[1];
|
||||||
|
|||||||
@@ -115,7 +115,7 @@ void matmul_general(
|
|||||||
}
|
}
|
||||||
|
|
||||||
void Matmul::eval_cpu(const std::vector<array>& inputs, array& out) {
|
void Matmul::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||||
out.set_data(allocator::malloc(out.nbytes()));
|
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||||
if (inputs[0].shape(-1) == 0) {
|
if (inputs[0].shape(-1) == 0) {
|
||||||
auto& encoder = cpu::get_command_encoder(stream());
|
auto& encoder = cpu::get_command_encoder(stream());
|
||||||
encoder.set_output_array(out);
|
encoder.set_output_array(out);
|
||||||
@@ -132,10 +132,6 @@ void AddMM::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
throw std::runtime_error(
|
throw std::runtime_error(
|
||||||
"[AddMM::eval_cpu] Currently only supports float32.");
|
"[AddMM::eval_cpu] Currently only supports float32.");
|
||||||
}
|
}
|
||||||
if (out.size() == 0) {
|
|
||||||
out.set_data(allocator::malloc(out.nbytes()));
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Fill output with C
|
// Fill output with C
|
||||||
auto& c = inputs[2];
|
auto& c = inputs[2];
|
||||||
@@ -143,9 +139,7 @@ void AddMM::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
? CopyType::Scalar
|
? CopyType::Scalar
|
||||||
: (c.flags().row_contiguous ? CopyType::Vector : CopyType::General);
|
: (c.flags().row_contiguous ? CopyType::Vector : CopyType::General);
|
||||||
copy(c, out, ctype, stream());
|
copy(c, out, ctype, stream());
|
||||||
if (inputs[0].shape(-1) == 0) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
matmul_general(inputs[0], inputs[1], out, stream(), alpha_, beta_);
|
matmul_general(inputs[0], inputs[1], out, stream(), alpha_, beta_);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -21,7 +21,7 @@ namespace mlx::core {
|
|||||||
void reshape(const array& in, array& out) {
|
void reshape(const array& in, array& out) {
|
||||||
auto [copy_necessary, out_strides] = prepare_reshape(in, out);
|
auto [copy_necessary, out_strides] = prepare_reshape(in, out);
|
||||||
if (copy_necessary) {
|
if (copy_necessary) {
|
||||||
out.set_data(allocator::malloc(out.nbytes()));
|
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||||
copy_inplace(in, out, CopyType::General, out.primitive().stream());
|
copy_inplace(in, out, CopyType::General, out.primitive().stream());
|
||||||
} else {
|
} else {
|
||||||
shared_buffer_reshape(in, out_strides, out);
|
shared_buffer_reshape(in, out_strides, out);
|
||||||
@@ -39,7 +39,7 @@ static std::pair<array, bool> compute_dynamic_offset(
|
|||||||
if (donate) {
|
if (donate) {
|
||||||
offset.copy_shared_buffer(indices);
|
offset.copy_shared_buffer(indices);
|
||||||
} else {
|
} else {
|
||||||
offset.set_data(allocator::malloc(offset.itemsize()));
|
offset.set_data(allocator::malloc_or_wait(offset.itemsize()));
|
||||||
}
|
}
|
||||||
|
|
||||||
auto& encoder = cpu::get_command_encoder(stream);
|
auto& encoder = cpu::get_command_encoder(stream);
|
||||||
@@ -124,7 +124,7 @@ void Transpose::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
|
|
||||||
void Arange::eval_cpu(const std::vector<array>& inputs, array& out) {
|
void Arange::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||||
assert(inputs.size() == 0);
|
assert(inputs.size() == 0);
|
||||||
out.set_data(allocator::malloc(out.nbytes()));
|
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||||
switch (out.dtype()) {
|
switch (out.dtype()) {
|
||||||
case bool_:
|
case bool_:
|
||||||
throw std::runtime_error("Bool type unsupported for arange.");
|
throw std::runtime_error("Bool type unsupported for arange.");
|
||||||
@@ -186,7 +186,7 @@ void Concatenate::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
}
|
}
|
||||||
std::partial_sum(sizes.cbegin(), sizes.cend(), sizes.begin());
|
std::partial_sum(sizes.cbegin(), sizes.cend(), sizes.begin());
|
||||||
|
|
||||||
out.set_data(allocator::malloc(out.nbytes()));
|
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||||
|
|
||||||
auto strides = out.strides();
|
auto strides = out.strides();
|
||||||
auto flags = out.flags();
|
auto flags = out.flags();
|
||||||
@@ -205,10 +205,8 @@ void Concatenate::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
void Contiguous::eval_cpu(const std::vector<array>& inputs, array& out) {
|
void Contiguous::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||||
assert(inputs.size() == 1);
|
assert(inputs.size() == 1);
|
||||||
auto& in = inputs[0];
|
auto& in = inputs[0];
|
||||||
constexpr size_t extra_bytes = 16384;
|
if (in.flags().row_contiguous ||
|
||||||
if (in.buffer_size() <= out.nbytes() + extra_bytes &&
|
(allow_col_major_ && in.flags().col_contiguous)) {
|
||||||
(in.flags().row_contiguous ||
|
|
||||||
(allow_col_major_ && in.flags().col_contiguous))) {
|
|
||||||
out.copy_shared_buffer(in);
|
out.copy_shared_buffer(in);
|
||||||
} else {
|
} else {
|
||||||
copy(in, out, CopyType::General, stream());
|
copy(in, out, CopyType::General, stream());
|
||||||
@@ -278,7 +276,7 @@ void RandomBits::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
|
|
||||||
size_t elems_per_key = out.size() / num_keys;
|
size_t elems_per_key = out.size() / num_keys;
|
||||||
size_t bytes_per_key = out.itemsize() * elems_per_key;
|
size_t bytes_per_key = out.itemsize() * elems_per_key;
|
||||||
out.set_data(allocator::malloc(out.nbytes()));
|
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||||
|
|
||||||
auto kptr = inputs[0].data<uint32_t>();
|
auto kptr = inputs[0].data<uint32_t>();
|
||||||
auto cptr = out.data<char>();
|
auto cptr = out.data<char>();
|
||||||
@@ -337,7 +335,7 @@ void DynamicSlice::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
auto& in = inputs[0];
|
auto& in = inputs[0];
|
||||||
out.set_data(allocator::malloc(out.nbytes()));
|
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||||
auto [in_offset, donated] =
|
auto [in_offset, donated] =
|
||||||
compute_dynamic_offset(inputs[1], in.strides(), axes_, stream());
|
compute_dynamic_offset(inputs[1], in.strides(), axes_, stream());
|
||||||
copy_inplace(
|
copy_inplace(
|
||||||
@@ -452,7 +450,7 @@ void View::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
} else {
|
} else {
|
||||||
auto tmp = array(
|
auto tmp = array(
|
||||||
in.shape(), in.dtype() == bool_ ? uint8 : in.dtype(), nullptr, {});
|
in.shape(), in.dtype() == bool_ ? uint8 : in.dtype(), nullptr, {});
|
||||||
tmp.set_data(allocator::malloc(tmp.nbytes()));
|
tmp.set_data(allocator::malloc_or_wait(tmp.nbytes()));
|
||||||
if (in.dtype() == bool_) {
|
if (in.dtype() == bool_) {
|
||||||
auto in_tmp = array(in.shape(), uint8, nullptr, {});
|
auto in_tmp = array(in.shape(), uint8, nullptr, {});
|
||||||
in_tmp.copy_shared_buffer(in);
|
in_tmp.copy_shared_buffer(in);
|
||||||
|
|||||||
@@ -25,11 +25,12 @@ void qrf_impl(const array& a, array& q, array& r, Stream stream) {
|
|||||||
auto strides = in.strides();
|
auto strides = in.strides();
|
||||||
strides[in.ndim() - 2] = 1;
|
strides[in.ndim() - 2] = 1;
|
||||||
strides[in.ndim() - 1] = M;
|
strides[in.ndim() - 1] = M;
|
||||||
in.set_data(allocator::malloc(in.nbytes()), in.nbytes(), strides, flags);
|
in.set_data(
|
||||||
|
allocator::malloc_or_wait(in.nbytes()), in.nbytes(), strides, flags);
|
||||||
copy_inplace(a, in, CopyType::GeneralGeneral, stream);
|
copy_inplace(a, in, CopyType::GeneralGeneral, stream);
|
||||||
auto& encoder = cpu::get_command_encoder(stream);
|
auto& encoder = cpu::get_command_encoder(stream);
|
||||||
q.set_data(allocator::malloc(q.nbytes()));
|
q.set_data(allocator::malloc_or_wait(q.nbytes()));
|
||||||
r.set_data(allocator::malloc(r.nbytes()));
|
r.set_data(allocator::malloc_or_wait(r.nbytes()));
|
||||||
|
|
||||||
auto in_ptr = in.data<T>();
|
auto in_ptr = in.data<T>();
|
||||||
auto r_ptr = r.data<T>();
|
auto r_ptr = r.data<T>();
|
||||||
@@ -40,7 +41,8 @@ void qrf_impl(const array& a, array& q, array& r, Stream stream) {
|
|||||||
encoder.set_output_array(r);
|
encoder.set_output_array(r);
|
||||||
encoder.dispatch([in_ptr, q_ptr, r_ptr, M, N, lda, num_matrices]() {
|
encoder.dispatch([in_ptr, q_ptr, r_ptr, M, N, lda, num_matrices]() {
|
||||||
int num_reflectors = std::min(M, N);
|
int num_reflectors = std::min(M, N);
|
||||||
auto tau = allocator::malloc(sizeof(T) * num_matrices * num_reflectors);
|
auto tau =
|
||||||
|
allocator::malloc_or_wait(sizeof(T) * num_matrices * num_reflectors);
|
||||||
|
|
||||||
T optimal_work;
|
T optimal_work;
|
||||||
int lwork = -1;
|
int lwork = -1;
|
||||||
@@ -51,7 +53,7 @@ void qrf_impl(const array& a, array& q, array& r, Stream stream) {
|
|||||||
|
|
||||||
// Update workspace size
|
// Update workspace size
|
||||||
lwork = optimal_work;
|
lwork = optimal_work;
|
||||||
auto work = allocator::malloc(sizeof(T) * lwork);
|
auto work = allocator::malloc_or_wait(sizeof(T) * lwork);
|
||||||
|
|
||||||
// Loop over matrices
|
// Loop over matrices
|
||||||
for (int i = 0; i < num_matrices; ++i) {
|
for (int i = 0; i < num_matrices; ++i) {
|
||||||
@@ -94,7 +96,7 @@ void qrf_impl(const array& a, array& q, array& r, Stream stream) {
|
|||||||
&lwork,
|
&lwork,
|
||||||
&info);
|
&info);
|
||||||
lwork = optimal_work;
|
lwork = optimal_work;
|
||||||
work = allocator::malloc(sizeof(T) * lwork);
|
work = allocator::malloc_or_wait(sizeof(T) * lwork);
|
||||||
|
|
||||||
// Loop over matrices
|
// Loop over matrices
|
||||||
for (int i = 0; i < num_matrices; ++i) {
|
for (int i = 0; i < num_matrices; ++i) {
|
||||||
|
|||||||
@@ -13,18 +13,9 @@ namespace mlx::core {
|
|||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
inline constexpr short get_pack_factor(int bits, int wsize = 8) {
|
|
||||||
return (bits == 3 || bits == 5) ? 8 : (bits == 6 ? 4 : wsize / bits);
|
|
||||||
}
|
|
||||||
|
|
||||||
inline constexpr short get_bytes_per_pack(int bits, int wsize = 8) {
|
|
||||||
auto power_of_2_bits = (bits & (bits - 1)) == 0;
|
|
||||||
return power_of_2_bits ? (wsize / 8) : (bits == 5 ? 5 : 3);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T, int bits>
|
template <typename T, int bits>
|
||||||
void extract_bits(const uint8_t* w_in, T* w_out) {
|
void extract_bits(const uint8_t* w_in, T* w_out) {
|
||||||
static_assert(bits == 3 || bits == 5 || bits == 6);
|
assert(bits == 3 || bits == 6);
|
||||||
if (bits == 3) {
|
if (bits == 3) {
|
||||||
w_out[0] = static_cast<T>(w_in[0] & 0x7);
|
w_out[0] = static_cast<T>(w_in[0] & 0x7);
|
||||||
w_out[1] = static_cast<T>((w_in[0] & 0x38) >> 3);
|
w_out[1] = static_cast<T>((w_in[0] & 0x38) >> 3);
|
||||||
@@ -34,16 +25,6 @@ void extract_bits(const uint8_t* w_in, T* w_out) {
|
|||||||
w_out[5] = static_cast<T>(((w_in[1] & 0x80) >> 7) + ((w_in[2] & 0x3) << 1));
|
w_out[5] = static_cast<T>(((w_in[1] & 0x80) >> 7) + ((w_in[2] & 0x3) << 1));
|
||||||
w_out[6] = static_cast<T>((w_in[2] & 0x1c) >> 2);
|
w_out[6] = static_cast<T>((w_in[2] & 0x1c) >> 2);
|
||||||
w_out[7] = static_cast<T>((w_in[2] & 0xe0) >> 5);
|
w_out[7] = static_cast<T>((w_in[2] & 0xe0) >> 5);
|
||||||
} else if (bits == 5) {
|
|
||||||
w_out[0] = static_cast<T>(w_in[0] & 0x1f);
|
|
||||||
w_out[1] = static_cast<T>(((w_in[0] & 0xe0) >> 5) + ((w_in[1] & 0x3) << 3));
|
|
||||||
w_out[2] = static_cast<T>((w_in[1] & 0x7c) >> 2);
|
|
||||||
w_out[3] = static_cast<T>(((w_in[1] & 0x80) >> 7) + ((w_in[2] & 0xf) << 1));
|
|
||||||
w_out[4] = static_cast<T>(((w_in[2] & 0xf0) >> 4) + ((w_in[3] & 0x1) << 4));
|
|
||||||
w_out[5] = static_cast<T>((w_in[3] & 0x3e) >> 1);
|
|
||||||
w_out[6] = static_cast<T>(((w_in[3] & 0xc0) >> 6) + ((w_in[4] & 0x7) << 2));
|
|
||||||
w_out[7] = static_cast<T>((w_in[4] & 0xf8) >> 3);
|
|
||||||
|
|
||||||
} else if (bits == 6) {
|
} else if (bits == 6) {
|
||||||
w_out[0] = static_cast<T>(w_in[0] & 0x3f);
|
w_out[0] = static_cast<T>(w_in[0] & 0x3f);
|
||||||
w_out[1] =
|
w_out[1] =
|
||||||
@@ -65,8 +46,8 @@ void _qmm(
|
|||||||
int N,
|
int N,
|
||||||
int K) {
|
int K) {
|
||||||
constexpr int bitmask = (1 << bits) - 1;
|
constexpr int bitmask = (1 << bits) - 1;
|
||||||
constexpr int pack_factor = get_pack_factor(bits, 8);
|
constexpr int pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 8 / bits;
|
||||||
constexpr int bytes_per_pack = get_bytes_per_pack(bits);
|
constexpr int bytes_per_pack = (bits == 3 || bits == 6) ? 3 : 1;
|
||||||
constexpr int packs_in_group = group_size / pack_factor;
|
constexpr int packs_in_group = group_size / pack_factor;
|
||||||
|
|
||||||
for (int m = 0; m < M; m++) {
|
for (int m = 0; m < M; m++) {
|
||||||
@@ -84,7 +65,7 @@ void _qmm(
|
|||||||
T scale = *scales_local++;
|
T scale = *scales_local++;
|
||||||
T bias = *biases_local++;
|
T bias = *biases_local++;
|
||||||
for (int ng = 0; ng < packs_in_group; ng++) {
|
for (int ng = 0; ng < packs_in_group; ng++) {
|
||||||
if constexpr (bits == 3 || bits == 5 || bits == 6) {
|
if (bits == 3 || bits == 6) {
|
||||||
T wl[pack_factor];
|
T wl[pack_factor];
|
||||||
extract_bits<T, bits>(w_local, wl);
|
extract_bits<T, bits>(w_local, wl);
|
||||||
#pragma clang loop unroll(full)
|
#pragma clang loop unroll(full)
|
||||||
@@ -123,9 +104,8 @@ void _qmm_t(
|
|||||||
int N,
|
int N,
|
||||||
int K) {
|
int K) {
|
||||||
constexpr int bitmask = (1 << bits) - 1;
|
constexpr int bitmask = (1 << bits) - 1;
|
||||||
|
constexpr int pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 8 / bits;
|
||||||
constexpr int pack_factor = get_pack_factor(bits, 8);
|
constexpr int bytes_per_pack = (bits == 3 || bits == 6) ? 3 : 1;
|
||||||
constexpr int bytes_per_pack = get_bytes_per_pack(bits);
|
|
||||||
constexpr int packs_in_group = group_size / pack_factor;
|
constexpr int packs_in_group = group_size / pack_factor;
|
||||||
|
|
||||||
for (int m = 0; m < M; m++) {
|
for (int m = 0; m < M; m++) {
|
||||||
@@ -141,7 +121,7 @@ void _qmm_t(
|
|||||||
T bias = *biases_local++;
|
T bias = *biases_local++;
|
||||||
|
|
||||||
for (int kw = 0; kw < packs_in_group; kw++) {
|
for (int kw = 0; kw < packs_in_group; kw++) {
|
||||||
if constexpr (bits == 3 || bits == 5 || bits == 6) {
|
if (bits == 3 || bits == 6) {
|
||||||
T wl[pack_factor];
|
T wl[pack_factor];
|
||||||
extract_bits<T, bits>(w_local, wl);
|
extract_bits<T, bits>(w_local, wl);
|
||||||
#pragma clang loop unroll(full)
|
#pragma clang loop unroll(full)
|
||||||
@@ -324,10 +304,6 @@ void _qmm_dispatch_typed(
|
|||||||
_qmm_dispatch_group<T, 4>(
|
_qmm_dispatch_group<T, 4>(
|
||||||
result, x, w, scales, biases, M, N, K, group_size, transposed_w);
|
result, x, w, scales, biases, M, N, K, group_size, transposed_w);
|
||||||
break;
|
break;
|
||||||
case 5:
|
|
||||||
_qmm_dispatch_group<T, 5>(
|
|
||||||
result, x, w, scales, biases, M, N, K, group_size, transposed_w);
|
|
||||||
break;
|
|
||||||
case 6:
|
case 6:
|
||||||
_qmm_dispatch_group<T, 6>(
|
_qmm_dispatch_group<T, 6>(
|
||||||
result, x, w, scales, biases, M, N, K, group_size, transposed_w);
|
result, x, w, scales, biases, M, N, K, group_size, transposed_w);
|
||||||
@@ -539,7 +515,7 @@ void QuantizedMatmul::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
auto scales = ensure_row_contiguous(scales_pre);
|
auto scales = ensure_row_contiguous(scales_pre);
|
||||||
auto biases = ensure_row_contiguous(biases_pre);
|
auto biases = ensure_row_contiguous(biases_pre);
|
||||||
|
|
||||||
out.set_data(allocator::malloc(out.nbytes()));
|
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||||
|
|
||||||
auto& encoder = cpu::get_command_encoder(stream());
|
auto& encoder = cpu::get_command_encoder(stream());
|
||||||
encoder.add_temporaries(std::move(temps));
|
encoder.add_temporaries(std::move(temps));
|
||||||
@@ -589,7 +565,7 @@ void GatherQMM::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
auto scales = ensure_row_contiguous_last_dims(scales_pre);
|
auto scales = ensure_row_contiguous_last_dims(scales_pre);
|
||||||
auto biases = ensure_row_contiguous_last_dims(biases_pre);
|
auto biases = ensure_row_contiguous_last_dims(biases_pre);
|
||||||
|
|
||||||
out.set_data(allocator::malloc(out.nbytes()));
|
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||||
|
|
||||||
auto& encoder = cpu::get_command_encoder(stream());
|
auto& encoder = cpu::get_command_encoder(stream());
|
||||||
encoder.add_temporaries(std::move(temps));
|
encoder.add_temporaries(std::move(temps));
|
||||||
@@ -637,8 +613,9 @@ void quantize(
|
|||||||
float eps = 1e-7;
|
float eps = 1e-7;
|
||||||
|
|
||||||
bool power_of_2_bits = is_power_of_2(bits);
|
bool power_of_2_bits = is_power_of_2(bits);
|
||||||
int el_per_int = get_pack_factor(bits, 32);
|
int el_per_int = bits == 3 ? 8 : bits == 6 ? 4 : 32 / bits;
|
||||||
int bytes_per_pack = get_bytes_per_pack(bits);
|
// For 3/6 bits we read 3 uint8s at a time instead of 1 uint32
|
||||||
|
int bytes_per_pack = power_of_2_bits ? 1 : 3;
|
||||||
int int_per_group = group_size * bytes_per_pack / el_per_int;
|
int int_per_group = group_size * bytes_per_pack / el_per_int;
|
||||||
size_t n_groups = w_size / group_size;
|
size_t n_groups = w_size / group_size;
|
||||||
|
|
||||||
@@ -663,21 +640,15 @@ void quantize(
|
|||||||
}
|
}
|
||||||
size_t out_idx = i * int_per_group;
|
size_t out_idx = i * int_per_group;
|
||||||
for (int j = 0; j < int_per_group / bytes_per_pack; ++j) {
|
for (int j = 0; j < int_per_group / bytes_per_pack; ++j) {
|
||||||
uint64_t out_el = 0;
|
uint32_t out_el = 0;
|
||||||
for (int k = 0; k < el_per_int; ++k) {
|
for (int k = 0; k < el_per_int; ++k) {
|
||||||
float w_el = w[w_idx + j * el_per_int + k];
|
float w_el = w[w_idx + j * el_per_int + k];
|
||||||
w_el = std::rint((w_el - bias) / scale);
|
w_el = std::rint((w_el - bias) / scale);
|
||||||
w_el = std::min(std::max(w_el, 0.0f), n_bins);
|
w_el = std::min(std::max(w_el, 0.0f), n_bins);
|
||||||
out_el |= static_cast<uint64_t>(w_el) << (k * bits);
|
out_el |= static_cast<uint32_t>(w_el) << (k * bits);
|
||||||
}
|
}
|
||||||
if (power_of_2_bits) {
|
if (power_of_2_bits) {
|
||||||
out[out_idx + j] = out_el;
|
out[out_idx + j] = out_el;
|
||||||
} else if (bits == 5) {
|
|
||||||
out[out_idx + bytes_per_pack * j] = out_el & 0xff;
|
|
||||||
out[out_idx + bytes_per_pack * j + 1] = (out_el & 0xff00) >> 8;
|
|
||||||
out[out_idx + bytes_per_pack * j + 2] = (out_el & 0xff0000) >> 16;
|
|
||||||
out[out_idx + bytes_per_pack * j + 3] = (out_el & 0xff000000) >> 24;
|
|
||||||
out[out_idx + bytes_per_pack * j + 4] = (out_el & 0xff00000000) >> 32;
|
|
||||||
} else {
|
} else {
|
||||||
out[out_idx + bytes_per_pack * j] = out_el & 0xff;
|
out[out_idx + bytes_per_pack * j] = out_el & 0xff;
|
||||||
out[out_idx + bytes_per_pack * j + 1] = (out_el & 0xff00) >> 8;
|
out[out_idx + bytes_per_pack * j + 1] = (out_el & 0xff00) >> 8;
|
||||||
@@ -720,12 +691,12 @@ void fast::AffineQuantize::eval_cpu(
|
|||||||
|
|
||||||
auto [w, copied] = ensure_row_contiguous(inputs[0]);
|
auto [w, copied] = ensure_row_contiguous(inputs[0]);
|
||||||
auto& out = outputs[0];
|
auto& out = outputs[0];
|
||||||
out.set_data(allocator::malloc(out.nbytes()));
|
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||||
|
|
||||||
auto& scales = outputs[1];
|
auto& scales = outputs[1];
|
||||||
auto& biases = outputs[2];
|
auto& biases = outputs[2];
|
||||||
scales.set_data(allocator::malloc(scales.nbytes()));
|
scales.set_data(allocator::malloc_or_wait(scales.nbytes()));
|
||||||
biases.set_data(allocator::malloc(biases.nbytes()));
|
biases.set_data(allocator::malloc_or_wait(biases.nbytes()));
|
||||||
auto& encoder = cpu::get_command_encoder(stream());
|
auto& encoder = cpu::get_command_encoder(stream());
|
||||||
if (copied) {
|
if (copied) {
|
||||||
encoder.add_temporary(w);
|
encoder.add_temporary(w);
|
||||||
|
|||||||
@@ -433,7 +433,7 @@ void reduce_dispatch_min_max(
|
|||||||
void Reduce::eval_cpu(const std::vector<array>& inputs, array& out) {
|
void Reduce::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||||
assert(inputs.size() == 1);
|
assert(inputs.size() == 1);
|
||||||
auto& in = inputs[0];
|
auto& in = inputs[0];
|
||||||
out.set_data(allocator::malloc(out.nbytes()));
|
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||||
auto& encoder = cpu::get_command_encoder(stream());
|
auto& encoder = cpu::get_command_encoder(stream());
|
||||||
encoder.set_input_array(in);
|
encoder.set_input_array(in);
|
||||||
encoder.set_output_array(out);
|
encoder.set_output_array(out);
|
||||||
|
|||||||
@@ -3,7 +3,6 @@
|
|||||||
#include <cassert>
|
#include <cassert>
|
||||||
|
|
||||||
#include "mlx/backend/common/utils.h"
|
#include "mlx/backend/common/utils.h"
|
||||||
#include "mlx/backend/cpu/binary_ops.h"
|
|
||||||
#include "mlx/backend/cpu/copy.h"
|
#include "mlx/backend/cpu/copy.h"
|
||||||
#include "mlx/backend/cpu/encoder.h"
|
#include "mlx/backend/cpu/encoder.h"
|
||||||
#include "mlx/backend/cpu/simd/simd.h"
|
#include "mlx/backend/cpu/simd/simd.h"
|
||||||
@@ -227,16 +226,6 @@ void scan_dispatch(
|
|||||||
scan_op<T, U>(in, out, axis, reverse, inclusive, op, init);
|
scan_op<T, U>(in, out, axis, reverse, inclusive, op, init);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
case Scan::LogAddExp: {
|
|
||||||
auto op = [](U a, T b) {
|
|
||||||
return detail::LogAddExp{}(a, static_cast<U>(b));
|
|
||||||
};
|
|
||||||
auto init = (issubdtype(in.dtype(), floating))
|
|
||||||
? static_cast<U>(-std::numeric_limits<float>::infinity())
|
|
||||||
: std::numeric_limits<U>::min();
|
|
||||||
scan_op<T, U>(in, out, axis, reverse, inclusive, op, init);
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -255,7 +244,7 @@ void Scan::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
in = arr_copy;
|
in = arr_copy;
|
||||||
encoder.add_temporary(arr_copy);
|
encoder.add_temporary(arr_copy);
|
||||||
}
|
}
|
||||||
out.set_data(allocator::malloc(out.nbytes()));
|
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||||
|
|
||||||
encoder.set_input_array(in);
|
encoder.set_input_array(in);
|
||||||
encoder.set_output_array(out);
|
encoder.set_output_array(out);
|
||||||
@@ -330,8 +319,7 @@ void Scan::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
reduce_type_, in, out, axis_, reverse_, inclusive_);
|
reduce_type_, in, out, axis_, reverse_, inclusive_);
|
||||||
break;
|
break;
|
||||||
case complex64:
|
case complex64:
|
||||||
scan_dispatch<complex64_t, complex64_t>(
|
throw std::runtime_error("Scan ops do not support complex types yet");
|
||||||
reduce_type_, in, out, axis_, reverse_, inclusive_);
|
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|||||||
@@ -17,7 +17,7 @@ struct ScalarT<float16_t, N> {
|
|||||||
#endif
|
#endif
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
inline constexpr int max_size<float16_t> = N;
|
static constexpr int max_size<float16_t> = N;
|
||||||
|
|
||||||
#define SIMD_FP16_DEFAULT_UNARY(op) \
|
#define SIMD_FP16_DEFAULT_UNARY(op) \
|
||||||
template <> \
|
template <> \
|
||||||
|
|||||||
@@ -83,25 +83,25 @@ struct Simd {
|
|||||||
// Values chosen based on benchmarks on M3 Max
|
// Values chosen based on benchmarks on M3 Max
|
||||||
// TODO: consider choosing these more optimally
|
// TODO: consider choosing these more optimally
|
||||||
template <>
|
template <>
|
||||||
inline constexpr int max_size<int8_t> = 16;
|
static constexpr int max_size<int8_t> = 16;
|
||||||
template <>
|
template <>
|
||||||
inline constexpr int max_size<int16_t> = 16;
|
static constexpr int max_size<int16_t> = 16;
|
||||||
template <>
|
template <>
|
||||||
inline constexpr int max_size<int> = 8;
|
static constexpr int max_size<int> = 8;
|
||||||
template <>
|
template <>
|
||||||
inline constexpr int max_size<int64_t> = 4;
|
static constexpr int max_size<int64_t> = 4;
|
||||||
template <>
|
template <>
|
||||||
inline constexpr int max_size<uint8_t> = 16;
|
static constexpr int max_size<uint8_t> = 16;
|
||||||
template <>
|
template <>
|
||||||
inline constexpr int max_size<uint16_t> = 16;
|
static constexpr int max_size<uint16_t> = 16;
|
||||||
template <>
|
template <>
|
||||||
inline constexpr int max_size<uint32_t> = 8;
|
static constexpr int max_size<uint32_t> = 8;
|
||||||
template <>
|
template <>
|
||||||
inline constexpr int max_size<uint64_t> = 4;
|
static constexpr int max_size<uint64_t> = 4;
|
||||||
template <>
|
template <>
|
||||||
inline constexpr int max_size<float> = 8;
|
static constexpr int max_size<float> = 8;
|
||||||
template <>
|
template <>
|
||||||
inline constexpr int max_size<double> = 4;
|
static constexpr int max_size<double> = 4;
|
||||||
|
|
||||||
#define SIMD_DEFAULT_UNARY(name, op) \
|
#define SIMD_DEFAULT_UNARY(name, op) \
|
||||||
template <typename T, int N> \
|
template <typename T, int N> \
|
||||||
|
|||||||
@@ -87,45 +87,14 @@ DEFAULT_UNARY(cosh, std::cosh)
|
|||||||
DEFAULT_UNARY(expm1, std::expm1)
|
DEFAULT_UNARY(expm1, std::expm1)
|
||||||
DEFAULT_UNARY(floor, std::floor)
|
DEFAULT_UNARY(floor, std::floor)
|
||||||
DEFAULT_UNARY(log, std::log)
|
DEFAULT_UNARY(log, std::log)
|
||||||
|
DEFAULT_UNARY(log2, std::log2)
|
||||||
DEFAULT_UNARY(log10, std::log10)
|
DEFAULT_UNARY(log10, std::log10)
|
||||||
|
DEFAULT_UNARY(log1p, std::log1p)
|
||||||
DEFAULT_UNARY(sinh, std::sinh)
|
DEFAULT_UNARY(sinh, std::sinh)
|
||||||
DEFAULT_UNARY(sqrt, std::sqrt)
|
DEFAULT_UNARY(sqrt, std::sqrt)
|
||||||
DEFAULT_UNARY(tan, std::tan)
|
DEFAULT_UNARY(tan, std::tan)
|
||||||
DEFAULT_UNARY(tanh, std::tanh)
|
DEFAULT_UNARY(tanh, std::tanh)
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
Simd<T, 1> log1p(Simd<T, 1> in) {
|
|
||||||
if constexpr (is_complex<T>) {
|
|
||||||
auto x = in.value.real();
|
|
||||||
auto y = in.value.imag();
|
|
||||||
auto zabs = std::abs(in.value);
|
|
||||||
auto theta = std::atan2(y, x + 1);
|
|
||||||
if (zabs < 0.5) {
|
|
||||||
auto r = x * (2 + x) + y * y;
|
|
||||||
if (r == 0) { // handle underflow
|
|
||||||
return Simd<T, 1>{T{x, theta}};
|
|
||||||
}
|
|
||||||
return Simd<T, 1>{T{((typeof(x))(0.5)) * std::log1p(r), theta}};
|
|
||||||
} else {
|
|
||||||
auto z0 = std::hypot(x + 1, y);
|
|
||||||
return Simd<T, 1>{T{std::log(z0), theta}};
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
return Simd<T, 1>{std::log1p(in.value)};
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
Simd<T, 1> log2(Simd<T, 1> in) {
|
|
||||||
if constexpr (is_complex<T>) {
|
|
||||||
auto out = std::log(in.value);
|
|
||||||
auto scale = decltype(out.real())(M_LN2);
|
|
||||||
return Simd<T, 1>{T{out.real() / scale, out.imag() / scale}};
|
|
||||||
} else {
|
|
||||||
return Simd<T, 1>{std::log2(in.value)};
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
Simd<T, 1> operator~(Simd<T, 1> in) {
|
Simd<T, 1> operator~(Simd<T, 1> in) {
|
||||||
return ~in.value;
|
return ~in.value;
|
||||||
|
|||||||
@@ -119,12 +119,17 @@ void Softmax::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
|
|
||||||
// Make sure that the last dimension is contiguous
|
// Make sure that the last dimension is contiguous
|
||||||
auto set_output = [s = stream(), &out](const array& x) {
|
auto set_output = [s = stream(), &out](const array& x) {
|
||||||
if (x.flags().contiguous && x.strides()[x.ndim() - 1] == 1) {
|
bool no_copy = x.strides()[x.ndim() - 1] == 1;
|
||||||
|
if (x.ndim() > 1) {
|
||||||
|
auto s = x.strides()[x.ndim() - 2];
|
||||||
|
no_copy &= (s == 0 || s == x.shape().back());
|
||||||
|
}
|
||||||
|
if (no_copy) {
|
||||||
if (x.is_donatable()) {
|
if (x.is_donatable()) {
|
||||||
out.copy_shared_buffer(x);
|
out.copy_shared_buffer(x);
|
||||||
} else {
|
} else {
|
||||||
out.set_data(
|
out.set_data(
|
||||||
allocator::malloc(x.data_size() * x.itemsize()),
|
allocator::malloc_or_wait(x.data_size() * x.itemsize()),
|
||||||
x.data_size(),
|
x.data_size(),
|
||||||
x.strides(),
|
x.strides(),
|
||||||
x.flags());
|
x.flags());
|
||||||
@@ -141,6 +146,18 @@ void Softmax::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
auto in = set_output(inputs[0]);
|
auto in = set_output(inputs[0]);
|
||||||
|
|
||||||
switch (in.dtype()) {
|
switch (in.dtype()) {
|
||||||
|
case bool_:
|
||||||
|
case uint8:
|
||||||
|
case uint16:
|
||||||
|
case uint32:
|
||||||
|
case uint64:
|
||||||
|
case int8:
|
||||||
|
case int16:
|
||||||
|
case int32:
|
||||||
|
case int64:
|
||||||
|
throw std::runtime_error(
|
||||||
|
"Softmax is defined only for floating point types");
|
||||||
|
break;
|
||||||
case float32:
|
case float32:
|
||||||
softmax<float, float>(in, out, stream());
|
softmax<float, float>(in, out, stream());
|
||||||
break;
|
break;
|
||||||
@@ -161,9 +178,9 @@ void Softmax::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
case float64:
|
case float64:
|
||||||
softmax<double, double>(in, out, stream());
|
softmax<double, double>(in, out, stream());
|
||||||
break;
|
break;
|
||||||
default:
|
case complex64:
|
||||||
throw std::runtime_error(
|
throw std::invalid_argument(
|
||||||
"[softmax] Only defined for floating point types.");
|
"[Softmax] Not yet implemented for complex64");
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -288,7 +288,7 @@ void ArgSort::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
auto& in = inputs[0];
|
auto& in = inputs[0];
|
||||||
|
|
||||||
// Allocate output
|
// Allocate output
|
||||||
out.set_data(allocator::malloc(out.nbytes()));
|
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||||
|
|
||||||
auto& encoder = cpu::get_command_encoder(stream());
|
auto& encoder = cpu::get_command_encoder(stream());
|
||||||
encoder.set_input_array(in);
|
encoder.set_input_array(in);
|
||||||
@@ -379,7 +379,7 @@ void ArgPartition::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
auto& in = inputs[0];
|
auto& in = inputs[0];
|
||||||
|
|
||||||
// Allocate output
|
// Allocate output
|
||||||
out.set_data(allocator::malloc(out.nbytes()));
|
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||||
|
|
||||||
auto& encoder = cpu::get_command_encoder(stream());
|
auto& encoder = cpu::get_command_encoder(stream());
|
||||||
encoder.set_input_array(in);
|
encoder.set_input_array(in);
|
||||||
|
|||||||
@@ -50,9 +50,9 @@ void svd_impl(
|
|||||||
array& s = outputs[1];
|
array& s = outputs[1];
|
||||||
array& vt = outputs[2];
|
array& vt = outputs[2];
|
||||||
|
|
||||||
u.set_data(allocator::malloc(u.nbytes()));
|
u.set_data(allocator::malloc_or_wait(u.nbytes()));
|
||||||
s.set_data(allocator::malloc(s.nbytes()));
|
s.set_data(allocator::malloc_or_wait(s.nbytes()));
|
||||||
vt.set_data(allocator::malloc(vt.nbytes()));
|
vt.set_data(allocator::malloc_or_wait(vt.nbytes()));
|
||||||
|
|
||||||
encoder.set_output_array(u);
|
encoder.set_output_array(u);
|
||||||
encoder.set_output_array(s);
|
encoder.set_output_array(s);
|
||||||
@@ -64,7 +64,7 @@ void svd_impl(
|
|||||||
} else {
|
} else {
|
||||||
array& s = outputs[0];
|
array& s = outputs[0];
|
||||||
|
|
||||||
s.set_data(allocator::malloc(s.nbytes()));
|
s.set_data(allocator::malloc_or_wait(s.nbytes()));
|
||||||
|
|
||||||
encoder.set_output_array(s);
|
encoder.set_output_array(s);
|
||||||
|
|
||||||
@@ -91,7 +91,7 @@ void svd_impl(
|
|||||||
|
|
||||||
// Will contain the indices of eigenvectors that failed to converge (not
|
// Will contain the indices of eigenvectors that failed to converge (not
|
||||||
// used here but required by lapack).
|
// used here but required by lapack).
|
||||||
auto iwork = array::Data{allocator::malloc(sizeof(int) * 12 * K)};
|
auto iwork = array::Data{allocator::malloc_or_wait(sizeof(int) * 12 * K)};
|
||||||
|
|
||||||
static const int lwork_query = -1;
|
static const int lwork_query = -1;
|
||||||
|
|
||||||
@@ -132,7 +132,7 @@ void svd_impl(
|
|||||||
}
|
}
|
||||||
|
|
||||||
const int lwork = workspace_dimension;
|
const int lwork = workspace_dimension;
|
||||||
auto scratch = array::Data{allocator::malloc(sizeof(T) * lwork)};
|
auto scratch = array::Data{allocator::malloc_or_wait(sizeof(T) * lwork)};
|
||||||
|
|
||||||
// Loop over matrices.
|
// Loop over matrices.
|
||||||
for (int i = 0; i < num_matrices; i++) {
|
for (int i = 0; i < num_matrices; i++) {
|
||||||
|
|||||||
@@ -1,8 +1,5 @@
|
|||||||
// Copyright © 2024 Apple Inc.
|
// Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
// Required for using M_LN2 in MSVC.
|
|
||||||
#define _USE_MATH_DEFINES
|
|
||||||
|
|
||||||
#include <cassert>
|
#include <cassert>
|
||||||
|
|
||||||
#include "mlx/backend/cpu/unary.h"
|
#include "mlx/backend/cpu/unary.h"
|
||||||
|
|||||||
@@ -2,13 +2,32 @@
|
|||||||
|
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include "mlx/backend/common/unary.h"
|
#include "mlx/allocator.h"
|
||||||
|
#include "mlx/array.h"
|
||||||
|
#include "mlx/backend/common/utils.h"
|
||||||
#include "mlx/backend/cpu/encoder.h"
|
#include "mlx/backend/cpu/encoder.h"
|
||||||
#include "mlx/backend/cpu/simd/simd.h"
|
#include "mlx/backend/cpu/simd/simd.h"
|
||||||
#include "mlx/utils.h"
|
#include "mlx/utils.h"
|
||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
|
|
||||||
|
void set_unary_output_data(const array& in, array& out) {
|
||||||
|
if (in.flags().contiguous) {
|
||||||
|
if (is_donatable(in, out)) {
|
||||||
|
out.copy_shared_buffer(in);
|
||||||
|
} else {
|
||||||
|
auto size = in.data_size();
|
||||||
|
out.set_data(
|
||||||
|
allocator::malloc_or_wait(size * out.itemsize()),
|
||||||
|
size,
|
||||||
|
in.strides(),
|
||||||
|
in.flags());
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
template <typename T, typename U = T, typename Op>
|
template <typename T, typename U = T, typename Op>
|
||||||
void unary_op(const T* a, U* out, size_t shape, size_t stride) {
|
void unary_op(const T* a, U* out, size_t shape, size_t stride) {
|
||||||
for (size_t i = 0; i < shape; i += 1) {
|
for (size_t i = 0; i < shape; i += 1) {
|
||||||
|
|||||||
@@ -86,14 +86,13 @@ struct Sign {
|
|||||||
template <int N, typename T>
|
template <int N, typename T>
|
||||||
Simd<T, N> operator()(Simd<T, N> x) {
|
Simd<T, N> operator()(Simd<T, N> x) {
|
||||||
auto z = Simd<T, N>{0};
|
auto z = Simd<T, N>{0};
|
||||||
auto o = Simd<T, N>{1};
|
|
||||||
auto m = Simd<T, N>{-1};
|
|
||||||
if constexpr (std::is_unsigned_v<T>) {
|
if constexpr (std::is_unsigned_v<T>) {
|
||||||
return simd::select(x == z, z, o);
|
return x != z;
|
||||||
} else if constexpr (std::is_same_v<T, complex64_t>) {
|
} else if constexpr (std::is_same_v<T, complex64_t>) {
|
||||||
return simd::select(x == z, x, Simd<T, N>(x / simd::abs(x)));
|
return simd::select(x == z, x, Simd<T, N>(x / simd::abs(x)));
|
||||||
} else {
|
} else {
|
||||||
return simd::select(x < z, m, simd::select(x > z, o, z));
|
return simd::select(
|
||||||
|
x < z, Simd<T, N>{-1}, simd::select(x > z, Simd<T, N>{1}, z));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
SINGLE()
|
SINGLE()
|
||||||
|
|||||||
@@ -1,119 +0,0 @@
|
|||||||
# Filename rules in cuda backend:
|
|
||||||
#
|
|
||||||
# * Use .cu/.cuh if code contains device code, and .cpp/.h if not.
|
|
||||||
# * Device-only code should be put in device/ subdir.
|
|
||||||
# * Files in device/ subdir should not include files outside.
|
|
||||||
target_sources(
|
|
||||||
mlx
|
|
||||||
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp
|
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/arg_reduce.cu
|
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/binary.cu
|
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/compiled.cpp
|
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/copy.cu
|
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/copy/copy_contiguous.cu
|
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/copy/copy_general.cu
|
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/copy/copy_general_dynamic.cu
|
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/copy/copy_general_input.cu
|
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/cuda.cpp
|
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
|
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/eval.cpp
|
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/event.cu
|
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/fence.cpp
|
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/jit_module.cpp
|
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp
|
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/kernel_utils.cu
|
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/matmul.cpp
|
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/layer_norm.cu
|
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/logsumexp.cu
|
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cu
|
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/random.cu
|
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/reduce.cu
|
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/reduce/col_reduce.cu
|
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/reduce/row_reduce.cu
|
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/reduce/segmented_reduce.cu
|
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/rms_norm.cu
|
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/rope.cu
|
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp
|
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/softmax.cu
|
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/sort.cu
|
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/ternary.cu
|
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/unary.cu
|
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp
|
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/worker.cpp)
|
|
||||||
|
|
||||||
target_compile_definitions(mlx PRIVATE MLX_USE_CUDA)
|
|
||||||
|
|
||||||
# Embed kernel sources in binary for JIT compilation.
|
|
||||||
file(
|
|
||||||
GLOB MLX_JIT_SOURCES
|
|
||||||
RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
|
|
||||||
"${CMAKE_CURRENT_SOURCE_DIR}/device/*.h"
|
|
||||||
"${CMAKE_CURRENT_SOURCE_DIR}/device/*.cuh")
|
|
||||||
string(JOIN ":" MLX_JIT_SOURCES_ARG ${MLX_JIT_SOURCES})
|
|
||||||
add_custom_command(
|
|
||||||
OUTPUT gen/cuda_jit_sources.h
|
|
||||||
COMMAND
|
|
||||||
${CMAKE_COMMAND} -DMLX_SOURCE_ROOT=${CMAKE_CURRENT_SOURCE_DIR}
|
|
||||||
-DMLX_JIT_SOURCES=${MLX_JIT_SOURCES_ARG} -P
|
|
||||||
"${CMAKE_CURRENT_SOURCE_DIR}/bin2h.cmake"
|
|
||||||
DEPENDS bin2h.cmake ${MLX_JIT_SOURCES})
|
|
||||||
add_custom_target(cuda_jit_sources DEPENDS gen/cuda_jit_sources.h)
|
|
||||||
add_dependencies(mlx cuda_jit_sources)
|
|
||||||
target_include_directories(mlx PRIVATE "${CMAKE_CURRENT_BINARY_DIR}/gen")
|
|
||||||
|
|
||||||
# Enable defining device lambda functions.
|
|
||||||
target_compile_options(mlx
|
|
||||||
PRIVATE "$<$<COMPILE_LANGUAGE:CUDA>:--extended-lambda>")
|
|
||||||
|
|
||||||
# CUDA 12.8 emits warning #20280-D for copy kernels which is a false positive.
|
|
||||||
# Explicitly pass this flag to suppress the warning, it is safe to set it to
|
|
||||||
# true but the warning wouldn't be suppressed.
|
|
||||||
if(CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 12.8.0)
|
|
||||||
target_compile_options(
|
|
||||||
mlx
|
|
||||||
PRIVATE "$<$<COMPILE_LANGUAGE:CUDA>:--static-global-template-stub=false>")
|
|
||||||
endif()
|
|
||||||
|
|
||||||
# Suppress warning when building for compute capability 7 used by V100.
|
|
||||||
target_compile_options(
|
|
||||||
mlx PRIVATE "$<$<COMPILE_LANGUAGE:CUDA>:--Wno-deprecated-gpu-targets>")
|
|
||||||
|
|
||||||
# Compute capability 7 is required for synchronization between CPU/GPU with
|
|
||||||
# managed memory. TODO: Add more architectures for potential performance gain.
|
|
||||||
set(MLX_CUDA_ARCHITECTURES
|
|
||||||
"70;80"
|
|
||||||
CACHE STRING "CUDA architectures")
|
|
||||||
message(STATUS "CUDA architectures: ${MLX_CUDA_ARCHITECTURES}")
|
|
||||||
set_target_properties(mlx PROPERTIES CUDA_ARCHITECTURES
|
|
||||||
"${MLX_CUDA_ARCHITECTURES}")
|
|
||||||
|
|
||||||
# Use fixed version of CCCL.
|
|
||||||
FetchContent_Declare(
|
|
||||||
cccl
|
|
||||||
URL "https://github.com/NVIDIA/cccl/releases/download/v2.8.1/cccl-v2.8.1.zip")
|
|
||||||
FetchContent_MakeAvailable(cccl)
|
|
||||||
target_include_directories(mlx BEFORE PRIVATE "${cccl_SOURCE_DIR}/include")
|
|
||||||
|
|
||||||
# Use fixed version of NVTX.
|
|
||||||
FetchContent_Declare(
|
|
||||||
nvtx3
|
|
||||||
GIT_REPOSITORY https://github.com/NVIDIA/NVTX.git
|
|
||||||
GIT_TAG v3.1.1
|
|
||||||
GIT_SHALLOW TRUE
|
|
||||||
SOURCE_SUBDIR c EXCLUDE_FROM_ALL)
|
|
||||||
FetchContent_MakeAvailable(nvtx3)
|
|
||||||
target_link_libraries(mlx PUBLIC $<BUILD_INTERFACE:nvtx3-cpp>)
|
|
||||||
|
|
||||||
# Make cuda runtime APIs available in non-cuda files.
|
|
||||||
find_package(CUDAToolkit REQUIRED)
|
|
||||||
target_include_directories(mlx PRIVATE ${CUDAToolkit_INCLUDE_DIRS})
|
|
||||||
|
|
||||||
# Use cublasLt.
|
|
||||||
target_link_libraries(mlx PRIVATE CUDA::cublasLt)
|
|
||||||
|
|
||||||
# Use NVRTC and driver APIs.
|
|
||||||
target_link_libraries(mlx PRIVATE CUDA::nvrtc CUDA::cuda_driver)
|
|
||||||
|
|
||||||
# Suppress nvcc warnings on MLX headers.
|
|
||||||
target_compile_options(mlx PRIVATE $<$<COMPILE_LANGUAGE:CUDA>:-Xcudafe
|
|
||||||
--diag_suppress=997>)
|
|
||||||
@@ -1,206 +0,0 @@
|
|||||||
// Copyright © 2025 Apple Inc.
|
|
||||||
|
|
||||||
#include "mlx/backend/cuda/allocator.h"
|
|
||||||
#include "mlx/backend/cuda/utils.h"
|
|
||||||
#include "mlx/backend/cuda/worker.h"
|
|
||||||
|
|
||||||
#include <cuda_runtime.h>
|
|
||||||
#include <fmt/format.h>
|
|
||||||
#include <unistd.h>
|
|
||||||
|
|
||||||
#include <cassert>
|
|
||||||
|
|
||||||
namespace mlx::core {
|
|
||||||
|
|
||||||
namespace cu {
|
|
||||||
|
|
||||||
CudaAllocator::CudaAllocator()
|
|
||||||
: buffer_cache_(
|
|
||||||
getpagesize(),
|
|
||||||
[](CudaBuffer* buf) { return buf->size; },
|
|
||||||
[this](CudaBuffer* buf) {
|
|
||||||
cuda_free(buf->data);
|
|
||||||
delete buf;
|
|
||||||
}) {
|
|
||||||
// TODO: Set memory limit for multi-device.
|
|
||||||
size_t free, total;
|
|
||||||
CHECK_CUDA_ERROR(cudaMemGetInfo(&free, &total));
|
|
||||||
memory_limit_ = total * 0.8;
|
|
||||||
max_pool_size_ = memory_limit_;
|
|
||||||
}
|
|
||||||
|
|
||||||
Buffer CudaAllocator::malloc(size_t size) {
|
|
||||||
// Find available buffer from cache.
|
|
||||||
std::unique_lock lock(mutex_);
|
|
||||||
CudaBuffer* buf = buffer_cache_.reuse_from_cache(size);
|
|
||||||
if (!buf) {
|
|
||||||
// If we have a lot of memory pressure or are over the maximum cache size,
|
|
||||||
// try to reclaim memory from the cache.
|
|
||||||
size_t mem_required = get_active_memory() + get_cache_memory() + size;
|
|
||||||
if (mem_required >= memory_limit_) {
|
|
||||||
buffer_cache_.release_cached_buffers(mem_required - memory_limit_);
|
|
||||||
}
|
|
||||||
|
|
||||||
lock.unlock();
|
|
||||||
buf = new CudaBuffer{nullptr, size};
|
|
||||||
cudaError_t err = cudaMallocManaged(&buf->data, size);
|
|
||||||
if (err != cudaSuccess && err != cudaErrorMemoryAllocation) {
|
|
||||||
throw std::runtime_error(fmt::format(
|
|
||||||
"cudaMallocManaged failed: {}.", cudaGetErrorString(err)));
|
|
||||||
}
|
|
||||||
lock.lock();
|
|
||||||
}
|
|
||||||
active_memory_ += size;
|
|
||||||
peak_memory_ = std::max(active_memory_, peak_memory_);
|
|
||||||
|
|
||||||
// Maintain the cache below the requested limit.
|
|
||||||
if (get_cache_memory() > max_pool_size_) {
|
|
||||||
buffer_cache_.release_cached_buffers(get_cache_memory() - max_pool_size_);
|
|
||||||
}
|
|
||||||
|
|
||||||
return Buffer{buf};
|
|
||||||
}
|
|
||||||
|
|
||||||
void CudaAllocator::free(Buffer buffer) {
|
|
||||||
auto* buf = static_cast<CudaBuffer*>(buffer.ptr());
|
|
||||||
if (!buf) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::unique_lock lock(mutex_);
|
|
||||||
active_memory_ -= buf->size;
|
|
||||||
if (get_cache_memory() < max_pool_size_) {
|
|
||||||
buffer_cache_.recycle_to_cache(buf);
|
|
||||||
} else {
|
|
||||||
lock.unlock();
|
|
||||||
cuda_free(buf->data);
|
|
||||||
delete buf;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
size_t CudaAllocator::size(Buffer buffer) const {
|
|
||||||
auto* buf = static_cast<CudaBuffer*>(buffer.ptr());
|
|
||||||
if (!buf) {
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
return buf->size;
|
|
||||||
}
|
|
||||||
|
|
||||||
void CudaAllocator::register_this_thread() {
|
|
||||||
std::lock_guard lock(worker_mutex_);
|
|
||||||
allowed_threads_.insert(std::this_thread::get_id());
|
|
||||||
}
|
|
||||||
|
|
||||||
void CudaAllocator::cuda_free(void* buf) {
|
|
||||||
// If cuda_free() is called from a unregistered thread, reschedule the call to
|
|
||||||
// worker.
|
|
||||||
{
|
|
||||||
std::lock_guard lock(worker_mutex_);
|
|
||||||
if (allowed_threads_.count(std::this_thread::get_id()) == 0) {
|
|
||||||
if (!worker_) {
|
|
||||||
worker_.reset(new Worker);
|
|
||||||
}
|
|
||||||
worker_->add_task([this, buf]() { this->cuda_free(buf); });
|
|
||||||
worker_->end_batch();
|
|
||||||
worker_->commit();
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
cudaFree(buf);
|
|
||||||
}
|
|
||||||
|
|
||||||
size_t CudaAllocator::get_active_memory() const {
|
|
||||||
return active_memory_;
|
|
||||||
}
|
|
||||||
|
|
||||||
size_t CudaAllocator::get_peak_memory() const {
|
|
||||||
return peak_memory_;
|
|
||||||
}
|
|
||||||
|
|
||||||
void CudaAllocator::reset_peak_memory() {
|
|
||||||
std::lock_guard lock(mutex_);
|
|
||||||
peak_memory_ = 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
size_t CudaAllocator::get_memory_limit() {
|
|
||||||
return memory_limit_;
|
|
||||||
}
|
|
||||||
|
|
||||||
size_t CudaAllocator::set_memory_limit(size_t limit) {
|
|
||||||
std::lock_guard lock(mutex_);
|
|
||||||
std::swap(limit, memory_limit_);
|
|
||||||
return limit;
|
|
||||||
}
|
|
||||||
|
|
||||||
size_t CudaAllocator::get_cache_memory() const {
|
|
||||||
return buffer_cache_.cache_size();
|
|
||||||
}
|
|
||||||
|
|
||||||
size_t CudaAllocator::set_cache_limit(size_t limit) {
|
|
||||||
std::lock_guard lk(mutex_);
|
|
||||||
std::swap(limit, max_pool_size_);
|
|
||||||
return limit;
|
|
||||||
}
|
|
||||||
|
|
||||||
void CudaAllocator::clear_cache() {
|
|
||||||
std::lock_guard lk(mutex_);
|
|
||||||
buffer_cache_.clear();
|
|
||||||
}
|
|
||||||
|
|
||||||
CudaAllocator& allocator() {
|
|
||||||
// By creating the |allocator_| on heap, the destructor of CudaAllocator
|
|
||||||
// will not be called on exit and buffers in the cache will be leaked. This
|
|
||||||
// can save some time at program exit.
|
|
||||||
static CudaAllocator* allocator_ = new CudaAllocator;
|
|
||||||
return *allocator_;
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace cu
|
|
||||||
|
|
||||||
namespace allocator {
|
|
||||||
|
|
||||||
Allocator& allocator() {
|
|
||||||
return cu::allocator();
|
|
||||||
}
|
|
||||||
|
|
||||||
void* Buffer::raw_ptr() {
|
|
||||||
if (!ptr_) {
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
return static_cast<cu::CudaBuffer*>(ptr_)->data;
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace allocator
|
|
||||||
|
|
||||||
size_t get_active_memory() {
|
|
||||||
return cu::allocator().get_active_memory();
|
|
||||||
}
|
|
||||||
size_t get_peak_memory() {
|
|
||||||
return cu::allocator().get_peak_memory();
|
|
||||||
}
|
|
||||||
void reset_peak_memory() {
|
|
||||||
return cu::allocator().reset_peak_memory();
|
|
||||||
}
|
|
||||||
size_t set_memory_limit(size_t limit) {
|
|
||||||
return cu::allocator().set_memory_limit(limit);
|
|
||||||
}
|
|
||||||
size_t get_memory_limit() {
|
|
||||||
return cu::allocator().get_memory_limit();
|
|
||||||
}
|
|
||||||
size_t get_cache_memory() {
|
|
||||||
return cu::allocator().get_cache_memory();
|
|
||||||
}
|
|
||||||
size_t set_cache_limit(size_t limit) {
|
|
||||||
return cu::allocator().set_cache_limit(limit);
|
|
||||||
}
|
|
||||||
void clear_cache() {
|
|
||||||
cu::allocator().clear_cache();
|
|
||||||
}
|
|
||||||
|
|
||||||
// Not supported in CUDA.
|
|
||||||
size_t set_wired_limit(size_t) {
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace mlx::core
|
|
||||||
@@ -1,67 +0,0 @@
|
|||||||
// Copyright © 2025 Apple Inc.
|
|
||||||
|
|
||||||
#pragma once
|
|
||||||
|
|
||||||
#include "mlx/allocator.h"
|
|
||||||
#include "mlx/backend/common/buffer_cache.h"
|
|
||||||
|
|
||||||
#include <mutex>
|
|
||||||
#include <set>
|
|
||||||
#include <thread>
|
|
||||||
#include <utility>
|
|
||||||
|
|
||||||
namespace mlx::core::cu {
|
|
||||||
|
|
||||||
class Worker;
|
|
||||||
|
|
||||||
using allocator::Buffer;
|
|
||||||
|
|
||||||
// Stores cuda-managed unified memory.
|
|
||||||
struct CudaBuffer {
|
|
||||||
void* data;
|
|
||||||
size_t size;
|
|
||||||
};
|
|
||||||
|
|
||||||
class CudaAllocator : public allocator::Allocator {
|
|
||||||
public:
|
|
||||||
Buffer malloc(size_t size) override;
|
|
||||||
void free(Buffer buffer) override;
|
|
||||||
size_t size(Buffer buffer) const override;
|
|
||||||
|
|
||||||
// Register current thread as safe to free buffers.
|
|
||||||
// In cuda freeing a buffer implicitly synchronizes stream, and for threads
|
|
||||||
// that may be waited by gpu stream (for example cpu stream threads), freeing
|
|
||||||
// buffers there would result in dead lock.
|
|
||||||
void register_this_thread();
|
|
||||||
|
|
||||||
// Call cudaFree in the safe thread.
|
|
||||||
void cuda_free(void* buf);
|
|
||||||
|
|
||||||
size_t get_active_memory() const;
|
|
||||||
size_t get_peak_memory() const;
|
|
||||||
void reset_peak_memory();
|
|
||||||
size_t get_memory_limit();
|
|
||||||
size_t set_memory_limit(size_t limit);
|
|
||||||
size_t get_cache_memory() const;
|
|
||||||
size_t set_cache_limit(size_t limit);
|
|
||||||
void clear_cache();
|
|
||||||
|
|
||||||
private:
|
|
||||||
CudaAllocator();
|
|
||||||
friend CudaAllocator& allocator();
|
|
||||||
|
|
||||||
std::mutex worker_mutex_;
|
|
||||||
std::unique_ptr<Worker> worker_;
|
|
||||||
std::set<std::thread::id> allowed_threads_;
|
|
||||||
|
|
||||||
std::mutex mutex_;
|
|
||||||
size_t memory_limit_;
|
|
||||||
size_t max_pool_size_;
|
|
||||||
BufferCache<CudaBuffer> buffer_cache_;
|
|
||||||
size_t active_memory_{0};
|
|
||||||
size_t peak_memory_{0};
|
|
||||||
};
|
|
||||||
|
|
||||||
CudaAllocator& allocator();
|
|
||||||
|
|
||||||
} // namespace mlx::core::cu
|
|
||||||
@@ -1,188 +0,0 @@
|
|||||||
// Copyright © 2025 Apple Inc.
|
|
||||||
#include "mlx/backend/common/utils.h"
|
|
||||||
#include "mlx/backend/cuda/device.h"
|
|
||||||
#include "mlx/backend/cuda/iterators/strided_iterator.cuh"
|
|
||||||
#include "mlx/backend/cuda/kernel_utils.cuh"
|
|
||||||
#include "mlx/dtype_utils.h"
|
|
||||||
#include "mlx/primitives.h"
|
|
||||||
|
|
||||||
#include <cooperative_groups.h>
|
|
||||||
#include <nvtx3/nvtx3.hpp>
|
|
||||||
#include <cub/block/block_load.cuh>
|
|
||||||
#include <cub/block/block_reduce.cuh>
|
|
||||||
|
|
||||||
#include <cassert>
|
|
||||||
|
|
||||||
namespace mlx::core {
|
|
||||||
|
|
||||||
namespace cu {
|
|
||||||
|
|
||||||
namespace cg = cooperative_groups;
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
struct IndexValPair {
|
|
||||||
uint32_t index;
|
|
||||||
T val;
|
|
||||||
};
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
struct ArgMin {
|
|
||||||
constexpr __device__ T init() {
|
|
||||||
return Limits<T>::max();
|
|
||||||
}
|
|
||||||
|
|
||||||
__device__ IndexValPair<T> operator()(
|
|
||||||
const IndexValPair<T>& best,
|
|
||||||
const IndexValPair<T>& current) {
|
|
||||||
if (best.val > current.val ||
|
|
||||||
(best.val == current.val && best.index > current.index)) {
|
|
||||||
return current;
|
|
||||||
} else {
|
|
||||||
return best;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
template <int N>
|
|
||||||
__device__ IndexValPair<T>
|
|
||||||
reduce_many(IndexValPair<T> best, T (&vals)[N], uint32_t offset) {
|
|
||||||
for (int i = 0; i < N; i++) {
|
|
||||||
if (vals[i] < best.val) {
|
|
||||||
best.val = vals[i];
|
|
||||||
best.index = offset + i;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return best;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
struct ArgMax {
|
|
||||||
constexpr __device__ T init() {
|
|
||||||
return Limits<T>::min();
|
|
||||||
}
|
|
||||||
|
|
||||||
__device__ IndexValPair<T> operator()(
|
|
||||||
const IndexValPair<T>& best,
|
|
||||||
const IndexValPair<T>& current) {
|
|
||||||
if (best.val < current.val ||
|
|
||||||
(best.val == current.val && best.index > current.index)) {
|
|
||||||
return current;
|
|
||||||
} else {
|
|
||||||
return best;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
template <int N>
|
|
||||||
__device__ IndexValPair<T>
|
|
||||||
reduce_many(IndexValPair<T> best, T (&vals)[N], uint32_t offset) {
|
|
||||||
for (int i = 0; i < N; i++) {
|
|
||||||
if (vals[i] > best.val) {
|
|
||||||
best.val = vals[i];
|
|
||||||
best.index = offset + i;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return best;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
template <typename T, typename Op, int BLOCK_DIM, int N_READS = 4>
|
|
||||||
__global__ void arg_reduce_general(
|
|
||||||
const T* in,
|
|
||||||
uint32_t* out,
|
|
||||||
size_t size,
|
|
||||||
const __grid_constant__ Shape shape,
|
|
||||||
const __grid_constant__ Strides in_strides,
|
|
||||||
const __grid_constant__ Strides out_strides,
|
|
||||||
int32_t ndim,
|
|
||||||
int64_t axis_stride,
|
|
||||||
int32_t axis_size) {
|
|
||||||
auto block = cg::this_thread_block();
|
|
||||||
|
|
||||||
int64_t index = cg::this_grid().block_rank();
|
|
||||||
if (index >= size) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
int64_t in_idx = elem_to_loc(index, shape.data(), in_strides.data(), ndim);
|
|
||||||
int64_t out_idx = elem_to_loc(index, shape.data(), out_strides.data(), ndim);
|
|
||||||
|
|
||||||
Op op;
|
|
||||||
T init = op.init();
|
|
||||||
IndexValPair<T> best{0, init};
|
|
||||||
|
|
||||||
for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) {
|
|
||||||
T vals[N_READS];
|
|
||||||
auto tid = r * BLOCK_DIM + block.thread_index().x;
|
|
||||||
cub::LoadDirectBlocked(
|
|
||||||
tid, strided_iterator(in + in_idx, axis_stride), vals, axis_size, init);
|
|
||||||
best = op.reduce_many(best, vals, tid * N_READS);
|
|
||||||
}
|
|
||||||
|
|
||||||
typedef cub::BlockReduce<IndexValPair<T>, BLOCK_DIM> BlockReduceT;
|
|
||||||
__shared__ typename BlockReduceT::TempStorage temp;
|
|
||||||
|
|
||||||
best = BlockReduceT(temp).Reduce(best, op);
|
|
||||||
|
|
||||||
if (block.thread_rank() == 0) {
|
|
||||||
out[out_idx] = best.index;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace cu
|
|
||||||
|
|
||||||
void ArgReduce::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|
||||||
nvtx3::scoped_range r("ArgReduce::eval_gpu");
|
|
||||||
assert(inputs.size() == 1);
|
|
||||||
auto& in = inputs[0];
|
|
||||||
out.set_data(allocator::malloc(out.nbytes()));
|
|
||||||
auto& s = stream();
|
|
||||||
|
|
||||||
// Prepare the shapes, strides and axis arguments.
|
|
||||||
Shape shape = remove_index(in.shape(), axis_);
|
|
||||||
Strides in_strides = remove_index(in.strides(), axis_);
|
|
||||||
Strides out_strides = out.ndim() == in.ndim()
|
|
||||||
? remove_index(out.strides(), axis_)
|
|
||||||
: out.strides();
|
|
||||||
int64_t axis_stride = in.strides()[axis_];
|
|
||||||
int32_t axis_size = in.shape()[axis_];
|
|
||||||
int32_t ndim = shape.size();
|
|
||||||
|
|
||||||
// ArgReduce.
|
|
||||||
auto& encoder = cu::get_command_encoder(s);
|
|
||||||
encoder.set_input_array(in);
|
|
||||||
encoder.set_output_array(out);
|
|
||||||
encoder.launch_kernel([&](cudaStream_t stream) {
|
|
||||||
MLX_SWITCH_REAL_TYPES_CHECKED(in.dtype(), "ArgReduce", CTYPE, {
|
|
||||||
using InType = cuda_type_t<CTYPE>;
|
|
||||||
constexpr uint32_t N_READS = 4;
|
|
||||||
MLX_SWITCH_BLOCK_DIM(cuda::ceil_div(axis_size, N_READS), BLOCK_DIM, {
|
|
||||||
dim3 num_blocks = get_2d_grid_dims(out.shape(), out.strides());
|
|
||||||
dim3 block_dims{BLOCK_DIM, 1, 1};
|
|
||||||
auto kernel = &cu::arg_reduce_general<
|
|
||||||
InType,
|
|
||||||
cu::ArgMax<InType>,
|
|
||||||
BLOCK_DIM,
|
|
||||||
N_READS>;
|
|
||||||
if (reduce_type_ == ArgReduce::ArgMin) {
|
|
||||||
kernel = &cu::arg_reduce_general<
|
|
||||||
InType,
|
|
||||||
cu::ArgMin<InType>,
|
|
||||||
BLOCK_DIM,
|
|
||||||
N_READS>;
|
|
||||||
}
|
|
||||||
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
|
||||||
in.data<InType>(),
|
|
||||||
out.data<uint32_t>(),
|
|
||||||
out.size(),
|
|
||||||
const_param(shape),
|
|
||||||
const_param(in_strides),
|
|
||||||
const_param(out_strides),
|
|
||||||
ndim,
|
|
||||||
axis_stride,
|
|
||||||
axis_size);
|
|
||||||
});
|
|
||||||
});
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace mlx::core
|
|
||||||
@@ -1,150 +0,0 @@
|
|||||||
# Based on: https://github.com/sivachandran/cmake-bin2h
|
|
||||||
#
|
|
||||||
# Copyright 2020 Sivachandran Paramasivam
|
|
||||||
#
|
|
||||||
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
||||||
# of this software and associated documentation files (the "Software"), to deal
|
|
||||||
# in the Software without restriction, including without limitation the rights
|
|
||||||
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
||||||
# copies of the Software, and to permit persons to whom the Software is
|
|
||||||
# furnished to do so, subject to the following conditions:
|
|
||||||
#
|
|
||||||
# The above copyright notice and this permission notice shall be included in all
|
|
||||||
# copies or substantial portions of the Software.
|
|
||||||
#
|
|
||||||
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
||||||
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
||||||
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
||||||
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
||||||
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
||||||
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
||||||
# SOFTWARE.
|
|
||||||
|
|
||||||
include(CMakeParseArguments)
|
|
||||||
|
|
||||||
# Function to wrap a given string into multiple lines at the given column
|
|
||||||
# position.
|
|
||||||
#
|
|
||||||
# Parameters:
|
|
||||||
#
|
|
||||||
# * VARIABLE - The name of the CMake variable holding the string.
|
|
||||||
# * AT_COLUMN - The column position at which string will be wrapped.
|
|
||||||
function(WRAP_STRING)
|
|
||||||
set(oneValueArgs VARIABLE AT_COLUMN)
|
|
||||||
cmake_parse_arguments(WRAP_STRING "${options}" "${oneValueArgs}" "" ${ARGN})
|
|
||||||
|
|
||||||
string(LENGTH ${${WRAP_STRING_VARIABLE}} stringLength)
|
|
||||||
math(EXPR offset "0")
|
|
||||||
|
|
||||||
while(stringLength GREATER 0)
|
|
||||||
if(stringLength GREATER ${WRAP_STRING_AT_COLUMN})
|
|
||||||
math(EXPR length "${WRAP_STRING_AT_COLUMN}")
|
|
||||||
else()
|
|
||||||
math(EXPR length "${stringLength}")
|
|
||||||
endif()
|
|
||||||
|
|
||||||
string(SUBSTRING ${${WRAP_STRING_VARIABLE}} ${offset} ${length} line)
|
|
||||||
set(lines "${lines}\n ${line}")
|
|
||||||
|
|
||||||
math(EXPR stringLength "${stringLength} - ${length}")
|
|
||||||
math(EXPR offset "${offset} + ${length}")
|
|
||||||
endwhile()
|
|
||||||
|
|
||||||
set(${WRAP_STRING_VARIABLE}
|
|
||||||
"${lines}"
|
|
||||||
PARENT_SCOPE)
|
|
||||||
endfunction()
|
|
||||||
|
|
||||||
# Function to embed contents of a file as byte array in C/C++ header file(.h).
|
|
||||||
# The header file will contain a byte array and integer variable holding the
|
|
||||||
# size of the array.
|
|
||||||
#
|
|
||||||
# Parameters:
|
|
||||||
#
|
|
||||||
# * SOURCE_FILES - The paths of source files whose contents will be embedded in
|
|
||||||
# the header file.
|
|
||||||
# * VARIABLE_NAME - The name of the variable for the byte array. The string
|
|
||||||
# "_SIZE" will be append to this name and will be used a variable name for
|
|
||||||
# size variable.
|
|
||||||
# * HEADER_FILE - The path of header file.
|
|
||||||
# * APPEND - If specified appends to the header file instead of overwriting it
|
|
||||||
# * HEADER_NAMESPACE - The namespace, where the array should be located in.
|
|
||||||
# * NULL_TERMINATE - If specified a null byte(zero) will be append to the byte
|
|
||||||
# array.
|
|
||||||
#
|
|
||||||
# Usage:
|
|
||||||
#
|
|
||||||
# bin2h(SOURCE_FILE "Logo.png" HEADER_FILE "Logo.h" VARIABLE_NAME "LOGO_PNG")
|
|
||||||
function(BIN2H)
|
|
||||||
set(options APPEND NULL_TERMINATE)
|
|
||||||
set(oneValueArgs VARIABLE_NAME HEADER_FILE HEADER_NAMESPACE)
|
|
||||||
set(multiValueArgs SOURCE_FILES)
|
|
||||||
cmake_parse_arguments(BIN2H "${options}" "${oneValueArgs}"
|
|
||||||
"${multiValueArgs}" ${ARGN})
|
|
||||||
|
|
||||||
set(arrayDefinition "")
|
|
||||||
foreach(SOURCE_FILE IN LISTS BIN2H_SOURCE_FILES)
|
|
||||||
# get filename without extension
|
|
||||||
get_filename_component(FILE_NAME_WE ${SOURCE_FILE} NAME_WE)
|
|
||||||
# convert the filename to a valid C identifier
|
|
||||||
string(MAKE_C_IDENTIFIER "${FILE_NAME_WE}" VALID_FILE_NAME)
|
|
||||||
|
|
||||||
# reads source file contents as hex string
|
|
||||||
file(READ ${SOURCE_FILE} hexString HEX)
|
|
||||||
|
|
||||||
# append null
|
|
||||||
if(BIN2H_NULL_TERMINATE)
|
|
||||||
string(APPEND hexString "00")
|
|
||||||
endif()
|
|
||||||
|
|
||||||
# wraps the hex string into multiple lines
|
|
||||||
wrap_string(VARIABLE hexString AT_COLUMN 24)
|
|
||||||
|
|
||||||
# strip the © in source code
|
|
||||||
string(REGEX REPLACE "c2a9" "2020" arrayValues ${hexString})
|
|
||||||
|
|
||||||
string(REGEX REPLACE "([0-9a-f][0-9a-f])" " 0x\\1," arrayValues
|
|
||||||
${arrayValues})
|
|
||||||
|
|
||||||
# make a full variable name for the array
|
|
||||||
set(FULL_VARIABLE_NAME "${BIN2H_VARIABLE_NAME}_${VALID_FILE_NAME}")
|
|
||||||
|
|
||||||
# declares byte array and the length variables
|
|
||||||
string(APPEND arrayDefinition
|
|
||||||
"constexpr char ${FULL_VARIABLE_NAME}[] = {${arrayValues}\n};\n\n")
|
|
||||||
endforeach()
|
|
||||||
|
|
||||||
# add namespace wrapper if defined
|
|
||||||
if(DEFINED BIN2H_HEADER_NAMESPACE)
|
|
||||||
set(namespaceStart "namespace ${BIN2H_HEADER_NAMESPACE} {")
|
|
||||||
set(namespaceEnd "} // namespace ${BIN2H_HEADER_NAMESPACE}")
|
|
||||||
set(declarations "${namespaceStart}\n\n${arrayDefinition}${namespaceEnd}\n")
|
|
||||||
endif()
|
|
||||||
|
|
||||||
set(arrayIncludes "#pragma once")
|
|
||||||
string(PREPEND declarations "${arrayIncludes}\n\n")
|
|
||||||
|
|
||||||
if(BIN2H_APPEND)
|
|
||||||
file(APPEND ${BIN2H_HEADER_FILE} "${declarations}")
|
|
||||||
else()
|
|
||||||
file(WRITE ${BIN2H_HEADER_FILE} "${declarations}")
|
|
||||||
endif()
|
|
||||||
endfunction()
|
|
||||||
|
|
||||||
# ----------------------------- CLI args -----------------------------
|
|
||||||
|
|
||||||
string(REPLACE ":" ";" MLX_JIT_SOURCES_LIST ${MLX_JIT_SOURCES})
|
|
||||||
foreach(source ${MLX_JIT_SOURCES_LIST})
|
|
||||||
list(APPEND MLX_JIT_SOURCES_ABS "${MLX_SOURCE_ROOT}/${source}")
|
|
||||||
endforeach()
|
|
||||||
|
|
||||||
bin2h(
|
|
||||||
SOURCE_FILES
|
|
||||||
${MLX_JIT_SOURCES_ABS}
|
|
||||||
NULL_TERMINATE
|
|
||||||
VARIABLE_NAME
|
|
||||||
"jit_source"
|
|
||||||
HEADER_NAMESPACE
|
|
||||||
"mlx::core"
|
|
||||||
HEADER_FILE
|
|
||||||
"${CMAKE_CURRENT_BINARY_DIR}/gen/cuda_jit_sources.h")
|
|
||||||
@@ -1,305 +0,0 @@
|
|||||||
// Copyright © 2025 Apple Inc.
|
|
||||||
|
|
||||||
#include "mlx/backend/common/binary.h"
|
|
||||||
#include "mlx/backend/cuda/device.h"
|
|
||||||
#include "mlx/backend/cuda/device/binary_ops.cuh"
|
|
||||||
#include "mlx/backend/cuda/device/cucomplex_math.cuh"
|
|
||||||
#include "mlx/backend/cuda/kernel_utils.cuh"
|
|
||||||
#include "mlx/dtype_utils.h"
|
|
||||||
#include "mlx/primitives.h"
|
|
||||||
|
|
||||||
#include <cooperative_groups.h>
|
|
||||||
#include <nvtx3/nvtx3.hpp>
|
|
||||||
|
|
||||||
namespace mlx::core {
|
|
||||||
|
|
||||||
namespace cu {
|
|
||||||
|
|
||||||
namespace cg = cooperative_groups;
|
|
||||||
|
|
||||||
template <typename Op, typename In, typename Out, typename IdxT>
|
|
||||||
__global__ void binary_ss(const In* a, const In* b, Out* out, IdxT size) {
|
|
||||||
IdxT index = cg::this_grid().thread_rank();
|
|
||||||
if (index < size) {
|
|
||||||
out[index] = Op{}(a[0], b[0]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename Op, typename In, typename Out, typename IdxT>
|
|
||||||
__global__ void binary_sv(const In* a, const In* b, Out* out, IdxT size) {
|
|
||||||
IdxT index = cg::this_grid().thread_rank();
|
|
||||||
if (index < size) {
|
|
||||||
out[index] = Op{}(a[0], b[index]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename Op, typename In, typename Out, typename IdxT>
|
|
||||||
__global__ void binary_vs(const In* a, const In* b, Out* out, IdxT size) {
|
|
||||||
IdxT index = cg::this_grid().thread_rank();
|
|
||||||
if (index < size) {
|
|
||||||
out[index] = Op{}(a[index], b[0]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename Op, typename In, typename Out, typename IdxT>
|
|
||||||
__global__ void binary_vv(const In* a, const In* b, Out* out, IdxT size) {
|
|
||||||
IdxT index = cg::this_grid().thread_rank();
|
|
||||||
if (index < size) {
|
|
||||||
out[index] = Op{}(a[index], b[index]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename Op, typename In, typename Out, typename IdxT, int NDIM>
|
|
||||||
__global__ void binary_g_nd(
|
|
||||||
const In* a,
|
|
||||||
const In* b,
|
|
||||||
Out* out,
|
|
||||||
IdxT size,
|
|
||||||
const __grid_constant__ cuda::std::array<int32_t, NDIM> shape,
|
|
||||||
const __grid_constant__ cuda::std::array<int64_t, NDIM> a_strides,
|
|
||||||
const __grid_constant__ cuda::std::array<int64_t, NDIM> b_strides) {
|
|
||||||
IdxT index = cg::this_grid().thread_rank();
|
|
||||||
if (index < size) {
|
|
||||||
auto [a_idx, b_idx] = elem_to_loc_nd<NDIM>(
|
|
||||||
index, shape.data(), a_strides.data(), b_strides.data());
|
|
||||||
out[index] = Op{}(a[a_idx], b[b_idx]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename Op, typename In, typename Out, typename IdxT>
|
|
||||||
__global__ void binary_g(
|
|
||||||
const In* a,
|
|
||||||
const In* b,
|
|
||||||
Out* out,
|
|
||||||
IdxT size,
|
|
||||||
const __grid_constant__ Shape shape,
|
|
||||||
const __grid_constant__ Strides a_strides,
|
|
||||||
const __grid_constant__ Strides b_strides,
|
|
||||||
int ndim) {
|
|
||||||
IdxT index = cg::this_grid().thread_rank();
|
|
||||||
if (index < size) {
|
|
||||||
auto [a_idx, b_idx] = elem_to_loc_4d(
|
|
||||||
index, shape.data(), a_strides.data(), b_strides.data(), ndim);
|
|
||||||
out[index] = Op{}(a[a_idx], b[b_idx]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename Op, typename In, typename Out>
|
|
||||||
constexpr bool supports_binary_op() {
|
|
||||||
if (std::is_same_v<Op, Add> || std::is_same_v<Op, Divide> ||
|
|
||||||
std::is_same_v<Op, Maximum> || std::is_same_v<Op, Minimum> ||
|
|
||||||
std::is_same_v<Op, Multiply> || std::is_same_v<Op, Subtract> ||
|
|
||||||
std::is_same_v<Op, Power> || std::is_same_v<Op, Remainder>) {
|
|
||||||
return std::is_same_v<In, Out>;
|
|
||||||
}
|
|
||||||
if (std::is_same_v<Op, Equal> || std::is_same_v<Op, Greater> ||
|
|
||||||
std::is_same_v<Op, GreaterEqual> || std::is_same_v<Op, Less> ||
|
|
||||||
std::is_same_v<Op, LessEqual> || std::is_same_v<Op, NotEqual>) {
|
|
||||||
return std::is_same_v<Out, bool>;
|
|
||||||
}
|
|
||||||
if (std::is_same_v<Op, LogicalAnd> || std::is_same_v<Op, LogicalOr>) {
|
|
||||||
return std::is_same_v<Out, bool> && std::is_same_v<In, bool>;
|
|
||||||
}
|
|
||||||
if (std::is_same_v<Op, NaNEqual>) {
|
|
||||||
return std::is_same_v<Out, bool> &&
|
|
||||||
(is_floating_v<In> || std::is_same_v<In, complex64_t>);
|
|
||||||
}
|
|
||||||
if (std::is_same_v<Op, LogAddExp> || std::is_same_v<Op, ArcTan2>) {
|
|
||||||
return std::is_same_v<In, Out> && is_floating_v<In>;
|
|
||||||
}
|
|
||||||
if (std::is_same_v<Op, BitwiseAnd> || std::is_same_v<Op, BitwiseOr> ||
|
|
||||||
std::is_same_v<Op, BitwiseXor>) {
|
|
||||||
return std::is_same_v<In, Out> && std::is_integral_v<In>;
|
|
||||||
}
|
|
||||||
if (std::is_same_v<Op, LeftShift> || std::is_same_v<Op, RightShift>) {
|
|
||||||
return std::is_same_v<In, Out> && std::is_integral_v<In> &&
|
|
||||||
!std::is_same_v<In, bool>;
|
|
||||||
}
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace cu
|
|
||||||
|
|
||||||
template <typename Op>
|
|
||||||
void binary_op_gpu_inplace(
|
|
||||||
const std::vector<array>& inputs,
|
|
||||||
std::vector<array>& outputs,
|
|
||||||
std::string_view op,
|
|
||||||
const Stream& s) {
|
|
||||||
assert(inputs.size() > 1);
|
|
||||||
const auto& a = inputs[0];
|
|
||||||
const auto& b = inputs[1];
|
|
||||||
auto& out = outputs[0];
|
|
||||||
if (out.size() == 0) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
auto& encoder = cu::get_command_encoder(s);
|
|
||||||
encoder.set_input_array(a);
|
|
||||||
encoder.set_input_array(b);
|
|
||||||
encoder.set_output_array(out);
|
|
||||||
encoder.launch_kernel([&](cudaStream_t stream) {
|
|
||||||
MLX_SWITCH_ALL_TYPES(a.dtype(), CTYPE_IN, {
|
|
||||||
MLX_SWITCH_ALL_TYPES(out.dtype(), CTYPE_OUT, {
|
|
||||||
if constexpr (cu::supports_binary_op<Op, CTYPE_IN, CTYPE_OUT>()) {
|
|
||||||
using InType = cuda_type_t<CTYPE_IN>;
|
|
||||||
using OutType = cuda_type_t<CTYPE_OUT>;
|
|
||||||
|
|
||||||
auto bopt = get_binary_op_type(a, b);
|
|
||||||
if (bopt == BinaryOpType::General) {
|
|
||||||
auto [shape, strides] = collapse_contiguous_dims(a, b, out);
|
|
||||||
auto& a_strides = strides[0];
|
|
||||||
auto& b_strides = strides[1];
|
|
||||||
bool large = a.data_size() > UINT32_MAX ||
|
|
||||||
b.data_size() > UINT32_MAX || out.data_size() > UINT32_MAX;
|
|
||||||
MLX_SWITCH_BOOL(large, LARGE, {
|
|
||||||
using IdxT = std::conditional_t<LARGE, int64_t, uint32_t>;
|
|
||||||
int ndim = shape.size();
|
|
||||||
if (ndim <= 3) {
|
|
||||||
MLX_SWITCH_1_2_3(ndim, NDIM, {
|
|
||||||
auto kernel =
|
|
||||||
&cu::binary_g_nd<Op, InType, OutType, IdxT, NDIM>;
|
|
||||||
auto [num_blocks, block_dims] =
|
|
||||||
get_launch_args(kernel, out, large);
|
|
||||||
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
|
||||||
a.data<InType>(),
|
|
||||||
b.data<InType>(),
|
|
||||||
out.data<OutType>(),
|
|
||||||
out.data_size(),
|
|
||||||
const_param<NDIM>(shape),
|
|
||||||
const_param<NDIM>(a_strides),
|
|
||||||
const_param<NDIM>(b_strides));
|
|
||||||
});
|
|
||||||
} else {
|
|
||||||
auto kernel = cu::binary_g<Op, InType, OutType, IdxT>;
|
|
||||||
auto [num_blocks, block_dims] =
|
|
||||||
get_launch_args(kernel, out, large);
|
|
||||||
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
|
||||||
a.data<InType>(),
|
|
||||||
b.data<InType>(),
|
|
||||||
out.data<OutType>(),
|
|
||||||
out.data_size(),
|
|
||||||
const_param(shape),
|
|
||||||
const_param(a_strides),
|
|
||||||
const_param(b_strides),
|
|
||||||
ndim);
|
|
||||||
}
|
|
||||||
});
|
|
||||||
} else {
|
|
||||||
MLX_SWITCH_BOOL(out.data_size() > UINT32_MAX, LARGE, {
|
|
||||||
using IdxT = std::conditional_t<LARGE, int64_t, uint32_t>;
|
|
||||||
auto kernel = cu::binary_ss<Op, InType, OutType, IdxT>;
|
|
||||||
if (bopt == BinaryOpType::ScalarVector) {
|
|
||||||
kernel = cu::binary_sv<Op, InType, OutType, IdxT>;
|
|
||||||
} else if (bopt == BinaryOpType::VectorScalar) {
|
|
||||||
kernel = cu::binary_vs<Op, InType, OutType, IdxT>;
|
|
||||||
} else if (bopt == BinaryOpType::VectorVector) {
|
|
||||||
kernel = cu::binary_vv<Op, InType, OutType, IdxT>;
|
|
||||||
}
|
|
||||||
auto [num_blocks, block_dims] =
|
|
||||||
get_launch_args(kernel, out, LARGE);
|
|
||||||
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
|
||||||
a.data<InType>(),
|
|
||||||
b.data<InType>(),
|
|
||||||
out.data<OutType>(),
|
|
||||||
out.data_size());
|
|
||||||
});
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
throw std::runtime_error(fmt::format(
|
|
||||||
"Can not do binary op {} on inputs of {} with result of {}.",
|
|
||||||
op,
|
|
||||||
dtype_to_string(a.dtype()),
|
|
||||||
dtype_to_string(out.dtype())));
|
|
||||||
}
|
|
||||||
});
|
|
||||||
});
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename Op>
|
|
||||||
void binary_op_gpu(
|
|
||||||
const std::vector<array>& inputs,
|
|
||||||
std::vector<array>& outputs,
|
|
||||||
std::string_view op,
|
|
||||||
const Stream& s) {
|
|
||||||
auto& a = inputs[0];
|
|
||||||
auto& b = inputs[1];
|
|
||||||
auto bopt = get_binary_op_type(a, b);
|
|
||||||
set_binary_op_output_data(a, b, outputs[0], bopt);
|
|
||||||
set_binary_op_output_data(a, b, outputs[1], bopt);
|
|
||||||
binary_op_gpu_inplace<Op>(inputs, outputs, op, s);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename Op>
|
|
||||||
void binary_op_gpu(
|
|
||||||
const std::vector<array>& inputs,
|
|
||||||
array& out,
|
|
||||||
std::string_view op,
|
|
||||||
const Stream& s) {
|
|
||||||
auto& a = inputs[0];
|
|
||||||
auto& b = inputs[1];
|
|
||||||
auto bopt = get_binary_op_type(a, b);
|
|
||||||
set_binary_op_output_data(a, b, out, bopt);
|
|
||||||
std::vector<array> outputs{out};
|
|
||||||
binary_op_gpu_inplace<Op>(inputs, outputs, op, s);
|
|
||||||
}
|
|
||||||
|
|
||||||
#define BINARY_GPU(func) \
|
|
||||||
void func::eval_gpu(const std::vector<array>& inputs, array& out) { \
|
|
||||||
nvtx3::scoped_range r(#func "::eval_gpu"); \
|
|
||||||
auto& s = out.primitive().stream(); \
|
|
||||||
binary_op_gpu<cu::func>(inputs, out, get_primitive_string(this), s); \
|
|
||||||
}
|
|
||||||
|
|
||||||
#define BINARY_GPU_MULTI(func) \
|
|
||||||
void func::eval_gpu( \
|
|
||||||
const std::vector<array>& inputs, std::vector<array>& outputs) { \
|
|
||||||
nvtx3::scoped_range r(#func "::eval_gpu"); \
|
|
||||||
auto& s = outputs[0].primitive().stream(); \
|
|
||||||
binary_op_gpu<cu::func>(inputs, outputs, get_primitive_string(this), s); \
|
|
||||||
}
|
|
||||||
|
|
||||||
BINARY_GPU(Add)
|
|
||||||
BINARY_GPU(ArcTan2)
|
|
||||||
BINARY_GPU(Divide)
|
|
||||||
BINARY_GPU(Remainder)
|
|
||||||
BINARY_GPU(Equal)
|
|
||||||
BINARY_GPU(Greater)
|
|
||||||
BINARY_GPU(GreaterEqual)
|
|
||||||
BINARY_GPU(Less)
|
|
||||||
BINARY_GPU(LessEqual)
|
|
||||||
BINARY_GPU(LogicalAnd)
|
|
||||||
BINARY_GPU(LogicalOr)
|
|
||||||
BINARY_GPU(LogAddExp)
|
|
||||||
BINARY_GPU(Maximum)
|
|
||||||
BINARY_GPU(Minimum)
|
|
||||||
BINARY_GPU(Multiply)
|
|
||||||
BINARY_GPU(NotEqual)
|
|
||||||
BINARY_GPU(Power)
|
|
||||||
BINARY_GPU(Subtract)
|
|
||||||
|
|
||||||
void BitwiseBinary::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|
||||||
nvtx3::scoped_range r("BitwiseBinary::eval_gpu");
|
|
||||||
auto& s = out.primitive().stream();
|
|
||||||
auto op = get_primitive_string(this);
|
|
||||||
switch (op_) {
|
|
||||||
case BitwiseBinary::And:
|
|
||||||
binary_op_gpu<cu::BitwiseAnd>(inputs, out, op, s);
|
|
||||||
break;
|
|
||||||
case BitwiseBinary::Or:
|
|
||||||
binary_op_gpu<cu::BitwiseOr>(inputs, out, op, s);
|
|
||||||
break;
|
|
||||||
case BitwiseBinary::Xor:
|
|
||||||
binary_op_gpu<cu::BitwiseXor>(inputs, out, op, s);
|
|
||||||
break;
|
|
||||||
case BitwiseBinary::LeftShift:
|
|
||||||
binary_op_gpu<cu::LeftShift>(inputs, out, op, s);
|
|
||||||
break;
|
|
||||||
case BitwiseBinary::RightShift:
|
|
||||||
binary_op_gpu<cu::RightShift>(inputs, out, op, s);
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace mlx::core
|
|
||||||
@@ -1,228 +0,0 @@
|
|||||||
// Copyright © 2025 Apple Inc.
|
|
||||||
|
|
||||||
#include "mlx/backend/common/compiled.h"
|
|
||||||
#include "mlx/backend/cuda/device.h"
|
|
||||||
#include "mlx/backend/cuda/jit_module.h"
|
|
||||||
#include "mlx/graph_utils.h"
|
|
||||||
#include "mlx/primitives.h"
|
|
||||||
|
|
||||||
#include <fmt/format.h>
|
|
||||||
#include <nvtx3/nvtx3.hpp>
|
|
||||||
|
|
||||||
namespace mlx::core {
|
|
||||||
|
|
||||||
namespace cu {
|
|
||||||
|
|
||||||
struct FusedKernelBuilder {
|
|
||||||
std::string os;
|
|
||||||
const std::string& kernel_name;
|
|
||||||
const std::vector<array>& inputs;
|
|
||||||
const std::vector<array>& outputs;
|
|
||||||
const std::vector<array>& tape;
|
|
||||||
const std::function<bool(size_t)>& is_constant;
|
|
||||||
|
|
||||||
void build(const char* name, bool contiguous) {
|
|
||||||
NodeNamer namer;
|
|
||||||
|
|
||||||
// Function parameters.
|
|
||||||
std::vector<std::string> params;
|
|
||||||
for (size_t i = 0; i < inputs.size(); ++i) {
|
|
||||||
if (is_constant(i)) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
const auto& x = inputs[i];
|
|
||||||
const std::string& xname = namer.get_name(x);
|
|
||||||
params.push_back(
|
|
||||||
fmt::format("const {}* {}", dtype_to_cuda_type(x.dtype()), xname));
|
|
||||||
if (!is_scalar(x) && !contiguous) {
|
|
||||||
params.push_back(fmt::format(
|
|
||||||
"const __grid_constant__ cuda::std::array<int64_t, NDIM> {}_strides",
|
|
||||||
xname));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
for (const auto& x : outputs) {
|
|
||||||
params.push_back(fmt::format(
|
|
||||||
"{}* {}", dtype_to_cuda_type(x.dtype()), namer.get_name(x)));
|
|
||||||
}
|
|
||||||
if (!contiguous) {
|
|
||||||
params.push_back(
|
|
||||||
"const __grid_constant__ cuda::std::array<int32_t, NDIM> shape");
|
|
||||||
}
|
|
||||||
params.push_back("IdxT size");
|
|
||||||
|
|
||||||
// Build function signature.
|
|
||||||
if (contiguous) {
|
|
||||||
os += "template <typename IdxT = uint32_t>\n";
|
|
||||||
} else {
|
|
||||||
os += "template <int NDIM, typename IdxT = uint32_t>\n";
|
|
||||||
}
|
|
||||||
os += fmt::format("__global__ void {}(\n", kernel_name + name);
|
|
||||||
for (size_t i = 0; i < params.size(); ++i) {
|
|
||||||
os += " ";
|
|
||||||
os += params[i];
|
|
||||||
if (i != params.size() - 1) {
|
|
||||||
os += ",\n";
|
|
||||||
}
|
|
||||||
}
|
|
||||||
os += ") {\n";
|
|
||||||
|
|
||||||
// Index.
|
|
||||||
os +=
|
|
||||||
" IdxT index = cg::this_grid().thread_rank();\n"
|
|
||||||
" if (index >= size) {\n"
|
|
||||||
" return;\n"
|
|
||||||
" }\n";
|
|
||||||
|
|
||||||
// Read inputs.
|
|
||||||
for (size_t i = 0; i < inputs.size(); ++i) {
|
|
||||||
const auto& x = inputs[i];
|
|
||||||
const std::string& xname = namer.get_name(x);
|
|
||||||
std::string type = dtype_to_cuda_type(x.dtype());
|
|
||||||
std::string value;
|
|
||||||
if (is_constant(i)) {
|
|
||||||
std::ostringstream ss;
|
|
||||||
print_constant(ss, x);
|
|
||||||
value = fmt::format("static_cast<{}>({})", type, ss.str());
|
|
||||||
} else if (is_scalar(x)) {
|
|
||||||
value = fmt::format("{}[0]", xname);
|
|
||||||
} else if (contiguous) {
|
|
||||||
value = fmt::format("{}[index]", xname);
|
|
||||||
} else {
|
|
||||||
std::string index = fmt::format(
|
|
||||||
"elem_to_loc_nd<NDIM>(index, shape.data(), {}_strides.data())",
|
|
||||||
xname);
|
|
||||||
value = fmt::format("{}[{}]", xname, index);
|
|
||||||
}
|
|
||||||
os += fmt::format(" {} tmp_{} = {};\n", type, xname, value);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Write tape.
|
|
||||||
for (const auto& x : tape) {
|
|
||||||
const std::string& xname = namer.get_name(x);
|
|
||||||
std::string type = dtype_to_cuda_type(x.dtype());
|
|
||||||
std::string value;
|
|
||||||
if (is_static_cast(x.primitive())) {
|
|
||||||
value = fmt::format(
|
|
||||||
"static_cast<{}>(tmp_{})", type, namer.get_name(x.inputs()[0]));
|
|
||||||
} else {
|
|
||||||
std::ostringstream ss;
|
|
||||||
x.primitive().print(ss);
|
|
||||||
value = ss.str();
|
|
||||||
value += "{}(";
|
|
||||||
for (size_t i = 0; i < x.inputs().size() - 1; ++i) {
|
|
||||||
value += fmt::format("tmp_{}, ", namer.get_name(x.inputs()[i]));
|
|
||||||
}
|
|
||||||
value += fmt::format("tmp_{})", namer.get_name(x.inputs().back()));
|
|
||||||
}
|
|
||||||
os += fmt::format(" {} tmp_{} = {};\n", type, xname, value);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Write output.
|
|
||||||
for (const auto& x : outputs) {
|
|
||||||
os += fmt::format(" {0}[index] = tmp_{0};\n", namer.get_name(x));
|
|
||||||
}
|
|
||||||
|
|
||||||
os += "}\n";
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
} // namespace cu
|
|
||||||
|
|
||||||
constexpr const char* g_jit_includes = R"(
|
|
||||||
#include "mlx/backend/cuda/device/binary_ops.cuh"
|
|
||||||
#include "mlx/backend/cuda/device/unary_ops.cuh"
|
|
||||||
#include "mlx/backend/cuda/device/utils.cuh"
|
|
||||||
|
|
||||||
#include <cooperative_groups.h>
|
|
||||||
|
|
||||||
)";
|
|
||||||
|
|
||||||
void Compiled::eval_gpu(
|
|
||||||
const std::vector<array>& inputs,
|
|
||||||
std::vector<array>& outputs) {
|
|
||||||
nvtx3::scoped_range r("Compiled::eval_gpu");
|
|
||||||
auto& s = stream();
|
|
||||||
|
|
||||||
cu::JitModule& mod = cu::get_jit_module(s.device, lib_name(), [&]() {
|
|
||||||
// Build source code.
|
|
||||||
cu::FusedKernelBuilder builder{
|
|
||||||
g_jit_includes, lib_name(), inputs_, outputs_, tape_, is_constant_};
|
|
||||||
builder.os +=
|
|
||||||
"namespace mlx::core::cu {\n\n"
|
|
||||||
"namespace cg = cooperative_groups;\n\n";
|
|
||||||
builder.build("_contiguous", true);
|
|
||||||
builder.os += "\n";
|
|
||||||
builder.build("_strided", false);
|
|
||||||
builder.os += "\n} // namespace mlx::core::cu\n";
|
|
||||||
// Build kernel names.
|
|
||||||
std::vector<std::string> kernel_names = {
|
|
||||||
fmt::format("mlx::core::cu::{}_contiguous<uint32_t>", lib_name()),
|
|
||||||
fmt::format("mlx::core::cu::{}_contiguous<int64_t>", lib_name()),
|
|
||||||
};
|
|
||||||
for (int i = 1; i <= MAX_NDIM; ++i) {
|
|
||||||
kernel_names.push_back(fmt::format(
|
|
||||||
"mlx::core::cu::{}_strided<{}, uint32_t>", lib_name(), i));
|
|
||||||
kernel_names.push_back(
|
|
||||||
fmt::format("mlx::core::cu::{}_strided<{}, int64_t>", lib_name(), i));
|
|
||||||
}
|
|
||||||
return std::make_pair(std::move(builder.os), std::move(kernel_names));
|
|
||||||
});
|
|
||||||
|
|
||||||
// Collapse contiguous dims to route to a faster kernel if possible. Also
|
|
||||||
// handle all broadcasting.
|
|
||||||
auto [contiguous, shape, strides_vec] =
|
|
||||||
compiled_collapse_contiguous_dims(inputs, outputs[0], is_constant_);
|
|
||||||
|
|
||||||
// Whether to use large index.
|
|
||||||
bool large = compiled_use_large_index(inputs, outputs, contiguous);
|
|
||||||
|
|
||||||
// Put inputs.
|
|
||||||
int strides_index = 1;
|
|
||||||
for (size_t i = 0; i < inputs.size(); ++i) {
|
|
||||||
if (is_constant_(i)) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
const auto& x = inputs[i];
|
|
||||||
mod.append_arg(x);
|
|
||||||
if (!contiguous && !is_scalar(x)) {
|
|
||||||
mod.append_arg(strides_vec[strides_index++]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Put outputs.
|
|
||||||
compiled_allocate_outputs(inputs, outputs, is_constant_, contiguous);
|
|
||||||
for (auto& x : outputs) {
|
|
||||||
mod.append_arg(x);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Put shape and size.
|
|
||||||
if (!contiguous) {
|
|
||||||
mod.append_arg(shape);
|
|
||||||
}
|
|
||||||
if (large) {
|
|
||||||
mod.append_arg<int64_t>(outputs[0].data_size());
|
|
||||||
} else {
|
|
||||||
mod.append_arg<uint32_t>(outputs[0].data_size());
|
|
||||||
}
|
|
||||||
|
|
||||||
// Launch kernel.
|
|
||||||
const char* index_type = large ? "int64_t" : "uint32_t";
|
|
||||||
std::string kernel_name = fmt::format("mlx::core::cu::{}", lib_name());
|
|
||||||
if (contiguous) {
|
|
||||||
kernel_name += fmt::format("_contiguous<{}>", index_type);
|
|
||||||
} else {
|
|
||||||
kernel_name += fmt::format("_strided<{}, {}>", shape.size(), index_type);
|
|
||||||
}
|
|
||||||
auto& encoder = cu::get_command_encoder(s);
|
|
||||||
for (const auto& in : inputs) {
|
|
||||||
encoder.set_input_array(in);
|
|
||||||
}
|
|
||||||
for (const auto& out : outputs) {
|
|
||||||
encoder.set_output_array(out);
|
|
||||||
}
|
|
||||||
encoder.launch_kernel([&](cudaStream_t stream) {
|
|
||||||
mod.launch_kernel(stream, kernel_name, outputs[0], large);
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace mlx::core
|
|
||||||
@@ -1,89 +0,0 @@
|
|||||||
// Copyright © 2025 Apple Inc.
|
|
||||||
|
|
||||||
#include "mlx/backend/common/utils.h"
|
|
||||||
#include "mlx/backend/cuda/copy/copy.cuh"
|
|
||||||
|
|
||||||
namespace mlx::core {
|
|
||||||
|
|
||||||
void copy_gpu_inplace(
|
|
||||||
const array& in_,
|
|
||||||
array& out,
|
|
||||||
const Shape& shape,
|
|
||||||
const Strides& strides_in,
|
|
||||||
const Strides& strides_out,
|
|
||||||
int64_t offset_in,
|
|
||||||
int64_t offset_out,
|
|
||||||
CopyType ctype,
|
|
||||||
const Stream& s,
|
|
||||||
const std::optional<array>& dynamic_offset_in,
|
|
||||||
const std::optional<array>& dynamic_offset_out) {
|
|
||||||
if (out.size() == 0) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
const array& in = in_.data_shared_ptr() ? in_ : out;
|
|
||||||
|
|
||||||
auto& encoder = cu::get_command_encoder(s);
|
|
||||||
encoder.set_input_array(in);
|
|
||||||
encoder.set_output_array(out);
|
|
||||||
|
|
||||||
if (ctype == CopyType::Scalar || ctype == CopyType::Vector) {
|
|
||||||
copy_contiguous(encoder, ctype, in, out, offset_in, offset_out);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (ctype == CopyType::General || ctype == CopyType::GeneralGeneral) {
|
|
||||||
auto [shape_collapsed, strides_vec] = collapse_contiguous_dims(
|
|
||||||
shape, std::vector{strides_in, strides_out}, INT32_MAX);
|
|
||||||
if (ctype == CopyType::General) {
|
|
||||||
copy_general_input(
|
|
||||||
encoder,
|
|
||||||
ctype,
|
|
||||||
in,
|
|
||||||
out,
|
|
||||||
offset_in,
|
|
||||||
offset_out,
|
|
||||||
shape_collapsed,
|
|
||||||
strides_vec[0]);
|
|
||||||
} else {
|
|
||||||
if (dynamic_offset_in || dynamic_offset_out) {
|
|
||||||
copy_general_dynamic(
|
|
||||||
encoder,
|
|
||||||
ctype,
|
|
||||||
in,
|
|
||||||
out,
|
|
||||||
offset_in,
|
|
||||||
offset_out,
|
|
||||||
shape_collapsed,
|
|
||||||
strides_vec[0],
|
|
||||||
strides_vec[1],
|
|
||||||
dynamic_offset_in ? *dynamic_offset_in : array(0, int64),
|
|
||||||
dynamic_offset_out ? *dynamic_offset_out : array(0, int64));
|
|
||||||
} else {
|
|
||||||
copy_general(
|
|
||||||
encoder,
|
|
||||||
ctype,
|
|
||||||
in,
|
|
||||||
out,
|
|
||||||
offset_in,
|
|
||||||
offset_out,
|
|
||||||
shape_collapsed,
|
|
||||||
strides_vec[0],
|
|
||||||
strides_vec[1]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void fill_gpu(const array& in, array& out, const Stream& s) {
|
|
||||||
if (out.size() == 0) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
out.set_data(allocator::malloc(out.nbytes()));
|
|
||||||
auto& encoder = cu::get_command_encoder(s);
|
|
||||||
encoder.set_input_array(in);
|
|
||||||
encoder.set_output_array(out);
|
|
||||||
copy_contiguous(encoder, CopyType::Scalar, in, out, 0, 0);
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace mlx::core
|
|
||||||
@@ -1,71 +0,0 @@
|
|||||||
// Copyright © 2025 Apple Inc.
|
|
||||||
|
|
||||||
#pragma once
|
|
||||||
|
|
||||||
#include "mlx/backend/cuda/device.h"
|
|
||||||
#include "mlx/backend/cuda/device/cast_op.cuh"
|
|
||||||
#include "mlx/backend/cuda/kernel_utils.cuh"
|
|
||||||
#include "mlx/backend/gpu/copy.h"
|
|
||||||
#include "mlx/dtype_utils.h"
|
|
||||||
|
|
||||||
namespace mlx::core {
|
|
||||||
|
|
||||||
#define MLX_SWITCH_COPY_TYPES(in, out, InType, OutType, ...) \
|
|
||||||
MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE_IN, { \
|
|
||||||
MLX_SWITCH_ALL_TYPES(out.dtype(), CTYPE_OUT, { \
|
|
||||||
using InType = cuda_type_t<CTYPE_IN>; \
|
|
||||||
using OutType = cuda_type_t<CTYPE_OUT>; \
|
|
||||||
if constexpr (cu::CastOp<InType, OutType>::is_castable) { \
|
|
||||||
__VA_ARGS__; \
|
|
||||||
} else { \
|
|
||||||
throw std::runtime_error(fmt::format( \
|
|
||||||
"Can not copy data from dtype {} to {}.", \
|
|
||||||
dtype_to_string(out.dtype()), \
|
|
||||||
dtype_to_string(in.dtype()))); \
|
|
||||||
} \
|
|
||||||
}); \
|
|
||||||
})
|
|
||||||
|
|
||||||
void copy_contiguous(
|
|
||||||
cu::CommandEncoder& encoder,
|
|
||||||
CopyType ctype,
|
|
||||||
const array& in,
|
|
||||||
array& out,
|
|
||||||
int64_t offset_in,
|
|
||||||
int64_t offset_out);
|
|
||||||
|
|
||||||
void copy_general(
|
|
||||||
cu::CommandEncoder& encoder,
|
|
||||||
CopyType ctype,
|
|
||||||
const array& in,
|
|
||||||
array& out,
|
|
||||||
int64_t offset_in,
|
|
||||||
int64_t offset_out,
|
|
||||||
const Shape& shape,
|
|
||||||
const Strides& strides_in,
|
|
||||||
const Strides& strides_out);
|
|
||||||
|
|
||||||
void copy_general_dynamic(
|
|
||||||
cu::CommandEncoder& encoder,
|
|
||||||
CopyType ctype,
|
|
||||||
const array& in,
|
|
||||||
array& out,
|
|
||||||
int64_t offset_in,
|
|
||||||
int64_t offset_out,
|
|
||||||
const Shape& shape,
|
|
||||||
const Strides& strides_in,
|
|
||||||
const Strides& strides_out,
|
|
||||||
const array& dynamic_offset_in,
|
|
||||||
const array& dynamic_offset_out);
|
|
||||||
|
|
||||||
void copy_general_input(
|
|
||||||
cu::CommandEncoder& encoder,
|
|
||||||
CopyType ctype,
|
|
||||||
const array& in,
|
|
||||||
array& out,
|
|
||||||
int64_t offset_in,
|
|
||||||
int64_t offset_out,
|
|
||||||
const Shape& shape,
|
|
||||||
const Strides& strides_in);
|
|
||||||
|
|
||||||
} // namespace mlx::core
|
|
||||||
@@ -1,56 +0,0 @@
|
|||||||
// Copyright © 2025 Apple Inc.
|
|
||||||
|
|
||||||
#include "mlx/backend/cuda/copy/copy.cuh"
|
|
||||||
|
|
||||||
#include <cooperative_groups.h>
|
|
||||||
|
|
||||||
namespace mlx::core {
|
|
||||||
|
|
||||||
namespace cu {
|
|
||||||
|
|
||||||
namespace cg = cooperative_groups;
|
|
||||||
|
|
||||||
template <typename In, typename Out, typename IdxT>
|
|
||||||
__global__ void copy_s(const In* in, Out* out, IdxT size) {
|
|
||||||
IdxT index = cg::this_grid().thread_rank();
|
|
||||||
if (index < size) {
|
|
||||||
out[index] = CastOp<In, Out>{}(in[0]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename In, typename Out, typename IdxT>
|
|
||||||
__global__ void copy_v(const In* in, Out* out, IdxT size) {
|
|
||||||
IdxT index = cg::this_grid().thread_rank();
|
|
||||||
if (index < size) {
|
|
||||||
out[index] = CastOp<In, Out>{}(in[index]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace cu
|
|
||||||
|
|
||||||
void copy_contiguous(
|
|
||||||
cu::CommandEncoder& encoder,
|
|
||||||
CopyType ctype,
|
|
||||||
const array& in,
|
|
||||||
array& out,
|
|
||||||
int64_t in_offset,
|
|
||||||
int64_t out_offset) {
|
|
||||||
encoder.launch_kernel([&](cudaStream_t stream) {
|
|
||||||
MLX_SWITCH_COPY_TYPES(in, out, InType, OutType, {
|
|
||||||
MLX_SWITCH_BOOL(out.data_size() > UINT32_MAX, LARGE, {
|
|
||||||
using IdxT = std::conditional_t<LARGE, int64_t, uint32_t>;
|
|
||||||
auto kernel = cu::copy_s<InType, OutType, IdxT>;
|
|
||||||
if (ctype == CopyType::Vector) {
|
|
||||||
kernel = cu::copy_v<InType, OutType, IdxT>;
|
|
||||||
}
|
|
||||||
auto [num_blocks, block_dims] = get_launch_args(kernel, out, LARGE);
|
|
||||||
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
|
||||||
in.data<InType>() + in_offset,
|
|
||||||
out.data<OutType>() + out_offset,
|
|
||||||
out.data_size());
|
|
||||||
});
|
|
||||||
});
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace mlx::core
|
|
||||||
@@ -1,95 +0,0 @@
|
|||||||
// Copyright © 2025 Apple Inc.
|
|
||||||
|
|
||||||
#include "mlx/backend/cuda/copy/copy.cuh"
|
|
||||||
|
|
||||||
#include <cooperative_groups.h>
|
|
||||||
|
|
||||||
namespace mlx::core {
|
|
||||||
|
|
||||||
namespace cu {
|
|
||||||
|
|
||||||
namespace cg = cooperative_groups;
|
|
||||||
|
|
||||||
template <typename In, typename Out, typename IdxT, int NDIM>
|
|
||||||
__global__ void copy_gg_nd(
|
|
||||||
const In* in,
|
|
||||||
Out* out,
|
|
||||||
IdxT size,
|
|
||||||
const __grid_constant__ cuda::std::array<int32_t, NDIM> shape,
|
|
||||||
const __grid_constant__ cuda::std::array<int64_t, NDIM> strides_in,
|
|
||||||
const __grid_constant__ cuda::std::array<int64_t, NDIM> strides_out) {
|
|
||||||
IdxT index = cg::this_grid().thread_rank();
|
|
||||||
if (index < size) {
|
|
||||||
auto [idx_in, idx_out] = elem_to_loc_nd<NDIM>(
|
|
||||||
index, shape.data(), strides_in.data(), strides_out.data());
|
|
||||||
out[idx_out] = CastOp<In, Out>{}(in[idx_in]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename In, typename Out, typename IdxT>
|
|
||||||
__global__ void copy_gg(
|
|
||||||
const In* in,
|
|
||||||
Out* out,
|
|
||||||
IdxT size,
|
|
||||||
const __grid_constant__ Shape shape,
|
|
||||||
const __grid_constant__ Strides strides_in,
|
|
||||||
const __grid_constant__ Strides strides_out,
|
|
||||||
int ndim) {
|
|
||||||
IdxT index = cg::this_grid().thread_rank();
|
|
||||||
if (index < size) {
|
|
||||||
auto [idx_in, idx_out] = elem_to_loc_4d(
|
|
||||||
index, shape.data(), strides_in.data(), strides_out.data(), ndim);
|
|
||||||
out[idx_out] = CastOp<In, Out>{}(in[idx_in]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace cu
|
|
||||||
|
|
||||||
void copy_general(
|
|
||||||
cu::CommandEncoder& encoder,
|
|
||||||
CopyType ctype,
|
|
||||||
const array& in,
|
|
||||||
array& out,
|
|
||||||
int64_t offset_in,
|
|
||||||
int64_t offset_out,
|
|
||||||
const Shape& shape,
|
|
||||||
const Strides& strides_in,
|
|
||||||
const Strides& strides_out) {
|
|
||||||
encoder.launch_kernel([&](cudaStream_t stream) {
|
|
||||||
MLX_SWITCH_COPY_TYPES(in, out, InType, OutType, {
|
|
||||||
const InType* in_ptr = in.data<InType>() + offset_in;
|
|
||||||
OutType* out_ptr = out.data<OutType>() + offset_out;
|
|
||||||
bool large = in.data_size() > UINT32_MAX || out.data_size() > UINT32_MAX;
|
|
||||||
MLX_SWITCH_BOOL(large, LARGE, {
|
|
||||||
using IdxT = std::conditional_t<LARGE, int64_t, uint32_t>;
|
|
||||||
int ndim = shape.size();
|
|
||||||
if (ndim <= 3) {
|
|
||||||
MLX_SWITCH_1_2_3(ndim, NDIM, {
|
|
||||||
auto kernel = cu::copy_gg_nd<InType, OutType, IdxT, NDIM>;
|
|
||||||
auto [num_blocks, block_dims] = get_launch_args(kernel, out, large);
|
|
||||||
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
|
||||||
in_ptr,
|
|
||||||
out_ptr,
|
|
||||||
out.data_size(),
|
|
||||||
const_param<NDIM>(shape),
|
|
||||||
const_param<NDIM>(strides_in),
|
|
||||||
const_param<NDIM>(strides_out));
|
|
||||||
});
|
|
||||||
} else { // ndim >= 4
|
|
||||||
auto kernel = cu::copy_gg<InType, OutType, IdxT>;
|
|
||||||
auto [num_blocks, block_dims] = get_launch_args(kernel, out, large);
|
|
||||||
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
|
||||||
in_ptr,
|
|
||||||
out_ptr,
|
|
||||||
out.data_size(),
|
|
||||||
const_param(shape),
|
|
||||||
const_param(strides_in),
|
|
||||||
const_param(strides_out),
|
|
||||||
ndim);
|
|
||||||
}
|
|
||||||
});
|
|
||||||
});
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace mlx::core
|
|
||||||
@@ -1,105 +0,0 @@
|
|||||||
// Copyright © 2025 Apple Inc.
|
|
||||||
|
|
||||||
#include "mlx/backend/cuda/copy/copy.cuh"
|
|
||||||
|
|
||||||
#include <cooperative_groups.h>
|
|
||||||
|
|
||||||
namespace mlx::core {
|
|
||||||
|
|
||||||
namespace cu {
|
|
||||||
|
|
||||||
namespace cg = cooperative_groups;
|
|
||||||
|
|
||||||
template <typename In, typename Out, typename IdxT, int NDIM>
|
|
||||||
__global__ void copy_gg_dynamic_nd(
|
|
||||||
const In* in,
|
|
||||||
Out* out,
|
|
||||||
IdxT size,
|
|
||||||
const __grid_constant__ cuda::std::array<int32_t, NDIM> shape,
|
|
||||||
const __grid_constant__ cuda::std::array<int64_t, NDIM> strides_in,
|
|
||||||
const __grid_constant__ cuda::std::array<int64_t, NDIM> strides_out,
|
|
||||||
const int64_t* offset_in,
|
|
||||||
const int64_t* offset_out) {
|
|
||||||
IdxT index = cg::this_grid().thread_rank();
|
|
||||||
if (index < size) {
|
|
||||||
auto [idx_in, idx_out] = elem_to_loc_nd<NDIM>(
|
|
||||||
index, shape.data(), strides_in.data(), strides_out.data());
|
|
||||||
out[idx_out + *offset_out] = CastOp<In, Out>{}(in[idx_in + *offset_in]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename In, typename Out, typename IdxT>
|
|
||||||
__global__ void copy_gg_dynamic(
|
|
||||||
const In* in,
|
|
||||||
Out* out,
|
|
||||||
IdxT size,
|
|
||||||
const __grid_constant__ Shape shape,
|
|
||||||
const __grid_constant__ Strides strides_in,
|
|
||||||
const __grid_constant__ Strides strides_out,
|
|
||||||
int ndim,
|
|
||||||
const int64_t* offset_in,
|
|
||||||
const int64_t* offset_out) {
|
|
||||||
IdxT index = cg::this_grid().thread_rank();
|
|
||||||
if (index < size) {
|
|
||||||
auto [idx_in, idx_out] = elem_to_loc_4d(
|
|
||||||
index, shape.data(), strides_in.data(), strides_out.data(), ndim);
|
|
||||||
out[idx_out + *offset_out] = CastOp<In, Out>{}(in[idx_in + *offset_in]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace cu
|
|
||||||
|
|
||||||
void copy_general_dynamic(
|
|
||||||
cu::CommandEncoder& encoder,
|
|
||||||
CopyType ctype,
|
|
||||||
const array& in,
|
|
||||||
array& out,
|
|
||||||
int64_t offset_in,
|
|
||||||
int64_t offset_out,
|
|
||||||
const Shape& shape,
|
|
||||||
const Strides& strides_in,
|
|
||||||
const Strides& strides_out,
|
|
||||||
const array& dynamic_offset_in,
|
|
||||||
const array& dynamic_offset_out) {
|
|
||||||
encoder.launch_kernel([&](cudaStream_t stream) {
|
|
||||||
MLX_SWITCH_COPY_TYPES(in, out, InType, OutType, {
|
|
||||||
const InType* in_ptr = in.data<InType>() + offset_in;
|
|
||||||
OutType* out_ptr = out.data<OutType>() + offset_out;
|
|
||||||
bool large = in.data_size() > UINT32_MAX || out.data_size() > UINT32_MAX;
|
|
||||||
MLX_SWITCH_BOOL(large, LARGE, {
|
|
||||||
using IdxT = std::conditional_t<LARGE, int64_t, uint32_t>;
|
|
||||||
int ndim = shape.size();
|
|
||||||
if (ndim <= 3) {
|
|
||||||
MLX_SWITCH_1_2_3(ndim, NDIM, {
|
|
||||||
auto kernel = cu::copy_gg_dynamic_nd<InType, OutType, IdxT, NDIM>;
|
|
||||||
auto [num_blocks, block_dims] = get_launch_args(kernel, out, large);
|
|
||||||
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
|
||||||
in_ptr,
|
|
||||||
out_ptr,
|
|
||||||
out.data_size(),
|
|
||||||
const_param<NDIM>(shape),
|
|
||||||
const_param<NDIM>(strides_in),
|
|
||||||
const_param<NDIM>(strides_out),
|
|
||||||
dynamic_offset_in.data<int64_t>(),
|
|
||||||
dynamic_offset_out.data<int64_t>());
|
|
||||||
});
|
|
||||||
} else { // ndim >= 4
|
|
||||||
auto kernel = cu::copy_gg_dynamic<InType, OutType, IdxT>;
|
|
||||||
auto [num_blocks, block_dims] = get_launch_args(kernel, out, large);
|
|
||||||
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
|
||||||
in_ptr,
|
|
||||||
out_ptr,
|
|
||||||
out.data_size(),
|
|
||||||
const_param(shape),
|
|
||||||
const_param(strides_in),
|
|
||||||
const_param(strides_out),
|
|
||||||
ndim,
|
|
||||||
dynamic_offset_in.data<int64_t>(),
|
|
||||||
dynamic_offset_out.data<int64_t>());
|
|
||||||
}
|
|
||||||
});
|
|
||||||
});
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace mlx::core
|
|
||||||
@@ -1,88 +0,0 @@
|
|||||||
// Copyright © 2025 Apple Inc.
|
|
||||||
|
|
||||||
#include "mlx/backend/cuda/copy/copy.cuh"
|
|
||||||
|
|
||||||
#include <cooperative_groups.h>
|
|
||||||
|
|
||||||
namespace mlx::core {
|
|
||||||
|
|
||||||
namespace cu {
|
|
||||||
|
|
||||||
namespace cg = cooperative_groups;
|
|
||||||
|
|
||||||
template <typename In, typename Out, typename IdxT, int NDIM>
|
|
||||||
__global__ void copy_g_nd(
|
|
||||||
const In* in,
|
|
||||||
Out* out,
|
|
||||||
IdxT size,
|
|
||||||
const __grid_constant__ cuda::std::array<int32_t, NDIM> shape,
|
|
||||||
const __grid_constant__ cuda::std::array<int64_t, NDIM> strides_in) {
|
|
||||||
IdxT index = cg::this_grid().thread_rank();
|
|
||||||
if (index < size) {
|
|
||||||
IdxT idx_in = elem_to_loc_nd<NDIM>(index, shape.data(), strides_in.data());
|
|
||||||
out[index] = CastOp<In, Out>{}(in[idx_in]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename In, typename Out, typename IdxT>
|
|
||||||
__global__ void copy_g(
|
|
||||||
const In* in,
|
|
||||||
Out* out,
|
|
||||||
IdxT size,
|
|
||||||
const __grid_constant__ Shape shape,
|
|
||||||
const __grid_constant__ Strides strides_in,
|
|
||||||
int ndim) {
|
|
||||||
IdxT index = cg::this_grid().thread_rank();
|
|
||||||
if (index < size) {
|
|
||||||
IdxT idx_in = elem_to_loc_4d(index, shape.data(), strides_in.data(), ndim);
|
|
||||||
out[index] = CastOp<In, Out>{}(in[idx_in]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace cu
|
|
||||||
|
|
||||||
void copy_general_input(
|
|
||||||
cu::CommandEncoder& encoder,
|
|
||||||
CopyType ctype,
|
|
||||||
const array& in,
|
|
||||||
array& out,
|
|
||||||
int64_t offset_in,
|
|
||||||
int64_t offset_out,
|
|
||||||
const Shape& shape,
|
|
||||||
const Strides& strides_in) {
|
|
||||||
encoder.launch_kernel([&](cudaStream_t stream) {
|
|
||||||
MLX_SWITCH_COPY_TYPES(in, out, InType, OutType, {
|
|
||||||
const InType* in_ptr = in.data<InType>() + offset_in;
|
|
||||||
OutType* out_ptr = out.data<OutType>() + offset_out;
|
|
||||||
bool large = in.data_size() > UINT32_MAX || out.data_size() > UINT32_MAX;
|
|
||||||
MLX_SWITCH_BOOL(large, LARGE, {
|
|
||||||
using IdxT = std::conditional_t<LARGE, int64_t, uint32_t>;
|
|
||||||
int ndim = shape.size();
|
|
||||||
if (ndim <= 3) {
|
|
||||||
MLX_SWITCH_1_2_3(ndim, NDIM, {
|
|
||||||
auto kernel = cu::copy_g_nd<InType, OutType, IdxT, NDIM>;
|
|
||||||
auto [num_blocks, block_dims] = get_launch_args(kernel, out, large);
|
|
||||||
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
|
||||||
in_ptr,
|
|
||||||
out_ptr,
|
|
||||||
out.data_size(),
|
|
||||||
const_param<NDIM>(shape),
|
|
||||||
const_param<NDIM>(strides_in));
|
|
||||||
});
|
|
||||||
} else { // ndim >= 4
|
|
||||||
auto kernel = cu::copy_g<InType, OutType, IdxT>;
|
|
||||||
auto [num_blocks, block_dims] = get_launch_args(kernel, out, large);
|
|
||||||
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
|
||||||
in_ptr,
|
|
||||||
out_ptr,
|
|
||||||
out.data_size(),
|
|
||||||
const_param(shape),
|
|
||||||
const_param(strides_in),
|
|
||||||
ndim);
|
|
||||||
}
|
|
||||||
});
|
|
||||||
});
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace mlx::core
|
|
||||||
@@ -1,11 +0,0 @@
|
|||||||
// Copyright © 2025 Apple Inc.
|
|
||||||
|
|
||||||
#include "mlx/backend/cuda/cuda.h"
|
|
||||||
|
|
||||||
namespace mlx::core::cu {
|
|
||||||
|
|
||||||
bool is_available() {
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace mlx::core::cu
|
|
||||||
@@ -1,10 +0,0 @@
|
|||||||
// Copyright © 2025 Apple Inc.
|
|
||||||
|
|
||||||
#pragma once
|
|
||||||
|
|
||||||
namespace mlx::core::cu {
|
|
||||||
|
|
||||||
/* Check if the CUDA backend is available. */
|
|
||||||
bool is_available();
|
|
||||||
|
|
||||||
} // namespace mlx::core::cu
|
|
||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user