mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Compare commits
1 Commits
33bf1a244b
...
cuda-sdpa-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
870208eff5 |
@@ -16,9 +16,6 @@ parameters:
|
|||||||
linux_release:
|
linux_release:
|
||||||
type: boolean
|
type: boolean
|
||||||
default: false
|
default: false
|
||||||
cuda_release:
|
|
||||||
type: boolean
|
|
||||||
default: false
|
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
build_documentation:
|
build_documentation:
|
||||||
@@ -107,7 +104,7 @@ jobs:
|
|||||||
command: |
|
command: |
|
||||||
echo "stubs"
|
echo "stubs"
|
||||||
pip install typing_extensions
|
pip install typing_extensions
|
||||||
python setup.py generate_stubs
|
python setup.py generate_stubs
|
||||||
- run:
|
- run:
|
||||||
name: Run Python tests
|
name: Run Python tests
|
||||||
command: |
|
command: |
|
||||||
@@ -165,7 +162,7 @@ jobs:
|
|||||||
command: |
|
command: |
|
||||||
source env/bin/activate
|
source env/bin/activate
|
||||||
pip install typing_extensions
|
pip install typing_extensions
|
||||||
python setup.py generate_stubs
|
python setup.py generate_stubs
|
||||||
- run:
|
- run:
|
||||||
name: Run Python tests
|
name: Run Python tests
|
||||||
command: |
|
command: |
|
||||||
@@ -226,6 +223,7 @@ jobs:
|
|||||||
command: |
|
command: |
|
||||||
sudo apt-get update
|
sudo apt-get update
|
||||||
sudo apt-get install libblas-dev liblapack-dev liblapacke-dev
|
sudo apt-get install libblas-dev liblapack-dev liblapacke-dev
|
||||||
|
sudo apt-get install openmpi-bin openmpi-common libopenmpi-dev
|
||||||
python -m venv env
|
python -m venv env
|
||||||
source env/bin/activate
|
source env/bin/activate
|
||||||
CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
|
CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
|
||||||
@@ -285,7 +283,7 @@ jobs:
|
|||||||
command: |
|
command: |
|
||||||
source env/bin/activate
|
source env/bin/activate
|
||||||
pip install typing_extensions
|
pip install typing_extensions
|
||||||
python setup.py generate_stubs
|
python setup.py generate_stubs
|
||||||
- run:
|
- run:
|
||||||
name: Build Python package
|
name: Build Python package
|
||||||
command: |
|
command: |
|
||||||
@@ -344,7 +342,7 @@ jobs:
|
|||||||
CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
|
CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
|
||||||
pip install . -v
|
pip install . -v
|
||||||
pip install typing_extensions
|
pip install typing_extensions
|
||||||
python setup.py generate_stubs
|
python setup.py generate_stubs
|
||||||
<< parameters.extra_env >> \
|
<< parameters.extra_env >> \
|
||||||
CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
|
CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
|
||||||
python -m build --wheel
|
python -m build --wheel
|
||||||
@@ -358,48 +356,6 @@ jobs:
|
|||||||
- store_artifacts:
|
- store_artifacts:
|
||||||
path: wheelhouse/
|
path: wheelhouse/
|
||||||
|
|
||||||
build_cuda_release:
|
|
||||||
parameters:
|
|
||||||
python_version:
|
|
||||||
type: string
|
|
||||||
default: "3.9"
|
|
||||||
extra_env:
|
|
||||||
type: string
|
|
||||||
default: "DEV_RELEASE=1"
|
|
||||||
machine:
|
|
||||||
image: linux-cuda-12:default
|
|
||||||
resource_class: gpu.nvidia.small.gen2
|
|
||||||
steps:
|
|
||||||
- checkout
|
|
||||||
- run:
|
|
||||||
name: Build wheel
|
|
||||||
command: |
|
|
||||||
sudo apt-get update
|
|
||||||
sudo apt-get install libblas-dev liblapack-dev liblapacke-dev
|
|
||||||
python -m venv env
|
|
||||||
source env/bin/activate
|
|
||||||
pip install auditwheel
|
|
||||||
pip install patchelf
|
|
||||||
pip install build
|
|
||||||
pip install twine
|
|
||||||
<< parameters.extra_env >> \
|
|
||||||
CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
|
|
||||||
CMAKE_ARGS="-DMLX_BUILD_CUDA=ON -DCMAKE_CUDA_COMPILER=`which nvcc`" \
|
|
||||||
pip install ".[dev]" -v
|
|
||||||
python setup.py generate_stubs
|
|
||||||
<< parameters.extra_env >> \
|
|
||||||
CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
|
|
||||||
CMAKE_ARGS="-DMLX_BUILD_CUDA=ON -DCMAKE_CUDA_COMPILER=`which nvcc`" \
|
|
||||||
python -m build --wheel
|
|
||||||
bash python/scripts/repair_cuda.sh
|
|
||||||
- run:
|
|
||||||
name: Upload package
|
|
||||||
command: |
|
|
||||||
source env/bin/activate
|
|
||||||
twine upload wheelhouse/*.whl
|
|
||||||
- store_artifacts:
|
|
||||||
path: wheelhouse/
|
|
||||||
|
|
||||||
workflows:
|
workflows:
|
||||||
build_and_test:
|
build_and_test:
|
||||||
when:
|
when:
|
||||||
@@ -669,14 +625,3 @@ workflows:
|
|||||||
parameters:
|
parameters:
|
||||||
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
|
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
|
||||||
extra_env: ["PYPI_RELEASE=1"]
|
extra_env: ["PYPI_RELEASE=1"]
|
||||||
cuda_test_release:
|
|
||||||
when:
|
|
||||||
and:
|
|
||||||
- equal: [ main, << pipeline.git.branch >> ]
|
|
||||||
- << pipeline.parameters.cuda_release >>
|
|
||||||
jobs:
|
|
||||||
- build_cuda_release:
|
|
||||||
matrix:
|
|
||||||
parameters:
|
|
||||||
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
|
|
||||||
extra_env: ["PYPI_RELEASE=1"]
|
|
||||||
|
|||||||
@@ -5,7 +5,6 @@ import os
|
|||||||
import time
|
import time
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.cuda
|
|
||||||
import torch.mps
|
import torch.mps
|
||||||
|
|
||||||
|
|
||||||
@@ -45,10 +44,8 @@ def bench(f, *args):
|
|||||||
|
|
||||||
|
|
||||||
def sync_if_needed(x):
|
def sync_if_needed(x):
|
||||||
if x.device == torch.device("mps"):
|
if x.device != torch.device("cpu"):
|
||||||
torch.mps.synchronize()
|
torch.mps.synchronize()
|
||||||
elif x.device == torch.device("cuda"):
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
@@ -102,14 +99,6 @@ def reduction(op, axis, x):
|
|||||||
sync_if_needed(x)
|
sync_if_needed(x)
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def sum_and_add(axis, x, y):
|
|
||||||
z = x.sum(axis=axis, keepdims=True)
|
|
||||||
for i in range(50):
|
|
||||||
z = (z + y).sum(axis=axis, keepdims=True)
|
|
||||||
sync_if_needed(x)
|
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def softmax(axis, x):
|
def softmax(axis, x):
|
||||||
ys = []
|
ys = []
|
||||||
@@ -351,11 +340,7 @@ if __name__ == "__main__":
|
|||||||
args.axis.pop(0)
|
args.axis.pop(0)
|
||||||
|
|
||||||
torch.set_num_threads(1)
|
torch.set_num_threads(1)
|
||||||
device = "mps"
|
device = "cpu" if args.cpu else "mps"
|
||||||
if torch.cuda.is_available():
|
|
||||||
device = "cuda"
|
|
||||||
if args.cpu:
|
|
||||||
device = "cpu"
|
|
||||||
|
|
||||||
types = args.dtype
|
types = args.dtype
|
||||||
if not types:
|
if not types:
|
||||||
@@ -475,8 +460,5 @@ if __name__ == "__main__":
|
|||||||
elif args.benchmark == "selu":
|
elif args.benchmark == "selu":
|
||||||
print(bench(selu, x))
|
print(bench(selu, x))
|
||||||
|
|
||||||
elif args.benchmark == "sum_and_add":
|
|
||||||
print(bench(sum_and_add, axis, *xs))
|
|
||||||
|
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown benchmark `{args.benchmark}`.")
|
raise ValueError(f"Unknown benchmark `{args.benchmark}`.")
|
||||||
|
|||||||
@@ -30,16 +30,6 @@ MLX is also available on conda-forge. To install MLX with conda do:
|
|||||||
|
|
||||||
conda install conda-forge::mlx
|
conda install conda-forge::mlx
|
||||||
|
|
||||||
CUDA
|
|
||||||
^^^^
|
|
||||||
|
|
||||||
MLX has a CUDA backend which you can use on any Linux platform with CUDA 12
|
|
||||||
and SM 7.0 (Volta) and up. To install MLX with CUDA support, run:
|
|
||||||
|
|
||||||
.. code-block:: shell
|
|
||||||
|
|
||||||
pip install mlx-cuda
|
|
||||||
|
|
||||||
|
|
||||||
Troubleshooting
|
Troubleshooting
|
||||||
^^^^^^^^^^^^^^^
|
^^^^^^^^^^^^^^^
|
||||||
@@ -75,8 +65,6 @@ Build Requirements
|
|||||||
Python API
|
Python API
|
||||||
^^^^^^^^^^
|
^^^^^^^^^^
|
||||||
|
|
||||||
.. _python install:
|
|
||||||
|
|
||||||
To build and install the MLX python library from source, first, clone MLX from
|
To build and install the MLX python library from source, first, clone MLX from
|
||||||
`its GitHub repo <https://github.com/ml-explore/mlx>`_:
|
`its GitHub repo <https://github.com/ml-explore/mlx>`_:
|
||||||
|
|
||||||
@@ -119,8 +107,6 @@ IDE:
|
|||||||
C++ API
|
C++ API
|
||||||
^^^^^^^
|
^^^^^^^
|
||||||
|
|
||||||
.. _cpp install:
|
|
||||||
|
|
||||||
Currently, MLX must be built and installed from source.
|
Currently, MLX must be built and installed from source.
|
||||||
|
|
||||||
Similarly to the python library, to build and install the MLX C++ library start
|
Similarly to the python library, to build and install the MLX C++ library start
|
||||||
@@ -199,7 +185,6 @@ should point to the path to the built metal library.
|
|||||||
|
|
||||||
xcrun -sdk macosx --show-sdk-version
|
xcrun -sdk macosx --show-sdk-version
|
||||||
|
|
||||||
|
|
||||||
Binary Size Minimization
|
Binary Size Minimization
|
||||||
~~~~~~~~~~~~~~~~~~~~~~~~
|
~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
@@ -228,50 +213,6 @@ be anwywhere from a few hundred millisecond to a few seconds depending on the
|
|||||||
application. Once a kernel is compiled, it will be cached by the system. The
|
application. Once a kernel is compiled, it will be cached by the system. The
|
||||||
Metal kernel cache persists across reboots.
|
Metal kernel cache persists across reboots.
|
||||||
|
|
||||||
Linux
|
|
||||||
^^^^^
|
|
||||||
|
|
||||||
To build from source on Linux (CPU only), install the BLAS and LAPACK headers.
|
|
||||||
For example on Ubuntu, run the following:
|
|
||||||
|
|
||||||
.. code-block:: shell
|
|
||||||
|
|
||||||
apt-get update -y
|
|
||||||
apt-get install libblas-dev liblapack-dev liblapacke-dev -y
|
|
||||||
|
|
||||||
From here follow the instructions to install either the :ref:`Python <python
|
|
||||||
install>` or :ref:`C++ <cpp install>` APIs.
|
|
||||||
|
|
||||||
CUDA
|
|
||||||
^^^^
|
|
||||||
|
|
||||||
To build from source on Linux with CUDA, install the BLAS and LAPACK headers
|
|
||||||
and the CUDA toolkit. For example on Ubuntu, run the following:
|
|
||||||
|
|
||||||
.. code-block:: shell
|
|
||||||
|
|
||||||
wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64/cuda-keyring_1.1-1_all.deb
|
|
||||||
dpkg -i cuda-keyring_1.1-1_all.deb
|
|
||||||
apt-get update -y
|
|
||||||
apt-get -y install cuda-toolkit-12-9
|
|
||||||
apt-get install libblas-dev liblapack-dev liblapacke-dev -y
|
|
||||||
|
|
||||||
|
|
||||||
When building either the Python or C++ APIs make sure to pass the cmake flag
|
|
||||||
``MLX_BUILD_CUDA=ON``. For example, to build the Python API run:
|
|
||||||
|
|
||||||
.. code-block:: shell
|
|
||||||
|
|
||||||
CMAKE_BUILD_PARALLEL_LEVEL=8 CMAKE_ARGS="-DMLX_BUILD_CUDA=ON" pip install -e ".[dev]"
|
|
||||||
|
|
||||||
To build the C++ package run:
|
|
||||||
|
|
||||||
.. code-block:: shell
|
|
||||||
|
|
||||||
mkdir -p build && cd build
|
|
||||||
cmake .. -DMLX_BUILD_CUDA=ON && make -j
|
|
||||||
|
|
||||||
|
|
||||||
Troubleshooting
|
Troubleshooting
|
||||||
^^^^^^^^^^^^^^^
|
^^^^^^^^^^^^^^^
|
||||||
|
|
||||||
|
|||||||
@@ -14,8 +14,6 @@ void print_constant(std::ostream& os, const array& x) {
|
|||||||
return print_float_constant<float16_t>(os, x);
|
return print_float_constant<float16_t>(os, x);
|
||||||
case bfloat16:
|
case bfloat16:
|
||||||
return print_float_constant<bfloat16_t>(os, x);
|
return print_float_constant<bfloat16_t>(os, x);
|
||||||
case float64:
|
|
||||||
return print_float_constant<double>(os, x);
|
|
||||||
case complex64:
|
case complex64:
|
||||||
return print_complex_constant<complex64_t>(os, x);
|
return print_complex_constant<complex64_t>(os, x);
|
||||||
case int8:
|
case int8:
|
||||||
@@ -52,8 +50,6 @@ std::string get_type_string(Dtype d) {
|
|||||||
return "float16_t";
|
return "float16_t";
|
||||||
case bfloat16:
|
case bfloat16:
|
||||||
return "bfloat16_t";
|
return "bfloat16_t";
|
||||||
case float64:
|
|
||||||
return "double";
|
|
||||||
case complex64:
|
case complex64:
|
||||||
return "complex64_t";
|
return "complex64_t";
|
||||||
case bool_:
|
case bool_:
|
||||||
|
|||||||
@@ -18,12 +18,8 @@ std::string get_type_string(Dtype d);
|
|||||||
template <typename T>
|
template <typename T>
|
||||||
void print_float_constant(std::ostream& os, const array& x) {
|
void print_float_constant(std::ostream& os, const array& x) {
|
||||||
auto old_precision = os.precision();
|
auto old_precision = os.precision();
|
||||||
if constexpr (std::is_same_v<T, double>) {
|
os << std::setprecision(std::numeric_limits<float>::digits10 + 1)
|
||||||
os << std::setprecision(std::numeric_limits<double>::digits10 + 1);
|
<< x.item<T>() << std::setprecision(old_precision);
|
||||||
} else {
|
|
||||||
os << std::setprecision(std::numeric_limits<float>::digits10 + 1);
|
|
||||||
}
|
|
||||||
os << x.item<T>() << std::setprecision(old_precision);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
|
|||||||
@@ -5,9 +5,11 @@
|
|||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
|
|
||||||
std::pair<Shape, Strides> shapes_without_reduction_axes(
|
std::pair<Shape, Strides> shapes_without_reduction_axes(
|
||||||
Shape shape,
|
const array& x,
|
||||||
Strides strides,
|
|
||||||
const std::vector<int>& axes) {
|
const std::vector<int>& axes) {
|
||||||
|
auto shape = x.shape();
|
||||||
|
auto strides = x.strides();
|
||||||
|
|
||||||
for (int i = axes.size() - 1; i >= 0; i--) {
|
for (int i = axes.size() - 1; i >= 0; i--) {
|
||||||
int a = axes[i];
|
int a = axes[i];
|
||||||
shape.erase(shape.begin() + a);
|
shape.erase(shape.begin() + a);
|
||||||
@@ -17,15 +19,6 @@ std::pair<Shape, Strides> shapes_without_reduction_axes(
|
|||||||
return std::make_pair(shape, strides);
|
return std::make_pair(shape, strides);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::pair<Shape, Strides> shapes_without_reduction_axes(
|
|
||||||
const array& x,
|
|
||||||
const std::vector<int>& axes) {
|
|
||||||
auto shape = x.shape();
|
|
||||||
auto strides = x.strides();
|
|
||||||
return shapes_without_reduction_axes(
|
|
||||||
std::move(shape), std::move(strides), axes);
|
|
||||||
}
|
|
||||||
|
|
||||||
ReductionPlan get_reduction_plan(const array& x, const std::vector<int>& axes) {
|
ReductionPlan get_reduction_plan(const array& x, const std::vector<int>& axes) {
|
||||||
// The data is all there and we are reducing over everything
|
// The data is all there and we are reducing over everything
|
||||||
if (x.size() == x.data_size() && axes.size() == x.ndim() &&
|
if (x.size() == x.data_size() && axes.size() == x.ndim() &&
|
||||||
|
|||||||
@@ -51,9 +51,5 @@ ReductionPlan get_reduction_plan(const array& x, const std::vector<int>& axes);
|
|||||||
std::pair<Shape, Strides> shapes_without_reduction_axes(
|
std::pair<Shape, Strides> shapes_without_reduction_axes(
|
||||||
const array& x,
|
const array& x,
|
||||||
const std::vector<int>& axes);
|
const std::vector<int>& axes);
|
||||||
std::pair<Shape, Strides> shapes_without_reduction_axes(
|
|
||||||
Shape shape,
|
|
||||||
Strides strides,
|
|
||||||
const std::vector<int>& axes);
|
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
|||||||
@@ -199,15 +199,12 @@ Dims get_2d_grid_dims_common(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (grid_y > UINT32_MAX || grid_x > UINT32_MAX) {
|
if (grid_y > UINT32_MAX || grid_x > UINT32_MAX || divisor > 1) {
|
||||||
throw std::runtime_error("Unable to safely factor shape.");
|
throw std::runtime_error("Unable to safely factor shape.");
|
||||||
}
|
}
|
||||||
if (grid_y > grid_x) {
|
if (grid_y > grid_x) {
|
||||||
std::swap(grid_x, grid_y);
|
std::swap(grid_x, grid_y);
|
||||||
}
|
}
|
||||||
if (divisor > 1) {
|
|
||||||
grid_x = ((grid_x + divisor - 1) / divisor) * divisor;
|
|
||||||
}
|
|
||||||
return std::make_tuple(
|
return std::make_tuple(
|
||||||
static_cast<uint32_t>(grid_x), static_cast<uint32_t>(grid_y), 1);
|
static_cast<uint32_t>(grid_x), static_cast<uint32_t>(grid_y), 1);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -8,7 +8,6 @@ target_sources(
|
|||||||
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp
|
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/arg_reduce.cu
|
${CMAKE_CURRENT_SOURCE_DIR}/arg_reduce.cu
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/binary.cu
|
${CMAKE_CURRENT_SOURCE_DIR}/binary.cu
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/binary_two.cu
|
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/compiled.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/compiled.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/copy.cu
|
${CMAKE_CURRENT_SOURCE_DIR}/copy.cu
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/copy/copy_contiguous.cu
|
${CMAKE_CURRENT_SOURCE_DIR}/copy/copy_contiguous.cu
|
||||||
@@ -29,12 +28,12 @@ target_sources(
|
|||||||
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cu
|
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cu
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/random.cu
|
${CMAKE_CURRENT_SOURCE_DIR}/random.cu
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/reduce.cu
|
${CMAKE_CURRENT_SOURCE_DIR}/reduce.cu
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/reduce/all_reduce.cu
|
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/reduce/col_reduce.cu
|
${CMAKE_CURRENT_SOURCE_DIR}/reduce/col_reduce.cu
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/reduce/init_reduce.cu
|
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/reduce/row_reduce.cu
|
${CMAKE_CURRENT_SOURCE_DIR}/reduce/row_reduce.cu
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/reduce/segmented_reduce.cu
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/rms_norm.cu
|
${CMAKE_CURRENT_SOURCE_DIR}/rms_norm.cu
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/rope.cu
|
${CMAKE_CURRENT_SOURCE_DIR}/rope.cu
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/scaled_dot_product_attention.cu
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/softmax.cu
|
${CMAKE_CURRENT_SOURCE_DIR}/softmax.cu
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/sort.cu
|
${CMAKE_CURRENT_SOURCE_DIR}/sort.cu
|
||||||
|
|||||||
@@ -3,7 +3,6 @@
|
|||||||
#include "mlx/backend/cuda/allocator.h"
|
#include "mlx/backend/cuda/allocator.h"
|
||||||
#include "mlx/backend/cuda/utils.h"
|
#include "mlx/backend/cuda/utils.h"
|
||||||
#include "mlx/backend/cuda/worker.h"
|
#include "mlx/backend/cuda/worker.h"
|
||||||
#include "mlx/utils.h"
|
|
||||||
|
|
||||||
#include <cuda_runtime.h>
|
#include <cuda_runtime.h>
|
||||||
#include <fmt/format.h>
|
#include <fmt/format.h>
|
||||||
@@ -15,11 +14,9 @@ namespace mlx::core {
|
|||||||
|
|
||||||
namespace cu {
|
namespace cu {
|
||||||
|
|
||||||
constexpr int page_size = 16384;
|
|
||||||
|
|
||||||
CudaAllocator::CudaAllocator()
|
CudaAllocator::CudaAllocator()
|
||||||
: buffer_cache_(
|
: buffer_cache_(
|
||||||
page_size,
|
getpagesize(),
|
||||||
[](CudaBuffer* buf) { return buf->size; },
|
[](CudaBuffer* buf) { return buf->size; },
|
||||||
[this](CudaBuffer* buf) {
|
[this](CudaBuffer* buf) {
|
||||||
cuda_free(buf->data);
|
cuda_free(buf->data);
|
||||||
@@ -34,14 +31,7 @@ CudaAllocator::CudaAllocator()
|
|||||||
|
|
||||||
Buffer CudaAllocator::malloc(size_t size) {
|
Buffer CudaAllocator::malloc(size_t size) {
|
||||||
// Find available buffer from cache.
|
// Find available buffer from cache.
|
||||||
auto orig_size = size;
|
|
||||||
std::unique_lock lock(mutex_);
|
std::unique_lock lock(mutex_);
|
||||||
if (size < page_size) {
|
|
||||||
size = next_power_of_2(size);
|
|
||||||
} else {
|
|
||||||
size = page_size * ((size + page_size - 1) / page_size);
|
|
||||||
}
|
|
||||||
|
|
||||||
CudaBuffer* buf = buffer_cache_.reuse_from_cache(size);
|
CudaBuffer* buf = buffer_cache_.reuse_from_cache(size);
|
||||||
if (!buf) {
|
if (!buf) {
|
||||||
// If we have a lot of memory pressure or are over the maximum cache size,
|
// If we have a lot of memory pressure or are over the maximum cache size,
|
||||||
@@ -116,6 +106,7 @@ void CudaAllocator::cuda_free(void* buf) {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
cudaFree(buf);
|
cudaFree(buf);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -125,12 +125,13 @@ constexpr bool supports_binary_op() {
|
|||||||
template <typename Op>
|
template <typename Op>
|
||||||
void binary_op_gpu_inplace(
|
void binary_op_gpu_inplace(
|
||||||
const std::vector<array>& inputs,
|
const std::vector<array>& inputs,
|
||||||
array& out,
|
std::vector<array>& outputs,
|
||||||
std::string_view op,
|
std::string_view op,
|
||||||
const Stream& s) {
|
const Stream& s) {
|
||||||
assert(inputs.size() > 1);
|
assert(inputs.size() > 1);
|
||||||
const auto& a = inputs[0];
|
const auto& a = inputs[0];
|
||||||
const auto& b = inputs[1];
|
const auto& b = inputs[1];
|
||||||
|
auto& out = outputs[0];
|
||||||
if (out.size() == 0) {
|
if (out.size() == 0) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
@@ -145,6 +146,7 @@ void binary_op_gpu_inplace(
|
|||||||
if constexpr (cu::supports_binary_op<Op, CTYPE_IN, CTYPE_OUT>()) {
|
if constexpr (cu::supports_binary_op<Op, CTYPE_IN, CTYPE_OUT>()) {
|
||||||
using InType = cuda_type_t<CTYPE_IN>;
|
using InType = cuda_type_t<CTYPE_IN>;
|
||||||
using OutType = cuda_type_t<CTYPE_OUT>;
|
using OutType = cuda_type_t<CTYPE_OUT>;
|
||||||
|
|
||||||
auto bopt = get_binary_op_type(a, b);
|
auto bopt = get_binary_op_type(a, b);
|
||||||
if (bopt == BinaryOpType::General) {
|
if (bopt == BinaryOpType::General) {
|
||||||
auto [shape, strides] = collapse_contiguous_dims(a, b, out);
|
auto [shape, strides] = collapse_contiguous_dims(a, b, out);
|
||||||
@@ -217,6 +219,20 @@ void binary_op_gpu_inplace(
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename Op>
|
||||||
|
void binary_op_gpu(
|
||||||
|
const std::vector<array>& inputs,
|
||||||
|
std::vector<array>& outputs,
|
||||||
|
std::string_view op,
|
||||||
|
const Stream& s) {
|
||||||
|
auto& a = inputs[0];
|
||||||
|
auto& b = inputs[1];
|
||||||
|
auto bopt = get_binary_op_type(a, b);
|
||||||
|
set_binary_op_output_data(a, b, outputs[0], bopt);
|
||||||
|
set_binary_op_output_data(a, b, outputs[1], bopt);
|
||||||
|
binary_op_gpu_inplace<Op>(inputs, outputs, op, s);
|
||||||
|
}
|
||||||
|
|
||||||
template <typename Op>
|
template <typename Op>
|
||||||
void binary_op_gpu(
|
void binary_op_gpu(
|
||||||
const std::vector<array>& inputs,
|
const std::vector<array>& inputs,
|
||||||
@@ -227,7 +243,8 @@ void binary_op_gpu(
|
|||||||
auto& b = inputs[1];
|
auto& b = inputs[1];
|
||||||
auto bopt = get_binary_op_type(a, b);
|
auto bopt = get_binary_op_type(a, b);
|
||||||
set_binary_op_output_data(a, b, out, bopt);
|
set_binary_op_output_data(a, b, out, bopt);
|
||||||
binary_op_gpu_inplace<Op>(inputs, out, op, s);
|
std::vector<array> outputs{out};
|
||||||
|
binary_op_gpu_inplace<Op>(inputs, outputs, op, s);
|
||||||
}
|
}
|
||||||
|
|
||||||
#define BINARY_GPU(func) \
|
#define BINARY_GPU(func) \
|
||||||
@@ -237,6 +254,14 @@ void binary_op_gpu(
|
|||||||
binary_op_gpu<cu::func>(inputs, out, get_primitive_string(this), s); \
|
binary_op_gpu<cu::func>(inputs, out, get_primitive_string(this), s); \
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#define BINARY_GPU_MULTI(func) \
|
||||||
|
void func::eval_gpu( \
|
||||||
|
const std::vector<array>& inputs, std::vector<array>& outputs) { \
|
||||||
|
nvtx3::scoped_range r(#func "::eval_gpu"); \
|
||||||
|
auto& s = outputs[0].primitive().stream(); \
|
||||||
|
binary_op_gpu<cu::func>(inputs, outputs, get_primitive_string(this), s); \
|
||||||
|
}
|
||||||
|
|
||||||
BINARY_GPU(Add)
|
BINARY_GPU(Add)
|
||||||
BINARY_GPU(ArcTan2)
|
BINARY_GPU(ArcTan2)
|
||||||
BINARY_GPU(Divide)
|
BINARY_GPU(Divide)
|
||||||
|
|||||||
@@ -1,248 +0,0 @@
|
|||||||
// Copyright © 2025 Apple Inc.
|
|
||||||
|
|
||||||
#include "mlx/backend/common/binary.h"
|
|
||||||
#include "mlx/backend/cuda/device.h"
|
|
||||||
#include "mlx/backend/cuda/device/binary_ops.cuh"
|
|
||||||
#include "mlx/backend/cuda/device/cucomplex_math.cuh"
|
|
||||||
#include "mlx/backend/cuda/kernel_utils.cuh"
|
|
||||||
#include "mlx/dtype_utils.h"
|
|
||||||
#include "mlx/primitives.h"
|
|
||||||
|
|
||||||
#include <cooperative_groups.h>
|
|
||||||
#include <nvtx3/nvtx3.hpp>
|
|
||||||
|
|
||||||
namespace mlx::core {
|
|
||||||
|
|
||||||
namespace cu {
|
|
||||||
|
|
||||||
namespace cg = cooperative_groups;
|
|
||||||
|
|
||||||
template <typename Op, typename In, typename Out, typename IdxT>
|
|
||||||
__global__ void
|
|
||||||
binary_ss(const In* a, const In* b, Out* out_a, Out* out_b, IdxT size) {
|
|
||||||
IdxT index = cg::this_grid().thread_rank();
|
|
||||||
if (index < size) {
|
|
||||||
auto out = Op{}(a[0], b[0]);
|
|
||||||
out_a[0] = out[0];
|
|
||||||
out_b[0] = out[1];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename Op, typename In, typename Out, typename IdxT>
|
|
||||||
__global__ void
|
|
||||||
binary_sv(const In* a, const In* b, Out* out_a, Out* out_b, IdxT size) {
|
|
||||||
IdxT index = cg::this_grid().thread_rank();
|
|
||||||
if (index < size) {
|
|
||||||
auto out = Op{}(a[0], b[index]);
|
|
||||||
out_a[index] = out[0];
|
|
||||||
out_b[index] = out[1];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename Op, typename In, typename Out, typename IdxT>
|
|
||||||
__global__ void
|
|
||||||
binary_vs(const In* a, const In* b, Out* out_a, Out* out_b, IdxT size) {
|
|
||||||
IdxT index = cg::this_grid().thread_rank();
|
|
||||||
if (index < size) {
|
|
||||||
auto out = Op{}(a[index], b[0]);
|
|
||||||
out_a[index] = out[0];
|
|
||||||
out_b[index] = out[1];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename Op, typename In, typename Out, typename IdxT>
|
|
||||||
__global__ void
|
|
||||||
binary_vv(const In* a, const In* b, Out* out_a, Out* out_b, IdxT size) {
|
|
||||||
IdxT index = cg::this_grid().thread_rank();
|
|
||||||
if (index < size) {
|
|
||||||
auto out = Op{}(a[index], b[index]);
|
|
||||||
out_a[index] = out[0];
|
|
||||||
out_b[index] = out[1];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename Op, typename In, typename Out, typename IdxT, int NDIM>
|
|
||||||
__global__ void binary_g_nd(
|
|
||||||
const In* a,
|
|
||||||
const In* b,
|
|
||||||
Out* out_a,
|
|
||||||
Out* out_b,
|
|
||||||
IdxT size,
|
|
||||||
const __grid_constant__ cuda::std::array<int32_t, NDIM> shape,
|
|
||||||
const __grid_constant__ cuda::std::array<int64_t, NDIM> a_strides,
|
|
||||||
const __grid_constant__ cuda::std::array<int64_t, NDIM> b_strides) {
|
|
||||||
IdxT index = cg::this_grid().thread_rank();
|
|
||||||
if (index < size) {
|
|
||||||
auto [a_idx, b_idx] = elem_to_loc_nd<NDIM>(
|
|
||||||
index, shape.data(), a_strides.data(), b_strides.data());
|
|
||||||
auto out = Op{}(a[a_idx], b[b_idx]);
|
|
||||||
out_a[index] = out[0];
|
|
||||||
out_b[index] = out[1];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename Op, typename In, typename Out, typename IdxT>
|
|
||||||
__global__ void binary_g(
|
|
||||||
const In* a,
|
|
||||||
const In* b,
|
|
||||||
Out* out_a,
|
|
||||||
Out* out_b,
|
|
||||||
IdxT size,
|
|
||||||
const __grid_constant__ Shape shape,
|
|
||||||
const __grid_constant__ Strides a_strides,
|
|
||||||
const __grid_constant__ Strides b_strides,
|
|
||||||
int ndim) {
|
|
||||||
IdxT index = cg::this_grid().thread_rank();
|
|
||||||
if (index < size) {
|
|
||||||
auto [a_idx, b_idx] = elem_to_loc_4d(
|
|
||||||
index, shape.data(), a_strides.data(), b_strides.data(), ndim);
|
|
||||||
auto out = Op{}(a[a_idx], b[b_idx]);
|
|
||||||
out_a[index] = out[0];
|
|
||||||
out_b[index] = out[1];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename Op, typename In, typename Out>
|
|
||||||
constexpr bool supports_binary_op() {
|
|
||||||
if (std::is_same_v<Op, DivMod>) {
|
|
||||||
return std::is_same_v<In, Out> &&
|
|
||||||
(std::is_integral_v<Out> || is_floating_v<Out>);
|
|
||||||
}
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace cu
|
|
||||||
|
|
||||||
template <typename Op>
|
|
||||||
void binary_op_gpu_inplace(
|
|
||||||
const std::vector<array>& inputs,
|
|
||||||
std::vector<array>& outputs,
|
|
||||||
std::string_view op,
|
|
||||||
const Stream& s) {
|
|
||||||
assert(inputs.size() > 1);
|
|
||||||
const auto& a = inputs[0];
|
|
||||||
const auto& b = inputs[1];
|
|
||||||
auto& out_a = outputs[0];
|
|
||||||
auto& out_b = outputs[1];
|
|
||||||
auto bopt = get_binary_op_type(a, b);
|
|
||||||
set_binary_op_output_data(a, b, out_a, bopt);
|
|
||||||
set_binary_op_output_data(a, b, out_b, bopt);
|
|
||||||
|
|
||||||
if (out_a.size() == 0) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
auto& encoder = cu::get_command_encoder(s);
|
|
||||||
encoder.set_input_array(a);
|
|
||||||
encoder.set_input_array(b);
|
|
||||||
encoder.set_output_array(out_a);
|
|
||||||
encoder.set_output_array(out_b);
|
|
||||||
encoder.launch_kernel([&](cudaStream_t stream) {
|
|
||||||
MLX_SWITCH_ALL_TYPES(a.dtype(), CTYPE_IN, {
|
|
||||||
MLX_SWITCH_ALL_TYPES(out_a.dtype(), CTYPE_OUT, {
|
|
||||||
if constexpr (cu::supports_binary_op<Op, CTYPE_IN, CTYPE_OUT>()) {
|
|
||||||
using InType = cuda_type_t<CTYPE_IN>;
|
|
||||||
using OutType = cuda_type_t<CTYPE_OUT>;
|
|
||||||
|
|
||||||
auto bopt = get_binary_op_type(a, b);
|
|
||||||
if (bopt == BinaryOpType::General) {
|
|
||||||
auto [shape, strides] = collapse_contiguous_dims(a, b, out_a);
|
|
||||||
auto& a_strides = strides[0];
|
|
||||||
auto& b_strides = strides[1];
|
|
||||||
bool large = a.data_size() > INT32_MAX ||
|
|
||||||
b.data_size() > INT32_MAX || out_a.data_size() > INT32_MAX;
|
|
||||||
MLX_SWITCH_BOOL(large, LARGE, {
|
|
||||||
using IdxT = std::conditional_t<LARGE, int64_t, int32_t>;
|
|
||||||
int ndim = shape.size();
|
|
||||||
if (ndim <= 3) {
|
|
||||||
MLX_SWITCH_1_2_3(ndim, NDIM, {
|
|
||||||
auto kernel =
|
|
||||||
cu::binary_g_nd<Op, InType, OutType, IdxT, NDIM>;
|
|
||||||
auto [num_blocks, block_dims] =
|
|
||||||
get_launch_args(kernel, out_a, large);
|
|
||||||
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
|
||||||
a.data<InType>(),
|
|
||||||
b.data<InType>(),
|
|
||||||
out_a.data<OutType>(),
|
|
||||||
out_b.data<OutType>(),
|
|
||||||
out_a.size(),
|
|
||||||
const_param<NDIM>(shape),
|
|
||||||
const_param<NDIM>(a_strides),
|
|
||||||
const_param<NDIM>(b_strides));
|
|
||||||
});
|
|
||||||
} else {
|
|
||||||
auto kernel = cu::binary_g<Op, InType, OutType, IdxT>;
|
|
||||||
auto [num_blocks, block_dims] =
|
|
||||||
get_launch_args(kernel, out_a, large);
|
|
||||||
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
|
||||||
a.data<InType>(),
|
|
||||||
b.data<InType>(),
|
|
||||||
out_a.data<OutType>(),
|
|
||||||
out_b.data<OutType>(),
|
|
||||||
out_a.size(),
|
|
||||||
const_param(shape),
|
|
||||||
const_param(a_strides),
|
|
||||||
const_param(b_strides),
|
|
||||||
ndim);
|
|
||||||
}
|
|
||||||
});
|
|
||||||
} else {
|
|
||||||
MLX_SWITCH_BOOL(out_a.data_size() > UINT32_MAX, LARGE, {
|
|
||||||
using IdxT = std::conditional_t<LARGE, int64_t, uint32_t>;
|
|
||||||
auto kernel = cu::binary_ss<Op, InType, OutType, IdxT>;
|
|
||||||
if (bopt == BinaryOpType::ScalarVector) {
|
|
||||||
kernel = cu::binary_sv<Op, InType, OutType, IdxT>;
|
|
||||||
} else if (bopt == BinaryOpType::VectorScalar) {
|
|
||||||
kernel = cu::binary_vs<Op, InType, OutType, IdxT>;
|
|
||||||
} else if (bopt == BinaryOpType::VectorVector) {
|
|
||||||
kernel = cu::binary_vv<Op, InType, OutType, IdxT>;
|
|
||||||
}
|
|
||||||
auto [num_blocks, block_dims] = get_launch_args(
|
|
||||||
kernel,
|
|
||||||
out_a.data_size(),
|
|
||||||
out_a.shape(),
|
|
||||||
out_a.strides(),
|
|
||||||
LARGE);
|
|
||||||
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
|
||||||
a.data<InType>(),
|
|
||||||
b.data<InType>(),
|
|
||||||
out_a.data<OutType>(),
|
|
||||||
out_b.data<OutType>(),
|
|
||||||
out_a.data_size());
|
|
||||||
});
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
throw std::runtime_error(fmt::format(
|
|
||||||
"Can not do binary op {} on inputs of {} with result of {}.",
|
|
||||||
op,
|
|
||||||
dtype_to_string(a.dtype()),
|
|
||||||
dtype_to_string(out_a.dtype())));
|
|
||||||
}
|
|
||||||
});
|
|
||||||
});
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename Op>
|
|
||||||
void binary_op_gpu(
|
|
||||||
const std::vector<array>& inputs,
|
|
||||||
std::vector<array>& outputs,
|
|
||||||
std::string_view op,
|
|
||||||
const Stream& s) {
|
|
||||||
auto& a = inputs[0];
|
|
||||||
auto& b = inputs[1];
|
|
||||||
auto bopt = get_binary_op_type(a, b);
|
|
||||||
set_binary_op_output_data(a, b, outputs[0], bopt);
|
|
||||||
set_binary_op_output_data(a, b, outputs[1], bopt);
|
|
||||||
binary_op_gpu_inplace<Op>(inputs, outputs, op, s);
|
|
||||||
}
|
|
||||||
|
|
||||||
void DivMod::eval_gpu(
|
|
||||||
const std::vector<array>& inputs,
|
|
||||||
std::vector<array>& outputs) {
|
|
||||||
nvtx3::scoped_range r("DivMod::eval_gpu");
|
|
||||||
auto& s = outputs[0].primitive().stream();
|
|
||||||
binary_op_gpu<cu::DivMod>(inputs, outputs, get_primitive_string(this), s);
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace mlx::core
|
|
||||||
@@ -24,6 +24,7 @@ void copy_gpu_inplace(
|
|||||||
auto& encoder = cu::get_command_encoder(s);
|
auto& encoder = cu::get_command_encoder(s);
|
||||||
encoder.set_input_array(in);
|
encoder.set_input_array(in);
|
||||||
encoder.set_output_array(out);
|
encoder.set_output_array(out);
|
||||||
|
|
||||||
if (ctype == CopyType::Scalar || ctype == CopyType::Vector) {
|
if (ctype == CopyType::Scalar || ctype == CopyType::Vector) {
|
||||||
copy_contiguous(encoder, ctype, in, out, offset_in, offset_out);
|
copy_contiguous(encoder, ctype, in, out, offset_in, offset_out);
|
||||||
return;
|
return;
|
||||||
|
|||||||
@@ -63,30 +63,25 @@ void copy_general(
|
|||||||
MLX_SWITCH_BOOL(large, LARGE, {
|
MLX_SWITCH_BOOL(large, LARGE, {
|
||||||
using IdxT = std::conditional_t<LARGE, int64_t, int32_t>;
|
using IdxT = std::conditional_t<LARGE, int64_t, int32_t>;
|
||||||
int ndim = shape.size();
|
int ndim = shape.size();
|
||||||
size_t data_size = 1;
|
|
||||||
for (auto& s : shape)
|
|
||||||
data_size *= s;
|
|
||||||
if (ndim <= 3) {
|
if (ndim <= 3) {
|
||||||
MLX_SWITCH_1_2_3(ndim, NDIM, {
|
MLX_SWITCH_1_2_3(ndim, NDIM, {
|
||||||
auto kernel = cu::copy_gg_nd<InType, OutType, IdxT, NDIM>;
|
auto kernel = cu::copy_gg_nd<InType, OutType, IdxT, NDIM>;
|
||||||
auto [num_blocks, block_dims] =
|
auto [num_blocks, block_dims] = get_launch_args(kernel, out, large);
|
||||||
get_launch_args(kernel, data_size, shape, out.strides(), large);
|
|
||||||
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
||||||
in_ptr,
|
in_ptr,
|
||||||
out_ptr,
|
out_ptr,
|
||||||
data_size,
|
out.size(),
|
||||||
const_param<NDIM>(shape),
|
const_param<NDIM>(shape),
|
||||||
const_param<NDIM>(strides_in),
|
const_param<NDIM>(strides_in),
|
||||||
const_param<NDIM>(strides_out));
|
const_param<NDIM>(strides_out));
|
||||||
});
|
});
|
||||||
} else { // ndim >= 4
|
} else { // ndim >= 4
|
||||||
auto kernel = cu::copy_gg<InType, OutType, IdxT>;
|
auto kernel = cu::copy_gg<InType, OutType, IdxT>;
|
||||||
auto [num_blocks, block_dims] =
|
auto [num_blocks, block_dims] = get_launch_args(kernel, out, large);
|
||||||
get_launch_args(kernel, data_size, shape, out.strides(), large);
|
|
||||||
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
||||||
in_ptr,
|
in_ptr,
|
||||||
out_ptr,
|
out_ptr,
|
||||||
data_size,
|
out.size(),
|
||||||
const_param(shape),
|
const_param(shape),
|
||||||
const_param(strides_in),
|
const_param(strides_in),
|
||||||
const_param(strides_out),
|
const_param(strides_out),
|
||||||
|
|||||||
@@ -6,7 +6,6 @@
|
|||||||
|
|
||||||
#include <fmt/format.h>
|
#include <fmt/format.h>
|
||||||
#include <nvtx3/nvtx3.hpp>
|
#include <nvtx3/nvtx3.hpp>
|
||||||
#include <future>
|
|
||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
|
|
||||||
@@ -108,16 +107,6 @@ void CommandEncoder::commit() {
|
|||||||
worker_.commit(stream_.last_cuda_stream());
|
worker_.commit(stream_.last_cuda_stream());
|
||||||
}
|
}
|
||||||
|
|
||||||
void CommandEncoder::synchronize() {
|
|
||||||
stream().synchronize();
|
|
||||||
auto p = std::make_shared<std::promise<void>>();
|
|
||||||
std::future<void> f = p->get_future();
|
|
||||||
add_completed_handler([p = std::move(p)]() { p->set_value(); });
|
|
||||||
worker_.end_batch();
|
|
||||||
commit();
|
|
||||||
f.wait();
|
|
||||||
}
|
|
||||||
|
|
||||||
Device& device(mlx::core::Device device) {
|
Device& device(mlx::core::Device device) {
|
||||||
static std::unordered_map<int, Device> devices;
|
static std::unordered_map<int, Device> devices;
|
||||||
auto it = devices.find(device.index);
|
auto it = devices.find(device.index);
|
||||||
|
|||||||
@@ -123,9 +123,6 @@ class CommandEncoder {
|
|||||||
return has_gpu_work_;
|
return has_gpu_work_;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Wait until kernels and completion handlers are finished
|
|
||||||
void synchronize();
|
|
||||||
|
|
||||||
private:
|
private:
|
||||||
Device& device_;
|
Device& device_;
|
||||||
DeviceStream& stream_;
|
DeviceStream& stream_;
|
||||||
|
|||||||
@@ -22,7 +22,7 @@ struct FloorDivide {
|
|||||||
if constexpr (cuda::std::is_integral_v<T>) {
|
if constexpr (cuda::std::is_integral_v<T>) {
|
||||||
return x / y;
|
return x / y;
|
||||||
} else {
|
} else {
|
||||||
return truncf(x / y);
|
return trunc(x / y);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@@ -132,7 +132,7 @@ struct LogAddExp {
|
|||||||
cuda::std::numeric_limits<float>::quiet_NaN(),
|
cuda::std::numeric_limits<float>::quiet_NaN(),
|
||||||
cuda::std::numeric_limits<float>::quiet_NaN()};
|
cuda::std::numeric_limits<float>::quiet_NaN()};
|
||||||
}
|
}
|
||||||
float inf = cuda::std::numeric_limits<float>::infinity();
|
constexpr float inf = cuda::std::numeric_limits<float>::infinity();
|
||||||
auto maxval = x > y ? x : y;
|
auto maxval = x > y ? x : y;
|
||||||
auto minval = x < y ? x : y;
|
auto minval = x < y ? x : y;
|
||||||
if (cuCrealf(minval) == -inf || cuCrealf(maxval) == inf)
|
if (cuCrealf(minval) == -inf || cuCrealf(maxval) == inf)
|
||||||
|
|||||||
@@ -5,7 +5,7 @@
|
|||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
// The maximum dimensions of shape/strides passed as kernel parameters.
|
// The maximum dimensions of shape/strides passed as kernel parameters.
|
||||||
#define MAX_NDIM 10
|
#define MAX_NDIM 8
|
||||||
|
|
||||||
// All existing NVIDIA hardware has a fixed 32 warp size. Though a built-in
|
// All existing NVIDIA hardware has a fixed 32 warp size. Though a built-in
|
||||||
// warpSize variable exists, using it would prevent compile-time optimizations.
|
// warpSize variable exists, using it would prevent compile-time optimizations.
|
||||||
|
|||||||
@@ -155,8 +155,8 @@ inline __host__ __device__ cuda::std::tuple<IdxT, IdxT> elem_to_loc_nd(
|
|||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i = NDIM - 1; i >= 0; --i) {
|
for (int i = NDIM - 1; i >= 0; --i) {
|
||||||
int dim_idx = elem % shape[i];
|
int dim_idx = elem % shape[i];
|
||||||
a_loc += dim_idx * IdxT(a_strides[i]);
|
a_loc += dim_idx * a_strides[i];
|
||||||
b_loc += dim_idx * IdxT(b_strides[i]);
|
b_loc += dim_idx * b_strides[i];
|
||||||
elem /= shape[i];
|
elem /= shape[i];
|
||||||
}
|
}
|
||||||
return cuda::std::make_tuple(a_loc, b_loc);
|
return cuda::std::make_tuple(a_loc, b_loc);
|
||||||
@@ -175,9 +175,9 @@ inline __host__ __device__ cuda::std::tuple<IdxT, IdxT, IdxT> elem_to_loc_nd(
|
|||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i = NDIM - 1; i >= 0; --i) {
|
for (int i = NDIM - 1; i >= 0; --i) {
|
||||||
int dim_idx = elem % shape[i];
|
int dim_idx = elem % shape[i];
|
||||||
a_loc += dim_idx * IdxT(a_strides[i]);
|
a_loc += dim_idx * a_strides[i];
|
||||||
b_loc += dim_idx * IdxT(b_strides[i]);
|
b_loc += dim_idx * b_strides[i];
|
||||||
c_loc += dim_idx * IdxT(c_strides[i]);
|
c_loc += dim_idx * c_strides[i];
|
||||||
elem /= shape[i];
|
elem /= shape[i];
|
||||||
}
|
}
|
||||||
return cuda::std::make_tuple(a_loc, b_loc, c_loc);
|
return cuda::std::make_tuple(a_loc, b_loc, c_loc);
|
||||||
@@ -206,8 +206,8 @@ inline __host__ __device__ cuda::std::tuple<IdxT, IdxT> elem_to_loc_4d(
|
|||||||
IdxT b_loc = 0;
|
IdxT b_loc = 0;
|
||||||
for (int i = ndim - 1; i >= 0; --i) {
|
for (int i = ndim - 1; i >= 0; --i) {
|
||||||
int dim_idx = elem % shape[i];
|
int dim_idx = elem % shape[i];
|
||||||
a_loc += dim_idx * IdxT(a_strides[i]);
|
a_loc += dim_idx * a_strides[i];
|
||||||
b_loc += dim_idx * IdxT(b_strides[i]);
|
b_loc += dim_idx * b_strides[i];
|
||||||
elem /= shape[i];
|
elem /= shape[i];
|
||||||
}
|
}
|
||||||
return cuda::std::make_tuple(a_loc, b_loc);
|
return cuda::std::make_tuple(a_loc, b_loc);
|
||||||
@@ -226,9 +226,9 @@ inline __host__ __device__ cuda::std::tuple<IdxT, IdxT, IdxT> elem_to_loc_4d(
|
|||||||
IdxT c_loc = 0;
|
IdxT c_loc = 0;
|
||||||
for (int i = ndim - 1; i >= 0; --i) {
|
for (int i = ndim - 1; i >= 0; --i) {
|
||||||
int dim_idx = elem % shape[i];
|
int dim_idx = elem % shape[i];
|
||||||
a_loc += dim_idx * IdxT(a_strides[i]);
|
a_loc += dim_idx * a_strides[i];
|
||||||
b_loc += dim_idx * IdxT(b_strides[i]);
|
b_loc += dim_idx * b_strides[i];
|
||||||
c_loc += dim_idx * IdxT(c_strides[i]);
|
c_loc += dim_idx * c_strides[i];
|
||||||
elem /= shape[i];
|
elem /= shape[i];
|
||||||
}
|
}
|
||||||
return cuda::std::make_tuple(a_loc, b_loc, c_loc);
|
return cuda::std::make_tuple(a_loc, b_loc, c_loc);
|
||||||
|
|||||||
@@ -62,7 +62,7 @@ void finalize(Stream s) {
|
|||||||
|
|
||||||
void synchronize(Stream s) {
|
void synchronize(Stream s) {
|
||||||
nvtx3::scoped_range r("gpu::synchronize");
|
nvtx3::scoped_range r("gpu::synchronize");
|
||||||
cu::get_command_encoder(s).synchronize();
|
cu::get_stream(s).synchronize();
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace mlx::core::gpu
|
} // namespace mlx::core::gpu
|
||||||
|
|||||||
@@ -37,46 +37,36 @@ void check_cu_error(const char* name, CUresult err) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Return the location of the CUDA toolkit.
|
// Return the location of the CUDA toolkit.
|
||||||
const std::string& cuda_home() {
|
const char* cuda_home() {
|
||||||
static std::string home = []() -> std::string {
|
const char* home = std::getenv("CUDA_HOME");
|
||||||
const char* home = std::getenv("CUDA_HOME");
|
if (home) {
|
||||||
if (home) {
|
return home;
|
||||||
return home;
|
}
|
||||||
}
|
home = std::getenv("CUDA_PATH");
|
||||||
home = std::getenv("CUDA_PATH");
|
if (home) {
|
||||||
if (home) {
|
return home;
|
||||||
return home;
|
}
|
||||||
}
|
|
||||||
#if defined(__linux__)
|
#if defined(__linux__)
|
||||||
home = "/usr/local/cuda";
|
home = "/usr/local/cuda";
|
||||||
if (std::filesystem::exists(home)) {
|
if (std::filesystem::exists(home)) {
|
||||||
return home;
|
return home;
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
throw std::runtime_error(
|
throw std::runtime_error(
|
||||||
"Environment variable CUDA_HOME or CUDA_PATH is not set.");
|
"Environment variable CUDA_HOME or CUDA_PATH is not set.");
|
||||||
}();
|
|
||||||
return home;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get the cache directory for storing compiled results.
|
// Get the cache directory for storing compiled results.
|
||||||
const std::filesystem::path& ptx_cache_dir() {
|
bool get_ptx_cache_dir(std::filesystem::path* result) {
|
||||||
static std::filesystem::path cache = []() -> std::filesystem::path {
|
auto path = std::filesystem::temp_directory_path() / "mlx" / "ptx";
|
||||||
std::filesystem::path cache;
|
if (!std::filesystem::is_directory(path)) {
|
||||||
if (auto c = std::getenv("MLX_PTX_CACHE"); c) {
|
std::error_code error;
|
||||||
cache = c;
|
if (!std::filesystem::create_directories(path, error)) {
|
||||||
} else {
|
return false;
|
||||||
cache = std::filesystem::temp_directory_path() / "mlx" / "ptx";
|
|
||||||
}
|
}
|
||||||
if (!std::filesystem::exists(cache)) {
|
}
|
||||||
std::error_code error;
|
*result = path;
|
||||||
if (!std::filesystem::create_directories(cache, error)) {
|
return true;
|
||||||
return std::filesystem::path();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return cache;
|
|
||||||
}();
|
|
||||||
return cache;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Try to read the cached |ptx| and |ptx_kernels| from |cache_dir|.
|
// Try to read the cached |ptx| and |ptx_kernels| from |cache_dir|.
|
||||||
@@ -85,10 +75,6 @@ bool read_cached_ptx(
|
|||||||
const std::string& module_name,
|
const std::string& module_name,
|
||||||
std::vector<char>* ptx,
|
std::vector<char>* ptx,
|
||||||
std::vector<std::pair<std::string, std::string>>* ptx_kernels) {
|
std::vector<std::pair<std::string, std::string>>* ptx_kernels) {
|
||||||
if (cache_dir.empty()) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
auto ptx_path = cache_dir / (module_name + ".ptx");
|
auto ptx_path = cache_dir / (module_name + ".ptx");
|
||||||
std::error_code error;
|
std::error_code error;
|
||||||
auto ptx_size = std::filesystem::file_size(ptx_path, error);
|
auto ptx_size = std::filesystem::file_size(ptx_path, error);
|
||||||
@@ -119,10 +105,6 @@ void write_cached_ptx(
|
|||||||
const std::string& module_name,
|
const std::string& module_name,
|
||||||
const std::vector<char>& ptx,
|
const std::vector<char>& ptx,
|
||||||
const std::vector<std::pair<std::string, std::string>>& ptx_kernels) {
|
const std::vector<std::pair<std::string, std::string>>& ptx_kernels) {
|
||||||
if (cache_dir.empty()) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::ofstream ptx_file(cache_dir / (module_name + ".ptx"), std::ios::binary);
|
std::ofstream ptx_file(cache_dir / (module_name + ".ptx"), std::ios::binary);
|
||||||
if (!ptx.empty()) {
|
if (!ptx.empty()) {
|
||||||
ptx_file.write(&ptx.front(), ptx.size());
|
ptx_file.write(&ptx.front(), ptx.size());
|
||||||
@@ -202,9 +184,11 @@ JitModule::JitModule(
|
|||||||
const std::string& module_name,
|
const std::string& module_name,
|
||||||
const KernelBuilder& builder) {
|
const KernelBuilder& builder) {
|
||||||
// Check cache.
|
// Check cache.
|
||||||
|
std::filesystem::path cache_dir;
|
||||||
std::vector<char> ptx;
|
std::vector<char> ptx;
|
||||||
std::vector<std::pair<std::string, std::string>> ptx_kernels;
|
std::vector<std::pair<std::string, std::string>> ptx_kernels;
|
||||||
if (!read_cached_ptx(ptx_cache_dir(), module_name, &ptx, &ptx_kernels)) {
|
if (!get_ptx_cache_dir(&cache_dir) ||
|
||||||
|
!read_cached_ptx(cache_dir, module_name, &ptx, &ptx_kernels)) {
|
||||||
// Create program.
|
// Create program.
|
||||||
auto [source_code, kernel_names] = builder();
|
auto [source_code, kernel_names] = builder();
|
||||||
nvrtcProgram prog;
|
nvrtcProgram prog;
|
||||||
@@ -262,7 +246,7 @@ JitModule::JitModule(
|
|||||||
} else {
|
} else {
|
||||||
CHECK_NVRTC_ERROR(nvrtcGetPTX(prog, ptx.data()));
|
CHECK_NVRTC_ERROR(nvrtcGetPTX(prog, ptx.data()));
|
||||||
}
|
}
|
||||||
write_cached_ptx(ptx_cache_dir(), module_name, ptx, ptx_kernels);
|
write_cached_ptx(cache_dir, module_name, ptx, ptx_kernels);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Load module.
|
// Load module.
|
||||||
|
|||||||
@@ -162,15 +162,11 @@ class MatMul {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void* workspace_ptr = nullptr;
|
array workspace(
|
||||||
if (heuristic_.workspaceSize > 0) {
|
allocator::malloc(heuristic_.workspaceSize),
|
||||||
array workspace(
|
{static_cast<int>(heuristic_.workspaceSize)},
|
||||||
allocator::malloc(heuristic_.workspaceSize),
|
int8);
|
||||||
{static_cast<int>(heuristic_.workspaceSize)},
|
encoder.add_temporary(workspace);
|
||||||
int8);
|
|
||||||
encoder.add_temporary(workspace);
|
|
||||||
workspace_ptr = workspace.data<void>();
|
|
||||||
}
|
|
||||||
|
|
||||||
encoder.launch_kernel([&](cudaStream_t stream) {
|
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||||
CHECK_CUBLAS_ERROR(cublasLtMatmul(
|
CHECK_CUBLAS_ERROR(cublasLtMatmul(
|
||||||
@@ -187,8 +183,8 @@ class MatMul {
|
|||||||
out,
|
out,
|
||||||
out_desc_,
|
out_desc_,
|
||||||
&heuristic_.algo,
|
&heuristic_.algo,
|
||||||
workspace_ptr,
|
workspace.data<void>(),
|
||||||
heuristic_.workspaceSize,
|
workspace.nbytes(),
|
||||||
stream));
|
stream));
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
@@ -362,18 +358,9 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
a_batch_strides.back(),
|
a_batch_strides.back(),
|
||||||
b_batch_strides.back());
|
b_batch_strides.back());
|
||||||
|
|
||||||
encoder.set_input_array(a);
|
|
||||||
encoder.set_input_array(b);
|
|
||||||
encoder.set_output_array(out);
|
|
||||||
auto nbatch = batch_count / batch_shape.back();
|
|
||||||
if (nbatch == 1) {
|
|
||||||
matmul.run(encoder, out.data<int8_t>(), a.data<int8_t>(), b.data<int8_t>());
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
ContiguousIterator a_it(batch_shape, a_batch_strides, batch_shape.size() - 1);
|
ContiguousIterator a_it(batch_shape, a_batch_strides, batch_shape.size() - 1);
|
||||||
ContiguousIterator b_it(batch_shape, b_batch_strides, batch_shape.size() - 1);
|
ContiguousIterator b_it(batch_shape, b_batch_strides, batch_shape.size() - 1);
|
||||||
for (size_t i = 0; i < nbatch; ++i) {
|
for (size_t i = 0; i < batch_count / batch_shape.back(); ++i) {
|
||||||
matmul.run(
|
matmul.run(
|
||||||
encoder,
|
encoder,
|
||||||
out.data<int8_t>() + out.itemsize() * i * batch_shape.back() * M * N,
|
out.data<int8_t>() + out.itemsize() * i * batch_shape.back() * M * N,
|
||||||
@@ -457,28 +444,10 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
b_batch_strides.back(),
|
b_batch_strides.back(),
|
||||||
c_batch_strides.back());
|
c_batch_strides.back());
|
||||||
|
|
||||||
encoder.set_input_array(a);
|
|
||||||
encoder.set_input_array(b);
|
|
||||||
encoder.set_input_array(c);
|
|
||||||
encoder.set_output_array(out);
|
|
||||||
|
|
||||||
auto nbatch = batch_count / batch_shape.back();
|
|
||||||
if (nbatch == 1) {
|
|
||||||
matmul.run(
|
|
||||||
encoder,
|
|
||||||
out.data<int8_t>(),
|
|
||||||
a.data<int8_t>(),
|
|
||||||
b.data<int8_t>(),
|
|
||||||
c.data<int8_t>(),
|
|
||||||
alpha_,
|
|
||||||
beta_);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
ContiguousIterator a_it(batch_shape, a_batch_strides, batch_shape.size() - 1);
|
ContiguousIterator a_it(batch_shape, a_batch_strides, batch_shape.size() - 1);
|
||||||
ContiguousIterator b_it(batch_shape, b_batch_strides, batch_shape.size() - 1);
|
ContiguousIterator b_it(batch_shape, b_batch_strides, batch_shape.size() - 1);
|
||||||
ContiguousIterator c_it(batch_shape, c_batch_strides, batch_shape.size() - 1);
|
ContiguousIterator c_it(batch_shape, c_batch_strides, batch_shape.size() - 1);
|
||||||
for (size_t i = 0; i < nbatch; ++i) {
|
for (size_t i = 0; i < batch_count / batch_shape.back(); ++i) {
|
||||||
matmul.run(
|
matmul.run(
|
||||||
encoder,
|
encoder,
|
||||||
out.data<int8_t>() + out.itemsize() * i * batch_shape.back() * M * N,
|
out.data<int8_t>() + out.itemsize() * i * batch_shape.back() * M * N,
|
||||||
|
|||||||
@@ -43,17 +43,6 @@ void Arange::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
bool fast::ScaledDotProductAttention::use_fallback(
|
|
||||||
const array& q,
|
|
||||||
const array& k,
|
|
||||||
const array& v,
|
|
||||||
bool has_mask,
|
|
||||||
bool has_arr_mask,
|
|
||||||
bool do_causal,
|
|
||||||
Stream s) {
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
#define NO_GPU_MULTI(func) \
|
#define NO_GPU_MULTI(func) \
|
||||||
void func::eval_gpu( \
|
void func::eval_gpu( \
|
||||||
const std::vector<array>& inputs, std::vector<array>& outputs) { \
|
const std::vector<array>& inputs, std::vector<array>& outputs) { \
|
||||||
@@ -71,8 +60,10 @@ bool fast::ScaledDotProductAttention::use_fallback(
|
|||||||
throw std::runtime_error(#func " has no CUDA implementation."); \
|
throw std::runtime_error(#func " has no CUDA implementation."); \
|
||||||
}
|
}
|
||||||
|
|
||||||
|
NO_GPU(ArgPartition)
|
||||||
NO_GPU(BlockMaskedMM)
|
NO_GPU(BlockMaskedMM)
|
||||||
NO_GPU(Convolution)
|
NO_GPU(Convolution)
|
||||||
|
NO_GPU_MULTI(DivMod)
|
||||||
NO_GPU(DynamicSlice)
|
NO_GPU(DynamicSlice)
|
||||||
NO_GPU(DynamicSliceUpdate)
|
NO_GPU(DynamicSliceUpdate)
|
||||||
NO_GPU(FFT)
|
NO_GPU(FFT)
|
||||||
@@ -81,6 +72,7 @@ NO_GPU(GatherQMM)
|
|||||||
NO_GPU(Hadamard)
|
NO_GPU(Hadamard)
|
||||||
NO_GPU(Load)
|
NO_GPU(Load)
|
||||||
NO_GPU_MULTI(LUF)
|
NO_GPU_MULTI(LUF)
|
||||||
|
NO_GPU(Partition)
|
||||||
NO_GPU_MULTI(QRF)
|
NO_GPU_MULTI(QRF)
|
||||||
NO_GPU(QuantizedMatmul)
|
NO_GPU(QuantizedMatmul)
|
||||||
NO_GPU(Scan)
|
NO_GPU(Scan)
|
||||||
@@ -91,7 +83,6 @@ NO_GPU_MULTI(Eig)
|
|||||||
NO_GPU_MULTI(Eigh)
|
NO_GPU_MULTI(Eigh)
|
||||||
|
|
||||||
namespace fast {
|
namespace fast {
|
||||||
NO_GPU(ScaledDotProductAttention)
|
|
||||||
NO_GPU_MULTI(AffineQuantize)
|
NO_GPU_MULTI(AffineQuantize)
|
||||||
NO_GPU_MULTI(CustomKernel)
|
NO_GPU_MULTI(CustomKernel)
|
||||||
} // namespace fast
|
} // namespace fast
|
||||||
|
|||||||
@@ -21,11 +21,28 @@ void Reduce::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
assert(!axes_.empty());
|
assert(!axes_.empty());
|
||||||
assert(out.size() != in.size());
|
assert(out.size() != in.size());
|
||||||
|
|
||||||
|
out.set_data(allocator::malloc(out.nbytes()));
|
||||||
|
|
||||||
auto& s = stream();
|
auto& s = stream();
|
||||||
auto& encoder = cu::get_command_encoder(s);
|
auto& encoder = cu::get_command_encoder(s);
|
||||||
|
encoder.set_input_array(in);
|
||||||
|
encoder.set_output_array(out);
|
||||||
|
|
||||||
|
// Fill out with init value.
|
||||||
if (in.size() == 0) {
|
if (in.size() == 0) {
|
||||||
init_reduce(encoder, in, out, reduce_type_);
|
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||||
|
MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE, {
|
||||||
|
MLX_SWITCH_REDUCE_OPS(reduce_type_, OP, {
|
||||||
|
using InType = cuda_type_t<CTYPE>;
|
||||||
|
using OutType = cu::ReduceResult<OP, InType>::type;
|
||||||
|
thrust::fill_n(
|
||||||
|
cu::thrust_policy(stream),
|
||||||
|
thrust::device_pointer_cast(out.data<OutType>()),
|
||||||
|
out.data_size(),
|
||||||
|
cu::ReduceInit<OP, InType>::value());
|
||||||
|
});
|
||||||
|
});
|
||||||
|
});
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -34,19 +51,7 @@ void Reduce::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
|
|
||||||
// If it is a general reduce then copy the input to a contiguous array and
|
// If it is a general reduce then copy the input to a contiguous array and
|
||||||
// recompute the plan.
|
// recompute the plan.
|
||||||
//
|
if (plan.type == GeneralReduce) {
|
||||||
// TODO: Instead of copying we can use elem-to-loc to deal with broadcasting
|
|
||||||
// like we do in Metal. When it comes to broadcasted reduction axes
|
|
||||||
// some can be ignored eg for min/max.
|
|
||||||
bool broadcasted = false;
|
|
||||||
for (int i = 0, j = 0; i < in.ndim() && !broadcasted; i++) {
|
|
||||||
if (j < axes_.size() && axes_[j] == i) {
|
|
||||||
j++;
|
|
||||||
} else {
|
|
||||||
broadcasted = in.strides(i) == 0;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (plan.type == GeneralReduce || broadcasted || !in.flags().contiguous) {
|
|
||||||
array in_copy(in.shape(), in.dtype(), nullptr, {});
|
array in_copy(in.shape(), in.dtype(), nullptr, {});
|
||||||
copy_gpu(in, in_copy, CopyType::General, s);
|
copy_gpu(in, in_copy, CopyType::General, s);
|
||||||
encoder.add_temporary(in_copy);
|
encoder.add_temporary(in_copy);
|
||||||
@@ -54,8 +59,9 @@ void Reduce::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
plan = get_reduction_plan(in, axes_);
|
plan = get_reduction_plan(in, axes_);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (plan.type == ContiguousAllReduce) {
|
if ((plan.type == ContiguousAllReduce) ||
|
||||||
all_reduce(encoder, in, out, reduce_type_);
|
(plan.type == ContiguousReduce && plan.shape.size() == 1)) {
|
||||||
|
segmented_reduce(encoder, in, out, reduce_type_, axes_, plan);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,150 +0,0 @@
|
|||||||
// Copyright © 2025 Apple Inc.
|
|
||||||
|
|
||||||
#include "mlx/backend/cuda/device.h"
|
|
||||||
#include "mlx/backend/cuda/reduce/reduce.cuh"
|
|
||||||
|
|
||||||
#include <cooperative_groups.h>
|
|
||||||
#include <cooperative_groups/reduce.h>
|
|
||||||
#include <cub/block/block_load.cuh>
|
|
||||||
|
|
||||||
namespace mlx::core {
|
|
||||||
|
|
||||||
namespace cu {
|
|
||||||
|
|
||||||
namespace cg = cooperative_groups;
|
|
||||||
|
|
||||||
template <typename T, typename U, typename ReduceOp, int N = 4>
|
|
||||||
__global__ void all_reduce(T* in, U* out, size_t block_step, size_t size) {
|
|
||||||
// TODO: Process multiple "rows" in each thread
|
|
||||||
constexpr int M = 1;
|
|
||||||
|
|
||||||
auto grid = cg::this_grid();
|
|
||||||
auto block = cg::this_thread_block();
|
|
||||||
auto warp = cg::tiled_partition<WARP_SIZE>(block);
|
|
||||||
|
|
||||||
const U init = cu::ReduceInit<ReduceOp, T>::value();
|
|
||||||
ReduceOp op;
|
|
||||||
|
|
||||||
T vals[N];
|
|
||||||
U accs[M];
|
|
||||||
accs[0] = init;
|
|
||||||
|
|
||||||
size_t start = grid.block_rank() * block_step;
|
|
||||||
size_t end = start + block_step;
|
|
||||||
size_t check = min(end, size);
|
|
||||||
|
|
||||||
size_t i = start;
|
|
||||||
for (; i + block.size() * N <= check; i += block.size() * N) {
|
|
||||||
cub::LoadDirectBlockedVectorized<T, N>(block.thread_rank(), in + i, vals);
|
|
||||||
for (int j = 0; j < N; j++) {
|
|
||||||
accs[0] = op(accs[0], __cast<U, T>(vals[j]));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (i < check) {
|
|
||||||
cub::LoadDirectBlocked(
|
|
||||||
block.thread_rank(), in + i, vals, check - i, __cast<T, U>(init));
|
|
||||||
for (int i = 0; i < N; i++) {
|
|
||||||
accs[0] = op(accs[0], __cast<U, T>(vals[i]));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
__shared__ U shared_accumulators[32];
|
|
||||||
block_reduce(block, warp, accs, shared_accumulators, op, init);
|
|
||||||
|
|
||||||
if (block.thread_rank() == 0) {
|
|
||||||
out[grid.block_rank()] = accs[0];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace cu
|
|
||||||
|
|
||||||
void all_reduce(
|
|
||||||
cu::CommandEncoder& encoder,
|
|
||||||
const array& in,
|
|
||||||
array& out,
|
|
||||||
Reduce::ReduceType reduce_type) {
|
|
||||||
constexpr int N_READS = 8;
|
|
||||||
|
|
||||||
out.set_data(allocator::malloc(out.nbytes()));
|
|
||||||
|
|
||||||
auto get_args = [](size_t size, int N) {
|
|
||||||
int threads = std::min(512UL, (size + N - 1) / N);
|
|
||||||
threads = ((threads + WARP_SIZE - 1) / WARP_SIZE) * WARP_SIZE;
|
|
||||||
int reductions_per_step = threads * N;
|
|
||||||
size_t steps_needed =
|
|
||||||
(size + reductions_per_step - 1) / reductions_per_step;
|
|
||||||
|
|
||||||
int blocks;
|
|
||||||
if (steps_needed < 32) {
|
|
||||||
blocks = 1;
|
|
||||||
} else if (steps_needed < 128) {
|
|
||||||
blocks = 32;
|
|
||||||
} else if (steps_needed < 512) {
|
|
||||||
blocks = 128;
|
|
||||||
} else if (steps_needed < 1024) {
|
|
||||||
blocks = 512;
|
|
||||||
} else {
|
|
||||||
blocks = 1024;
|
|
||||||
}
|
|
||||||
|
|
||||||
size_t steps_per_block = (steps_needed + blocks - 1) / blocks;
|
|
||||||
size_t block_step = steps_per_block * reductions_per_step;
|
|
||||||
|
|
||||||
return std::make_tuple(blocks, threads, block_step);
|
|
||||||
};
|
|
||||||
|
|
||||||
int blocks, threads;
|
|
||||||
size_t block_step;
|
|
||||||
size_t insize = in.size();
|
|
||||||
Dtype dt = in.dtype();
|
|
||||||
|
|
||||||
// Cub doesn't like const pointers for load (sigh).
|
|
||||||
void* indata = const_cast<void*>(in.data<void>());
|
|
||||||
|
|
||||||
// Large array so allocate an intermediate and accumulate there
|
|
||||||
std::tie(blocks, threads, block_step) = get_args(insize, N_READS);
|
|
||||||
encoder.set_input_array(in);
|
|
||||||
if (blocks > 1) {
|
|
||||||
array intermediate({blocks}, out.dtype(), nullptr, {});
|
|
||||||
intermediate.set_data(allocator::malloc(intermediate.nbytes()));
|
|
||||||
encoder.add_temporary(intermediate);
|
|
||||||
encoder.set_output_array(intermediate);
|
|
||||||
encoder.launch_kernel([&](cudaStream_t stream) {
|
|
||||||
MLX_SWITCH_ALL_TYPES(dt, CTYPE, {
|
|
||||||
MLX_SWITCH_REDUCE_OPS(reduce_type, OP, {
|
|
||||||
using T = cuda_type_t<CTYPE>;
|
|
||||||
using U = cu::ReduceResult<OP, T>::type;
|
|
||||||
auto kernel = cu::all_reduce<T, U, OP, N_READS>;
|
|
||||||
kernel<<<blocks, threads, 0, stream>>>(
|
|
||||||
static_cast<T*>(indata),
|
|
||||||
intermediate.data<U>(),
|
|
||||||
block_step,
|
|
||||||
insize);
|
|
||||||
});
|
|
||||||
});
|
|
||||||
});
|
|
||||||
|
|
||||||
// Set the input for the next step and recalculate the blocks
|
|
||||||
indata = intermediate.data<void>();
|
|
||||||
dt = intermediate.dtype();
|
|
||||||
insize = intermediate.size();
|
|
||||||
std::tie(blocks, threads, block_step) = get_args(insize, N_READS);
|
|
||||||
encoder.set_input_array(intermediate);
|
|
||||||
}
|
|
||||||
|
|
||||||
encoder.set_output_array(out);
|
|
||||||
encoder.launch_kernel([&](cudaStream_t stream) {
|
|
||||||
MLX_SWITCH_ALL_TYPES(dt, CTYPE, {
|
|
||||||
MLX_SWITCH_REDUCE_OPS(reduce_type, OP, {
|
|
||||||
using T = cuda_type_t<CTYPE>;
|
|
||||||
using U = cu::ReduceResult<OP, T>::type;
|
|
||||||
auto kernel = cu::all_reduce<T, U, OP, N_READS>;
|
|
||||||
kernel<<<blocks, threads, 0, stream>>>(
|
|
||||||
static_cast<T*>(indata), out.data<U>(), block_step, insize);
|
|
||||||
});
|
|
||||||
});
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace mlx::core
|
|
||||||
@@ -1,7 +1,5 @@
|
|||||||
// Copyright © 2025 Apple Inc.
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
#include <numeric>
|
|
||||||
|
|
||||||
#include "mlx/backend/cuda/device.h"
|
#include "mlx/backend/cuda/device.h"
|
||||||
#include "mlx/backend/cuda/device/cast_op.cuh"
|
#include "mlx/backend/cuda/device/cast_op.cuh"
|
||||||
#include "mlx/backend/cuda/reduce/reduce.cuh"
|
#include "mlx/backend/cuda/reduce/reduce.cuh"
|
||||||
@@ -38,36 +36,19 @@ struct ColReduceArgs {
|
|||||||
const array& in,
|
const array& in,
|
||||||
const ReductionPlan& plan,
|
const ReductionPlan& plan,
|
||||||
const std::vector<int>& axes) {
|
const std::vector<int>& axes) {
|
||||||
using ShapeVector = decltype(plan.shape);
|
|
||||||
using StridesVector = decltype(plan.strides);
|
|
||||||
|
|
||||||
ShapeVector shape_vec;
|
|
||||||
StridesVector strides_vec;
|
|
||||||
|
|
||||||
assert(!plan.shape.empty());
|
assert(!plan.shape.empty());
|
||||||
reduction_size = plan.shape.back();
|
reduction_size = plan.shape.back();
|
||||||
reduction_stride = plan.strides.back();
|
reduction_stride = plan.strides.back();
|
||||||
|
|
||||||
int64_t stride_back = 1;
|
int64_t stride_back = 1;
|
||||||
std::tie(shape_vec, strides_vec) = shapes_without_reduction_axes(in, axes);
|
auto [shape_vec, strides_vec] = shapes_without_reduction_axes(in, axes);
|
||||||
while (!shape_vec.empty() && stride_back < reduction_stride) {
|
while (!shape_vec.empty() && stride_back < reduction_stride) {
|
||||||
stride_back *= shape_vec.back();
|
stride_back *= shape_vec.back();
|
||||||
shape_vec.pop_back();
|
shape_vec.pop_back();
|
||||||
strides_vec.pop_back();
|
strides_vec.pop_back();
|
||||||
}
|
}
|
||||||
std::vector<int> indices(shape_vec.size());
|
|
||||||
std::iota(indices.begin(), indices.end(), 0);
|
|
||||||
std::sort(indices.begin(), indices.end(), [&](int left, int right) {
|
|
||||||
return strides_vec[left] > strides_vec[right];
|
|
||||||
});
|
|
||||||
ShapeVector sorted_shape;
|
|
||||||
StridesVector sorted_strides;
|
|
||||||
for (auto idx : indices) {
|
|
||||||
sorted_shape.push_back(shape_vec[idx]);
|
|
||||||
sorted_strides.push_back(strides_vec[idx]);
|
|
||||||
}
|
|
||||||
std::tie(shape_vec, strides_vec) =
|
std::tie(shape_vec, strides_vec) =
|
||||||
collapse_contiguous_dims(sorted_shape, sorted_strides);
|
collapse_contiguous_dims(shape_vec, strides_vec);
|
||||||
shape = const_param(shape_vec);
|
shape = const_param(shape_vec);
|
||||||
strides = const_param(strides_vec);
|
strides = const_param(strides_vec);
|
||||||
ndim = shape_vec.size();
|
ndim = shape_vec.size();
|
||||||
@@ -83,6 +64,86 @@ struct ColReduceArgs {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
template <typename T, typename U, typename Op, int NDIM, int N_READS = 4>
|
||||||
|
__global__ void col_reduce_small(
|
||||||
|
const T* in,
|
||||||
|
U* out,
|
||||||
|
const __grid_constant__ ColReduceArgs args) {
|
||||||
|
auto grid = cg::this_grid();
|
||||||
|
auto block = cg::this_thread_block();
|
||||||
|
|
||||||
|
int column =
|
||||||
|
grid.block_index().x * block.dim_threads().x + block.thread_index().x;
|
||||||
|
if (column * N_READS >= args.reduction_stride) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
int out_idx = grid.block_rank() / grid.dim_blocks().x;
|
||||||
|
in += elem_to_loc(out_idx, args.shape.data(), args.strides.data(), args.ndim);
|
||||||
|
|
||||||
|
Op op;
|
||||||
|
U totals[N_READS];
|
||||||
|
for (int i = 0; i < N_READS; i++) {
|
||||||
|
totals[i] = ReduceInit<Op, T>::value();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read input to local.
|
||||||
|
LoopedElemToLoc<NDIM, (NDIM > 2)> loop(args.reduce_ndim);
|
||||||
|
loop.next(
|
||||||
|
block.thread_index().y,
|
||||||
|
args.reduce_shape.data(),
|
||||||
|
args.reduce_strides.data());
|
||||||
|
for (size_t r = block.thread_index().y;
|
||||||
|
r < args.non_col_reductions * args.reduction_size;
|
||||||
|
r += block.dim_threads().y) {
|
||||||
|
U vals[N_READS];
|
||||||
|
cub::LoadDirectBlocked(
|
||||||
|
column,
|
||||||
|
make_cast_iterator<U>(in + loop.location()),
|
||||||
|
vals,
|
||||||
|
args.reduction_stride,
|
||||||
|
ReduceInit<Op, T>::value());
|
||||||
|
for (int i = 0; i < N_READS; i++) {
|
||||||
|
totals[i] = op(vals[i], totals[i]);
|
||||||
|
}
|
||||||
|
loop.next(
|
||||||
|
block.dim_threads().y,
|
||||||
|
args.reduce_shape.data(),
|
||||||
|
args.reduce_strides.data());
|
||||||
|
}
|
||||||
|
|
||||||
|
// Do block reduce when each column has more than 1 element to reduce.
|
||||||
|
if (block.dim_threads().y > 1) {
|
||||||
|
__shared__ U shared_vals[32 * 8 * N_READS];
|
||||||
|
size_t col =
|
||||||
|
block.thread_index().y * block.dim_threads().x + block.thread_index().x;
|
||||||
|
for (int i = 0; i < N_READS; i++) {
|
||||||
|
shared_vals[col * N_READS + i] = totals[i];
|
||||||
|
}
|
||||||
|
block.sync();
|
||||||
|
if (block.thread_index().y == 0) {
|
||||||
|
for (int i = 0; i < N_READS; i++) {
|
||||||
|
totals[i] = shared_vals[block.thread_index().x * N_READS + i];
|
||||||
|
}
|
||||||
|
for (int j = 1; j < block.dim_threads().y; j++) {
|
||||||
|
col = j * block.dim_threads().x + block.thread_index().x;
|
||||||
|
for (int i = 0; i < N_READS; i++) {
|
||||||
|
totals[i] = op(shared_vals[col * N_READS + i], totals[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write result.
|
||||||
|
if (block.thread_index().y == 0) {
|
||||||
|
cub::StoreDirectBlocked(
|
||||||
|
column,
|
||||||
|
out + out_idx * args.reduction_stride,
|
||||||
|
totals,
|
||||||
|
args.reduction_stride);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
template <
|
template <
|
||||||
typename T,
|
typename T,
|
||||||
typename U,
|
typename U,
|
||||||
@@ -91,94 +152,67 @@ template <
|
|||||||
int BM,
|
int BM,
|
||||||
int BN,
|
int BN,
|
||||||
int N_READS = 4>
|
int N_READS = 4>
|
||||||
__global__ void
|
__global__ void col_reduce_looped(
|
||||||
col_reduce_looped(T* in, U* out, const __grid_constant__ ColReduceArgs args) {
|
const T* in,
|
||||||
|
U* out,
|
||||||
|
const __grid_constant__ ColReduceArgs args) {
|
||||||
auto grid = cg::this_grid();
|
auto grid = cg::this_grid();
|
||||||
auto block = cg::this_thread_block();
|
auto block = cg::this_thread_block();
|
||||||
auto warp = cg::tiled_partition<WARP_SIZE>(block);
|
auto warp = cg::tiled_partition<WARP_SIZE>(block);
|
||||||
|
|
||||||
constexpr int threads_per_row = BN / N_READS;
|
constexpr int n_warps = BN / N_READS;
|
||||||
|
|
||||||
// Compute the indices for the tile
|
int out_idx = grid.block_rank() / grid.dim_blocks().x;
|
||||||
size_t tile_idx = grid.block_rank();
|
in += elem_to_loc(out_idx, args.shape.data(), args.strides.data(), args.ndim);
|
||||||
size_t tile_x = tile_idx % ((args.reduction_stride + BN - 1) / BN);
|
|
||||||
size_t tile_y = tile_idx / ((args.reduction_stride + BN - 1) / BN);
|
|
||||||
|
|
||||||
// Compute the indices for the thread within the tile
|
|
||||||
short thread_x = block.thread_rank() % threads_per_row;
|
|
||||||
short thread_y = block.thread_rank() / threads_per_row;
|
|
||||||
|
|
||||||
// Move the input pointer
|
|
||||||
in += elem_to_loc(tile_y, args.shape.data(), args.strides.data(), args.ndim) +
|
|
||||||
tile_x * BN;
|
|
||||||
|
|
||||||
// Initialize the running totals
|
|
||||||
Op op;
|
Op op;
|
||||||
U totals[N_READS];
|
U totals[N_READS];
|
||||||
for (int i = 0; i < N_READS; i++) {
|
for (int i = 0; i < N_READS; i++) {
|
||||||
totals[i] = ReduceInit<Op, T>::value();
|
totals[i] = ReduceInit<Op, T>::value();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Read input to local.
|
||||||
|
int r = block.thread_rank() / n_warps;
|
||||||
|
int column = block.thread_rank() % n_warps;
|
||||||
|
int in_offset = grid.block_index().x * BN;
|
||||||
LoopedElemToLoc<NDIM, (NDIM > 2)> loop(args.reduce_ndim);
|
LoopedElemToLoc<NDIM, (NDIM > 2)> loop(args.reduce_ndim);
|
||||||
loop.next(thread_y, args.reduce_shape.data(), args.reduce_strides.data());
|
loop.next(r, args.reduce_shape.data(), args.reduce_strides.data());
|
||||||
size_t total = args.non_col_reductions * args.reduction_size;
|
for (; r < args.non_col_reductions * args.reduction_size; r += BM) {
|
||||||
if (tile_x * BN + BN <= args.reduction_stride) {
|
U vals[N_READS];
|
||||||
if (args.reduction_stride % N_READS == 0) {
|
cub::LoadDirectBlocked(
|
||||||
for (size_t r = thread_y; r < total; r += BM) {
|
column,
|
||||||
T vals[N_READS];
|
make_cast_iterator<U>(in + loop.location() + in_offset),
|
||||||
cub::LoadDirectBlockedVectorized(thread_x, in + loop.location(), vals);
|
vals,
|
||||||
for (int i = 0; i < N_READS; i++) {
|
args.reduction_stride - in_offset,
|
||||||
totals[i] = op(totals[i], __cast<U, T>(vals[i]));
|
ReduceInit<Op, T>::value());
|
||||||
}
|
for (int i = 0; i < N_READS; i++) {
|
||||||
loop.next(BM, args.reduce_shape.data(), args.reduce_strides.data());
|
totals[i] = op(vals[i], totals[i]);
|
||||||
}
|
|
||||||
} else {
|
|
||||||
for (size_t r = thread_y; r < total; r += BM) {
|
|
||||||
T vals[N_READS];
|
|
||||||
cub::LoadDirectBlocked(thread_x, in + loop.location(), vals);
|
|
||||||
for (int i = 0; i < N_READS; i++) {
|
|
||||||
totals[i] = op(totals[i], __cast<U, T>(vals[i]));
|
|
||||||
}
|
|
||||||
loop.next(BM, args.reduce_shape.data(), args.reduce_strides.data());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
for (size_t r = thread_y; r < total; r += BM) {
|
|
||||||
T vals[N_READS];
|
|
||||||
cub::LoadDirectBlocked(
|
|
||||||
thread_x,
|
|
||||||
in + loop.location(),
|
|
||||||
vals,
|
|
||||||
args.reduction_stride - tile_x * BN,
|
|
||||||
__cast<T, U>(ReduceInit<Op, T>::value()));
|
|
||||||
for (int i = 0; i < N_READS; i++) {
|
|
||||||
totals[i] = op(totals[i], __cast<U, T>(vals[i]));
|
|
||||||
}
|
|
||||||
loop.next(BM, args.reduce_shape.data(), args.reduce_strides.data());
|
|
||||||
}
|
}
|
||||||
|
loop.next(BM, args.reduce_shape.data(), args.reduce_strides.data());
|
||||||
}
|
}
|
||||||
|
|
||||||
// Do warp reduce for each output.
|
// Do warp reduce for each output.
|
||||||
constexpr int n_outputs = BN / threads_per_row;
|
constexpr int n_outputs = BN / n_warps;
|
||||||
static_assert(BM == 32 && n_outputs == N_READS);
|
static_assert(BM == 32 && n_outputs == N_READS);
|
||||||
__shared__ U shared_vals[BM * BN];
|
__shared__ U shared_vals[BM * BN];
|
||||||
short s_idx = thread_y * BN + thread_x * N_READS;
|
size_t col = block.thread_index().y * BN + block.thread_index().x * N_READS;
|
||||||
for (int i = 0; i < N_READS; i++) {
|
for (int i = 0; i < N_READS; i++) {
|
||||||
shared_vals[s_idx + i] = totals[i];
|
shared_vals[col + i] = totals[i];
|
||||||
}
|
}
|
||||||
block.sync();
|
block.sync();
|
||||||
s_idx = warp.thread_rank() * BN + warp.meta_group_rank() * n_outputs;
|
col = warp.thread_rank() * BN + warp.meta_group_rank() * n_outputs;
|
||||||
for (int i = 0; i < n_outputs; i++) {
|
for (int i = 0; i < n_outputs; i++) {
|
||||||
totals[i] = cg::reduce(warp, shared_vals[s_idx + i], op);
|
totals[i] = cg::reduce(warp, shared_vals[col + i], op);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Write result.
|
// Write result.
|
||||||
if (warp.thread_rank() == 0) {
|
if (warp.thread_rank() == 0) {
|
||||||
|
size_t out_offset = grid.block_index().x * BN;
|
||||||
cub::StoreDirectBlocked(
|
cub::StoreDirectBlocked(
|
||||||
warp.meta_group_rank(),
|
warp.meta_group_rank(),
|
||||||
out + tile_y * args.reduction_stride + tile_x * BN,
|
out + out_idx * args.reduction_stride + out_offset,
|
||||||
totals,
|
totals,
|
||||||
args.reduction_stride - tile_x * BN);
|
args.reduction_stride - out_offset);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -186,55 +220,14 @@ col_reduce_looped(T* in, U* out, const __grid_constant__ ColReduceArgs args) {
|
|||||||
|
|
||||||
inline auto output_grid_for_col_reduce(
|
inline auto output_grid_for_col_reduce(
|
||||||
const array& out,
|
const array& out,
|
||||||
const cu::ColReduceArgs& args,
|
const cu::ColReduceArgs& args) {
|
||||||
int bn) {
|
auto out_shape = out.shape();
|
||||||
int gx, gy = 1;
|
auto out_strides = out.strides();
|
||||||
size_t n_inner_blocks = cuda::ceil_div(args.reduction_stride, bn);
|
while (!out_shape.empty() && out_strides.back() < args.reduction_stride) {
|
||||||
size_t n_outer_blocks = out.size() / args.reduction_stride;
|
out_shape.pop_back();
|
||||||
size_t n_blocks = n_outer_blocks * n_inner_blocks;
|
out_strides.pop_back();
|
||||||
while (n_blocks / gy > INT32_MAX) {
|
|
||||||
gy *= 2;
|
|
||||||
}
|
}
|
||||||
gx = cuda::ceil_div(n_blocks, gy);
|
return get_2d_grid_dims(out_shape, out_strides);
|
||||||
|
|
||||||
return dim3(gx, gy, 1);
|
|
||||||
}
|
|
||||||
|
|
||||||
void col_reduce_looped(
|
|
||||||
cu::CommandEncoder& encoder,
|
|
||||||
const array& in,
|
|
||||||
array& out,
|
|
||||||
Reduce::ReduceType reduce_type,
|
|
||||||
const std::vector<int>& axes,
|
|
||||||
const ReductionPlan& plan,
|
|
||||||
cu::ColReduceArgs args) {
|
|
||||||
// Allocate data for the output using in's layout to access them as
|
|
||||||
// contiguously as possible.
|
|
||||||
allocate_same_layout(out, in, axes);
|
|
||||||
|
|
||||||
encoder.set_input_array(in);
|
|
||||||
encoder.set_output_array(out);
|
|
||||||
encoder.launch_kernel([&](cudaStream_t stream) {
|
|
||||||
MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE, {
|
|
||||||
MLX_SWITCH_REDUCE_OPS(reduce_type, OP, {
|
|
||||||
MLX_SWITCH_REDUCE_NDIM(args.reduce_ndim, NDIM, {
|
|
||||||
using T = cuda_type_t<CTYPE>;
|
|
||||||
using U = cu::ReduceResult<OP, T>::type;
|
|
||||||
|
|
||||||
// Cub doesn't like const pointers for vectorized loads. (sigh)
|
|
||||||
T* indata = const_cast<T*>(in.data<T>());
|
|
||||||
|
|
||||||
constexpr int N_READS = 4;
|
|
||||||
constexpr int BM = 32;
|
|
||||||
constexpr int BN = 32;
|
|
||||||
dim3 grid = output_grid_for_col_reduce(out, args, BN);
|
|
||||||
int blocks = BM * BN / N_READS;
|
|
||||||
auto kernel = cu::col_reduce_looped<T, U, OP, NDIM, BM, BN, N_READS>;
|
|
||||||
kernel<<<grid, blocks, 0, stream>>>(indata, out.data<U>(), args);
|
|
||||||
});
|
|
||||||
});
|
|
||||||
});
|
|
||||||
});
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void col_reduce(
|
void col_reduce(
|
||||||
@@ -244,23 +237,42 @@ void col_reduce(
|
|||||||
Reduce::ReduceType reduce_type,
|
Reduce::ReduceType reduce_type,
|
||||||
const std::vector<int>& axes,
|
const std::vector<int>& axes,
|
||||||
const ReductionPlan& plan) {
|
const ReductionPlan& plan) {
|
||||||
// Current col reduce options
|
|
||||||
//
|
|
||||||
// - col_reduce_looped
|
|
||||||
//
|
|
||||||
// It is a general strided reduce. Each threadblock computes the output for
|
|
||||||
// a subrow of the fast moving axis. For instance 32 elements.
|
|
||||||
//
|
|
||||||
// Notes: As in row reduce we opt to read as much in order as possible and
|
|
||||||
// leave transpositions as they are (contrary to our Metal backend).
|
|
||||||
//
|
|
||||||
// Moreover we need different kernels for short rows and tuning
|
|
||||||
|
|
||||||
// Make the args struct to help route to the best kernel
|
|
||||||
cu::ColReduceArgs args(in, plan, axes);
|
cu::ColReduceArgs args(in, plan, axes);
|
||||||
|
|
||||||
// Fallback col reduce
|
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||||
col_reduce_looped(encoder, in, out, reduce_type, axes, plan, args);
|
MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE, {
|
||||||
|
using InType = cuda_type_t<CTYPE>;
|
||||||
|
MLX_SWITCH_REDUCE_OPS(reduce_type, OP, {
|
||||||
|
using OutType = cu::ReduceResult<OP, InType>::type;
|
||||||
|
MLX_SWITCH_REDUCE_NDIM(args.reduce_ndim, NDIM, {
|
||||||
|
constexpr int N_READS = 4;
|
||||||
|
dim3 block_dims;
|
||||||
|
dim3 num_blocks = output_grid_for_col_reduce(out, args);
|
||||||
|
num_blocks.z = num_blocks.y;
|
||||||
|
num_blocks.y = num_blocks.x;
|
||||||
|
auto kernel =
|
||||||
|
cu::col_reduce_small<InType, OutType, OP, NDIM, N_READS>;
|
||||||
|
size_t total = args.non_col_reductions * args.reduction_size;
|
||||||
|
if (total < 32) {
|
||||||
|
size_t stride_blocks =
|
||||||
|
cuda::ceil_div(args.reduction_stride, N_READS);
|
||||||
|
block_dims.x = std::min(stride_blocks, 32ul);
|
||||||
|
block_dims.y = std::min(total, 8ul);
|
||||||
|
num_blocks.x = cuda::ceil_div(stride_blocks, block_dims.x);
|
||||||
|
} else {
|
||||||
|
constexpr int BM = 32;
|
||||||
|
constexpr int BN = 32;
|
||||||
|
block_dims.x = BM * BN / N_READS;
|
||||||
|
num_blocks.x = cuda::ceil_div(args.reduction_stride, BN);
|
||||||
|
kernel = cu::
|
||||||
|
col_reduce_looped<InType, OutType, OP, NDIM, BM, BN, N_READS>;
|
||||||
|
}
|
||||||
|
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
||||||
|
in.data<InType>(), out.data<OutType>(), args);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
});
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
|||||||
@@ -1,50 +0,0 @@
|
|||||||
// Copyright © 2025 Apple Inc.
|
|
||||||
|
|
||||||
#include "mlx/backend/cuda/device.h"
|
|
||||||
#include "mlx/backend/cuda/reduce/reduce.cuh"
|
|
||||||
|
|
||||||
#include <cooperative_groups.h>
|
|
||||||
|
|
||||||
namespace mlx::core {
|
|
||||||
|
|
||||||
namespace cu {
|
|
||||||
|
|
||||||
namespace cg = cooperative_groups;
|
|
||||||
|
|
||||||
template <typename T, typename U, typename Op>
|
|
||||||
__global__ void init_reduce(U* out, size_t size) {
|
|
||||||
auto index = cg::this_grid().thread_rank();
|
|
||||||
if (index < size) {
|
|
||||||
out[index] = ReduceInit<Op, T>::value();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace cu
|
|
||||||
|
|
||||||
void init_reduce(
|
|
||||||
cu::CommandEncoder& encoder,
|
|
||||||
const array& in,
|
|
||||||
array& out,
|
|
||||||
Reduce::ReduceType reduce_type) {
|
|
||||||
// Allocate if needed
|
|
||||||
if (out.data_shared_ptr() == nullptr) {
|
|
||||||
out.set_data(allocator::malloc(out.nbytes()));
|
|
||||||
}
|
|
||||||
|
|
||||||
encoder.set_output_array(out);
|
|
||||||
encoder.launch_kernel([&](cudaStream_t stream) {
|
|
||||||
MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE, {
|
|
||||||
MLX_SWITCH_REDUCE_OPS(reduce_type, OP, {
|
|
||||||
using T = cuda_type_t<CTYPE>;
|
|
||||||
using U = cu::ReduceResult<OP, T>::type;
|
|
||||||
auto kernel = cu::init_reduce<T, U, OP>;
|
|
||||||
dim3 grid = get_2d_grid_dims(out.shape(), out.strides());
|
|
||||||
dim3 block(grid.x < 1024 ? grid.x : 1024, 1, 1);
|
|
||||||
grid.x = (grid.x + 1023) / 1024;
|
|
||||||
kernel<<<grid, block, 0, stream>>>(out.data<U>(), out.size());
|
|
||||||
});
|
|
||||||
});
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace mlx::core
|
|
||||||
@@ -47,11 +47,13 @@ namespace mlx::core {
|
|||||||
throw std::invalid_argument("Unknown reduce type."); \
|
throw std::invalid_argument("Unknown reduce type."); \
|
||||||
}
|
}
|
||||||
|
|
||||||
void all_reduce(
|
void segmented_reduce(
|
||||||
cu::CommandEncoder& encoder,
|
cu::CommandEncoder& encoder,
|
||||||
const array& in,
|
const array& in,
|
||||||
array& out,
|
array& out,
|
||||||
Reduce::ReduceType reduce_type);
|
Reduce::ReduceType reduce_type,
|
||||||
|
const std::vector<int>& axes,
|
||||||
|
const ReductionPlan& plan);
|
||||||
|
|
||||||
void row_reduce(
|
void row_reduce(
|
||||||
cu::CommandEncoder& encoder,
|
cu::CommandEncoder& encoder,
|
||||||
@@ -69,10 +71,4 @@ void col_reduce(
|
|||||||
const std::vector<int>& axes,
|
const std::vector<int>& axes,
|
||||||
const ReductionPlan& plan);
|
const ReductionPlan& plan);
|
||||||
|
|
||||||
void init_reduce(
|
|
||||||
cu::CommandEncoder& encoder,
|
|
||||||
const array& in,
|
|
||||||
array& out,
|
|
||||||
Reduce::ReduceType reduce_type);
|
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
|||||||
@@ -3,89 +3,48 @@
|
|||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include "mlx/backend/cuda/device/utils.cuh"
|
#include "mlx/backend/cuda/device/utils.cuh"
|
||||||
#include "mlx/backend/cuda/reduce/reduce_utils.cuh"
|
|
||||||
|
|
||||||
namespace mlx::core::cu {
|
namespace mlx::core::cu {
|
||||||
|
|
||||||
// Reduce ops.
|
// Reduce ops.
|
||||||
struct And {
|
struct And {
|
||||||
__device__ __forceinline__ bool operator()(bool a, bool b) {
|
__device__ bool operator()(bool a, bool b) {
|
||||||
return a && b;
|
return a && b;
|
||||||
}
|
}
|
||||||
|
|
||||||
__device__ void atomic_update(bool* x, bool y) {
|
|
||||||
atomic_reduce<bool, And>(x, y);
|
|
||||||
}
|
|
||||||
};
|
};
|
||||||
|
|
||||||
struct Or {
|
struct Or {
|
||||||
__device__ __forceinline__ bool operator()(bool a, bool b) {
|
__device__ bool operator()(bool a, bool b) {
|
||||||
return a || b;
|
return a || b;
|
||||||
}
|
}
|
||||||
|
|
||||||
__device__ void atomic_update(bool* x, bool y) {
|
|
||||||
atomic_reduce<bool, Or>(x, y);
|
|
||||||
}
|
|
||||||
};
|
};
|
||||||
|
|
||||||
struct Sum {
|
struct Sum {
|
||||||
template <typename T>
|
template <typename T>
|
||||||
__device__ __forceinline__ T operator()(T a, T b) {
|
__device__ T operator()(T a, T b) {
|
||||||
return a + b;
|
return a + b;
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
__device__ void atomic_update(T* x, T y) {
|
|
||||||
atomic_reduce<T, Sum>(x, y);
|
|
||||||
}
|
|
||||||
|
|
||||||
__device__ void atomic_update(__nv_bfloat16* x, __nv_bfloat16 y) {
|
|
||||||
atomicAdd(x, y);
|
|
||||||
}
|
|
||||||
|
|
||||||
__device__ void atomic_update(int* x, int y) {
|
|
||||||
atomicAdd(x, y);
|
|
||||||
}
|
|
||||||
|
|
||||||
__device__ void atomic_update(float* x, float y) {
|
|
||||||
atomicAdd(x, y);
|
|
||||||
}
|
|
||||||
};
|
};
|
||||||
|
|
||||||
struct Prod {
|
struct Prod {
|
||||||
template <typename T>
|
template <typename T>
|
||||||
__device__ __forceinline__ T operator()(T a, T b) {
|
__device__ T operator()(T a, T b) {
|
||||||
return a * b;
|
return a * b;
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
__device__ void atomic_update(T* x, T y) {
|
|
||||||
atomic_reduce<T, Prod>(x, y);
|
|
||||||
}
|
|
||||||
};
|
};
|
||||||
|
|
||||||
struct Min {
|
struct Min {
|
||||||
template <typename T>
|
template <typename T>
|
||||||
__device__ __forceinline__ T operator()(T a, T b) {
|
__device__ T operator()(T a, T b) {
|
||||||
return a < b ? a : b;
|
return a < b ? a : b;
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
__device__ void atomic_update(T* x, T y) {
|
|
||||||
atomic_reduce<T, Min>(x, y);
|
|
||||||
}
|
|
||||||
};
|
};
|
||||||
|
|
||||||
struct Max {
|
struct Max {
|
||||||
template <typename T>
|
template <typename T>
|
||||||
__device__ __forceinline__ T operator()(T a, T b) {
|
__device__ T operator()(T a, T b) {
|
||||||
return a > b ? a : b;
|
return a > b ? a : b;
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
__device__ void atomic_update(T* x, T y) {
|
|
||||||
atomic_reduce<T, Max>(x, y);
|
|
||||||
}
|
|
||||||
};
|
};
|
||||||
|
|
||||||
// Traits to get the result type of reduce op.
|
// Traits to get the result type of reduce op.
|
||||||
@@ -161,7 +120,7 @@ template <typename T>
|
|||||||
struct ReduceInit<Prod, T> {
|
struct ReduceInit<Prod, T> {
|
||||||
static constexpr __host__ __device__ auto value() {
|
static constexpr __host__ __device__ auto value() {
|
||||||
if constexpr (cuda::std::is_same_v<T, cuComplex>) {
|
if constexpr (cuda::std::is_same_v<T, cuComplex>) {
|
||||||
return T{1, 0};
|
return T{1, 1};
|
||||||
} else {
|
} else {
|
||||||
return typename ReduceResult<Prod, T>::type{1};
|
return typename ReduceResult<Prod, T>::type{1};
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,158 +0,0 @@
|
|||||||
// Copyright © 2025 Apple Inc.
|
|
||||||
|
|
||||||
#pragma once
|
|
||||||
|
|
||||||
#include <numeric>
|
|
||||||
|
|
||||||
#include "mlx/backend/cuda/device/utils.cuh"
|
|
||||||
|
|
||||||
#include <cooperative_groups.h>
|
|
||||||
#include <cooperative_groups/reduce.h>
|
|
||||||
|
|
||||||
namespace mlx::core {
|
|
||||||
|
|
||||||
namespace cu {
|
|
||||||
|
|
||||||
namespace cg = cooperative_groups;
|
|
||||||
|
|
||||||
template <size_t N>
|
|
||||||
struct uint_by_size;
|
|
||||||
template <>
|
|
||||||
struct uint_by_size<2> {
|
|
||||||
using type = uint16_t;
|
|
||||||
};
|
|
||||||
template <>
|
|
||||||
struct uint_by_size<4> {
|
|
||||||
using type = uint32_t;
|
|
||||||
};
|
|
||||||
template <>
|
|
||||||
struct uint_by_size<8> {
|
|
||||||
using type = unsigned long long int;
|
|
||||||
};
|
|
||||||
|
|
||||||
template <typename T, typename Op>
|
|
||||||
__device__ void atomic_reduce(T* x, T y) {
|
|
||||||
if constexpr (sizeof(T) == 1) {
|
|
||||||
using U = uint16_t;
|
|
||||||
U* x_int = (U*)((char*)x - ((size_t)x % 2));
|
|
||||||
int shift = ((char*)x - (char*)x_int) * 8;
|
|
||||||
int mask = 0xff << shift;
|
|
||||||
U old_val, new_val;
|
|
||||||
do {
|
|
||||||
old_val = *x_int;
|
|
||||||
T result = Op{}(static_cast<T>((old_val >> shift) & 0xff), y);
|
|
||||||
new_val = (old_val & ~mask) | (result << shift);
|
|
||||||
} while (atomicCAS(x_int, old_val, new_val) != old_val);
|
|
||||||
} else {
|
|
||||||
using U = typename uint_by_size<sizeof(T)>::type;
|
|
||||||
U* x_int = (U*)(x);
|
|
||||||
U old_val, new_val;
|
|
||||||
do {
|
|
||||||
old_val = *x_int;
|
|
||||||
T result = Op{}(*((T*)&old_val), y);
|
|
||||||
new_val = *((U*)&result);
|
|
||||||
} while (atomicCAS(x_int, old_val, new_val) != old_val);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO: Should make a custom complex type
|
|
||||||
template <typename U, typename T>
|
|
||||||
inline __device__ U __cast(T x) {
|
|
||||||
return static_cast<U>(x);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <>
|
|
||||||
inline __device__ bool __cast<bool, cuComplex>(cuComplex x) {
|
|
||||||
return x.x != 0 && x.y != 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
template <>
|
|
||||||
inline __device__ cuComplex __cast<cuComplex, bool>(bool x) {
|
|
||||||
return x ? make_cuFloatComplex(1, 1) : make_cuFloatComplex(0, 0);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T, int N, typename Block, typename Warp, typename Op>
|
|
||||||
inline __device__ void
|
|
||||||
block_reduce(Block block, Warp warp, T (&vals)[N], T* smem, Op op, T init) {
|
|
||||||
// First reduce in the current warp
|
|
||||||
for (int i = 0; i < N; i++) {
|
|
||||||
vals[i] = cg::reduce(warp, vals[i], op);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Reduce across warps
|
|
||||||
if (warp.meta_group_size() > 1) {
|
|
||||||
if (warp.thread_rank() == 0) {
|
|
||||||
for (int i = 0; i < N; i++) {
|
|
||||||
smem[warp.meta_group_rank() * N + i] = vals[i];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
block.sync();
|
|
||||||
if (warp.thread_rank() < warp.meta_group_size()) {
|
|
||||||
for (int i = 0; i < N; i++) {
|
|
||||||
vals[i] = smem[warp.thread_rank() * N + i];
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
for (int i = 0; i < N; i++) {
|
|
||||||
vals[i] = init;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
for (int i = 0; i < N; i++) {
|
|
||||||
vals[i] = cg::reduce(warp, vals[i], op);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace cu
|
|
||||||
|
|
||||||
inline void allocate_same_layout(
|
|
||||||
array& out,
|
|
||||||
const array& in,
|
|
||||||
const std::vector<int>& axes) {
|
|
||||||
if (in.flags().row_contiguous) {
|
|
||||||
out.set_data(allocator::malloc(out.nbytes()));
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (out.ndim() < in.ndim()) {
|
|
||||||
throw std::runtime_error(
|
|
||||||
"Reduction without keepdims only supported for row-contiguous inputs");
|
|
||||||
}
|
|
||||||
|
|
||||||
// Calculate the transpositions applied to in in order to apply them to out.
|
|
||||||
std::vector<int> axis_order(in.ndim());
|
|
||||||
std::iota(axis_order.begin(), axis_order.end(), 0);
|
|
||||||
std::sort(axis_order.begin(), axis_order.end(), [&](int left, int right) {
|
|
||||||
return in.strides(left) > in.strides(right);
|
|
||||||
});
|
|
||||||
|
|
||||||
// Transpose the shape and calculate the strides
|
|
||||||
Shape out_shape(in.ndim());
|
|
||||||
Strides out_strides(in.ndim(), 1);
|
|
||||||
for (int i = 0; i < in.ndim(); i++) {
|
|
||||||
out_shape[i] = out.shape(axis_order[i]);
|
|
||||||
}
|
|
||||||
for (int i = in.ndim() - 2; i >= 0; i--) {
|
|
||||||
out_strides[i] = out_shape[i + 1] * out_strides[i + 1];
|
|
||||||
}
|
|
||||||
|
|
||||||
// Reverse the axis order to get the final strides
|
|
||||||
Strides final_strides(in.ndim());
|
|
||||||
for (int i = 0; i < in.ndim(); i++) {
|
|
||||||
final_strides[axis_order[i]] = out_strides[i];
|
|
||||||
}
|
|
||||||
|
|
||||||
// Calculate the resulting contiguity and do the memory allocation
|
|
||||||
auto [data_size, rc, cc] = check_contiguity(out.shape(), final_strides);
|
|
||||||
auto fl = in.flags();
|
|
||||||
fl.row_contiguous = rc;
|
|
||||||
fl.col_contiguous = cc;
|
|
||||||
fl.contiguous = true;
|
|
||||||
out.set_data(
|
|
||||||
allocator::malloc(out.nbytes()),
|
|
||||||
data_size,
|
|
||||||
final_strides,
|
|
||||||
fl,
|
|
||||||
allocator::free);
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace mlx::core
|
|
||||||
@@ -1,7 +1,5 @@
|
|||||||
// Copyright © 2025 Apple Inc.
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
#include <numeric>
|
|
||||||
|
|
||||||
#include "mlx/backend/cuda/device.h"
|
#include "mlx/backend/cuda/device.h"
|
||||||
#include "mlx/backend/cuda/device/cast_op.cuh"
|
#include "mlx/backend/cuda/device/cast_op.cuh"
|
||||||
#include "mlx/backend/cuda/reduce/reduce.cuh"
|
#include "mlx/backend/cuda/reduce/reduce.cuh"
|
||||||
@@ -57,108 +55,84 @@ struct RowReduceArgs {
|
|||||||
non_row_reductions *= reduce_shape[i];
|
non_row_reductions *= reduce_shape[i];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Convert shape and strides as if in was contiguous
|
|
||||||
void sort_access_pattern(const array& in, const std::vector<int>& axes) {
|
|
||||||
auto shape_vec = in.shape();
|
|
||||||
auto strides_vec = in.strides();
|
|
||||||
std::tie(shape_vec, strides_vec) =
|
|
||||||
shapes_without_reduction_axes(shape_vec, strides_vec, axes);
|
|
||||||
std::vector<int> indices(shape_vec.size());
|
|
||||||
std::iota(indices.begin(), indices.end(), 0);
|
|
||||||
std::sort(indices.begin(), indices.end(), [&](int left, int right) {
|
|
||||||
return strides_vec[left] > strides_vec[right];
|
|
||||||
});
|
|
||||||
decltype(shape_vec) sorted_shape;
|
|
||||||
decltype(strides_vec) sorted_strides;
|
|
||||||
for (auto idx : indices) {
|
|
||||||
sorted_shape.push_back(shape_vec[idx]);
|
|
||||||
sorted_strides.push_back(strides_vec[idx]);
|
|
||||||
}
|
|
||||||
std::tie(shape_vec, strides_vec) =
|
|
||||||
collapse_contiguous_dims(sorted_shape, sorted_strides);
|
|
||||||
shape = const_param(shape_vec);
|
|
||||||
strides = const_param(strides_vec);
|
|
||||||
ndim = shape_vec.size();
|
|
||||||
}
|
|
||||||
};
|
};
|
||||||
|
|
||||||
template <typename T, typename U, typename ReduceOp, int N = 4, int M = 1>
|
template <typename T, typename U, typename Op, int NDIM, int N_READS = 4>
|
||||||
__global__ void row_reduce_simple(T* in, U* out, size_t n_rows, int size) {
|
__global__ void row_reduce_small(
|
||||||
|
const T* in,
|
||||||
|
U* out,
|
||||||
|
size_t out_size,
|
||||||
|
const __grid_constant__ RowReduceArgs args) {
|
||||||
|
size_t out_idx = cg::this_grid().thread_rank();
|
||||||
|
if (out_idx >= out_size) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
Op op;
|
||||||
|
|
||||||
|
U total_val = ReduceInit<Op, T>::value();
|
||||||
|
LoopedElemToLoc<NDIM, (NDIM > 2)> loop(args.reduce_ndim);
|
||||||
|
|
||||||
|
in += elem_to_loc(out_idx, args.shape.data(), args.strides.data(), args.ndim);
|
||||||
|
|
||||||
|
for (size_t n = 0; n < args.non_row_reductions; n++) {
|
||||||
|
for (int r = 0; r < cuda::ceil_div(args.row_size, N_READS); r++) {
|
||||||
|
U vals[N_READS];
|
||||||
|
cub::LoadDirectBlocked(
|
||||||
|
r,
|
||||||
|
make_cast_iterator<U>(in + loop.location()),
|
||||||
|
vals,
|
||||||
|
args.row_size,
|
||||||
|
ReduceInit<Op, T>::value());
|
||||||
|
total_val = op(total_val, cub::ThreadReduce(vals, op));
|
||||||
|
}
|
||||||
|
loop.next(args.reduce_shape.data(), args.reduce_strides.data());
|
||||||
|
}
|
||||||
|
|
||||||
|
out[out_idx] = total_val;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, typename U, typename Op, int NDIM, int N_READS = 4>
|
||||||
|
__global__ void row_reduce_small_warp(
|
||||||
|
const T* in,
|
||||||
|
U* out,
|
||||||
|
size_t out_size,
|
||||||
|
const __grid_constant__ RowReduceArgs args) {
|
||||||
auto grid = cg::this_grid();
|
auto grid = cg::this_grid();
|
||||||
auto block = cg::this_thread_block();
|
auto block = cg::this_thread_block();
|
||||||
auto warp = cg::tiled_partition<WARP_SIZE>(block);
|
auto warp = cg::tiled_partition<WARP_SIZE>(block);
|
||||||
|
|
||||||
const U init = cu::ReduceInit<ReduceOp, T>::value();
|
size_t out_idx = grid.thread_rank() / WARP_SIZE;
|
||||||
ReduceOp op;
|
if (out_idx >= out_size) {
|
||||||
|
return;
|
||||||
T vals[M][N];
|
|
||||||
U accs[M];
|
|
||||||
for (int i = 0; i < M; i++) {
|
|
||||||
accs[i] = init;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
const size_t start_row =
|
Op op;
|
||||||
min(n_rows - M, static_cast<size_t>(grid.block_rank() * M));
|
|
||||||
const size_t full_blocks = size / (block.size() * N);
|
|
||||||
const size_t final_offset = full_blocks * (block.size() * N);
|
|
||||||
in += start_row * size;
|
|
||||||
out += start_row;
|
|
||||||
|
|
||||||
if (size % N == 0) {
|
U total_val = ReduceInit<Op, T>::value();
|
||||||
for (size_t r = 0; r < full_blocks; r++) {
|
LoopedElemToLoc<NDIM, (NDIM > 2)> loop(args.reduce_ndim);
|
||||||
for (int k = 0; k < M; k++) {
|
|
||||||
cub::LoadDirectBlockedVectorized<T, N>(
|
|
||||||
block.thread_rank(),
|
|
||||||
in + k * size + r * (block.size() * N),
|
|
||||||
vals[k]);
|
|
||||||
for (int j = 0; j < N; j++) {
|
|
||||||
accs[k] = op(accs[k], __cast<U, T>(vals[k][j]));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
for (size_t r = 0; r < full_blocks; r++) {
|
|
||||||
for (int k = 0; k < M; k++) {
|
|
||||||
cub::LoadDirectBlocked(
|
|
||||||
block.thread_rank(),
|
|
||||||
in + k * size + r * (block.size() * N),
|
|
||||||
vals[k]);
|
|
||||||
for (int j = 0; j < N; j++) {
|
|
||||||
accs[k] = op(accs[k], __cast<U, T>(vals[k][j]));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (final_offset < size) {
|
in += elem_to_loc(out_idx, args.shape.data(), args.strides.data(), args.ndim);
|
||||||
for (int k = 0; k < M; k++) {
|
|
||||||
|
for (size_t n = warp.thread_rank(); n < args.non_row_reductions;
|
||||||
|
n += WARP_SIZE) {
|
||||||
|
for (int r = 0; r < cuda::ceil_div(args.row_size, N_READS); r++) {
|
||||||
|
U vals[N_READS];
|
||||||
cub::LoadDirectBlocked(
|
cub::LoadDirectBlocked(
|
||||||
block.thread_rank(),
|
r,
|
||||||
in + k * size + final_offset,
|
make_cast_iterator<U>(in + loop.location()),
|
||||||
vals[k],
|
vals,
|
||||||
size,
|
args.row_size,
|
||||||
__cast<T, U>(init));
|
ReduceInit<Op, T>::value());
|
||||||
for (int j = 0; j < N; j++) {
|
total_val = op(total_val, cub::ThreadReduce(vals, op));
|
||||||
accs[k] = op(accs[k], __cast<U, T>(vals[k][j]));
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
loop.next(WARP_SIZE, args.reduce_shape.data(), args.reduce_strides.data());
|
||||||
}
|
}
|
||||||
|
|
||||||
__shared__ U shared_accumulators[32 * M];
|
total_val = cg::reduce(warp, total_val, op);
|
||||||
block_reduce(block, warp, accs, shared_accumulators, op, init);
|
|
||||||
|
|
||||||
if (block.thread_rank() == 0) {
|
if (warp.thread_rank() == 0) {
|
||||||
if (grid.block_rank() * M + M <= n_rows) {
|
out[out_idx] = total_val;
|
||||||
for (int i = 0; i < M; i++) {
|
|
||||||
out[i] = accs[i];
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
short offset = grid.block_rank() * M + M - n_rows;
|
|
||||||
for (int i = offset; i < M; i++) {
|
|
||||||
out[i] = accs[i];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -167,165 +141,55 @@ template <
|
|||||||
typename U,
|
typename U,
|
||||||
typename Op,
|
typename Op,
|
||||||
int NDIM,
|
int NDIM,
|
||||||
int BLOCK_DIM,
|
int BLOCK_DIM_X,
|
||||||
int N_READS = 4>
|
int N_READS = 4>
|
||||||
__global__ void row_reduce_looped(
|
__global__ void row_reduce_looped(
|
||||||
T* in,
|
const T* in,
|
||||||
U* out,
|
U* out,
|
||||||
size_t out_size,
|
size_t out_size,
|
||||||
const __grid_constant__ RowReduceArgs args) {
|
const __grid_constant__ RowReduceArgs args) {
|
||||||
auto grid = cg::this_grid();
|
auto grid = cg::this_grid();
|
||||||
auto block = cg::this_thread_block();
|
auto block = cg::this_thread_block();
|
||||||
auto warp = cg::tiled_partition<WARP_SIZE>(block);
|
|
||||||
|
|
||||||
size_t out_idx = grid.block_rank();
|
size_t out_idx = grid.thread_rank() / BLOCK_DIM_X;
|
||||||
|
if (out_idx >= out_size) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
Op op;
|
Op op;
|
||||||
|
|
||||||
U total[1];
|
U total_val = ReduceInit<Op, T>::value();
|
||||||
U init = ReduceInit<Op, T>::value();
|
|
||||||
total[0] = init;
|
|
||||||
LoopedElemToLoc<NDIM, (NDIM > 2)> loop(args.reduce_ndim);
|
LoopedElemToLoc<NDIM, (NDIM > 2)> loop(args.reduce_ndim);
|
||||||
size_t full_blocks = args.row_size / (BLOCK_DIM * N_READS);
|
|
||||||
size_t final_offset = full_blocks * BLOCK_DIM * N_READS;
|
|
||||||
|
|
||||||
in += elem_to_loc(out_idx, args.shape.data(), args.strides.data(), args.ndim);
|
in += elem_to_loc(out_idx, args.shape.data(), args.strides.data(), args.ndim);
|
||||||
|
|
||||||
for (size_t n = 0; n < args.non_row_reductions; n++) {
|
for (size_t n = 0; n < args.non_row_reductions; n++) {
|
||||||
for (size_t r = 0; r < full_blocks; r++) {
|
for (size_t r = 0; r < cuda::ceil_div(args.row_size, BLOCK_DIM_X * N_READS);
|
||||||
T vals[N_READS];
|
r++) {
|
||||||
cub::LoadDirectBlockedVectorized<T, N_READS>(
|
U vals[N_READS];
|
||||||
block.thread_rank(),
|
|
||||||
in + loop.location() + r * BLOCK_DIM * N_READS,
|
|
||||||
vals);
|
|
||||||
for (int i = 0; i < N_READS; i++) {
|
|
||||||
total[0] = op(total[0], __cast<U, T>(vals[i]));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (final_offset < args.row_size) {
|
|
||||||
T vals[N_READS];
|
|
||||||
cub::LoadDirectBlocked(
|
cub::LoadDirectBlocked(
|
||||||
block.thread_rank(),
|
r * BLOCK_DIM_X + block.thread_index().x,
|
||||||
in + loop.location() + final_offset,
|
make_cast_iterator<U>(in + loop.location()),
|
||||||
vals,
|
vals,
|
||||||
args.row_size - final_offset,
|
args.row_size,
|
||||||
__cast<T, U>(init));
|
ReduceInit<Op, T>::value());
|
||||||
for (int i = 0; i < N_READS; i++) {
|
total_val = op(total_val, cub::ThreadReduce(vals, op));
|
||||||
total[0] = op(total[0], __cast<U, T>(vals[i]));
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
// TODO: Maybe block.sync() here?
|
|
||||||
loop.next(args.reduce_shape.data(), args.reduce_strides.data());
|
loop.next(args.reduce_shape.data(), args.reduce_strides.data());
|
||||||
}
|
}
|
||||||
|
|
||||||
__shared__ U shared_accumulators[32];
|
typedef cub::BlockReduce<U, BLOCK_DIM_X> BlockReduceT;
|
||||||
block_reduce(block, warp, total, shared_accumulators, op, init);
|
__shared__ typename BlockReduceT::TempStorage temp;
|
||||||
|
|
||||||
|
total_val = BlockReduceT(temp).Reduce(total_val, op);
|
||||||
|
|
||||||
if (block.thread_rank() == 0) {
|
if (block.thread_rank() == 0) {
|
||||||
out[out_idx] = total[0];
|
out[out_idx] = total_val;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace cu
|
} // namespace cu
|
||||||
|
|
||||||
void row_reduce_simple(
|
|
||||||
cu::CommandEncoder& encoder,
|
|
||||||
const array& in,
|
|
||||||
array& out,
|
|
||||||
Reduce::ReduceType reduce_type,
|
|
||||||
const std::vector<int>& axes,
|
|
||||||
const ReductionPlan& plan) {
|
|
||||||
constexpr int N_READS = 8;
|
|
||||||
|
|
||||||
// Allocate data for the output using in's layout to avoid elem_to_loc in the
|
|
||||||
// kernel.
|
|
||||||
allocate_same_layout(out, in, axes);
|
|
||||||
|
|
||||||
// TODO: If out.size() < 1024 which will be a common case then write this in
|
|
||||||
// 2 passes. Something like 32 * out.size() and then do a warp reduce.
|
|
||||||
encoder.set_input_array(in);
|
|
||||||
encoder.set_output_array(out);
|
|
||||||
encoder.launch_kernel([&](cudaStream_t stream) {
|
|
||||||
MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE, {
|
|
||||||
MLX_SWITCH_REDUCE_OPS(reduce_type, OP, {
|
|
||||||
using T = cuda_type_t<CTYPE>;
|
|
||||||
using U = cu::ReduceResult<OP, T>::type;
|
|
||||||
|
|
||||||
// Cub doesn't like const pointers for vectorized loads. (sigh)
|
|
||||||
T* indata = const_cast<T*>(in.data<T>());
|
|
||||||
|
|
||||||
// Calculate the grid and block dims
|
|
||||||
size_t reductions = (plan.shape.back() + N_READS - 1) / N_READS;
|
|
||||||
dim3 grid = get_2d_grid_dims(out.shape(), out.strides());
|
|
||||||
int threads = std::min(1024UL, reductions);
|
|
||||||
threads = ((threads + WARP_SIZE - 1) / WARP_SIZE) * WARP_SIZE;
|
|
||||||
dim3 block(threads, 1, 1);
|
|
||||||
|
|
||||||
// Pick the kernel
|
|
||||||
auto kernel = cu::row_reduce_simple<T, U, OP, N_READS>;
|
|
||||||
if (grid.x >= 1024) {
|
|
||||||
grid.x = (grid.x + 1) / 2;
|
|
||||||
kernel = cu::row_reduce_simple<T, U, OP, N_READS, 2>;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Launch
|
|
||||||
kernel<<<grid, block, 0, stream>>>(
|
|
||||||
indata, out.data<U>(), out.size(), plan.shape.back());
|
|
||||||
});
|
|
||||||
});
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
void row_reduce_looped(
|
|
||||||
cu::CommandEncoder& encoder,
|
|
||||||
const array& in,
|
|
||||||
array& out,
|
|
||||||
Reduce::ReduceType reduce_type,
|
|
||||||
const std::vector<int>& axes,
|
|
||||||
const ReductionPlan& plan,
|
|
||||||
cu::RowReduceArgs args) {
|
|
||||||
constexpr int N_READS = 8;
|
|
||||||
|
|
||||||
// Allocate data for the output using in's layout to access them as
|
|
||||||
// contiguously as possible.
|
|
||||||
allocate_same_layout(out, in, axes);
|
|
||||||
|
|
||||||
encoder.set_input_array(in);
|
|
||||||
encoder.set_output_array(out);
|
|
||||||
encoder.launch_kernel([&](cudaStream_t stream) {
|
|
||||||
MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE, {
|
|
||||||
MLX_SWITCH_REDUCE_OPS(reduce_type, OP, {
|
|
||||||
using T = cuda_type_t<CTYPE>;
|
|
||||||
using U = cu::ReduceResult<OP, T>::type;
|
|
||||||
|
|
||||||
// Cub doesn't like const pointers for vectorized loads. (sigh)
|
|
||||||
T* indata = const_cast<T*>(in.data<T>());
|
|
||||||
|
|
||||||
// Calculate the grid and block dims
|
|
||||||
args.sort_access_pattern(in, axes);
|
|
||||||
dim3 grid = get_2d_grid_dims(out.shape(), out.strides());
|
|
||||||
size_t reductions = (args.row_size + N_READS - 1) / N_READS;
|
|
||||||
int threads = std::min(1024UL, reductions);
|
|
||||||
threads = ((threads + WARP_SIZE - 1) / WARP_SIZE) * WARP_SIZE;
|
|
||||||
dim3 block(threads, 1, 1);
|
|
||||||
|
|
||||||
// Pick the kernel
|
|
||||||
auto kernel = cu::row_reduce_looped<T, U, OP, 1, 32, N_READS>;
|
|
||||||
MLX_SWITCH_REDUCE_NDIM(args.reduce_ndim, NDIM, {
|
|
||||||
MLX_SWITCH_BLOCK_DIM(threads, THREADS, {
|
|
||||||
kernel = cu::row_reduce_looped<T, U, OP, NDIM, THREADS, N_READS>;
|
|
||||||
block.x = THREADS;
|
|
||||||
});
|
|
||||||
});
|
|
||||||
|
|
||||||
// Launch
|
|
||||||
kernel<<<grid, block, 0, stream>>>(
|
|
||||||
indata, out.data<U>(), out.size(), args);
|
|
||||||
});
|
|
||||||
});
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
void row_reduce(
|
void row_reduce(
|
||||||
cu::CommandEncoder& encoder,
|
cu::CommandEncoder& encoder,
|
||||||
const array& in,
|
const array& in,
|
||||||
@@ -333,35 +197,54 @@ void row_reduce(
|
|||||||
Reduce::ReduceType reduce_type,
|
Reduce::ReduceType reduce_type,
|
||||||
const std::vector<int>& axes,
|
const std::vector<int>& axes,
|
||||||
const ReductionPlan& plan) {
|
const ReductionPlan& plan) {
|
||||||
// Current row reduction options
|
|
||||||
//
|
|
||||||
// - row_reduce_simple
|
|
||||||
//
|
|
||||||
// That means that we are simply reducing across the fastest moving axis.
|
|
||||||
// We are reducing 1 or 2 rows per threadblock depending on the size of
|
|
||||||
// output.
|
|
||||||
//
|
|
||||||
// - row_reduce_looped
|
|
||||||
//
|
|
||||||
// It is a general row reduction. We are computing 1 output per
|
|
||||||
// threadblock. We read the fastest moving axis vectorized and loop over
|
|
||||||
// the rest of the axes.
|
|
||||||
//
|
|
||||||
// Notes: We opt to read as much in order as possible and leave
|
|
||||||
// transpositions as they are (contrary to our Metal backend).
|
|
||||||
|
|
||||||
// Simple row reduce means that we have 1 axis that we are reducing over and
|
|
||||||
// it has stride 1.
|
|
||||||
if (plan.shape.size() == 1) {
|
|
||||||
row_reduce_simple(encoder, in, out, reduce_type, axes, plan);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Make the args struct to help route to the best kernel
|
|
||||||
cu::RowReduceArgs args(in, plan, axes);
|
cu::RowReduceArgs args(in, plan, axes);
|
||||||
|
|
||||||
// Fallback row reduce
|
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||||
row_reduce_looped(encoder, in, out, reduce_type, axes, plan, std::move(args));
|
MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE, {
|
||||||
|
using InType = cuda_type_t<CTYPE>;
|
||||||
|
MLX_SWITCH_REDUCE_OPS(reduce_type, OP, {
|
||||||
|
using OutType = cu::ReduceResult<OP, InType>::type;
|
||||||
|
MLX_SWITCH_REDUCE_NDIM(args.reduce_ndim, NDIM, {
|
||||||
|
constexpr size_t N_READS = 4;
|
||||||
|
dim3 out_dims = get_2d_grid_dims(out.shape(), out.strides());
|
||||||
|
dim3 block_dims, num_blocks;
|
||||||
|
auto kernel =
|
||||||
|
cu::row_reduce_small<InType, OutType, OP, NDIM, N_READS>;
|
||||||
|
if (args.row_size <= 64) {
|
||||||
|
if ((args.non_row_reductions < 32 && args.row_size <= 8) ||
|
||||||
|
(args.non_row_reductions <= 8)) {
|
||||||
|
block_dims.x = std::min(out_dims.x, 1024u);
|
||||||
|
num_blocks.x = cuda::ceil_div(out_dims.x, block_dims.x);
|
||||||
|
num_blocks.y = out_dims.y;
|
||||||
|
} else {
|
||||||
|
block_dims.x = WARP_SIZE;
|
||||||
|
num_blocks.y = out_dims.x;
|
||||||
|
num_blocks.z = out_dims.y;
|
||||||
|
kernel =
|
||||||
|
cu::row_reduce_small_warp<InType, OutType, OP, NDIM, N_READS>;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
size_t num_threads = cuda::ceil_div(args.row_size, N_READS);
|
||||||
|
num_threads = cuda::ceil_div(num_threads, WARP_SIZE) * WARP_SIZE;
|
||||||
|
MLX_SWITCH_BLOCK_DIM(num_threads, BLOCK_DIM_X, {
|
||||||
|
num_blocks.y = out_dims.x;
|
||||||
|
num_blocks.z = out_dims.y;
|
||||||
|
block_dims.x = BLOCK_DIM_X;
|
||||||
|
kernel = cu::row_reduce_looped<
|
||||||
|
InType,
|
||||||
|
OutType,
|
||||||
|
OP,
|
||||||
|
NDIM,
|
||||||
|
BLOCK_DIM_X,
|
||||||
|
N_READS>;
|
||||||
|
});
|
||||||
|
}
|
||||||
|
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
||||||
|
in.data<InType>(), out.data<OutType>(), out.size(), args);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
});
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
|||||||
84
mlx/backend/cuda/reduce/segmented_reduce.cu
Normal file
84
mlx/backend/cuda/reduce/segmented_reduce.cu
Normal file
@@ -0,0 +1,84 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#include "mlx/backend/cuda/device.h"
|
||||||
|
#include "mlx/backend/cuda/device/cast_op.cuh"
|
||||||
|
#include "mlx/backend/cuda/reduce/reduce.cuh"
|
||||||
|
|
||||||
|
#include <thrust/device_ptr.h>
|
||||||
|
#include <cub/device/device_reduce.cuh>
|
||||||
|
#include <cub/device/device_segmented_reduce.cuh>
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
|
||||||
|
template <typename... Args>
|
||||||
|
void cub_all_reduce(cu::CommandEncoder& encoder, Args&&... args) {
|
||||||
|
// Allocate temporary storage.
|
||||||
|
size_t size;
|
||||||
|
CHECK_CUDA_ERROR(cub::DeviceReduce::Reduce(nullptr, size, args...));
|
||||||
|
array temp(allocator::malloc(size), {static_cast<int>(size)}, uint8);
|
||||||
|
encoder.add_temporary(temp);
|
||||||
|
// Run op.
|
||||||
|
CHECK_CUDA_ERROR(cub::DeviceReduce::Reduce(temp.data<void>(), size, args...));
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename... Args>
|
||||||
|
void cub_segmented_reduce(cu::CommandEncoder& encoder, Args&&... args) {
|
||||||
|
// Allocate temporary storage.
|
||||||
|
size_t size;
|
||||||
|
CHECK_CUDA_ERROR(cub::DeviceSegmentedReduce::Reduce(nullptr, size, args...));
|
||||||
|
array temp(allocator::malloc(size), {static_cast<int>(size)}, uint8);
|
||||||
|
encoder.add_temporary(temp);
|
||||||
|
// Run op.
|
||||||
|
CHECK_CUDA_ERROR(
|
||||||
|
cub::DeviceSegmentedReduce::Reduce(temp.data<void>(), size, args...));
|
||||||
|
}
|
||||||
|
|
||||||
|
struct MultiplyOp {
|
||||||
|
int factor;
|
||||||
|
__device__ int operator()(int i) {
|
||||||
|
return i * factor;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
void segmented_reduce(
|
||||||
|
cu::CommandEncoder& encoder,
|
||||||
|
const array& in,
|
||||||
|
array& out,
|
||||||
|
Reduce::ReduceType reduce_type,
|
||||||
|
const std::vector<int>& axes,
|
||||||
|
const ReductionPlan& plan) {
|
||||||
|
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||||
|
MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE, {
|
||||||
|
MLX_SWITCH_REDUCE_OPS(reduce_type, OP, {
|
||||||
|
using InType = cuda_type_t<CTYPE>;
|
||||||
|
using OutType = cu::ReduceResult<OP, InType>::type;
|
||||||
|
auto in_iter = cu::make_cast_iterator<OutType>(
|
||||||
|
thrust::device_pointer_cast(in.data<InType>()));
|
||||||
|
auto out_ptr = thrust::device_pointer_cast(out.data<OutType>());
|
||||||
|
auto init = cu::ReduceInit<OP, InType>::value();
|
||||||
|
|
||||||
|
if (plan.type == ContiguousAllReduce) {
|
||||||
|
cub_all_reduce(
|
||||||
|
encoder, in_iter, out_ptr, in.data_size(), OP(), init, stream);
|
||||||
|
} else if (plan.type == ContiguousReduce) {
|
||||||
|
auto offsets = thrust::make_transform_iterator(
|
||||||
|
thrust::make_counting_iterator(0), MultiplyOp{plan.shape.back()});
|
||||||
|
cub_segmented_reduce(
|
||||||
|
encoder,
|
||||||
|
in_iter,
|
||||||
|
out_ptr,
|
||||||
|
out.size(),
|
||||||
|
offsets,
|
||||||
|
offsets + 1,
|
||||||
|
OP(),
|
||||||
|
init,
|
||||||
|
stream);
|
||||||
|
} else {
|
||||||
|
throw std::runtime_error("Unsupported plan in segmented_reduce.");
|
||||||
|
}
|
||||||
|
});
|
||||||
|
});
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core
|
||||||
51
mlx/backend/cuda/scaled_dot_product_attention.cu
Normal file
51
mlx/backend/cuda/scaled_dot_product_attention.cu
Normal file
@@ -0,0 +1,51 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#include "mlx/backend/cuda/device.h"
|
||||||
|
#include "mlx/fast_primitives.h"
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
|
||||||
|
namespace cu {} // namespace cu
|
||||||
|
|
||||||
|
namespace fast {
|
||||||
|
|
||||||
|
bool ScaledDotProductAttention::use_fallback(
|
||||||
|
const array& q,
|
||||||
|
const array& k,
|
||||||
|
const array& v,
|
||||||
|
bool has_mask,
|
||||||
|
bool has_arr_mask,
|
||||||
|
bool do_causal,
|
||||||
|
Stream s) {
|
||||||
|
if (detail::in_grad_tracing()) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
if (s.device == Device::cpu) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
const int value_head_dim = v.shape(-1);
|
||||||
|
const int query_head_dim = q.shape(-1);
|
||||||
|
const int query_sequence_length = q.shape(2);
|
||||||
|
const int key_sequence_length = k.shape(2);
|
||||||
|
|
||||||
|
const bool sdpa_vector_supported_head_dim =
|
||||||
|
query_head_dim == value_head_dim &&
|
||||||
|
(query_head_dim == 64 || query_head_dim == 96 || query_head_dim == 128 ||
|
||||||
|
query_head_dim == 256);
|
||||||
|
const bool supports_sdpa_vector = (query_sequence_length <= 1) &&
|
||||||
|
(query_sequence_length <= key_sequence_length) &&
|
||||||
|
sdpa_vector_supported_head_dim;
|
||||||
|
|
||||||
|
return !supports_sdpa_vector;
|
||||||
|
}
|
||||||
|
|
||||||
|
void ScaledDotProductAttention::eval_gpu(
|
||||||
|
const std::vector<array>& inputs,
|
||||||
|
array& out) {
|
||||||
|
nvtx3::scoped_range r("ScaledDotProductAttention::eval_gpu");
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace fast
|
||||||
|
|
||||||
|
} // namespace mlx::core
|
||||||
@@ -51,7 +51,7 @@ __global__ void softmax(const T* in, T* out, int axis_size) {
|
|||||||
make_cast_iterator<AccT>(in),
|
make_cast_iterator<AccT>(in),
|
||||||
vals,
|
vals,
|
||||||
axis_size,
|
axis_size,
|
||||||
Limits<AccT>::min());
|
Limits<AccT>::finite_min());
|
||||||
prevmax = maxval;
|
prevmax = maxval;
|
||||||
maxval = max_op(maxval, cub::ThreadReduce(vals, max_op));
|
maxval = max_op(maxval, cub::ThreadReduce(vals, max_op));
|
||||||
// Online normalizer calculation for softmax:
|
// Online normalizer calculation for softmax:
|
||||||
@@ -79,7 +79,7 @@ __global__ void softmax(const T* in, T* out, int axis_size) {
|
|||||||
block.sync();
|
block.sync();
|
||||||
maxval = warp.thread_rank() < warp.meta_group_size()
|
maxval = warp.thread_rank() < warp.meta_group_size()
|
||||||
? local_max[warp.thread_rank()]
|
? local_max[warp.thread_rank()]
|
||||||
: Limits<AccT>::min();
|
: Limits<AccT>::finite_min();
|
||||||
maxval = cg::reduce(warp, maxval, max_op);
|
maxval = cg::reduce(warp, maxval, max_op);
|
||||||
normalizer = normalizer * softmax_exp(prevmax - maxval);
|
normalizer = normalizer * softmax_exp(prevmax - maxval);
|
||||||
if (warp.thread_rank() == 0) {
|
if (warp.thread_rank() == 0) {
|
||||||
|
|||||||
@@ -79,10 +79,14 @@ void segmented_sort(cu::CommandEncoder& encoder, Args&&... args) {
|
|||||||
void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) {
|
void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) {
|
||||||
array out = out_;
|
array out = out_;
|
||||||
auto& encoder = cu::get_command_encoder(s);
|
auto& encoder = cu::get_command_encoder(s);
|
||||||
|
encoder.set_input_array(in);
|
||||||
|
encoder.set_output_array(out);
|
||||||
|
|
||||||
if (axis < 0) {
|
if (axis < 0) {
|
||||||
axis += in.ndim();
|
axis += in.ndim();
|
||||||
}
|
}
|
||||||
int nsort = in.shape(axis);
|
int nsort = in.shape(axis);
|
||||||
|
int nsegments = in.data_size() / nsort;
|
||||||
int last_dim = in.ndim() - 1;
|
int last_dim = in.ndim() - 1;
|
||||||
|
|
||||||
// If we are not sorting the innermost dimension of a contiguous array,
|
// If we are not sorting the innermost dimension of a contiguous array,
|
||||||
@@ -96,15 +100,9 @@ void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) {
|
|||||||
out = array(allocator::malloc(out.nbytes()), in.shape(), out.dtype());
|
out = array(allocator::malloc(out.nbytes()), in.shape(), out.dtype());
|
||||||
encoder.add_temporary(out);
|
encoder.add_temporary(out);
|
||||||
} else {
|
} else {
|
||||||
out.set_data(
|
out.set_data(allocator::malloc(out.nbytes()));
|
||||||
allocator::malloc(in.data_size() * out.itemsize()),
|
|
||||||
in.data_size(),
|
|
||||||
in.strides(),
|
|
||||||
in.flags());
|
|
||||||
}
|
}
|
||||||
|
|
||||||
encoder.set_input_array(in);
|
|
||||||
encoder.set_output_array(out);
|
|
||||||
encoder.launch_kernel([&](cudaStream_t stream) {
|
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||||
MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE, {
|
MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE, {
|
||||||
if constexpr (!std::is_same_v<CTYPE, complex64_t>) {
|
if constexpr (!std::is_same_v<CTYPE, complex64_t>) {
|
||||||
@@ -136,7 +134,7 @@ void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) {
|
|||||||
indices.data<uint32_t>(),
|
indices.data<uint32_t>(),
|
||||||
out.data<uint32_t>(),
|
out.data<uint32_t>(),
|
||||||
in.data_size(),
|
in.data_size(),
|
||||||
in.data_size() / nsort,
|
nsegments,
|
||||||
offsets,
|
offsets,
|
||||||
offsets + 1,
|
offsets + 1,
|
||||||
stream);
|
stream);
|
||||||
@@ -146,7 +144,7 @@ void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) {
|
|||||||
in.data<Type>(),
|
in.data<Type>(),
|
||||||
out.data<Type>(),
|
out.data<Type>(),
|
||||||
in.data_size(),
|
in.data_size(),
|
||||||
in.data_size() / nsort,
|
nsegments,
|
||||||
offsets,
|
offsets,
|
||||||
offsets + 1,
|
offsets + 1,
|
||||||
stream);
|
stream);
|
||||||
@@ -179,14 +177,4 @@ void Sort::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
gpu_sort(stream(), inputs[0], out, axis_, false);
|
gpu_sort(stream(), inputs[0], out, axis_, false);
|
||||||
}
|
}
|
||||||
|
|
||||||
void ArgPartition::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|
||||||
nvtx3::scoped_range r("ArgPartition::eval_gpu");
|
|
||||||
gpu_sort(stream(), inputs[0], out, axis_, true);
|
|
||||||
}
|
|
||||||
|
|
||||||
void Partition::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|
||||||
nvtx3::scoped_range r("Partition::eval_gpu");
|
|
||||||
gpu_sort(stream(), inputs[0], out, axis_, false);
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
|||||||
@@ -80,9 +80,7 @@ void Worker::thread_fn() {
|
|||||||
}
|
}
|
||||||
worker_tasks_.erase(worker_tasks_.begin(), end);
|
worker_tasks_.erase(worker_tasks_.begin(), end);
|
||||||
}
|
}
|
||||||
// Make sure tasks are cleared before the next wait
|
for (auto& task : tasks) {
|
||||||
for (int i = 0; i < tasks.size(); ++i) {
|
|
||||||
auto task = std::move(tasks[i]);
|
|
||||||
task();
|
task();
|
||||||
}
|
}
|
||||||
worker_event_.wait(batch + 1);
|
worker_event_.wait(batch + 1);
|
||||||
|
|||||||
@@ -245,30 +245,6 @@ void merge(array& dst, array& src, ParentsMap& parents_map) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Any parent in the divider will continue to refer to `x` but any parent not
|
|
||||||
// in the divider will refer to a copy of the operation.
|
|
||||||
array split_one(
|
|
||||||
const array& x,
|
|
||||||
ParentsMap& parents_map,
|
|
||||||
const std::unordered_set<uintptr_t>& divider) {
|
|
||||||
array y(x.shape(), x.dtype(), x.primitive_ptr(), x.inputs());
|
|
||||||
|
|
||||||
auto& x_parents = parents_map[x.id()];
|
|
||||||
auto& y_parents = parents_map[y.id()];
|
|
||||||
|
|
||||||
for (auto it = x_parents.begin(); it != x_parents.end();) {
|
|
||||||
if (divider.find(it->first.id()) != divider.end()) {
|
|
||||||
it->first.inputs()[it->second] = y;
|
|
||||||
y_parents.emplace_back(std::move(*it));
|
|
||||||
it = x_parents.erase(it);
|
|
||||||
} else {
|
|
||||||
it++;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return std::move(y);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T, typename... U>
|
template <typename T, typename... U>
|
||||||
std::uintptr_t get_function_address(const std::function<T(U...)>& fun) {
|
std::uintptr_t get_function_address(const std::function<T(U...)>& fun) {
|
||||||
using FunType = T (*)(U...);
|
using FunType = T (*)(U...);
|
||||||
@@ -693,16 +669,10 @@ void compile_fuse(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Arrays with a mix of parents outside the compilable section
|
// Arrays with a mix of parents outside the compilable section
|
||||||
// are not fusable except for broadcast which we can split to avoid
|
// are not fusable
|
||||||
// stopping fusion
|
|
||||||
if (!all_parents_in) {
|
if (!all_parents_in) {
|
||||||
if (a.has_primitive() && is_broadcast(a.primitive())) {
|
// Possible input
|
||||||
array b = split_one(a, parents_map, cache);
|
input_set.insert(a.id());
|
||||||
recurse(b, depth, s, shape);
|
|
||||||
} else {
|
|
||||||
// Possible input
|
|
||||||
input_set.insert(a.id());
|
|
||||||
}
|
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -546,7 +546,7 @@ class GELU(Module):
|
|||||||
|
|
||||||
See :func:`gelu`, :func:`gelu_approx` and :func:`gelu_fast_approx` for the
|
See :func:`gelu`, :func:`gelu_approx` and :func:`gelu_fast_approx` for the
|
||||||
functional equivalents and information regarding error bounds.
|
functional equivalents and information regarding error bounds.
|
||||||
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
approx ('none' | 'precise' | 'fast'): Which approximation to gelu to use if any.
|
approx ('none' | 'precise' | 'fast'): Which approximation to gelu to use if any.
|
||||||
@@ -554,19 +554,20 @@ class GELU(Module):
|
|||||||
|
|
||||||
def __init__(self, approx="none"):
|
def __init__(self, approx="none"):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self._approx = approx
|
|
||||||
allowed = ["none", "precise", "tanh", "fast"]
|
if approx == "none":
|
||||||
if approx not in allowed:
|
self._act = gelu
|
||||||
|
elif approx == "precise" or approx == "tanh":
|
||||||
|
self._act = gelu_approx
|
||||||
|
elif approx == "fast":
|
||||||
|
self._act = gelu_fast_approx
|
||||||
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"The approximation should be in {allowed} but '{approx}' was given"
|
f"The approximation should be in ['none', 'precise', 'tanh', 'fast'] but '{approx}' was given"
|
||||||
)
|
)
|
||||||
|
|
||||||
def __call__(self, x):
|
def __call__(self, x):
|
||||||
if self._approx == "none":
|
return self._act(x)
|
||||||
return gelu(x)
|
|
||||||
elif self._approx in ["precise", "tanh"]:
|
|
||||||
return gelu_approx(x)
|
|
||||||
return gelu_fast_approx(x)
|
|
||||||
|
|
||||||
|
|
||||||
@_make_activation_module(tanh)
|
@_make_activation_module(tanh)
|
||||||
|
|||||||
@@ -404,7 +404,7 @@ class Module(dict):
|
|||||||
dst[k] = new_value
|
dst[k] = new_value
|
||||||
elif isinstance(current_value, (dict, list)):
|
elif isinstance(current_value, (dict, list)):
|
||||||
apply(current_value, new_value)
|
apply(current_value, new_value)
|
||||||
elif strict and new_value != {}:
|
elif strict:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Received invalid type: {type(new_value).__name__}."
|
f"Received invalid type: {type(new_value).__name__}."
|
||||||
)
|
)
|
||||||
@@ -413,14 +413,14 @@ class Module(dict):
|
|||||||
f'Module does not have sub-module named "{k}".'
|
f'Module does not have sub-module named "{k}".'
|
||||||
)
|
)
|
||||||
elif isinstance(modules, list):
|
elif isinstance(modules, list):
|
||||||
for i in range(len(modules)):
|
for i in range(len(dst)):
|
||||||
current_value = dst[i]
|
current_value = dst[i]
|
||||||
new_value = modules[i]
|
new_value = modules[i]
|
||||||
if self.is_module(current_value) and self.is_module(new_value):
|
if self.is_module(current_value) and self.is_module(new_value):
|
||||||
dst[i] = new_value
|
dst[i] = new_value
|
||||||
elif isinstance(current_value, (dict, list)):
|
elif isinstance(current_value, (dict, list)):
|
||||||
apply(current_value, new_value)
|
apply(current_value, new_value)
|
||||||
elif strict and new_value != {}:
|
elif strict:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Received invalid type: {type(new_value).__name__}."
|
f"Received invalid type: {type(new_value).__name__}."
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -1,17 +0,0 @@
|
|||||||
#!/bin/bash
|
|
||||||
|
|
||||||
auditwheel repair dist/* \
|
|
||||||
--plat manylinux_2_35_x86_64 \
|
|
||||||
--exclude libcublas* \
|
|
||||||
--exclude libnvrtc*
|
|
||||||
|
|
||||||
cd wheelhouse
|
|
||||||
repaired_wheel=$(find . -name "*.whl" -print -quit)
|
|
||||||
unzip -q "${repaired_wheel}"
|
|
||||||
core_so=$(find mlx -name "core*.so" -print -quit)
|
|
||||||
rpath=$(patchelf --print-rpath "${core_so}")
|
|
||||||
rpath=$rpath:\$ORIGIN/../nvidia/cublas/lib:\$ORIGIN/../nvidia/cuda_nvrtc/lib
|
|
||||||
patchelf --force-rpath --set-rpath "$rpath" "$core_so"
|
|
||||||
|
|
||||||
# Re-zip the repaired wheel
|
|
||||||
zip -r -q "${repaired_wheel}" .
|
|
||||||
@@ -205,8 +205,6 @@ nb::object to_scalar(mx::array& a) {
|
|||||||
return nb::cast(static_cast<float>(a.item<mx::bfloat16_t>()));
|
return nb::cast(static_cast<float>(a.item<mx::bfloat16_t>()));
|
||||||
case mx::complex64:
|
case mx::complex64:
|
||||||
return nb::cast(a.item<std::complex<float>>());
|
return nb::cast(a.item<std::complex<float>>());
|
||||||
case mx::float64:
|
|
||||||
return nb::cast(a.item<double>());
|
|
||||||
default:
|
default:
|
||||||
throw nb::type_error("type cannot be converted to Python scalar.");
|
throw nb::type_error("type cannot be converted to Python scalar.");
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,19 +1,37 @@
|
|||||||
cuda_skip = {
|
cuda_skip = {
|
||||||
"TestArray.test_api",
|
"TestArray.test_api",
|
||||||
|
"TestAutograd.test_update_state",
|
||||||
"TestBF16.test_arg_reduction_ops",
|
"TestBF16.test_arg_reduction_ops",
|
||||||
|
"TestBF16.test_reduction_ops",
|
||||||
"TestBlas.test_complex_gemm",
|
"TestBlas.test_complex_gemm",
|
||||||
|
"TestCompile.test_compile_dynamic_dims",
|
||||||
"TestEinsum.test_ellipses",
|
"TestEinsum.test_ellipses",
|
||||||
"TestEinsum.test_opt_einsum_test_cases",
|
"TestEinsum.test_opt_einsum_test_cases",
|
||||||
"TestLoad.test_load_f8_e4m3",
|
"TestLoad.test_load_f8_e4m3",
|
||||||
|
"TestMemory.test_memory_info",
|
||||||
"TestLayers.test_group_norm",
|
"TestLayers.test_group_norm",
|
||||||
"TestLayers.test_pooling",
|
"TestLayers.test_pooling",
|
||||||
"TestLayers.test_quantized_embedding",
|
"TestLayers.test_quantized_embedding",
|
||||||
"TestLayers.test_sin_pe",
|
"TestLayers.test_sin_pe",
|
||||||
"TestLayers.test_upsample",
|
"TestLayers.test_upsample",
|
||||||
|
"TestOps.test_array_equal",
|
||||||
"TestOps.test_complex_ops",
|
"TestOps.test_complex_ops",
|
||||||
"TestOps.test_dynamic_slicing",
|
"TestOps.test_dynamic_slicing",
|
||||||
|
"TestOps.test_softmax",
|
||||||
|
"TestOps.test_sort",
|
||||||
|
"TestOps.test_tile",
|
||||||
|
"TestReduce.test_axis_permutation_sums",
|
||||||
"TestReduce.test_dtypes",
|
"TestReduce.test_dtypes",
|
||||||
|
"TestReduce.test_expand_sums",
|
||||||
|
"TestReduce.test_many_reduction_axes",
|
||||||
"TestUpsample.test_torch_upsample",
|
"TestUpsample.test_torch_upsample",
|
||||||
|
# DivMod NYI
|
||||||
|
"TestOps.test_divmod",
|
||||||
|
"TestEval.test_multi_output_eval_during_transform",
|
||||||
|
# Partition NYI
|
||||||
|
"TestAutograd.test_topk_grad",
|
||||||
|
"TestOps.test_argpartition",
|
||||||
|
"TestOps.test_partition",
|
||||||
# Block masked matmul NYI
|
# Block masked matmul NYI
|
||||||
"TestBlas.test_block_masked_matmul",
|
"TestBlas.test_block_masked_matmul",
|
||||||
# Gather matmul NYI
|
# Gather matmul NYI
|
||||||
|
|||||||
@@ -2,10 +2,8 @@
|
|||||||
|
|
||||||
import gc
|
import gc
|
||||||
import io
|
import io
|
||||||
import math
|
|
||||||
import unittest
|
import unittest
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from io import StringIO
|
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx_tests
|
import mlx_tests
|
||||||
@@ -981,39 +979,6 @@ class TestCompile(mlx_tests.MLXTestCase):
|
|||||||
|
|
||||||
self.assertEqual(mem_pre, mem_post)
|
self.assertEqual(mem_pre, mem_post)
|
||||||
|
|
||||||
def test_double_constant(self):
|
|
||||||
with mx.stream(mx.cpu):
|
|
||||||
x = mx.array(1.0, dtype=mx.float64)
|
|
||||||
|
|
||||||
def fun(x):
|
|
||||||
return (x + math.pi) * 2.0
|
|
||||||
|
|
||||||
y = fun(x).item()
|
|
||||||
y_compiled = mx.compile(fun)(x).item()
|
|
||||||
self.assertEqual(y, y_compiled)
|
|
||||||
|
|
||||||
def test_shared_broadcast(self):
|
|
||||||
def fun(x, y, z):
|
|
||||||
yy = mx.broadcast_to(y, z.shape)
|
|
||||||
return (x + yy * z), yy.sum()
|
|
||||||
|
|
||||||
a = mx.random.normal((10, 10))
|
|
||||||
b = mx.array(0.1)
|
|
||||||
c = mx.random.normal((10, 10))
|
|
||||||
mx.eval(a, b, c)
|
|
||||||
fc = mx.compile(fun)
|
|
||||||
d = fc(a, b, c)
|
|
||||||
|
|
||||||
s = StringIO()
|
|
||||||
mx.export_to_dot(s, a=a, b=b, c=c, d1=d[0], d2=d[1])
|
|
||||||
s.seek(0)
|
|
||||||
s = s.read()
|
|
||||||
|
|
||||||
self.assertTrue("CompiledBroadcastMultiplyAdd" in s)
|
|
||||||
d_hat = fun(a, b, c)
|
|
||||||
self.assertTrue(mx.allclose(d[0], d_hat[0]))
|
|
||||||
self.assertTrue(mx.allclose(d[1], d_hat[1]))
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
mlx_tests.MLXTestRunner()
|
mlx_tests.MLXTestRunner()
|
||||||
|
|||||||
@@ -259,21 +259,6 @@ class TestBase(mlx_tests.MLXTestCase):
|
|||||||
with self.assertRaises(ValueError):
|
with self.assertRaises(ValueError):
|
||||||
m = m.update_modules({"list": ["hi"]})
|
m = m.update_modules({"list": ["hi"]})
|
||||||
|
|
||||||
# Allow updating a strict subset
|
|
||||||
m = nn.Sequential(nn.Linear(3, 3), nn.Linear(3, 3))
|
|
||||||
m.update_modules({"layers": [{}, nn.Linear(3, 4)]})
|
|
||||||
self.assertEqual(m.layers[1].weight.shape, (4, 3))
|
|
||||||
|
|
||||||
# Using leaf_modules in the update should always work
|
|
||||||
class MyModel(nn.Module):
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__()
|
|
||||||
self.stuff = [nn.Linear(2, 2), 0, nn.Linear(2, 2)]
|
|
||||||
self.more_stuff = {"hi": nn.Linear(2, 2), "bye": 0}
|
|
||||||
|
|
||||||
m = MyModel()
|
|
||||||
m.update_modules(m.leaf_modules())
|
|
||||||
|
|
||||||
|
|
||||||
class TestLayers(mlx_tests.MLXTestCase):
|
class TestLayers(mlx_tests.MLXTestCase):
|
||||||
def test_identity(self):
|
def test_identity(self):
|
||||||
|
|||||||
8
setup.py
8
setup.py
@@ -174,26 +174,20 @@ if __name__ == "__main__":
|
|||||||
)
|
)
|
||||||
package_dir = {"": "python"}
|
package_dir = {"": "python"}
|
||||||
package_data = {"mlx": ["lib/*", "include/*", "share/*"], "mlx.core": ["*.pyi"]}
|
package_data = {"mlx": ["lib/*", "include/*", "share/*"], "mlx.core": ["*.pyi"]}
|
||||||
install_requires = []
|
|
||||||
build_cuda = "MLX_BUILD_CUDA=ON" in os.environ.get("CMAKE_ARGS", "")
|
|
||||||
if build_cuda:
|
|
||||||
install_requires = ["nvidia-cublas-cu12", "nvidia-cuda-nvrtc-cu12"]
|
|
||||||
|
|
||||||
setup(
|
setup(
|
||||||
name="mlx-cuda" if build_cuda else "mlx",
|
name="mlx",
|
||||||
version=get_version(),
|
version=get_version(),
|
||||||
author="MLX Contributors",
|
author="MLX Contributors",
|
||||||
author_email="mlx@group.apple.com",
|
author_email="mlx@group.apple.com",
|
||||||
description="A framework for machine learning on Apple silicon.",
|
description="A framework for machine learning on Apple silicon.",
|
||||||
long_description=long_description,
|
long_description=long_description,
|
||||||
long_description_content_type="text/markdown",
|
long_description_content_type="text/markdown",
|
||||||
license="MIT",
|
|
||||||
url="https://github.com/ml-explore/mlx",
|
url="https://github.com/ml-explore/mlx",
|
||||||
packages=packages,
|
packages=packages,
|
||||||
package_dir=package_dir,
|
package_dir=package_dir,
|
||||||
package_data=package_data,
|
package_data=package_data,
|
||||||
include_package_data=True,
|
include_package_data=True,
|
||||||
install_requires=install_requires,
|
|
||||||
extras_require={
|
extras_require={
|
||||||
"dev": [
|
"dev": [
|
||||||
"nanobind==2.4.0",
|
"nanobind==2.4.0",
|
||||||
|
|||||||
Reference in New Issue
Block a user