mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 01:17:26 +08:00
Compare commits
14 Commits
0ce20290b9
...
fd1d0821d2
Author | SHA1 | Date | |
---|---|---|---|
![]() |
fd1d0821d2 | ||
![]() |
818e8e663e | ||
![]() |
cc4b995723 | ||
![]() |
664d8e42b8 | ||
![]() |
abdb21f27c | ||
![]() |
880751a084 | ||
![]() |
cd523ffd9f | ||
![]() |
4d2b682a13 | ||
![]() |
b70a964cde | ||
![]() |
9cf7ef1068 | ||
![]() |
ab7c310914 | ||
![]() |
5adf185f86 | ||
![]() |
c9a9180584 | ||
![]() |
76831ed83d |
@ -16,6 +16,9 @@ parameters:
|
|||||||
linux_release:
|
linux_release:
|
||||||
type: boolean
|
type: boolean
|
||||||
default: false
|
default: false
|
||||||
|
cuda_release:
|
||||||
|
type: boolean
|
||||||
|
default: false
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
build_documentation:
|
build_documentation:
|
||||||
@ -104,7 +107,7 @@ jobs:
|
|||||||
command: |
|
command: |
|
||||||
echo "stubs"
|
echo "stubs"
|
||||||
pip install typing_extensions
|
pip install typing_extensions
|
||||||
python setup.py generate_stubs
|
python setup.py generate_stubs
|
||||||
- run:
|
- run:
|
||||||
name: Run Python tests
|
name: Run Python tests
|
||||||
command: |
|
command: |
|
||||||
@ -162,7 +165,7 @@ jobs:
|
|||||||
command: |
|
command: |
|
||||||
source env/bin/activate
|
source env/bin/activate
|
||||||
pip install typing_extensions
|
pip install typing_extensions
|
||||||
python setup.py generate_stubs
|
python setup.py generate_stubs
|
||||||
- run:
|
- run:
|
||||||
name: Run Python tests
|
name: Run Python tests
|
||||||
command: |
|
command: |
|
||||||
@ -223,7 +226,6 @@ jobs:
|
|||||||
command: |
|
command: |
|
||||||
sudo apt-get update
|
sudo apt-get update
|
||||||
sudo apt-get install libblas-dev liblapack-dev liblapacke-dev
|
sudo apt-get install libblas-dev liblapack-dev liblapacke-dev
|
||||||
sudo apt-get install openmpi-bin openmpi-common libopenmpi-dev
|
|
||||||
python -m venv env
|
python -m venv env
|
||||||
source env/bin/activate
|
source env/bin/activate
|
||||||
CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
|
CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
|
||||||
@ -283,7 +285,7 @@ jobs:
|
|||||||
command: |
|
command: |
|
||||||
source env/bin/activate
|
source env/bin/activate
|
||||||
pip install typing_extensions
|
pip install typing_extensions
|
||||||
python setup.py generate_stubs
|
python setup.py generate_stubs
|
||||||
- run:
|
- run:
|
||||||
name: Build Python package
|
name: Build Python package
|
||||||
command: |
|
command: |
|
||||||
@ -342,7 +344,7 @@ jobs:
|
|||||||
CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
|
CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
|
||||||
pip install . -v
|
pip install . -v
|
||||||
pip install typing_extensions
|
pip install typing_extensions
|
||||||
python setup.py generate_stubs
|
python setup.py generate_stubs
|
||||||
<< parameters.extra_env >> \
|
<< parameters.extra_env >> \
|
||||||
CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
|
CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
|
||||||
python -m build --wheel
|
python -m build --wheel
|
||||||
@ -356,6 +358,48 @@ jobs:
|
|||||||
- store_artifacts:
|
- store_artifacts:
|
||||||
path: wheelhouse/
|
path: wheelhouse/
|
||||||
|
|
||||||
|
build_cuda_release:
|
||||||
|
parameters:
|
||||||
|
python_version:
|
||||||
|
type: string
|
||||||
|
default: "3.9"
|
||||||
|
extra_env:
|
||||||
|
type: string
|
||||||
|
default: "DEV_RELEASE=1"
|
||||||
|
machine:
|
||||||
|
image: linux-cuda-12:default
|
||||||
|
resource_class: gpu.nvidia.small.gen2
|
||||||
|
steps:
|
||||||
|
- checkout
|
||||||
|
- run:
|
||||||
|
name: Build wheel
|
||||||
|
command: |
|
||||||
|
sudo apt-get update
|
||||||
|
sudo apt-get install libblas-dev liblapack-dev liblapacke-dev
|
||||||
|
python -m venv env
|
||||||
|
source env/bin/activate
|
||||||
|
pip install auditwheel
|
||||||
|
pip install patchelf
|
||||||
|
pip install build
|
||||||
|
pip install twine
|
||||||
|
<< parameters.extra_env >> \
|
||||||
|
CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
|
||||||
|
CMAKE_ARGS="-DMLX_BUILD_CUDA=ON -DCMAKE_CUDA_COMPILER=`which nvcc`" \
|
||||||
|
pip install ".[dev]" -v
|
||||||
|
python setup.py generate_stubs
|
||||||
|
<< parameters.extra_env >> \
|
||||||
|
CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
|
||||||
|
CMAKE_ARGS="-DMLX_BUILD_CUDA=ON -DCMAKE_CUDA_COMPILER=`which nvcc`" \
|
||||||
|
python -m build --wheel
|
||||||
|
bash python/scripts/repair_cuda.sh
|
||||||
|
- run:
|
||||||
|
name: Upload package
|
||||||
|
command: |
|
||||||
|
source env/bin/activate
|
||||||
|
twine upload wheelhouse/*.whl
|
||||||
|
- store_artifacts:
|
||||||
|
path: wheelhouse/
|
||||||
|
|
||||||
workflows:
|
workflows:
|
||||||
build_and_test:
|
build_and_test:
|
||||||
when:
|
when:
|
||||||
@ -625,3 +669,14 @@ workflows:
|
|||||||
parameters:
|
parameters:
|
||||||
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
|
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
|
||||||
extra_env: ["PYPI_RELEASE=1"]
|
extra_env: ["PYPI_RELEASE=1"]
|
||||||
|
cuda_test_release:
|
||||||
|
when:
|
||||||
|
and:
|
||||||
|
- equal: [ main, << pipeline.git.branch >> ]
|
||||||
|
- << pipeline.parameters.cuda_release >>
|
||||||
|
jobs:
|
||||||
|
- build_cuda_release:
|
||||||
|
matrix:
|
||||||
|
parameters:
|
||||||
|
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
|
||||||
|
extra_env: ["PYPI_RELEASE=1"]
|
||||||
|
@ -5,6 +5,7 @@ import os
|
|||||||
import time
|
import time
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
import torch.cuda
|
||||||
import torch.mps
|
import torch.mps
|
||||||
|
|
||||||
|
|
||||||
@ -44,8 +45,10 @@ def bench(f, *args):
|
|||||||
|
|
||||||
|
|
||||||
def sync_if_needed(x):
|
def sync_if_needed(x):
|
||||||
if x.device != torch.device("cpu"):
|
if x.device == torch.device("mps"):
|
||||||
torch.mps.synchronize()
|
torch.mps.synchronize()
|
||||||
|
elif x.device == torch.device("cuda"):
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
@ -99,6 +102,14 @@ def reduction(op, axis, x):
|
|||||||
sync_if_needed(x)
|
sync_if_needed(x)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def sum_and_add(axis, x, y):
|
||||||
|
z = x.sum(axis=axis, keepdims=True)
|
||||||
|
for i in range(50):
|
||||||
|
z = (z + y).sum(axis=axis, keepdims=True)
|
||||||
|
sync_if_needed(x)
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def softmax(axis, x):
|
def softmax(axis, x):
|
||||||
ys = []
|
ys = []
|
||||||
@ -340,7 +351,11 @@ if __name__ == "__main__":
|
|||||||
args.axis.pop(0)
|
args.axis.pop(0)
|
||||||
|
|
||||||
torch.set_num_threads(1)
|
torch.set_num_threads(1)
|
||||||
device = "cpu" if args.cpu else "mps"
|
device = "mps"
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
device = "cuda"
|
||||||
|
if args.cpu:
|
||||||
|
device = "cpu"
|
||||||
|
|
||||||
types = args.dtype
|
types = args.dtype
|
||||||
if not types:
|
if not types:
|
||||||
@ -460,5 +475,8 @@ if __name__ == "__main__":
|
|||||||
elif args.benchmark == "selu":
|
elif args.benchmark == "selu":
|
||||||
print(bench(selu, x))
|
print(bench(selu, x))
|
||||||
|
|
||||||
|
elif args.benchmark == "sum_and_add":
|
||||||
|
print(bench(sum_and_add, axis, *xs))
|
||||||
|
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown benchmark `{args.benchmark}`.")
|
raise ValueError(f"Unknown benchmark `{args.benchmark}`.")
|
||||||
|
@ -30,6 +30,16 @@ MLX is also available on conda-forge. To install MLX with conda do:
|
|||||||
|
|
||||||
conda install conda-forge::mlx
|
conda install conda-forge::mlx
|
||||||
|
|
||||||
|
CUDA
|
||||||
|
^^^^
|
||||||
|
|
||||||
|
MLX has a CUDA backend which you can use on any Linux platform with CUDA 12
|
||||||
|
and SM 7.0 (Volta) and up. To install MLX with CUDA support, run:
|
||||||
|
|
||||||
|
.. code-block:: shell
|
||||||
|
|
||||||
|
pip install mlx-cuda
|
||||||
|
|
||||||
|
|
||||||
Troubleshooting
|
Troubleshooting
|
||||||
^^^^^^^^^^^^^^^
|
^^^^^^^^^^^^^^^
|
||||||
@ -65,6 +75,8 @@ Build Requirements
|
|||||||
Python API
|
Python API
|
||||||
^^^^^^^^^^
|
^^^^^^^^^^
|
||||||
|
|
||||||
|
.. _python install:
|
||||||
|
|
||||||
To build and install the MLX python library from source, first, clone MLX from
|
To build and install the MLX python library from source, first, clone MLX from
|
||||||
`its GitHub repo <https://github.com/ml-explore/mlx>`_:
|
`its GitHub repo <https://github.com/ml-explore/mlx>`_:
|
||||||
|
|
||||||
@ -107,6 +119,8 @@ IDE:
|
|||||||
C++ API
|
C++ API
|
||||||
^^^^^^^
|
^^^^^^^
|
||||||
|
|
||||||
|
.. _cpp install:
|
||||||
|
|
||||||
Currently, MLX must be built and installed from source.
|
Currently, MLX must be built and installed from source.
|
||||||
|
|
||||||
Similarly to the python library, to build and install the MLX C++ library start
|
Similarly to the python library, to build and install the MLX C++ library start
|
||||||
@ -185,6 +199,7 @@ should point to the path to the built metal library.
|
|||||||
|
|
||||||
xcrun -sdk macosx --show-sdk-version
|
xcrun -sdk macosx --show-sdk-version
|
||||||
|
|
||||||
|
|
||||||
Binary Size Minimization
|
Binary Size Minimization
|
||||||
~~~~~~~~~~~~~~~~~~~~~~~~
|
~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
@ -213,6 +228,50 @@ be anwywhere from a few hundred millisecond to a few seconds depending on the
|
|||||||
application. Once a kernel is compiled, it will be cached by the system. The
|
application. Once a kernel is compiled, it will be cached by the system. The
|
||||||
Metal kernel cache persists across reboots.
|
Metal kernel cache persists across reboots.
|
||||||
|
|
||||||
|
Linux
|
||||||
|
^^^^^
|
||||||
|
|
||||||
|
To build from source on Linux (CPU only), install the BLAS and LAPACK headers.
|
||||||
|
For example on Ubuntu, run the following:
|
||||||
|
|
||||||
|
.. code-block:: shell
|
||||||
|
|
||||||
|
apt-get update -y
|
||||||
|
apt-get install libblas-dev liblapack-dev liblapacke-dev -y
|
||||||
|
|
||||||
|
From here follow the instructions to install either the :ref:`Python <python
|
||||||
|
install>` or :ref:`C++ <cpp install>` APIs.
|
||||||
|
|
||||||
|
CUDA
|
||||||
|
^^^^
|
||||||
|
|
||||||
|
To build from source on Linux with CUDA, install the BLAS and LAPACK headers
|
||||||
|
and the CUDA toolkit. For example on Ubuntu, run the following:
|
||||||
|
|
||||||
|
.. code-block:: shell
|
||||||
|
|
||||||
|
wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64/cuda-keyring_1.1-1_all.deb
|
||||||
|
dpkg -i cuda-keyring_1.1-1_all.deb
|
||||||
|
apt-get update -y
|
||||||
|
apt-get -y install cuda-toolkit-12-9
|
||||||
|
apt-get install libblas-dev liblapack-dev liblapacke-dev -y
|
||||||
|
|
||||||
|
|
||||||
|
When building either the Python or C++ APIs make sure to pass the cmake flag
|
||||||
|
``MLX_BUILD_CUDA=ON``. For example, to build the Python API run:
|
||||||
|
|
||||||
|
.. code-block:: shell
|
||||||
|
|
||||||
|
CMAKE_BUILD_PARALLEL_LEVEL=8 CMAKE_ARGS="-DMLX_BUILD_CUDA=ON" pip install -e ".[dev]"
|
||||||
|
|
||||||
|
To build the C++ package run:
|
||||||
|
|
||||||
|
.. code-block:: shell
|
||||||
|
|
||||||
|
mkdir -p build && cd build
|
||||||
|
cmake .. -DMLX_BUILD_CUDA=ON && make -j
|
||||||
|
|
||||||
|
|
||||||
Troubleshooting
|
Troubleshooting
|
||||||
^^^^^^^^^^^^^^^
|
^^^^^^^^^^^^^^^
|
||||||
|
|
||||||
|
@ -5,11 +5,9 @@
|
|||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
|
|
||||||
std::pair<Shape, Strides> shapes_without_reduction_axes(
|
std::pair<Shape, Strides> shapes_without_reduction_axes(
|
||||||
const array& x,
|
Shape shape,
|
||||||
|
Strides strides,
|
||||||
const std::vector<int>& axes) {
|
const std::vector<int>& axes) {
|
||||||
auto shape = x.shape();
|
|
||||||
auto strides = x.strides();
|
|
||||||
|
|
||||||
for (int i = axes.size() - 1; i >= 0; i--) {
|
for (int i = axes.size() - 1; i >= 0; i--) {
|
||||||
int a = axes[i];
|
int a = axes[i];
|
||||||
shape.erase(shape.begin() + a);
|
shape.erase(shape.begin() + a);
|
||||||
@ -19,6 +17,15 @@ std::pair<Shape, Strides> shapes_without_reduction_axes(
|
|||||||
return std::make_pair(shape, strides);
|
return std::make_pair(shape, strides);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::pair<Shape, Strides> shapes_without_reduction_axes(
|
||||||
|
const array& x,
|
||||||
|
const std::vector<int>& axes) {
|
||||||
|
auto shape = x.shape();
|
||||||
|
auto strides = x.strides();
|
||||||
|
return shapes_without_reduction_axes(
|
||||||
|
std::move(shape), std::move(strides), axes);
|
||||||
|
}
|
||||||
|
|
||||||
ReductionPlan get_reduction_plan(const array& x, const std::vector<int>& axes) {
|
ReductionPlan get_reduction_plan(const array& x, const std::vector<int>& axes) {
|
||||||
// The data is all there and we are reducing over everything
|
// The data is all there and we are reducing over everything
|
||||||
if (x.size() == x.data_size() && axes.size() == x.ndim() &&
|
if (x.size() == x.data_size() && axes.size() == x.ndim() &&
|
||||||
|
@ -51,5 +51,9 @@ ReductionPlan get_reduction_plan(const array& x, const std::vector<int>& axes);
|
|||||||
std::pair<Shape, Strides> shapes_without_reduction_axes(
|
std::pair<Shape, Strides> shapes_without_reduction_axes(
|
||||||
const array& x,
|
const array& x,
|
||||||
const std::vector<int>& axes);
|
const std::vector<int>& axes);
|
||||||
|
std::pair<Shape, Strides> shapes_without_reduction_axes(
|
||||||
|
Shape shape,
|
||||||
|
Strides strides,
|
||||||
|
const std::vector<int>& axes);
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
@ -29,9 +29,10 @@ target_sources(
|
|||||||
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cu
|
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cu
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/random.cu
|
${CMAKE_CURRENT_SOURCE_DIR}/random.cu
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/reduce.cu
|
${CMAKE_CURRENT_SOURCE_DIR}/reduce.cu
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/reduce/all_reduce.cu
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/reduce/col_reduce.cu
|
${CMAKE_CURRENT_SOURCE_DIR}/reduce/col_reduce.cu
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/reduce/init_reduce.cu
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/reduce/row_reduce.cu
|
${CMAKE_CURRENT_SOURCE_DIR}/reduce/row_reduce.cu
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/reduce/segmented_reduce.cu
|
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/rms_norm.cu
|
${CMAKE_CURRENT_SOURCE_DIR}/rms_norm.cu
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/rope.cu
|
${CMAKE_CURRENT_SOURCE_DIR}/rope.cu
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp
|
||||||
|
@ -3,6 +3,7 @@
|
|||||||
#include "mlx/backend/cuda/allocator.h"
|
#include "mlx/backend/cuda/allocator.h"
|
||||||
#include "mlx/backend/cuda/utils.h"
|
#include "mlx/backend/cuda/utils.h"
|
||||||
#include "mlx/backend/cuda/worker.h"
|
#include "mlx/backend/cuda/worker.h"
|
||||||
|
#include "mlx/utils.h"
|
||||||
|
|
||||||
#include <cuda_runtime.h>
|
#include <cuda_runtime.h>
|
||||||
#include <fmt/format.h>
|
#include <fmt/format.h>
|
||||||
@ -14,9 +15,11 @@ namespace mlx::core {
|
|||||||
|
|
||||||
namespace cu {
|
namespace cu {
|
||||||
|
|
||||||
|
constexpr int page_size = 16384;
|
||||||
|
|
||||||
CudaAllocator::CudaAllocator()
|
CudaAllocator::CudaAllocator()
|
||||||
: buffer_cache_(
|
: buffer_cache_(
|
||||||
getpagesize(),
|
page_size,
|
||||||
[](CudaBuffer* buf) { return buf->size; },
|
[](CudaBuffer* buf) { return buf->size; },
|
||||||
[this](CudaBuffer* buf) {
|
[this](CudaBuffer* buf) {
|
||||||
cuda_free(buf->data);
|
cuda_free(buf->data);
|
||||||
@ -31,7 +34,14 @@ CudaAllocator::CudaAllocator()
|
|||||||
|
|
||||||
Buffer CudaAllocator::malloc(size_t size) {
|
Buffer CudaAllocator::malloc(size_t size) {
|
||||||
// Find available buffer from cache.
|
// Find available buffer from cache.
|
||||||
|
auto orig_size = size;
|
||||||
std::unique_lock lock(mutex_);
|
std::unique_lock lock(mutex_);
|
||||||
|
if (size < page_size) {
|
||||||
|
size = next_power_of_2(size);
|
||||||
|
} else {
|
||||||
|
size = page_size * ((size + page_size - 1) / page_size);
|
||||||
|
}
|
||||||
|
|
||||||
CudaBuffer* buf = buffer_cache_.reuse_from_cache(size);
|
CudaBuffer* buf = buffer_cache_.reuse_from_cache(size);
|
||||||
if (!buf) {
|
if (!buf) {
|
||||||
// If we have a lot of memory pressure or are over the maximum cache size,
|
// If we have a lot of memory pressure or are over the maximum cache size,
|
||||||
|
@ -157,7 +157,7 @@ void binary_op_gpu_inplace(
|
|||||||
if (ndim <= 3) {
|
if (ndim <= 3) {
|
||||||
MLX_SWITCH_1_2_3(ndim, NDIM, {
|
MLX_SWITCH_1_2_3(ndim, NDIM, {
|
||||||
auto kernel =
|
auto kernel =
|
||||||
&cu::binary_g_nd<Op, InType, OutType, IdxT, NDIM>;
|
cu::binary_g_nd<Op, InType, OutType, IdxT, NDIM>;
|
||||||
auto [num_blocks, block_dims] =
|
auto [num_blocks, block_dims] =
|
||||||
get_launch_args(kernel, out_a, large);
|
get_launch_args(kernel, out_a, large);
|
||||||
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
||||||
|
@ -24,7 +24,6 @@ void copy_gpu_inplace(
|
|||||||
auto& encoder = cu::get_command_encoder(s);
|
auto& encoder = cu::get_command_encoder(s);
|
||||||
encoder.set_input_array(in);
|
encoder.set_input_array(in);
|
||||||
encoder.set_output_array(out);
|
encoder.set_output_array(out);
|
||||||
|
|
||||||
if (ctype == CopyType::Scalar || ctype == CopyType::Vector) {
|
if (ctype == CopyType::Scalar || ctype == CopyType::Vector) {
|
||||||
copy_contiguous(encoder, ctype, in, out, offset_in, offset_out);
|
copy_contiguous(encoder, ctype, in, out, offset_in, offset_out);
|
||||||
return;
|
return;
|
||||||
|
@ -114,7 +114,7 @@ void CommandEncoder::synchronize() {
|
|||||||
std::future<void> f = p->get_future();
|
std::future<void> f = p->get_future();
|
||||||
add_completed_handler([p = std::move(p)]() { p->set_value(); });
|
add_completed_handler([p = std::move(p)]() { p->set_value(); });
|
||||||
worker_.end_batch();
|
worker_.end_batch();
|
||||||
worker_.commit();
|
commit();
|
||||||
f.wait();
|
f.wait();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -155,8 +155,8 @@ inline __host__ __device__ cuda::std::tuple<IdxT, IdxT> elem_to_loc_nd(
|
|||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i = NDIM - 1; i >= 0; --i) {
|
for (int i = NDIM - 1; i >= 0; --i) {
|
||||||
int dim_idx = elem % shape[i];
|
int dim_idx = elem % shape[i];
|
||||||
a_loc += dim_idx * a_strides[i];
|
a_loc += dim_idx * IdxT(a_strides[i]);
|
||||||
b_loc += dim_idx * b_strides[i];
|
b_loc += dim_idx * IdxT(b_strides[i]);
|
||||||
elem /= shape[i];
|
elem /= shape[i];
|
||||||
}
|
}
|
||||||
return cuda::std::make_tuple(a_loc, b_loc);
|
return cuda::std::make_tuple(a_loc, b_loc);
|
||||||
@ -175,9 +175,9 @@ inline __host__ __device__ cuda::std::tuple<IdxT, IdxT, IdxT> elem_to_loc_nd(
|
|||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i = NDIM - 1; i >= 0; --i) {
|
for (int i = NDIM - 1; i >= 0; --i) {
|
||||||
int dim_idx = elem % shape[i];
|
int dim_idx = elem % shape[i];
|
||||||
a_loc += dim_idx * a_strides[i];
|
a_loc += dim_idx * IdxT(a_strides[i]);
|
||||||
b_loc += dim_idx * b_strides[i];
|
b_loc += dim_idx * IdxT(b_strides[i]);
|
||||||
c_loc += dim_idx * c_strides[i];
|
c_loc += dim_idx * IdxT(c_strides[i]);
|
||||||
elem /= shape[i];
|
elem /= shape[i];
|
||||||
}
|
}
|
||||||
return cuda::std::make_tuple(a_loc, b_loc, c_loc);
|
return cuda::std::make_tuple(a_loc, b_loc, c_loc);
|
||||||
@ -206,8 +206,8 @@ inline __host__ __device__ cuda::std::tuple<IdxT, IdxT> elem_to_loc_4d(
|
|||||||
IdxT b_loc = 0;
|
IdxT b_loc = 0;
|
||||||
for (int i = ndim - 1; i >= 0; --i) {
|
for (int i = ndim - 1; i >= 0; --i) {
|
||||||
int dim_idx = elem % shape[i];
|
int dim_idx = elem % shape[i];
|
||||||
a_loc += dim_idx * a_strides[i];
|
a_loc += dim_idx * IdxT(a_strides[i]);
|
||||||
b_loc += dim_idx * b_strides[i];
|
b_loc += dim_idx * IdxT(b_strides[i]);
|
||||||
elem /= shape[i];
|
elem /= shape[i];
|
||||||
}
|
}
|
||||||
return cuda::std::make_tuple(a_loc, b_loc);
|
return cuda::std::make_tuple(a_loc, b_loc);
|
||||||
@ -226,9 +226,9 @@ inline __host__ __device__ cuda::std::tuple<IdxT, IdxT, IdxT> elem_to_loc_4d(
|
|||||||
IdxT c_loc = 0;
|
IdxT c_loc = 0;
|
||||||
for (int i = ndim - 1; i >= 0; --i) {
|
for (int i = ndim - 1; i >= 0; --i) {
|
||||||
int dim_idx = elem % shape[i];
|
int dim_idx = elem % shape[i];
|
||||||
a_loc += dim_idx * a_strides[i];
|
a_loc += dim_idx * IdxT(a_strides[i]);
|
||||||
b_loc += dim_idx * b_strides[i];
|
b_loc += dim_idx * IdxT(b_strides[i]);
|
||||||
c_loc += dim_idx * c_strides[i];
|
c_loc += dim_idx * IdxT(c_strides[i]);
|
||||||
elem /= shape[i];
|
elem /= shape[i];
|
||||||
}
|
}
|
||||||
return cuda::std::make_tuple(a_loc, b_loc, c_loc);
|
return cuda::std::make_tuple(a_loc, b_loc, c_loc);
|
||||||
|
@ -162,11 +162,15 @@ class MatMul {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
array workspace(
|
void* workspace_ptr = nullptr;
|
||||||
allocator::malloc(heuristic_.workspaceSize),
|
if (heuristic_.workspaceSize > 0) {
|
||||||
{static_cast<int>(heuristic_.workspaceSize)},
|
array workspace(
|
||||||
int8);
|
allocator::malloc(heuristic_.workspaceSize),
|
||||||
encoder.add_temporary(workspace);
|
{static_cast<int>(heuristic_.workspaceSize)},
|
||||||
|
int8);
|
||||||
|
encoder.add_temporary(workspace);
|
||||||
|
workspace_ptr = workspace.data<void>();
|
||||||
|
}
|
||||||
|
|
||||||
encoder.launch_kernel([&](cudaStream_t stream) {
|
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||||
CHECK_CUBLAS_ERROR(cublasLtMatmul(
|
CHECK_CUBLAS_ERROR(cublasLtMatmul(
|
||||||
@ -183,8 +187,8 @@ class MatMul {
|
|||||||
out,
|
out,
|
||||||
out_desc_,
|
out_desc_,
|
||||||
&heuristic_.algo,
|
&heuristic_.algo,
|
||||||
workspace.data<void>(),
|
workspace_ptr,
|
||||||
workspace.nbytes(),
|
heuristic_.workspaceSize,
|
||||||
stream));
|
stream));
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
@ -358,9 +362,18 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
a_batch_strides.back(),
|
a_batch_strides.back(),
|
||||||
b_batch_strides.back());
|
b_batch_strides.back());
|
||||||
|
|
||||||
|
encoder.set_input_array(a);
|
||||||
|
encoder.set_input_array(b);
|
||||||
|
encoder.set_output_array(out);
|
||||||
|
auto nbatch = batch_count / batch_shape.back();
|
||||||
|
if (nbatch == 1) {
|
||||||
|
matmul.run(encoder, out.data<int8_t>(), a.data<int8_t>(), b.data<int8_t>());
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
ContiguousIterator a_it(batch_shape, a_batch_strides, batch_shape.size() - 1);
|
ContiguousIterator a_it(batch_shape, a_batch_strides, batch_shape.size() - 1);
|
||||||
ContiguousIterator b_it(batch_shape, b_batch_strides, batch_shape.size() - 1);
|
ContiguousIterator b_it(batch_shape, b_batch_strides, batch_shape.size() - 1);
|
||||||
for (size_t i = 0; i < batch_count / batch_shape.back(); ++i) {
|
for (size_t i = 0; i < nbatch; ++i) {
|
||||||
matmul.run(
|
matmul.run(
|
||||||
encoder,
|
encoder,
|
||||||
out.data<int8_t>() + out.itemsize() * i * batch_shape.back() * M * N,
|
out.data<int8_t>() + out.itemsize() * i * batch_shape.back() * M * N,
|
||||||
@ -444,10 +457,28 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
b_batch_strides.back(),
|
b_batch_strides.back(),
|
||||||
c_batch_strides.back());
|
c_batch_strides.back());
|
||||||
|
|
||||||
|
encoder.set_input_array(a);
|
||||||
|
encoder.set_input_array(b);
|
||||||
|
encoder.set_input_array(c);
|
||||||
|
encoder.set_output_array(out);
|
||||||
|
|
||||||
|
auto nbatch = batch_count / batch_shape.back();
|
||||||
|
if (nbatch == 1) {
|
||||||
|
matmul.run(
|
||||||
|
encoder,
|
||||||
|
out.data<int8_t>(),
|
||||||
|
a.data<int8_t>(),
|
||||||
|
b.data<int8_t>(),
|
||||||
|
c.data<int8_t>(),
|
||||||
|
alpha_,
|
||||||
|
beta_);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
ContiguousIterator a_it(batch_shape, a_batch_strides, batch_shape.size() - 1);
|
ContiguousIterator a_it(batch_shape, a_batch_strides, batch_shape.size() - 1);
|
||||||
ContiguousIterator b_it(batch_shape, b_batch_strides, batch_shape.size() - 1);
|
ContiguousIterator b_it(batch_shape, b_batch_strides, batch_shape.size() - 1);
|
||||||
ContiguousIterator c_it(batch_shape, c_batch_strides, batch_shape.size() - 1);
|
ContiguousIterator c_it(batch_shape, c_batch_strides, batch_shape.size() - 1);
|
||||||
for (size_t i = 0; i < batch_count / batch_shape.back(); ++i) {
|
for (size_t i = 0; i < nbatch; ++i) {
|
||||||
matmul.run(
|
matmul.run(
|
||||||
encoder,
|
encoder,
|
||||||
out.data<int8_t>() + out.itemsize() * i * batch_shape.back() * M * N,
|
out.data<int8_t>() + out.itemsize() * i * batch_shape.back() * M * N,
|
||||||
|
@ -21,28 +21,11 @@ void Reduce::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
assert(!axes_.empty());
|
assert(!axes_.empty());
|
||||||
assert(out.size() != in.size());
|
assert(out.size() != in.size());
|
||||||
|
|
||||||
out.set_data(allocator::malloc(out.nbytes()));
|
|
||||||
|
|
||||||
auto& s = stream();
|
auto& s = stream();
|
||||||
auto& encoder = cu::get_command_encoder(s);
|
auto& encoder = cu::get_command_encoder(s);
|
||||||
encoder.set_input_array(in);
|
|
||||||
encoder.set_output_array(out);
|
|
||||||
|
|
||||||
// Fill out with init value.
|
|
||||||
if (in.size() == 0) {
|
if (in.size() == 0) {
|
||||||
encoder.launch_kernel([&](cudaStream_t stream) {
|
init_reduce(encoder, in, out, reduce_type_);
|
||||||
MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE, {
|
|
||||||
MLX_SWITCH_REDUCE_OPS(reduce_type_, OP, {
|
|
||||||
using InType = cuda_type_t<CTYPE>;
|
|
||||||
using OutType = cu::ReduceResult<OP, InType>::type;
|
|
||||||
thrust::fill_n(
|
|
||||||
cu::thrust_policy(stream),
|
|
||||||
thrust::device_pointer_cast(out.data<OutType>()),
|
|
||||||
out.data_size(),
|
|
||||||
cu::ReduceInit<OP, InType>::value());
|
|
||||||
});
|
|
||||||
});
|
|
||||||
});
|
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -59,9 +42,8 @@ void Reduce::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
plan = get_reduction_plan(in, axes_);
|
plan = get_reduction_plan(in, axes_);
|
||||||
}
|
}
|
||||||
|
|
||||||
if ((plan.type == ContiguousAllReduce) ||
|
if (plan.type == ContiguousAllReduce) {
|
||||||
(plan.type == ContiguousReduce && plan.shape.size() == 1)) {
|
all_reduce(encoder, in, out, reduce_type_);
|
||||||
segmented_reduce(encoder, in, out, reduce_type_, axes_, plan);
|
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
140
mlx/backend/cuda/reduce/all_reduce.cu
Normal file
140
mlx/backend/cuda/reduce/all_reduce.cu
Normal file
@ -0,0 +1,140 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#include "mlx/backend/cuda/device.h"
|
||||||
|
#include "mlx/backend/cuda/reduce/reduce.cuh"
|
||||||
|
|
||||||
|
#include <cooperative_groups.h>
|
||||||
|
#include <cooperative_groups/reduce.h>
|
||||||
|
#include <cub/block/block_load.cuh>
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
|
||||||
|
namespace cu {
|
||||||
|
|
||||||
|
namespace cg = cooperative_groups;
|
||||||
|
|
||||||
|
template <typename T, typename U, typename ReduceOp, int N = 4>
|
||||||
|
__global__ void all_reduce(T* in, U* out, size_t block_step, size_t size) {
|
||||||
|
// TODO: Process multiple "rows" in each thread
|
||||||
|
constexpr int M = 1;
|
||||||
|
|
||||||
|
auto grid = cg::this_grid();
|
||||||
|
auto block = cg::this_thread_block();
|
||||||
|
auto warp = cg::tiled_partition<WARP_SIZE>(block);
|
||||||
|
|
||||||
|
const U init = cu::ReduceInit<ReduceOp, T>::value();
|
||||||
|
ReduceOp op;
|
||||||
|
|
||||||
|
T vals[N];
|
||||||
|
U accs[M];
|
||||||
|
accs[0] = init;
|
||||||
|
|
||||||
|
size_t start = grid.block_rank() * block_step;
|
||||||
|
size_t end = start + block_step;
|
||||||
|
size_t check = min(end, size);
|
||||||
|
|
||||||
|
for (size_t i = start; i + block.size() * N <= check; i += block.size() * N) {
|
||||||
|
cub::LoadDirectBlockedVectorized<T, N>(block.thread_rank(), in + i, vals);
|
||||||
|
for (int j = 0; j < N; j++) {
|
||||||
|
accs[0] = op(accs[0], __cast<U, T>(vals[j]));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (end > size) {
|
||||||
|
size_t offset = end - block.size() * N;
|
||||||
|
int block_end = size - offset;
|
||||||
|
cub::LoadDirectBlocked(
|
||||||
|
block.thread_rank(), in + offset, vals, block_end, __cast<T, U>(init));
|
||||||
|
for (int i = 0; i < N; i++) {
|
||||||
|
accs[0] = op(accs[0], __cast<U, T>(vals[i]));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
__shared__ U shared_accumulators[32];
|
||||||
|
block_reduce(block, warp, accs, shared_accumulators, op, init);
|
||||||
|
|
||||||
|
if (block.thread_rank() == 0) {
|
||||||
|
out[grid.block_rank()] = accs[0];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace cu
|
||||||
|
|
||||||
|
void all_reduce(
|
||||||
|
cu::CommandEncoder& encoder,
|
||||||
|
const array& in,
|
||||||
|
array& out,
|
||||||
|
Reduce::ReduceType reduce_type) {
|
||||||
|
constexpr int N_READS = 8;
|
||||||
|
|
||||||
|
out.set_data(allocator::malloc(out.nbytes()));
|
||||||
|
|
||||||
|
auto get_args = [](size_t size, int N) {
|
||||||
|
size_t reductions = size / N;
|
||||||
|
int threads = 512;
|
||||||
|
size_t full_blocks = (reductions + threads - 1) / threads;
|
||||||
|
int blocks;
|
||||||
|
if (full_blocks < 32) {
|
||||||
|
blocks = 1;
|
||||||
|
} else if (full_blocks < 128) {
|
||||||
|
blocks = 32;
|
||||||
|
} else if (full_blocks < 512) {
|
||||||
|
blocks = 128;
|
||||||
|
} else if (full_blocks < 1024) {
|
||||||
|
blocks = 512;
|
||||||
|
} else {
|
||||||
|
blocks = 1024;
|
||||||
|
}
|
||||||
|
size_t reductions_per_block = std::max(
|
||||||
|
static_cast<size_t>(threads), (reductions + blocks - 1) / blocks);
|
||||||
|
size_t block_step = reductions_per_block * N;
|
||||||
|
|
||||||
|
return std::make_tuple(blocks, threads, block_step);
|
||||||
|
};
|
||||||
|
|
||||||
|
int blocks, threads;
|
||||||
|
size_t block_step;
|
||||||
|
array x = in;
|
||||||
|
|
||||||
|
// Large array so allocate an intermediate and accumulate there
|
||||||
|
std::tie(blocks, threads, block_step) = get_args(x.size(), N_READS);
|
||||||
|
if (blocks > 1) {
|
||||||
|
std::tie(blocks, threads, block_step) = get_args(x.size(), N_READS);
|
||||||
|
array intermediate({blocks}, out.dtype(), nullptr, {});
|
||||||
|
intermediate.set_data(allocator::malloc(intermediate.nbytes()));
|
||||||
|
encoder.add_temporary(intermediate);
|
||||||
|
encoder.set_input_array(x);
|
||||||
|
encoder.set_output_array(intermediate);
|
||||||
|
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||||
|
MLX_SWITCH_ALL_TYPES(x.dtype(), CTYPE, {
|
||||||
|
MLX_SWITCH_REDUCE_OPS(reduce_type, OP, {
|
||||||
|
using T = cuda_type_t<CTYPE>;
|
||||||
|
using U = cu::ReduceResult<OP, T>::type;
|
||||||
|
auto kernel = cu::all_reduce<T, U, OP, N_READS>;
|
||||||
|
kernel<<<blocks, threads, 0, stream>>>(
|
||||||
|
x.data<T>(), intermediate.data<U>(), block_step, x.size());
|
||||||
|
});
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
// Set the input for the next step and recalculate the blocks
|
||||||
|
x = intermediate;
|
||||||
|
std::tie(blocks, threads, block_step) = get_args(x.size(), N_READS);
|
||||||
|
}
|
||||||
|
|
||||||
|
encoder.set_input_array(x);
|
||||||
|
encoder.set_output_array(out);
|
||||||
|
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||||
|
MLX_SWITCH_ALL_TYPES(x.dtype(), CTYPE, {
|
||||||
|
MLX_SWITCH_REDUCE_OPS(reduce_type, OP, {
|
||||||
|
using T = cuda_type_t<CTYPE>;
|
||||||
|
using U = cu::ReduceResult<OP, T>::type;
|
||||||
|
auto kernel = cu::all_reduce<T, U, OP, N_READS>;
|
||||||
|
kernel<<<blocks, threads, 0, stream>>>(
|
||||||
|
x.data<T>(), out.data<U>(), block_step, x.size());
|
||||||
|
});
|
||||||
|
});
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core
|
@ -64,86 +64,6 @@ struct ColReduceArgs {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
template <typename T, typename U, typename Op, int NDIM, int N_READS = 4>
|
|
||||||
__global__ void col_reduce_small(
|
|
||||||
const T* in,
|
|
||||||
U* out,
|
|
||||||
const __grid_constant__ ColReduceArgs args) {
|
|
||||||
auto grid = cg::this_grid();
|
|
||||||
auto block = cg::this_thread_block();
|
|
||||||
|
|
||||||
int column =
|
|
||||||
grid.block_index().x * block.dim_threads().x + block.thread_index().x;
|
|
||||||
if (column * N_READS >= args.reduction_stride) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
int out_idx = grid.block_rank() / grid.dim_blocks().x;
|
|
||||||
in += elem_to_loc(out_idx, args.shape.data(), args.strides.data(), args.ndim);
|
|
||||||
|
|
||||||
Op op;
|
|
||||||
U totals[N_READS];
|
|
||||||
for (int i = 0; i < N_READS; i++) {
|
|
||||||
totals[i] = ReduceInit<Op, T>::value();
|
|
||||||
}
|
|
||||||
|
|
||||||
// Read input to local.
|
|
||||||
LoopedElemToLoc<NDIM, (NDIM > 2)> loop(args.reduce_ndim);
|
|
||||||
loop.next(
|
|
||||||
block.thread_index().y,
|
|
||||||
args.reduce_shape.data(),
|
|
||||||
args.reduce_strides.data());
|
|
||||||
for (size_t r = block.thread_index().y;
|
|
||||||
r < args.non_col_reductions * args.reduction_size;
|
|
||||||
r += block.dim_threads().y) {
|
|
||||||
U vals[N_READS];
|
|
||||||
cub::LoadDirectBlocked(
|
|
||||||
column,
|
|
||||||
make_cast_iterator<U>(in + loop.location()),
|
|
||||||
vals,
|
|
||||||
args.reduction_stride,
|
|
||||||
ReduceInit<Op, T>::value());
|
|
||||||
for (int i = 0; i < N_READS; i++) {
|
|
||||||
totals[i] = op(vals[i], totals[i]);
|
|
||||||
}
|
|
||||||
loop.next(
|
|
||||||
block.dim_threads().y,
|
|
||||||
args.reduce_shape.data(),
|
|
||||||
args.reduce_strides.data());
|
|
||||||
}
|
|
||||||
|
|
||||||
// Do block reduce when each column has more than 1 element to reduce.
|
|
||||||
if (block.dim_threads().y > 1) {
|
|
||||||
__shared__ U shared_vals[32 * 8 * N_READS];
|
|
||||||
size_t col =
|
|
||||||
block.thread_index().y * block.dim_threads().x + block.thread_index().x;
|
|
||||||
for (int i = 0; i < N_READS; i++) {
|
|
||||||
shared_vals[col * N_READS + i] = totals[i];
|
|
||||||
}
|
|
||||||
block.sync();
|
|
||||||
if (block.thread_index().y == 0) {
|
|
||||||
for (int i = 0; i < N_READS; i++) {
|
|
||||||
totals[i] = shared_vals[block.thread_index().x * N_READS + i];
|
|
||||||
}
|
|
||||||
for (int j = 1; j < block.dim_threads().y; j++) {
|
|
||||||
col = j * block.dim_threads().x + block.thread_index().x;
|
|
||||||
for (int i = 0; i < N_READS; i++) {
|
|
||||||
totals[i] = op(shared_vals[col * N_READS + i], totals[i]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Write result.
|
|
||||||
if (block.thread_index().y == 0) {
|
|
||||||
cub::StoreDirectBlocked(
|
|
||||||
column,
|
|
||||||
out + out_idx * args.reduction_stride,
|
|
||||||
totals,
|
|
||||||
args.reduction_stride);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
template <
|
template <
|
||||||
typename T,
|
typename T,
|
||||||
typename U,
|
typename U,
|
||||||
@ -152,67 +72,83 @@ template <
|
|||||||
int BM,
|
int BM,
|
||||||
int BN,
|
int BN,
|
||||||
int N_READS = 4>
|
int N_READS = 4>
|
||||||
__global__ void col_reduce_looped(
|
__global__ void
|
||||||
const T* in,
|
col_reduce_looped(T* in, U* out, const __grid_constant__ ColReduceArgs args) {
|
||||||
U* out,
|
|
||||||
const __grid_constant__ ColReduceArgs args) {
|
|
||||||
auto grid = cg::this_grid();
|
auto grid = cg::this_grid();
|
||||||
auto block = cg::this_thread_block();
|
auto block = cg::this_thread_block();
|
||||||
auto warp = cg::tiled_partition<WARP_SIZE>(block);
|
auto warp = cg::tiled_partition<WARP_SIZE>(block);
|
||||||
|
|
||||||
constexpr int n_warps = BN / N_READS;
|
constexpr int threads_per_row = BN / N_READS;
|
||||||
|
|
||||||
int out_idx = grid.block_rank() / grid.dim_blocks().x;
|
// Compute the indices for the tile
|
||||||
in += elem_to_loc(out_idx, args.shape.data(), args.strides.data(), args.ndim);
|
size_t tile_idx = grid.block_rank();
|
||||||
|
size_t tile_x = tile_idx % ((args.reduction_stride + BN - 1) / BN);
|
||||||
|
size_t tile_y = tile_idx / ((args.reduction_stride + BN - 1) / BN);
|
||||||
|
|
||||||
|
// Compute the indices for the thread within the tile
|
||||||
|
short thread_x = block.thread_rank() % threads_per_row;
|
||||||
|
short thread_y = block.thread_rank() / threads_per_row;
|
||||||
|
|
||||||
|
// Move the input pointer
|
||||||
|
in += elem_to_loc(tile_y, args.shape.data(), args.strides.data(), args.ndim) +
|
||||||
|
tile_x * BN;
|
||||||
|
|
||||||
|
// Initialize the running totals
|
||||||
Op op;
|
Op op;
|
||||||
U totals[N_READS];
|
U totals[N_READS];
|
||||||
for (int i = 0; i < N_READS; i++) {
|
for (int i = 0; i < N_READS; i++) {
|
||||||
totals[i] = ReduceInit<Op, T>::value();
|
totals[i] = ReduceInit<Op, T>::value();
|
||||||
}
|
}
|
||||||
|
|
||||||
// Read input to local.
|
|
||||||
int r = block.thread_rank() / n_warps;
|
|
||||||
int column = block.thread_rank() % n_warps;
|
|
||||||
int in_offset = grid.block_index().x * BN;
|
|
||||||
LoopedElemToLoc<NDIM, (NDIM > 2)> loop(args.reduce_ndim);
|
LoopedElemToLoc<NDIM, (NDIM > 2)> loop(args.reduce_ndim);
|
||||||
loop.next(r, args.reduce_shape.data(), args.reduce_strides.data());
|
loop.next(thread_y, args.reduce_shape.data(), args.reduce_strides.data());
|
||||||
for (; r < args.non_col_reductions * args.reduction_size; r += BM) {
|
size_t total = args.non_col_reductions * args.reduction_size;
|
||||||
U vals[N_READS];
|
if (tile_x * BN + BN <= args.reduction_stride) {
|
||||||
cub::LoadDirectBlocked(
|
for (size_t r = thread_y; r < total; r += BM) {
|
||||||
column,
|
T vals[N_READS];
|
||||||
make_cast_iterator<U>(in + loop.location() + in_offset),
|
cub::LoadDirectBlockedVectorized(thread_x, in + loop.location(), vals);
|
||||||
vals,
|
for (int i = 0; i < N_READS; i++) {
|
||||||
args.reduction_stride - in_offset,
|
totals[i] = op(totals[i], __cast<U, T>(vals[i]));
|
||||||
ReduceInit<Op, T>::value());
|
}
|
||||||
for (int i = 0; i < N_READS; i++) {
|
loop.next(BM, args.reduce_shape.data(), args.reduce_strides.data());
|
||||||
totals[i] = op(vals[i], totals[i]);
|
}
|
||||||
|
} else {
|
||||||
|
for (size_t r = thread_y; r < total; r += BM) {
|
||||||
|
T vals[N_READS];
|
||||||
|
cub::LoadDirectBlocked(
|
||||||
|
thread_x,
|
||||||
|
in + loop.location(),
|
||||||
|
vals,
|
||||||
|
args.reduction_stride - tile_x * BN,
|
||||||
|
__cast<T, U>(ReduceInit<Op, T>::value()));
|
||||||
|
for (int i = 0; i < N_READS; i++) {
|
||||||
|
totals[i] = op(totals[i], __cast<U, T>(vals[i]));
|
||||||
|
}
|
||||||
|
loop.next(BM, args.reduce_shape.data(), args.reduce_strides.data());
|
||||||
}
|
}
|
||||||
loop.next(BM, args.reduce_shape.data(), args.reduce_strides.data());
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Do warp reduce for each output.
|
// Do warp reduce for each output.
|
||||||
constexpr int n_outputs = BN / n_warps;
|
constexpr int n_outputs = BN / threads_per_row;
|
||||||
static_assert(BM == 32 && n_outputs == N_READS);
|
static_assert(BM == 32 && n_outputs == N_READS);
|
||||||
__shared__ U shared_vals[BM * BN];
|
__shared__ U shared_vals[BM * BN];
|
||||||
size_t col = block.thread_index().y * BN + block.thread_index().x * N_READS;
|
short s_idx = thread_y * BN + thread_x * N_READS;
|
||||||
for (int i = 0; i < N_READS; i++) {
|
for (int i = 0; i < N_READS; i++) {
|
||||||
shared_vals[col + i] = totals[i];
|
shared_vals[s_idx + i] = totals[i];
|
||||||
}
|
}
|
||||||
block.sync();
|
block.sync();
|
||||||
col = warp.thread_rank() * BN + warp.meta_group_rank() * n_outputs;
|
s_idx = warp.thread_rank() * BN + warp.meta_group_rank() * n_outputs;
|
||||||
for (int i = 0; i < n_outputs; i++) {
|
for (int i = 0; i < n_outputs; i++) {
|
||||||
totals[i] = cg::reduce(warp, shared_vals[col + i], op);
|
totals[i] = cg::reduce(warp, shared_vals[s_idx + i], op);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Write result.
|
// Write result.
|
||||||
if (warp.thread_rank() == 0) {
|
if (warp.thread_rank() == 0) {
|
||||||
size_t out_offset = grid.block_index().x * BN;
|
|
||||||
cub::StoreDirectBlocked(
|
cub::StoreDirectBlocked(
|
||||||
warp.meta_group_rank(),
|
warp.meta_group_rank(),
|
||||||
out + out_idx * args.reduction_stride + out_offset,
|
out + tile_y * args.reduction_stride + tile_x * BN,
|
||||||
totals,
|
totals,
|
||||||
args.reduction_stride - out_offset);
|
args.reduction_stride - tile_x * BN);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -230,6 +166,53 @@ inline auto output_grid_for_col_reduce(
|
|||||||
return get_2d_grid_dims(out_shape, out_strides);
|
return get_2d_grid_dims(out_shape, out_strides);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void col_reduce_looped(
|
||||||
|
cu::CommandEncoder& encoder,
|
||||||
|
const array& in,
|
||||||
|
array& out,
|
||||||
|
Reduce::ReduceType reduce_type,
|
||||||
|
const std::vector<int>& axes,
|
||||||
|
const ReductionPlan& plan,
|
||||||
|
cu::ColReduceArgs args) {
|
||||||
|
// Allocate data for the output using in's layout to access them as
|
||||||
|
// contiguously as possible.
|
||||||
|
allocate_same_layout(out, in, axes);
|
||||||
|
|
||||||
|
// Just a way to get out of the constness because cub doesn't like it ...
|
||||||
|
// (sigh)
|
||||||
|
array x = in;
|
||||||
|
|
||||||
|
encoder.set_input_array(x);
|
||||||
|
encoder.set_output_array(out);
|
||||||
|
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||||
|
MLX_SWITCH_ALL_TYPES(x.dtype(), CTYPE, {
|
||||||
|
MLX_SWITCH_REDUCE_OPS(reduce_type, OP, {
|
||||||
|
MLX_SWITCH_REDUCE_NDIM(args.reduce_ndim, NDIM, {
|
||||||
|
using T = cuda_type_t<CTYPE>;
|
||||||
|
using U = cu::ReduceResult<OP, T>::type;
|
||||||
|
|
||||||
|
constexpr int N_READS = 4;
|
||||||
|
constexpr int BM = 32;
|
||||||
|
constexpr int BN = 32;
|
||||||
|
dim3 grid = output_grid_for_col_reduce(out, args);
|
||||||
|
size_t extra_blocks = cuda::ceil_div(args.reduction_stride, BN);
|
||||||
|
if (grid.x * extra_blocks < INT32_MAX) {
|
||||||
|
grid.x *= extra_blocks;
|
||||||
|
} else if (grid.y * extra_blocks < 65536) {
|
||||||
|
grid.y *= extra_blocks;
|
||||||
|
} else {
|
||||||
|
throw std::runtime_error(
|
||||||
|
"[col_reduce_looped] Need to factorize reduction_stride");
|
||||||
|
}
|
||||||
|
int blocks = BM * BN / N_READS;
|
||||||
|
auto kernel = cu::col_reduce_looped<T, U, OP, NDIM, BM, BN, N_READS>;
|
||||||
|
kernel<<<grid, blocks, 0, stream>>>(x.data<T>(), out.data<U>(), args);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
});
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
void col_reduce(
|
void col_reduce(
|
||||||
cu::CommandEncoder& encoder,
|
cu::CommandEncoder& encoder,
|
||||||
const array& in,
|
const array& in,
|
||||||
@ -237,42 +220,24 @@ void col_reduce(
|
|||||||
Reduce::ReduceType reduce_type,
|
Reduce::ReduceType reduce_type,
|
||||||
const std::vector<int>& axes,
|
const std::vector<int>& axes,
|
||||||
const ReductionPlan& plan) {
|
const ReductionPlan& plan) {
|
||||||
|
// Current col reduce options
|
||||||
|
//
|
||||||
|
// - col_reduce_looped
|
||||||
|
//
|
||||||
|
// It is a general strided reduce. Each threadblock computes the output for
|
||||||
|
// a subrow of the fast moving axis. For instance 32 elements.
|
||||||
|
//
|
||||||
|
// Notes: As in row reduce we opt to read as much in order as possible and
|
||||||
|
// leave
|
||||||
|
// transpositions as they are (contrary to our Metal backend).
|
||||||
|
//
|
||||||
|
// Moreover we need different kernels for short rows and tuning
|
||||||
|
|
||||||
|
// Make the args struct to help route to the best kernel
|
||||||
cu::ColReduceArgs args(in, plan, axes);
|
cu::ColReduceArgs args(in, plan, axes);
|
||||||
|
|
||||||
encoder.launch_kernel([&](cudaStream_t stream) {
|
// Fallback col reduce
|
||||||
MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE, {
|
col_reduce_looped(encoder, in, out, reduce_type, axes, plan, args);
|
||||||
using InType = cuda_type_t<CTYPE>;
|
|
||||||
MLX_SWITCH_REDUCE_OPS(reduce_type, OP, {
|
|
||||||
using OutType = cu::ReduceResult<OP, InType>::type;
|
|
||||||
MLX_SWITCH_REDUCE_NDIM(args.reduce_ndim, NDIM, {
|
|
||||||
constexpr int N_READS = 4;
|
|
||||||
dim3 block_dims;
|
|
||||||
dim3 num_blocks = output_grid_for_col_reduce(out, args);
|
|
||||||
num_blocks.z = num_blocks.y;
|
|
||||||
num_blocks.y = num_blocks.x;
|
|
||||||
auto kernel =
|
|
||||||
cu::col_reduce_small<InType, OutType, OP, NDIM, N_READS>;
|
|
||||||
size_t total = args.non_col_reductions * args.reduction_size;
|
|
||||||
if (total < 32) {
|
|
||||||
size_t stride_blocks =
|
|
||||||
cuda::ceil_div(args.reduction_stride, N_READS);
|
|
||||||
block_dims.x = std::min(stride_blocks, 32ul);
|
|
||||||
block_dims.y = std::min(total, 8ul);
|
|
||||||
num_blocks.x = cuda::ceil_div(stride_blocks, block_dims.x);
|
|
||||||
} else {
|
|
||||||
constexpr int BM = 32;
|
|
||||||
constexpr int BN = 32;
|
|
||||||
block_dims.x = BM * BN / N_READS;
|
|
||||||
num_blocks.x = cuda::ceil_div(args.reduction_stride, BN);
|
|
||||||
kernel = cu::
|
|
||||||
col_reduce_looped<InType, OutType, OP, NDIM, BM, BN, N_READS>;
|
|
||||||
}
|
|
||||||
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
|
||||||
in.data<InType>(), out.data<OutType>(), args);
|
|
||||||
});
|
|
||||||
});
|
|
||||||
});
|
|
||||||
});
|
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
51
mlx/backend/cuda/reduce/init_reduce.cu
Normal file
51
mlx/backend/cuda/reduce/init_reduce.cu
Normal file
@ -0,0 +1,51 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#include "mlx/backend/cuda/device.h"
|
||||||
|
#include "mlx/backend/cuda/reduce/reduce.cuh"
|
||||||
|
|
||||||
|
#include <cooperative_groups.h>
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
|
||||||
|
namespace cu {
|
||||||
|
|
||||||
|
namespace cg = cooperative_groups;
|
||||||
|
|
||||||
|
template <typename T, typename U, typename Op>
|
||||||
|
__global__ void init_reduce(U* out, size_t size) {
|
||||||
|
auto index = cg::this_grid().thread_rank();
|
||||||
|
if (index < size) {
|
||||||
|
out[index] = ReduceInit<Op, T>::value();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace cu
|
||||||
|
|
||||||
|
void init_reduce(
|
||||||
|
cu::CommandEncoder& encoder,
|
||||||
|
const array& in,
|
||||||
|
array& out,
|
||||||
|
Reduce::ReduceType reduce_type) {
|
||||||
|
// Allocate if needed
|
||||||
|
if (out.data_shared_ptr() == nullptr) {
|
||||||
|
out.set_data(allocator::malloc(out.nbytes()));
|
||||||
|
}
|
||||||
|
|
||||||
|
encoder.set_input_array(in);
|
||||||
|
encoder.set_output_array(out);
|
||||||
|
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||||
|
MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE, {
|
||||||
|
MLX_SWITCH_REDUCE_OPS(reduce_type, OP, {
|
||||||
|
using T = cuda_type_t<CTYPE>;
|
||||||
|
using U = cu::ReduceResult<OP, T>::type;
|
||||||
|
auto kernel = cu::init_reduce<T, U, OP>;
|
||||||
|
dim3 grid = get_2d_grid_dims(out.shape(), out.strides());
|
||||||
|
dim3 block(grid.x < 1024 ? grid.x : 1024, 1, 1);
|
||||||
|
grid.x = (grid.x + 1023) / 1024;
|
||||||
|
kernel<<<grid, block, 0, stream>>>(out.data<U>(), out.size());
|
||||||
|
});
|
||||||
|
});
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core
|
@ -47,13 +47,11 @@ namespace mlx::core {
|
|||||||
throw std::invalid_argument("Unknown reduce type."); \
|
throw std::invalid_argument("Unknown reduce type."); \
|
||||||
}
|
}
|
||||||
|
|
||||||
void segmented_reduce(
|
void all_reduce(
|
||||||
cu::CommandEncoder& encoder,
|
cu::CommandEncoder& encoder,
|
||||||
const array& in,
|
const array& in,
|
||||||
array& out,
|
array& out,
|
||||||
Reduce::ReduceType reduce_type,
|
Reduce::ReduceType reduce_type);
|
||||||
const std::vector<int>& axes,
|
|
||||||
const ReductionPlan& plan);
|
|
||||||
|
|
||||||
void row_reduce(
|
void row_reduce(
|
||||||
cu::CommandEncoder& encoder,
|
cu::CommandEncoder& encoder,
|
||||||
@ -71,4 +69,10 @@ void col_reduce(
|
|||||||
const std::vector<int>& axes,
|
const std::vector<int>& axes,
|
||||||
const ReductionPlan& plan);
|
const ReductionPlan& plan);
|
||||||
|
|
||||||
|
void init_reduce(
|
||||||
|
cu::CommandEncoder& encoder,
|
||||||
|
const array& in,
|
||||||
|
array& out,
|
||||||
|
Reduce::ReduceType reduce_type);
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
@ -3,48 +3,89 @@
|
|||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include "mlx/backend/cuda/device/utils.cuh"
|
#include "mlx/backend/cuda/device/utils.cuh"
|
||||||
|
#include "mlx/backend/cuda/reduce/reduce_utils.cuh"
|
||||||
|
|
||||||
namespace mlx::core::cu {
|
namespace mlx::core::cu {
|
||||||
|
|
||||||
// Reduce ops.
|
// Reduce ops.
|
||||||
struct And {
|
struct And {
|
||||||
__device__ bool operator()(bool a, bool b) {
|
__device__ __forceinline__ bool operator()(bool a, bool b) {
|
||||||
return a && b;
|
return a && b;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
__device__ void atomic_update(bool* x, bool y) {
|
||||||
|
atomic_reduce<bool, And>(x, y);
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
struct Or {
|
struct Or {
|
||||||
__device__ bool operator()(bool a, bool b) {
|
__device__ __forceinline__ bool operator()(bool a, bool b) {
|
||||||
return a || b;
|
return a || b;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
__device__ void atomic_update(bool* x, bool y) {
|
||||||
|
atomic_reduce<bool, Or>(x, y);
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
struct Sum {
|
struct Sum {
|
||||||
template <typename T>
|
template <typename T>
|
||||||
__device__ T operator()(T a, T b) {
|
__device__ __forceinline__ T operator()(T a, T b) {
|
||||||
return a + b;
|
return a + b;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
__device__ void atomic_update(T* x, T y) {
|
||||||
|
atomic_reduce<T, Sum>(x, y);
|
||||||
|
}
|
||||||
|
|
||||||
|
__device__ void atomic_update(__nv_bfloat16* x, __nv_bfloat16 y) {
|
||||||
|
atomicAdd(x, y);
|
||||||
|
}
|
||||||
|
|
||||||
|
__device__ void atomic_update(int* x, int y) {
|
||||||
|
atomicAdd(x, y);
|
||||||
|
}
|
||||||
|
|
||||||
|
__device__ void atomic_update(float* x, float y) {
|
||||||
|
atomicAdd(x, y);
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
struct Prod {
|
struct Prod {
|
||||||
template <typename T>
|
template <typename T>
|
||||||
__device__ T operator()(T a, T b) {
|
__device__ __forceinline__ T operator()(T a, T b) {
|
||||||
return a * b;
|
return a * b;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
__device__ void atomic_update(T* x, T y) {
|
||||||
|
atomic_reduce<T, Prod>(x, y);
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
struct Min {
|
struct Min {
|
||||||
template <typename T>
|
template <typename T>
|
||||||
__device__ T operator()(T a, T b) {
|
__device__ __forceinline__ T operator()(T a, T b) {
|
||||||
return a < b ? a : b;
|
return a < b ? a : b;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
__device__ void atomic_update(T* x, T y) {
|
||||||
|
atomic_reduce<T, Min>(x, y);
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
struct Max {
|
struct Max {
|
||||||
template <typename T>
|
template <typename T>
|
||||||
__device__ T operator()(T a, T b) {
|
__device__ __forceinline__ T operator()(T a, T b) {
|
||||||
return a > b ? a : b;
|
return a > b ? a : b;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
__device__ void atomic_update(T* x, T y) {
|
||||||
|
atomic_reduce<T, Max>(x, y);
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
// Traits to get the result type of reduce op.
|
// Traits to get the result type of reduce op.
|
||||||
@ -120,7 +161,7 @@ template <typename T>
|
|||||||
struct ReduceInit<Prod, T> {
|
struct ReduceInit<Prod, T> {
|
||||||
static constexpr __host__ __device__ auto value() {
|
static constexpr __host__ __device__ auto value() {
|
||||||
if constexpr (cuda::std::is_same_v<T, cuComplex>) {
|
if constexpr (cuda::std::is_same_v<T, cuComplex>) {
|
||||||
return T{1, 1};
|
return T{1, 0};
|
||||||
} else {
|
} else {
|
||||||
return typename ReduceResult<Prod, T>::type{1};
|
return typename ReduceResult<Prod, T>::type{1};
|
||||||
}
|
}
|
||||||
|
134
mlx/backend/cuda/reduce/reduce_utils.cuh
Normal file
134
mlx/backend/cuda/reduce/reduce_utils.cuh
Normal file
@ -0,0 +1,134 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "mlx/backend/cuda/device/utils.cuh"
|
||||||
|
|
||||||
|
#include <cooperative_groups.h>
|
||||||
|
#include <cooperative_groups/reduce.h>
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
|
||||||
|
namespace cu {
|
||||||
|
|
||||||
|
namespace cg = cooperative_groups;
|
||||||
|
|
||||||
|
template <size_t N>
|
||||||
|
struct uint_by_size;
|
||||||
|
template <>
|
||||||
|
struct uint_by_size<2> {
|
||||||
|
using type = uint16_t;
|
||||||
|
};
|
||||||
|
template <>
|
||||||
|
struct uint_by_size<4> {
|
||||||
|
using type = uint32_t;
|
||||||
|
};
|
||||||
|
template <>
|
||||||
|
struct uint_by_size<8> {
|
||||||
|
using type = unsigned long long int;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename T, typename Op>
|
||||||
|
__device__ void atomic_reduce(T* x, T y) {
|
||||||
|
if constexpr (sizeof(T) == 1) {
|
||||||
|
using U = uint16_t;
|
||||||
|
U* x_int = (U*)((char*)x - ((size_t)x % 2));
|
||||||
|
int shift = ((char*)x - (char*)x_int) * 8;
|
||||||
|
int mask = 0xff << shift;
|
||||||
|
U old_val, new_val;
|
||||||
|
do {
|
||||||
|
old_val = *x_int;
|
||||||
|
T result = Op{}(static_cast<T>((old_val >> shift) & 0xff), y);
|
||||||
|
new_val = (old_val & ~mask) | (result << shift);
|
||||||
|
} while (atomicCAS(x_int, old_val, new_val) != old_val);
|
||||||
|
} else {
|
||||||
|
using U = typename uint_by_size<sizeof(T)>::type;
|
||||||
|
U* x_int = (U*)(x);
|
||||||
|
U old_val, new_val;
|
||||||
|
do {
|
||||||
|
old_val = *x_int;
|
||||||
|
T result = Op{}(*((T*)&old_val), y);
|
||||||
|
new_val = *((U*)&result);
|
||||||
|
} while (atomicCAS(x_int, old_val, new_val) != old_val);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: Should make a custom complex type
|
||||||
|
template <typename U, typename T>
|
||||||
|
inline __device__ U __cast(T x) {
|
||||||
|
return static_cast<U>(x);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
inline __device__ bool __cast<bool, cuComplex>(cuComplex x) {
|
||||||
|
return x.x != 0 && x.y != 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
inline __device__ cuComplex __cast<cuComplex, bool>(bool x) {
|
||||||
|
return x ? make_cuFloatComplex(1, 1) : make_cuFloatComplex(0, 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, int N, typename Block, typename Warp, typename Op>
|
||||||
|
inline __device__ void
|
||||||
|
block_reduce(Block block, Warp warp, T (&vals)[N], T* smem, Op op, T init) {
|
||||||
|
// First reduce in the current warp
|
||||||
|
for (int i = 0; i < N; i++) {
|
||||||
|
vals[i] = cg::reduce(warp, vals[i], op);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reduce across warps
|
||||||
|
if (warp.meta_group_size() > 1) {
|
||||||
|
if (warp.thread_rank() == 0) {
|
||||||
|
for (int i = 0; i < N; i++) {
|
||||||
|
smem[warp.meta_group_rank() * N + i] = vals[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
block.sync();
|
||||||
|
if (warp.thread_rank() < warp.meta_group_size()) {
|
||||||
|
for (int i = 0; i < N; i++) {
|
||||||
|
vals[i] = smem[warp.thread_rank() * N + i];
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for (int i = 0; i < N; i++) {
|
||||||
|
vals[i] = init;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for (int i = 0; i < N; i++) {
|
||||||
|
vals[i] = cg::reduce(warp, vals[i], op);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace cu
|
||||||
|
|
||||||
|
inline void allocate_same_layout(
|
||||||
|
array& out,
|
||||||
|
const array& in,
|
||||||
|
const std::vector<int>& axes) {
|
||||||
|
// Initialize out such that it matches in's layout. Basically we keep any
|
||||||
|
// transpositions as it were and that allows us either to skip finding the
|
||||||
|
// location of the output that matches the input or simply contiguous read or
|
||||||
|
// writes.
|
||||||
|
auto out_strides = in.strides();
|
||||||
|
for (auto ax : axes) {
|
||||||
|
for (auto& s : out_strides) {
|
||||||
|
if (s > in.strides(ax)) {
|
||||||
|
s /= in.shape(ax);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
auto [data_size, rc, cc] = check_contiguity(out.shape(), out_strides);
|
||||||
|
auto fl = in.flags();
|
||||||
|
fl.row_contiguous = rc;
|
||||||
|
fl.col_contiguous = cc;
|
||||||
|
fl.contiguous = data_size == out.size();
|
||||||
|
out.set_data(
|
||||||
|
allocator::malloc(out.nbytes()),
|
||||||
|
data_size,
|
||||||
|
out_strides,
|
||||||
|
fl,
|
||||||
|
allocator::free);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core
|
@ -55,84 +55,88 @@ struct RowReduceArgs {
|
|||||||
non_row_reductions *= reduce_shape[i];
|
non_row_reductions *= reduce_shape[i];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Convert shape and strides as if in was contiguous
|
||||||
|
void convert_shapes_to_contiguous(
|
||||||
|
const array& in,
|
||||||
|
const std::vector<int>& axes) {
|
||||||
|
auto shape_vec = in.shape();
|
||||||
|
auto strides_vec = in.strides();
|
||||||
|
size_t s = 1;
|
||||||
|
for (int i = in.ndim() - 1; i >= 0; i--) {
|
||||||
|
strides_vec[i] = s;
|
||||||
|
s *= shape_vec[i];
|
||||||
|
}
|
||||||
|
std::tie(shape_vec, strides_vec) =
|
||||||
|
shapes_without_reduction_axes(shape_vec, strides_vec, axes);
|
||||||
|
std::tie(shape_vec, strides_vec) =
|
||||||
|
collapse_contiguous_dims(shape_vec, strides_vec);
|
||||||
|
shape = const_param(shape_vec);
|
||||||
|
strides = const_param(strides_vec);
|
||||||
|
ndim = shape_vec.size();
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
template <typename T, typename U, typename Op, int NDIM, int N_READS = 4>
|
template <typename T, typename U, typename ReduceOp, int N = 4, int M = 1>
|
||||||
__global__ void row_reduce_small(
|
__global__ void row_reduce_simple(T* in, U* out, size_t n_rows, int size) {
|
||||||
const T* in,
|
|
||||||
U* out,
|
|
||||||
size_t out_size,
|
|
||||||
const __grid_constant__ RowReduceArgs args) {
|
|
||||||
size_t out_idx = cg::this_grid().thread_rank();
|
|
||||||
if (out_idx >= out_size) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
Op op;
|
|
||||||
|
|
||||||
U total_val = ReduceInit<Op, T>::value();
|
|
||||||
LoopedElemToLoc<NDIM, (NDIM > 2)> loop(args.reduce_ndim);
|
|
||||||
|
|
||||||
in += elem_to_loc(out_idx, args.shape.data(), args.strides.data(), args.ndim);
|
|
||||||
|
|
||||||
for (size_t n = 0; n < args.non_row_reductions; n++) {
|
|
||||||
for (int r = 0; r < cuda::ceil_div(args.row_size, N_READS); r++) {
|
|
||||||
U vals[N_READS];
|
|
||||||
cub::LoadDirectBlocked(
|
|
||||||
r,
|
|
||||||
make_cast_iterator<U>(in + loop.location()),
|
|
||||||
vals,
|
|
||||||
args.row_size,
|
|
||||||
ReduceInit<Op, T>::value());
|
|
||||||
total_val = op(total_val, cub::ThreadReduce(vals, op));
|
|
||||||
}
|
|
||||||
loop.next(args.reduce_shape.data(), args.reduce_strides.data());
|
|
||||||
}
|
|
||||||
|
|
||||||
out[out_idx] = total_val;
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T, typename U, typename Op, int NDIM, int N_READS = 4>
|
|
||||||
__global__ void row_reduce_small_warp(
|
|
||||||
const T* in,
|
|
||||||
U* out,
|
|
||||||
size_t out_size,
|
|
||||||
const __grid_constant__ RowReduceArgs args) {
|
|
||||||
auto grid = cg::this_grid();
|
auto grid = cg::this_grid();
|
||||||
auto block = cg::this_thread_block();
|
auto block = cg::this_thread_block();
|
||||||
auto warp = cg::tiled_partition<WARP_SIZE>(block);
|
auto warp = cg::tiled_partition<WARP_SIZE>(block);
|
||||||
|
|
||||||
size_t out_idx = grid.thread_rank() / WARP_SIZE;
|
const U init = cu::ReduceInit<ReduceOp, T>::value();
|
||||||
if (out_idx >= out_size) {
|
ReduceOp op;
|
||||||
return;
|
|
||||||
|
T vals[M][N];
|
||||||
|
U accs[M];
|
||||||
|
for (int i = 0; i < M; i++) {
|
||||||
|
accs[i] = init;
|
||||||
}
|
}
|
||||||
|
|
||||||
Op op;
|
const size_t start_row =
|
||||||
|
min(n_rows - M, static_cast<size_t>(grid.block_rank() * M));
|
||||||
|
const size_t full_blocks = size / (block.size() * N);
|
||||||
|
const size_t final_offset = full_blocks * (block.size() * N);
|
||||||
|
in += start_row * size;
|
||||||
|
out += start_row;
|
||||||
|
|
||||||
U total_val = ReduceInit<Op, T>::value();
|
for (size_t r = 0; r < full_blocks; r++) {
|
||||||
LoopedElemToLoc<NDIM, (NDIM > 2)> loop(args.reduce_ndim);
|
for (int k = 0; k < M; k++) {
|
||||||
|
cub::LoadDirectBlockedVectorized<T, N>(
|
||||||
in += elem_to_loc(out_idx, args.shape.data(), args.strides.data(), args.ndim);
|
block.thread_rank(), in + k * size + r * (block.size() * N), vals[k]);
|
||||||
|
for (int j = 0; j < N; j++) {
|
||||||
for (size_t n = warp.thread_rank(); n < args.non_row_reductions;
|
accs[k] = op(accs[k], __cast<U, T>(vals[k][j]));
|
||||||
n += WARP_SIZE) {
|
}
|
||||||
for (int r = 0; r < cuda::ceil_div(args.row_size, N_READS); r++) {
|
|
||||||
U vals[N_READS];
|
|
||||||
cub::LoadDirectBlocked(
|
|
||||||
r,
|
|
||||||
make_cast_iterator<U>(in + loop.location()),
|
|
||||||
vals,
|
|
||||||
args.row_size,
|
|
||||||
ReduceInit<Op, T>::value());
|
|
||||||
total_val = op(total_val, cub::ThreadReduce(vals, op));
|
|
||||||
}
|
}
|
||||||
loop.next(WARP_SIZE, args.reduce_shape.data(), args.reduce_strides.data());
|
|
||||||
}
|
}
|
||||||
|
|
||||||
total_val = cg::reduce(warp, total_val, op);
|
if (final_offset < size) {
|
||||||
|
for (int k = 0; k < M; k++) {
|
||||||
|
cub::LoadDirectBlocked(
|
||||||
|
block.thread_rank(),
|
||||||
|
in + k * size + final_offset,
|
||||||
|
vals[k],
|
||||||
|
size,
|
||||||
|
__cast<T, U>(init));
|
||||||
|
for (int j = 0; j < N; j++) {
|
||||||
|
accs[k] = op(accs[k], __cast<U, T>(vals[k][j]));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if (warp.thread_rank() == 0) {
|
__shared__ U shared_accumulators[32 * M];
|
||||||
out[out_idx] = total_val;
|
block_reduce(block, warp, accs, shared_accumulators, op, init);
|
||||||
|
|
||||||
|
if (block.thread_rank() == 0) {
|
||||||
|
if (grid.block_rank() * M + M <= n_rows) {
|
||||||
|
for (int i = 0; i < M; i++) {
|
||||||
|
out[i] = accs[i];
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
short offset = grid.block_rank() * M + M - n_rows;
|
||||||
|
for (int i = offset; i < M; i++) {
|
||||||
|
out[i] = accs[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -141,55 +145,167 @@ template <
|
|||||||
typename U,
|
typename U,
|
||||||
typename Op,
|
typename Op,
|
||||||
int NDIM,
|
int NDIM,
|
||||||
int BLOCK_DIM_X,
|
int BLOCK_DIM,
|
||||||
int N_READS = 4>
|
int N_READS = 4>
|
||||||
__global__ void row_reduce_looped(
|
__global__ void row_reduce_looped(
|
||||||
const T* in,
|
T* in,
|
||||||
U* out,
|
U* out,
|
||||||
size_t out_size,
|
size_t out_size,
|
||||||
const __grid_constant__ RowReduceArgs args) {
|
const __grid_constant__ RowReduceArgs args) {
|
||||||
auto grid = cg::this_grid();
|
auto grid = cg::this_grid();
|
||||||
auto block = cg::this_thread_block();
|
auto block = cg::this_thread_block();
|
||||||
|
auto warp = cg::tiled_partition<WARP_SIZE>(block);
|
||||||
|
|
||||||
size_t out_idx = grid.thread_rank() / BLOCK_DIM_X;
|
size_t out_idx = grid.block_rank();
|
||||||
if (out_idx >= out_size) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
Op op;
|
Op op;
|
||||||
|
|
||||||
U total_val = ReduceInit<Op, T>::value();
|
U total[1];
|
||||||
|
U init = ReduceInit<Op, T>::value();
|
||||||
|
total[0] = init;
|
||||||
LoopedElemToLoc<NDIM, (NDIM > 2)> loop(args.reduce_ndim);
|
LoopedElemToLoc<NDIM, (NDIM > 2)> loop(args.reduce_ndim);
|
||||||
|
size_t full_blocks = args.row_size / (BLOCK_DIM * N_READS);
|
||||||
|
size_t final_offset = full_blocks * BLOCK_DIM * N_READS;
|
||||||
|
|
||||||
in += elem_to_loc(out_idx, args.shape.data(), args.strides.data(), args.ndim);
|
in += elem_to_loc(out_idx, args.shape.data(), args.strides.data(), args.ndim);
|
||||||
|
|
||||||
for (size_t n = 0; n < args.non_row_reductions; n++) {
|
for (size_t n = 0; n < args.non_row_reductions; n++) {
|
||||||
for (size_t r = 0; r < cuda::ceil_div(args.row_size, BLOCK_DIM_X * N_READS);
|
for (size_t r = 0; r < full_blocks; r++) {
|
||||||
r++) {
|
T vals[N_READS];
|
||||||
U vals[N_READS];
|
cub::LoadDirectBlockedVectorized<T, N_READS>(
|
||||||
cub::LoadDirectBlocked(
|
block.thread_rank(),
|
||||||
r * BLOCK_DIM_X + block.thread_index().x,
|
in + loop.location() + r * BLOCK_DIM * N_READS,
|
||||||
make_cast_iterator<U>(in + loop.location()),
|
vals);
|
||||||
vals,
|
for (int i = 0; i < N_READS; i++) {
|
||||||
args.row_size,
|
total[0] = op(total[0], __cast<U, T>(vals[i]));
|
||||||
ReduceInit<Op, T>::value());
|
}
|
||||||
total_val = op(total_val, cub::ThreadReduce(vals, op));
|
|
||||||
}
|
}
|
||||||
|
if (final_offset < args.row_size) {
|
||||||
|
T vals[N_READS];
|
||||||
|
cub::LoadDirectBlocked(
|
||||||
|
block.thread_rank(),
|
||||||
|
in + loop.location() + final_offset,
|
||||||
|
vals,
|
||||||
|
args.row_size - final_offset,
|
||||||
|
__cast<T, U>(init));
|
||||||
|
for (int i = 0; i < N_READS; i++) {
|
||||||
|
total[0] = op(total[0], __cast<U, T>(vals[i]));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// TODO: Maybe block.sync() here?
|
||||||
loop.next(args.reduce_shape.data(), args.reduce_strides.data());
|
loop.next(args.reduce_shape.data(), args.reduce_strides.data());
|
||||||
}
|
}
|
||||||
|
|
||||||
typedef cub::BlockReduce<U, BLOCK_DIM_X> BlockReduceT;
|
__shared__ U shared_accumulators[32];
|
||||||
__shared__ typename BlockReduceT::TempStorage temp;
|
block_reduce(block, warp, total, shared_accumulators, op, init);
|
||||||
|
|
||||||
total_val = BlockReduceT(temp).Reduce(total_val, op);
|
|
||||||
|
|
||||||
if (block.thread_rank() == 0) {
|
if (block.thread_rank() == 0) {
|
||||||
out[out_idx] = total_val;
|
out[out_idx] = total[0];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace cu
|
} // namespace cu
|
||||||
|
|
||||||
|
void row_reduce_simple(
|
||||||
|
cu::CommandEncoder& encoder,
|
||||||
|
const array& in,
|
||||||
|
array& out,
|
||||||
|
Reduce::ReduceType reduce_type,
|
||||||
|
const std::vector<int>& axes,
|
||||||
|
const ReductionPlan& plan) {
|
||||||
|
constexpr int N_READS = 8;
|
||||||
|
|
||||||
|
// Allocate data for the output using in's layout to avoid elem_to_loc in the
|
||||||
|
// kernel.
|
||||||
|
allocate_same_layout(out, in, axes);
|
||||||
|
|
||||||
|
// Just a way to get out of the constness because cub doesn't like it ...
|
||||||
|
// (sigh)
|
||||||
|
array x = in;
|
||||||
|
|
||||||
|
// TODO: If out.size() < 1024 which will be a common case then write this in
|
||||||
|
// 2 passes. Something like 32 * out.size() and then do a warp reduce.
|
||||||
|
encoder.set_input_array(x);
|
||||||
|
encoder.set_output_array(out);
|
||||||
|
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||||
|
MLX_SWITCH_ALL_TYPES(x.dtype(), CTYPE, {
|
||||||
|
MLX_SWITCH_REDUCE_OPS(reduce_type, OP, {
|
||||||
|
using T = cuda_type_t<CTYPE>;
|
||||||
|
using U = cu::ReduceResult<OP, T>::type;
|
||||||
|
|
||||||
|
// Calculate the grid and block dims
|
||||||
|
size_t reductions = (plan.shape.back() + N_READS - 1) / N_READS;
|
||||||
|
dim3 grid = get_2d_grid_dims(out.shape(), out.strides());
|
||||||
|
int threads = std::min(1024UL, reductions);
|
||||||
|
threads = ((threads + WARP_SIZE - 1) / WARP_SIZE) * WARP_SIZE;
|
||||||
|
dim3 block(threads, 1, 1);
|
||||||
|
|
||||||
|
// Pick the kernel
|
||||||
|
auto kernel = cu::row_reduce_simple<T, U, OP, N_READS>;
|
||||||
|
if (grid.x >= 1024) {
|
||||||
|
grid.x = (grid.x + 1) / 2;
|
||||||
|
kernel = cu::row_reduce_simple<T, U, OP, N_READS, 2>;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Launch
|
||||||
|
kernel<<<grid, block, 0, stream>>>(
|
||||||
|
x.data<T>(), out.data<U>(), out.size(), plan.shape.back());
|
||||||
|
});
|
||||||
|
});
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
void row_reduce_looped(
|
||||||
|
cu::CommandEncoder& encoder,
|
||||||
|
const array& in,
|
||||||
|
array& out,
|
||||||
|
Reduce::ReduceType reduce_type,
|
||||||
|
const std::vector<int>& axes,
|
||||||
|
const ReductionPlan& plan,
|
||||||
|
cu::RowReduceArgs args) {
|
||||||
|
constexpr int N_READS = 8;
|
||||||
|
|
||||||
|
// Allocate data for the output using in's layout to access them as
|
||||||
|
// contiguously as possible.
|
||||||
|
allocate_same_layout(out, in, axes);
|
||||||
|
|
||||||
|
// Just a way to get out of the constness because cub doesn't like it ...
|
||||||
|
// (sigh)
|
||||||
|
array x = in;
|
||||||
|
|
||||||
|
encoder.set_input_array(x);
|
||||||
|
encoder.set_output_array(out);
|
||||||
|
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||||
|
MLX_SWITCH_ALL_TYPES(x.dtype(), CTYPE, {
|
||||||
|
MLX_SWITCH_REDUCE_OPS(reduce_type, OP, {
|
||||||
|
using T = cuda_type_t<CTYPE>;
|
||||||
|
using U = cu::ReduceResult<OP, T>::type;
|
||||||
|
|
||||||
|
// Calculate the grid and block dims
|
||||||
|
args.convert_shapes_to_contiguous(x, axes);
|
||||||
|
dim3 grid = get_2d_grid_dims(out.shape(), out.strides());
|
||||||
|
size_t reductions = (args.row_size + N_READS - 1) / N_READS;
|
||||||
|
int threads = std::min(1024UL, reductions);
|
||||||
|
threads = ((threads + WARP_SIZE - 1) / WARP_SIZE) * WARP_SIZE;
|
||||||
|
dim3 block(threads, 1, 1);
|
||||||
|
|
||||||
|
// Pick the kernel
|
||||||
|
auto kernel = cu::row_reduce_looped<T, U, OP, 1, 32, N_READS>;
|
||||||
|
MLX_SWITCH_REDUCE_NDIM(args.reduce_ndim, NDIM, {
|
||||||
|
MLX_SWITCH_BLOCK_DIM(threads, THREADS, {
|
||||||
|
kernel = cu::row_reduce_looped<T, U, OP, NDIM, THREADS, N_READS>;
|
||||||
|
block.x = THREADS;
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
// Launch
|
||||||
|
kernel<<<grid, block, 0, stream>>>(
|
||||||
|
x.data<T>(), out.data<U>(), out.size(), args);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
void row_reduce(
|
void row_reduce(
|
||||||
cu::CommandEncoder& encoder,
|
cu::CommandEncoder& encoder,
|
||||||
const array& in,
|
const array& in,
|
||||||
@ -197,54 +313,35 @@ void row_reduce(
|
|||||||
Reduce::ReduceType reduce_type,
|
Reduce::ReduceType reduce_type,
|
||||||
const std::vector<int>& axes,
|
const std::vector<int>& axes,
|
||||||
const ReductionPlan& plan) {
|
const ReductionPlan& plan) {
|
||||||
|
// Current row reduction options
|
||||||
|
//
|
||||||
|
// - row_reduce_simple
|
||||||
|
//
|
||||||
|
// That means that we are simply reducing across the fastest moving axis.
|
||||||
|
// We are reducing 1 or 2 rows per threadblock depending on the size of
|
||||||
|
// output.
|
||||||
|
//
|
||||||
|
// - row_reduce_looped
|
||||||
|
//
|
||||||
|
// It is a general row reduction. We are computing 1 output per
|
||||||
|
// threadblock. We read the fastest moving axis vectorized and loop over
|
||||||
|
// the rest of the axes.
|
||||||
|
//
|
||||||
|
// Notes: We opt to read as much in order as possible and leave
|
||||||
|
// transpositions as they are (contrary to our Metal backend).
|
||||||
|
|
||||||
|
// Simple row reduce means that we have 1 axis that we are reducing over and
|
||||||
|
// it has stride 1.
|
||||||
|
if (plan.shape.size() == 1) {
|
||||||
|
row_reduce_simple(encoder, in, out, reduce_type, axes, plan);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Make the args struct to help route to the best kernel
|
||||||
cu::RowReduceArgs args(in, plan, axes);
|
cu::RowReduceArgs args(in, plan, axes);
|
||||||
|
|
||||||
encoder.launch_kernel([&](cudaStream_t stream) {
|
// Fallback row reduce
|
||||||
MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE, {
|
row_reduce_looped(encoder, in, out, reduce_type, axes, plan, std::move(args));
|
||||||
using InType = cuda_type_t<CTYPE>;
|
|
||||||
MLX_SWITCH_REDUCE_OPS(reduce_type, OP, {
|
|
||||||
using OutType = cu::ReduceResult<OP, InType>::type;
|
|
||||||
MLX_SWITCH_REDUCE_NDIM(args.reduce_ndim, NDIM, {
|
|
||||||
constexpr size_t N_READS = 4;
|
|
||||||
dim3 out_dims = get_2d_grid_dims(out.shape(), out.strides());
|
|
||||||
dim3 block_dims, num_blocks;
|
|
||||||
auto kernel =
|
|
||||||
cu::row_reduce_small<InType, OutType, OP, NDIM, N_READS>;
|
|
||||||
if (args.row_size <= 64) {
|
|
||||||
if ((args.non_row_reductions < 32 && args.row_size <= 8) ||
|
|
||||||
(args.non_row_reductions <= 8)) {
|
|
||||||
block_dims.x = std::min(out_dims.x, 1024u);
|
|
||||||
num_blocks.x = cuda::ceil_div(out_dims.x, block_dims.x);
|
|
||||||
num_blocks.y = out_dims.y;
|
|
||||||
} else {
|
|
||||||
block_dims.x = WARP_SIZE;
|
|
||||||
num_blocks.y = out_dims.x;
|
|
||||||
num_blocks.z = out_dims.y;
|
|
||||||
kernel =
|
|
||||||
cu::row_reduce_small_warp<InType, OutType, OP, NDIM, N_READS>;
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
size_t num_threads = cuda::ceil_div(args.row_size, N_READS);
|
|
||||||
num_threads = cuda::ceil_div(num_threads, WARP_SIZE) * WARP_SIZE;
|
|
||||||
MLX_SWITCH_BLOCK_DIM(num_threads, BLOCK_DIM_X, {
|
|
||||||
num_blocks.y = out_dims.x;
|
|
||||||
num_blocks.z = out_dims.y;
|
|
||||||
block_dims.x = BLOCK_DIM_X;
|
|
||||||
kernel = cu::row_reduce_looped<
|
|
||||||
InType,
|
|
||||||
OutType,
|
|
||||||
OP,
|
|
||||||
NDIM,
|
|
||||||
BLOCK_DIM_X,
|
|
||||||
N_READS>;
|
|
||||||
});
|
|
||||||
}
|
|
||||||
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
|
||||||
in.data<InType>(), out.data<OutType>(), out.size(), args);
|
|
||||||
});
|
|
||||||
});
|
|
||||||
});
|
|
||||||
});
|
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
@ -1,84 +0,0 @@
|
|||||||
// Copyright © 2025 Apple Inc.
|
|
||||||
|
|
||||||
#include "mlx/backend/cuda/device.h"
|
|
||||||
#include "mlx/backend/cuda/device/cast_op.cuh"
|
|
||||||
#include "mlx/backend/cuda/reduce/reduce.cuh"
|
|
||||||
|
|
||||||
#include <thrust/device_ptr.h>
|
|
||||||
#include <cub/device/device_reduce.cuh>
|
|
||||||
#include <cub/device/device_segmented_reduce.cuh>
|
|
||||||
|
|
||||||
namespace mlx::core {
|
|
||||||
|
|
||||||
template <typename... Args>
|
|
||||||
void cub_all_reduce(cu::CommandEncoder& encoder, Args&&... args) {
|
|
||||||
// Allocate temporary storage.
|
|
||||||
size_t size;
|
|
||||||
CHECK_CUDA_ERROR(cub::DeviceReduce::Reduce(nullptr, size, args...));
|
|
||||||
array temp(allocator::malloc(size), {static_cast<int>(size)}, uint8);
|
|
||||||
encoder.add_temporary(temp);
|
|
||||||
// Run op.
|
|
||||||
CHECK_CUDA_ERROR(cub::DeviceReduce::Reduce(temp.data<void>(), size, args...));
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename... Args>
|
|
||||||
void cub_segmented_reduce(cu::CommandEncoder& encoder, Args&&... args) {
|
|
||||||
// Allocate temporary storage.
|
|
||||||
size_t size;
|
|
||||||
CHECK_CUDA_ERROR(cub::DeviceSegmentedReduce::Reduce(nullptr, size, args...));
|
|
||||||
array temp(allocator::malloc(size), {static_cast<int>(size)}, uint8);
|
|
||||||
encoder.add_temporary(temp);
|
|
||||||
// Run op.
|
|
||||||
CHECK_CUDA_ERROR(
|
|
||||||
cub::DeviceSegmentedReduce::Reduce(temp.data<void>(), size, args...));
|
|
||||||
}
|
|
||||||
|
|
||||||
struct MultiplyOp {
|
|
||||||
int factor;
|
|
||||||
__device__ int operator()(int i) {
|
|
||||||
return i * factor;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
void segmented_reduce(
|
|
||||||
cu::CommandEncoder& encoder,
|
|
||||||
const array& in,
|
|
||||||
array& out,
|
|
||||||
Reduce::ReduceType reduce_type,
|
|
||||||
const std::vector<int>& axes,
|
|
||||||
const ReductionPlan& plan) {
|
|
||||||
encoder.launch_kernel([&](cudaStream_t stream) {
|
|
||||||
MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE, {
|
|
||||||
MLX_SWITCH_REDUCE_OPS(reduce_type, OP, {
|
|
||||||
using InType = cuda_type_t<CTYPE>;
|
|
||||||
using OutType = cu::ReduceResult<OP, InType>::type;
|
|
||||||
auto in_iter = cu::make_cast_iterator<OutType>(
|
|
||||||
thrust::device_pointer_cast(in.data<InType>()));
|
|
||||||
auto out_ptr = thrust::device_pointer_cast(out.data<OutType>());
|
|
||||||
auto init = cu::ReduceInit<OP, InType>::value();
|
|
||||||
|
|
||||||
if (plan.type == ContiguousAllReduce) {
|
|
||||||
cub_all_reduce(
|
|
||||||
encoder, in_iter, out_ptr, in.data_size(), OP(), init, stream);
|
|
||||||
} else if (plan.type == ContiguousReduce) {
|
|
||||||
auto offsets = thrust::make_transform_iterator(
|
|
||||||
thrust::make_counting_iterator(0), MultiplyOp{plan.shape.back()});
|
|
||||||
cub_segmented_reduce(
|
|
||||||
encoder,
|
|
||||||
in_iter,
|
|
||||||
out_ptr,
|
|
||||||
out.size(),
|
|
||||||
offsets,
|
|
||||||
offsets + 1,
|
|
||||||
OP(),
|
|
||||||
init,
|
|
||||||
stream);
|
|
||||||
} else {
|
|
||||||
throw std::runtime_error("Unsupported plan in segmented_reduce.");
|
|
||||||
}
|
|
||||||
});
|
|
||||||
});
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace mlx::core
|
|
@ -51,7 +51,7 @@ __global__ void softmax(const T* in, T* out, int axis_size) {
|
|||||||
make_cast_iterator<AccT>(in),
|
make_cast_iterator<AccT>(in),
|
||||||
vals,
|
vals,
|
||||||
axis_size,
|
axis_size,
|
||||||
Limits<AccT>::finite_min());
|
Limits<AccT>::min());
|
||||||
prevmax = maxval;
|
prevmax = maxval;
|
||||||
maxval = max_op(maxval, cub::ThreadReduce(vals, max_op));
|
maxval = max_op(maxval, cub::ThreadReduce(vals, max_op));
|
||||||
// Online normalizer calculation for softmax:
|
// Online normalizer calculation for softmax:
|
||||||
@ -79,7 +79,7 @@ __global__ void softmax(const T* in, T* out, int axis_size) {
|
|||||||
block.sync();
|
block.sync();
|
||||||
maxval = warp.thread_rank() < warp.meta_group_size()
|
maxval = warp.thread_rank() < warp.meta_group_size()
|
||||||
? local_max[warp.thread_rank()]
|
? local_max[warp.thread_rank()]
|
||||||
: Limits<AccT>::finite_min();
|
: Limits<AccT>::min();
|
||||||
maxval = cg::reduce(warp, maxval, max_op);
|
maxval = cg::reduce(warp, maxval, max_op);
|
||||||
normalizer = normalizer * softmax_exp(prevmax - maxval);
|
normalizer = normalizer * softmax_exp(prevmax - maxval);
|
||||||
if (warp.thread_rank() == 0) {
|
if (warp.thread_rank() == 0) {
|
||||||
|
@ -79,9 +79,6 @@ void segmented_sort(cu::CommandEncoder& encoder, Args&&... args) {
|
|||||||
void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) {
|
void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) {
|
||||||
array out = out_;
|
array out = out_;
|
||||||
auto& encoder = cu::get_command_encoder(s);
|
auto& encoder = cu::get_command_encoder(s);
|
||||||
encoder.set_input_array(in);
|
|
||||||
encoder.set_output_array(out);
|
|
||||||
|
|
||||||
if (axis < 0) {
|
if (axis < 0) {
|
||||||
axis += in.ndim();
|
axis += in.ndim();
|
||||||
}
|
}
|
||||||
@ -106,6 +103,8 @@ void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) {
|
|||||||
in.flags());
|
in.flags());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
encoder.set_input_array(in);
|
||||||
|
encoder.set_output_array(out);
|
||||||
encoder.launch_kernel([&](cudaStream_t stream) {
|
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||||
MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE, {
|
MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE, {
|
||||||
if constexpr (!std::is_same_v<CTYPE, complex64_t>) {
|
if constexpr (!std::is_same_v<CTYPE, complex64_t>) {
|
||||||
|
@ -413,7 +413,7 @@ class Module(dict):
|
|||||||
f'Module does not have sub-module named "{k}".'
|
f'Module does not have sub-module named "{k}".'
|
||||||
)
|
)
|
||||||
elif isinstance(modules, list):
|
elif isinstance(modules, list):
|
||||||
for i in range(len(dst)):
|
for i in range(len(modules)):
|
||||||
current_value = dst[i]
|
current_value = dst[i]
|
||||||
new_value = modules[i]
|
new_value = modules[i]
|
||||||
if self.is_module(current_value) and self.is_module(new_value):
|
if self.is_module(current_value) and self.is_module(new_value):
|
||||||
|
17
python/scripts/repair_cuda.sh
Normal file
17
python/scripts/repair_cuda.sh
Normal file
@ -0,0 +1,17 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
auditwheel repair dist/* \
|
||||||
|
--plat manylinux_2_35_x86_64 \
|
||||||
|
--exclude libcublas* \
|
||||||
|
--exclude libnvrtc*
|
||||||
|
|
||||||
|
cd wheelhouse
|
||||||
|
repaired_wheel=$(find . -name "*.whl" -print -quit)
|
||||||
|
unzip -q "${repaired_wheel}"
|
||||||
|
core_so=$(find mlx -name "core*.so" -print -quit)
|
||||||
|
rpath=$(patchelf --print-rpath "${core_so}")
|
||||||
|
rpath=$rpath:\$ORIGIN/../nvidia/cublas/lib:\$ORIGIN/../nvidia/cuda_nvrtc/lib
|
||||||
|
patchelf --force-rpath --set-rpath "$rpath" "$core_so"
|
||||||
|
|
||||||
|
# Re-zip the repaired wheel
|
||||||
|
zip -r -q "${repaired_wheel}" .
|
@ -13,7 +13,6 @@ cuda_skip = {
|
|||||||
"TestLayers.test_upsample",
|
"TestLayers.test_upsample",
|
||||||
"TestOps.test_complex_ops",
|
"TestOps.test_complex_ops",
|
||||||
"TestOps.test_dynamic_slicing",
|
"TestOps.test_dynamic_slicing",
|
||||||
"TestOps.test_softmax",
|
|
||||||
"TestReduce.test_axis_permutation_sums",
|
"TestReduce.test_axis_permutation_sums",
|
||||||
"TestReduce.test_dtypes",
|
"TestReduce.test_dtypes",
|
||||||
"TestReduce.test_expand_sums",
|
"TestReduce.test_expand_sums",
|
||||||
|
@ -259,6 +259,11 @@ class TestBase(mlx_tests.MLXTestCase):
|
|||||||
with self.assertRaises(ValueError):
|
with self.assertRaises(ValueError):
|
||||||
m = m.update_modules({"list": ["hi"]})
|
m = m.update_modules({"list": ["hi"]})
|
||||||
|
|
||||||
|
# Allow updating a strict subset
|
||||||
|
m = nn.Sequential(nn.Linear(3, 3), nn.Linear(3, 3))
|
||||||
|
m.update_modules({"layers": [{}, nn.Linear(3, 4)]})
|
||||||
|
self.assertEqual(m.layers[1].weight.shape, (4, 3))
|
||||||
|
|
||||||
|
|
||||||
class TestLayers(mlx_tests.MLXTestCase):
|
class TestLayers(mlx_tests.MLXTestCase):
|
||||||
def test_identity(self):
|
def test_identity(self):
|
||||||
|
8
setup.py
8
setup.py
@ -174,20 +174,26 @@ if __name__ == "__main__":
|
|||||||
)
|
)
|
||||||
package_dir = {"": "python"}
|
package_dir = {"": "python"}
|
||||||
package_data = {"mlx": ["lib/*", "include/*", "share/*"], "mlx.core": ["*.pyi"]}
|
package_data = {"mlx": ["lib/*", "include/*", "share/*"], "mlx.core": ["*.pyi"]}
|
||||||
|
install_requires = []
|
||||||
|
build_cuda = "MLX_BUILD_CUDA=ON" in os.environ.get("CMAKE_ARGS", "")
|
||||||
|
if build_cuda:
|
||||||
|
install_requires = ["nvidia-cublas-cu12", "nvidia-cuda-nvrtc-cu12"]
|
||||||
|
|
||||||
setup(
|
setup(
|
||||||
name="mlx",
|
name="mlx-cuda" if build_cuda else "mlx",
|
||||||
version=get_version(),
|
version=get_version(),
|
||||||
author="MLX Contributors",
|
author="MLX Contributors",
|
||||||
author_email="mlx@group.apple.com",
|
author_email="mlx@group.apple.com",
|
||||||
description="A framework for machine learning on Apple silicon.",
|
description="A framework for machine learning on Apple silicon.",
|
||||||
long_description=long_description,
|
long_description=long_description,
|
||||||
long_description_content_type="text/markdown",
|
long_description_content_type="text/markdown",
|
||||||
|
license="MIT",
|
||||||
url="https://github.com/ml-explore/mlx",
|
url="https://github.com/ml-explore/mlx",
|
||||||
packages=packages,
|
packages=packages,
|
||||||
package_dir=package_dir,
|
package_dir=package_dir,
|
||||||
package_data=package_data,
|
package_data=package_data,
|
||||||
include_package_data=True,
|
include_package_data=True,
|
||||||
|
install_requires=install_requires,
|
||||||
extras_require={
|
extras_require={
|
||||||
"dev": [
|
"dev": [
|
||||||
"nanobind==2.4.0",
|
"nanobind==2.4.0",
|
||||||
|
Loading…
Reference in New Issue
Block a user