mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Compare commits
14 Commits
cc4de6a607
...
v0.26.2
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
58f3860306 | ||
|
|
dd4f53db63 | ||
|
|
3d5e17e507 | ||
|
|
33bf1a244b | ||
|
|
772f471ff2 | ||
|
|
2c11d10f8d | ||
|
|
656ed7f780 | ||
|
|
81bb9a2a9e | ||
|
|
5adf185f86 | ||
|
|
c9a9180584 | ||
|
|
76831ed83d | ||
|
|
b3d7b85376 | ||
|
|
cad5c0241c | ||
|
|
b8022c578a |
@@ -16,6 +16,9 @@ parameters:
|
||||
linux_release:
|
||||
type: boolean
|
||||
default: false
|
||||
cuda_release:
|
||||
type: boolean
|
||||
default: false
|
||||
|
||||
jobs:
|
||||
build_documentation:
|
||||
@@ -104,7 +107,7 @@ jobs:
|
||||
command: |
|
||||
echo "stubs"
|
||||
pip install typing_extensions
|
||||
python setup.py generate_stubs
|
||||
python setup.py generate_stubs
|
||||
- run:
|
||||
name: Run Python tests
|
||||
command: |
|
||||
@@ -162,7 +165,7 @@ jobs:
|
||||
command: |
|
||||
source env/bin/activate
|
||||
pip install typing_extensions
|
||||
python setup.py generate_stubs
|
||||
python setup.py generate_stubs
|
||||
- run:
|
||||
name: Run Python tests
|
||||
command: |
|
||||
@@ -223,7 +226,6 @@ jobs:
|
||||
command: |
|
||||
sudo apt-get update
|
||||
sudo apt-get install libblas-dev liblapack-dev liblapacke-dev
|
||||
sudo apt-get install openmpi-bin openmpi-common libopenmpi-dev
|
||||
python -m venv env
|
||||
source env/bin/activate
|
||||
CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
|
||||
@@ -283,7 +285,7 @@ jobs:
|
||||
command: |
|
||||
source env/bin/activate
|
||||
pip install typing_extensions
|
||||
python setup.py generate_stubs
|
||||
python setup.py generate_stubs
|
||||
- run:
|
||||
name: Build Python package
|
||||
command: |
|
||||
@@ -342,7 +344,7 @@ jobs:
|
||||
CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
|
||||
pip install . -v
|
||||
pip install typing_extensions
|
||||
python setup.py generate_stubs
|
||||
python setup.py generate_stubs
|
||||
<< parameters.extra_env >> \
|
||||
CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
|
||||
python -m build --wheel
|
||||
@@ -356,6 +358,48 @@ jobs:
|
||||
- store_artifacts:
|
||||
path: wheelhouse/
|
||||
|
||||
build_cuda_release:
|
||||
parameters:
|
||||
python_version:
|
||||
type: string
|
||||
default: "3.9"
|
||||
extra_env:
|
||||
type: string
|
||||
default: "DEV_RELEASE=1"
|
||||
machine:
|
||||
image: linux-cuda-12:default
|
||||
resource_class: gpu.nvidia.small.gen2
|
||||
steps:
|
||||
- checkout
|
||||
- run:
|
||||
name: Build wheel
|
||||
command: |
|
||||
sudo apt-get update
|
||||
sudo apt-get install libblas-dev liblapack-dev liblapacke-dev
|
||||
python -m venv env
|
||||
source env/bin/activate
|
||||
pip install auditwheel
|
||||
pip install patchelf
|
||||
pip install build
|
||||
pip install twine
|
||||
<< parameters.extra_env >> \
|
||||
CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
|
||||
CMAKE_ARGS="-DMLX_BUILD_CUDA=ON -DCMAKE_CUDA_COMPILER=`which nvcc`" \
|
||||
pip install ".[dev]" -v
|
||||
python setup.py generate_stubs
|
||||
<< parameters.extra_env >> \
|
||||
CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
|
||||
CMAKE_ARGS="-DMLX_BUILD_CUDA=ON -DCMAKE_CUDA_COMPILER=`which nvcc`" \
|
||||
python -m build --wheel
|
||||
bash python/scripts/repair_cuda.sh
|
||||
- run:
|
||||
name: Upload package
|
||||
command: |
|
||||
source env/bin/activate
|
||||
twine upload wheelhouse/*.whl
|
||||
- store_artifacts:
|
||||
path: wheelhouse/
|
||||
|
||||
workflows:
|
||||
build_and_test:
|
||||
when:
|
||||
@@ -625,3 +669,14 @@ workflows:
|
||||
parameters:
|
||||
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
|
||||
extra_env: ["PYPI_RELEASE=1"]
|
||||
cuda_test_release:
|
||||
when:
|
||||
and:
|
||||
- equal: [ main, << pipeline.git.branch >> ]
|
||||
- << pipeline.parameters.cuda_release >>
|
||||
jobs:
|
||||
- build_cuda_release:
|
||||
matrix:
|
||||
parameters:
|
||||
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
|
||||
extra_env: ["PYPI_RELEASE=1"]
|
||||
|
||||
@@ -35,7 +35,6 @@ option(MLX_BUILD_PYTHON_BINDINGS "Build python bindings for mlx" OFF)
|
||||
option(MLX_BUILD_METAL "Build metal backend" ON)
|
||||
option(MLX_BUILD_CPU "Build cpu backend" ON)
|
||||
option(MLX_BUILD_CUDA "Build cuda backend" OFF)
|
||||
option(MLX_BUILD_ROCM "Build ROCm backend" OFF)
|
||||
option(MLX_METAL_DEBUG "Enhance metal debug workflow" OFF)
|
||||
option(MLX_ENABLE_X64_MAC "Enable building for x64 macOS" OFF)
|
||||
option(MLX_BUILD_GGUF "Include support for GGUF format" ON)
|
||||
@@ -89,10 +88,6 @@ if(MLX_BUILD_CUDA)
|
||||
enable_language(CUDA)
|
||||
endif()
|
||||
|
||||
if(MLX_BUILD_ROCM)
|
||||
enable_language(HIP)
|
||||
endif()
|
||||
|
||||
if(MLX_BUILD_METAL AND NOT METAL_LIB)
|
||||
message(STATUS "Metal not found. Unable to build GPU")
|
||||
set(MLX_BUILD_METAL OFF)
|
||||
|
||||
@@ -5,6 +5,7 @@ import os
|
||||
import time
|
||||
|
||||
import torch
|
||||
import torch.cuda
|
||||
import torch.mps
|
||||
|
||||
|
||||
@@ -44,8 +45,10 @@ def bench(f, *args):
|
||||
|
||||
|
||||
def sync_if_needed(x):
|
||||
if x.device != torch.device("cpu"):
|
||||
if x.device == torch.device("mps"):
|
||||
torch.mps.synchronize()
|
||||
elif x.device == torch.device("cuda"):
|
||||
torch.cuda.synchronize()
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
@@ -99,6 +102,14 @@ def reduction(op, axis, x):
|
||||
sync_if_needed(x)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def sum_and_add(axis, x, y):
|
||||
z = x.sum(axis=axis, keepdims=True)
|
||||
for i in range(50):
|
||||
z = (z + y).sum(axis=axis, keepdims=True)
|
||||
sync_if_needed(x)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def softmax(axis, x):
|
||||
ys = []
|
||||
@@ -340,7 +351,11 @@ if __name__ == "__main__":
|
||||
args.axis.pop(0)
|
||||
|
||||
torch.set_num_threads(1)
|
||||
device = "cpu" if args.cpu else "mps"
|
||||
device = "mps"
|
||||
if torch.cuda.is_available():
|
||||
device = "cuda"
|
||||
if args.cpu:
|
||||
device = "cpu"
|
||||
|
||||
types = args.dtype
|
||||
if not types:
|
||||
@@ -460,5 +475,8 @@ if __name__ == "__main__":
|
||||
elif args.benchmark == "selu":
|
||||
print(bench(selu, x))
|
||||
|
||||
elif args.benchmark == "sum_and_add":
|
||||
print(bench(sum_and_add, axis, *xs))
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unknown benchmark `{args.benchmark}`.")
|
||||
|
||||
@@ -30,6 +30,16 @@ MLX is also available on conda-forge. To install MLX with conda do:
|
||||
|
||||
conda install conda-forge::mlx
|
||||
|
||||
CUDA
|
||||
^^^^
|
||||
|
||||
MLX has a CUDA backend which you can use on any Linux platform with CUDA 12
|
||||
and SM 7.0 (Volta) and up. To install MLX with CUDA support, run:
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
pip install mlx-cuda
|
||||
|
||||
|
||||
Troubleshooting
|
||||
^^^^^^^^^^^^^^^
|
||||
@@ -65,6 +75,8 @@ Build Requirements
|
||||
Python API
|
||||
^^^^^^^^^^
|
||||
|
||||
.. _python install:
|
||||
|
||||
To build and install the MLX python library from source, first, clone MLX from
|
||||
`its GitHub repo <https://github.com/ml-explore/mlx>`_:
|
||||
|
||||
@@ -107,6 +119,8 @@ IDE:
|
||||
C++ API
|
||||
^^^^^^^
|
||||
|
||||
.. _cpp install:
|
||||
|
||||
Currently, MLX must be built and installed from source.
|
||||
|
||||
Similarly to the python library, to build and install the MLX C++ library start
|
||||
@@ -185,6 +199,7 @@ should point to the path to the built metal library.
|
||||
|
||||
xcrun -sdk macosx --show-sdk-version
|
||||
|
||||
|
||||
Binary Size Minimization
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
@@ -213,6 +228,50 @@ be anwywhere from a few hundred millisecond to a few seconds depending on the
|
||||
application. Once a kernel is compiled, it will be cached by the system. The
|
||||
Metal kernel cache persists across reboots.
|
||||
|
||||
Linux
|
||||
^^^^^
|
||||
|
||||
To build from source on Linux (CPU only), install the BLAS and LAPACK headers.
|
||||
For example on Ubuntu, run the following:
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
apt-get update -y
|
||||
apt-get install libblas-dev liblapack-dev liblapacke-dev -y
|
||||
|
||||
From here follow the instructions to install either the :ref:`Python <python
|
||||
install>` or :ref:`C++ <cpp install>` APIs.
|
||||
|
||||
CUDA
|
||||
^^^^
|
||||
|
||||
To build from source on Linux with CUDA, install the BLAS and LAPACK headers
|
||||
and the CUDA toolkit. For example on Ubuntu, run the following:
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64/cuda-keyring_1.1-1_all.deb
|
||||
dpkg -i cuda-keyring_1.1-1_all.deb
|
||||
apt-get update -y
|
||||
apt-get -y install cuda-toolkit-12-9
|
||||
apt-get install libblas-dev liblapack-dev liblapacke-dev -y
|
||||
|
||||
|
||||
When building either the Python or C++ APIs make sure to pass the cmake flag
|
||||
``MLX_BUILD_CUDA=ON``. For example, to build the Python API run:
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
CMAKE_BUILD_PARALLEL_LEVEL=8 CMAKE_ARGS="-DMLX_BUILD_CUDA=ON" pip install -e ".[dev]"
|
||||
|
||||
To build the C++ package run:
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
mkdir -p build && cd build
|
||||
cmake .. -DMLX_BUILD_CUDA=ON && make -j
|
||||
|
||||
|
||||
Troubleshooting
|
||||
^^^^^^^^^^^^^^^
|
||||
|
||||
|
||||
@@ -60,16 +60,7 @@ else()
|
||||
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/backend/cuda/no_cuda.cpp)
|
||||
endif()
|
||||
|
||||
if(MLX_BUILD_ROCM)
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/rocm)
|
||||
else()
|
||||
target_sources(mlx
|
||||
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/backend/rocm/no_rocm.cpp)
|
||||
endif()
|
||||
|
||||
if(MLX_BUILD_METAL
|
||||
OR MLX_BUILD_CUDA
|
||||
OR MLX_BUILD_ROCM)
|
||||
if(MLX_BUILD_METAL OR MLX_BUILD_CUDA)
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/gpu)
|
||||
else()
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/no_gpu)
|
||||
|
||||
@@ -14,6 +14,8 @@ void print_constant(std::ostream& os, const array& x) {
|
||||
return print_float_constant<float16_t>(os, x);
|
||||
case bfloat16:
|
||||
return print_float_constant<bfloat16_t>(os, x);
|
||||
case float64:
|
||||
return print_float_constant<double>(os, x);
|
||||
case complex64:
|
||||
return print_complex_constant<complex64_t>(os, x);
|
||||
case int8:
|
||||
@@ -50,6 +52,8 @@ std::string get_type_string(Dtype d) {
|
||||
return "float16_t";
|
||||
case bfloat16:
|
||||
return "bfloat16_t";
|
||||
case float64:
|
||||
return "double";
|
||||
case complex64:
|
||||
return "complex64_t";
|
||||
case bool_:
|
||||
|
||||
@@ -18,8 +18,12 @@ std::string get_type_string(Dtype d);
|
||||
template <typename T>
|
||||
void print_float_constant(std::ostream& os, const array& x) {
|
||||
auto old_precision = os.precision();
|
||||
os << std::setprecision(std::numeric_limits<float>::digits10 + 1)
|
||||
<< x.item<T>() << std::setprecision(old_precision);
|
||||
if constexpr (std::is_same_v<T, double>) {
|
||||
os << std::setprecision(std::numeric_limits<double>::digits10 + 1);
|
||||
} else {
|
||||
os << std::setprecision(std::numeric_limits<float>::digits10 + 1);
|
||||
}
|
||||
os << x.item<T>() << std::setprecision(old_precision);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
|
||||
@@ -5,11 +5,9 @@
|
||||
namespace mlx::core {
|
||||
|
||||
std::pair<Shape, Strides> shapes_without_reduction_axes(
|
||||
const array& x,
|
||||
Shape shape,
|
||||
Strides strides,
|
||||
const std::vector<int>& axes) {
|
||||
auto shape = x.shape();
|
||||
auto strides = x.strides();
|
||||
|
||||
for (int i = axes.size() - 1; i >= 0; i--) {
|
||||
int a = axes[i];
|
||||
shape.erase(shape.begin() + a);
|
||||
@@ -19,6 +17,15 @@ std::pair<Shape, Strides> shapes_without_reduction_axes(
|
||||
return std::make_pair(shape, strides);
|
||||
}
|
||||
|
||||
std::pair<Shape, Strides> shapes_without_reduction_axes(
|
||||
const array& x,
|
||||
const std::vector<int>& axes) {
|
||||
auto shape = x.shape();
|
||||
auto strides = x.strides();
|
||||
return shapes_without_reduction_axes(
|
||||
std::move(shape), std::move(strides), axes);
|
||||
}
|
||||
|
||||
ReductionPlan get_reduction_plan(const array& x, const std::vector<int>& axes) {
|
||||
// The data is all there and we are reducing over everything
|
||||
if (x.size() == x.data_size() && axes.size() == x.ndim() &&
|
||||
|
||||
@@ -51,5 +51,9 @@ ReductionPlan get_reduction_plan(const array& x, const std::vector<int>& axes);
|
||||
std::pair<Shape, Strides> shapes_without_reduction_axes(
|
||||
const array& x,
|
||||
const std::vector<int>& axes);
|
||||
std::pair<Shape, Strides> shapes_without_reduction_axes(
|
||||
Shape shape,
|
||||
Strides strides,
|
||||
const std::vector<int>& axes);
|
||||
|
||||
} // namespace mlx::core
|
||||
|
||||
@@ -199,12 +199,15 @@ Dims get_2d_grid_dims_common(
|
||||
}
|
||||
}
|
||||
}
|
||||
if (grid_y > UINT32_MAX || grid_x > UINT32_MAX || divisor > 1) {
|
||||
if (grid_y > UINT32_MAX || grid_x > UINT32_MAX) {
|
||||
throw std::runtime_error("Unable to safely factor shape.");
|
||||
}
|
||||
if (grid_y > grid_x) {
|
||||
std::swap(grid_x, grid_y);
|
||||
}
|
||||
if (divisor > 1) {
|
||||
grid_x = ((grid_x + divisor - 1) / divisor) * divisor;
|
||||
}
|
||||
return std::make_tuple(
|
||||
static_cast<uint32_t>(grid_x), static_cast<uint32_t>(grid_y), 1);
|
||||
}
|
||||
|
||||
@@ -8,6 +8,7 @@ target_sources(
|
||||
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/arg_reduce.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/binary.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/binary_two.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/compiled.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/copy.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/copy/copy_contiguous.cu
|
||||
@@ -28,9 +29,10 @@ target_sources(
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/random.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/reduce.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/reduce/all_reduce.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/reduce/col_reduce.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/reduce/init_reduce.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/reduce/row_reduce.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/reduce/segmented_reduce.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/rms_norm.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/rope.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
#include "mlx/backend/cuda/allocator.h"
|
||||
#include "mlx/backend/cuda/utils.h"
|
||||
#include "mlx/backend/cuda/worker.h"
|
||||
#include "mlx/utils.h"
|
||||
|
||||
#include <cuda_runtime.h>
|
||||
#include <fmt/format.h>
|
||||
@@ -14,9 +15,11 @@ namespace mlx::core {
|
||||
|
||||
namespace cu {
|
||||
|
||||
constexpr int page_size = 16384;
|
||||
|
||||
CudaAllocator::CudaAllocator()
|
||||
: buffer_cache_(
|
||||
getpagesize(),
|
||||
page_size,
|
||||
[](CudaBuffer* buf) { return buf->size; },
|
||||
[this](CudaBuffer* buf) {
|
||||
cuda_free(buf->data);
|
||||
@@ -31,7 +34,14 @@ CudaAllocator::CudaAllocator()
|
||||
|
||||
Buffer CudaAllocator::malloc(size_t size) {
|
||||
// Find available buffer from cache.
|
||||
auto orig_size = size;
|
||||
std::unique_lock lock(mutex_);
|
||||
if (size < page_size) {
|
||||
size = next_power_of_2(size);
|
||||
} else {
|
||||
size = page_size * ((size + page_size - 1) / page_size);
|
||||
}
|
||||
|
||||
CudaBuffer* buf = buffer_cache_.reuse_from_cache(size);
|
||||
if (!buf) {
|
||||
// If we have a lot of memory pressure or are over the maximum cache size,
|
||||
@@ -106,7 +116,6 @@ void CudaAllocator::cuda_free(void* buf) {
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
cudaFree(buf);
|
||||
}
|
||||
|
||||
|
||||
@@ -152,35 +152,29 @@ void ArgReduce::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
encoder.set_input_array(in);
|
||||
encoder.set_output_array(out);
|
||||
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||
MLX_SWITCH_REAL_TYPES_CHECKED(in.dtype(), "ArgReduce", CTYPE, {
|
||||
using InType = cuda_type_t<CTYPE>;
|
||||
dispatch_real_types(in.dtype(), "ArgReduce", [&](auto type_tag) {
|
||||
using T = cuda_type_t<MLX_GET_TYPE(type_tag)>;
|
||||
constexpr uint32_t N_READS = 4;
|
||||
MLX_SWITCH_BLOCK_DIM(cuda::ceil_div(axis_size, N_READS), BLOCK_DIM, {
|
||||
dim3 num_blocks = get_2d_grid_dims(out.shape(), out.strides());
|
||||
dim3 block_dims{BLOCK_DIM, 1, 1};
|
||||
auto kernel = &cu::arg_reduce_general<
|
||||
InType,
|
||||
cu::ArgMax<InType>,
|
||||
BLOCK_DIM,
|
||||
N_READS>;
|
||||
if (reduce_type_ == ArgReduce::ArgMin) {
|
||||
kernel = &cu::arg_reduce_general<
|
||||
InType,
|
||||
cu::ArgMin<InType>,
|
||||
BLOCK_DIM,
|
||||
N_READS>;
|
||||
}
|
||||
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
||||
in.data<InType>(),
|
||||
out.data<uint32_t>(),
|
||||
out.size(),
|
||||
const_param(shape),
|
||||
const_param(in_strides),
|
||||
const_param(out_strides),
|
||||
ndim,
|
||||
axis_stride,
|
||||
axis_size);
|
||||
});
|
||||
dispatch_block_dim(
|
||||
cuda::ceil_div(axis_size, N_READS), [&](auto block_dim) {
|
||||
dim3 num_blocks = get_2d_grid_dims(out.shape(), out.strides());
|
||||
auto kernel =
|
||||
cu::arg_reduce_general<T, cu::ArgMax<T>, block_dim(), N_READS>;
|
||||
if (reduce_type_ == ArgReduce::ArgMin) {
|
||||
kernel = cu::
|
||||
arg_reduce_general<T, cu::ArgMin<T>, block_dim(), N_READS>;
|
||||
}
|
||||
kernel<<<num_blocks, block_dim(), 0, stream>>>(
|
||||
in.data<T>(),
|
||||
out.data<uint32_t>(),
|
||||
out.size(),
|
||||
const_param(shape),
|
||||
const_param(in_strides),
|
||||
const_param(out_strides),
|
||||
ndim,
|
||||
axis_stride,
|
||||
axis_size);
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
@@ -125,13 +125,12 @@ constexpr bool supports_binary_op() {
|
||||
template <typename Op>
|
||||
void binary_op_gpu_inplace(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs,
|
||||
array& out,
|
||||
std::string_view op,
|
||||
const Stream& s) {
|
||||
assert(inputs.size() > 1);
|
||||
const auto& a = inputs[0];
|
||||
const auto& b = inputs[1];
|
||||
auto& out = outputs[0];
|
||||
if (out.size() == 0) {
|
||||
return;
|
||||
}
|
||||
@@ -141,55 +140,64 @@ void binary_op_gpu_inplace(
|
||||
encoder.set_input_array(b);
|
||||
encoder.set_output_array(out);
|
||||
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||
MLX_SWITCH_ALL_TYPES(a.dtype(), CTYPE_IN, {
|
||||
MLX_SWITCH_ALL_TYPES(out.dtype(), CTYPE_OUT, {
|
||||
dispatch_all_types(a.dtype(), [&](auto in_type_tag) {
|
||||
dispatch_all_types(out.dtype(), [&](auto out_type_tag) {
|
||||
using CTYPE_IN = MLX_GET_TYPE(in_type_tag);
|
||||
using CTYPE_OUT = MLX_GET_TYPE(out_type_tag);
|
||||
if constexpr (cu::supports_binary_op<Op, CTYPE_IN, CTYPE_OUT>()) {
|
||||
using InType = cuda_type_t<CTYPE_IN>;
|
||||
using OutType = cuda_type_t<CTYPE_OUT>;
|
||||
|
||||
auto bopt = get_binary_op_type(a, b);
|
||||
if (bopt == BinaryOpType::General) {
|
||||
auto [shape, strides] = collapse_contiguous_dims(a, b, out);
|
||||
auto& a_strides = strides[0];
|
||||
auto& b_strides = strides[1];
|
||||
bool large = a.data_size() > INT32_MAX ||
|
||||
b.data_size() > INT32_MAX || out.data_size() > INT32_MAX;
|
||||
MLX_SWITCH_BOOL(large, LARGE, {
|
||||
using IdxT = std::conditional_t<LARGE, int64_t, int32_t>;
|
||||
int ndim = shape.size();
|
||||
if (ndim <= 3) {
|
||||
MLX_SWITCH_1_2_3(ndim, NDIM, {
|
||||
auto kernel =
|
||||
&cu::binary_g_nd<Op, InType, OutType, IdxT, NDIM>;
|
||||
auto [num_blocks, block_dims] =
|
||||
get_launch_args(kernel, out, large);
|
||||
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
||||
a.data<InType>(),
|
||||
b.data<InType>(),
|
||||
out.data<OutType>(),
|
||||
out.size(),
|
||||
const_param<NDIM>(shape),
|
||||
const_param<NDIM>(a_strides),
|
||||
const_param<NDIM>(b_strides));
|
||||
dispatch_bool(
|
||||
a.data_size() > INT32_MAX || b.data_size() > INT32_MAX ||
|
||||
out.data_size() > INT32_MAX,
|
||||
[&](auto large) {
|
||||
using IdxT = std::conditional_t<large(), int64_t, int32_t>;
|
||||
Shape shape;
|
||||
std::vector<Strides> strides;
|
||||
std::tie(shape, strides) =
|
||||
collapse_contiguous_dims(a, b, out);
|
||||
auto& a_strides = strides[0];
|
||||
auto& b_strides = strides[1];
|
||||
int ndim = shape.size();
|
||||
if (ndim <= 3) {
|
||||
dispatch_1_2_3(ndim, [&](auto dims_constant) {
|
||||
auto kernel = cu::binary_g_nd<
|
||||
Op,
|
||||
InType,
|
||||
OutType,
|
||||
IdxT,
|
||||
dims_constant()>;
|
||||
auto [num_blocks, block_dims] =
|
||||
get_launch_args(kernel, out, large());
|
||||
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
||||
a.data<InType>(),
|
||||
b.data<InType>(),
|
||||
out.data<OutType>(),
|
||||
out.size(),
|
||||
const_param<dims_constant()>(shape),
|
||||
const_param<dims_constant()>(a_strides),
|
||||
const_param<dims_constant()>(b_strides));
|
||||
});
|
||||
} else {
|
||||
auto kernel = cu::binary_g<Op, InType, OutType, IdxT>;
|
||||
auto [num_blocks, block_dims] =
|
||||
get_launch_args(kernel, out, large());
|
||||
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
||||
a.data<InType>(),
|
||||
b.data<InType>(),
|
||||
out.data<OutType>(),
|
||||
out.size(),
|
||||
const_param(shape),
|
||||
const_param(a_strides),
|
||||
const_param(b_strides),
|
||||
ndim);
|
||||
}
|
||||
});
|
||||
} else {
|
||||
auto kernel = cu::binary_g<Op, InType, OutType, IdxT>;
|
||||
auto [num_blocks, block_dims] =
|
||||
get_launch_args(kernel, out, large);
|
||||
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
||||
a.data<InType>(),
|
||||
b.data<InType>(),
|
||||
out.data<OutType>(),
|
||||
out.size(),
|
||||
const_param(shape),
|
||||
const_param(a_strides),
|
||||
const_param(b_strides),
|
||||
ndim);
|
||||
}
|
||||
});
|
||||
} else {
|
||||
MLX_SWITCH_BOOL(out.data_size() > UINT32_MAX, LARGE, {
|
||||
using IdxT = std::conditional_t<LARGE, int64_t, uint32_t>;
|
||||
dispatch_bool(out.data_size() > INT32_MAX, [&](auto large) {
|
||||
using IdxT = std::conditional_t<large(), int64_t, uint32_t>;
|
||||
auto kernel = cu::binary_ss<Op, InType, OutType, IdxT>;
|
||||
if (bopt == BinaryOpType::ScalarVector) {
|
||||
kernel = cu::binary_sv<Op, InType, OutType, IdxT>;
|
||||
@@ -199,7 +207,7 @@ void binary_op_gpu_inplace(
|
||||
kernel = cu::binary_vv<Op, InType, OutType, IdxT>;
|
||||
}
|
||||
auto [num_blocks, block_dims] = get_launch_args(
|
||||
kernel, out.data_size(), out.shape(), out.strides(), LARGE);
|
||||
kernel, out.data_size(), out.shape(), out.strides(), large());
|
||||
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
||||
a.data<InType>(),
|
||||
b.data<InType>(),
|
||||
@@ -219,20 +227,6 @@ void binary_op_gpu_inplace(
|
||||
});
|
||||
}
|
||||
|
||||
template <typename Op>
|
||||
void binary_op_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs,
|
||||
std::string_view op,
|
||||
const Stream& s) {
|
||||
auto& a = inputs[0];
|
||||
auto& b = inputs[1];
|
||||
auto bopt = get_binary_op_type(a, b);
|
||||
set_binary_op_output_data(a, b, outputs[0], bopt);
|
||||
set_binary_op_output_data(a, b, outputs[1], bopt);
|
||||
binary_op_gpu_inplace<Op>(inputs, outputs, op, s);
|
||||
}
|
||||
|
||||
template <typename Op>
|
||||
void binary_op_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
@@ -243,8 +237,7 @@ void binary_op_gpu(
|
||||
auto& b = inputs[1];
|
||||
auto bopt = get_binary_op_type(a, b);
|
||||
set_binary_op_output_data(a, b, out, bopt);
|
||||
std::vector<array> outputs{out};
|
||||
binary_op_gpu_inplace<Op>(inputs, outputs, op, s);
|
||||
binary_op_gpu_inplace<Op>(inputs, out, op, s);
|
||||
}
|
||||
|
||||
#define BINARY_GPU(func) \
|
||||
@@ -254,14 +247,6 @@ void binary_op_gpu(
|
||||
binary_op_gpu<cu::func>(inputs, out, get_primitive_string(this), s); \
|
||||
}
|
||||
|
||||
#define BINARY_GPU_MULTI(func) \
|
||||
void func::eval_gpu( \
|
||||
const std::vector<array>& inputs, std::vector<array>& outputs) { \
|
||||
nvtx3::scoped_range r(#func "::eval_gpu"); \
|
||||
auto& s = outputs[0].primitive().stream(); \
|
||||
binary_op_gpu<cu::func>(inputs, outputs, get_primitive_string(this), s); \
|
||||
}
|
||||
|
||||
BINARY_GPU(Add)
|
||||
BINARY_GPU(ArcTan2)
|
||||
BINARY_GPU(Divide)
|
||||
|
||||
258
mlx/backend/cuda/binary_two.cu
Normal file
258
mlx/backend/cuda/binary_two.cu
Normal file
@@ -0,0 +1,258 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/common/binary.h"
|
||||
#include "mlx/backend/cuda/device.h"
|
||||
#include "mlx/backend/cuda/device/binary_ops.cuh"
|
||||
#include "mlx/backend/cuda/device/cucomplex_math.cuh"
|
||||
#include "mlx/backend/cuda/kernel_utils.cuh"
|
||||
#include "mlx/dtype_utils.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
#include <cooperative_groups.h>
|
||||
#include <nvtx3/nvtx3.hpp>
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace cu {
|
||||
|
||||
namespace cg = cooperative_groups;
|
||||
|
||||
template <typename Op, typename In, typename Out, typename IdxT>
|
||||
__global__ void
|
||||
binary_ss(const In* a, const In* b, Out* out_a, Out* out_b, IdxT size) {
|
||||
IdxT index = cg::this_grid().thread_rank();
|
||||
if (index < size) {
|
||||
auto out = Op{}(a[0], b[0]);
|
||||
out_a[0] = out[0];
|
||||
out_b[0] = out[1];
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Op, typename In, typename Out, typename IdxT>
|
||||
__global__ void
|
||||
binary_sv(const In* a, const In* b, Out* out_a, Out* out_b, IdxT size) {
|
||||
IdxT index = cg::this_grid().thread_rank();
|
||||
if (index < size) {
|
||||
auto out = Op{}(a[0], b[index]);
|
||||
out_a[index] = out[0];
|
||||
out_b[index] = out[1];
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Op, typename In, typename Out, typename IdxT>
|
||||
__global__ void
|
||||
binary_vs(const In* a, const In* b, Out* out_a, Out* out_b, IdxT size) {
|
||||
IdxT index = cg::this_grid().thread_rank();
|
||||
if (index < size) {
|
||||
auto out = Op{}(a[index], b[0]);
|
||||
out_a[index] = out[0];
|
||||
out_b[index] = out[1];
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Op, typename In, typename Out, typename IdxT>
|
||||
__global__ void
|
||||
binary_vv(const In* a, const In* b, Out* out_a, Out* out_b, IdxT size) {
|
||||
IdxT index = cg::this_grid().thread_rank();
|
||||
if (index < size) {
|
||||
auto out = Op{}(a[index], b[index]);
|
||||
out_a[index] = out[0];
|
||||
out_b[index] = out[1];
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Op, typename In, typename Out, typename IdxT, int NDIM>
|
||||
__global__ void binary_g_nd(
|
||||
const In* a,
|
||||
const In* b,
|
||||
Out* out_a,
|
||||
Out* out_b,
|
||||
IdxT size,
|
||||
const __grid_constant__ cuda::std::array<int32_t, NDIM> shape,
|
||||
const __grid_constant__ cuda::std::array<int64_t, NDIM> a_strides,
|
||||
const __grid_constant__ cuda::std::array<int64_t, NDIM> b_strides) {
|
||||
IdxT index = cg::this_grid().thread_rank();
|
||||
if (index < size) {
|
||||
auto [a_idx, b_idx] = elem_to_loc_nd<NDIM>(
|
||||
index, shape.data(), a_strides.data(), b_strides.data());
|
||||
auto out = Op{}(a[a_idx], b[b_idx]);
|
||||
out_a[index] = out[0];
|
||||
out_b[index] = out[1];
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Op, typename In, typename Out, typename IdxT>
|
||||
__global__ void binary_g(
|
||||
const In* a,
|
||||
const In* b,
|
||||
Out* out_a,
|
||||
Out* out_b,
|
||||
IdxT size,
|
||||
const __grid_constant__ Shape shape,
|
||||
const __grid_constant__ Strides a_strides,
|
||||
const __grid_constant__ Strides b_strides,
|
||||
int ndim) {
|
||||
IdxT index = cg::this_grid().thread_rank();
|
||||
if (index < size) {
|
||||
auto [a_idx, b_idx] = elem_to_loc_4d(
|
||||
index, shape.data(), a_strides.data(), b_strides.data(), ndim);
|
||||
auto out = Op{}(a[a_idx], b[b_idx]);
|
||||
out_a[index] = out[0];
|
||||
out_b[index] = out[1];
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Op, typename In, typename Out>
|
||||
constexpr bool supports_binary_op() {
|
||||
if (std::is_same_v<Op, DivMod>) {
|
||||
return std::is_same_v<In, Out> &&
|
||||
(std::is_integral_v<Out> || is_floating_v<Out>);
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
} // namespace cu
|
||||
|
||||
template <typename Op>
|
||||
void binary_op_gpu_inplace(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs,
|
||||
std::string_view op,
|
||||
const Stream& s) {
|
||||
assert(inputs.size() > 1);
|
||||
const auto& a = inputs[0];
|
||||
const auto& b = inputs[1];
|
||||
auto& out_a = outputs[0];
|
||||
auto& out_b = outputs[1];
|
||||
auto bopt = get_binary_op_type(a, b);
|
||||
set_binary_op_output_data(a, b, out_a, bopt);
|
||||
set_binary_op_output_data(a, b, out_b, bopt);
|
||||
|
||||
if (out_a.size() == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
auto& encoder = cu::get_command_encoder(s);
|
||||
encoder.set_input_array(a);
|
||||
encoder.set_input_array(b);
|
||||
encoder.set_output_array(out_a);
|
||||
encoder.set_output_array(out_b);
|
||||
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||
dispatch_all_types(a.dtype(), [&](auto in_type_tag) {
|
||||
dispatch_all_types(out_a.dtype(), [&](auto out_type_tag) {
|
||||
using CTYPE_IN = MLX_GET_TYPE(in_type_tag);
|
||||
using CTYPE_OUT = MLX_GET_TYPE(out_type_tag);
|
||||
if constexpr (cu::supports_binary_op<Op, CTYPE_IN, CTYPE_OUT>()) {
|
||||
using InType = cuda_type_t<CTYPE_IN>;
|
||||
using OutType = cuda_type_t<CTYPE_OUT>;
|
||||
|
||||
auto bopt = get_binary_op_type(a, b);
|
||||
if (bopt == BinaryOpType::General) {
|
||||
dispatch_bool(
|
||||
a.data_size() > INT32_MAX || b.data_size() > INT32_MAX ||
|
||||
out_a.data_size() > INT32_MAX,
|
||||
[&](auto large) {
|
||||
using IdxT = std::conditional_t<large(), int64_t, int32_t>;
|
||||
Shape shape;
|
||||
std::vector<Strides> strides;
|
||||
std::tie(shape, strides) =
|
||||
collapse_contiguous_dims(a, b, out_a);
|
||||
auto& a_strides = strides[0];
|
||||
auto& b_strides = strides[1];
|
||||
int ndim = shape.size();
|
||||
if (ndim <= 3) {
|
||||
dispatch_1_2_3(ndim, [&](auto dims_constant) {
|
||||
auto kernel = cu::binary_g_nd<
|
||||
Op,
|
||||
InType,
|
||||
OutType,
|
||||
IdxT,
|
||||
dims_constant()>;
|
||||
auto [num_blocks, block_dims] =
|
||||
get_launch_args(kernel, out_a, large());
|
||||
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
||||
a.data<InType>(),
|
||||
b.data<InType>(),
|
||||
out_a.data<OutType>(),
|
||||
out_b.data<OutType>(),
|
||||
out_a.size(),
|
||||
const_param<dims_constant()>(shape),
|
||||
const_param<dims_constant()>(a_strides),
|
||||
const_param<dims_constant()>(b_strides));
|
||||
});
|
||||
} else {
|
||||
auto kernel = cu::binary_g<Op, InType, OutType, IdxT>;
|
||||
auto [num_blocks, block_dims] =
|
||||
get_launch_args(kernel, out_a, large());
|
||||
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
||||
a.data<InType>(),
|
||||
b.data<InType>(),
|
||||
out_a.data<OutType>(),
|
||||
out_b.data<OutType>(),
|
||||
out_a.size(),
|
||||
const_param(shape),
|
||||
const_param(a_strides),
|
||||
const_param(b_strides),
|
||||
ndim);
|
||||
}
|
||||
});
|
||||
} else {
|
||||
dispatch_bool(out_a.data_size() > INT32_MAX, [&](auto large) {
|
||||
using IdxT = std::conditional_t<large(), int64_t, uint32_t>;
|
||||
auto kernel = cu::binary_ss<Op, InType, OutType, IdxT>;
|
||||
if (bopt == BinaryOpType::ScalarVector) {
|
||||
kernel = cu::binary_sv<Op, InType, OutType, IdxT>;
|
||||
} else if (bopt == BinaryOpType::VectorScalar) {
|
||||
kernel = cu::binary_vs<Op, InType, OutType, IdxT>;
|
||||
} else if (bopt == BinaryOpType::VectorVector) {
|
||||
kernel = cu::binary_vv<Op, InType, OutType, IdxT>;
|
||||
}
|
||||
auto [num_blocks, block_dims] = get_launch_args(
|
||||
kernel,
|
||||
out_a.data_size(),
|
||||
out_a.shape(),
|
||||
out_a.strides(),
|
||||
large());
|
||||
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
||||
a.data<InType>(),
|
||||
b.data<InType>(),
|
||||
out_a.data<OutType>(),
|
||||
out_b.data<OutType>(),
|
||||
out_a.data_size());
|
||||
});
|
||||
}
|
||||
} else {
|
||||
throw std::runtime_error(fmt::format(
|
||||
"Can not do binary op {} on inputs of {} with result of {}.",
|
||||
op,
|
||||
dtype_to_string(a.dtype()),
|
||||
dtype_to_string(out_a.dtype())));
|
||||
}
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
template <typename Op>
|
||||
void binary_op_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs,
|
||||
std::string_view op,
|
||||
const Stream& s) {
|
||||
auto& a = inputs[0];
|
||||
auto& b = inputs[1];
|
||||
auto bopt = get_binary_op_type(a, b);
|
||||
set_binary_op_output_data(a, b, outputs[0], bopt);
|
||||
set_binary_op_output_data(a, b, outputs[1], bopt);
|
||||
binary_op_gpu_inplace<Op>(inputs, outputs, op, s);
|
||||
}
|
||||
|
||||
void DivMod::eval_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
nvtx3::scoped_range r("DivMod::eval_gpu");
|
||||
auto& s = outputs[0].primitive().stream();
|
||||
binary_op_gpu<cu::DivMod>(inputs, outputs, get_primitive_string(this), s);
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
@@ -24,7 +24,6 @@ void copy_gpu_inplace(
|
||||
auto& encoder = cu::get_command_encoder(s);
|
||||
encoder.set_input_array(in);
|
||||
encoder.set_output_array(out);
|
||||
|
||||
if (ctype == CopyType::Scalar || ctype == CopyType::Vector) {
|
||||
copy_contiguous(encoder, ctype, in, out, offset_in, offset_out);
|
||||
return;
|
||||
|
||||
@@ -10,15 +10,6 @@
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
#define MLX_SWITCH_COPY_TYPES(in, out, InType, OutType, ...) \
|
||||
MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE_IN, { \
|
||||
MLX_SWITCH_ALL_TYPES(out.dtype(), CTYPE_OUT, { \
|
||||
using InType = cuda_type_t<CTYPE_IN>; \
|
||||
using OutType = cuda_type_t<CTYPE_OUT>; \
|
||||
__VA_ARGS__; \
|
||||
}); \
|
||||
})
|
||||
|
||||
void copy_contiguous(
|
||||
cu::CommandEncoder& encoder,
|
||||
CopyType ctype,
|
||||
|
||||
@@ -36,19 +36,23 @@ void copy_contiguous(
|
||||
int64_t in_offset,
|
||||
int64_t out_offset) {
|
||||
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||
MLX_SWITCH_COPY_TYPES(in, out, InType, OutType, {
|
||||
MLX_SWITCH_BOOL(out.data_size() > UINT32_MAX, LARGE, {
|
||||
using IdxT = std::conditional_t<LARGE, int64_t, uint32_t>;
|
||||
auto kernel = cu::copy_s<InType, OutType, IdxT>;
|
||||
if (ctype == CopyType::Vector) {
|
||||
kernel = cu::copy_v<InType, OutType, IdxT>;
|
||||
}
|
||||
auto [num_blocks, block_dims] = get_launch_args(
|
||||
kernel, out.data_size(), out.shape(), out.strides(), LARGE);
|
||||
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
||||
in.data<InType>() + in_offset,
|
||||
out.data<OutType>() + out_offset,
|
||||
out.data_size());
|
||||
dispatch_all_types(in.dtype(), [&](auto in_type_tag) {
|
||||
dispatch_all_types(out.dtype(), [&](auto out_type_tag) {
|
||||
dispatch_bool(out.data_size() > INT32_MAX, [&](auto large) {
|
||||
using InType = cuda_type_t<MLX_GET_TYPE(in_type_tag)>;
|
||||
using OutType = cuda_type_t<MLX_GET_TYPE(out_type_tag)>;
|
||||
using IdxT = std::conditional_t<large(), int64_t, uint32_t>;
|
||||
auto kernel = cu::copy_s<InType, OutType, IdxT>;
|
||||
if (ctype == CopyType::Vector) {
|
||||
kernel = cu::copy_v<InType, OutType, IdxT>;
|
||||
}
|
||||
auto [num_blocks, block_dims] = get_launch_args(
|
||||
kernel, out.data_size(), out.shape(), out.strides(), large());
|
||||
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
||||
in.data<InType>() + in_offset,
|
||||
out.data<OutType>() + out_offset,
|
||||
out.data_size());
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -56,37 +56,48 @@ void copy_general(
|
||||
const Strides& strides_in,
|
||||
const Strides& strides_out) {
|
||||
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||
MLX_SWITCH_COPY_TYPES(in, out, InType, OutType, {
|
||||
const InType* in_ptr = in.data<InType>() + offset_in;
|
||||
OutType* out_ptr = out.data<OutType>() + offset_out;
|
||||
bool large = in.data_size() > INT32_MAX || out.data_size() > INT32_MAX;
|
||||
MLX_SWITCH_BOOL(large, LARGE, {
|
||||
using IdxT = std::conditional_t<LARGE, int64_t, int32_t>;
|
||||
int ndim = shape.size();
|
||||
if (ndim <= 3) {
|
||||
MLX_SWITCH_1_2_3(ndim, NDIM, {
|
||||
auto kernel = cu::copy_gg_nd<InType, OutType, IdxT, NDIM>;
|
||||
auto [num_blocks, block_dims] = get_launch_args(kernel, out, large);
|
||||
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
||||
in_ptr,
|
||||
out_ptr,
|
||||
out.size(),
|
||||
const_param<NDIM>(shape),
|
||||
const_param<NDIM>(strides_in),
|
||||
const_param<NDIM>(strides_out));
|
||||
});
|
||||
} else { // ndim >= 4
|
||||
auto kernel = cu::copy_gg<InType, OutType, IdxT>;
|
||||
auto [num_blocks, block_dims] = get_launch_args(kernel, out, large);
|
||||
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
||||
in_ptr,
|
||||
out_ptr,
|
||||
out.size(),
|
||||
const_param(shape),
|
||||
const_param(strides_in),
|
||||
const_param(strides_out),
|
||||
ndim);
|
||||
}
|
||||
dispatch_all_types(in.dtype(), [&](auto in_type_tag) {
|
||||
dispatch_all_types(out.dtype(), [&](auto out_type_tag) {
|
||||
dispatch_bool(
|
||||
in.data_size() > INT32_MAX || out.data_size() > INT32_MAX,
|
||||
[&](auto large) {
|
||||
using InType = cuda_type_t<MLX_GET_TYPE(in_type_tag)>;
|
||||
using OutType = cuda_type_t<MLX_GET_TYPE(out_type_tag)>;
|
||||
using IdxT = std::conditional_t<large(), int64_t, int32_t>;
|
||||
const InType* in_ptr = in.data<InType>() + offset_in;
|
||||
OutType* out_ptr = out.data<OutType>() + offset_out;
|
||||
int ndim = shape.size();
|
||||
size_t data_size = 1;
|
||||
for (auto& s : shape)
|
||||
data_size *= s;
|
||||
if (ndim <= 3) {
|
||||
dispatch_1_2_3(ndim, [&](auto ndim_constant) {
|
||||
auto kernel =
|
||||
cu::copy_gg_nd<InType, OutType, IdxT, ndim_constant()>;
|
||||
auto [num_blocks, block_dims] = get_launch_args(
|
||||
kernel, data_size, shape, out.strides(), large());
|
||||
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
||||
in_ptr,
|
||||
out_ptr,
|
||||
data_size,
|
||||
const_param<ndim_constant()>(shape),
|
||||
const_param<ndim_constant()>(strides_in),
|
||||
const_param<ndim_constant()>(strides_out));
|
||||
});
|
||||
} else { // ndim >= 4
|
||||
auto kernel = cu::copy_gg<InType, OutType, IdxT>;
|
||||
auto [num_blocks, block_dims] = get_launch_args(
|
||||
kernel, data_size, shape, out.strides(), large());
|
||||
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
||||
in_ptr,
|
||||
out_ptr,
|
||||
data_size,
|
||||
const_param(shape),
|
||||
const_param(strides_in),
|
||||
const_param(strides_out),
|
||||
ndim);
|
||||
}
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -62,41 +62,52 @@ void copy_general_dynamic(
|
||||
const array& dynamic_offset_in,
|
||||
const array& dynamic_offset_out) {
|
||||
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||
MLX_SWITCH_COPY_TYPES(in, out, InType, OutType, {
|
||||
const InType* in_ptr = in.data<InType>() + offset_in;
|
||||
OutType* out_ptr = out.data<OutType>() + offset_out;
|
||||
bool large = in.data_size() > INT32_MAX || out.data_size() > INT32_MAX;
|
||||
MLX_SWITCH_BOOL(large, LARGE, {
|
||||
using IdxT = std::conditional_t<LARGE, int64_t, int32_t>;
|
||||
int ndim = shape.size();
|
||||
if (ndim <= 3) {
|
||||
MLX_SWITCH_1_2_3(ndim, NDIM, {
|
||||
auto kernel = cu::copy_gg_dynamic_nd<InType, OutType, IdxT, NDIM>;
|
||||
auto [num_blocks, block_dims] = get_launch_args(kernel, out, large);
|
||||
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
||||
in_ptr,
|
||||
out_ptr,
|
||||
out.size(),
|
||||
const_param<NDIM>(shape),
|
||||
const_param<NDIM>(strides_in),
|
||||
const_param<NDIM>(strides_out),
|
||||
dynamic_offset_in.data<int64_t>(),
|
||||
dynamic_offset_out.data<int64_t>());
|
||||
});
|
||||
} else { // ndim >= 4
|
||||
auto kernel = cu::copy_gg_dynamic<InType, OutType, IdxT>;
|
||||
auto [num_blocks, block_dims] = get_launch_args(kernel, out, large);
|
||||
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
||||
in_ptr,
|
||||
out_ptr,
|
||||
out.size(),
|
||||
const_param(shape),
|
||||
const_param(strides_in),
|
||||
const_param(strides_out),
|
||||
ndim,
|
||||
dynamic_offset_in.data<int64_t>(),
|
||||
dynamic_offset_out.data<int64_t>());
|
||||
}
|
||||
dispatch_all_types(in.dtype(), [&](auto in_type_tag) {
|
||||
dispatch_all_types(out.dtype(), [&](auto out_type_tag) {
|
||||
dispatch_bool(
|
||||
in.data_size() > INT32_MAX || out.data_size() > INT32_MAX,
|
||||
[&](auto large) {
|
||||
using InType = cuda_type_t<MLX_GET_TYPE(in_type_tag)>;
|
||||
using OutType = cuda_type_t<MLX_GET_TYPE(out_type_tag)>;
|
||||
using IdxT = std::conditional_t<large(), int64_t, int32_t>;
|
||||
const InType* in_ptr = in.data<InType>() + offset_in;
|
||||
OutType* out_ptr = out.data<OutType>() + offset_out;
|
||||
int ndim = shape.size();
|
||||
if (ndim <= 3) {
|
||||
dispatch_1_2_3(ndim, [&](auto dims_constant) {
|
||||
auto kernel = cu::copy_gg_dynamic_nd<
|
||||
InType,
|
||||
OutType,
|
||||
IdxT,
|
||||
dims_constant()>;
|
||||
auto [num_blocks, block_dims] =
|
||||
get_launch_args(kernel, out, large());
|
||||
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
||||
in_ptr,
|
||||
out_ptr,
|
||||
out.size(),
|
||||
const_param<dims_constant()>(shape),
|
||||
const_param<dims_constant()>(strides_in),
|
||||
const_param<dims_constant()>(strides_out),
|
||||
dynamic_offset_in.data<int64_t>(),
|
||||
dynamic_offset_out.data<int64_t>());
|
||||
});
|
||||
} else { // ndim >= 4
|
||||
auto kernel = cu::copy_gg_dynamic<InType, OutType, IdxT>;
|
||||
auto [num_blocks, block_dims] =
|
||||
get_launch_args(kernel, out, large());
|
||||
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
||||
in_ptr,
|
||||
out_ptr,
|
||||
out.size(),
|
||||
const_param(shape),
|
||||
const_param(strides_in),
|
||||
const_param(strides_out),
|
||||
ndim,
|
||||
dynamic_offset_in.data<int64_t>(),
|
||||
dynamic_offset_out.data<int64_t>());
|
||||
}
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -51,35 +51,43 @@ void copy_general_input(
|
||||
const Shape& shape,
|
||||
const Strides& strides_in) {
|
||||
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||
MLX_SWITCH_COPY_TYPES(in, out, InType, OutType, {
|
||||
const InType* in_ptr = in.data<InType>() + offset_in;
|
||||
OutType* out_ptr = out.data<OutType>() + offset_out;
|
||||
bool large = in.data_size() > INT32_MAX || out.data_size() > INT32_MAX;
|
||||
MLX_SWITCH_BOOL(large, LARGE, {
|
||||
using IdxT = std::conditional_t<LARGE, int64_t, int32_t>;
|
||||
int ndim = shape.size();
|
||||
if (ndim <= 3) {
|
||||
MLX_SWITCH_1_2_3(ndim, NDIM, {
|
||||
auto kernel = cu::copy_g_nd<InType, OutType, IdxT, NDIM>;
|
||||
auto [num_blocks, block_dims] = get_launch_args(kernel, out, large);
|
||||
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
||||
in_ptr,
|
||||
out_ptr,
|
||||
out.size(),
|
||||
const_param<NDIM>(shape),
|
||||
const_param<NDIM>(strides_in));
|
||||
});
|
||||
} else { // ndim >= 4
|
||||
auto kernel = cu::copy_g<InType, OutType, IdxT>;
|
||||
auto [num_blocks, block_dims] = get_launch_args(kernel, out, large);
|
||||
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
||||
in_ptr,
|
||||
out_ptr,
|
||||
out.size(),
|
||||
const_param(shape),
|
||||
const_param(strides_in),
|
||||
ndim);
|
||||
}
|
||||
dispatch_all_types(in.dtype(), [&](auto in_type_tag) {
|
||||
dispatch_all_types(out.dtype(), [&](auto out_type_tag) {
|
||||
dispatch_bool(
|
||||
in.data_size() > INT32_MAX || out.data_size() > INT32_MAX,
|
||||
[&](auto large) {
|
||||
using InType = cuda_type_t<MLX_GET_TYPE(in_type_tag)>;
|
||||
using OutType = cuda_type_t<MLX_GET_TYPE(out_type_tag)>;
|
||||
using IdxT = std::conditional_t<large(), int64_t, int32_t>;
|
||||
const InType* in_ptr = in.data<InType>() + offset_in;
|
||||
OutType* out_ptr = out.data<OutType>() + offset_out;
|
||||
int ndim = shape.size();
|
||||
if (ndim <= 3) {
|
||||
dispatch_1_2_3(ndim, [&](auto dims_constant) {
|
||||
auto kernel =
|
||||
cu::copy_g_nd<InType, OutType, IdxT, dims_constant()>;
|
||||
auto [num_blocks, block_dims] =
|
||||
get_launch_args(kernel, out, large());
|
||||
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
||||
in_ptr,
|
||||
out_ptr,
|
||||
out.size(),
|
||||
const_param<dims_constant()>(shape),
|
||||
const_param<dims_constant()>(strides_in));
|
||||
});
|
||||
} else { // ndim >= 4
|
||||
auto kernel = cu::copy_g<InType, OutType, IdxT>;
|
||||
auto [num_blocks, block_dims] =
|
||||
get_launch_args(kernel, out, large());
|
||||
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
||||
in_ptr,
|
||||
out_ptr,
|
||||
out.size(),
|
||||
const_param(shape),
|
||||
const_param(strides_in),
|
||||
ndim);
|
||||
}
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -6,6 +6,7 @@
|
||||
|
||||
#include <fmt/format.h>
|
||||
#include <nvtx3/nvtx3.hpp>
|
||||
#include <future>
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
@@ -107,6 +108,16 @@ void CommandEncoder::commit() {
|
||||
worker_.commit(stream_.last_cuda_stream());
|
||||
}
|
||||
|
||||
void CommandEncoder::synchronize() {
|
||||
stream().synchronize();
|
||||
auto p = std::make_shared<std::promise<void>>();
|
||||
std::future<void> f = p->get_future();
|
||||
add_completed_handler([p = std::move(p)]() { p->set_value(); });
|
||||
worker_.end_batch();
|
||||
commit();
|
||||
f.wait();
|
||||
}
|
||||
|
||||
Device& device(mlx::core::Device device) {
|
||||
static std::unordered_map<int, Device> devices;
|
||||
auto it = devices.find(device.index);
|
||||
|
||||
@@ -123,6 +123,9 @@ class CommandEncoder {
|
||||
return has_gpu_work_;
|
||||
}
|
||||
|
||||
// Wait until kernels and completion handlers are finished
|
||||
void synchronize();
|
||||
|
||||
private:
|
||||
Device& device_;
|
||||
DeviceStream& stream_;
|
||||
|
||||
@@ -22,7 +22,7 @@ struct FloorDivide {
|
||||
if constexpr (cuda::std::is_integral_v<T>) {
|
||||
return x / y;
|
||||
} else {
|
||||
return trunc(x / y);
|
||||
return truncf(x / y);
|
||||
}
|
||||
}
|
||||
};
|
||||
@@ -132,7 +132,7 @@ struct LogAddExp {
|
||||
cuda::std::numeric_limits<float>::quiet_NaN(),
|
||||
cuda::std::numeric_limits<float>::quiet_NaN()};
|
||||
}
|
||||
constexpr float inf = cuda::std::numeric_limits<float>::infinity();
|
||||
float inf = cuda::std::numeric_limits<float>::infinity();
|
||||
auto maxval = x > y ? x : y;
|
||||
auto minval = x < y ? x : y;
|
||||
if (cuCrealf(minval) == -inf || cuCrealf(maxval) == inf)
|
||||
|
||||
@@ -5,7 +5,7 @@
|
||||
#pragma once
|
||||
|
||||
// The maximum dimensions of shape/strides passed as kernel parameters.
|
||||
#define MAX_NDIM 8
|
||||
#define MAX_NDIM 10
|
||||
|
||||
// All existing NVIDIA hardware has a fixed 32 warp size. Though a built-in
|
||||
// warpSize variable exists, using it would prevent compile-time optimizations.
|
||||
|
||||
@@ -27,6 +27,8 @@ struct ArcCos {
|
||||
__device__ T operator()(T x) {
|
||||
return acos(x);
|
||||
}
|
||||
|
||||
__device__ cuComplex operator()(cuComplex x);
|
||||
};
|
||||
|
||||
struct ArcCosh {
|
||||
@@ -41,6 +43,8 @@ struct ArcSin {
|
||||
__device__ T operator()(T x) {
|
||||
return asin(x);
|
||||
}
|
||||
|
||||
__device__ cuComplex operator()(cuComplex x);
|
||||
};
|
||||
|
||||
struct ArcSinh {
|
||||
@@ -55,6 +59,8 @@ struct ArcTan {
|
||||
__device__ T operator()(T x) {
|
||||
return atan(x);
|
||||
}
|
||||
|
||||
__device__ cuComplex operator()(cuComplex x);
|
||||
};
|
||||
|
||||
struct ArcTanh {
|
||||
@@ -261,13 +267,6 @@ struct Round {
|
||||
}
|
||||
};
|
||||
|
||||
struct Rsqrt {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x) {
|
||||
return rsqrt(x);
|
||||
}
|
||||
};
|
||||
|
||||
struct Sigmoid {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x) {
|
||||
@@ -333,6 +332,29 @@ struct Sqrt {
|
||||
__device__ T operator()(T x) {
|
||||
return sqrt(x);
|
||||
}
|
||||
|
||||
__device__ cuComplex operator()(cuComplex x) {
|
||||
auto xr = cuCrealf(x);
|
||||
auto xi = cuCimagf(x);
|
||||
if (xr == 0.0f && xi == 0.0f) {
|
||||
return {0.0f, 0.0f};
|
||||
}
|
||||
auto r = cuCrealf(Abs{}(x));
|
||||
auto a = sqrt((r + xr) / 2.0f);
|
||||
auto b_abs = sqrt((r - xr) / 2.0f);
|
||||
auto b = copysign(b_abs, xi);
|
||||
return {a, b};
|
||||
}
|
||||
};
|
||||
|
||||
struct Rsqrt {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x) {
|
||||
return rsqrt(x);
|
||||
}
|
||||
__device__ cuComplex operator()(cuComplex x) {
|
||||
return 1.0f / Sqrt{}(x);
|
||||
}
|
||||
};
|
||||
|
||||
struct Tan {
|
||||
@@ -365,4 +387,22 @@ struct Tanh {
|
||||
}
|
||||
};
|
||||
|
||||
__device__ cuComplex ArcCos::operator()(cuComplex x) {
|
||||
auto i = cuComplex{0.0, 1.0};
|
||||
auto y = Log{}(x + i * Sqrt{}(1.0 - x * x));
|
||||
return {cuCimagf(y), -cuCrealf(y)};
|
||||
};
|
||||
|
||||
__device__ cuComplex ArcSin::operator()(cuComplex x) {
|
||||
auto i = cuComplex{0.0f, 1.0f};
|
||||
auto y = Log{}(i * x + Sqrt{}(1.0f - x * x));
|
||||
return {cuCimagf(y), -cuCrealf(y)};
|
||||
};
|
||||
|
||||
__device__ cuComplex ArcTan::operator()(cuComplex x) {
|
||||
auto i = cuComplex{0.0f, 1.0f};
|
||||
auto ix = i * x;
|
||||
return (1.0f / cuComplex{0.0f, 2.0f}) * Log{}((1.0f + ix) / (1.0f - ix));
|
||||
};
|
||||
|
||||
} // namespace mlx::core::cu
|
||||
|
||||
@@ -155,8 +155,8 @@ inline __host__ __device__ cuda::std::tuple<IdxT, IdxT> elem_to_loc_nd(
|
||||
#pragma unroll
|
||||
for (int i = NDIM - 1; i >= 0; --i) {
|
||||
int dim_idx = elem % shape[i];
|
||||
a_loc += dim_idx * a_strides[i];
|
||||
b_loc += dim_idx * b_strides[i];
|
||||
a_loc += dim_idx * IdxT(a_strides[i]);
|
||||
b_loc += dim_idx * IdxT(b_strides[i]);
|
||||
elem /= shape[i];
|
||||
}
|
||||
return cuda::std::make_tuple(a_loc, b_loc);
|
||||
@@ -175,9 +175,9 @@ inline __host__ __device__ cuda::std::tuple<IdxT, IdxT, IdxT> elem_to_loc_nd(
|
||||
#pragma unroll
|
||||
for (int i = NDIM - 1; i >= 0; --i) {
|
||||
int dim_idx = elem % shape[i];
|
||||
a_loc += dim_idx * a_strides[i];
|
||||
b_loc += dim_idx * b_strides[i];
|
||||
c_loc += dim_idx * c_strides[i];
|
||||
a_loc += dim_idx * IdxT(a_strides[i]);
|
||||
b_loc += dim_idx * IdxT(b_strides[i]);
|
||||
c_loc += dim_idx * IdxT(c_strides[i]);
|
||||
elem /= shape[i];
|
||||
}
|
||||
return cuda::std::make_tuple(a_loc, b_loc, c_loc);
|
||||
@@ -206,8 +206,8 @@ inline __host__ __device__ cuda::std::tuple<IdxT, IdxT> elem_to_loc_4d(
|
||||
IdxT b_loc = 0;
|
||||
for (int i = ndim - 1; i >= 0; --i) {
|
||||
int dim_idx = elem % shape[i];
|
||||
a_loc += dim_idx * a_strides[i];
|
||||
b_loc += dim_idx * b_strides[i];
|
||||
a_loc += dim_idx * IdxT(a_strides[i]);
|
||||
b_loc += dim_idx * IdxT(b_strides[i]);
|
||||
elem /= shape[i];
|
||||
}
|
||||
return cuda::std::make_tuple(a_loc, b_loc);
|
||||
@@ -226,9 +226,9 @@ inline __host__ __device__ cuda::std::tuple<IdxT, IdxT, IdxT> elem_to_loc_4d(
|
||||
IdxT c_loc = 0;
|
||||
for (int i = ndim - 1; i >= 0; --i) {
|
||||
int dim_idx = elem % shape[i];
|
||||
a_loc += dim_idx * a_strides[i];
|
||||
b_loc += dim_idx * b_strides[i];
|
||||
c_loc += dim_idx * c_strides[i];
|
||||
a_loc += dim_idx * IdxT(a_strides[i]);
|
||||
b_loc += dim_idx * IdxT(b_strides[i]);
|
||||
c_loc += dim_idx * IdxT(c_strides[i]);
|
||||
elem /= shape[i];
|
||||
}
|
||||
return cuda::std::make_tuple(a_loc, b_loc, c_loc);
|
||||
|
||||
@@ -62,7 +62,7 @@ void finalize(Stream s) {
|
||||
|
||||
void synchronize(Stream s) {
|
||||
nvtx3::scoped_range r("gpu::synchronize");
|
||||
cu::get_stream(s).synchronize();
|
||||
cu::get_command_encoder(s).synchronize();
|
||||
}
|
||||
|
||||
} // namespace mlx::core::gpu
|
||||
|
||||
@@ -37,36 +37,46 @@ void check_cu_error(const char* name, CUresult err) {
|
||||
}
|
||||
|
||||
// Return the location of the CUDA toolkit.
|
||||
const char* cuda_home() {
|
||||
const char* home = std::getenv("CUDA_HOME");
|
||||
if (home) {
|
||||
return home;
|
||||
}
|
||||
home = std::getenv("CUDA_PATH");
|
||||
if (home) {
|
||||
return home;
|
||||
}
|
||||
const std::string& cuda_home() {
|
||||
static std::string home = []() -> std::string {
|
||||
const char* home = std::getenv("CUDA_HOME");
|
||||
if (home) {
|
||||
return home;
|
||||
}
|
||||
home = std::getenv("CUDA_PATH");
|
||||
if (home) {
|
||||
return home;
|
||||
}
|
||||
#if defined(__linux__)
|
||||
home = "/usr/local/cuda";
|
||||
if (std::filesystem::exists(home)) {
|
||||
return home;
|
||||
}
|
||||
home = "/usr/local/cuda";
|
||||
if (std::filesystem::exists(home)) {
|
||||
return home;
|
||||
}
|
||||
#endif
|
||||
throw std::runtime_error(
|
||||
"Environment variable CUDA_HOME or CUDA_PATH is not set.");
|
||||
throw std::runtime_error(
|
||||
"Environment variable CUDA_HOME or CUDA_PATH is not set.");
|
||||
}();
|
||||
return home;
|
||||
}
|
||||
|
||||
// Get the cache directory for storing compiled results.
|
||||
bool get_ptx_cache_dir(std::filesystem::path* result) {
|
||||
auto path = std::filesystem::temp_directory_path() / "mlx" / "ptx";
|
||||
if (!std::filesystem::is_directory(path)) {
|
||||
std::error_code error;
|
||||
if (!std::filesystem::create_directories(path, error)) {
|
||||
return false;
|
||||
const std::filesystem::path& ptx_cache_dir() {
|
||||
static std::filesystem::path cache = []() -> std::filesystem::path {
|
||||
std::filesystem::path cache;
|
||||
if (auto c = std::getenv("MLX_PTX_CACHE"); c) {
|
||||
cache = c;
|
||||
} else {
|
||||
cache = std::filesystem::temp_directory_path() / "mlx" / "ptx";
|
||||
}
|
||||
}
|
||||
*result = path;
|
||||
return true;
|
||||
if (!std::filesystem::exists(cache)) {
|
||||
std::error_code error;
|
||||
if (!std::filesystem::create_directories(cache, error)) {
|
||||
return std::filesystem::path();
|
||||
}
|
||||
}
|
||||
return cache;
|
||||
}();
|
||||
return cache;
|
||||
}
|
||||
|
||||
// Try to read the cached |ptx| and |ptx_kernels| from |cache_dir|.
|
||||
@@ -75,6 +85,10 @@ bool read_cached_ptx(
|
||||
const std::string& module_name,
|
||||
std::vector<char>* ptx,
|
||||
std::vector<std::pair<std::string, std::string>>* ptx_kernels) {
|
||||
if (cache_dir.empty()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto ptx_path = cache_dir / (module_name + ".ptx");
|
||||
std::error_code error;
|
||||
auto ptx_size = std::filesystem::file_size(ptx_path, error);
|
||||
@@ -105,6 +119,10 @@ void write_cached_ptx(
|
||||
const std::string& module_name,
|
||||
const std::vector<char>& ptx,
|
||||
const std::vector<std::pair<std::string, std::string>>& ptx_kernels) {
|
||||
if (cache_dir.empty()) {
|
||||
return;
|
||||
}
|
||||
|
||||
std::ofstream ptx_file(cache_dir / (module_name + ".ptx"), std::ios::binary);
|
||||
if (!ptx.empty()) {
|
||||
ptx_file.write(&ptx.front(), ptx.size());
|
||||
@@ -184,11 +202,9 @@ JitModule::JitModule(
|
||||
const std::string& module_name,
|
||||
const KernelBuilder& builder) {
|
||||
// Check cache.
|
||||
std::filesystem::path cache_dir;
|
||||
std::vector<char> ptx;
|
||||
std::vector<std::pair<std::string, std::string>> ptx_kernels;
|
||||
if (!get_ptx_cache_dir(&cache_dir) ||
|
||||
!read_cached_ptx(cache_dir, module_name, &ptx, &ptx_kernels)) {
|
||||
if (!read_cached_ptx(ptx_cache_dir(), module_name, &ptx, &ptx_kernels)) {
|
||||
// Create program.
|
||||
auto [source_code, kernel_names] = builder();
|
||||
nvrtcProgram prog;
|
||||
@@ -246,7 +262,7 @@ JitModule::JitModule(
|
||||
} else {
|
||||
CHECK_NVRTC_ERROR(nvrtcGetPTX(prog, ptx.data()));
|
||||
}
|
||||
write_cached_ptx(cache_dir, module_name, ptx, ptx_kernels);
|
||||
write_cached_ptx(ptx_cache_dir(), module_name, ptx, ptx_kernels);
|
||||
}
|
||||
|
||||
// Load module.
|
||||
|
||||
@@ -6,6 +6,8 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <type_traits>
|
||||
|
||||
#include "mlx/array.h"
|
||||
#include "mlx/backend/cuda/device/utils.cuh"
|
||||
|
||||
@@ -17,60 +19,46 @@
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
// Convert a number between 1~3 to constexpr.
|
||||
#define MLX_SWITCH_1_2_3(N, NDIM, ...) \
|
||||
switch (N) { \
|
||||
case 1: { \
|
||||
constexpr int NDIM = 1; \
|
||||
__VA_ARGS__; \
|
||||
break; \
|
||||
} \
|
||||
case 2: { \
|
||||
constexpr int NDIM = 2; \
|
||||
__VA_ARGS__; \
|
||||
break; \
|
||||
} \
|
||||
case 3: { \
|
||||
constexpr int NDIM = 3; \
|
||||
__VA_ARGS__; \
|
||||
break; \
|
||||
} \
|
||||
template <typename F>
|
||||
void dispatch_1_2_3(int n, F&& f) {
|
||||
switch (n) {
|
||||
case 1:
|
||||
f(std::integral_constant<int, 1>{});
|
||||
break;
|
||||
case 2:
|
||||
f(std::integral_constant<int, 2>{});
|
||||
break;
|
||||
case 3:
|
||||
f(std::integral_constant<int, 3>{});
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Like MLX_SWITCH_ALL_TYPES but for booleans.
|
||||
#define MLX_SWITCH_BOOL(BOOL, BOOL_ALIAS, ...) \
|
||||
if (BOOL) { \
|
||||
constexpr bool BOOL_ALIAS = true; \
|
||||
__VA_ARGS__; \
|
||||
} else { \
|
||||
constexpr bool BOOL_ALIAS = false; \
|
||||
__VA_ARGS__; \
|
||||
template <typename F>
|
||||
void dispatch_bool(bool v, F&& f) {
|
||||
if (v) {
|
||||
f(std::true_type{});
|
||||
} else {
|
||||
f(std::false_type{});
|
||||
}
|
||||
}
|
||||
|
||||
// Convert a block_dim to constexpr between WARP_SIZE and WARP_SIZE ^ 2.
|
||||
#define MLX_SWITCH_BLOCK_DIM(NUM_THREADS, BLOCK_DIM, ...) \
|
||||
{ \
|
||||
uint32_t _num_threads = NUM_THREADS; \
|
||||
if (_num_threads <= WARP_SIZE) { \
|
||||
constexpr uint32_t BLOCK_DIM = WARP_SIZE; \
|
||||
__VA_ARGS__; \
|
||||
} else if (_num_threads <= WARP_SIZE * 2) { \
|
||||
constexpr uint32_t BLOCK_DIM = WARP_SIZE * 2; \
|
||||
__VA_ARGS__; \
|
||||
} else if (_num_threads <= WARP_SIZE * 4) { \
|
||||
constexpr uint32_t BLOCK_DIM = WARP_SIZE * 4; \
|
||||
__VA_ARGS__; \
|
||||
} else if (_num_threads <= WARP_SIZE * 8) { \
|
||||
constexpr uint32_t BLOCK_DIM = WARP_SIZE * 8; \
|
||||
__VA_ARGS__; \
|
||||
} else if (_num_threads <= WARP_SIZE * 16) { \
|
||||
constexpr uint32_t BLOCK_DIM = WARP_SIZE * 16; \
|
||||
__VA_ARGS__; \
|
||||
} else { \
|
||||
constexpr uint32_t BLOCK_DIM = WARP_SIZE * WARP_SIZE; \
|
||||
__VA_ARGS__; \
|
||||
} \
|
||||
template <typename F>
|
||||
void dispatch_block_dim(int threads, F&& f) {
|
||||
if (threads <= WARP_SIZE) {
|
||||
f(std::integral_constant<int, WARP_SIZE>{});
|
||||
} else if (threads <= WARP_SIZE * 2) {
|
||||
f(std::integral_constant<int, WARP_SIZE * 2>{});
|
||||
} else if (threads <= WARP_SIZE * 4) {
|
||||
f(std::integral_constant<int, WARP_SIZE * 4>{});
|
||||
} else if (threads <= WARP_SIZE * 8) {
|
||||
f(std::integral_constant<int, WARP_SIZE * 8>{});
|
||||
} else if (threads <= WARP_SIZE * 16) {
|
||||
f(std::integral_constant<int, WARP_SIZE * 16>{});
|
||||
} else {
|
||||
f(std::integral_constant<int, WARP_SIZE * 32>{});
|
||||
}
|
||||
}
|
||||
|
||||
// Maps CPU types to CUDA types.
|
||||
template <typename T>
|
||||
|
||||
@@ -259,21 +259,22 @@ void LayerNorm::eval_gpu(
|
||||
encoder.set_input_array(b);
|
||||
encoder.set_output_array(out);
|
||||
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||
MLX_SWITCH_FLOAT_TYPES_CHECKED(out.dtype(), "layernorm", CTYPE, {
|
||||
using DataType = cuda_type_t<CTYPE>;
|
||||
dispatch_float_types(out.dtype(), "layernorm", [&](auto type_tag) {
|
||||
constexpr uint32_t N_READS = 4;
|
||||
MLX_SWITCH_BLOCK_DIM(cuda::ceil_div(axis_size, N_READS), BLOCK_DIM, {
|
||||
auto kernel = cu::layer_norm<DataType, BLOCK_DIM, N_READS>;
|
||||
kernel<<<n_rows, BLOCK_DIM, 0, stream>>>(
|
||||
x.data<DataType>(),
|
||||
w.data<DataType>(),
|
||||
b.data<DataType>(),
|
||||
out.data<DataType>(),
|
||||
eps_,
|
||||
axis_size,
|
||||
w_stride,
|
||||
b_stride);
|
||||
});
|
||||
dispatch_block_dim(
|
||||
cuda::ceil_div(axis_size, N_READS), [&](auto block_dim) {
|
||||
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
|
||||
auto kernel = cu::layer_norm<DataType, block_dim(), N_READS>;
|
||||
kernel<<<n_rows, block_dim(), 0, stream>>>(
|
||||
x.data<DataType>(),
|
||||
w.data<DataType>(),
|
||||
b.data<DataType>(),
|
||||
out.data<DataType>(),
|
||||
eps_,
|
||||
axis_size,
|
||||
w_stride,
|
||||
b_stride);
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
@@ -341,8 +342,6 @@ void LayerNormVJP::eval_gpu(
|
||||
encoder.add_temporary(gw_temp);
|
||||
}
|
||||
}
|
||||
gw.set_data(allocator::malloc(gw.nbytes()));
|
||||
gb.set_data(allocator::malloc(gb.nbytes()));
|
||||
|
||||
// Finish with the gradient for b in case we had a b.
|
||||
if (gb.ndim() == 1 && gb.size() == axis_size) {
|
||||
@@ -357,22 +356,27 @@ void LayerNormVJP::eval_gpu(
|
||||
encoder.set_output_array(gx);
|
||||
encoder.set_output_array(gw_temp);
|
||||
encoder.launch_kernel([&, x = x, g = g](cudaStream_t stream) {
|
||||
MLX_SWITCH_FLOAT_TYPES_CHECKED(gx.dtype(), "layernorm_vjp", CTYPE, {
|
||||
using DataType = cuda_type_t<CTYPE>;
|
||||
constexpr int N_READS = 4;
|
||||
MLX_SWITCH_BOOL(has_w, HAS_W, {
|
||||
MLX_SWITCH_BLOCK_DIM(cuda::ceil_div(axis_size, N_READS), BLOCK_DIM, {
|
||||
auto kernel = cu::layer_norm_vjp<DataType, HAS_W, BLOCK_DIM, N_READS>;
|
||||
kernel<<<n_rows, BLOCK_DIM, 0, stream>>>(
|
||||
x.data<DataType>(),
|
||||
w.data<DataType>(),
|
||||
g.data<DataType>(),
|
||||
gx.data<DataType>(),
|
||||
gw_temp.data<DataType>(),
|
||||
eps_,
|
||||
axis_size,
|
||||
w_stride);
|
||||
});
|
||||
dispatch_float_types(gx.dtype(), "layernorm_vjp", [&](auto type_tag) {
|
||||
dispatch_bool(has_w, [&](auto has_w_constant) {
|
||||
constexpr int N_READS = 4;
|
||||
dispatch_block_dim(
|
||||
cuda::ceil_div(axis_size, N_READS), [&](auto block_dim) {
|
||||
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
|
||||
auto kernel = cu::layer_norm_vjp<
|
||||
DataType,
|
||||
has_w_constant(),
|
||||
block_dim(),
|
||||
N_READS>;
|
||||
kernel<<<n_rows, block_dim(), 0, stream>>>(
|
||||
x.data<DataType>(),
|
||||
w.data<DataType>(),
|
||||
g.data<DataType>(),
|
||||
gx.data<DataType>(),
|
||||
gw_temp.data<DataType>(),
|
||||
eps_,
|
||||
axis_size,
|
||||
w_stride);
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -144,14 +144,15 @@ void LogSumExp::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
encoder.set_input_array(in);
|
||||
encoder.set_output_array(out);
|
||||
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||
MLX_SWITCH_FLOAT_TYPES_CHECKED(out.dtype(), "logsumexp", CTYPE, {
|
||||
using DataType = cuda_type_t<CTYPE>;
|
||||
dispatch_float_types(out.dtype(), "logsumexp", [&](auto type_tag) {
|
||||
constexpr int N_READS = 4;
|
||||
MLX_SWITCH_BLOCK_DIM(cuda::ceil_div(axis_size, N_READS), BLOCK_DIM, {
|
||||
auto kernel = cu::logsumexp<DataType, float, BLOCK_DIM, N_READS>;
|
||||
kernel<<<n_rows, BLOCK_DIM, 0, stream>>>(
|
||||
in.data<DataType>(), out.data<DataType>(), axis_size);
|
||||
});
|
||||
dispatch_block_dim(
|
||||
cuda::ceil_div(axis_size, N_READS), [&](auto block_dim) {
|
||||
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
|
||||
auto kernel = cu::logsumexp<DataType, float, block_dim(), N_READS>;
|
||||
kernel<<<n_rows, block_dim(), 0, stream>>>(
|
||||
in.data<DataType>(), out.data<DataType>(), axis_size);
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
@@ -162,11 +162,15 @@ class MatMul {
|
||||
}
|
||||
}
|
||||
|
||||
array workspace(
|
||||
allocator::malloc(heuristic_.workspaceSize),
|
||||
{static_cast<int>(heuristic_.workspaceSize)},
|
||||
int8);
|
||||
encoder.add_temporary(workspace);
|
||||
void* workspace_ptr = nullptr;
|
||||
if (heuristic_.workspaceSize > 0) {
|
||||
array workspace(
|
||||
allocator::malloc(heuristic_.workspaceSize),
|
||||
{static_cast<int>(heuristic_.workspaceSize)},
|
||||
int8);
|
||||
encoder.add_temporary(workspace);
|
||||
workspace_ptr = workspace.data<void>();
|
||||
}
|
||||
|
||||
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||
CHECK_CUBLAS_ERROR(cublasLtMatmul(
|
||||
@@ -183,8 +187,8 @@ class MatMul {
|
||||
out,
|
||||
out_desc_,
|
||||
&heuristic_.algo,
|
||||
workspace.data<void>(),
|
||||
workspace.nbytes(),
|
||||
workspace_ptr,
|
||||
heuristic_.workspaceSize,
|
||||
stream));
|
||||
});
|
||||
}
|
||||
@@ -358,9 +362,18 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
a_batch_strides.back(),
|
||||
b_batch_strides.back());
|
||||
|
||||
encoder.set_input_array(a);
|
||||
encoder.set_input_array(b);
|
||||
encoder.set_output_array(out);
|
||||
auto nbatch = batch_count / batch_shape.back();
|
||||
if (nbatch == 1) {
|
||||
matmul.run(encoder, out.data<int8_t>(), a.data<int8_t>(), b.data<int8_t>());
|
||||
return;
|
||||
}
|
||||
|
||||
ContiguousIterator a_it(batch_shape, a_batch_strides, batch_shape.size() - 1);
|
||||
ContiguousIterator b_it(batch_shape, b_batch_strides, batch_shape.size() - 1);
|
||||
for (size_t i = 0; i < batch_count / batch_shape.back(); ++i) {
|
||||
for (size_t i = 0; i < nbatch; ++i) {
|
||||
matmul.run(
|
||||
encoder,
|
||||
out.data<int8_t>() + out.itemsize() * i * batch_shape.back() * M * N,
|
||||
@@ -444,10 +457,28 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
b_batch_strides.back(),
|
||||
c_batch_strides.back());
|
||||
|
||||
encoder.set_input_array(a);
|
||||
encoder.set_input_array(b);
|
||||
encoder.set_input_array(c);
|
||||
encoder.set_output_array(out);
|
||||
|
||||
auto nbatch = batch_count / batch_shape.back();
|
||||
if (nbatch == 1) {
|
||||
matmul.run(
|
||||
encoder,
|
||||
out.data<int8_t>(),
|
||||
a.data<int8_t>(),
|
||||
b.data<int8_t>(),
|
||||
c.data<int8_t>(),
|
||||
alpha_,
|
||||
beta_);
|
||||
return;
|
||||
}
|
||||
|
||||
ContiguousIterator a_it(batch_shape, a_batch_strides, batch_shape.size() - 1);
|
||||
ContiguousIterator b_it(batch_shape, b_batch_strides, batch_shape.size() - 1);
|
||||
ContiguousIterator c_it(batch_shape, c_batch_strides, batch_shape.size() - 1);
|
||||
for (size_t i = 0; i < batch_count / batch_shape.back(); ++i) {
|
||||
for (size_t i = 0; i < nbatch; ++i) {
|
||||
matmul.run(
|
||||
encoder,
|
||||
out.data<int8_t>() + out.itemsize() * i * batch_shape.back() * M * N,
|
||||
|
||||
@@ -28,7 +28,8 @@ void Arange::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
auto& encoder = cu::get_command_encoder(s);
|
||||
encoder.set_output_array(out);
|
||||
encoder.launch_kernel([&, this](cudaStream_t stream) {
|
||||
MLX_SWITCH_INT_FLOAT_TYPES_CHECKED(out.dtype(), "Arange", CTYPE, {
|
||||
dispatch_int_float_types(out.dtype(), "Arange", [&](auto type_tag) {
|
||||
using CTYPE = MLX_GET_TYPE(type_tag);
|
||||
using OutType = cuda_type_t<CTYPE>;
|
||||
CTYPE step =
|
||||
static_cast<CTYPE>(start_ + step_) - static_cast<CTYPE>(start_);
|
||||
@@ -71,10 +72,8 @@ bool fast::ScaledDotProductAttention::use_fallback(
|
||||
throw std::runtime_error(#func " has no CUDA implementation."); \
|
||||
}
|
||||
|
||||
NO_GPU(ArgPartition)
|
||||
NO_GPU(BlockMaskedMM)
|
||||
NO_GPU(Convolution)
|
||||
NO_GPU_MULTI(DivMod)
|
||||
NO_GPU(DynamicSlice)
|
||||
NO_GPU(DynamicSliceUpdate)
|
||||
NO_GPU(FFT)
|
||||
@@ -83,7 +82,6 @@ NO_GPU(GatherQMM)
|
||||
NO_GPU(Hadamard)
|
||||
NO_GPU(Load)
|
||||
NO_GPU_MULTI(LUF)
|
||||
NO_GPU(Partition)
|
||||
NO_GPU_MULTI(QRF)
|
||||
NO_GPU(QuantizedMatmul)
|
||||
NO_GPU(Scan)
|
||||
|
||||
@@ -21,28 +21,11 @@ void Reduce::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(!axes_.empty());
|
||||
assert(out.size() != in.size());
|
||||
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
|
||||
auto& s = stream();
|
||||
auto& encoder = cu::get_command_encoder(s);
|
||||
encoder.set_input_array(in);
|
||||
encoder.set_output_array(out);
|
||||
|
||||
// Fill out with init value.
|
||||
if (in.size() == 0) {
|
||||
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||
MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE, {
|
||||
MLX_SWITCH_REDUCE_OPS(reduce_type_, OP, {
|
||||
using InType = cuda_type_t<CTYPE>;
|
||||
using OutType = cu::ReduceResult<OP, InType>::type;
|
||||
thrust::fill_n(
|
||||
cu::thrust_policy(stream),
|
||||
thrust::device_pointer_cast(out.data<OutType>()),
|
||||
out.data_size(),
|
||||
cu::ReduceInit<OP, InType>::value());
|
||||
});
|
||||
});
|
||||
});
|
||||
init_reduce(encoder, in, out, reduce_type_);
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -51,7 +34,19 @@ void Reduce::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
|
||||
// If it is a general reduce then copy the input to a contiguous array and
|
||||
// recompute the plan.
|
||||
if (plan.type == GeneralReduce) {
|
||||
//
|
||||
// TODO: Instead of copying we can use elem-to-loc to deal with broadcasting
|
||||
// like we do in Metal. When it comes to broadcasted reduction axes
|
||||
// some can be ignored eg for min/max.
|
||||
bool broadcasted = false;
|
||||
for (int i = 0, j = 0; i < in.ndim() && !broadcasted; i++) {
|
||||
if (j < axes_.size() && axes_[j] == i) {
|
||||
j++;
|
||||
} else {
|
||||
broadcasted = in.strides(i) == 0;
|
||||
}
|
||||
}
|
||||
if (plan.type == GeneralReduce || broadcasted || !in.flags().contiguous) {
|
||||
array in_copy(in.shape(), in.dtype(), nullptr, {});
|
||||
copy_gpu(in, in_copy, CopyType::General, s);
|
||||
encoder.add_temporary(in_copy);
|
||||
@@ -59,9 +54,8 @@ void Reduce::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
plan = get_reduction_plan(in, axes_);
|
||||
}
|
||||
|
||||
if ((plan.type == ContiguousAllReduce) ||
|
||||
(plan.type == ContiguousReduce && plan.shape.size() == 1)) {
|
||||
segmented_reduce(encoder, in, out, reduce_type_, axes_, plan);
|
||||
if (plan.type == ContiguousAllReduce) {
|
||||
all_reduce(encoder, in, out, reduce_type_);
|
||||
return;
|
||||
}
|
||||
|
||||
|
||||
152
mlx/backend/cuda/reduce/all_reduce.cu
Normal file
152
mlx/backend/cuda/reduce/all_reduce.cu
Normal file
@@ -0,0 +1,152 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cuda/device.h"
|
||||
#include "mlx/backend/cuda/reduce/reduce.cuh"
|
||||
|
||||
#include <cooperative_groups.h>
|
||||
#include <cooperative_groups/reduce.h>
|
||||
#include <cub/block/block_load.cuh>
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace cu {
|
||||
|
||||
namespace cg = cooperative_groups;
|
||||
|
||||
template <typename T, typename U, typename ReduceOp, int N = 4>
|
||||
__global__ void all_reduce(T* in, U* out, size_t block_step, size_t size) {
|
||||
// TODO: Process multiple "rows" in each thread
|
||||
constexpr int M = 1;
|
||||
|
||||
auto grid = cg::this_grid();
|
||||
auto block = cg::this_thread_block();
|
||||
auto warp = cg::tiled_partition<WARP_SIZE>(block);
|
||||
|
||||
const U init = cu::ReduceInit<ReduceOp, T>::value();
|
||||
ReduceOp op;
|
||||
|
||||
T vals[N];
|
||||
U accs[M];
|
||||
accs[0] = init;
|
||||
|
||||
size_t start = grid.block_rank() * block_step;
|
||||
size_t end = start + block_step;
|
||||
size_t check = min(end, size);
|
||||
|
||||
size_t i = start;
|
||||
for (; i + block.size() * N <= check; i += block.size() * N) {
|
||||
cub::LoadDirectBlockedVectorized<T, N>(block.thread_rank(), in + i, vals);
|
||||
for (int j = 0; j < N; j++) {
|
||||
accs[0] = op(accs[0], __cast<U, T>(vals[j]));
|
||||
}
|
||||
}
|
||||
|
||||
if (i < check) {
|
||||
cub::LoadDirectBlocked(
|
||||
block.thread_rank(), in + i, vals, check - i, __cast<T, U>(init));
|
||||
for (int i = 0; i < N; i++) {
|
||||
accs[0] = op(accs[0], __cast<U, T>(vals[i]));
|
||||
}
|
||||
}
|
||||
|
||||
__shared__ U shared_accumulators[32];
|
||||
block_reduce(block, warp, accs, shared_accumulators, op, init);
|
||||
|
||||
if (block.thread_rank() == 0) {
|
||||
out[grid.block_rank()] = accs[0];
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace cu
|
||||
|
||||
void all_reduce(
|
||||
cu::CommandEncoder& encoder,
|
||||
const array& in,
|
||||
array& out,
|
||||
Reduce::ReduceType reduce_type) {
|
||||
constexpr int N_READS = 8;
|
||||
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
|
||||
auto get_args = [](size_t size, int N) {
|
||||
int threads = std::min(512UL, (size + N - 1) / N);
|
||||
threads = ((threads + WARP_SIZE - 1) / WARP_SIZE) * WARP_SIZE;
|
||||
int reductions_per_step = threads * N;
|
||||
size_t steps_needed =
|
||||
(size + reductions_per_step - 1) / reductions_per_step;
|
||||
|
||||
int blocks;
|
||||
if (steps_needed < 32) {
|
||||
blocks = 1;
|
||||
} else if (steps_needed < 128) {
|
||||
blocks = 32;
|
||||
} else if (steps_needed < 512) {
|
||||
blocks = 128;
|
||||
} else if (steps_needed < 1024) {
|
||||
blocks = 512;
|
||||
} else {
|
||||
blocks = 1024;
|
||||
}
|
||||
|
||||
size_t steps_per_block = (steps_needed + blocks - 1) / blocks;
|
||||
size_t block_step = steps_per_block * reductions_per_step;
|
||||
|
||||
return std::make_tuple(blocks, threads, block_step);
|
||||
};
|
||||
|
||||
int blocks, threads;
|
||||
size_t block_step;
|
||||
size_t insize = in.size();
|
||||
Dtype dt = in.dtype();
|
||||
|
||||
// Cub doesn't like const pointers for load (sigh).
|
||||
void* indata = const_cast<void*>(in.data<void>());
|
||||
|
||||
// Large array so allocate an intermediate and accumulate there
|
||||
std::tie(blocks, threads, block_step) = get_args(insize, N_READS);
|
||||
encoder.set_input_array(in);
|
||||
if (blocks > 1) {
|
||||
array intermediate({blocks}, out.dtype(), nullptr, {});
|
||||
intermediate.set_data(allocator::malloc(intermediate.nbytes()));
|
||||
encoder.add_temporary(intermediate);
|
||||
encoder.set_output_array(intermediate);
|
||||
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||
dispatch_all_types(dt, [&](auto type_tag) {
|
||||
dispatch_reduce_ops(reduce_type, [&](auto reduce_type_tag) {
|
||||
using OP = MLX_GET_TYPE(reduce_type_tag);
|
||||
using T = cuda_type_t<MLX_GET_TYPE(type_tag)>;
|
||||
using U = typename cu::ReduceResult<OP, T>::type;
|
||||
auto kernel = cu::all_reduce<T, U, OP, N_READS>;
|
||||
kernel<<<blocks, threads, 0, stream>>>(
|
||||
static_cast<T*>(indata),
|
||||
intermediate.data<U>(),
|
||||
block_step,
|
||||
insize);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
// Set the input for the next step and recalculate the blocks
|
||||
indata = intermediate.data<void>();
|
||||
dt = intermediate.dtype();
|
||||
insize = intermediate.size();
|
||||
std::tie(blocks, threads, block_step) = get_args(insize, N_READS);
|
||||
encoder.set_input_array(intermediate);
|
||||
}
|
||||
|
||||
encoder.set_output_array(out);
|
||||
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||
dispatch_all_types(dt, [&](auto type_tag) {
|
||||
dispatch_reduce_ops(reduce_type, [&](auto reduce_type_tag) {
|
||||
using OP = MLX_GET_TYPE(reduce_type_tag);
|
||||
using T = cuda_type_t<MLX_GET_TYPE(type_tag)>;
|
||||
using U = typename cu::ReduceResult<OP, T>::type;
|
||||
auto kernel = cu::all_reduce<T, U, OP, N_READS>;
|
||||
kernel<<<blocks, threads, 0, stream>>>(
|
||||
static_cast<T*>(indata), out.data<U>(), block_step, insize);
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
@@ -1,5 +1,7 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include <numeric>
|
||||
|
||||
#include "mlx/backend/cuda/device.h"
|
||||
#include "mlx/backend/cuda/device/cast_op.cuh"
|
||||
#include "mlx/backend/cuda/reduce/reduce.cuh"
|
||||
@@ -36,19 +38,36 @@ struct ColReduceArgs {
|
||||
const array& in,
|
||||
const ReductionPlan& plan,
|
||||
const std::vector<int>& axes) {
|
||||
using ShapeVector = decltype(plan.shape);
|
||||
using StridesVector = decltype(plan.strides);
|
||||
|
||||
ShapeVector shape_vec;
|
||||
StridesVector strides_vec;
|
||||
|
||||
assert(!plan.shape.empty());
|
||||
reduction_size = plan.shape.back();
|
||||
reduction_stride = plan.strides.back();
|
||||
|
||||
int64_t stride_back = 1;
|
||||
auto [shape_vec, strides_vec] = shapes_without_reduction_axes(in, axes);
|
||||
std::tie(shape_vec, strides_vec) = shapes_without_reduction_axes(in, axes);
|
||||
while (!shape_vec.empty() && stride_back < reduction_stride) {
|
||||
stride_back *= shape_vec.back();
|
||||
shape_vec.pop_back();
|
||||
strides_vec.pop_back();
|
||||
}
|
||||
std::vector<int> indices(shape_vec.size());
|
||||
std::iota(indices.begin(), indices.end(), 0);
|
||||
std::sort(indices.begin(), indices.end(), [&](int left, int right) {
|
||||
return strides_vec[left] > strides_vec[right];
|
||||
});
|
||||
ShapeVector sorted_shape;
|
||||
StridesVector sorted_strides;
|
||||
for (auto idx : indices) {
|
||||
sorted_shape.push_back(shape_vec[idx]);
|
||||
sorted_strides.push_back(strides_vec[idx]);
|
||||
}
|
||||
std::tie(shape_vec, strides_vec) =
|
||||
collapse_contiguous_dims(shape_vec, strides_vec);
|
||||
collapse_contiguous_dims(sorted_shape, sorted_strides);
|
||||
shape = const_param(shape_vec);
|
||||
strides = const_param(strides_vec);
|
||||
ndim = shape_vec.size();
|
||||
@@ -64,86 +83,6 @@ struct ColReduceArgs {
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, typename U, typename Op, int NDIM, int N_READS = 4>
|
||||
__global__ void col_reduce_small(
|
||||
const T* in,
|
||||
U* out,
|
||||
const __grid_constant__ ColReduceArgs args) {
|
||||
auto grid = cg::this_grid();
|
||||
auto block = cg::this_thread_block();
|
||||
|
||||
int column =
|
||||
grid.block_index().x * block.dim_threads().x + block.thread_index().x;
|
||||
if (column * N_READS >= args.reduction_stride) {
|
||||
return;
|
||||
}
|
||||
|
||||
int out_idx = grid.block_rank() / grid.dim_blocks().x;
|
||||
in += elem_to_loc(out_idx, args.shape.data(), args.strides.data(), args.ndim);
|
||||
|
||||
Op op;
|
||||
U totals[N_READS];
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
totals[i] = ReduceInit<Op, T>::value();
|
||||
}
|
||||
|
||||
// Read input to local.
|
||||
LoopedElemToLoc<NDIM, (NDIM > 2)> loop(args.reduce_ndim);
|
||||
loop.next(
|
||||
block.thread_index().y,
|
||||
args.reduce_shape.data(),
|
||||
args.reduce_strides.data());
|
||||
for (size_t r = block.thread_index().y;
|
||||
r < args.non_col_reductions * args.reduction_size;
|
||||
r += block.dim_threads().y) {
|
||||
U vals[N_READS];
|
||||
cub::LoadDirectBlocked(
|
||||
column,
|
||||
make_cast_iterator<U>(in + loop.location()),
|
||||
vals,
|
||||
args.reduction_stride,
|
||||
ReduceInit<Op, T>::value());
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
totals[i] = op(vals[i], totals[i]);
|
||||
}
|
||||
loop.next(
|
||||
block.dim_threads().y,
|
||||
args.reduce_shape.data(),
|
||||
args.reduce_strides.data());
|
||||
}
|
||||
|
||||
// Do block reduce when each column has more than 1 element to reduce.
|
||||
if (block.dim_threads().y > 1) {
|
||||
__shared__ U shared_vals[32 * 8 * N_READS];
|
||||
size_t col =
|
||||
block.thread_index().y * block.dim_threads().x + block.thread_index().x;
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
shared_vals[col * N_READS + i] = totals[i];
|
||||
}
|
||||
block.sync();
|
||||
if (block.thread_index().y == 0) {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
totals[i] = shared_vals[block.thread_index().x * N_READS + i];
|
||||
}
|
||||
for (int j = 1; j < block.dim_threads().y; j++) {
|
||||
col = j * block.dim_threads().x + block.thread_index().x;
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
totals[i] = op(shared_vals[col * N_READS + i], totals[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Write result.
|
||||
if (block.thread_index().y == 0) {
|
||||
cub::StoreDirectBlocked(
|
||||
column,
|
||||
out + out_idx * args.reduction_stride,
|
||||
totals,
|
||||
args.reduction_stride);
|
||||
}
|
||||
}
|
||||
|
||||
template <
|
||||
typename T,
|
||||
typename U,
|
||||
@@ -152,67 +91,94 @@ template <
|
||||
int BM,
|
||||
int BN,
|
||||
int N_READS = 4>
|
||||
__global__ void col_reduce_looped(
|
||||
const T* in,
|
||||
U* out,
|
||||
const __grid_constant__ ColReduceArgs args) {
|
||||
__global__ void
|
||||
col_reduce_looped(T* in, U* out, const __grid_constant__ ColReduceArgs args) {
|
||||
auto grid = cg::this_grid();
|
||||
auto block = cg::this_thread_block();
|
||||
auto warp = cg::tiled_partition<WARP_SIZE>(block);
|
||||
|
||||
constexpr int n_warps = BN / N_READS;
|
||||
constexpr int threads_per_row = BN / N_READS;
|
||||
|
||||
int out_idx = grid.block_rank() / grid.dim_blocks().x;
|
||||
in += elem_to_loc(out_idx, args.shape.data(), args.strides.data(), args.ndim);
|
||||
// Compute the indices for the tile
|
||||
size_t tile_idx = grid.block_rank();
|
||||
size_t tile_x = tile_idx % ((args.reduction_stride + BN - 1) / BN);
|
||||
size_t tile_y = tile_idx / ((args.reduction_stride + BN - 1) / BN);
|
||||
|
||||
// Compute the indices for the thread within the tile
|
||||
short thread_x = block.thread_rank() % threads_per_row;
|
||||
short thread_y = block.thread_rank() / threads_per_row;
|
||||
|
||||
// Move the input pointer
|
||||
in += elem_to_loc(tile_y, args.shape.data(), args.strides.data(), args.ndim) +
|
||||
tile_x * BN;
|
||||
|
||||
// Initialize the running totals
|
||||
Op op;
|
||||
U totals[N_READS];
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
totals[i] = ReduceInit<Op, T>::value();
|
||||
}
|
||||
|
||||
// Read input to local.
|
||||
int r = block.thread_rank() / n_warps;
|
||||
int column = block.thread_rank() % n_warps;
|
||||
int in_offset = grid.block_index().x * BN;
|
||||
LoopedElemToLoc<NDIM, (NDIM > 2)> loop(args.reduce_ndim);
|
||||
loop.next(r, args.reduce_shape.data(), args.reduce_strides.data());
|
||||
for (; r < args.non_col_reductions * args.reduction_size; r += BM) {
|
||||
U vals[N_READS];
|
||||
cub::LoadDirectBlocked(
|
||||
column,
|
||||
make_cast_iterator<U>(in + loop.location() + in_offset),
|
||||
vals,
|
||||
args.reduction_stride - in_offset,
|
||||
ReduceInit<Op, T>::value());
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
totals[i] = op(vals[i], totals[i]);
|
||||
loop.next(thread_y, args.reduce_shape.data(), args.reduce_strides.data());
|
||||
size_t total = args.non_col_reductions * args.reduction_size;
|
||||
if (tile_x * BN + BN <= args.reduction_stride) {
|
||||
if (args.reduction_stride % N_READS == 0) {
|
||||
for (size_t r = thread_y; r < total; r += BM) {
|
||||
T vals[N_READS];
|
||||
cub::LoadDirectBlockedVectorized(thread_x, in + loop.location(), vals);
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
totals[i] = op(totals[i], __cast<U, T>(vals[i]));
|
||||
}
|
||||
loop.next(BM, args.reduce_shape.data(), args.reduce_strides.data());
|
||||
}
|
||||
} else {
|
||||
for (size_t r = thread_y; r < total; r += BM) {
|
||||
T vals[N_READS];
|
||||
cub::LoadDirectBlocked(thread_x, in + loop.location(), vals);
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
totals[i] = op(totals[i], __cast<U, T>(vals[i]));
|
||||
}
|
||||
loop.next(BM, args.reduce_shape.data(), args.reduce_strides.data());
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for (size_t r = thread_y; r < total; r += BM) {
|
||||
T vals[N_READS];
|
||||
cub::LoadDirectBlocked(
|
||||
thread_x,
|
||||
in + loop.location(),
|
||||
vals,
|
||||
args.reduction_stride - tile_x * BN,
|
||||
__cast<T, U>(ReduceInit<Op, T>::value()));
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
totals[i] = op(totals[i], __cast<U, T>(vals[i]));
|
||||
}
|
||||
loop.next(BM, args.reduce_shape.data(), args.reduce_strides.data());
|
||||
}
|
||||
loop.next(BM, args.reduce_shape.data(), args.reduce_strides.data());
|
||||
}
|
||||
|
||||
// Do warp reduce for each output.
|
||||
constexpr int n_outputs = BN / n_warps;
|
||||
constexpr int n_outputs = BN / threads_per_row;
|
||||
static_assert(BM == 32 && n_outputs == N_READS);
|
||||
__shared__ U shared_vals[BM * BN];
|
||||
size_t col = block.thread_index().y * BN + block.thread_index().x * N_READS;
|
||||
short s_idx = thread_y * BN + thread_x * N_READS;
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
shared_vals[col + i] = totals[i];
|
||||
shared_vals[s_idx + i] = totals[i];
|
||||
}
|
||||
block.sync();
|
||||
col = warp.thread_rank() * BN + warp.meta_group_rank() * n_outputs;
|
||||
s_idx = warp.thread_rank() * BN + warp.meta_group_rank() * n_outputs;
|
||||
for (int i = 0; i < n_outputs; i++) {
|
||||
totals[i] = cg::reduce(warp, shared_vals[col + i], op);
|
||||
totals[i] = cg::reduce(warp, shared_vals[s_idx + i], op);
|
||||
}
|
||||
|
||||
// Write result.
|
||||
if (warp.thread_rank() == 0) {
|
||||
size_t out_offset = grid.block_index().x * BN;
|
||||
cub::StoreDirectBlocked(
|
||||
warp.meta_group_rank(),
|
||||
out + out_idx * args.reduction_stride + out_offset,
|
||||
out + tile_y * args.reduction_stride + tile_x * BN,
|
||||
totals,
|
||||
args.reduction_stride - out_offset);
|
||||
args.reduction_stride - tile_x * BN);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -220,14 +186,57 @@ __global__ void col_reduce_looped(
|
||||
|
||||
inline auto output_grid_for_col_reduce(
|
||||
const array& out,
|
||||
const cu::ColReduceArgs& args) {
|
||||
auto out_shape = out.shape();
|
||||
auto out_strides = out.strides();
|
||||
while (!out_shape.empty() && out_strides.back() < args.reduction_stride) {
|
||||
out_shape.pop_back();
|
||||
out_strides.pop_back();
|
||||
const cu::ColReduceArgs& args,
|
||||
int bn) {
|
||||
int gx, gy = 1;
|
||||
size_t n_inner_blocks = cuda::ceil_div(args.reduction_stride, bn);
|
||||
size_t n_outer_blocks = out.size() / args.reduction_stride;
|
||||
size_t n_blocks = n_outer_blocks * n_inner_blocks;
|
||||
while (n_blocks / gy > INT32_MAX) {
|
||||
gy *= 2;
|
||||
}
|
||||
return get_2d_grid_dims(out_shape, out_strides);
|
||||
gx = cuda::ceil_div(n_blocks, gy);
|
||||
|
||||
return dim3(gx, gy, 1);
|
||||
}
|
||||
|
||||
void col_reduce_looped(
|
||||
cu::CommandEncoder& encoder,
|
||||
const array& in,
|
||||
array& out,
|
||||
Reduce::ReduceType reduce_type,
|
||||
const std::vector<int>& axes,
|
||||
const ReductionPlan& plan,
|
||||
cu::ColReduceArgs args) {
|
||||
// Allocate data for the output using in's layout to access them as
|
||||
// contiguously as possible.
|
||||
allocate_same_layout(out, in, axes);
|
||||
|
||||
encoder.set_input_array(in);
|
||||
encoder.set_output_array(out);
|
||||
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||
dispatch_all_types(in.dtype(), [&](auto type_tag) {
|
||||
dispatch_reduce_ops(reduce_type, [&](auto reduce_type_tag) {
|
||||
dispatch_reduce_ndim(args.reduce_ndim, [&](auto reduce_ndim) {
|
||||
using OP = MLX_GET_TYPE(reduce_type_tag);
|
||||
using T = cuda_type_t<MLX_GET_TYPE(type_tag)>;
|
||||
using U = typename cu::ReduceResult<OP, T>::type;
|
||||
|
||||
// Cub doesn't like const pointers for vectorized loads. (sigh)
|
||||
T* indata = const_cast<T*>(in.data<T>());
|
||||
|
||||
constexpr int N_READS = 4;
|
||||
constexpr int BM = 32;
|
||||
constexpr int BN = 32;
|
||||
dim3 grid = output_grid_for_col_reduce(out, args, BN);
|
||||
int blocks = BM * BN / N_READS;
|
||||
auto kernel =
|
||||
cu::col_reduce_looped<T, U, OP, reduce_ndim(), BM, BN, N_READS>;
|
||||
kernel<<<grid, blocks, 0, stream>>>(indata, out.data<U>(), args);
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
void col_reduce(
|
||||
@@ -237,42 +246,23 @@ void col_reduce(
|
||||
Reduce::ReduceType reduce_type,
|
||||
const std::vector<int>& axes,
|
||||
const ReductionPlan& plan) {
|
||||
// Current col reduce options
|
||||
//
|
||||
// - col_reduce_looped
|
||||
//
|
||||
// It is a general strided reduce. Each threadblock computes the output for
|
||||
// a subrow of the fast moving axis. For instance 32 elements.
|
||||
//
|
||||
// Notes: As in row reduce we opt to read as much in order as possible and
|
||||
// leave transpositions as they are (contrary to our Metal backend).
|
||||
//
|
||||
// Moreover we need different kernels for short rows and tuning
|
||||
|
||||
// Make the args struct to help route to the best kernel
|
||||
cu::ColReduceArgs args(in, plan, axes);
|
||||
|
||||
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||
MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE, {
|
||||
using InType = cuda_type_t<CTYPE>;
|
||||
MLX_SWITCH_REDUCE_OPS(reduce_type, OP, {
|
||||
using OutType = cu::ReduceResult<OP, InType>::type;
|
||||
MLX_SWITCH_REDUCE_NDIM(args.reduce_ndim, NDIM, {
|
||||
constexpr int N_READS = 4;
|
||||
dim3 block_dims;
|
||||
dim3 num_blocks = output_grid_for_col_reduce(out, args);
|
||||
num_blocks.z = num_blocks.y;
|
||||
num_blocks.y = num_blocks.x;
|
||||
auto kernel =
|
||||
cu::col_reduce_small<InType, OutType, OP, NDIM, N_READS>;
|
||||
size_t total = args.non_col_reductions * args.reduction_size;
|
||||
if (total < 32) {
|
||||
size_t stride_blocks =
|
||||
cuda::ceil_div(args.reduction_stride, N_READS);
|
||||
block_dims.x = std::min(stride_blocks, 32ul);
|
||||
block_dims.y = std::min(total, 8ul);
|
||||
num_blocks.x = cuda::ceil_div(stride_blocks, block_dims.x);
|
||||
} else {
|
||||
constexpr int BM = 32;
|
||||
constexpr int BN = 32;
|
||||
block_dims.x = BM * BN / N_READS;
|
||||
num_blocks.x = cuda::ceil_div(args.reduction_stride, BN);
|
||||
kernel = cu::
|
||||
col_reduce_looped<InType, OutType, OP, NDIM, BM, BN, N_READS>;
|
||||
}
|
||||
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
||||
in.data<InType>(), out.data<OutType>(), args);
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
// Fallback col reduce
|
||||
col_reduce_looped(encoder, in, out, reduce_type, axes, plan, args);
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
||||
51
mlx/backend/cuda/reduce/init_reduce.cu
Normal file
51
mlx/backend/cuda/reduce/init_reduce.cu
Normal file
@@ -0,0 +1,51 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cuda/device.h"
|
||||
#include "mlx/backend/cuda/reduce/reduce.cuh"
|
||||
|
||||
#include <cooperative_groups.h>
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace cu {
|
||||
|
||||
namespace cg = cooperative_groups;
|
||||
|
||||
template <typename T, typename U, typename Op>
|
||||
__global__ void init_reduce(U* out, size_t size) {
|
||||
auto index = cg::this_grid().thread_rank();
|
||||
if (index < size) {
|
||||
out[index] = ReduceInit<Op, T>::value();
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace cu
|
||||
|
||||
void init_reduce(
|
||||
cu::CommandEncoder& encoder,
|
||||
const array& in,
|
||||
array& out,
|
||||
Reduce::ReduceType reduce_type) {
|
||||
// Allocate if needed
|
||||
if (out.data_shared_ptr() == nullptr) {
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
}
|
||||
|
||||
encoder.set_output_array(out);
|
||||
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||
dispatch_all_types(in.dtype(), [&](auto type_tag) {
|
||||
dispatch_reduce_ops(reduce_type, [&](auto reduce_type_tag) {
|
||||
using OP = MLX_GET_TYPE(reduce_type_tag);
|
||||
using T = cuda_type_t<MLX_GET_TYPE(type_tag)>;
|
||||
using U = typename cu::ReduceResult<OP, T>::type;
|
||||
auto kernel = cu::init_reduce<T, U, OP>;
|
||||
dim3 grid = get_2d_grid_dims(out.shape(), out.strides());
|
||||
dim3 block(grid.x < 1024 ? grid.x : 1024, 1, 1);
|
||||
grid.x = (grid.x + 1023) / 1024;
|
||||
kernel<<<grid, block, 0, stream>>>(out.data<U>(), out.size());
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
@@ -1,5 +1,7 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include <type_traits>
|
||||
|
||||
#include "mlx/backend/common/reduce.h"
|
||||
#include "mlx/backend/cuda/device/cucomplex_math.cuh"
|
||||
#include "mlx/backend/cuda/kernel_utils.cuh"
|
||||
@@ -9,51 +11,41 @@
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
// Dispatch dynamic ndim to constexpr.
|
||||
// The behavior follows get_kernel_reduce_ndim in metal/reduce.cpp file.
|
||||
#define MLX_SWITCH_REDUCE_NDIM(ndim, NDIM, ...) \
|
||||
if (ndim == 1) { \
|
||||
constexpr uint32_t NDIM = 1; \
|
||||
__VA_ARGS__; \
|
||||
} else if (ndim == 2) { \
|
||||
constexpr uint32_t NDIM = 2; \
|
||||
__VA_ARGS__; \
|
||||
} else { \
|
||||
constexpr uint32_t NDIM = 5; \
|
||||
__VA_ARGS__; \
|
||||
template <typename F>
|
||||
void dispatch_reduce_ndim(int ndim, F&& f) {
|
||||
if (ndim == 1) {
|
||||
f(std::integral_constant<int, 1>{});
|
||||
} else if (ndim == 2) {
|
||||
f(std::integral_constant<int, 2>{});
|
||||
} else {
|
||||
f(std::integral_constant<int, 5>{});
|
||||
}
|
||||
}
|
||||
|
||||
// Dispatch reduce ops to constexpr.
|
||||
#define MLX_SWITCH_REDUCE_OPS(REDUCE, OP, ...) \
|
||||
if (REDUCE == Reduce::ReduceType::And) { \
|
||||
using OP = cu::And; \
|
||||
__VA_ARGS__; \
|
||||
} else if (REDUCE == Reduce::ReduceType::Or) { \
|
||||
using OP = cu::Or; \
|
||||
__VA_ARGS__; \
|
||||
} else if (REDUCE == Reduce::ReduceType::Sum) { \
|
||||
using OP = cu::Sum; \
|
||||
__VA_ARGS__; \
|
||||
} else if (REDUCE == Reduce::ReduceType::Prod) { \
|
||||
using OP = cu::Prod; \
|
||||
__VA_ARGS__; \
|
||||
} else if (REDUCE == Reduce::ReduceType::Max) { \
|
||||
using OP = cu::Max; \
|
||||
__VA_ARGS__; \
|
||||
} else if (REDUCE == Reduce::ReduceType::Min) { \
|
||||
using OP = cu::Min; \
|
||||
__VA_ARGS__; \
|
||||
} else { \
|
||||
throw std::invalid_argument("Unknown reduce type."); \
|
||||
template <typename F>
|
||||
void dispatch_reduce_ops(Reduce::ReduceType reduce_type, F&& f) {
|
||||
if (reduce_type == Reduce::ReduceType::And) {
|
||||
f(type_identity<cu::And>{});
|
||||
} else if (reduce_type == Reduce::ReduceType::Or) {
|
||||
f(type_identity<cu::Or>{});
|
||||
} else if (reduce_type == Reduce::ReduceType::Sum) {
|
||||
f(type_identity<cu::Sum>{});
|
||||
} else if (reduce_type == Reduce::ReduceType::Prod) {
|
||||
f(type_identity<cu::Prod>{});
|
||||
} else if (reduce_type == Reduce::ReduceType::Max) {
|
||||
f(type_identity<cu::Max>{});
|
||||
} else if (reduce_type == Reduce::ReduceType::Min) {
|
||||
f(type_identity<cu::Min>{});
|
||||
} else {
|
||||
throw std::invalid_argument("Unknown reduce type.");
|
||||
}
|
||||
}
|
||||
|
||||
void segmented_reduce(
|
||||
void all_reduce(
|
||||
cu::CommandEncoder& encoder,
|
||||
const array& in,
|
||||
array& out,
|
||||
Reduce::ReduceType reduce_type,
|
||||
const std::vector<int>& axes,
|
||||
const ReductionPlan& plan);
|
||||
Reduce::ReduceType reduce_type);
|
||||
|
||||
void row_reduce(
|
||||
cu::CommandEncoder& encoder,
|
||||
@@ -71,4 +63,10 @@ void col_reduce(
|
||||
const std::vector<int>& axes,
|
||||
const ReductionPlan& plan);
|
||||
|
||||
void init_reduce(
|
||||
cu::CommandEncoder& encoder,
|
||||
const array& in,
|
||||
array& out,
|
||||
Reduce::ReduceType reduce_type);
|
||||
|
||||
} // namespace mlx::core
|
||||
|
||||
@@ -3,48 +3,89 @@
|
||||
#pragma once
|
||||
|
||||
#include "mlx/backend/cuda/device/utils.cuh"
|
||||
#include "mlx/backend/cuda/reduce/reduce_utils.cuh"
|
||||
|
||||
namespace mlx::core::cu {
|
||||
|
||||
// Reduce ops.
|
||||
struct And {
|
||||
__device__ bool operator()(bool a, bool b) {
|
||||
__device__ __forceinline__ bool operator()(bool a, bool b) {
|
||||
return a && b;
|
||||
}
|
||||
|
||||
__device__ void atomic_update(bool* x, bool y) {
|
||||
atomic_reduce<bool, And>(x, y);
|
||||
}
|
||||
};
|
||||
|
||||
struct Or {
|
||||
__device__ bool operator()(bool a, bool b) {
|
||||
__device__ __forceinline__ bool operator()(bool a, bool b) {
|
||||
return a || b;
|
||||
}
|
||||
|
||||
__device__ void atomic_update(bool* x, bool y) {
|
||||
atomic_reduce<bool, Or>(x, y);
|
||||
}
|
||||
};
|
||||
|
||||
struct Sum {
|
||||
template <typename T>
|
||||
__device__ T operator()(T a, T b) {
|
||||
__device__ __forceinline__ T operator()(T a, T b) {
|
||||
return a + b;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__device__ void atomic_update(T* x, T y) {
|
||||
atomic_reduce<T, Sum>(x, y);
|
||||
}
|
||||
|
||||
__device__ void atomic_update(__nv_bfloat16* x, __nv_bfloat16 y) {
|
||||
atomicAdd(x, y);
|
||||
}
|
||||
|
||||
__device__ void atomic_update(int* x, int y) {
|
||||
atomicAdd(x, y);
|
||||
}
|
||||
|
||||
__device__ void atomic_update(float* x, float y) {
|
||||
atomicAdd(x, y);
|
||||
}
|
||||
};
|
||||
|
||||
struct Prod {
|
||||
template <typename T>
|
||||
__device__ T operator()(T a, T b) {
|
||||
__device__ __forceinline__ T operator()(T a, T b) {
|
||||
return a * b;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__device__ void atomic_update(T* x, T y) {
|
||||
atomic_reduce<T, Prod>(x, y);
|
||||
}
|
||||
};
|
||||
|
||||
struct Min {
|
||||
template <typename T>
|
||||
__device__ T operator()(T a, T b) {
|
||||
__device__ __forceinline__ T operator()(T a, T b) {
|
||||
return a < b ? a : b;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__device__ void atomic_update(T* x, T y) {
|
||||
atomic_reduce<T, Min>(x, y);
|
||||
}
|
||||
};
|
||||
|
||||
struct Max {
|
||||
template <typename T>
|
||||
__device__ T operator()(T a, T b) {
|
||||
__device__ __forceinline__ T operator()(T a, T b) {
|
||||
return a > b ? a : b;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__device__ void atomic_update(T* x, T y) {
|
||||
atomic_reduce<T, Max>(x, y);
|
||||
}
|
||||
};
|
||||
|
||||
// Traits to get the result type of reduce op.
|
||||
@@ -120,7 +161,7 @@ template <typename T>
|
||||
struct ReduceInit<Prod, T> {
|
||||
static constexpr __host__ __device__ auto value() {
|
||||
if constexpr (cuda::std::is_same_v<T, cuComplex>) {
|
||||
return T{1, 1};
|
||||
return T{1, 0};
|
||||
} else {
|
||||
return typename ReduceResult<Prod, T>::type{1};
|
||||
}
|
||||
|
||||
158
mlx/backend/cuda/reduce/reduce_utils.cuh
Normal file
158
mlx/backend/cuda/reduce/reduce_utils.cuh
Normal file
@@ -0,0 +1,158 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <numeric>
|
||||
|
||||
#include "mlx/backend/cuda/device/utils.cuh"
|
||||
|
||||
#include <cooperative_groups.h>
|
||||
#include <cooperative_groups/reduce.h>
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace cu {
|
||||
|
||||
namespace cg = cooperative_groups;
|
||||
|
||||
template <size_t N>
|
||||
struct uint_by_size;
|
||||
template <>
|
||||
struct uint_by_size<2> {
|
||||
using type = uint16_t;
|
||||
};
|
||||
template <>
|
||||
struct uint_by_size<4> {
|
||||
using type = uint32_t;
|
||||
};
|
||||
template <>
|
||||
struct uint_by_size<8> {
|
||||
using type = unsigned long long int;
|
||||
};
|
||||
|
||||
template <typename T, typename Op>
|
||||
__device__ void atomic_reduce(T* x, T y) {
|
||||
if constexpr (sizeof(T) == 1) {
|
||||
using U = uint16_t;
|
||||
U* x_int = (U*)((char*)x - ((size_t)x % 2));
|
||||
int shift = ((char*)x - (char*)x_int) * 8;
|
||||
int mask = 0xff << shift;
|
||||
U old_val, new_val;
|
||||
do {
|
||||
old_val = *x_int;
|
||||
T result = Op{}(static_cast<T>((old_val >> shift) & 0xff), y);
|
||||
new_val = (old_val & ~mask) | (result << shift);
|
||||
} while (atomicCAS(x_int, old_val, new_val) != old_val);
|
||||
} else {
|
||||
using U = typename uint_by_size<sizeof(T)>::type;
|
||||
U* x_int = (U*)(x);
|
||||
U old_val, new_val;
|
||||
do {
|
||||
old_val = *x_int;
|
||||
T result = Op{}(*((T*)&old_val), y);
|
||||
new_val = *((U*)&result);
|
||||
} while (atomicCAS(x_int, old_val, new_val) != old_val);
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: Should make a custom complex type
|
||||
template <typename U, typename T>
|
||||
inline __device__ U __cast(T x) {
|
||||
return static_cast<U>(x);
|
||||
}
|
||||
|
||||
template <>
|
||||
inline __device__ bool __cast<bool, cuComplex>(cuComplex x) {
|
||||
return x.x != 0 && x.y != 0;
|
||||
}
|
||||
|
||||
template <>
|
||||
inline __device__ cuComplex __cast<cuComplex, bool>(bool x) {
|
||||
return x ? make_cuFloatComplex(1, 1) : make_cuFloatComplex(0, 0);
|
||||
}
|
||||
|
||||
template <typename T, int N, typename Block, typename Warp, typename Op>
|
||||
inline __device__ void
|
||||
block_reduce(Block block, Warp warp, T (&vals)[N], T* smem, Op op, T init) {
|
||||
// First reduce in the current warp
|
||||
for (int i = 0; i < N; i++) {
|
||||
vals[i] = cg::reduce(warp, vals[i], op);
|
||||
}
|
||||
|
||||
// Reduce across warps
|
||||
if (warp.meta_group_size() > 1) {
|
||||
if (warp.thread_rank() == 0) {
|
||||
for (int i = 0; i < N; i++) {
|
||||
smem[warp.meta_group_rank() * N + i] = vals[i];
|
||||
}
|
||||
}
|
||||
block.sync();
|
||||
if (warp.thread_rank() < warp.meta_group_size()) {
|
||||
for (int i = 0; i < N; i++) {
|
||||
vals[i] = smem[warp.thread_rank() * N + i];
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < N; i++) {
|
||||
vals[i] = init;
|
||||
}
|
||||
}
|
||||
for (int i = 0; i < N; i++) {
|
||||
vals[i] = cg::reduce(warp, vals[i], op);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace cu
|
||||
|
||||
inline void allocate_same_layout(
|
||||
array& out,
|
||||
const array& in,
|
||||
const std::vector<int>& axes) {
|
||||
if (in.flags().row_contiguous) {
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
return;
|
||||
}
|
||||
|
||||
if (out.ndim() < in.ndim()) {
|
||||
throw std::runtime_error(
|
||||
"Reduction without keepdims only supported for row-contiguous inputs");
|
||||
}
|
||||
|
||||
// Calculate the transpositions applied to in in order to apply them to out.
|
||||
std::vector<int> axis_order(in.ndim());
|
||||
std::iota(axis_order.begin(), axis_order.end(), 0);
|
||||
std::sort(axis_order.begin(), axis_order.end(), [&](int left, int right) {
|
||||
return in.strides(left) > in.strides(right);
|
||||
});
|
||||
|
||||
// Transpose the shape and calculate the strides
|
||||
Shape out_shape(in.ndim());
|
||||
Strides out_strides(in.ndim(), 1);
|
||||
for (int i = 0; i < in.ndim(); i++) {
|
||||
out_shape[i] = out.shape(axis_order[i]);
|
||||
}
|
||||
for (int i = in.ndim() - 2; i >= 0; i--) {
|
||||
out_strides[i] = out_shape[i + 1] * out_strides[i + 1];
|
||||
}
|
||||
|
||||
// Reverse the axis order to get the final strides
|
||||
Strides final_strides(in.ndim());
|
||||
for (int i = 0; i < in.ndim(); i++) {
|
||||
final_strides[axis_order[i]] = out_strides[i];
|
||||
}
|
||||
|
||||
// Calculate the resulting contiguity and do the memory allocation
|
||||
auto [data_size, rc, cc] = check_contiguity(out.shape(), final_strides);
|
||||
auto fl = in.flags();
|
||||
fl.row_contiguous = rc;
|
||||
fl.col_contiguous = cc;
|
||||
fl.contiguous = true;
|
||||
out.set_data(
|
||||
allocator::malloc(out.nbytes()),
|
||||
data_size,
|
||||
final_strides,
|
||||
fl,
|
||||
allocator::free);
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
@@ -1,5 +1,7 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include <numeric>
|
||||
|
||||
#include "mlx/backend/cuda/device.h"
|
||||
#include "mlx/backend/cuda/device/cast_op.cuh"
|
||||
#include "mlx/backend/cuda/reduce/reduce.cuh"
|
||||
@@ -55,84 +57,108 @@ struct RowReduceArgs {
|
||||
non_row_reductions *= reduce_shape[i];
|
||||
}
|
||||
}
|
||||
|
||||
// Convert shape and strides as if in was contiguous
|
||||
void sort_access_pattern(const array& in, const std::vector<int>& axes) {
|
||||
auto shape_vec = in.shape();
|
||||
auto strides_vec = in.strides();
|
||||
std::tie(shape_vec, strides_vec) =
|
||||
shapes_without_reduction_axes(shape_vec, strides_vec, axes);
|
||||
std::vector<int> indices(shape_vec.size());
|
||||
std::iota(indices.begin(), indices.end(), 0);
|
||||
std::sort(indices.begin(), indices.end(), [&](int left, int right) {
|
||||
return strides_vec[left] > strides_vec[right];
|
||||
});
|
||||
decltype(shape_vec) sorted_shape;
|
||||
decltype(strides_vec) sorted_strides;
|
||||
for (auto idx : indices) {
|
||||
sorted_shape.push_back(shape_vec[idx]);
|
||||
sorted_strides.push_back(strides_vec[idx]);
|
||||
}
|
||||
std::tie(shape_vec, strides_vec) =
|
||||
collapse_contiguous_dims(sorted_shape, sorted_strides);
|
||||
shape = const_param(shape_vec);
|
||||
strides = const_param(strides_vec);
|
||||
ndim = shape_vec.size();
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, typename U, typename Op, int NDIM, int N_READS = 4>
|
||||
__global__ void row_reduce_small(
|
||||
const T* in,
|
||||
U* out,
|
||||
size_t out_size,
|
||||
const __grid_constant__ RowReduceArgs args) {
|
||||
size_t out_idx = cg::this_grid().thread_rank();
|
||||
if (out_idx >= out_size) {
|
||||
return;
|
||||
}
|
||||
|
||||
Op op;
|
||||
|
||||
U total_val = ReduceInit<Op, T>::value();
|
||||
LoopedElemToLoc<NDIM, (NDIM > 2)> loop(args.reduce_ndim);
|
||||
|
||||
in += elem_to_loc(out_idx, args.shape.data(), args.strides.data(), args.ndim);
|
||||
|
||||
for (size_t n = 0; n < args.non_row_reductions; n++) {
|
||||
for (int r = 0; r < cuda::ceil_div(args.row_size, N_READS); r++) {
|
||||
U vals[N_READS];
|
||||
cub::LoadDirectBlocked(
|
||||
r,
|
||||
make_cast_iterator<U>(in + loop.location()),
|
||||
vals,
|
||||
args.row_size,
|
||||
ReduceInit<Op, T>::value());
|
||||
total_val = op(total_val, cub::ThreadReduce(vals, op));
|
||||
}
|
||||
loop.next(args.reduce_shape.data(), args.reduce_strides.data());
|
||||
}
|
||||
|
||||
out[out_idx] = total_val;
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op, int NDIM, int N_READS = 4>
|
||||
__global__ void row_reduce_small_warp(
|
||||
const T* in,
|
||||
U* out,
|
||||
size_t out_size,
|
||||
const __grid_constant__ RowReduceArgs args) {
|
||||
template <typename T, typename U, typename ReduceOp, int N = 4, int M = 1>
|
||||
__global__ void row_reduce_simple(T* in, U* out, size_t n_rows, int size) {
|
||||
auto grid = cg::this_grid();
|
||||
auto block = cg::this_thread_block();
|
||||
auto warp = cg::tiled_partition<WARP_SIZE>(block);
|
||||
|
||||
size_t out_idx = grid.thread_rank() / WARP_SIZE;
|
||||
if (out_idx >= out_size) {
|
||||
return;
|
||||
const U init = cu::ReduceInit<ReduceOp, T>::value();
|
||||
ReduceOp op;
|
||||
|
||||
T vals[M][N];
|
||||
U accs[M];
|
||||
for (int i = 0; i < M; i++) {
|
||||
accs[i] = init;
|
||||
}
|
||||
|
||||
Op op;
|
||||
const size_t start_row =
|
||||
min(n_rows - M, static_cast<size_t>(grid.block_rank() * M));
|
||||
const size_t full_blocks = size / (block.size() * N);
|
||||
const size_t final_offset = full_blocks * (block.size() * N);
|
||||
in += start_row * size;
|
||||
out += start_row;
|
||||
|
||||
U total_val = ReduceInit<Op, T>::value();
|
||||
LoopedElemToLoc<NDIM, (NDIM > 2)> loop(args.reduce_ndim);
|
||||
|
||||
in += elem_to_loc(out_idx, args.shape.data(), args.strides.data(), args.ndim);
|
||||
|
||||
for (size_t n = warp.thread_rank(); n < args.non_row_reductions;
|
||||
n += WARP_SIZE) {
|
||||
for (int r = 0; r < cuda::ceil_div(args.row_size, N_READS); r++) {
|
||||
U vals[N_READS];
|
||||
cub::LoadDirectBlocked(
|
||||
r,
|
||||
make_cast_iterator<U>(in + loop.location()),
|
||||
vals,
|
||||
args.row_size,
|
||||
ReduceInit<Op, T>::value());
|
||||
total_val = op(total_val, cub::ThreadReduce(vals, op));
|
||||
if (size % N == 0) {
|
||||
for (size_t r = 0; r < full_blocks; r++) {
|
||||
for (int k = 0; k < M; k++) {
|
||||
cub::LoadDirectBlockedVectorized<T, N>(
|
||||
block.thread_rank(),
|
||||
in + k * size + r * (block.size() * N),
|
||||
vals[k]);
|
||||
for (int j = 0; j < N; j++) {
|
||||
accs[k] = op(accs[k], __cast<U, T>(vals[k][j]));
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for (size_t r = 0; r < full_blocks; r++) {
|
||||
for (int k = 0; k < M; k++) {
|
||||
cub::LoadDirectBlocked(
|
||||
block.thread_rank(),
|
||||
in + k * size + r * (block.size() * N),
|
||||
vals[k]);
|
||||
for (int j = 0; j < N; j++) {
|
||||
accs[k] = op(accs[k], __cast<U, T>(vals[k][j]));
|
||||
}
|
||||
}
|
||||
}
|
||||
loop.next(WARP_SIZE, args.reduce_shape.data(), args.reduce_strides.data());
|
||||
}
|
||||
|
||||
total_val = cg::reduce(warp, total_val, op);
|
||||
if (final_offset < size) {
|
||||
for (int k = 0; k < M; k++) {
|
||||
cub::LoadDirectBlocked(
|
||||
block.thread_rank(),
|
||||
in + k * size + final_offset,
|
||||
vals[k],
|
||||
size,
|
||||
__cast<T, U>(init));
|
||||
for (int j = 0; j < N; j++) {
|
||||
accs[k] = op(accs[k], __cast<U, T>(vals[k][j]));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (warp.thread_rank() == 0) {
|
||||
out[out_idx] = total_val;
|
||||
__shared__ U shared_accumulators[32 * M];
|
||||
block_reduce(block, warp, accs, shared_accumulators, op, init);
|
||||
|
||||
if (block.thread_rank() == 0) {
|
||||
if (grid.block_rank() * M + M <= n_rows) {
|
||||
for (int i = 0; i < M; i++) {
|
||||
out[i] = accs[i];
|
||||
}
|
||||
} else {
|
||||
short offset = grid.block_rank() * M + M - n_rows;
|
||||
for (int i = offset; i < M; i++) {
|
||||
out[i] = accs[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -141,55 +167,173 @@ template <
|
||||
typename U,
|
||||
typename Op,
|
||||
int NDIM,
|
||||
int BLOCK_DIM_X,
|
||||
int BLOCK_DIM,
|
||||
int N_READS = 4>
|
||||
__global__ void row_reduce_looped(
|
||||
const T* in,
|
||||
T* in,
|
||||
U* out,
|
||||
size_t out_size,
|
||||
const __grid_constant__ RowReduceArgs args) {
|
||||
auto grid = cg::this_grid();
|
||||
auto block = cg::this_thread_block();
|
||||
auto warp = cg::tiled_partition<WARP_SIZE>(block);
|
||||
|
||||
size_t out_idx = grid.thread_rank() / BLOCK_DIM_X;
|
||||
if (out_idx >= out_size) {
|
||||
return;
|
||||
}
|
||||
size_t out_idx = grid.block_rank();
|
||||
|
||||
Op op;
|
||||
|
||||
U total_val = ReduceInit<Op, T>::value();
|
||||
U total[1];
|
||||
U init = ReduceInit<Op, T>::value();
|
||||
total[0] = init;
|
||||
LoopedElemToLoc<NDIM, (NDIM > 2)> loop(args.reduce_ndim);
|
||||
size_t full_blocks = args.row_size / (BLOCK_DIM * N_READS);
|
||||
size_t final_offset = full_blocks * BLOCK_DIM * N_READS;
|
||||
|
||||
in += elem_to_loc(out_idx, args.shape.data(), args.strides.data(), args.ndim);
|
||||
|
||||
for (size_t n = 0; n < args.non_row_reductions; n++) {
|
||||
for (size_t r = 0; r < cuda::ceil_div(args.row_size, BLOCK_DIM_X * N_READS);
|
||||
r++) {
|
||||
U vals[N_READS];
|
||||
cub::LoadDirectBlocked(
|
||||
r * BLOCK_DIM_X + block.thread_index().x,
|
||||
make_cast_iterator<U>(in + loop.location()),
|
||||
vals,
|
||||
args.row_size,
|
||||
ReduceInit<Op, T>::value());
|
||||
total_val = op(total_val, cub::ThreadReduce(vals, op));
|
||||
for (size_t r = 0; r < full_blocks; r++) {
|
||||
T vals[N_READS];
|
||||
cub::LoadDirectBlockedVectorized<T, N_READS>(
|
||||
block.thread_rank(),
|
||||
in + loop.location() + r * BLOCK_DIM * N_READS,
|
||||
vals);
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
total[0] = op(total[0], __cast<U, T>(vals[i]));
|
||||
}
|
||||
}
|
||||
if (final_offset < args.row_size) {
|
||||
T vals[N_READS];
|
||||
cub::LoadDirectBlocked(
|
||||
block.thread_rank(),
|
||||
in + loop.location() + final_offset,
|
||||
vals,
|
||||
args.row_size - final_offset,
|
||||
__cast<T, U>(init));
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
total[0] = op(total[0], __cast<U, T>(vals[i]));
|
||||
}
|
||||
}
|
||||
// TODO: Maybe block.sync() here?
|
||||
loop.next(args.reduce_shape.data(), args.reduce_strides.data());
|
||||
}
|
||||
|
||||
typedef cub::BlockReduce<U, BLOCK_DIM_X> BlockReduceT;
|
||||
__shared__ typename BlockReduceT::TempStorage temp;
|
||||
|
||||
total_val = BlockReduceT(temp).Reduce(total_val, op);
|
||||
__shared__ U shared_accumulators[32];
|
||||
block_reduce(block, warp, total, shared_accumulators, op, init);
|
||||
|
||||
if (block.thread_rank() == 0) {
|
||||
out[out_idx] = total_val;
|
||||
out[out_idx] = total[0];
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace cu
|
||||
|
||||
void row_reduce_simple(
|
||||
cu::CommandEncoder& encoder,
|
||||
const array& in,
|
||||
array& out,
|
||||
Reduce::ReduceType reduce_type,
|
||||
const std::vector<int>& axes,
|
||||
const ReductionPlan& plan) {
|
||||
constexpr int N_READS = 8;
|
||||
|
||||
// Allocate data for the output using in's layout to avoid elem_to_loc in the
|
||||
// kernel.
|
||||
allocate_same_layout(out, in, axes);
|
||||
|
||||
// TODO: If out.size() < 1024 which will be a common case then write this in
|
||||
// 2 passes. Something like 32 * out.size() and then do a warp reduce.
|
||||
encoder.set_input_array(in);
|
||||
encoder.set_output_array(out);
|
||||
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||
dispatch_all_types(in.dtype(), [&](auto type_tag) {
|
||||
dispatch_reduce_ops(reduce_type, [&](auto reduce_type_tag) {
|
||||
using OP = MLX_GET_TYPE(reduce_type_tag);
|
||||
using T = cuda_type_t<MLX_GET_TYPE(type_tag)>;
|
||||
using U = typename cu::ReduceResult<OP, T>::type;
|
||||
|
||||
// Cub doesn't like const pointers for vectorized loads. (sigh)
|
||||
T* indata = const_cast<T*>(in.data<T>());
|
||||
|
||||
// Calculate the grid and block dims
|
||||
size_t reductions = (plan.shape.back() + N_READS - 1) / N_READS;
|
||||
dim3 grid = get_2d_grid_dims(out.shape(), out.strides());
|
||||
int threads = std::min(1024UL, reductions);
|
||||
threads = ((threads + WARP_SIZE - 1) / WARP_SIZE) * WARP_SIZE;
|
||||
dim3 block(threads, 1, 1);
|
||||
|
||||
// Pick the kernel
|
||||
auto kernel = cu::row_reduce_simple<T, U, OP, N_READS>;
|
||||
if (grid.x >= 1024) {
|
||||
grid.x = (grid.x + 1) / 2;
|
||||
kernel = cu::row_reduce_simple<T, U, OP, N_READS, 2>;
|
||||
}
|
||||
|
||||
// Launch
|
||||
kernel<<<grid, block, 0, stream>>>(
|
||||
indata, out.data<U>(), out.size(), plan.shape.back());
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
void row_reduce_looped(
|
||||
cu::CommandEncoder& encoder,
|
||||
const array& in,
|
||||
array& out,
|
||||
Reduce::ReduceType reduce_type,
|
||||
const std::vector<int>& axes,
|
||||
const ReductionPlan& plan,
|
||||
cu::RowReduceArgs args) {
|
||||
constexpr int N_READS = 8;
|
||||
|
||||
// Allocate data for the output using in's layout to access them as
|
||||
// contiguously as possible.
|
||||
allocate_same_layout(out, in, axes);
|
||||
|
||||
encoder.set_input_array(in);
|
||||
encoder.set_output_array(out);
|
||||
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||
dispatch_all_types(in.dtype(), [&](auto type_tag) {
|
||||
dispatch_reduce_ops(reduce_type, [&](auto reduce_type_tag) {
|
||||
using OP = MLX_GET_TYPE(reduce_type_tag);
|
||||
using T = cuda_type_t<MLX_GET_TYPE(type_tag)>;
|
||||
using U = typename cu::ReduceResult<OP, T>::type;
|
||||
|
||||
// Cub doesn't like const pointers for vectorized loads. (sigh)
|
||||
T* indata = const_cast<T*>(in.data<T>());
|
||||
|
||||
// Calculate the grid and block dims
|
||||
args.sort_access_pattern(in, axes);
|
||||
dim3 grid = get_2d_grid_dims(out.shape(), out.strides());
|
||||
size_t reductions = (args.row_size + N_READS - 1) / N_READS;
|
||||
int threads = std::min(1024UL, reductions);
|
||||
threads = ((threads + WARP_SIZE - 1) / WARP_SIZE) * WARP_SIZE;
|
||||
dim3 block(threads, 1, 1);
|
||||
|
||||
// Pick the kernel
|
||||
auto kernel = cu::row_reduce_looped<T, U, OP, 1, 32, N_READS>;
|
||||
dispatch_reduce_ndim(args.reduce_ndim, [&](auto reduce_ndim) {
|
||||
dispatch_block_dim(threads, [&](auto threads_constant) {
|
||||
kernel = cu::row_reduce_looped<
|
||||
T,
|
||||
U,
|
||||
OP,
|
||||
reduce_ndim(),
|
||||
threads_constant(),
|
||||
N_READS>;
|
||||
block.x = threads_constant();
|
||||
});
|
||||
});
|
||||
|
||||
// Launch
|
||||
kernel<<<grid, block, 0, stream>>>(
|
||||
indata, out.data<U>(), out.size(), args);
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
void row_reduce(
|
||||
cu::CommandEncoder& encoder,
|
||||
const array& in,
|
||||
@@ -197,54 +341,35 @@ void row_reduce(
|
||||
Reduce::ReduceType reduce_type,
|
||||
const std::vector<int>& axes,
|
||||
const ReductionPlan& plan) {
|
||||
// Current row reduction options
|
||||
//
|
||||
// - row_reduce_simple
|
||||
//
|
||||
// That means that we are simply reducing across the fastest moving axis.
|
||||
// We are reducing 1 or 2 rows per threadblock depending on the size of
|
||||
// output.
|
||||
//
|
||||
// - row_reduce_looped
|
||||
//
|
||||
// It is a general row reduction. We are computing 1 output per
|
||||
// threadblock. We read the fastest moving axis vectorized and loop over
|
||||
// the rest of the axes.
|
||||
//
|
||||
// Notes: We opt to read as much in order as possible and leave
|
||||
// transpositions as they are (contrary to our Metal backend).
|
||||
|
||||
// Simple row reduce means that we have 1 axis that we are reducing over and
|
||||
// it has stride 1.
|
||||
if (plan.shape.size() == 1) {
|
||||
row_reduce_simple(encoder, in, out, reduce_type, axes, plan);
|
||||
return;
|
||||
}
|
||||
|
||||
// Make the args struct to help route to the best kernel
|
||||
cu::RowReduceArgs args(in, plan, axes);
|
||||
|
||||
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||
MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE, {
|
||||
using InType = cuda_type_t<CTYPE>;
|
||||
MLX_SWITCH_REDUCE_OPS(reduce_type, OP, {
|
||||
using OutType = cu::ReduceResult<OP, InType>::type;
|
||||
MLX_SWITCH_REDUCE_NDIM(args.reduce_ndim, NDIM, {
|
||||
constexpr size_t N_READS = 4;
|
||||
dim3 out_dims = get_2d_grid_dims(out.shape(), out.strides());
|
||||
dim3 block_dims, num_blocks;
|
||||
auto kernel =
|
||||
cu::row_reduce_small<InType, OutType, OP, NDIM, N_READS>;
|
||||
if (args.row_size <= 64) {
|
||||
if ((args.non_row_reductions < 32 && args.row_size <= 8) ||
|
||||
(args.non_row_reductions <= 8)) {
|
||||
block_dims.x = std::min(out_dims.x, 1024u);
|
||||
num_blocks.x = cuda::ceil_div(out_dims.x, block_dims.x);
|
||||
num_blocks.y = out_dims.y;
|
||||
} else {
|
||||
block_dims.x = WARP_SIZE;
|
||||
num_blocks.y = out_dims.x;
|
||||
num_blocks.z = out_dims.y;
|
||||
kernel =
|
||||
cu::row_reduce_small_warp<InType, OutType, OP, NDIM, N_READS>;
|
||||
}
|
||||
} else {
|
||||
size_t num_threads = cuda::ceil_div(args.row_size, N_READS);
|
||||
num_threads = cuda::ceil_div(num_threads, WARP_SIZE) * WARP_SIZE;
|
||||
MLX_SWITCH_BLOCK_DIM(num_threads, BLOCK_DIM_X, {
|
||||
num_blocks.y = out_dims.x;
|
||||
num_blocks.z = out_dims.y;
|
||||
block_dims.x = BLOCK_DIM_X;
|
||||
kernel = cu::row_reduce_looped<
|
||||
InType,
|
||||
OutType,
|
||||
OP,
|
||||
NDIM,
|
||||
BLOCK_DIM_X,
|
||||
N_READS>;
|
||||
});
|
||||
}
|
||||
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
||||
in.data<InType>(), out.data<OutType>(), out.size(), args);
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
// Fallback row reduce
|
||||
row_reduce_looped(encoder, in, out, reduce_type, axes, plan, std::move(args));
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
||||
@@ -1,84 +0,0 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cuda/device.h"
|
||||
#include "mlx/backend/cuda/device/cast_op.cuh"
|
||||
#include "mlx/backend/cuda/reduce/reduce.cuh"
|
||||
|
||||
#include <thrust/device_ptr.h>
|
||||
#include <cub/device/device_reduce.cuh>
|
||||
#include <cub/device/device_segmented_reduce.cuh>
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
template <typename... Args>
|
||||
void cub_all_reduce(cu::CommandEncoder& encoder, Args&&... args) {
|
||||
// Allocate temporary storage.
|
||||
size_t size;
|
||||
CHECK_CUDA_ERROR(cub::DeviceReduce::Reduce(nullptr, size, args...));
|
||||
array temp(allocator::malloc(size), {static_cast<int>(size)}, uint8);
|
||||
encoder.add_temporary(temp);
|
||||
// Run op.
|
||||
CHECK_CUDA_ERROR(cub::DeviceReduce::Reduce(temp.data<void>(), size, args...));
|
||||
}
|
||||
|
||||
template <typename... Args>
|
||||
void cub_segmented_reduce(cu::CommandEncoder& encoder, Args&&... args) {
|
||||
// Allocate temporary storage.
|
||||
size_t size;
|
||||
CHECK_CUDA_ERROR(cub::DeviceSegmentedReduce::Reduce(nullptr, size, args...));
|
||||
array temp(allocator::malloc(size), {static_cast<int>(size)}, uint8);
|
||||
encoder.add_temporary(temp);
|
||||
// Run op.
|
||||
CHECK_CUDA_ERROR(
|
||||
cub::DeviceSegmentedReduce::Reduce(temp.data<void>(), size, args...));
|
||||
}
|
||||
|
||||
struct MultiplyOp {
|
||||
int factor;
|
||||
__device__ int operator()(int i) {
|
||||
return i * factor;
|
||||
}
|
||||
};
|
||||
|
||||
void segmented_reduce(
|
||||
cu::CommandEncoder& encoder,
|
||||
const array& in,
|
||||
array& out,
|
||||
Reduce::ReduceType reduce_type,
|
||||
const std::vector<int>& axes,
|
||||
const ReductionPlan& plan) {
|
||||
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||
MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE, {
|
||||
MLX_SWITCH_REDUCE_OPS(reduce_type, OP, {
|
||||
using InType = cuda_type_t<CTYPE>;
|
||||
using OutType = cu::ReduceResult<OP, InType>::type;
|
||||
auto in_iter = cu::make_cast_iterator<OutType>(
|
||||
thrust::device_pointer_cast(in.data<InType>()));
|
||||
auto out_ptr = thrust::device_pointer_cast(out.data<OutType>());
|
||||
auto init = cu::ReduceInit<OP, InType>::value();
|
||||
|
||||
if (plan.type == ContiguousAllReduce) {
|
||||
cub_all_reduce(
|
||||
encoder, in_iter, out_ptr, in.data_size(), OP(), init, stream);
|
||||
} else if (plan.type == ContiguousReduce) {
|
||||
auto offsets = thrust::make_transform_iterator(
|
||||
thrust::make_counting_iterator(0), MultiplyOp{plan.shape.back()});
|
||||
cub_segmented_reduce(
|
||||
encoder,
|
||||
in_iter,
|
||||
out_ptr,
|
||||
out.size(),
|
||||
offsets,
|
||||
offsets + 1,
|
||||
OP(),
|
||||
init,
|
||||
stream);
|
||||
} else {
|
||||
throw std::runtime_error("Unsupported plan in segmented_reduce.");
|
||||
}
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
@@ -225,19 +225,20 @@ void RMSNorm::eval_gpu(
|
||||
encoder.set_input_array(w);
|
||||
encoder.set_output_array(out);
|
||||
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||
MLX_SWITCH_FLOAT_TYPES_CHECKED(out.dtype(), "rms_norm", CTYPE, {
|
||||
using DataType = cuda_type_t<CTYPE>;
|
||||
dispatch_float_types(out.dtype(), "rms_norm", [&](auto type_tag) {
|
||||
constexpr uint32_t N_READS = 4;
|
||||
MLX_SWITCH_BLOCK_DIM(cuda::ceil_div(axis_size, N_READS), BLOCK_DIM, {
|
||||
auto kernel = cu::rms_norm<DataType, BLOCK_DIM, N_READS>;
|
||||
kernel<<<n_rows, BLOCK_DIM, 0, stream>>>(
|
||||
x.data<DataType>(),
|
||||
w.data<DataType>(),
|
||||
out.data<DataType>(),
|
||||
eps_,
|
||||
axis_size,
|
||||
w_stride);
|
||||
});
|
||||
dispatch_block_dim(
|
||||
cuda::ceil_div(axis_size, N_READS), [&](auto block_dim) {
|
||||
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
|
||||
auto kernel = cu::rms_norm<DataType, block_dim(), N_READS>;
|
||||
kernel<<<n_rows, block_dim(), 0, stream>>>(
|
||||
x.data<DataType>(),
|
||||
w.data<DataType>(),
|
||||
out.data<DataType>(),
|
||||
eps_,
|
||||
axis_size,
|
||||
w_stride);
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
@@ -303,7 +304,6 @@ void RMSNormVJP::eval_gpu(
|
||||
encoder.add_temporary(gw_temp);
|
||||
}
|
||||
}
|
||||
gw.set_data(allocator::malloc(gw.nbytes()));
|
||||
|
||||
encoder.set_input_array(x);
|
||||
encoder.set_input_array(w);
|
||||
@@ -311,22 +311,28 @@ void RMSNormVJP::eval_gpu(
|
||||
encoder.set_output_array(gx);
|
||||
encoder.set_output_array(gw_temp);
|
||||
encoder.launch_kernel([&, x = x, g = g](cudaStream_t stream) {
|
||||
MLX_SWITCH_FLOAT_TYPES_CHECKED(gx.dtype(), "rms_norm_vjp", CTYPE, {
|
||||
using DataType = cuda_type_t<CTYPE>;
|
||||
constexpr int N_READS = 4;
|
||||
MLX_SWITCH_BOOL(has_w, HAS_W, {
|
||||
MLX_SWITCH_BLOCK_DIM(cuda::ceil_div(axis_size, N_READS), BLOCK_DIM, {
|
||||
auto kernel = cu::rms_norm_vjp<DataType, HAS_W, BLOCK_DIM, N_READS>;
|
||||
kernel<<<n_rows, BLOCK_DIM, 0, stream>>>(
|
||||
x.data<DataType>(),
|
||||
w.data<DataType>(),
|
||||
g.data<DataType>(),
|
||||
gx.data<DataType>(),
|
||||
gw_temp.data<DataType>(),
|
||||
eps_,
|
||||
axis_size,
|
||||
w_stride);
|
||||
});
|
||||
dispatch_float_types(gx.dtype(), "rms_norm_vjp", [&](auto type_tag) {
|
||||
dispatch_bool(has_w, [&](auto has_w_constant) {
|
||||
constexpr int N_READS = 4;
|
||||
dispatch_block_dim(
|
||||
cuda::ceil_div(axis_size, N_READS), [&](auto block_dim) {
|
||||
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
|
||||
constexpr int N_READS = 4;
|
||||
auto kernel = cu::rms_norm_vjp<
|
||||
DataType,
|
||||
has_w_constant(),
|
||||
block_dim(),
|
||||
N_READS>;
|
||||
kernel<<<n_rows, block_dim(), 0, stream>>>(
|
||||
x.data<DataType>(),
|
||||
w.data<DataType>(),
|
||||
g.data<DataType>(),
|
||||
gx.data<DataType>(),
|
||||
gw_temp.data<DataType>(),
|
||||
eps_,
|
||||
axis_size,
|
||||
w_stride);
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -310,12 +310,12 @@ void RoPE::eval_gpu(
|
||||
encoder.set_input_array(offset);
|
||||
encoder.set_output_array(out);
|
||||
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||
MLX_SWITCH_FLOAT_TYPES_CHECKED(in.dtype(), "rope", CTYPE, {
|
||||
using DataType = cuda_type_t<CTYPE>;
|
||||
MLX_SWITCH_BOOL(traditional_, TRADITIONAL, {
|
||||
MLX_SWITCH_BOOL(forward_, FORWARD, {
|
||||
dispatch_float_types(out.dtype(), "rope", [&](auto type_tag) {
|
||||
dispatch_bool(traditional_, [&](auto traditional) {
|
||||
dispatch_bool(forward_, [&](auto forward) {
|
||||
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
|
||||
if (single && !with_freqs) {
|
||||
auto kernel = cu::rope_single<DataType, TRADITIONAL, FORWARD>;
|
||||
auto kernel = cu::rope_single<DataType, traditional(), forward()>;
|
||||
uint2 dims = make_uint2(dims_ / 2, in.size() / mat_size);
|
||||
auto [grid, block] = get_grid_and_block(dims.x, dims.y, 1);
|
||||
kernel<<<grid, block, 0, stream>>>(
|
||||
@@ -327,7 +327,8 @@ void RoPE::eval_gpu(
|
||||
mat_size,
|
||||
dims);
|
||||
} else if (single) {
|
||||
auto kernel = cu::rope_single_freqs<DataType, TRADITIONAL, FORWARD>;
|
||||
auto kernel =
|
||||
cu::rope_single_freqs<DataType, traditional(), forward()>;
|
||||
uint2 dims = make_uint2(dims_ / 2, in.size() / mat_size);
|
||||
auto [grid, block] = get_grid_and_block(dims.x, dims.y, 1);
|
||||
kernel<<<grid, block, 0, stream>>>(
|
||||
@@ -340,7 +341,7 @@ void RoPE::eval_gpu(
|
||||
dims,
|
||||
inputs[2].strides(0));
|
||||
} else if (with_freqs) {
|
||||
auto kernel = cu::rope_freqs<DataType, TRADITIONAL, FORWARD>;
|
||||
auto kernel = cu::rope_freqs<DataType, traditional(), forward()>;
|
||||
uint3 dims =
|
||||
make_uint3(dims_ / 2, in.shape(-2), in.size() / mat_size);
|
||||
dims.z = (dims.z + 3) / 4;
|
||||
@@ -358,7 +359,7 @@ void RoPE::eval_gpu(
|
||||
dims,
|
||||
inputs[2].strides(0));
|
||||
} else {
|
||||
auto kernel = cu::rope<DataType, TRADITIONAL, FORWARD>;
|
||||
auto kernel = cu::rope<DataType, traditional(), forward()>;
|
||||
uint3 dims =
|
||||
make_uint3(dims_ / 2, in.shape(-2), in.size() / mat_size);
|
||||
dims.z = (dims.z + 3) / 4;
|
||||
|
||||
@@ -51,7 +51,7 @@ __global__ void softmax(const T* in, T* out, int axis_size) {
|
||||
make_cast_iterator<AccT>(in),
|
||||
vals,
|
||||
axis_size,
|
||||
Limits<AccT>::finite_min());
|
||||
Limits<AccT>::min());
|
||||
prevmax = maxval;
|
||||
maxval = max_op(maxval, cub::ThreadReduce(vals, max_op));
|
||||
// Online normalizer calculation for softmax:
|
||||
@@ -79,7 +79,7 @@ __global__ void softmax(const T* in, T* out, int axis_size) {
|
||||
block.sync();
|
||||
maxval = warp.thread_rank() < warp.meta_group_size()
|
||||
? local_max[warp.thread_rank()]
|
||||
: Limits<AccT>::finite_min();
|
||||
: Limits<AccT>::min();
|
||||
maxval = cg::reduce(warp, maxval, max_op);
|
||||
normalizer = normalizer * softmax_exp(prevmax - maxval);
|
||||
if (warp.thread_rank() == 0) {
|
||||
@@ -142,17 +142,18 @@ void Softmax::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
encoder.set_input_array(in);
|
||||
encoder.set_output_array(out);
|
||||
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||
MLX_SWITCH_FLOAT_TYPES_CHECKED(out.dtype(), "softmax", CTYPE, {
|
||||
using DataType = cuda_type_t<CTYPE>;
|
||||
dispatch_float_types(out.dtype(), "softmax", [&](auto type_tag) {
|
||||
constexpr int N_READS = 4;
|
||||
MLX_SWITCH_BLOCK_DIM(cuda::ceil_div(axis_size, N_READS), BLOCK_DIM, {
|
||||
auto kernel = cu::softmax<DataType, DataType, BLOCK_DIM, N_READS>;
|
||||
if (precise) {
|
||||
kernel = cu::softmax<DataType, float, BLOCK_DIM, N_READS>;
|
||||
}
|
||||
kernel<<<n_rows, BLOCK_DIM, 0, stream>>>(
|
||||
in.data<DataType>(), out.data<DataType>(), axis_size);
|
||||
});
|
||||
dispatch_block_dim(
|
||||
cuda::ceil_div(axis_size, N_READS), [&](auto block_dim) {
|
||||
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
|
||||
auto kernel = cu::softmax<DataType, DataType, block_dim(), N_READS>;
|
||||
if (precise) {
|
||||
kernel = cu::softmax<DataType, float, block_dim(), N_READS>;
|
||||
}
|
||||
kernel<<<n_rows, block_dim(), 0, stream>>>(
|
||||
in.data<DataType>(), out.data<DataType>(), axis_size);
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
@@ -76,17 +76,21 @@ void segmented_sort(cu::CommandEncoder& encoder, Args&&... args) {
|
||||
temp.data<void>(), size, args...));
|
||||
}
|
||||
|
||||
struct OffsetTransform {
|
||||
int nsort;
|
||||
|
||||
int __device__ operator()(int i) {
|
||||
return i * nsort;
|
||||
}
|
||||
};
|
||||
|
||||
void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) {
|
||||
array out = out_;
|
||||
auto& encoder = cu::get_command_encoder(s);
|
||||
encoder.set_input_array(in);
|
||||
encoder.set_output_array(out);
|
||||
|
||||
if (axis < 0) {
|
||||
axis += in.ndim();
|
||||
}
|
||||
int nsort = in.shape(axis);
|
||||
int nsegments = in.data_size() / nsort;
|
||||
int last_dim = in.ndim() - 1;
|
||||
|
||||
// If we are not sorting the innermost dimension of a contiguous array,
|
||||
@@ -100,16 +104,22 @@ void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) {
|
||||
out = array(allocator::malloc(out.nbytes()), in.shape(), out.dtype());
|
||||
encoder.add_temporary(out);
|
||||
} else {
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
out.set_data(
|
||||
allocator::malloc(in.data_size() * out.itemsize()),
|
||||
in.data_size(),
|
||||
in.strides(),
|
||||
in.flags());
|
||||
}
|
||||
|
||||
encoder.set_input_array(in);
|
||||
encoder.set_output_array(out);
|
||||
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||
MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE, {
|
||||
dispatch_all_types(in.dtype(), [&](auto type_tag) {
|
||||
using CTYPE = MLX_GET_TYPE(type_tag);
|
||||
if constexpr (!std::is_same_v<CTYPE, complex64_t>) {
|
||||
using Type = cuda_type_t<CTYPE>;
|
||||
auto offsets = thrust::make_transform_iterator(
|
||||
thrust::make_counting_iterator(0),
|
||||
[nsort] __device__(int i) { return i * nsort; });
|
||||
thrust::make_counting_iterator(0), OffsetTransform{nsort});
|
||||
if (argsort) {
|
||||
// Indices in the sorted dimension.
|
||||
array indices(
|
||||
@@ -134,7 +144,7 @@ void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) {
|
||||
indices.data<uint32_t>(),
|
||||
out.data<uint32_t>(),
|
||||
in.data_size(),
|
||||
nsegments,
|
||||
in.data_size() / nsort,
|
||||
offsets,
|
||||
offsets + 1,
|
||||
stream);
|
||||
@@ -144,7 +154,7 @@ void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) {
|
||||
in.data<Type>(),
|
||||
out.data<Type>(),
|
||||
in.data_size(),
|
||||
nsegments,
|
||||
in.data_size() / nsort,
|
||||
offsets,
|
||||
offsets + 1,
|
||||
stream);
|
||||
@@ -177,4 +187,14 @@ void Sort::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
gpu_sort(stream(), inputs[0], out, axis_, false);
|
||||
}
|
||||
|
||||
void ArgPartition::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
nvtx3::scoped_range r("ArgPartition::eval_gpu");
|
||||
gpu_sort(stream(), inputs[0], out, axis_, true);
|
||||
}
|
||||
|
||||
void Partition::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
nvtx3::scoped_range r("Partition::eval_gpu");
|
||||
gpu_sort(stream(), inputs[0], out, axis_, false);
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
||||
@@ -92,58 +92,63 @@ void ternary_op_gpu_inplace(
|
||||
encoder.set_input_array(c);
|
||||
encoder.set_output_array(out);
|
||||
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||
MLX_SWITCH_ALL_TYPES(out.dtype(), CTYPE, {
|
||||
using DType = cuda_type_t<CTYPE>;
|
||||
dispatch_all_types(out.dtype(), [&](auto type_tag) {
|
||||
using DType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
|
||||
|
||||
auto topt = get_ternary_op_type(a, b, c);
|
||||
if (topt == TernaryOpType::General) {
|
||||
auto [shape, strides] = collapse_contiguous_dims(a, b, c, out);
|
||||
auto& a_strides = strides[0];
|
||||
auto& b_strides = strides[1];
|
||||
auto& c_strides = strides[2];
|
||||
bool large = a.data_size() > INT32_MAX || b.data_size() > INT32_MAX ||
|
||||
c.data_size() > INT32_MAX || out.data_size() > INT32_MAX;
|
||||
MLX_SWITCH_BOOL(large, LARGE, {
|
||||
using IdxT = std::conditional_t<LARGE, int64_t, int32_t>;
|
||||
int ndim = shape.size();
|
||||
if (ndim <= 3) {
|
||||
MLX_SWITCH_1_2_3(ndim, NDIM, {
|
||||
auto kernel = cu::ternary_g_nd<Op, DType, IdxT, NDIM>;
|
||||
auto [num_blocks, block_dims] =
|
||||
get_launch_args(kernel, out, large);
|
||||
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
||||
a.data<bool>(),
|
||||
b.data<DType>(),
|
||||
c.data<DType>(),
|
||||
out.data<DType>(),
|
||||
out.size(),
|
||||
const_param<NDIM>(shape),
|
||||
const_param<NDIM>(a_strides),
|
||||
const_param<NDIM>(b_strides),
|
||||
const_param<NDIM>(c_strides));
|
||||
dispatch_bool(
|
||||
a.data_size() > INT32_MAX || b.data_size() > INT32_MAX ||
|
||||
c.data_size() > INT32_MAX || out.data_size() > INT32_MAX,
|
||||
[&](auto large) {
|
||||
using IdxT = std::conditional_t<large(), int64_t, int32_t>;
|
||||
Shape shape;
|
||||
std::vector<Strides> strides;
|
||||
std::tie(shape, strides) = collapse_contiguous_dims(a, b, c, out);
|
||||
auto& a_strides = strides[0];
|
||||
auto& b_strides = strides[1];
|
||||
auto& c_strides = strides[2];
|
||||
int ndim = shape.size();
|
||||
if (ndim <= 3) {
|
||||
dispatch_1_2_3(ndim, [&](auto dims_constant) {
|
||||
auto kernel =
|
||||
cu::ternary_g_nd<Op, DType, IdxT, dims_constant()>;
|
||||
auto [num_blocks, block_dims] =
|
||||
get_launch_args(kernel, out, large());
|
||||
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
||||
a.data<bool>(),
|
||||
b.data<DType>(),
|
||||
c.data<DType>(),
|
||||
out.data<DType>(),
|
||||
out.size(),
|
||||
const_param<dims_constant()>(shape),
|
||||
const_param<dims_constant()>(a_strides),
|
||||
const_param<dims_constant()>(b_strides),
|
||||
const_param<dims_constant()>(c_strides));
|
||||
});
|
||||
} else {
|
||||
auto kernel = cu::ternary_g<Op, DType, IdxT>;
|
||||
auto [num_blocks, block_dims] =
|
||||
get_launch_args(kernel, out, large());
|
||||
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
||||
a.data<bool>(),
|
||||
b.data<DType>(),
|
||||
c.data<DType>(),
|
||||
out.data<DType>(),
|
||||
out.data_size(),
|
||||
const_param(shape),
|
||||
const_param(a_strides),
|
||||
const_param(b_strides),
|
||||
const_param(c_strides),
|
||||
ndim);
|
||||
}
|
||||
});
|
||||
} else {
|
||||
auto kernel = cu::ternary_g<Op, DType, IdxT>;
|
||||
auto [num_blocks, block_dims] = get_launch_args(kernel, out, large);
|
||||
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
||||
a.data<bool>(),
|
||||
b.data<DType>(),
|
||||
c.data<DType>(),
|
||||
out.data<DType>(),
|
||||
out.data_size(),
|
||||
const_param(shape),
|
||||
const_param(a_strides),
|
||||
const_param(b_strides),
|
||||
const_param(c_strides),
|
||||
ndim);
|
||||
}
|
||||
});
|
||||
} else {
|
||||
MLX_SWITCH_BOOL(out.data_size() > UINT32_MAX, LARGE, {
|
||||
using IdxT = std::conditional_t<LARGE, int64_t, uint32_t>;
|
||||
dispatch_bool(out.data_size() > INT32_MAX, [&](auto large) {
|
||||
using IdxT = std::conditional_t<large(), int64_t, uint32_t>;
|
||||
auto kernel = cu::ternary_v<Op, DType, IdxT>;
|
||||
auto [num_blocks, block_dims] = get_launch_args(
|
||||
kernel, out.data_size(), out.shape(), out.strides(), LARGE);
|
||||
kernel, out.data_size(), out.shape(), out.strides(), large());
|
||||
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
||||
a.data<bool>(),
|
||||
b.data<DType>(),
|
||||
|
||||
@@ -20,38 +20,35 @@ namespace cu {
|
||||
template <typename Op, typename In, typename Out>
|
||||
constexpr bool supports_unary_op() {
|
||||
if (std::is_same_v<Op, Abs> || std::is_same_v<Op, Negative> ||
|
||||
std::is_same_v<Op, Sign>) {
|
||||
std::is_same_v<Op, Sign> || std::is_same_v<Op, Square>) {
|
||||
return std::is_same_v<In, Out>;
|
||||
}
|
||||
if (std::is_same_v<Op, ArcCos> || std::is_same_v<Op, ArcCosh> ||
|
||||
std::is_same_v<Op, ArcSin> || std::is_same_v<Op, ArcSinh> ||
|
||||
std::is_same_v<Op, ArcTan> || std::is_same_v<Op, ArcTanh> ||
|
||||
std::is_same_v<Op, Erf> || std::is_same_v<Op, ErfInv> ||
|
||||
std::is_same_v<Op, Expm1> || std::is_same_v<Op, Sigmoid> ||
|
||||
std::is_same_v<Op, Sqrt> || std::is_same_v<Op, Rsqrt>) {
|
||||
if (std::is_same_v<Op, ArcCosh> || std::is_same_v<Op, ArcSinh> ||
|
||||
std::is_same_v<Op, ArcTanh> || std::is_same_v<Op, Erf> ||
|
||||
std::is_same_v<Op, ErfInv> || std::is_same_v<Op, Expm1> ||
|
||||
std::is_same_v<Op, Sigmoid>) {
|
||||
return std::is_same_v<In, Out> && is_floating_v<In>;
|
||||
}
|
||||
if (std::is_same_v<Op, Log> || std::is_same_v<Op, Log2> ||
|
||||
std::is_same_v<Op, Log10> || std::is_same_v<Op, Log1p>) {
|
||||
return std::is_same_v<In, Out> && is_inexact_v<In>;
|
||||
}
|
||||
if (std::is_same_v<Op, BitwiseInvert>) {
|
||||
return std::is_same_v<In, Out> && std::is_integral_v<In> &&
|
||||
!std::is_same_v<In, bool>;
|
||||
}
|
||||
if (std::is_same_v<Op, Ceil> || std::is_same_v<Op, Floor> ||
|
||||
std::is_same_v<Op, Square>) {
|
||||
if (std::is_same_v<Op, Ceil> || std::is_same_v<Op, Floor>) {
|
||||
return std::is_same_v<In, Out> && !std::is_same_v<In, complex64_t>;
|
||||
}
|
||||
if (std::is_same_v<Op, Conjugate>) {
|
||||
return std::is_same_v<In, Out> && std::is_same_v<In, complex64_t>;
|
||||
}
|
||||
if (std::is_same_v<Op, Cos> || std::is_same_v<Op, Cosh> ||
|
||||
std::is_same_v<Op, Exp> || std::is_same_v<Op, Round> ||
|
||||
std::is_same_v<Op, Sin> || std::is_same_v<Op, Sinh> ||
|
||||
std::is_same_v<Op, Tan> || std::is_same_v<Op, Tanh>) {
|
||||
return std::is_same_v<In, Out> &&
|
||||
(is_floating_v<In> || std::is_same_v<In, complex64_t>);
|
||||
if (std::is_same_v<Op, ArcCos> || std::is_same_v<Op, ArcSin> ||
|
||||
std::is_same_v<Op, ArcTan> || std::is_same_v<Op, Cos> ||
|
||||
std::is_same_v<Op, Cosh> || std::is_same_v<Op, Exp> ||
|
||||
std::is_same_v<Op, Log> || std::is_same_v<Op, Log2> ||
|
||||
std::is_same_v<Op, Log10> || std::is_same_v<Op, Log1p> ||
|
||||
std::is_same_v<Op, Round> || std::is_same_v<Op, Rsqrt> ||
|
||||
std::is_same_v<Op, Sqrt> || std::is_same_v<Op, Sin> ||
|
||||
std::is_same_v<Op, Sinh> || std::is_same_v<Op, Tan> ||
|
||||
std::is_same_v<Op, Tanh>) {
|
||||
return std::is_same_v<In, Out> && is_inexact_v<In>;
|
||||
}
|
||||
if (std::is_same_v<Op, Imag> || std::is_same_v<Op, Real>) {
|
||||
return std::is_same_v<In, complex64_t> && std::is_same_v<Out, float>;
|
||||
@@ -79,8 +76,10 @@ void unary_op_gpu_inplace(
|
||||
encoder.set_input_array(in);
|
||||
encoder.set_output_array(out);
|
||||
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||
MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE_IN, {
|
||||
MLX_SWITCH_ALL_TYPES(out.dtype(), CTYPE_OUT, {
|
||||
dispatch_all_types(in.dtype(), [&](auto in_type_tag) {
|
||||
dispatch_all_types(out.dtype(), [&](auto out_type_tag) {
|
||||
using CTYPE_IN = MLX_GET_TYPE(in_type_tag);
|
||||
using CTYPE_OUT = MLX_GET_TYPE(out_type_tag);
|
||||
if constexpr (cu::supports_unary_op<Op, CTYPE_IN, CTYPE_OUT>()) {
|
||||
using InType = cuda_type_t<CTYPE_IN>;
|
||||
using OutType = cuda_type_t<CTYPE_OUT>;
|
||||
|
||||
@@ -25,22 +25,38 @@ void check_cuda_error(const char* name, cudaError_t err) {
|
||||
}
|
||||
|
||||
const char* dtype_to_cuda_type(const Dtype& dtype) {
|
||||
if (dtype == float16) {
|
||||
return "__half";
|
||||
switch (dtype) {
|
||||
case bool_:
|
||||
return "bool";
|
||||
case int8:
|
||||
return "int8_t";
|
||||
case int16:
|
||||
return "int16_t";
|
||||
case int32:
|
||||
return "int32_t";
|
||||
case int64:
|
||||
return "int64_t";
|
||||
case uint8:
|
||||
return "uint8_t";
|
||||
case uint16:
|
||||
return "uint16_t";
|
||||
case uint32:
|
||||
return "uint32_t";
|
||||
case uint64:
|
||||
return "uint64_t";
|
||||
case float16:
|
||||
return "__half";
|
||||
case bfloat16:
|
||||
return "__nv_bfloat16";
|
||||
case float32:
|
||||
return "float";
|
||||
case float64:
|
||||
return "double";
|
||||
case complex64:
|
||||
return "cuComplex";
|
||||
default:
|
||||
return "unknown";
|
||||
}
|
||||
if (dtype == bfloat16) {
|
||||
return "__nv_bfloat16";
|
||||
}
|
||||
if (dtype == complex64) {
|
||||
return "cuComplex";
|
||||
}
|
||||
#define SPECIALIZE_DtypeToString(CPP_TYPE, DTYPE) \
|
||||
if (dtype == DTYPE) { \
|
||||
return #CPP_TYPE; \
|
||||
}
|
||||
MLX_FORALL_DTYPES(SPECIALIZE_DtypeToString)
|
||||
#undef SPECIALIZE_DtypeToString
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
||||
@@ -80,7 +80,9 @@ void Worker::thread_fn() {
|
||||
}
|
||||
worker_tasks_.erase(worker_tasks_.begin(), end);
|
||||
}
|
||||
for (auto& task : tasks) {
|
||||
// Make sure tasks are cleared before the next wait
|
||||
for (int i = 0; i < tasks.size(); ++i) {
|
||||
auto task = std::move(tasks[i]);
|
||||
task();
|
||||
}
|
||||
worker_event_.wait(batch + 1);
|
||||
|
||||
@@ -1,85 +0,0 @@
|
||||
# Filename rules in ROCm backend:
|
||||
#
|
||||
# * Use .hip/.hpp if code contains device code, and .cpp/.h if not.
|
||||
# * Device-only code should be put in device/ subdir.
|
||||
# * Files in device/ subdir should not include files outside.
|
||||
target_sources(
|
||||
mlx
|
||||
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/arg_reduce.hip
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/binary.hip
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/compiled.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/copy.hip
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/eval.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/event.hip
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/fence.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/kernel_utils.hip
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/matmul.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/layer_norm.hip
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/logsumexp.hip
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/primitives.hip
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/random.hip
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/reduce.hip
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/rms_norm.hip
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/rope.hip
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/rocm.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/softmax.hip
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/sort.hip
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/ternary.hip
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/unary.hip
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/worker.cpp)
|
||||
|
||||
target_compile_definitions(mlx PRIVATE MLX_USE_ROCM)
|
||||
|
||||
# Embed kernel sources in binary for JIT compilation.
|
||||
file(
|
||||
GLOB MLX_JIT_SOURCES
|
||||
RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
|
||||
"${CMAKE_CURRENT_SOURCE_DIR}/device/*.h"
|
||||
"${CMAKE_CURRENT_SOURCE_DIR}/device/*.hpp")
|
||||
string(JOIN ":" MLX_JIT_SOURCES_ARG ${MLX_JIT_SOURCES})
|
||||
add_custom_command(
|
||||
OUTPUT gen/rocm_jit_sources.h
|
||||
COMMAND
|
||||
${CMAKE_COMMAND} -DMLX_SOURCE_ROOT=${CMAKE_CURRENT_SOURCE_DIR}
|
||||
-DMLX_JIT_SOURCES=${MLX_JIT_SOURCES_ARG} -P
|
||||
"${CMAKE_CURRENT_SOURCE_DIR}/bin2h.cmake"
|
||||
DEPENDS bin2h.cmake ${MLX_JIT_SOURCES})
|
||||
add_custom_target(rocm_jit_sources DEPENDS gen/rocm_jit_sources.h)
|
||||
add_dependencies(mlx rocm_jit_sources)
|
||||
target_include_directories(mlx PRIVATE "${CMAKE_CURRENT_BINARY_DIR}/gen")
|
||||
|
||||
# Find ROCm installation
|
||||
find_package(hip REQUIRED)
|
||||
find_package(rocblas REQUIRED)
|
||||
|
||||
# Link with ROCm libraries
|
||||
target_link_libraries(mlx PRIVATE hip::device roc::rocblas)
|
||||
|
||||
# Set GPU architectures for ROCm Common ROCm architectures: gfx900, gfx906,
|
||||
# gfx908, gfx90a, gfx1030, gfx1100
|
||||
set(MLX_ROCM_ARCHITECTURES
|
||||
"gfx900;gfx906;gfx908;gfx90a;gfx1030;gfx1100"
|
||||
CACHE STRING "ROCm GPU architectures")
|
||||
message(STATUS "ROCm GPU architectures: ${MLX_ROCM_ARCHITECTURES}")
|
||||
|
||||
# Set GPU targets for HIP compilation
|
||||
set_property(TARGET mlx PROPERTY HIP_ARCHITECTURES "${MLX_ROCM_ARCHITECTURES}")
|
||||
|
||||
# Enable HIP language support
|
||||
enable_language(HIP)
|
||||
|
||||
# Set HIP compiler flags
|
||||
target_compile_options(
|
||||
mlx
|
||||
PRIVATE "$<$<COMPILE_LANGUAGE:HIP>:-fgpu-rdc>"
|
||||
"$<$<COMPILE_LANGUAGE:HIP>:-Xcompiler=-Wall>"
|
||||
"$<$<COMPILE_LANGUAGE:HIP>:-Xcompiler=-Wextra>")
|
||||
|
||||
# Add ROCm include directories
|
||||
target_include_directories(mlx PRIVATE ${hip_INCLUDE_DIRS})
|
||||
target_include_directories(mlx PRIVATE ${rocblas_INCLUDE_DIRS})
|
||||
@@ -1,206 +0,0 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/rocm/allocator.h"
|
||||
#include "mlx/backend/rocm/utils.h"
|
||||
#include "mlx/backend/rocm/worker.h"
|
||||
|
||||
#include <fmt/format.h>
|
||||
#include <hip/hip_runtime.h>
|
||||
#include <unistd.h>
|
||||
|
||||
#include <cassert>
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace rocm {
|
||||
|
||||
RocmAllocator::RocmAllocator()
|
||||
: buffer_cache_(
|
||||
getpagesize(),
|
||||
[](RocmBuffer* buf) { return buf->size; },
|
||||
[this](RocmBuffer* buf) {
|
||||
rocm_free(buf->data);
|
||||
delete buf;
|
||||
}) {
|
||||
// TODO: Set memory limit for multi-device.
|
||||
size_t free, total;
|
||||
CHECK_HIP_ERROR(hipMemGetInfo(&free, &total));
|
||||
memory_limit_ = total * 0.8;
|
||||
max_pool_size_ = memory_limit_;
|
||||
}
|
||||
|
||||
Buffer RocmAllocator::malloc(size_t size) {
|
||||
// Find available buffer from cache.
|
||||
std::unique_lock lock(mutex_);
|
||||
RocmBuffer* buf = buffer_cache_.reuse_from_cache(size);
|
||||
if (!buf) {
|
||||
// If we have a lot of memory pressure or are over the maximum cache size,
|
||||
// try to reclaim memory from the cache.
|
||||
size_t mem_required = get_active_memory() + get_cache_memory() + size;
|
||||
if (mem_required >= memory_limit_) {
|
||||
buffer_cache_.release_cached_buffers(mem_required - memory_limit_);
|
||||
}
|
||||
|
||||
lock.unlock();
|
||||
buf = new RocmBuffer{nullptr, size};
|
||||
hipError_t err = hipMallocManaged(&buf->data, size);
|
||||
if (err != hipSuccess && err != hipErrorMemoryAllocation) {
|
||||
throw std::runtime_error(
|
||||
fmt::format("hipMallocManaged failed: {}.", hipGetErrorString(err)));
|
||||
}
|
||||
lock.lock();
|
||||
}
|
||||
active_memory_ += size;
|
||||
peak_memory_ = std::max(active_memory_, peak_memory_);
|
||||
|
||||
// Maintain the cache below the requested limit.
|
||||
if (get_cache_memory() > max_pool_size_) {
|
||||
buffer_cache_.release_cached_buffers(get_cache_memory() - max_pool_size_);
|
||||
}
|
||||
|
||||
return Buffer{buf};
|
||||
}
|
||||
|
||||
void RocmAllocator::free(Buffer buffer) {
|
||||
auto* buf = static_cast<RocmBuffer*>(buffer.ptr());
|
||||
if (!buf) {
|
||||
return;
|
||||
}
|
||||
|
||||
std::unique_lock lock(mutex_);
|
||||
active_memory_ -= buf->size;
|
||||
if (get_cache_memory() < max_pool_size_) {
|
||||
buffer_cache_.recycle_to_cache(buf);
|
||||
} else {
|
||||
lock.unlock();
|
||||
rocm_free(buf->data);
|
||||
delete buf;
|
||||
}
|
||||
}
|
||||
|
||||
size_t RocmAllocator::size(Buffer buffer) const {
|
||||
auto* buf = static_cast<RocmBuffer*>(buffer.ptr());
|
||||
if (!buf) {
|
||||
return 0;
|
||||
}
|
||||
return buf->size;
|
||||
}
|
||||
|
||||
void RocmAllocator::register_this_thread() {
|
||||
std::lock_guard lock(worker_mutex_);
|
||||
allowed_threads_.insert(std::this_thread::get_id());
|
||||
}
|
||||
|
||||
void RocmAllocator::rocm_free(void* buf) {
|
||||
// If rocm_free() is called from a unregistered thread, reschedule the call to
|
||||
// worker.
|
||||
{
|
||||
std::lock_guard lock(worker_mutex_);
|
||||
if (allowed_threads_.count(std::this_thread::get_id()) == 0) {
|
||||
if (!worker_) {
|
||||
worker_.reset(new Worker);
|
||||
}
|
||||
worker_->add_task([this, buf]() { this->rocm_free(buf); });
|
||||
worker_->end_batch();
|
||||
worker_->commit();
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
hipFree(buf);
|
||||
}
|
||||
|
||||
size_t RocmAllocator::get_active_memory() const {
|
||||
return active_memory_;
|
||||
}
|
||||
|
||||
size_t RocmAllocator::get_peak_memory() const {
|
||||
return peak_memory_;
|
||||
}
|
||||
|
||||
void RocmAllocator::reset_peak_memory() {
|
||||
std::lock_guard lock(mutex_);
|
||||
peak_memory_ = 0;
|
||||
}
|
||||
|
||||
size_t RocmAllocator::get_memory_limit() {
|
||||
return memory_limit_;
|
||||
}
|
||||
|
||||
size_t RocmAllocator::set_memory_limit(size_t limit) {
|
||||
std::lock_guard lock(mutex_);
|
||||
std::swap(limit, memory_limit_);
|
||||
return limit;
|
||||
}
|
||||
|
||||
size_t RocmAllocator::get_cache_memory() const {
|
||||
return buffer_cache_.cache_size();
|
||||
}
|
||||
|
||||
size_t RocmAllocator::set_cache_limit(size_t limit) {
|
||||
std::lock_guard lk(mutex_);
|
||||
std::swap(limit, max_pool_size_);
|
||||
return limit;
|
||||
}
|
||||
|
||||
void RocmAllocator::clear_cache() {
|
||||
std::lock_guard lk(mutex_);
|
||||
buffer_cache_.clear();
|
||||
}
|
||||
|
||||
RocmAllocator& allocator() {
|
||||
// By creating the |allocator_| on heap, the destructor of RocmAllocator
|
||||
// will not be called on exit and buffers in the cache will be leaked. This
|
||||
// can save some time at program exit.
|
||||
static RocmAllocator* allocator_ = new RocmAllocator;
|
||||
return *allocator_;
|
||||
}
|
||||
|
||||
} // namespace rocm
|
||||
|
||||
namespace allocator {
|
||||
|
||||
Allocator& allocator() {
|
||||
return rocm::allocator();
|
||||
}
|
||||
|
||||
void* Buffer::raw_ptr() {
|
||||
if (!ptr_) {
|
||||
return nullptr;
|
||||
}
|
||||
return static_cast<rocm::RocmBuffer*>(ptr_)->data;
|
||||
}
|
||||
|
||||
} // namespace allocator
|
||||
|
||||
size_t get_active_memory() {
|
||||
return rocm::allocator().get_active_memory();
|
||||
}
|
||||
size_t get_peak_memory() {
|
||||
return rocm::allocator().get_peak_memory();
|
||||
}
|
||||
void reset_peak_memory() {
|
||||
return rocm::allocator().reset_peak_memory();
|
||||
}
|
||||
size_t set_memory_limit(size_t limit) {
|
||||
return rocm::allocator().set_memory_limit(limit);
|
||||
}
|
||||
size_t get_memory_limit() {
|
||||
return rocm::allocator().get_memory_limit();
|
||||
}
|
||||
size_t get_cache_memory() {
|
||||
return rocm::allocator().get_cache_memory();
|
||||
}
|
||||
size_t set_cache_limit(size_t limit) {
|
||||
return rocm::allocator().set_cache_limit(limit);
|
||||
}
|
||||
void clear_cache() {
|
||||
rocm::allocator().clear_cache();
|
||||
}
|
||||
|
||||
// Not supported in ROCm.
|
||||
size_t set_wired_limit(size_t) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
@@ -1,67 +0,0 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "mlx/allocator.h"
|
||||
#include "mlx/backend/common/buffer_cache.h"
|
||||
|
||||
#include <mutex>
|
||||
#include <set>
|
||||
#include <thread>
|
||||
#include <utility>
|
||||
|
||||
namespace mlx::core::rocm {
|
||||
|
||||
class Worker;
|
||||
|
||||
using allocator::Buffer;
|
||||
|
||||
// Stores ROCm-managed unified memory.
|
||||
struct RocmBuffer {
|
||||
void* data;
|
||||
size_t size;
|
||||
};
|
||||
|
||||
class RocmAllocator : public allocator::Allocator {
|
||||
public:
|
||||
Buffer malloc(size_t size) override;
|
||||
void free(Buffer buffer) override;
|
||||
size_t size(Buffer buffer) const override;
|
||||
|
||||
// Register current thread as safe to free buffers.
|
||||
// In ROCm freeing a buffer implicitly synchronizes stream, and for threads
|
||||
// that may be waited by gpu stream (for example cpu stream threads), freeing
|
||||
// buffers there would result in dead lock.
|
||||
void register_this_thread();
|
||||
|
||||
// Call hipFree in the safe thread.
|
||||
void rocm_free(void* buf);
|
||||
|
||||
size_t get_active_memory() const;
|
||||
size_t get_peak_memory() const;
|
||||
void reset_peak_memory();
|
||||
size_t get_memory_limit();
|
||||
size_t set_memory_limit(size_t limit);
|
||||
size_t get_cache_memory() const;
|
||||
size_t set_cache_limit(size_t limit);
|
||||
void clear_cache();
|
||||
|
||||
private:
|
||||
RocmAllocator();
|
||||
friend RocmAllocator& allocator();
|
||||
|
||||
std::mutex worker_mutex_;
|
||||
std::unique_ptr<Worker> worker_;
|
||||
std::set<std::thread::id> allowed_threads_;
|
||||
|
||||
std::mutex mutex_;
|
||||
size_t memory_limit_;
|
||||
size_t max_pool_size_;
|
||||
BufferCache<RocmBuffer> buffer_cache_;
|
||||
size_t active_memory_{0};
|
||||
size_t peak_memory_{0};
|
||||
};
|
||||
|
||||
RocmAllocator& allocator();
|
||||
|
||||
} // namespace mlx::core::rocm
|
||||
@@ -1,28 +0,0 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include <hip/hip_runtime.h>
|
||||
|
||||
namespace mlx::core::rocm {
|
||||
|
||||
__global__ void argmax_kernel(float* input, int* output, int n) {
|
||||
int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
|
||||
// Simple argmax placeholder
|
||||
if (idx == 0) {
|
||||
int max_idx = 0;
|
||||
float max_val = input[0];
|
||||
for (int i = 1; i < n; i++) {
|
||||
if (input[i] > max_val) {
|
||||
max_val = input[i];
|
||||
max_idx = i;
|
||||
}
|
||||
}
|
||||
output[0] = max_idx;
|
||||
}
|
||||
}
|
||||
|
||||
void launch_argmax(float* input, int* output, int n, hipStream_t stream) {
|
||||
hipLaunchKernelGGL(argmax_kernel, dim3(1), dim3(1), 0, stream, input, output, n);
|
||||
}
|
||||
|
||||
} // namespace mlx::core::rocm
|
||||
@@ -1,47 +0,0 @@
|
||||
# Copyright © 2025 Apple Inc.
|
||||
|
||||
# Script to embed kernel source files as header for JIT compilation
|
||||
|
||||
set(MLX_OUTPUT_FILE "${CMAKE_CURRENT_BINARY_DIR}/gen/rocm_jit_sources.h")
|
||||
set(MLX_KERNEL_HEADER
|
||||
"#pragma once\n\n#include <unordered_map>\n#include <string>\n\nnamespace mlx::core::rocm {\n\n"
|
||||
)
|
||||
set(MLX_KERNEL_FOOTER "\n} // namespace mlx::core::rocm\n")
|
||||
|
||||
# Create output directory
|
||||
get_filename_component(MLX_OUTPUT_DIR ${MLX_OUTPUT_FILE} DIRECTORY)
|
||||
file(MAKE_DIRECTORY ${MLX_OUTPUT_DIR})
|
||||
|
||||
# Write header
|
||||
file(WRITE ${MLX_OUTPUT_FILE} ${MLX_KERNEL_HEADER})
|
||||
|
||||
# Process JIT sources
|
||||
string(REPLACE ":" ";" MLX_JIT_SOURCES_LIST ${MLX_JIT_SOURCES})
|
||||
|
||||
set(MLX_SOURCE_MAP
|
||||
"const std::unordered_map<std::string, std::string> kernel_sources = {\n")
|
||||
|
||||
foreach(source IN LISTS MLX_JIT_SOURCES_LIST)
|
||||
set(source_file "${MLX_SOURCE_ROOT}/${source}")
|
||||
if(EXISTS ${source_file})
|
||||
# Read source file
|
||||
file(READ ${source_file} source_content)
|
||||
|
||||
# Escape content for C++ string literal
|
||||
string(REPLACE "\\" "\\\\" source_content "${source_content}")
|
||||
string(REPLACE "\"" "\\\"" source_content "${source_content}")
|
||||
string(REPLACE "\n" "\\n\"\n\"" source_content "${source_content}")
|
||||
|
||||
# Add to map
|
||||
set(MLX_SOURCE_MAP
|
||||
"${MLX_SOURCE_MAP} {\"${source}\", \"${source_content}\"},\n")
|
||||
endif()
|
||||
endforeach()
|
||||
|
||||
set(MLX_SOURCE_MAP "${MLX_SOURCE_MAP}};\n")
|
||||
|
||||
# Write source map
|
||||
file(APPEND ${MLX_OUTPUT_FILE} ${MLX_SOURCE_MAP})
|
||||
|
||||
# Write footer
|
||||
file(APPEND ${MLX_OUTPUT_FILE} ${MLX_KERNEL_FOOTER})
|
||||
@@ -1,312 +0,0 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/common/binary.h"
|
||||
#include "mlx/backend/rocm/device.h"
|
||||
#include "mlx/backend/rocm/device/binary_ops.hpp"
|
||||
#include "mlx/backend/rocm/kernel_utils.hpp"
|
||||
#include "mlx/dtype_utils.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
#include <hip/hip_cooperative_groups.h>
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace rocm {
|
||||
|
||||
namespace cg = cooperative_groups;
|
||||
|
||||
template <typename Op, typename In, typename Out, typename IdxT>
|
||||
__global__ void binary_ss(const In* a, const In* b, Out* out, IdxT size) {
|
||||
IdxT index = cg::this_grid().thread_rank();
|
||||
if (index < size) {
|
||||
out[index] = Op{}(a[0], b[0]);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Op, typename In, typename Out, typename IdxT>
|
||||
__global__ void binary_sv(const In* a, const In* b, Out* out, IdxT size) {
|
||||
IdxT index = cg::this_grid().thread_rank();
|
||||
if (index < size) {
|
||||
out[index] = Op{}(a[0], b[index]);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Op, typename In, typename Out, typename IdxT>
|
||||
__global__ void binary_vs(const In* a, const In* b, Out* out, IdxT size) {
|
||||
IdxT index = cg::this_grid().thread_rank();
|
||||
if (index < size) {
|
||||
out[index] = Op{}(a[index], b[0]);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Op, typename In, typename Out, typename IdxT>
|
||||
__global__ void binary_vv(const In* a, const In* b, Out* out, IdxT size) {
|
||||
IdxT index = cg::this_grid().thread_rank();
|
||||
if (index < size) {
|
||||
out[index] = Op{}(a[index], b[index]);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Op, typename In, typename Out, typename IdxT, int NDIM>
|
||||
__global__ void binary_g_nd(
|
||||
const In* a,
|
||||
const In* b,
|
||||
Out* out,
|
||||
IdxT size,
|
||||
const hip_array<int32_t, NDIM> shape,
|
||||
const hip_array<int64_t, NDIM> a_strides,
|
||||
const hip_array<int64_t, NDIM> b_strides) {
|
||||
IdxT index = cg::this_grid().thread_rank();
|
||||
if (index < size) {
|
||||
auto [a_idx, b_idx] = elem_to_loc_nd<NDIM>(
|
||||
index, shape.data(), a_strides.data(), b_strides.data());
|
||||
out[index] = Op{}(a[a_idx], b[b_idx]);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Op, typename In, typename Out, typename IdxT>
|
||||
__global__ void binary_g(
|
||||
const In* a,
|
||||
const In* b,
|
||||
Out* out,
|
||||
IdxT size,
|
||||
const hip_array<int32_t, MAX_DIMS> shape,
|
||||
const hip_array<int64_t, MAX_DIMS> a_strides,
|
||||
const hip_array<int64_t, MAX_DIMS> b_strides,
|
||||
int ndim) {
|
||||
IdxT index = cg::this_grid().thread_rank();
|
||||
if (index < size) {
|
||||
auto [a_idx, b_idx] = elem_to_loc_4d(
|
||||
index, shape.data(), a_strides.data(), b_strides.data(), ndim);
|
||||
out[index] = Op{}(a[a_idx], b[b_idx]);
|
||||
}
|
||||
}
|
||||
|
||||
// Binary operation support checking
|
||||
template <typename Op, typename In, typename Out>
|
||||
constexpr bool supports_binary_op() {
|
||||
if (std::is_same_v<Op, Add> || std::is_same_v<Op, Divide> ||
|
||||
std::is_same_v<Op, Maximum> || std::is_same_v<Op, Minimum> ||
|
||||
std::is_same_v<Op, Multiply> || std::is_same_v<Op, Subtract> ||
|
||||
std::is_same_v<Op, Power> || std::is_same_v<Op, Remainder>) {
|
||||
return std::is_same_v<In, Out>;
|
||||
}
|
||||
if (std::is_same_v<Op, Equal> || std::is_same_v<Op, Greater> ||
|
||||
std::is_same_v<Op, GreaterEqual> || std::is_same_v<Op, Less> ||
|
||||
std::is_same_v<Op, LessEqual> || std::is_same_v<Op, NotEqual>) {
|
||||
return std::is_same_v<Out, bool>;
|
||||
}
|
||||
if (std::is_same_v<Op, LogicalAnd> || std::is_same_v<Op, LogicalOr>) {
|
||||
return std::is_same_v<Out, bool> && std::is_same_v<In, bool>;
|
||||
}
|
||||
if (std::is_same_v<Op, NaNEqual>) {
|
||||
return std::is_same_v<Out, bool> && is_inexact_v<In>;
|
||||
}
|
||||
if (std::is_same_v<Op, LogAddExp>) {
|
||||
return std::is_same_v<In, Out> && is_inexact_v<In>;
|
||||
}
|
||||
if (std::is_same_v<Op, ArcTan2>) {
|
||||
return std::is_same_v<In, Out> && is_floating_v<In>;
|
||||
}
|
||||
if (std::is_same_v<Op, BitwiseAnd> || std::is_same_v<Op, BitwiseOr> ||
|
||||
std::is_same_v<Op, BitwiseXor>) {
|
||||
return std::is_same_v<In, Out> && std::is_integral_v<In>;
|
||||
}
|
||||
if (std::is_same_v<Op, LeftShift> || std::is_same_v<Op, RightShift>) {
|
||||
return std::is_same_v<In, Out> && std::is_integral_v<In> &&
|
||||
!std::is_same_v<In, bool>;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
} // namespace rocm
|
||||
|
||||
template <typename Op>
|
||||
void binary_op_gpu_inplace(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs,
|
||||
std::string_view op,
|
||||
const Stream& s) {
|
||||
assert(inputs.size() > 1);
|
||||
const auto& a = inputs[0];
|
||||
const auto& b = inputs[1];
|
||||
auto& out = outputs[0];
|
||||
if (out.size() == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
auto& encoder = rocm::get_command_encoder(s);
|
||||
encoder.set_input_array(a);
|
||||
encoder.set_input_array(b);
|
||||
encoder.set_output_array(out);
|
||||
encoder.launch_kernel([&](hipStream_t stream) {
|
||||
MLX_SWITCH_ALL_TYPES(a.dtype(), CTYPE_IN, {
|
||||
MLX_SWITCH_ALL_TYPES(out.dtype(), CTYPE_OUT, {
|
||||
if constexpr (rocm::supports_binary_op<Op, CTYPE_IN, CTYPE_OUT>()) {
|
||||
using InType = hip_type_t<CTYPE_IN>;
|
||||
using OutType = hip_type_t<CTYPE_OUT>;
|
||||
|
||||
auto bopt = get_binary_op_type(a, b);
|
||||
if (bopt == BinaryOpType::General) {
|
||||
auto [shape, strides] = collapse_contiguous_dims(a, b, out);
|
||||
auto& a_strides = strides[0];
|
||||
auto& b_strides = strides[1];
|
||||
bool large = a.data_size() > INT32_MAX ||
|
||||
b.data_size() > INT32_MAX || out.data_size() > INT32_MAX;
|
||||
MLX_SWITCH_BOOL(large, LARGE, {
|
||||
using IdxT = std::conditional_t<LARGE, int64_t, int32_t>;
|
||||
int ndim = shape.size();
|
||||
if (ndim <= 3) {
|
||||
MLX_SWITCH_1_2_3(ndim, NDIM, {
|
||||
auto kernel =
|
||||
&rocm::binary_g_nd<Op, InType, OutType, IdxT, NDIM>;
|
||||
auto [num_blocks, block_dims] =
|
||||
get_launch_args(kernel, out, large);
|
||||
hipLaunchKernelGGL(kernel, num_blocks, block_dims, 0, stream,
|
||||
a.data<InType>(),
|
||||
b.data<InType>(),
|
||||
out.data<OutType>(),
|
||||
out.size(),
|
||||
make_hip_array<NDIM>(shape),
|
||||
make_hip_array<NDIM>(a_strides),
|
||||
make_hip_array<NDIM>(b_strides));
|
||||
});
|
||||
} else {
|
||||
auto kernel = rocm::binary_g<Op, InType, OutType, IdxT>;
|
||||
auto [num_blocks, block_dims] =
|
||||
get_launch_args(kernel, out, large);
|
||||
hipLaunchKernelGGL(kernel, num_blocks, block_dims, 0, stream,
|
||||
a.data<InType>(),
|
||||
b.data<InType>(),
|
||||
out.data<OutType>(),
|
||||
out.size(),
|
||||
make_hip_array(shape),
|
||||
make_hip_array(a_strides),
|
||||
make_hip_array(b_strides),
|
||||
ndim);
|
||||
}
|
||||
});
|
||||
} else {
|
||||
MLX_SWITCH_BOOL(out.data_size() > UINT32_MAX, LARGE, {
|
||||
using IdxT = std::conditional_t<LARGE, int64_t, uint32_t>;
|
||||
auto kernel = rocm::binary_ss<Op, InType, OutType, IdxT>;
|
||||
if (bopt == BinaryOpType::ScalarVector) {
|
||||
kernel = rocm::binary_sv<Op, InType, OutType, IdxT>;
|
||||
} else if (bopt == BinaryOpType::VectorScalar) {
|
||||
kernel = rocm::binary_vs<Op, InType, OutType, IdxT>;
|
||||
} else if (bopt == BinaryOpType::VectorVector) {
|
||||
kernel = rocm::binary_vv<Op, InType, OutType, IdxT>;
|
||||
}
|
||||
auto [num_blocks, block_dims] = get_launch_args(
|
||||
kernel, out.data_size(), out.shape(), out.strides(), LARGE);
|
||||
hipLaunchKernelGGL(kernel, num_blocks, block_dims, 0, stream,
|
||||
a.data<InType>(),
|
||||
b.data<InType>(),
|
||||
out.data<OutType>(),
|
||||
out.data_size());
|
||||
});
|
||||
}
|
||||
} else {
|
||||
throw std::runtime_error(fmt::format(
|
||||
"Can not do binary op {} on inputs of {} with result of {}.",
|
||||
op,
|
||||
dtype_to_string(a.dtype()),
|
||||
dtype_to_string(out.dtype())));
|
||||
}
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
template <typename Op>
|
||||
void binary_op_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs,
|
||||
std::string_view op,
|
||||
const Stream& s) {
|
||||
auto& a = inputs[0];
|
||||
auto& b = inputs[1];
|
||||
auto bopt = get_binary_op_type(a, b);
|
||||
set_binary_op_output_data(a, b, outputs[0], bopt);
|
||||
set_binary_op_output_data(a, b, outputs[1], bopt);
|
||||
binary_op_gpu_inplace<Op>(inputs, outputs, op, s);
|
||||
}
|
||||
|
||||
template <typename Op>
|
||||
void binary_op_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
array& out,
|
||||
std::string_view op,
|
||||
const Stream& s) {
|
||||
auto& a = inputs[0];
|
||||
auto& b = inputs[1];
|
||||
auto bopt = get_binary_op_type(a, b);
|
||||
set_binary_op_output_data(a, b, out, bopt);
|
||||
std::vector<array> outputs{out};
|
||||
binary_op_gpu_inplace<Op>(inputs, outputs, op, s);
|
||||
}
|
||||
|
||||
#define BINARY_GPU(func) \
|
||||
void func::eval_gpu(const std::vector<array>& inputs, array& out) { \
|
||||
auto& s = out.primitive().stream(); \
|
||||
binary_op_gpu<rocm::func>(inputs, out, get_primitive_string(this), s); \
|
||||
}
|
||||
|
||||
#define BINARY_GPU_MULTI(func) \
|
||||
void func::eval_gpu( \
|
||||
const std::vector<array>& inputs, std::vector<array>& outputs) { \
|
||||
auto& s = outputs[0].primitive().stream(); \
|
||||
binary_op_gpu<rocm::func>(inputs, outputs, get_primitive_string(this), s); \
|
||||
}
|
||||
|
||||
BINARY_GPU(Add)
|
||||
BINARY_GPU(ArcTan2)
|
||||
BINARY_GPU(Divide)
|
||||
BINARY_GPU(Remainder)
|
||||
BINARY_GPU(Greater)
|
||||
BINARY_GPU(GreaterEqual)
|
||||
BINARY_GPU(Less)
|
||||
BINARY_GPU(LessEqual)
|
||||
BINARY_GPU(LogicalAnd)
|
||||
BINARY_GPU(LogicalOr)
|
||||
BINARY_GPU(LogAddExp)
|
||||
BINARY_GPU(Maximum)
|
||||
BINARY_GPU(Minimum)
|
||||
BINARY_GPU(Multiply)
|
||||
BINARY_GPU(NotEqual)
|
||||
BINARY_GPU(Power)
|
||||
BINARY_GPU(Subtract)
|
||||
|
||||
void Equal::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
auto& s = out.primitive().stream();
|
||||
auto op = get_primitive_string(this);
|
||||
if (equal_nan_) {
|
||||
binary_op_gpu<rocm::NaNEqual>(inputs, out, op, s);
|
||||
} else {
|
||||
binary_op_gpu<rocm::Equal>(inputs, out, op, s);
|
||||
}
|
||||
}
|
||||
|
||||
void BitwiseBinary::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
auto& s = out.primitive().stream();
|
||||
auto op = get_primitive_string(this);
|
||||
switch (op_) {
|
||||
case BitwiseBinary::And:
|
||||
binary_op_gpu<rocm::BitwiseAnd>(inputs, out, op, s);
|
||||
break;
|
||||
case BitwiseBinary::Or:
|
||||
binary_op_gpu<rocm::BitwiseOr>(inputs, out, op, s);
|
||||
break;
|
||||
case BitwiseBinary::Xor:
|
||||
binary_op_gpu<rocm::BitwiseXor>(inputs, out, op, s);
|
||||
break;
|
||||
case BitwiseBinary::LeftShift:
|
||||
binary_op_gpu<rocm::LeftShift>(inputs, out, op, s);
|
||||
break;
|
||||
case BitwiseBinary::RightShift:
|
||||
binary_op_gpu<rocm::RightShift>(inputs, out, op, s);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
@@ -1,9 +0,0 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
namespace mlx::core::rocm {
|
||||
|
||||
void compile() {
|
||||
// Placeholder for ROCm compilation
|
||||
}
|
||||
|
||||
} // namespace mlx::core::rocm
|
||||
@@ -1,20 +0,0 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include <hip/hip_runtime.h>
|
||||
|
||||
namespace mlx::core::rocm {
|
||||
|
||||
__global__ void copy_kernel(float* src, float* dst, int n) {
|
||||
int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (idx < n) {
|
||||
dst[idx] = src[idx];
|
||||
}
|
||||
}
|
||||
|
||||
void launch_copy(float* src, float* dst, int n, hipStream_t stream) {
|
||||
int threads = 256;
|
||||
int blocks = (n + threads - 1) / threads;
|
||||
hipLaunchKernelGGL(copy_kernel, dim3(blocks), dim3(threads), 0, stream, src, dst, n);
|
||||
}
|
||||
|
||||
} // namespace mlx::core::rocm
|
||||
@@ -1,60 +0,0 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <hip/hip_runtime.h>
|
||||
#include <cstddef>
|
||||
|
||||
namespace mlx::core::rocm {
|
||||
|
||||
// Copy function declarations
|
||||
void copy_contiguous(
|
||||
const void* src,
|
||||
void* dst,
|
||||
size_t size,
|
||||
hipStream_t stream);
|
||||
|
||||
void copy_general(
|
||||
const void* src,
|
||||
void* dst,
|
||||
const int* src_shape,
|
||||
const size_t* src_strides,
|
||||
const int* dst_shape,
|
||||
const size_t* dst_strides,
|
||||
int ndim,
|
||||
size_t size,
|
||||
size_t dtype_size,
|
||||
hipStream_t stream);
|
||||
|
||||
void copy_general_dynamic(
|
||||
const void* src,
|
||||
void* dst,
|
||||
const int* src_shape,
|
||||
const size_t* src_strides,
|
||||
const int* dst_shape,
|
||||
const size_t* dst_strides,
|
||||
int ndim,
|
||||
size_t size,
|
||||
size_t dtype_size,
|
||||
hipStream_t stream);
|
||||
|
||||
void copy_general_input(
|
||||
const void* src,
|
||||
void* dst,
|
||||
const int* src_shape,
|
||||
const size_t* src_strides,
|
||||
const int* dst_shape,
|
||||
const size_t* dst_strides,
|
||||
int ndim,
|
||||
size_t size,
|
||||
size_t dtype_size,
|
||||
hipStream_t stream);
|
||||
|
||||
// Utility functions for element location calculation
|
||||
__device__ size_t
|
||||
elem_to_loc(size_t elem, const int* shape, const size_t* strides, int ndim);
|
||||
|
||||
__device__ size_t
|
||||
loc_to_elem(size_t loc, const int* shape, const size_t* strides, int ndim);
|
||||
|
||||
} // namespace mlx::core::rocm
|
||||
@@ -1,38 +0,0 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/rocm/copy/copy.hpp"
|
||||
#include "mlx/backend/rocm/kernel_utils.hpp"
|
||||
|
||||
#include <hip/hip_runtime.h>
|
||||
|
||||
namespace mlx::core::rocm {
|
||||
|
||||
__global__ void copy_contiguous_kernel(
|
||||
const char* src,
|
||||
char* dst,
|
||||
size_t size) {
|
||||
size_t tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (tid < size) {
|
||||
dst[tid] = src[tid];
|
||||
}
|
||||
}
|
||||
|
||||
void copy_contiguous(
|
||||
const void* src,
|
||||
void* dst,
|
||||
size_t size,
|
||||
hipStream_t stream) {
|
||||
if (size == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int threads_per_block = 256;
|
||||
const int blocks = (size + threads_per_block - 1) / threads_per_block;
|
||||
|
||||
copy_contiguous_kernel<<<blocks, threads_per_block, 0, stream>>>(
|
||||
static_cast<const char*>(src),
|
||||
static_cast<char*>(dst),
|
||||
size);
|
||||
}
|
||||
|
||||
} // namespace mlx::core::rocm
|
||||
@@ -1,130 +0,0 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/rocm/device.h"
|
||||
#include "mlx/backend/metal/metal.h"
|
||||
#include "mlx/backend/rocm/worker.h"
|
||||
|
||||
#include <fmt/format.h>
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace rocm {
|
||||
|
||||
DeviceStream::DeviceStream(Device& device) : device_(device), stream_(device) {}
|
||||
|
||||
void DeviceStream::synchronize() {
|
||||
CHECK_HIP_ERROR(hipStreamSynchronize(stream_));
|
||||
}
|
||||
|
||||
hipStream_t DeviceStream::schedule_hip_stream() {
|
||||
// TODO: Return a stream that maximizes parallelism.
|
||||
return stream_;
|
||||
}
|
||||
|
||||
hipStream_t DeviceStream::last_hip_stream() {
|
||||
return stream_;
|
||||
}
|
||||
|
||||
CommandEncoder& DeviceStream::get_encoder() {
|
||||
if (!encoder_) {
|
||||
encoder_ = std::make_unique<CommandEncoder>(*this);
|
||||
}
|
||||
return *encoder_;
|
||||
}
|
||||
|
||||
Device::Device(int device) : device_(device) {
|
||||
CHECK_HIP_ERROR(hipDeviceGetAttribute(
|
||||
&compute_capability_major_,
|
||||
hipDeviceAttributeComputeCapabilityMajor,
|
||||
device_));
|
||||
CHECK_HIP_ERROR(hipDeviceGetAttribute(
|
||||
&compute_capability_minor_,
|
||||
hipDeviceAttributeComputeCapabilityMinor,
|
||||
device_));
|
||||
|
||||
// Validate device requirements
|
||||
int attr = 0;
|
||||
CHECK_HIP_ERROR(hipDeviceGetAttribute(
|
||||
&attr, hipDeviceAttributeConcurrentManagedAccess, device_));
|
||||
if (attr != 1) {
|
||||
// ROCm unified memory might not be available on all devices
|
||||
// This is a warning rather than an error for ROCm
|
||||
// TODO: Add proper ROCm unified memory checking
|
||||
}
|
||||
|
||||
// Create rocBLAS handle
|
||||
make_current();
|
||||
CHECK_HIP_ERROR(
|
||||
static_cast<hipError_t>(rocblas_create_handle(&rocblas_handle_)));
|
||||
}
|
||||
|
||||
Device::~Device() {
|
||||
if (rocblas_handle_) {
|
||||
rocblas_destroy_handle(rocblas_handle_);
|
||||
}
|
||||
}
|
||||
|
||||
void Device::make_current() {
|
||||
// Cache current device to reduce HIP API calls
|
||||
static int current = 0;
|
||||
if (current != device_) {
|
||||
CHECK_HIP_ERROR(hipSetDevice(device_));
|
||||
current = device_;
|
||||
}
|
||||
}
|
||||
|
||||
DeviceStream& Device::get_stream(Stream s) {
|
||||
auto it = streams_.find(s.index);
|
||||
if (it == streams_.end()) {
|
||||
it = streams_.try_emplace(s.index, *this).first;
|
||||
}
|
||||
return it->second;
|
||||
}
|
||||
|
||||
CommandEncoder::CommandEncoder(DeviceStream& s)
|
||||
: device_(s.device()), stream_(s) {}
|
||||
|
||||
void CommandEncoder::add_completed_handler(std::function<void()> task) {
|
||||
worker_.add_task(std::move(task));
|
||||
}
|
||||
|
||||
void CommandEncoder::end_encoding() {
|
||||
if (!temporaries_.empty()) {
|
||||
add_completed_handler([temporaries = std::move(temporaries_)]() {});
|
||||
}
|
||||
|
||||
// There is no kernel running, run completion handlers immediately.
|
||||
if (!has_gpu_work_) {
|
||||
worker_.consume_in_this_thread();
|
||||
return;
|
||||
}
|
||||
has_gpu_work_ = false;
|
||||
|
||||
// Commit tasks
|
||||
commit();
|
||||
}
|
||||
|
||||
void CommandEncoder::commit() {
|
||||
worker_.commit(stream_.last_hip_stream());
|
||||
}
|
||||
|
||||
Device& device(mlx::core::Device device) {
|
||||
static std::unordered_map<int, Device> devices;
|
||||
auto it = devices.find(device.index);
|
||||
if (it == devices.end()) {
|
||||
it = devices.try_emplace(device.index, device.index).first;
|
||||
}
|
||||
return it->second;
|
||||
}
|
||||
|
||||
DeviceStream& get_stream(Stream s) {
|
||||
return device(s.device).get_stream(s);
|
||||
}
|
||||
|
||||
CommandEncoder& get_command_encoder(Stream s) {
|
||||
return get_stream(s).get_encoder();
|
||||
}
|
||||
|
||||
} // namespace rocm
|
||||
|
||||
} // namespace mlx::core
|
||||
@@ -1,146 +0,0 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "mlx/array.h"
|
||||
#include "mlx/backend/rocm/utils.h"
|
||||
#include "mlx/backend/rocm/worker.h"
|
||||
#include "mlx/stream.h"
|
||||
|
||||
#include <hip/hip_runtime.h>
|
||||
#include <rocblas/rocblas.h>
|
||||
|
||||
#include <unordered_map>
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace rocm {
|
||||
|
||||
class Device;
|
||||
class CommandEncoder;
|
||||
|
||||
class DeviceStream {
|
||||
public:
|
||||
explicit DeviceStream(Device& device);
|
||||
|
||||
DeviceStream(const DeviceStream&) = delete;
|
||||
DeviceStream& operator=(const DeviceStream&) = delete;
|
||||
|
||||
// Wait until kernels in the stream complete.
|
||||
void synchronize();
|
||||
|
||||
// Return a HIP stream for launching kernels.
|
||||
hipStream_t schedule_hip_stream();
|
||||
|
||||
// Return the last HIP stream used.
|
||||
hipStream_t last_hip_stream();
|
||||
|
||||
CommandEncoder& get_encoder();
|
||||
|
||||
Device& device() {
|
||||
return device_;
|
||||
}
|
||||
|
||||
private:
|
||||
Device& device_;
|
||||
HipStream stream_;
|
||||
std::unique_ptr<CommandEncoder> encoder_;
|
||||
};
|
||||
|
||||
class Device {
|
||||
public:
|
||||
explicit Device(int device);
|
||||
~Device();
|
||||
|
||||
Device(const Device&) = delete;
|
||||
Device& operator=(const Device&) = delete;
|
||||
|
||||
// Make this device the current HIP device, required by some HIP calls.
|
||||
void make_current();
|
||||
|
||||
DeviceStream& get_stream(Stream s);
|
||||
|
||||
int hip_device() const {
|
||||
return device_;
|
||||
}
|
||||
int compute_capability_major() const {
|
||||
return compute_capability_major_;
|
||||
}
|
||||
int compute_capability_minor() const {
|
||||
return compute_capability_minor_;
|
||||
}
|
||||
rocblas_handle rocblas_handle() const {
|
||||
return rocblas_handle_;
|
||||
}
|
||||
|
||||
private:
|
||||
int device_;
|
||||
int compute_capability_major_;
|
||||
int compute_capability_minor_;
|
||||
rocblas_handle rocblas_handle_;
|
||||
std::unordered_map<int, DeviceStream> streams_;
|
||||
};
|
||||
|
||||
class CommandEncoder {
|
||||
public:
|
||||
explicit CommandEncoder(DeviceStream& stream);
|
||||
|
||||
CommandEncoder(const CommandEncoder&) = delete;
|
||||
CommandEncoder& operator=(const CommandEncoder&) = delete;
|
||||
|
||||
void set_input_array(const array& arr) {}
|
||||
void set_output_array(const array& arr) {}
|
||||
|
||||
void add_temporary(const array& arr) {
|
||||
temporaries_.push_back(arr.data_shared_ptr());
|
||||
}
|
||||
|
||||
void add_completed_handler(std::function<void()> task);
|
||||
void end_encoding();
|
||||
void commit();
|
||||
|
||||
// Schedule a HIP stream for |fun| to launch kernels, and check error
|
||||
// afterwards.
|
||||
template <typename F>
|
||||
void launch_kernel(F&& fun) {
|
||||
launch_kernel(stream_.schedule_hip_stream(), std::forward<F>(fun));
|
||||
}
|
||||
|
||||
template <typename F>
|
||||
void launch_kernel(hipStream_t stream, F&& fun) {
|
||||
device_.make_current();
|
||||
fun(stream);
|
||||
check_hip_error("kernel launch", hipGetLastError());
|
||||
has_gpu_work_ = true;
|
||||
}
|
||||
|
||||
Device& device() {
|
||||
return device_;
|
||||
}
|
||||
|
||||
DeviceStream& stream() {
|
||||
return stream_;
|
||||
}
|
||||
|
||||
bool has_gpu_work() const {
|
||||
return has_gpu_work_;
|
||||
}
|
||||
|
||||
private:
|
||||
Device& device_;
|
||||
DeviceStream& stream_;
|
||||
Worker worker_;
|
||||
bool has_gpu_work_{false};
|
||||
std::vector<std::shared_ptr<array::Data>> temporaries_;
|
||||
};
|
||||
|
||||
Device& device(mlx::core::Device device);
|
||||
DeviceStream& get_stream(Stream s);
|
||||
CommandEncoder& get_command_encoder(Stream s);
|
||||
|
||||
// Utility function to check HIP errors
|
||||
void check_hip_error(const char* msg, hipError_t error);
|
||||
|
||||
} // namespace rocm
|
||||
|
||||
} // namespace mlx::core
|
||||
@@ -1,17 +0,0 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <hip/hip_runtime.h>
|
||||
|
||||
namespace mlx::core::rocm {
|
||||
|
||||
template <typename T>
|
||||
__global__ void arange_kernel(T* out, T start, T step, size_t size) {
|
||||
size_t tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (tid < size) {
|
||||
out[tid] = start + static_cast<T>(tid) * step;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core::rocm
|
||||
@@ -1,36 +0,0 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <hip/hip_runtime.h>
|
||||
|
||||
namespace mlx::core::rocm {
|
||||
|
||||
// Atomic operations for HIP
|
||||
__device__ inline float atomicAddFloat(float* address, float val) {
|
||||
return atomicAdd(address, val);
|
||||
}
|
||||
|
||||
__device__ inline double atomicAddDouble(double* address, double val) {
|
||||
return atomicAdd(address, val);
|
||||
}
|
||||
|
||||
__device__ inline int atomicAddInt(int* address, int val) {
|
||||
return atomicAdd(address, val);
|
||||
}
|
||||
|
||||
__device__ inline unsigned int atomicAddUInt(
|
||||
unsigned int* address,
|
||||
unsigned int val) {
|
||||
return atomicAdd(address, val);
|
||||
}
|
||||
|
||||
__device__ inline float atomicMaxFloat(float* address, float val) {
|
||||
return atomicMax(address, val);
|
||||
}
|
||||
|
||||
__device__ inline float atomicMinFloat(float* address, float val) {
|
||||
return atomicMin(address, val);
|
||||
}
|
||||
|
||||
} // namespace mlx::core::rocm
|
||||
@@ -1,217 +0,0 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <hip/hip_bfloat16.h>
|
||||
#include <hip/hip_fp16.h>
|
||||
#include <hip/hip_runtime.h>
|
||||
#include <hipcomplex.h>
|
||||
|
||||
namespace mlx::core::rocm {
|
||||
|
||||
// Arithmetic operations
|
||||
struct Add {
|
||||
template <typename T>
|
||||
__device__ T operator()(T a, T b) {
|
||||
return a + b;
|
||||
}
|
||||
};
|
||||
|
||||
struct Subtract {
|
||||
template <typename T>
|
||||
__device__ T operator()(T a, T b) {
|
||||
return a - b;
|
||||
}
|
||||
};
|
||||
|
||||
struct Multiply {
|
||||
template <typename T>
|
||||
__device__ T operator()(T a, T b) {
|
||||
return a * b;
|
||||
}
|
||||
};
|
||||
|
||||
struct Divide {
|
||||
template <typename T>
|
||||
__device__ T operator()(T a, T b) {
|
||||
return a / b;
|
||||
}
|
||||
};
|
||||
|
||||
struct Power {
|
||||
template <typename T>
|
||||
__device__ T operator()(T a, T b) {
|
||||
return powf(a, b);
|
||||
}
|
||||
|
||||
__device__ double operator()(double a, double b) {
|
||||
return pow(a, b);
|
||||
}
|
||||
};
|
||||
|
||||
struct Remainder {
|
||||
template <typename T>
|
||||
__device__ T operator()(T a, T b) {
|
||||
return fmodf(a, b);
|
||||
}
|
||||
|
||||
__device__ double operator()(double a, double b) {
|
||||
return fmod(a, b);
|
||||
}
|
||||
};
|
||||
|
||||
// Comparison operations
|
||||
struct Equal {
|
||||
template <typename T>
|
||||
__device__ bool operator()(T a, T b) {
|
||||
return a == b;
|
||||
}
|
||||
};
|
||||
|
||||
struct NotEqual {
|
||||
template <typename T>
|
||||
__device__ bool operator()(T a, T b) {
|
||||
return a != b;
|
||||
}
|
||||
};
|
||||
|
||||
struct Greater {
|
||||
template <typename T>
|
||||
__device__ bool operator()(T a, T b) {
|
||||
return a > b;
|
||||
}
|
||||
};
|
||||
|
||||
struct GreaterEqual {
|
||||
template <typename T>
|
||||
__device__ bool operator()(T a, T b) {
|
||||
return a >= b;
|
||||
}
|
||||
};
|
||||
|
||||
struct Less {
|
||||
template <typename T>
|
||||
__device__ bool operator()(T a, T b) {
|
||||
return a < b;
|
||||
}
|
||||
};
|
||||
|
||||
struct LessEqual {
|
||||
template <typename T>
|
||||
__device__ bool operator()(T a, T b) {
|
||||
return a <= b;
|
||||
}
|
||||
};
|
||||
|
||||
struct NaNEqual {
|
||||
template <typename T>
|
||||
__device__ bool operator()(T a, T b) {
|
||||
return (isnan(a) && isnan(b)) || (a == b);
|
||||
}
|
||||
};
|
||||
|
||||
// Logic operations
|
||||
struct LogicalAnd {
|
||||
__device__ bool operator()(bool a, bool b) {
|
||||
return a && b;
|
||||
}
|
||||
};
|
||||
|
||||
struct LogicalOr {
|
||||
__device__ bool operator()(bool a, bool b) {
|
||||
return a || b;
|
||||
}
|
||||
};
|
||||
|
||||
// Math operations
|
||||
struct Maximum {
|
||||
template <typename T>
|
||||
__device__ T operator()(T a, T b) {
|
||||
return fmaxf(a, b);
|
||||
}
|
||||
|
||||
__device__ double operator()(double a, double b) {
|
||||
return fmax(a, b);
|
||||
}
|
||||
};
|
||||
|
||||
struct Minimum {
|
||||
template <typename T>
|
||||
__device__ T operator()(T a, T b) {
|
||||
return fminf(a, b);
|
||||
}
|
||||
|
||||
__device__ double operator()(double a, double b) {
|
||||
return fmin(a, b);
|
||||
}
|
||||
};
|
||||
|
||||
struct LogAddExp {
|
||||
template <typename T>
|
||||
__device__ T operator()(T a, T b) {
|
||||
T max_val = fmaxf(a, b);
|
||||
T min_val = fminf(a, b);
|
||||
if (isinf(max_val)) {
|
||||
return max_val;
|
||||
}
|
||||
return max_val + log1pf(expf(min_val - max_val));
|
||||
}
|
||||
|
||||
__device__ double operator()(double a, double b) {
|
||||
double max_val = fmax(a, b);
|
||||
double min_val = fmin(a, b);
|
||||
if (isinf(max_val)) {
|
||||
return max_val;
|
||||
}
|
||||
return max_val + log1p(exp(min_val - max_val));
|
||||
}
|
||||
};
|
||||
|
||||
struct ArcTan2 {
|
||||
template <typename T>
|
||||
__device__ T operator()(T a, T b) {
|
||||
return atan2f(a, b);
|
||||
}
|
||||
|
||||
__device__ double operator()(double a, double b) {
|
||||
return atan2(a, b);
|
||||
}
|
||||
};
|
||||
|
||||
// Bitwise operations
|
||||
struct BitwiseAnd {
|
||||
template <typename T>
|
||||
__device__ T operator()(T a, T b) {
|
||||
return a & b;
|
||||
}
|
||||
};
|
||||
|
||||
struct BitwiseOr {
|
||||
template <typename T>
|
||||
__device__ T operator()(T a, T b) {
|
||||
return a | b;
|
||||
}
|
||||
};
|
||||
|
||||
struct BitwiseXor {
|
||||
template <typename T>
|
||||
__device__ T operator()(T a, T b) {
|
||||
return a ^ b;
|
||||
}
|
||||
};
|
||||
|
||||
struct LeftShift {
|
||||
template <typename T>
|
||||
__device__ T operator()(T a, T b) {
|
||||
return a << b;
|
||||
}
|
||||
};
|
||||
|
||||
struct RightShift {
|
||||
template <typename T>
|
||||
__device__ T operator()(T a, T b) {
|
||||
return a >> b;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace mlx::core::rocm
|
||||
@@ -1,21 +0,0 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <hip/hip_runtime.h>
|
||||
|
||||
namespace mlx::core::rocm {
|
||||
|
||||
template <typename To, typename From>
|
||||
struct CastOp {
|
||||
__device__ To operator()(From x) const {
|
||||
return static_cast<To>(x);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename To, typename From>
|
||||
__device__ inline To cast_op(From x) {
|
||||
return static_cast<To>(x);
|
||||
}
|
||||
|
||||
} // namespace mlx::core::rocm
|
||||
@@ -1,14 +0,0 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
// ROCm/HIP specific configuration
|
||||
#define ROCM_MAX_THREADS_PER_BLOCK 1024
|
||||
#define ROCM_WARP_SIZE 64
|
||||
#define ROCM_MAX_BLOCKS_PER_GRID 65535
|
||||
|
||||
namespace mlx::core::rocm {
|
||||
constexpr int kMaxThreadsPerBlock = ROCM_MAX_THREADS_PER_BLOCK;
|
||||
constexpr int kWarpSize = ROCM_WARP_SIZE;
|
||||
constexpr int kMaxBlocksPerGrid = ROCM_MAX_BLOCKS_PER_GRID;
|
||||
} // namespace mlx::core::rocm
|
||||
@@ -1,87 +0,0 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <hip/hip_fp16.h>
|
||||
#include <hip/hip_runtime.h>
|
||||
|
||||
namespace mlx::core::rocm {
|
||||
|
||||
// HIP/ROCm equivalents of CUDA half precision math functions
|
||||
inline __device__ __half2 h2sin(__half2 x) {
|
||||
return __half2{hsin(x.x), hsin(x.y)};
|
||||
}
|
||||
|
||||
inline __device__ __half2 h2cos(__half2 x) {
|
||||
return __half2{hcos(x.x), hcos(x.y)};
|
||||
}
|
||||
|
||||
inline __device__ __half2 h2exp(__half2 x) {
|
||||
return __half2{hexp(x.x), hexp(x.y)};
|
||||
}
|
||||
|
||||
inline __device__ __half2 h2log(__half2 x) {
|
||||
return __half2{hlog(x.x), hlog(x.y)};
|
||||
}
|
||||
|
||||
inline __device__ __half2 h2sqrt(__half2 x) {
|
||||
return __half2{hsqrt(x.x), hsqrt(x.y)};
|
||||
}
|
||||
|
||||
inline __device__ __half2 h2rsqrt(__half2 x) {
|
||||
return __half2{hrsqrt(x.x), hrsqrt(x.y)};
|
||||
}
|
||||
|
||||
inline __device__ __half2 h2ceil(__half2 x) {
|
||||
return __half2{hceil(x.x), hceil(x.y)};
|
||||
}
|
||||
|
||||
inline __device__ __half2 h2floor(__half2 x) {
|
||||
return __half2{hfloor(x.x), hfloor(x.y)};
|
||||
}
|
||||
|
||||
inline __device__ __half2 h2rint(__half2 x) {
|
||||
return __half2{hrint(x.x), hrint(x.y)};
|
||||
}
|
||||
|
||||
inline __device__ __half2 h2trunc(__half2 x) {
|
||||
return __half2{htrunc(x.x), htrunc(x.y)};
|
||||
}
|
||||
|
||||
// Additional math functions for half precision
|
||||
inline __device__ __half habs(__half x) {
|
||||
return __half{fabsf(__half2float(x))};
|
||||
}
|
||||
|
||||
inline __device__ __half2 h2abs(__half2 x) {
|
||||
return __half2{habs(x.x), habs(x.y)};
|
||||
}
|
||||
|
||||
inline __device__ __half hneg(__half x) {
|
||||
return __half{-__half2float(x)};
|
||||
}
|
||||
|
||||
inline __device__ __half2 h2neg(__half2 x) {
|
||||
return __half2{hneg(x.x), hneg(x.y)};
|
||||
}
|
||||
|
||||
// BFloat16 support functions
|
||||
#ifdef __HIP_BFLOAT16__
|
||||
inline __device__ __hip_bfloat16 habs(__hip_bfloat16 x) {
|
||||
return __hip_bfloat16{fabsf(__bfloat162float(x))};
|
||||
}
|
||||
|
||||
inline __device__ __hip_bfloat162 h2abs(__hip_bfloat162 x) {
|
||||
return __hip_bfloat162{habs(x.x), habs(x.y)};
|
||||
}
|
||||
|
||||
inline __device__ __hip_bfloat16 hneg(__hip_bfloat16 x) {
|
||||
return __hip_bfloat16{-__bfloat162float(x)};
|
||||
}
|
||||
|
||||
inline __device__ __hip_bfloat162 h2neg(__hip_bfloat162 x) {
|
||||
return __hip_bfloat162{hneg(x.x), hneg(x.y)};
|
||||
}
|
||||
#endif
|
||||
|
||||
} // namespace mlx::core::rocm
|
||||
@@ -1,52 +0,0 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <hip/hip_complex.h>
|
||||
#include <hip/hip_runtime.h>
|
||||
|
||||
namespace mlx::core::rocm {
|
||||
|
||||
// HIP complex math functions
|
||||
__device__ inline hipFloatComplex hip_complex_add(
|
||||
hipFloatComplex a,
|
||||
hipFloatComplex b) {
|
||||
return make_hipFloatComplex(
|
||||
hipCrealf(a) + hipCrealf(b), hipCimagf(a) + hipCimagf(b));
|
||||
}
|
||||
|
||||
__device__ inline hipFloatComplex hip_complex_sub(
|
||||
hipFloatComplex a,
|
||||
hipFloatComplex b) {
|
||||
return make_hipFloatComplex(
|
||||
hipCrealf(a) - hipCrealf(b), hipCimagf(a) - hipCimagf(b));
|
||||
}
|
||||
|
||||
__device__ inline hipFloatComplex hip_complex_mul(
|
||||
hipFloatComplex a,
|
||||
hipFloatComplex b) {
|
||||
float real = hipCrealf(a) * hipCrealf(b) - hipCimagf(a) * hipCimagf(b);
|
||||
float imag = hipCrealf(a) * hipCimagf(b) + hipCimagf(a) * hipCrealf(b);
|
||||
return make_hipFloatComplex(real, imag);
|
||||
}
|
||||
|
||||
__device__ inline hipFloatComplex hip_complex_div(
|
||||
hipFloatComplex a,
|
||||
hipFloatComplex b) {
|
||||
float denom = hipCrealf(b) * hipCrealf(b) + hipCimagf(b) * hipCimagf(b);
|
||||
float real =
|
||||
(hipCrealf(a) * hipCrealf(b) + hipCimagf(a) * hipCimagf(b)) / denom;
|
||||
float imag =
|
||||
(hipCimagf(a) * hipCrealf(b) - hipCrealf(a) * hipCimagf(b)) / denom;
|
||||
return make_hipFloatComplex(real, imag);
|
||||
}
|
||||
|
||||
__device__ inline float hip_complex_abs(hipFloatComplex z) {
|
||||
return sqrtf(hipCrealf(z) * hipCrealf(z) + hipCimagf(z) * hipCimagf(z));
|
||||
}
|
||||
|
||||
__device__ inline hipFloatComplex hip_complex_conj(hipFloatComplex z) {
|
||||
return make_hipFloatComplex(hipCrealf(z), -hipCimagf(z));
|
||||
}
|
||||
|
||||
} // namespace mlx::core::rocm
|
||||
@@ -1,16 +0,0 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <hip/hip_runtime.h>
|
||||
|
||||
namespace mlx::core::rocm {
|
||||
|
||||
struct Select {
|
||||
template <typename T>
|
||||
__device__ T operator()(bool condition, T a, T b) const {
|
||||
return condition ? a : b;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace mlx::core::rocm
|
||||
@@ -1,368 +0,0 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "mlx/backend/rocm/device/fp16_math.hpp"
|
||||
#include "mlx/backend/rocm/device/utils.hpp"
|
||||
|
||||
#include <hip/hip_runtime.h>
|
||||
|
||||
namespace mlx::core::rocm {
|
||||
|
||||
struct Abs {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x) {
|
||||
if constexpr (std::is_unsigned_v<T>) {
|
||||
return x;
|
||||
} else if constexpr (std::is_same_v<T, hipFloatComplex>) {
|
||||
return {
|
||||
sqrt(hipCrealf(x) * hipCrealf(x) + hipCimagf(x) * hipCimagf(x)), 0};
|
||||
} else {
|
||||
return abs(x);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct ArcCos {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x) {
|
||||
return acos(x);
|
||||
}
|
||||
};
|
||||
|
||||
struct ArcCosh {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x) {
|
||||
return acosh(x);
|
||||
}
|
||||
};
|
||||
|
||||
struct ArcSin {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x) {
|
||||
return asin(x);
|
||||
}
|
||||
};
|
||||
|
||||
struct ArcSinh {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x) {
|
||||
return asinh(x);
|
||||
}
|
||||
};
|
||||
|
||||
struct ArcTan {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x) {
|
||||
return atan(x);
|
||||
}
|
||||
};
|
||||
|
||||
struct ArcTanh {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x) {
|
||||
return atanh(x);
|
||||
}
|
||||
};
|
||||
|
||||
struct BitwiseInvert {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x) {
|
||||
return ~x;
|
||||
}
|
||||
};
|
||||
|
||||
struct Ceil {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x) {
|
||||
if constexpr (std::is_integral_v<T>) {
|
||||
return x;
|
||||
} else {
|
||||
return ceil(x);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct Conjugate {
|
||||
__device__ hipFloatComplex operator()(hipFloatComplex x) {
|
||||
return {hipCrealf(x), -hipCimagf(x)};
|
||||
}
|
||||
};
|
||||
|
||||
struct Cos {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x) {
|
||||
if constexpr (std::is_same_v<T, hipFloatComplex>) {
|
||||
return {
|
||||
cos(hipCrealf(x)) * cosh(hipCimagf(x)),
|
||||
-sin(hipCrealf(x)) * sinh(hipCimagf(x))};
|
||||
} else {
|
||||
return cos(x);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct Cosh {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x) {
|
||||
if constexpr (std::is_same_v<T, hipFloatComplex>) {
|
||||
return {
|
||||
cosh(hipCrealf(x)) * cos(hipCimagf(x)),
|
||||
sinh(hipCrealf(x)) * sin(hipCimagf(x))};
|
||||
} else {
|
||||
return cosh(x);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct Erf {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x) {
|
||||
if constexpr (std::is_same_v<T, __half>) {
|
||||
return erf(__half2float(x));
|
||||
} else if constexpr (std::is_same_v<T, __hip_bfloat16>) {
|
||||
return erf(__bfloat162float(x));
|
||||
} else {
|
||||
return erf(x);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct ErfInv {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x) {
|
||||
if constexpr (std::is_same_v<T, __half>) {
|
||||
return erfinv(__half2float(x));
|
||||
} else if constexpr (std::is_same_v<T, __hip_bfloat16>) {
|
||||
return erfinv(__bfloat162float(x));
|
||||
} else {
|
||||
return erfinv(x);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct Exp {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x) {
|
||||
if constexpr (std::is_same_v<T, hipFloatComplex>) {
|
||||
auto m = exp(hipCrealf(x));
|
||||
return {m * cos(hipCimagf(x)), m * sinh(hipCimagf(x))};
|
||||
} else {
|
||||
return exp(x);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct Expm1 {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x) {
|
||||
if constexpr (std::is_same_v<T, __half>) {
|
||||
return expm1(__half2float(x));
|
||||
} else if constexpr (std::is_same_v<T, __hip_bfloat16>) {
|
||||
return expm1(__bfloat162float(x));
|
||||
} else {
|
||||
return expm1(x);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct Floor {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x) {
|
||||
if constexpr (std::is_integral_v<T>) {
|
||||
return x;
|
||||
} else {
|
||||
return floor(x);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct Imag {
|
||||
__device__ float operator()(hipFloatComplex x) {
|
||||
return hipCimagf(x);
|
||||
}
|
||||
};
|
||||
|
||||
struct Log {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x) {
|
||||
if constexpr (std::is_same_v<T, hipFloatComplex>) {
|
||||
auto r = log(hipCrealf(Abs{}(x)));
|
||||
auto i = atan2f(hipCimagf(x), hipCrealf(x));
|
||||
return {r, i};
|
||||
} else {
|
||||
return log(x);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct Log2 {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x) {
|
||||
if constexpr (std::is_same_v<T, hipFloatComplex>) {
|
||||
auto y = Log{}(x);
|
||||
return {hipCrealf(y) / M_LN2, hipCimagf(y) / M_LN2};
|
||||
} else {
|
||||
return log2(x);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct Log10 {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x) {
|
||||
if constexpr (std::is_same_v<T, hipFloatComplex>) {
|
||||
auto y = Log{}(x);
|
||||
return {hipCrealf(y) / M_LN10, hipCimagf(y) / M_LN10};
|
||||
} else {
|
||||
return log10(x);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct Log1p {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x) {
|
||||
return log1p(x);
|
||||
}
|
||||
};
|
||||
|
||||
struct LogicalNot {
|
||||
__device__ bool operator()(bool x) {
|
||||
return !x;
|
||||
}
|
||||
};
|
||||
|
||||
struct Negative {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x) {
|
||||
if constexpr (std::is_same_v<T, hipFloatComplex>) {
|
||||
return 0 - x;
|
||||
} else {
|
||||
return -x;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct Real {
|
||||
__device__ float operator()(hipFloatComplex x) {
|
||||
return hipCrealf(x);
|
||||
}
|
||||
};
|
||||
|
||||
struct Round {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x) {
|
||||
if constexpr (std::is_same_v<T, hipFloatComplex>) {
|
||||
return {rint(hipCrealf(x)), rint(hipCimagf(x))};
|
||||
} else {
|
||||
return rint(x);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct Rsqrt {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x) {
|
||||
return rsqrt(x);
|
||||
}
|
||||
};
|
||||
|
||||
struct Sigmoid {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x) {
|
||||
T y = 1 / (1 + exp(-abs(x)));
|
||||
return (x < 0) ? 1 - y : y;
|
||||
}
|
||||
};
|
||||
|
||||
struct Sign {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x) {
|
||||
if constexpr (std::is_unsigned_v<T>) {
|
||||
return x != 0;
|
||||
} else if constexpr (std::is_same_v<T, hipFloatComplex>) {
|
||||
if (hipCrealf(x) == 0 && hipCimagf(x) == 0) {
|
||||
return x;
|
||||
} else {
|
||||
return x / Abs()(x);
|
||||
}
|
||||
} else if constexpr (std::is_same_v<T, __hip_bfloat16>) {
|
||||
return static_cast<float>((x > T(0.f)) - (x < T(0.f)));
|
||||
} else {
|
||||
return (x > T(0)) - (x < T(0));
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct Sin {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x) {
|
||||
if constexpr (std::is_same_v<T, hipFloatComplex>) {
|
||||
return {
|
||||
sin(hipCrealf(x)) * cosh(hipCimagf(x)),
|
||||
cos(hipCrealf(x)) * sinh(hipCimagf(x))};
|
||||
} else {
|
||||
return sin(x);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct Sinh {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x) {
|
||||
if constexpr (std::is_same_v<T, hipFloatComplex>) {
|
||||
return {
|
||||
sinh(hipCrealf(x)) * cos(hipCimagf(x)),
|
||||
cosh(hipCrealf(x)) * sin(hipCimagf(x))};
|
||||
} else {
|
||||
return sinh(x);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct Square {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x) {
|
||||
return x * x;
|
||||
}
|
||||
};
|
||||
|
||||
struct Sqrt {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x) {
|
||||
return sqrt(x);
|
||||
}
|
||||
};
|
||||
|
||||
struct Tan {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x) {
|
||||
if constexpr (std::is_same_v<T, hipFloatComplex>) {
|
||||
float tan_a = tan(hipCrealf(x));
|
||||
float tanh_b = tanh(hipCimagf(x));
|
||||
float t1 = tan_a * tanh_b;
|
||||
float denom = 1. + t1 * t1;
|
||||
return {(tan_a - tanh_b * t1) / denom, (tanh_b + tan_a * t1) / denom};
|
||||
} else {
|
||||
return tan(x);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct Tanh {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x) {
|
||||
if constexpr (std::is_same_v<T, hipFloatComplex>) {
|
||||
float tanh_a = tanh(hipCrealf(x));
|
||||
float tan_b = tan(hipCimagf(x));
|
||||
float t1 = tanh_a * tan_b;
|
||||
float denom = 1. + t1 * t1;
|
||||
return {(tanh_a + tan_b * t1) / denom, (tan_b - tanh_a * t1) / denom};
|
||||
} else {
|
||||
return tanh(x);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace mlx::core::rocm
|
||||
@@ -1,173 +0,0 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <hip/hip_complex.h>
|
||||
#include <hip/hip_runtime.h>
|
||||
|
||||
namespace mlx::core::rocm {
|
||||
|
||||
// HIP/ROCm type definitions
|
||||
using hip_complex = hipFloatComplex;
|
||||
|
||||
// Utility functions for HIP device code
|
||||
template <typename T>
|
||||
struct hip_type {
|
||||
using type = T;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct hip_type<bool> {
|
||||
using type = bool;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct hip_type<int8_t> {
|
||||
using type = int8_t;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct hip_type<uint8_t> {
|
||||
using type = uint8_t;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct hip_type<int16_t> {
|
||||
using type = int16_t;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct hip_type<uint16_t> {
|
||||
using type = uint16_t;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct hip_type<int32_t> {
|
||||
using type = int32_t;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct hip_type<uint32_t> {
|
||||
using type = uint32_t;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct hip_type<int64_t> {
|
||||
using type = int64_t;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct hip_type<uint64_t> {
|
||||
using type = uint64_t;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct hip_type<float> {
|
||||
using type = float;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct hip_type<double> {
|
||||
using type = double;
|
||||
};
|
||||
|
||||
#ifdef __HIP_PLATFORM_HCC__
|
||||
template <>
|
||||
struct hip_type<__half> {
|
||||
using type = __half;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct hip_type<__hip_bfloat16> {
|
||||
using type = __hip_bfloat16;
|
||||
};
|
||||
#endif
|
||||
|
||||
template <typename T>
|
||||
using hip_type_t = typename hip_type<T>::type;
|
||||
|
||||
// Element-wise operations support
|
||||
template <typename T>
|
||||
constexpr bool is_floating_point_v = std::is_floating_point_v<T>;
|
||||
|
||||
template <typename T>
|
||||
constexpr bool is_integral_v = std::is_integral_v<T>;
|
||||
|
||||
template <typename T>
|
||||
constexpr bool is_signed_v = std::is_signed_v<T>;
|
||||
|
||||
template <typename T>
|
||||
constexpr bool is_unsigned_v = std::is_unsigned_v<T>;
|
||||
|
||||
// Complex number helper functions
|
||||
inline __device__ hipFloatComplex make_complex(float real, float imag) {
|
||||
return make_hipFloatComplex(real, imag);
|
||||
}
|
||||
|
||||
inline __device__ float hip_real(hipFloatComplex z) {
|
||||
return hipCrealf(z);
|
||||
}
|
||||
|
||||
inline __device__ float hip_imag(hipFloatComplex z) {
|
||||
return hipCimagf(z);
|
||||
}
|
||||
|
||||
inline __device__ hipFloatComplex hip_conj(hipFloatComplex z) {
|
||||
return make_hipFloatComplex(hipCrealf(z), -hipCimagf(z));
|
||||
}
|
||||
|
||||
inline __device__ float hip_abs(hipFloatComplex z) {
|
||||
return sqrtf(hipCrealf(z) * hipCrealf(z) + hipCimagf(z) * hipCimagf(z));
|
||||
}
|
||||
|
||||
// Memory access utilities
|
||||
template <typename T>
|
||||
inline __device__ T hip_load_global(const T* ptr) {
|
||||
return *ptr;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline __device__ void hip_store_global(T* ptr, T value) {
|
||||
*ptr = value;
|
||||
}
|
||||
|
||||
// Grid and block utilities
|
||||
inline __device__ int hip_thread_idx() {
|
||||
return threadIdx.x;
|
||||
}
|
||||
|
||||
inline __device__ int hip_block_idx() {
|
||||
return blockIdx.x;
|
||||
}
|
||||
|
||||
inline __device__ int hip_block_dim() {
|
||||
return blockDim.x;
|
||||
}
|
||||
|
||||
inline __device__ int hip_grid_dim() {
|
||||
return gridDim.x;
|
||||
}
|
||||
|
||||
inline __device__ int hip_global_thread_idx() {
|
||||
return blockIdx.x * blockDim.x + threadIdx.x;
|
||||
}
|
||||
|
||||
// Synchronization
|
||||
inline __device__ void hip_sync_threads() {
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
// Math constants for HIP (equivalent to CUDA's math_constants.h)
|
||||
#ifndef M_PI
|
||||
#define M_PI 3.14159265358979323846
|
||||
#endif
|
||||
|
||||
#ifndef M_LN2
|
||||
#define M_LN2 0.693147180559945309417
|
||||
#endif
|
||||
|
||||
#ifndef M_LN10
|
||||
#define M_LN10 2.302585092994045684018
|
||||
#endif
|
||||
|
||||
} // namespace mlx::core::rocm
|
||||
@@ -1,11 +0,0 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/rocm/utils.h"
|
||||
|
||||
namespace mlx::core::rocm {
|
||||
|
||||
void eval() {
|
||||
// Placeholder for ROCm evaluation
|
||||
}
|
||||
|
||||
} // namespace mlx::core::rocm
|
||||
@@ -1,50 +0,0 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/rocm/event.h"
|
||||
#include "mlx/backend/rocm/utils.h"
|
||||
|
||||
namespace mlx::core::rocm {
|
||||
|
||||
HipEvent::HipEvent() {
|
||||
CHECK_HIP_ERROR(hipEventCreate(&event_));
|
||||
}
|
||||
|
||||
HipEvent::~HipEvent() {
|
||||
CHECK_HIP_ERROR(hipEventDestroy(event_));
|
||||
}
|
||||
|
||||
void HipEvent::record(hipStream_t stream) {
|
||||
CHECK_HIP_ERROR(hipEventRecord(event_, stream));
|
||||
}
|
||||
|
||||
void HipEvent::wait() {
|
||||
CHECK_HIP_ERROR(hipEventSynchronize(event_));
|
||||
}
|
||||
|
||||
bool HipEvent::query() const {
|
||||
hipError_t status = hipEventQuery(event_);
|
||||
if (status == hipSuccess) {
|
||||
return true;
|
||||
} else if (status == hipErrorNotReady) {
|
||||
return false;
|
||||
} else {
|
||||
CHECK_HIP_ERROR(status);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
SharedEvent::SharedEvent() = default;
|
||||
|
||||
void SharedEvent::notify() {
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
ready_ = true;
|
||||
cv_.notify_one();
|
||||
}
|
||||
|
||||
void SharedEvent::wait() {
|
||||
std::unique_lock<std::mutex> lock(mutex_);
|
||||
cv_.wait(lock, [this] { return ready_; });
|
||||
ready_ = false;
|
||||
}
|
||||
|
||||
} // namespace mlx::core::rocm
|
||||
@@ -1,48 +0,0 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <hip/hip_runtime.h>
|
||||
|
||||
#include <condition_variable>
|
||||
#include <memory>
|
||||
#include <mutex>
|
||||
|
||||
namespace mlx::core::rocm {
|
||||
|
||||
// HIP event managed with RAII.
|
||||
class HipEvent {
|
||||
public:
|
||||
HipEvent();
|
||||
~HipEvent();
|
||||
|
||||
HipEvent(const HipEvent&) = delete;
|
||||
HipEvent& operator=(const HipEvent&) = delete;
|
||||
|
||||
void record(hipStream_t stream);
|
||||
void wait();
|
||||
bool query() const;
|
||||
|
||||
operator hipEvent_t() const {
|
||||
return event_;
|
||||
}
|
||||
|
||||
private:
|
||||
hipEvent_t event_;
|
||||
};
|
||||
|
||||
// Shared event for worker thread synchronization.
|
||||
class SharedEvent {
|
||||
public:
|
||||
SharedEvent();
|
||||
|
||||
void notify();
|
||||
void wait();
|
||||
|
||||
private:
|
||||
std::mutex mutex_;
|
||||
std::condition_variable cv_;
|
||||
bool ready_{false};
|
||||
};
|
||||
|
||||
} // namespace mlx::core::rocm
|
||||
@@ -1,32 +0,0 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include <hip/hip_runtime.h>
|
||||
#include "mlx/backend/rocm/utils.h"
|
||||
|
||||
namespace mlx::core::rocm {
|
||||
|
||||
class Event {
|
||||
public:
|
||||
Event() {
|
||||
check_hip_error("hipEventCreate", hipEventCreate(&event_));
|
||||
}
|
||||
|
||||
~Event() {
|
||||
hipEventDestroy(event_);
|
||||
}
|
||||
|
||||
void record(hipStream_t stream) {
|
||||
check_hip_error("hipEventRecord", hipEventRecord(event_, stream));
|
||||
}
|
||||
|
||||
void wait() {
|
||||
check_hip_error("hipEventSynchronize", hipEventSynchronize(event_));
|
||||
}
|
||||
|
||||
hipEvent_t event() const { return event_; }
|
||||
|
||||
private:
|
||||
hipEvent_t event_;
|
||||
};
|
||||
|
||||
} // namespace mlx::core::rocm
|
||||
@@ -1,9 +0,0 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
namespace mlx::core::rocm {
|
||||
|
||||
void fence() {
|
||||
// Placeholder for ROCm fence operation
|
||||
}
|
||||
|
||||
} // namespace mlx::core::rocm
|
||||
@@ -1,9 +0,0 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
namespace mlx::core::rocm {
|
||||
|
||||
void index() {
|
||||
// Placeholder for ROCm indexing operation
|
||||
}
|
||||
|
||||
} // namespace mlx::core::rocm
|
||||
@@ -1,153 +0,0 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <hip/hip_runtime.h>
|
||||
#include <cstdint>
|
||||
|
||||
namespace mlx::core::rocm {
|
||||
|
||||
template <typename IdxType>
|
||||
struct GeneralIterator {
|
||||
using difference_type = ptrdiff_t;
|
||||
using value_type = IdxType;
|
||||
using pointer = IdxType*;
|
||||
using reference = IdxType&;
|
||||
using iterator_category = std::random_access_iterator_tag;
|
||||
|
||||
const IdxType* base_ptr;
|
||||
IdxType offset;
|
||||
const int* shape;
|
||||
const size_t* strides;
|
||||
int ndim;
|
||||
size_t size;
|
||||
|
||||
__device__ GeneralIterator(
|
||||
const IdxType* base_ptr,
|
||||
IdxType offset,
|
||||
const int* shape,
|
||||
const size_t* strides,
|
||||
int ndim,
|
||||
size_t size)
|
||||
: base_ptr(base_ptr),
|
||||
offset(offset),
|
||||
shape(shape),
|
||||
strides(strides),
|
||||
ndim(ndim),
|
||||
size(size) {}
|
||||
|
||||
__device__ GeneralIterator operator+(difference_type n) const {
|
||||
return GeneralIterator(base_ptr, offset + n, shape, strides, ndim, size);
|
||||
}
|
||||
|
||||
__device__ GeneralIterator operator-(difference_type n) const {
|
||||
return GeneralIterator(base_ptr, offset - n, shape, strides, ndim, size);
|
||||
}
|
||||
|
||||
__device__ difference_type operator-(const GeneralIterator& other) const {
|
||||
return offset - other.offset;
|
||||
}
|
||||
|
||||
__device__ GeneralIterator& operator+=(difference_type n) {
|
||||
offset += n;
|
||||
return *this;
|
||||
}
|
||||
|
||||
__device__ GeneralIterator& operator-=(difference_type n) {
|
||||
offset -= n;
|
||||
return *this;
|
||||
}
|
||||
|
||||
__device__ GeneralIterator& operator++() {
|
||||
++offset;
|
||||
return *this;
|
||||
}
|
||||
|
||||
__device__ GeneralIterator operator++(int) {
|
||||
GeneralIterator temp = *this;
|
||||
++offset;
|
||||
return temp;
|
||||
}
|
||||
|
||||
__device__ GeneralIterator& operator--() {
|
||||
--offset;
|
||||
return *this;
|
||||
}
|
||||
|
||||
__device__ GeneralIterator operator--(int) {
|
||||
GeneralIterator temp = *this;
|
||||
--offset;
|
||||
return temp;
|
||||
}
|
||||
|
||||
__device__ bool operator==(const GeneralIterator& other) const {
|
||||
return offset == other.offset;
|
||||
}
|
||||
|
||||
__device__ bool operator!=(const GeneralIterator& other) const {
|
||||
return offset != other.offset;
|
||||
}
|
||||
|
||||
__device__ bool operator<(const GeneralIterator& other) const {
|
||||
return offset < other.offset;
|
||||
}
|
||||
|
||||
__device__ bool operator>(const GeneralIterator& other) const {
|
||||
return offset > other.offset;
|
||||
}
|
||||
|
||||
__device__ bool operator<=(const GeneralIterator& other) const {
|
||||
return offset <= other.offset;
|
||||
}
|
||||
|
||||
__device__ bool operator>=(const GeneralIterator& other) const {
|
||||
return offset >= other.offset;
|
||||
}
|
||||
|
||||
__device__ IdxType operator*() const {
|
||||
return base_ptr[elem_to_loc(offset, shape, strides, ndim)];
|
||||
}
|
||||
|
||||
__device__ IdxType operator[](difference_type n) const {
|
||||
return base_ptr[elem_to_loc(offset + n, shape, strides, ndim)];
|
||||
}
|
||||
|
||||
private:
|
||||
__device__ size_t elem_to_loc(
|
||||
size_t elem,
|
||||
const int* shape,
|
||||
const size_t* strides,
|
||||
int ndim) const {
|
||||
size_t loc = 0;
|
||||
for (int i = ndim - 1; i >= 0; --i) {
|
||||
auto q_and_r = div(elem, static_cast<size_t>(shape[i]));
|
||||
loc += q_and_r.rem * strides[i];
|
||||
elem = q_and_r.quot;
|
||||
}
|
||||
return loc;
|
||||
}
|
||||
|
||||
__device__ div_t div(size_t numer, size_t denom) const {
|
||||
div_t result;
|
||||
result.quot = numer / denom;
|
||||
result.rem = numer % denom;
|
||||
return result;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename IdxType>
|
||||
__device__ std::pair<GeneralIterator<IdxType>, GeneralIterator<IdxType>>
|
||||
make_general_iterators(
|
||||
const IdxType* base_ptr,
|
||||
size_t size,
|
||||
const int* shape,
|
||||
const size_t* strides,
|
||||
int ndim) {
|
||||
auto begin =
|
||||
GeneralIterator<IdxType>(base_ptr, 0, shape, strides, ndim, size);
|
||||
auto end =
|
||||
GeneralIterator<IdxType>(base_ptr, size, shape, strides, ndim, size);
|
||||
return std::make_pair(begin, end);
|
||||
}
|
||||
|
||||
} // namespace mlx::core::rocm
|
||||
@@ -1,106 +0,0 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <hip/hip_runtime.h>
|
||||
#include <cstdint>
|
||||
|
||||
namespace mlx::core::rocm {
|
||||
|
||||
template <typename T>
|
||||
struct StridedIterator {
|
||||
using difference_type = ptrdiff_t;
|
||||
using value_type = T;
|
||||
using pointer = T*;
|
||||
using reference = T&;
|
||||
using iterator_category = std::random_access_iterator_tag;
|
||||
|
||||
T* ptr;
|
||||
size_t stride;
|
||||
|
||||
__device__ StridedIterator(T* ptr, size_t stride)
|
||||
: ptr(ptr), stride(stride) {}
|
||||
|
||||
__device__ StridedIterator operator+(difference_type n) const {
|
||||
return StridedIterator(ptr + n * stride, stride);
|
||||
}
|
||||
|
||||
__device__ StridedIterator operator-(difference_type n) const {
|
||||
return StridedIterator(ptr - n * stride, stride);
|
||||
}
|
||||
|
||||
__device__ difference_type operator-(const StridedIterator& other) const {
|
||||
return (ptr - other.ptr) / stride;
|
||||
}
|
||||
|
||||
__device__ StridedIterator& operator+=(difference_type n) {
|
||||
ptr += n * stride;
|
||||
return *this;
|
||||
}
|
||||
|
||||
__device__ StridedIterator& operator-=(difference_type n) {
|
||||
ptr -= n * stride;
|
||||
return *this;
|
||||
}
|
||||
|
||||
__device__ StridedIterator& operator++() {
|
||||
ptr += stride;
|
||||
return *this;
|
||||
}
|
||||
|
||||
__device__ StridedIterator operator++(int) {
|
||||
StridedIterator temp = *this;
|
||||
ptr += stride;
|
||||
return temp;
|
||||
}
|
||||
|
||||
__device__ StridedIterator& operator--() {
|
||||
ptr -= stride;
|
||||
return *this;
|
||||
}
|
||||
|
||||
__device__ StridedIterator operator--(int) {
|
||||
StridedIterator temp = *this;
|
||||
ptr -= stride;
|
||||
return temp;
|
||||
}
|
||||
|
||||
__device__ bool operator==(const StridedIterator& other) const {
|
||||
return ptr == other.ptr;
|
||||
}
|
||||
|
||||
__device__ bool operator!=(const StridedIterator& other) const {
|
||||
return ptr != other.ptr;
|
||||
}
|
||||
|
||||
__device__ bool operator<(const StridedIterator& other) const {
|
||||
return ptr < other.ptr;
|
||||
}
|
||||
|
||||
__device__ bool operator>(const StridedIterator& other) const {
|
||||
return ptr > other.ptr;
|
||||
}
|
||||
|
||||
__device__ bool operator<=(const StridedIterator& other) const {
|
||||
return ptr <= other.ptr;
|
||||
}
|
||||
|
||||
__device__ bool operator>=(const StridedIterator& other) const {
|
||||
return ptr >= other.ptr;
|
||||
}
|
||||
|
||||
__device__ T& operator*() const {
|
||||
return *ptr;
|
||||
}
|
||||
|
||||
__device__ T& operator[](difference_type n) const {
|
||||
return *(ptr + n * stride);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
__device__ StridedIterator<T> make_strided_iterator(T* ptr, size_t stride) {
|
||||
return StridedIterator<T>(ptr, stride);
|
||||
}
|
||||
|
||||
} // namespace mlx::core::rocm
|
||||
@@ -1,167 +0,0 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/rocm/jit_module.h"
|
||||
#include "mlx/backend/rocm/utils.h"
|
||||
|
||||
#include <fmt/format.h>
|
||||
#include <mutex>
|
||||
#include <sstream>
|
||||
|
||||
namespace mlx::core::rocm {
|
||||
|
||||
JitModule::JitModule(
|
||||
const std::string& kernel_name,
|
||||
const std::string& kernel_source,
|
||||
const std::vector<std::string>& template_args,
|
||||
const std::vector<std::string>& compiler_flags,
|
||||
bool verbose) {
|
||||
compile(kernel_name, kernel_source, template_args, compiler_flags, verbose);
|
||||
}
|
||||
|
||||
JitModule::~JitModule() {
|
||||
if (kernel_) {
|
||||
// No hipFunctionDestroy equivalent in HIP
|
||||
}
|
||||
if (module_) {
|
||||
CHECK_HIP_ERROR(hipModuleUnload(module_));
|
||||
}
|
||||
if (program_) {
|
||||
hiprtcDestroyProgram(&program_);
|
||||
}
|
||||
}
|
||||
|
||||
void JitModule::compile(
|
||||
const std::string& kernel_name,
|
||||
const std::string& kernel_source,
|
||||
const std::vector<std::string>& template_args,
|
||||
const std::vector<std::string>& compiler_flags,
|
||||
bool verbose) {
|
||||
// Create HIPRTC program
|
||||
CHECK_HIP_ERROR(hiprtcCreateProgram(
|
||||
&program_,
|
||||
kernel_source.c_str(),
|
||||
kernel_name.c_str(),
|
||||
0,
|
||||
nullptr,
|
||||
nullptr));
|
||||
|
||||
// Build compiler options
|
||||
std::vector<const char*> options;
|
||||
std::vector<std::string> option_strings;
|
||||
|
||||
// Add default options
|
||||
option_strings.push_back("--std=c++17");
|
||||
option_strings.push_back("-O3");
|
||||
option_strings.push_back("-DMLX_USE_ROCM");
|
||||
|
||||
// Add user-provided flags
|
||||
for (const auto& flag : compiler_flags) {
|
||||
option_strings.push_back(flag);
|
||||
}
|
||||
|
||||
// Add template arguments
|
||||
for (const auto& arg : template_args) {
|
||||
option_strings.push_back("-D" + arg);
|
||||
}
|
||||
|
||||
// Convert to char* array
|
||||
for (const auto& option : option_strings) {
|
||||
options.push_back(option.c_str());
|
||||
}
|
||||
|
||||
// Compile the program
|
||||
hiprtcResult compile_result =
|
||||
hiprtcCompileProgram(program_, options.size(), options.data());
|
||||
|
||||
// Get compilation log
|
||||
size_t log_size;
|
||||
CHECK_HIP_ERROR(hiprtcGetProgramLogSize(program_, &log_size));
|
||||
|
||||
if (log_size > 1) {
|
||||
std::vector<char> log(log_size);
|
||||
CHECK_HIP_ERROR(hiprtcGetProgramLog(program_, log.data()));
|
||||
|
||||
if (verbose || compile_result != HIPRTC_SUCCESS) {
|
||||
fmt::print(
|
||||
"HIPRTC compilation log for {}:\n{}\n", kernel_name, log.data());
|
||||
}
|
||||
}
|
||||
|
||||
if (compile_result != HIPRTC_SUCCESS) {
|
||||
throw std::runtime_error(
|
||||
fmt::format("HIPRTC compilation failed for kernel {}", kernel_name));
|
||||
}
|
||||
|
||||
// Get compiled code
|
||||
size_t code_size;
|
||||
CHECK_HIP_ERROR(hiprtcGetCodeSize(program_, &code_size));
|
||||
|
||||
std::vector<char> code(code_size);
|
||||
CHECK_HIP_ERROR(hiprtcGetCode(program_, code.data()));
|
||||
|
||||
// Load module
|
||||
CHECK_HIP_ERROR(hipModuleLoadData(&module_, code.data()));
|
||||
|
||||
// Get kernel function
|
||||
CHECK_HIP_ERROR(hipModuleGetFunction(&kernel_, module_, kernel_name.c_str()));
|
||||
}
|
||||
|
||||
JitCache& JitCache::instance() {
|
||||
static JitCache cache;
|
||||
return cache;
|
||||
}
|
||||
|
||||
std::shared_ptr<JitModule> JitCache::get_or_create(
|
||||
const std::string& kernel_name,
|
||||
const std::string& kernel_source,
|
||||
const std::vector<std::string>& template_args,
|
||||
const std::vector<std::string>& compiler_flags) {
|
||||
std::string key =
|
||||
make_key(kernel_name, kernel_source, template_args, compiler_flags);
|
||||
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
|
||||
auto it = cache_.find(key);
|
||||
if (it != cache_.end()) {
|
||||
if (auto module = it->second.lock()) {
|
||||
return module;
|
||||
} else {
|
||||
cache_.erase(it);
|
||||
}
|
||||
}
|
||||
|
||||
auto module = std::make_shared<JitModule>(
|
||||
kernel_name, kernel_source, template_args, compiler_flags);
|
||||
cache_[key] = module;
|
||||
return module;
|
||||
}
|
||||
|
||||
std::string JitCache::make_key(
|
||||
const std::string& kernel_name,
|
||||
const std::string& kernel_source,
|
||||
const std::vector<std::string>& template_args,
|
||||
const std::vector<std::string>& compiler_flags) const {
|
||||
std::ostringstream oss;
|
||||
oss << kernel_name << "|" << kernel_source;
|
||||
|
||||
for (const auto& arg : template_args) {
|
||||
oss << "|" << arg;
|
||||
}
|
||||
|
||||
for (const auto& flag : compiler_flags) {
|
||||
oss << "|" << flag;
|
||||
}
|
||||
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
std::shared_ptr<JitModule> make_jit_kernel(
|
||||
const std::string& kernel_name,
|
||||
const std::string& kernel_source,
|
||||
const std::vector<std::string>& template_args,
|
||||
const std::vector<std::string>& compiler_flags) {
|
||||
return JitCache::instance().get_or_create(
|
||||
kernel_name, kernel_source, template_args, compiler_flags);
|
||||
}
|
||||
|
||||
} // namespace mlx::core::rocm
|
||||
@@ -1,100 +0,0 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <hip/hip_runtime.h>
|
||||
#include <hip/hiprtc.h>
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
namespace mlx::core::rocm {
|
||||
|
||||
// JIT compilation module for ROCm
|
||||
class JitModule {
|
||||
public:
|
||||
JitModule(
|
||||
const std::string& kernel_name,
|
||||
const std::string& kernel_source,
|
||||
const std::vector<std::string>& template_args = {},
|
||||
const std::vector<std::string>& compiler_flags = {},
|
||||
bool verbose = false);
|
||||
|
||||
~JitModule();
|
||||
|
||||
JitModule(const JitModule&) = delete;
|
||||
JitModule& operator=(const JitModule&) = delete;
|
||||
|
||||
// Get the compiled kernel function
|
||||
hipFunction_t get_kernel() const {
|
||||
return kernel_;
|
||||
}
|
||||
|
||||
// Launch the kernel with given arguments
|
||||
template <typename... Args>
|
||||
void launch(
|
||||
dim3 grid_dims,
|
||||
dim3 block_dims,
|
||||
size_t shared_memory,
|
||||
hipStream_t stream,
|
||||
Args&&... args) {
|
||||
void* kernel_args[] = {(void*)&args...};
|
||||
CHECK_HIP_ERROR(hipModuleLaunchKernel(
|
||||
kernel_,
|
||||
grid_dims.x,
|
||||
grid_dims.y,
|
||||
grid_dims.z,
|
||||
block_dims.x,
|
||||
block_dims.y,
|
||||
block_dims.z,
|
||||
shared_memory,
|
||||
stream,
|
||||
kernel_args,
|
||||
nullptr));
|
||||
}
|
||||
|
||||
private:
|
||||
void compile(
|
||||
const std::string& kernel_name,
|
||||
const std::string& kernel_source,
|
||||
const std::vector<std::string>& template_args,
|
||||
const std::vector<std::string>& compiler_flags,
|
||||
bool verbose);
|
||||
|
||||
hiprtcProgram program_{nullptr};
|
||||
hipModule_t module_{nullptr};
|
||||
hipFunction_t kernel_{nullptr};
|
||||
};
|
||||
|
||||
// JIT cache for compiled modules
|
||||
class JitCache {
|
||||
public:
|
||||
static JitCache& instance();
|
||||
|
||||
std::shared_ptr<JitModule> get_or_create(
|
||||
const std::string& kernel_name,
|
||||
const std::string& kernel_source,
|
||||
const std::vector<std::string>& template_args = {},
|
||||
const std::vector<std::string>& compiler_flags = {});
|
||||
|
||||
private:
|
||||
std::unordered_map<std::string, std::weak_ptr<JitModule>> cache_;
|
||||
std::mutex mutex_;
|
||||
|
||||
std::string make_key(
|
||||
const std::string& kernel_name,
|
||||
const std::string& kernel_source,
|
||||
const std::vector<std::string>& template_args,
|
||||
const std::vector<std::string>& compiler_flags) const;
|
||||
};
|
||||
|
||||
// Helper function to create and cache JIT modules
|
||||
std::shared_ptr<JitModule> make_jit_kernel(
|
||||
const std::string& kernel_name,
|
||||
const std::string& kernel_source,
|
||||
const std::vector<std::string>& template_args = {},
|
||||
const std::vector<std::string>& compiler_flags = {});
|
||||
|
||||
} // namespace mlx::core::rocm
|
||||
@@ -1,29 +0,0 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include <hip/hip_runtime.h>
|
||||
|
||||
namespace mlx::core::rocm {
|
||||
|
||||
// Utility functions for HIP kernels
|
||||
|
||||
__device__ inline int get_global_id() {
|
||||
return blockIdx.x * blockDim.x + threadIdx.x;
|
||||
}
|
||||
|
||||
__device__ inline int get_local_id() {
|
||||
return threadIdx.x;
|
||||
}
|
||||
|
||||
__device__ inline int get_group_id() {
|
||||
return blockIdx.x;
|
||||
}
|
||||
|
||||
__device__ inline int get_local_size() {
|
||||
return blockDim.x;
|
||||
}
|
||||
|
||||
__device__ inline int get_num_groups() {
|
||||
return gridDim.x;
|
||||
}
|
||||
|
||||
} // namespace mlx::core::rocm
|
||||
@@ -1,135 +0,0 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <hip/hip_runtime.h>
|
||||
#include <array>
|
||||
|
||||
namespace mlx::core::rocm {
|
||||
|
||||
// Constants
|
||||
constexpr int MAX_DIMS = 8;
|
||||
|
||||
// HIP array type for passing arrays to kernels
|
||||
template <typename T, int N>
|
||||
using hip_array = std::array<T, N>;
|
||||
|
||||
// Helper to create hip_array from vector
|
||||
template <int N, typename T>
|
||||
__host__ hip_array<T, N> make_hip_array(const std::vector<T>& vec) {
|
||||
hip_array<T, N> arr;
|
||||
for (int i = 0; i < N && i < vec.size(); ++i) {
|
||||
arr[i] = vec[i];
|
||||
}
|
||||
return arr;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__host__ hip_array<T, MAX_DIMS> make_hip_array(const std::vector<T>& vec) {
|
||||
return make_hip_array<MAX_DIMS>(vec);
|
||||
}
|
||||
|
||||
// Type mapping from MLX types to HIP types
|
||||
template <typename T>
|
||||
using hip_type_t = T;
|
||||
|
||||
template <>
|
||||
using hip_type_t<float16> = __half;
|
||||
|
||||
template <>
|
||||
using hip_type_t<bfloat16> = __hip_bfloat16;
|
||||
|
||||
template <>
|
||||
using hip_type_t<complex64> = hipFloatComplex;
|
||||
|
||||
// Element to location mapping for general broadcasting
|
||||
template <int NDIM>
|
||||
__device__ std::pair<int64_t, int64_t> elem_to_loc_nd(
|
||||
int64_t elem,
|
||||
const int32_t* shape,
|
||||
const int64_t* a_strides,
|
||||
const int64_t* b_strides) {
|
||||
int64_t a_idx = 0;
|
||||
int64_t b_idx = 0;
|
||||
|
||||
for (int i = NDIM - 1; i >= 0; --i) {
|
||||
int64_t pos_in_dim = elem % shape[i];
|
||||
elem /= shape[i];
|
||||
a_idx += pos_in_dim * a_strides[i];
|
||||
b_idx += pos_in_dim * b_strides[i];
|
||||
}
|
||||
|
||||
return {a_idx, b_idx};
|
||||
}
|
||||
|
||||
// 4D specialization for performance
|
||||
__device__ inline std::pair<int64_t, int64_t> elem_to_loc_4d(
|
||||
int64_t elem,
|
||||
const int32_t* shape,
|
||||
const int64_t* a_strides,
|
||||
const int64_t* b_strides,
|
||||
int ndim) {
|
||||
int64_t a_idx = 0;
|
||||
int64_t b_idx = 0;
|
||||
|
||||
for (int i = ndim - 1; i >= 0; --i) {
|
||||
int64_t pos_in_dim = elem % shape[i];
|
||||
elem /= shape[i];
|
||||
a_idx += pos_in_dim * a_strides[i];
|
||||
b_idx += pos_in_dim * b_strides[i];
|
||||
}
|
||||
|
||||
return {a_idx, b_idx};
|
||||
}
|
||||
|
||||
// Launch configuration calculation
|
||||
template <typename Kernel>
|
||||
std::pair<dim3, dim3>
|
||||
get_launch_args(Kernel kernel, const array& out, bool large = false) {
|
||||
int threads_per_block = 256;
|
||||
int64_t total_threads = out.size();
|
||||
|
||||
if (large) {
|
||||
// For large arrays, use more blocks
|
||||
int64_t blocks =
|
||||
(total_threads + threads_per_block - 1) / threads_per_block;
|
||||
return {dim3(blocks), dim3(threads_per_block)};
|
||||
} else {
|
||||
int blocks = (total_threads + threads_per_block - 1) / threads_per_block;
|
||||
return {dim3(blocks), dim3(threads_per_block)};
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Kernel>
|
||||
std::pair<dim3, dim3> get_launch_args(
|
||||
Kernel kernel,
|
||||
int64_t size,
|
||||
const std::vector<int>& shape,
|
||||
const std::vector<size_t>& strides,
|
||||
bool large = false) {
|
||||
int threads_per_block = 256;
|
||||
|
||||
if (large) {
|
||||
int64_t blocks = (size + threads_per_block - 1) / threads_per_block;
|
||||
return {dim3(blocks), dim3(threads_per_block)};
|
||||
} else {
|
||||
int blocks = (size + threads_per_block - 1) / threads_per_block;
|
||||
return {dim3(blocks), dim3(threads_per_block)};
|
||||
}
|
||||
}
|
||||
|
||||
// Cooperative groups thread rank equivalent
|
||||
namespace cooperative_groups {
|
||||
class grid_group {
|
||||
public:
|
||||
__device__ int64_t thread_rank() const {
|
||||
return blockIdx.x * blockDim.x + threadIdx.x;
|
||||
}
|
||||
};
|
||||
|
||||
__device__ grid_group this_grid() {
|
||||
return grid_group{};
|
||||
}
|
||||
} // namespace cooperative_groups
|
||||
|
||||
} // namespace mlx::core::rocm
|
||||
@@ -1,437 +0,0 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/rocm/device.h"
|
||||
#include "mlx/backend/rocm/iterators/strided_iterator.hpp"
|
||||
#include "mlx/backend/rocm/kernel_utils.hpp"
|
||||
#include "mlx/backend/rocm/reduce/reduce.hpp"
|
||||
#include "mlx/backend/gpu/copy.h"
|
||||
#include "mlx/dtype_utils.h"
|
||||
#include "mlx/fast_primitives.h"
|
||||
|
||||
#include <hip/hip_runtime.h>
|
||||
#include <hip/hip_cooperative_groups.h>
|
||||
#include <rocprim/block/block_load.hpp>
|
||||
#include <rocprim/block/block_reduce.hpp>
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace rocm {
|
||||
|
||||
namespace cg = cooperative_groups;
|
||||
|
||||
inline __device__ float3 plus_f3(const float3& a, const float3& b) {
|
||||
return {a.x + b.x, a.y + b.y, a.z + b.z};
|
||||
}
|
||||
|
||||
// Similar to rocprim::BlockReduce, but result is broadcasted to every thread.
|
||||
template <typename T, int BLOCK_DIM>
|
||||
struct BlockBroadcastReduce {
|
||||
static_assert(WARP_SIZE <= BLOCK_DIM && BLOCK_DIM <= WARP_SIZE * WARP_SIZE);
|
||||
static_assert(BLOCK_DIM % WARP_SIZE == 0);
|
||||
using TempStorage = T[BLOCK_DIM / WARP_SIZE];
|
||||
|
||||
cg::thread_block& block;
|
||||
TempStorage& temp;
|
||||
|
||||
template <typename Op>
|
||||
__device__ T Reduce(const T& input, const Op& op, const T& init_value) {
|
||||
auto warp = cg::tiled_partition<WARP_SIZE>(block);
|
||||
T x = cg::reduce(warp, input, op);
|
||||
if (warp.thread_rank() == 0) {
|
||||
temp[warp.meta_group_rank()] = x;
|
||||
}
|
||||
block.sync();
|
||||
x = warp.thread_rank() < warp.meta_group_size() ? temp[warp.thread_rank()]
|
||||
: init_value;
|
||||
return cg::reduce(warp, x, op);
|
||||
}
|
||||
|
||||
__device__ T Sum(const T& input) {
|
||||
return Reduce(input, hip_plus<T>{}, T{});
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, int BLOCK_DIM, int N_READS = 4>
|
||||
__global__ void layer_norm(
|
||||
const T* x,
|
||||
const T* w,
|
||||
const T* b,
|
||||
T* out,
|
||||
float eps,
|
||||
int32_t axis_size,
|
||||
int64_t w_stride,
|
||||
int64_t b_stride) {
|
||||
auto grid = cg::this_grid();
|
||||
auto block = cg::this_thread_block();
|
||||
|
||||
using BlockReduceT = BlockBroadcastReduce<float, BLOCK_DIM>;
|
||||
__shared__ typename BlockReduceT::TempStorage temp;
|
||||
|
||||
x += grid.block_rank() * axis_size;
|
||||
out += grid.block_rank() * axis_size;
|
||||
|
||||
// Sum.
|
||||
float sum = 0;
|
||||
for (int r = 0; r < hip_ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) {
|
||||
auto index = r * BLOCK_DIM + block.thread_rank();
|
||||
T xn[N_READS] = {};
|
||||
rocprim::block_load_direct_blocked(index, x, xn, axis_size);
|
||||
sum += static_cast<float>(rocprim::thread_reduce(xn, hip_plus<T>{}));
|
||||
}
|
||||
sum = BlockReduceT{block, temp}.Sum(sum);
|
||||
|
||||
// Mean.
|
||||
float mean = sum / axis_size;
|
||||
|
||||
// Normalizer.
|
||||
float normalizer = 0;
|
||||
for (int r = 0; r < hip_ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) {
|
||||
auto index = r * BLOCK_DIM + block.thread_rank();
|
||||
T xn[N_READS];
|
||||
rocprim::block_load_direct_blocked(index, x, xn, axis_size, mean);
|
||||
for (int i = 0; i < N_READS; ++i) {
|
||||
float t = static_cast<float>(xn[i]) - mean;
|
||||
normalizer += t * t;
|
||||
}
|
||||
}
|
||||
normalizer = BlockReduceT{block, temp}.Sum(normalizer);
|
||||
normalizer = rsqrt(normalizer / axis_size + eps);
|
||||
|
||||
// Outputs.
|
||||
for (int r = 0; r < hip_ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) {
|
||||
auto index = r * BLOCK_DIM + block.thread_rank();
|
||||
T xn[N_READS];
|
||||
T wn[N_READS];
|
||||
T bn[N_READS];
|
||||
rocprim::block_load_direct_blocked(index, x, xn, axis_size);
|
||||
rocprim::block_load_direct_blocked(index, strided_iterator(w, w_stride), wn, axis_size);
|
||||
rocprim::block_load_direct_blocked(index, strided_iterator(b, b_stride), bn, axis_size);
|
||||
for (int i = 0; i < N_READS; ++i) {
|
||||
float norm = (static_cast<float>(xn[i]) - mean) * normalizer;
|
||||
xn[i] = wn[i] * static_cast<T>(norm) + bn[i];
|
||||
}
|
||||
rocprim::block_store_direct_blocked(index, out, xn, axis_size);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, bool HAS_W, int BLOCK_DIM, int N_READS = 4>
|
||||
__global__ void layer_norm_vjp(
|
||||
const T* x,
|
||||
const T* w,
|
||||
const T* g,
|
||||
T* gx,
|
||||
T* gw,
|
||||
float eps,
|
||||
int32_t axis_size,
|
||||
int64_t w_stride) {
|
||||
auto grid = cg::this_grid();
|
||||
auto block = cg::this_thread_block();
|
||||
|
||||
using BlockReduceF = BlockBroadcastReduce<float, BLOCK_DIM>;
|
||||
using BlockReduceF3 = BlockBroadcastReduce<float3, BLOCK_DIM>;
|
||||
__shared__ union {
|
||||
typename BlockReduceF::TempStorage f;
|
||||
typename BlockReduceF3::TempStorage f3;
|
||||
} temp;
|
||||
|
||||
x += grid.block_rank() * axis_size;
|
||||
g += grid.block_rank() * axis_size;
|
||||
gx += grid.block_rank() * axis_size;
|
||||
gw += grid.block_rank() * axis_size;
|
||||
|
||||
// Sum.
|
||||
float sum = 0;
|
||||
for (int r = 0; r < hip_ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) {
|
||||
auto index = r * BLOCK_DIM + block.thread_rank();
|
||||
T xn[N_READS] = {};
|
||||
rocprim::block_load_direct_blocked(index, x, xn, axis_size);
|
||||
sum += static_cast<float>(rocprim::thread_reduce(xn, hip_plus<T>{}));
|
||||
}
|
||||
sum = BlockReduceF{block, temp.f}.Sum(sum);
|
||||
|
||||
// Mean.
|
||||
float mean = sum / axis_size;
|
||||
|
||||
// Normalizer.
|
||||
float3 factors = {};
|
||||
for (int r = 0; r < hip_ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) {
|
||||
T xn[N_READS];
|
||||
T wn[N_READS] = {};
|
||||
T gn[N_READS] = {};
|
||||
auto index = r * BLOCK_DIM + block.thread_rank();
|
||||
rocprim::block_load_direct_blocked(index, x, xn, axis_size, mean);
|
||||
rocprim::block_load_direct_blocked(index, g, gn, axis_size);
|
||||
rocprim::block_load_direct_blocked(index, strided_iterator(w, w_stride), wn, axis_size);
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
float t = static_cast<float>(xn[i]) - mean;
|
||||
float wi = wn[i];
|
||||
float gi = gn[i];
|
||||
float wg = wi * gi;
|
||||
factors = plus_f3(factors, {wg, wg * t, t * t});
|
||||
}
|
||||
}
|
||||
factors = BlockReduceF3{block, temp.f3}.Reduce(factors, plus_f3, {});
|
||||
float meanwg = factors.x / axis_size;
|
||||
float meanwgxc = factors.y / axis_size;
|
||||
float normalizer2 = 1 / (factors.z / axis_size + eps);
|
||||
float normalizer = sqrt(normalizer2);
|
||||
|
||||
// Outputs.
|
||||
for (int r = 0; r < hip_ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) {
|
||||
auto index = r * BLOCK_DIM + block.thread_rank();
|
||||
T xn[N_READS];
|
||||
T wn[N_READS];
|
||||
T gn[N_READS];
|
||||
rocprim::block_load_direct_blocked(index, x, xn, axis_size);
|
||||
rocprim::block_load_direct_blocked(index, g, gn, axis_size);
|
||||
rocprim::block_load_direct_blocked(index, strided_iterator(w, w_stride), wn, axis_size);
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
float xi = (static_cast<float>(xn[i]) - mean) * normalizer;
|
||||
float wi = wn[i];
|
||||
float gi = gn[i];
|
||||
xn[i] = normalizer * (wi * gi - meanwg) - xi * meanwgxc * normalizer2;
|
||||
if constexpr (HAS_W) {
|
||||
wn[i] = gi * xi;
|
||||
}
|
||||
}
|
||||
rocprim::block_store_direct_blocked(index, gx, xn, axis_size);
|
||||
if constexpr (HAS_W) {
|
||||
rocprim::block_store_direct_blocked(index, gw, wn, axis_size);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Utility functions
|
||||
template <typename T>
|
||||
struct hip_plus {
|
||||
__device__ T operator()(const T& a, const T& b) const {
|
||||
return a + b;
|
||||
}
|
||||
};
|
||||
|
||||
inline __device__ int hip_ceil_div(int a, int b) {
|
||||
return (a + b - 1) / b;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__device__ inline auto strided_iterator(const T* ptr, int64_t stride) {
|
||||
return ptr + stride; // Simplified strided iterator
|
||||
}
|
||||
|
||||
} // namespace rocm
|
||||
|
||||
namespace fast {
|
||||
|
||||
bool LayerNorm::use_fallback(Stream s) {
|
||||
return s.device == Device::cpu;
|
||||
}
|
||||
|
||||
// TODO: There are duplicate code with backend/metal/normalization.cpp
|
||||
void LayerNorm::eval_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
auto& s = stream();
|
||||
auto& out = outputs[0];
|
||||
|
||||
// Make sure that the last dimension is contiguous.
|
||||
auto set_output = [&s, &out](const array& x) {
|
||||
bool no_copy = x.flags().contiguous && x.strides()[x.ndim() - 1] == 1;
|
||||
if (no_copy && x.ndim() > 1) {
|
||||
auto s = x.strides()[x.ndim() - 2];
|
||||
no_copy &= (s == 0 || s == x.shape().back());
|
||||
}
|
||||
if (no_copy) {
|
||||
if (x.is_donatable()) {
|
||||
out.copy_shared_buffer(x);
|
||||
} else {
|
||||
out.set_data(
|
||||
allocator::malloc(x.data_size() * x.itemsize()),
|
||||
x.data_size(),
|
||||
x.strides(),
|
||||
x.flags());
|
||||
}
|
||||
return x;
|
||||
} else {
|
||||
auto x_copy = array(x.shape(), x.dtype(), nullptr, {});
|
||||
copy_gpu(x, x_copy, CopyType::General, s);
|
||||
out.copy_shared_buffer(x_copy);
|
||||
return x_copy;
|
||||
}
|
||||
};
|
||||
|
||||
const array x = set_output(inputs[0]);
|
||||
const array& w = inputs[1];
|
||||
const array& b = inputs[2];
|
||||
|
||||
int32_t axis_size = x.shape().back();
|
||||
int32_t n_rows = x.data_size() / axis_size;
|
||||
int64_t w_stride = (w.ndim() == 1) ? w.strides()[0] : 0;
|
||||
int64_t b_stride = (b.ndim() == 1) ? b.strides()[0] : 0;
|
||||
|
||||
auto& encoder = rocm::get_command_encoder(s);
|
||||
encoder.set_input_array(x);
|
||||
encoder.set_input_array(w);
|
||||
encoder.set_input_array(b);
|
||||
encoder.set_output_array(out);
|
||||
encoder.launch_kernel([&](hipStream_t stream) {
|
||||
MLX_SWITCH_FLOAT_TYPES_CHECKED(out.dtype(), "layernorm", CTYPE, {
|
||||
using DataType = hip_type_t<CTYPE>;
|
||||
constexpr uint32_t N_READS = 4;
|
||||
MLX_SWITCH_BLOCK_DIM(rocm::hip_ceil_div(axis_size, N_READS), BLOCK_DIM, {
|
||||
auto kernel = rocm::layer_norm<DataType, BLOCK_DIM, N_READS>;
|
||||
hipLaunchKernelGGL(kernel, n_rows, BLOCK_DIM, 0, stream,
|
||||
x.data<DataType>(),
|
||||
w.data<DataType>(),
|
||||
b.data<DataType>(),
|
||||
out.data<DataType>(),
|
||||
eps_,
|
||||
axis_size,
|
||||
w_stride,
|
||||
b_stride);
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
void LayerNormVJP::eval_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
auto& s = stream();
|
||||
auto& encoder = rocm::get_command_encoder(s);
|
||||
|
||||
// Ensure row contiguity. We could relax this step by checking that the array
|
||||
// is contiguous (no broadcasts or holes) and that the input strides are the
|
||||
// same as the cotangent strides but for now this is simpler.
|
||||
auto check_input = [&s](const array& x) -> std::pair<array, bool> {
|
||||
if (x.flags().row_contiguous) {
|
||||
return {x, false};
|
||||
}
|
||||
array x_copy(x.shape(), x.dtype(), nullptr, {});
|
||||
copy_gpu(x, x_copy, CopyType::General, s);
|
||||
return {x_copy, true};
|
||||
};
|
||||
bool donate_x = inputs[0].is_donatable();
|
||||
bool donate_g = inputs[3].is_donatable();
|
||||
auto [x, copied] = check_input(inputs[0]);
|
||||
donate_x |= copied;
|
||||
const array& w = inputs[1];
|
||||
const array& b = inputs[2];
|
||||
auto [g, g_copied] = check_input(inputs[3]);
|
||||
donate_g |= g_copied;
|
||||
array& gx = outputs[0];
|
||||
array& gw = outputs[1];
|
||||
array& gb = outputs[2];
|
||||
|
||||
// Check whether we had a weight.
|
||||
bool has_w = w.ndim() != 0;
|
||||
|
||||
// Allocate space for the outputs.
|
||||
bool g_in_gx = false;
|
||||
if (donate_x) {
|
||||
gx.copy_shared_buffer(x);
|
||||
} else if (donate_g) {
|
||||
gx.copy_shared_buffer(g);
|
||||
g_in_gx = true;
|
||||
} else {
|
||||
gx.set_data(allocator::malloc(gx.nbytes()));
|
||||
}
|
||||
if (g_copied && !g_in_gx) {
|
||||
encoder.add_temporary(g);
|
||||
}
|
||||
|
||||
int32_t axis_size = x.shape().back();
|
||||
int32_t n_rows = x.data_size() / axis_size;
|
||||
int64_t w_stride = (w.ndim() == 1) ? w.strides()[0] : 0;
|
||||
|
||||
// Allocate a temporary to store the gradients for w and allocate the output
|
||||
// gradient accumulators.
|
||||
array gw_temp =
|
||||
(has_w) ? array({n_rows, x.shape().back()}, gw.dtype(), nullptr, {}) : w;
|
||||
if (has_w) {
|
||||
if (!g_in_gx && donate_g) {
|
||||
gw_temp.copy_shared_buffer(g);
|
||||
} else {
|
||||
gw_temp.set_data(allocator::malloc(gw_temp.nbytes()));
|
||||
encoder.add_temporary(gw_temp);
|
||||
}
|
||||
}
|
||||
gw.set_data(allocator::malloc(gw.nbytes()));
|
||||
gb.set_data(allocator::malloc(gb.nbytes()));
|
||||
|
||||
// Finish with the gradient for b in case we had a b.
|
||||
if (gb.ndim() == 1 && gb.size() == axis_size) {
|
||||
ReductionPlan plan(
|
||||
ReductionOpType::ContiguousStridedReduce, {n_rows}, {axis_size});
|
||||
col_reduce(encoder, g, gb, Reduce::ReduceType::Sum, {0}, plan);
|
||||
}
|
||||
|
||||
encoder.set_input_array(x);
|
||||
encoder.set_input_array(w);
|
||||
encoder.set_input_array(g);
|
||||
encoder.set_output_array(gx);
|
||||
encoder.set_output_array(gw_temp);
|
||||
encoder.launch_kernel([&, x = x, g = g](hipStream_t stream) {
|
||||
MLX_SWITCH_FLOAT_TYPES_CHECKED(gx.dtype(), "layernorm_vjp", CTYPE, {
|
||||
using DataType = hip_type_t<CTYPE>;
|
||||
constexpr int N_READS = 4;
|
||||
MLX_SWITCH_BOOL(has_w, HAS_W, {
|
||||
MLX_SWITCH_BLOCK_DIM(rocm::hip_ceil_div(axis_size, N_READS), BLOCK_DIM, {
|
||||
auto kernel = rocm::layer_norm_vjp<DataType, HAS_W, BLOCK_DIM, N_READS>;
|
||||
hipLaunchKernelGGL(kernel, n_rows, BLOCK_DIM, 0, stream,
|
||||
x.data<DataType>(),
|
||||
w.data<DataType>(),
|
||||
g.data<DataType>(),
|
||||
gx.data<DataType>(),
|
||||
gw_temp.data<DataType>(),
|
||||
eps_,
|
||||
axis_size,
|
||||
w_stride);
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
if (has_w) {
|
||||
ReductionPlan plan(
|
||||
ReductionOpType::ContiguousStridedReduce, {n_rows}, {axis_size});
|
||||
col_reduce(encoder, gw_temp, gw, Reduce::ReduceType::Sum, {0}, plan);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace fast
|
||||
|
||||
} // namespace mlx::core
|
||||
|
||||
namespace mlx::core::rocm {
|
||||
|
||||
__global__ void layer_norm_kernel(
|
||||
float* input,
|
||||
float* output,
|
||||
float* gamma,
|
||||
float* beta,
|
||||
int n,
|
||||
float eps) {
|
||||
int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
|
||||
if (idx < n) {
|
||||
// Simplified layer norm placeholder
|
||||
// Real implementation would compute mean and variance
|
||||
output[idx] = gamma[idx] * input[idx] + beta[idx];
|
||||
}
|
||||
}
|
||||
|
||||
void launch_layer_norm(
|
||||
float* input,
|
||||
float* output,
|
||||
float* gamma,
|
||||
float* beta,
|
||||
int n,
|
||||
float eps,
|
||||
hipStream_t stream) {
|
||||
int threads = 256;
|
||||
int blocks = (n + threads - 1) / threads;
|
||||
hipLaunchKernelGGL(layer_norm_kernel, dim3(blocks), dim3(threads), 0, stream,
|
||||
input, output, gamma, beta, n, eps);
|
||||
}
|
||||
|
||||
} // namespace mlx::core::rocm
|
||||
@@ -1,13 +0,0 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include <hip/hip_runtime.h>
|
||||
|
||||
namespace mlx::core::rocm {
|
||||
|
||||
__global__ void logsumexp_kernel(float* input, float* output, int n) {
|
||||
// Placeholder implementation
|
||||
int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
(void)input; (void)output; (void)n; (void)idx;
|
||||
}
|
||||
|
||||
} // namespace mlx::core::rocm
|
||||
@@ -1,30 +0,0 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/rocm/device.h"
|
||||
#include "mlx/backend/rocm/utils.h"
|
||||
|
||||
namespace mlx::core::rocm {
|
||||
|
||||
void matmul_hip(
|
||||
float* a,
|
||||
float* b,
|
||||
float* c,
|
||||
int m,
|
||||
int n,
|
||||
int k,
|
||||
hipStream_t stream) {
|
||||
// This is a placeholder - in a real implementation, this would use rocBLAS
|
||||
// auto& device = get_current_device();
|
||||
// rocblas_sgemm(device.rocblas_handle(), ...);
|
||||
|
||||
// For now, just a placeholder
|
||||
(void)a;
|
||||
(void)b;
|
||||
(void)c;
|
||||
(void)m;
|
||||
(void)n;
|
||||
(void)k;
|
||||
(void)stream;
|
||||
}
|
||||
|
||||
} // namespace mlx::core::rocm
|
||||
@@ -1,11 +0,0 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/rocm/rocm.h"
|
||||
|
||||
namespace mlx::core::rocm {
|
||||
|
||||
bool is_available() {
|
||||
return false;
|
||||
}
|
||||
|
||||
} // namespace mlx::core::rocm
|
||||
@@ -1,21 +0,0 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include <hip/hip_runtime.h>
|
||||
|
||||
#include "mlx/backend/rocm/utils.h"
|
||||
#include "mlx/backend/common/primitives.h"
|
||||
|
||||
namespace mlx::core::rocm {
|
||||
|
||||
// Basic kernel implementations will go here
|
||||
// This is a placeholder for ROCm-specific primitive operations
|
||||
|
||||
void add_hip() {
|
||||
// Placeholder for HIP add operation
|
||||
}
|
||||
|
||||
void multiply_hip() {
|
||||
// Placeholder for HIP multiply operation
|
||||
}
|
||||
|
||||
} // namespace mlx::core::rocm
|
||||
@@ -1,23 +0,0 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include <hip/hip_runtime.h>
|
||||
|
||||
namespace mlx::core::rocm {
|
||||
|
||||
__global__ void random_uniform_kernel(float* output, int n, unsigned int seed) {
|
||||
int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (idx < n) {
|
||||
// Simple LCG placeholder - real implementation would use rocRAND
|
||||
unsigned int state = seed + idx;
|
||||
state = state * 1103515245 + 12345;
|
||||
output[idx] = (float)(state & 0x7FFFFFFF) / (float)0x7FFFFFFF;
|
||||
}
|
||||
}
|
||||
|
||||
void launch_random_uniform(float* output, int n, unsigned int seed, hipStream_t stream) {
|
||||
int threads = 256;
|
||||
int blocks = (n + threads - 1) / threads;
|
||||
hipLaunchKernelGGL(random_uniform_kernel, dim3(blocks), dim3(threads), 0, stream, output, n, seed);
|
||||
}
|
||||
|
||||
} // namespace mlx::core::rocm
|
||||
@@ -1,24 +0,0 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include <hip/hip_runtime.h>
|
||||
|
||||
namespace mlx::core::rocm {
|
||||
|
||||
__global__ void sum_reduce_kernel(float* input, float* output, int n) {
|
||||
int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
|
||||
// Simple reduction placeholder
|
||||
if (idx == 0) {
|
||||
float sum = 0.0f;
|
||||
for (int i = 0; i < n; i++) {
|
||||
sum += input[i];
|
||||
}
|
||||
output[0] = sum;
|
||||
}
|
||||
}
|
||||
|
||||
void launch_sum_reduce(float* input, float* output, int n, hipStream_t stream) {
|
||||
hipLaunchKernelGGL(sum_reduce_kernel, dim3(1), dim3(1), 0, stream, input, output, n);
|
||||
}
|
||||
|
||||
} // namespace mlx::core::rocm
|
||||
@@ -1,311 +0,0 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/rocm/device.h"
|
||||
#include "mlx/backend/rocm/device/cast_op.hpp"
|
||||
#include "mlx/backend/rocm/reduce/reduce.hpp"
|
||||
|
||||
#include <hip/hip_runtime.h>
|
||||
#include <hip/hip_cooperative_groups.h>
|
||||
#include <rocprim/block/block_load.hpp>
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace rocm {
|
||||
|
||||
namespace cg = cooperative_groups;
|
||||
|
||||
struct ColReduceArgs {
|
||||
// The size of the contiguous column reduction.
|
||||
size_t reduction_size;
|
||||
int64_t reduction_stride;
|
||||
|
||||
// Input shape and strides excluding the reduction axes.
|
||||
Shape shape;
|
||||
Strides strides;
|
||||
int ndim;
|
||||
|
||||
// Input shape and strides of the reduction axes (including last dimension).
|
||||
Shape reduce_shape;
|
||||
Strides reduce_strides;
|
||||
int reduce_ndim;
|
||||
|
||||
// The number of column we are reducing. Namely prod(reduce_shape).
|
||||
size_t non_col_reductions;
|
||||
|
||||
ColReduceArgs(
|
||||
const array& in,
|
||||
const ReductionPlan& plan,
|
||||
const std::vector<int>& axes) {
|
||||
assert(!plan.shape.empty());
|
||||
reduction_size = plan.shape.back();
|
||||
reduction_stride = plan.strides.back();
|
||||
|
||||
int64_t stride_back = 1;
|
||||
auto [shape_vec, strides_vec] = shapes_without_reduction_axes(in, axes);
|
||||
while (!shape_vec.empty() && stride_back < reduction_stride) {
|
||||
stride_back *= shape_vec.back();
|
||||
shape_vec.pop_back();
|
||||
strides_vec.pop_back();
|
||||
}
|
||||
std::tie(shape_vec, strides_vec) =
|
||||
collapse_contiguous_dims(shape_vec, strides_vec);
|
||||
shape = const_param(shape_vec);
|
||||
strides = const_param(strides_vec);
|
||||
ndim = shape_vec.size();
|
||||
|
||||
reduce_shape = const_param(plan.shape);
|
||||
reduce_strides = const_param(plan.strides);
|
||||
reduce_ndim = plan.shape.size();
|
||||
|
||||
non_col_reductions = 1;
|
||||
for (int i = 0; i < reduce_ndim - 1; i++) {
|
||||
non_col_reductions *= reduce_shape[i];
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, typename U, typename Op, int NDIM, int N_READS = 4>
|
||||
__global__ void col_reduce_small(
|
||||
const T* in,
|
||||
U* out,
|
||||
const ColReduceArgs args) {
|
||||
auto grid = cg::this_grid();
|
||||
auto block = cg::this_thread_block();
|
||||
|
||||
int column =
|
||||
grid.block_index().x * block.dim_threads().x + block.thread_index().x;
|
||||
if (column * N_READS >= args.reduction_stride) {
|
||||
return;
|
||||
}
|
||||
|
||||
int out_idx = grid.block_rank() / grid.dim_blocks().x;
|
||||
in += elem_to_loc(out_idx, args.shape.data(), args.strides.data(), args.ndim);
|
||||
|
||||
Op op;
|
||||
U totals[N_READS];
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
totals[i] = ReduceInit<Op, T>::value();
|
||||
}
|
||||
|
||||
// Read input to local.
|
||||
LoopedElemToLoc<NDIM, (NDIM > 2)> loop(args.reduce_ndim);
|
||||
loop.next(
|
||||
block.thread_index().y,
|
||||
args.reduce_shape.data(),
|
||||
args.reduce_strides.data());
|
||||
for (size_t r = block.thread_index().y;
|
||||
r < args.non_col_reductions * args.reduction_size;
|
||||
r += block.dim_threads().y) {
|
||||
U vals[N_READS];
|
||||
rocprim::block_load_direct_blocked(
|
||||
column,
|
||||
make_cast_iterator<U>(in + loop.location()),
|
||||
vals,
|
||||
args.reduction_stride,
|
||||
ReduceInit<Op, T>::value());
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
totals[i] = op(vals[i], totals[i]);
|
||||
}
|
||||
loop.next(
|
||||
block.dim_threads().y,
|
||||
args.reduce_shape.data(),
|
||||
args.reduce_strides.data());
|
||||
}
|
||||
|
||||
// Do block reduce when each column has more than 1 element to reduce.
|
||||
if (block.dim_threads().y > 1) {
|
||||
__shared__ U shared_vals[32 * 8 * N_READS];
|
||||
size_t col =
|
||||
block.thread_index().y * block.dim_threads().x + block.thread_index().x;
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
shared_vals[col * N_READS + i] = totals[i];
|
||||
}
|
||||
block.sync();
|
||||
if (block.thread_index().y == 0) {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
totals[i] = shared_vals[block.thread_index().x * N_READS + i];
|
||||
}
|
||||
for (int j = 1; j < block.dim_threads().y; j++) {
|
||||
col = j * block.dim_threads().x + block.thread_index().x;
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
totals[i] = op(shared_vals[col * N_READS + i], totals[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Write result.
|
||||
if (block.thread_index().y == 0) {
|
||||
rocprim::block_store_direct_blocked(
|
||||
column,
|
||||
out + out_idx * args.reduction_stride,
|
||||
totals,
|
||||
args.reduction_stride);
|
||||
}
|
||||
}
|
||||
|
||||
template <
|
||||
typename T,
|
||||
typename U,
|
||||
typename Op,
|
||||
int NDIM,
|
||||
int BM,
|
||||
int BN,
|
||||
int N_READS = 4>
|
||||
__global__ void col_reduce_looped(
|
||||
const T* in,
|
||||
U* out,
|
||||
const ColReduceArgs args) {
|
||||
auto grid = cg::this_grid();
|
||||
auto block = cg::this_thread_block();
|
||||
auto warp = cg::tiled_partition<WARP_SIZE>(block);
|
||||
|
||||
constexpr int n_warps = BN / N_READS;
|
||||
|
||||
int out_idx = grid.block_rank() / grid.dim_blocks().x;
|
||||
in += elem_to_loc(out_idx, args.shape.data(), args.strides.data(), args.ndim);
|
||||
|
||||
Op op;
|
||||
U totals[N_READS];
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
totals[i] = ReduceInit<Op, T>::value();
|
||||
}
|
||||
|
||||
// Read input to local.
|
||||
int r = block.thread_rank() / n_warps;
|
||||
int column = block.thread_rank() % n_warps;
|
||||
int in_offset = grid.block_index().x * BN;
|
||||
LoopedElemToLoc<NDIM, (NDIM > 2)> loop(args.reduce_ndim);
|
||||
loop.next(r, args.reduce_shape.data(), args.reduce_strides.data());
|
||||
for (; r < args.non_col_reductions * args.reduction_size; r += BM) {
|
||||
U vals[N_READS];
|
||||
rocprim::block_load_direct_blocked(
|
||||
column,
|
||||
make_cast_iterator<U>(in + loop.location() + in_offset),
|
||||
vals,
|
||||
args.reduction_stride - in_offset,
|
||||
ReduceInit<Op, T>::value());
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
totals[i] = op(vals[i], totals[i]);
|
||||
}
|
||||
loop.next(BM, args.reduce_shape.data(), args.reduce_strides.data());
|
||||
}
|
||||
|
||||
// Do warp reduce for each output.
|
||||
constexpr int n_outputs = BN / n_warps;
|
||||
static_assert(BM == 32 && n_outputs == N_READS);
|
||||
__shared__ U shared_vals[BM * BN];
|
||||
size_t col = block.thread_index().y * BN + block.thread_index().x * N_READS;
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
shared_vals[col + i] = totals[i];
|
||||
}
|
||||
block.sync();
|
||||
col = warp.thread_rank() * BN + warp.meta_group_rank() * n_outputs;
|
||||
for (int i = 0; i < n_outputs; i++) {
|
||||
totals[i] = cg::reduce(warp, shared_vals[col + i], op);
|
||||
}
|
||||
|
||||
// Write result.
|
||||
if (warp.thread_rank() == 0) {
|
||||
size_t out_offset = grid.block_index().x * BN;
|
||||
rocprim::block_store_direct_blocked(
|
||||
warp.meta_group_rank(),
|
||||
out + out_idx * args.reduction_stride + out_offset,
|
||||
totals,
|
||||
args.reduction_stride - out_offset);
|
||||
}
|
||||
}
|
||||
|
||||
// Utility functions and templates
|
||||
template <int NDIM, bool USE_FAST_PATH>
|
||||
struct LoopedElemToLoc {
|
||||
size_t location;
|
||||
|
||||
__device__ LoopedElemToLoc(int reduce_ndim) : location(0) {}
|
||||
|
||||
__device__ void next(size_t step, const int* shape, const size_t* strides) {
|
||||
// Simplified implementation - actual would handle multi-dimensional indexing
|
||||
location += step;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
__device__ inline T* make_cast_iterator(const T* ptr) {
|
||||
return const_cast<T*>(ptr);
|
||||
}
|
||||
|
||||
__device__ inline size_t elem_to_loc(
|
||||
size_t elem,
|
||||
const int* shape,
|
||||
const size_t* strides,
|
||||
int ndim) {
|
||||
size_t loc = 0;
|
||||
for (int i = ndim - 1; i >= 0; --i) {
|
||||
size_t q = elem / shape[i];
|
||||
size_t r = elem % shape[i];
|
||||
loc += r * strides[i];
|
||||
elem = q;
|
||||
}
|
||||
return loc;
|
||||
}
|
||||
|
||||
} // namespace rocm
|
||||
|
||||
inline auto output_grid_for_col_reduce(
|
||||
const array& out,
|
||||
const rocm::ColReduceArgs& args) {
|
||||
auto out_shape = out.shape();
|
||||
auto out_strides = out.strides();
|
||||
while (!out_shape.empty() && out_strides.back() < args.reduction_stride) {
|
||||
out_shape.pop_back();
|
||||
out_strides.pop_back();
|
||||
}
|
||||
return get_2d_grid_dims(out_shape, out_strides);
|
||||
}
|
||||
|
||||
void col_reduce(
|
||||
rocm::CommandEncoder& encoder,
|
||||
const array& in,
|
||||
array& out,
|
||||
Reduce::ReduceType reduce_type,
|
||||
const std::vector<int>& axes,
|
||||
const ReductionPlan& plan) {
|
||||
rocm::ColReduceArgs args(in, plan, axes);
|
||||
|
||||
encoder.launch_kernel([&](hipStream_t stream) {
|
||||
MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE, {
|
||||
using InType = hip_type_t<CTYPE>;
|
||||
MLX_SWITCH_REDUCE_OPS(reduce_type, OP, {
|
||||
using OutType = rocm::ReduceResult<OP, InType>::type;
|
||||
MLX_SWITCH_REDUCE_NDIM(args.reduce_ndim, NDIM, {
|
||||
constexpr int N_READS = 4;
|
||||
dim3 block_dims;
|
||||
dim3 num_blocks = output_grid_for_col_reduce(out, args);
|
||||
num_blocks.z = num_blocks.y;
|
||||
num_blocks.y = num_blocks.x;
|
||||
auto kernel =
|
||||
rocm::col_reduce_small<InType, OutType, OP, NDIM, N_READS>;
|
||||
size_t total = args.non_col_reductions * args.reduction_size;
|
||||
if (total < 32) {
|
||||
size_t stride_blocks =
|
||||
hip_ceil_div(args.reduction_stride, N_READS);
|
||||
block_dims.x = std::min(stride_blocks, 32ul);
|
||||
block_dims.y = std::min(total, 8ul);
|
||||
num_blocks.x = hip_ceil_div(stride_blocks, block_dims.x);
|
||||
} else {
|
||||
constexpr int BM = 32;
|
||||
constexpr int BN = 32;
|
||||
block_dims.x = BM * BN / N_READS;
|
||||
num_blocks.x = hip_ceil_div(args.reduction_stride, BN);
|
||||
kernel = rocm::
|
||||
col_reduce_looped<InType, OutType, OP, NDIM, BM, BN, N_READS>;
|
||||
}
|
||||
hipLaunchKernelGGL(kernel, num_blocks, block_dims, 0, stream,
|
||||
in.data<InType>(), out.data<OutType>(), args);
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
@@ -1,119 +0,0 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <hip/hip_runtime.h>
|
||||
#include <cstddef>
|
||||
|
||||
namespace mlx::core::rocm {
|
||||
|
||||
// Reduction operation types
|
||||
template <typename Op, typename T>
|
||||
struct ReduceInit {
|
||||
static constexpr T value();
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct ReduceInit<struct Sum, T> {
|
||||
static constexpr T value() {
|
||||
return T(0);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct ReduceInit<struct Max, T> {
|
||||
static constexpr T value() {
|
||||
return -std::numeric_limits<T>::infinity();
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct ReduceInit<struct Min, T> {
|
||||
static constexpr T value() {
|
||||
return std::numeric_limits<T>::infinity();
|
||||
}
|
||||
};
|
||||
|
||||
// Reduction operations
|
||||
struct Sum {
|
||||
template <typename T>
|
||||
__device__ T operator()(T a, T b) const {
|
||||
return a + b;
|
||||
}
|
||||
};
|
||||
|
||||
struct Max {
|
||||
template <typename T>
|
||||
__device__ T operator()(T a, T b) const {
|
||||
return fmax(a, b);
|
||||
}
|
||||
};
|
||||
|
||||
struct Min {
|
||||
template <typename T>
|
||||
__device__ T operator()(T a, T b) const {
|
||||
return fmin(a, b);
|
||||
}
|
||||
};
|
||||
|
||||
struct Prod {
|
||||
template <typename T>
|
||||
__device__ T operator()(T a, T b) const {
|
||||
return a * b;
|
||||
}
|
||||
};
|
||||
|
||||
// Utility functions for reductions
|
||||
template <typename T>
|
||||
__device__ T warp_reduce(T val, T (*op)(T, T)) {
|
||||
for (int offset = warpSize / 2; offset > 0; offset /= 2) {
|
||||
val = op(val, __shfl_down(val, offset));
|
||||
}
|
||||
return val;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__device__ T block_reduce(T val, T (*op)(T, T)) {
|
||||
static __shared__ T shared[32];
|
||||
int lane = threadIdx.x % warpSize;
|
||||
int wid = threadIdx.x / warpSize;
|
||||
|
||||
val = warp_reduce(val, op);
|
||||
|
||||
if (lane == 0)
|
||||
shared[wid] = val;
|
||||
__syncthreads();
|
||||
|
||||
val = (threadIdx.x < blockDim.x / warpSize) ? shared[lane] : 0;
|
||||
if (wid == 0)
|
||||
val = warp_reduce(val, op);
|
||||
|
||||
return val;
|
||||
}
|
||||
|
||||
// Column reduction arguments
|
||||
struct ColReduceArgs {
|
||||
size_t reduction_size;
|
||||
int64_t reduction_stride;
|
||||
int* shape;
|
||||
size_t* strides;
|
||||
int ndim;
|
||||
int* reduce_shape;
|
||||
size_t* reduce_strides;
|
||||
int reduce_ndim;
|
||||
size_t non_col_reductions;
|
||||
};
|
||||
|
||||
// Row reduction arguments
|
||||
struct RowReduceArgs {
|
||||
size_t reduction_size;
|
||||
int64_t reduction_stride;
|
||||
int* shape;
|
||||
size_t* strides;
|
||||
int ndim;
|
||||
int* reduce_shape;
|
||||
size_t* reduce_strides;
|
||||
int reduce_ndim;
|
||||
};
|
||||
|
||||
} // namespace mlx::core::rocm
|
||||
@@ -1,375 +0,0 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/rocm/device.h"
|
||||
#include "mlx/backend/rocm/iterators/strided_iterator.hpp"
|
||||
#include "mlx/backend/rocm/kernel_utils.hpp"
|
||||
#include "mlx/backend/rocm/reduce/reduce.hpp"
|
||||
#include "mlx/backend/gpu/copy.h"
|
||||
#include "mlx/dtype_utils.h"
|
||||
#include "mlx/fast_primitives.h"
|
||||
|
||||
#include <hip/hip_runtime.h>
|
||||
#include <hip/hip_cooperative_groups.h>
|
||||
#include <rocprim/block/block_load.hpp>
|
||||
#include <rocprim/block/block_reduce.hpp>
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace rocm {
|
||||
|
||||
namespace cg = cooperative_groups;
|
||||
|
||||
// Similar to rocprim::BlockReduce, but result is broadcasted to every thread.
|
||||
template <typename T, int BLOCK_DIM>
|
||||
struct BlockBroadcastReduce {
|
||||
static_assert(WARP_SIZE <= BLOCK_DIM && BLOCK_DIM <= WARP_SIZE * WARP_SIZE);
|
||||
static_assert(BLOCK_DIM % WARP_SIZE == 0);
|
||||
using TempStorage = T[BLOCK_DIM / WARP_SIZE];
|
||||
|
||||
cg::thread_block& block;
|
||||
TempStorage& temp;
|
||||
|
||||
template <typename Op>
|
||||
__device__ T Reduce(const T& input, const Op& op, const T& init_value) {
|
||||
auto warp = cg::tiled_partition<WARP_SIZE>(block);
|
||||
T x = cg::reduce(warp, input, op);
|
||||
if (warp.thread_rank() == 0) {
|
||||
temp[warp.meta_group_rank()] = x;
|
||||
}
|
||||
block.sync();
|
||||
x = warp.thread_rank() < warp.meta_group_size() ? temp[warp.thread_rank()]
|
||||
: init_value;
|
||||
return cg::reduce(warp, x, op);
|
||||
}
|
||||
|
||||
__device__ T Sum(const T& input) {
|
||||
return Reduce(input, hip_plus<T>{}, T{});
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, int BLOCK_DIM, int N_READS = 4>
|
||||
__global__ void rms_norm(
|
||||
const T* x,
|
||||
const T* w,
|
||||
T* out,
|
||||
float eps,
|
||||
int32_t axis_size,
|
||||
int64_t w_stride) {
|
||||
auto grid = cg::this_grid();
|
||||
auto block = cg::this_thread_block();
|
||||
|
||||
using BlockReduceT = BlockBroadcastReduce<float, BLOCK_DIM>;
|
||||
__shared__ typename BlockReduceT::TempStorage temp;
|
||||
|
||||
x += grid.block_rank() * axis_size;
|
||||
out += grid.block_rank() * axis_size;
|
||||
|
||||
// Sum of squares.
|
||||
float sum_sq = 0;
|
||||
for (int r = 0; r < hip_ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) {
|
||||
auto index = r * BLOCK_DIM + block.thread_rank();
|
||||
T xn[N_READS] = {};
|
||||
rocprim::block_load_direct_blocked(index, x, xn, axis_size);
|
||||
for (int i = 0; i < N_READS; ++i) {
|
||||
float val = static_cast<float>(xn[i]);
|
||||
sum_sq += val * val;
|
||||
}
|
||||
}
|
||||
sum_sq = BlockReduceT{block, temp}.Sum(sum_sq);
|
||||
|
||||
// RMS normalizer.
|
||||
float rms_normalizer = rsqrt(sum_sq / axis_size + eps);
|
||||
|
||||
// Outputs.
|
||||
for (int r = 0; r < hip_ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) {
|
||||
auto index = r * BLOCK_DIM + block.thread_rank();
|
||||
T xn[N_READS];
|
||||
T wn[N_READS];
|
||||
rocprim::block_load_direct_blocked(index, x, xn, axis_size);
|
||||
rocprim::block_load_direct_blocked(index, strided_iterator(w, w_stride), wn, axis_size);
|
||||
for (int i = 0; i < N_READS; ++i) {
|
||||
float norm = static_cast<float>(xn[i]) * rms_normalizer;
|
||||
xn[i] = wn[i] * static_cast<T>(norm);
|
||||
}
|
||||
rocprim::block_store_direct_blocked(index, out, xn, axis_size);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, bool HAS_W, int BLOCK_DIM, int N_READS = 4>
|
||||
__global__ void rms_norm_vjp(
|
||||
const T* x,
|
||||
const T* w,
|
||||
const T* g,
|
||||
T* gx,
|
||||
T* gw,
|
||||
float eps,
|
||||
int32_t axis_size,
|
||||
int64_t w_stride) {
|
||||
auto grid = cg::this_grid();
|
||||
auto block = cg::this_thread_block();
|
||||
|
||||
using BlockReduceF = BlockBroadcastReduce<float, BLOCK_DIM>;
|
||||
using BlockReduceF2 = BlockBroadcastReduce<float2, BLOCK_DIM>;
|
||||
__shared__ union {
|
||||
typename BlockReduceF::TempStorage f;
|
||||
typename BlockReduceF2::TempStorage f2;
|
||||
} temp;
|
||||
|
||||
x += grid.block_rank() * axis_size;
|
||||
g += grid.block_rank() * axis_size;
|
||||
gx += grid.block_rank() * axis_size;
|
||||
gw += grid.block_rank() * axis_size;
|
||||
|
||||
// Sum of squares.
|
||||
float sum_sq = 0;
|
||||
for (int r = 0; r < hip_ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) {
|
||||
auto index = r * BLOCK_DIM + block.thread_rank();
|
||||
T xn[N_READS] = {};
|
||||
rocprim::block_load_direct_blocked(index, x, xn, axis_size);
|
||||
for (int i = 0; i < N_READS; ++i) {
|
||||
float val = static_cast<float>(xn[i]);
|
||||
sum_sq += val * val;
|
||||
}
|
||||
}
|
||||
sum_sq = BlockReduceF{block, temp.f}.Sum(sum_sq);
|
||||
|
||||
// RMS normalizer.
|
||||
float rms_normalizer = rsqrt(sum_sq / axis_size + eps);
|
||||
|
||||
// Compute gradient terms.
|
||||
float2 factors = {};
|
||||
for (int r = 0; r < hip_ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) {
|
||||
T xn[N_READS];
|
||||
T wn[N_READS] = {};
|
||||
T gn[N_READS] = {};
|
||||
auto index = r * BLOCK_DIM + block.thread_rank();
|
||||
rocprim::block_load_direct_blocked(index, x, xn, axis_size);
|
||||
rocprim::block_load_direct_blocked(index, g, gn, axis_size);
|
||||
rocprim::block_load_direct_blocked(index, strided_iterator(w, w_stride), wn, axis_size);
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
float xi = static_cast<float>(xn[i]);
|
||||
float wi = wn[i];
|
||||
float gi = gn[i];
|
||||
float wg = wi * gi;
|
||||
factors.x += wg;
|
||||
factors.y += wg * xi;
|
||||
}
|
||||
}
|
||||
auto plus_f2 = [] __device__ (const float2& a, const float2& b) -> float2 {
|
||||
return {a.x + b.x, a.y + b.y};
|
||||
};
|
||||
factors = BlockReduceF2{block, temp.f2}.Reduce(factors, plus_f2, {});
|
||||
float mean_wg = factors.x / axis_size;
|
||||
float mean_wgx = factors.y / axis_size;
|
||||
float rms3 = rms_normalizer * rms_normalizer * rms_normalizer;
|
||||
|
||||
// Outputs.
|
||||
for (int r = 0; r < hip_ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) {
|
||||
auto index = r * BLOCK_DIM + block.thread_rank();
|
||||
T xn[N_READS];
|
||||
T wn[N_READS];
|
||||
T gn[N_READS];
|
||||
rocprim::block_load_direct_blocked(index, x, xn, axis_size);
|
||||
rocprim::block_load_direct_blocked(index, g, gn, axis_size);
|
||||
rocprim::block_load_direct_blocked(index, strided_iterator(w, w_stride), wn, axis_size);
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
float xi = static_cast<float>(xn[i]);
|
||||
float wi = wn[i];
|
||||
float gi = gn[i];
|
||||
float norm = xi * rms_normalizer;
|
||||
xn[i] = rms_normalizer * (wi * gi - mean_wg) - norm * mean_wgx * rms3;
|
||||
if constexpr (HAS_W) {
|
||||
wn[i] = gi * norm;
|
||||
}
|
||||
}
|
||||
rocprim::block_store_direct_blocked(index, gx, xn, axis_size);
|
||||
if constexpr (HAS_W) {
|
||||
rocprim::block_store_direct_blocked(index, gw, wn, axis_size);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Utility functions
|
||||
template <typename T>
|
||||
struct hip_plus {
|
||||
__device__ T operator()(const T& a, const T& b) const {
|
||||
return a + b;
|
||||
}
|
||||
};
|
||||
|
||||
inline __device__ int hip_ceil_div(int a, int b) {
|
||||
return (a + b - 1) / b;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__device__ inline auto strided_iterator(const T* ptr, int64_t stride) {
|
||||
return ptr + stride; // Simplified strided iterator
|
||||
}
|
||||
|
||||
} // namespace rocm
|
||||
|
||||
namespace fast {
|
||||
|
||||
bool RMSNorm::use_fallback(Stream s) {
|
||||
return s.device == Device::cpu;
|
||||
}
|
||||
|
||||
void RMSNorm::eval_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
auto& s = stream();
|
||||
auto& out = outputs[0];
|
||||
|
||||
// Make sure that the last dimension is contiguous.
|
||||
auto set_output = [&s, &out](const array& x) {
|
||||
bool no_copy = x.flags().contiguous && x.strides()[x.ndim() - 1] == 1;
|
||||
if (no_copy && x.ndim() > 1) {
|
||||
auto s = x.strides()[x.ndim() - 2];
|
||||
no_copy &= (s == 0 || s == x.shape().back());
|
||||
}
|
||||
if (no_copy) {
|
||||
if (x.is_donatable()) {
|
||||
out.copy_shared_buffer(x);
|
||||
} else {
|
||||
out.set_data(
|
||||
allocator::malloc(x.data_size() * x.itemsize()),
|
||||
x.data_size(),
|
||||
x.strides(),
|
||||
x.flags());
|
||||
}
|
||||
return x;
|
||||
} else {
|
||||
auto x_copy = array(x.shape(), x.dtype(), nullptr, {});
|
||||
copy_gpu(x, x_copy, CopyType::General, s);
|
||||
out.copy_shared_buffer(x_copy);
|
||||
return x_copy;
|
||||
}
|
||||
};
|
||||
|
||||
const array x = set_output(inputs[0]);
|
||||
const array& w = inputs[1];
|
||||
|
||||
int32_t axis_size = x.shape().back();
|
||||
int32_t n_rows = x.data_size() / axis_size;
|
||||
int64_t w_stride = (w.ndim() == 1) ? w.strides()[0] : 0;
|
||||
|
||||
auto& encoder = rocm::get_command_encoder(s);
|
||||
encoder.set_input_array(x);
|
||||
encoder.set_input_array(w);
|
||||
encoder.set_output_array(out);
|
||||
encoder.launch_kernel([&](hipStream_t stream) {
|
||||
MLX_SWITCH_FLOAT_TYPES_CHECKED(out.dtype(), "rmsnorm", CTYPE, {
|
||||
using DataType = hip_type_t<CTYPE>;
|
||||
constexpr uint32_t N_READS = 4;
|
||||
MLX_SWITCH_BLOCK_DIM(rocm::hip_ceil_div(axis_size, N_READS), BLOCK_DIM, {
|
||||
auto kernel = rocm::rms_norm<DataType, BLOCK_DIM, N_READS>;
|
||||
hipLaunchKernelGGL(kernel, n_rows, BLOCK_DIM, 0, stream,
|
||||
x.data<DataType>(),
|
||||
w.data<DataType>(),
|
||||
out.data<DataType>(),
|
||||
eps_,
|
||||
axis_size,
|
||||
w_stride);
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
void RMSNormVJP::eval_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
auto& s = stream();
|
||||
auto& encoder = rocm::get_command_encoder(s);
|
||||
|
||||
// Ensure row contiguity. We could relax this step by checking that the array
|
||||
// is contiguous (no broadcasts or holes) and that the input strides are the
|
||||
// same as the cotangent strides but for now this is simpler.
|
||||
auto check_input = [&s](const array& x) -> std::pair<array, bool> {
|
||||
if (x.flags().row_contiguous) {
|
||||
return {x, false};
|
||||
}
|
||||
array x_copy(x.shape(), x.dtype(), nullptr, {});
|
||||
copy_gpu(x, x_copy, CopyType::General, s);
|
||||
return {x_copy, true};
|
||||
};
|
||||
bool donate_x = inputs[0].is_donatable();
|
||||
bool donate_g = inputs[2].is_donatable();
|
||||
auto [x, copied] = check_input(inputs[0]);
|
||||
donate_x |= copied;
|
||||
const array& w = inputs[1];
|
||||
auto [g, g_copied] = check_input(inputs[2]);
|
||||
donate_g |= g_copied;
|
||||
array& gx = outputs[0];
|
||||
array& gw = outputs[1];
|
||||
|
||||
// Check whether we had a weight.
|
||||
bool has_w = w.ndim() != 0;
|
||||
|
||||
// Allocate space for the outputs.
|
||||
bool g_in_gx = false;
|
||||
if (donate_x) {
|
||||
gx.copy_shared_buffer(x);
|
||||
} else if (donate_g) {
|
||||
gx.copy_shared_buffer(g);
|
||||
g_in_gx = true;
|
||||
} else {
|
||||
gx.set_data(allocator::malloc(gx.nbytes()));
|
||||
}
|
||||
if (g_copied && !g_in_gx) {
|
||||
encoder.add_temporary(g);
|
||||
}
|
||||
|
||||
int32_t axis_size = x.shape().back();
|
||||
int32_t n_rows = x.data_size() / axis_size;
|
||||
int64_t w_stride = (w.ndim() == 1) ? w.strides()[0] : 0;
|
||||
|
||||
// Allocate a temporary to store the gradients for w and allocate the output
|
||||
// gradient accumulators.
|
||||
array gw_temp =
|
||||
(has_w) ? array({n_rows, x.shape().back()}, gw.dtype(), nullptr, {}) : w;
|
||||
if (has_w) {
|
||||
if (!g_in_gx && donate_g) {
|
||||
gw_temp.copy_shared_buffer(g);
|
||||
} else {
|
||||
gw_temp.set_data(allocator::malloc(gw_temp.nbytes()));
|
||||
encoder.add_temporary(gw_temp);
|
||||
}
|
||||
}
|
||||
gw.set_data(allocator::malloc(gw.nbytes()));
|
||||
|
||||
encoder.set_input_array(x);
|
||||
encoder.set_input_array(w);
|
||||
encoder.set_input_array(g);
|
||||
encoder.set_output_array(gx);
|
||||
encoder.set_output_array(gw_temp);
|
||||
encoder.launch_kernel([&, x = x, g = g](hipStream_t stream) {
|
||||
MLX_SWITCH_FLOAT_TYPES_CHECKED(gx.dtype(), "rmsnorm_vjp", CTYPE, {
|
||||
using DataType = hip_type_t<CTYPE>;
|
||||
constexpr int N_READS = 4;
|
||||
MLX_SWITCH_BOOL(has_w, HAS_W, {
|
||||
MLX_SWITCH_BLOCK_DIM(rocm::hip_ceil_div(axis_size, N_READS), BLOCK_DIM, {
|
||||
auto kernel = rocm::rms_norm_vjp<DataType, HAS_W, BLOCK_DIM, N_READS>;
|
||||
hipLaunchKernelGGL(kernel, n_rows, BLOCK_DIM, 0, stream,
|
||||
x.data<DataType>(),
|
||||
w.data<DataType>(),
|
||||
g.data<DataType>(),
|
||||
gx.data<DataType>(),
|
||||
gw_temp.data<DataType>(),
|
||||
eps_,
|
||||
axis_size,
|
||||
w_stride);
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
if (has_w) {
|
||||
ReductionPlan plan(
|
||||
ReductionOpType::ContiguousStridedReduce, {n_rows}, {axis_size});
|
||||
col_reduce(encoder, gw_temp, gw, Reduce::ReduceType::Sum, {0}, plan);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace fast
|
||||
|
||||
} // namespace mlx::core
|
||||
@@ -1,11 +0,0 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/rocm/rocm.h"
|
||||
|
||||
namespace mlx::core::rocm {
|
||||
|
||||
bool is_available() {
|
||||
return true;
|
||||
}
|
||||
|
||||
} // namespace mlx::core::rocm
|
||||
@@ -1,10 +0,0 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
namespace mlx::core::rocm {
|
||||
|
||||
/* Check if the ROCm backend is available. */
|
||||
bool is_available();
|
||||
|
||||
} // namespace mlx::core::rocm
|
||||
@@ -1,383 +0,0 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/rocm/device.h"
|
||||
#include "mlx/backend/rocm/kernel_utils.hpp"
|
||||
#include "mlx/backend/gpu/copy.h"
|
||||
#include "mlx/dtype_utils.h"
|
||||
#include "mlx/fast_primitives.h"
|
||||
|
||||
#include <hip/hip_runtime.h>
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace rocm {
|
||||
|
||||
template <typename T, bool traditional, bool forward>
|
||||
__device__ void rope_single_impl(
|
||||
const T* in,
|
||||
T* out,
|
||||
int32_t offset,
|
||||
float inv_freq,
|
||||
float scale,
|
||||
int64_t stride,
|
||||
uint2 pos,
|
||||
uint2 dims) {
|
||||
float L = scale * static_cast<float>(offset);
|
||||
|
||||
// Compute costheta, sintheta
|
||||
float theta = L * inv_freq;
|
||||
float costheta = cos(theta);
|
||||
float sintheta = sin(theta);
|
||||
|
||||
// Compute the input and output indices
|
||||
uint index_1, index_2;
|
||||
if (traditional) {
|
||||
index_1 = 2 * pos.x + pos.y * stride;
|
||||
index_2 = index_1 + 1;
|
||||
} else {
|
||||
index_1 = pos.x + pos.y * stride;
|
||||
index_2 = index_1 + dims.x;
|
||||
}
|
||||
|
||||
// Read and write the output
|
||||
float x1 = static_cast<float>(in[index_1]);
|
||||
float x2 = static_cast<float>(in[index_2]);
|
||||
float rx1;
|
||||
float rx2;
|
||||
if (forward) {
|
||||
rx1 = x1 * costheta - x2 * sintheta;
|
||||
rx2 = x1 * sintheta + x2 * costheta;
|
||||
} else {
|
||||
rx1 = x2 * sintheta + x1 * costheta;
|
||||
rx2 = x2 * costheta - x1 * sintheta;
|
||||
}
|
||||
out[index_1] = static_cast<T>(rx1);
|
||||
out[index_2] = static_cast<T>(rx2);
|
||||
}
|
||||
|
||||
template <typename T, bool traditional, bool forward>
|
||||
__global__ void rope_single(
|
||||
const T* in,
|
||||
T* out,
|
||||
const int32_t* offset,
|
||||
float scale,
|
||||
float base,
|
||||
int64_t stride,
|
||||
uint2 dims) {
|
||||
uint2 pos = make_uint2(
|
||||
blockIdx.x * blockDim.x + threadIdx.x,
|
||||
blockIdx.y * blockDim.y + threadIdx.y);
|
||||
if (pos.x >= dims.x || pos.y >= dims.y) {
|
||||
return;
|
||||
}
|
||||
|
||||
float d = static_cast<float>(pos.x) / static_cast<float>(dims.x);
|
||||
float inv_freq = exp2(-d * base);
|
||||
rope_single_impl<T, traditional, forward>(
|
||||
in, out, *offset, inv_freq, scale, stride, pos, dims);
|
||||
}
|
||||
|
||||
template <typename T, bool traditional, bool forward>
|
||||
__global__ void rope_single_freqs(
|
||||
const T* in,
|
||||
T* out,
|
||||
const int32_t* offset,
|
||||
const float* freqs,
|
||||
float scale,
|
||||
int64_t stride,
|
||||
uint2 dims,
|
||||
int64_t freq_stride) {
|
||||
uint2 pos = make_uint2(
|
||||
blockIdx.x * blockDim.x + threadIdx.x,
|
||||
blockIdx.y * blockDim.y + threadIdx.y);
|
||||
if (pos.x >= dims.x || pos.y >= dims.y) {
|
||||
return;
|
||||
}
|
||||
|
||||
float inv_freq = 1.0 / freqs[freq_stride * pos.x];
|
||||
rope_single_impl<T, traditional, forward>(
|
||||
in, out, *offset, inv_freq, scale, stride, pos, dims);
|
||||
}
|
||||
|
||||
template <typename T, bool traditional, bool forward, int N = 4>
|
||||
__device__ void rope_impl(
|
||||
const T* in,
|
||||
T* out,
|
||||
int offset,
|
||||
float inv_freq,
|
||||
float scale,
|
||||
const hip_array<int64_t, 3> strides,
|
||||
const hip_array<int64_t, 3> out_strides,
|
||||
int64_t n_batch,
|
||||
uint3 pos,
|
||||
uint3 dims) {
|
||||
float L = scale * static_cast<float>(pos.y + offset);
|
||||
|
||||
// Compute costheta, sintheta
|
||||
float theta = L * inv_freq;
|
||||
float costheta = cos(theta);
|
||||
float sintheta = sin(theta);
|
||||
|
||||
// Compute the input and output indices
|
||||
size_t in_index_1, in_index_2;
|
||||
size_t out_index_1, out_index_2;
|
||||
if (traditional) {
|
||||
out_index_1 = 2 * pos.x * out_strides[2] + pos.y * out_strides[1] +
|
||||
N * pos.z * out_strides[0];
|
||||
out_index_2 = out_index_1 + 1;
|
||||
in_index_1 =
|
||||
2 * pos.x * strides[2] + pos.y * strides[1] + N * pos.z * strides[0];
|
||||
in_index_2 = in_index_1 + strides[2];
|
||||
} else {
|
||||
out_index_1 = pos.x * out_strides[2] + pos.y * out_strides[1] +
|
||||
N * pos.z * out_strides[0];
|
||||
out_index_2 = out_index_1 + dims.x * out_strides[2];
|
||||
in_index_1 =
|
||||
pos.x * strides[2] + pos.y * strides[1] + N * pos.z * strides[0];
|
||||
in_index_2 = in_index_1 + dims.x * strides[2];
|
||||
}
|
||||
for (int i = 0; i < N && pos.z * N + i < n_batch; ++i) {
|
||||
// Read and write the output
|
||||
float x1 = static_cast<float>(in[in_index_1]);
|
||||
float x2 = static_cast<float>(in[in_index_2]);
|
||||
float rx1;
|
||||
float rx2;
|
||||
if (forward) {
|
||||
rx1 = x1 * costheta - x2 * sintheta;
|
||||
rx2 = x1 * sintheta + x2 * costheta;
|
||||
} else {
|
||||
rx1 = x2 * sintheta + x1 * costheta;
|
||||
rx2 = x2 * costheta - x1 * sintheta;
|
||||
}
|
||||
out[out_index_1] = static_cast<T>(rx1);
|
||||
out[out_index_2] = static_cast<T>(rx2);
|
||||
in_index_1 += strides[0];
|
||||
in_index_2 += strides[0];
|
||||
out_index_1 += out_strides[0];
|
||||
out_index_2 += out_strides[0];
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, bool traditional, bool forward>
|
||||
__global__ void rope(
|
||||
const T* in,
|
||||
T* out,
|
||||
const int32_t* offset,
|
||||
float scale,
|
||||
float base,
|
||||
const hip_array<int64_t, 3> strides,
|
||||
const hip_array<int64_t, 3> out_strides,
|
||||
int64_t n_batch,
|
||||
uint3 dims) {
|
||||
uint3 pos = make_uint3(
|
||||
blockIdx.x * blockDim.x + threadIdx.x,
|
||||
blockIdx.y * blockDim.y + threadIdx.y,
|
||||
blockIdx.z * blockDim.z + threadIdx.z);
|
||||
if (pos.x >= dims.x || pos.y >= dims.y || pos.z >= dims.z) {
|
||||
return;
|
||||
}
|
||||
|
||||
float d = static_cast<float>(pos.x) / static_cast<float>(dims.x);
|
||||
float inv_freq = exp2(-d * base);
|
||||
rope_impl<T, traditional, forward>(
|
||||
in,
|
||||
out,
|
||||
*offset,
|
||||
inv_freq,
|
||||
scale,
|
||||
strides,
|
||||
out_strides,
|
||||
n_batch,
|
||||
pos,
|
||||
dims);
|
||||
}
|
||||
|
||||
template <typename T, bool traditional, bool forward>
|
||||
__global__ void rope_freqs(
|
||||
const T* in,
|
||||
T* out,
|
||||
const int32_t* offset,
|
||||
const float* freqs,
|
||||
float scale,
|
||||
float base,
|
||||
const hip_array<int64_t, 3> strides,
|
||||
const hip_array<int64_t, 3> out_strides,
|
||||
int64_t n_batch,
|
||||
uint3 dims,
|
||||
int64_t freq_stride) {
|
||||
uint3 pos = make_uint3(
|
||||
blockIdx.x * blockDim.x + threadIdx.x,
|
||||
blockIdx.y * blockDim.y + threadIdx.y,
|
||||
blockIdx.z * blockDim.z + threadIdx.z);
|
||||
if (pos.x >= dims.x || pos.y >= dims.y || pos.z >= dims.z) {
|
||||
return;
|
||||
}
|
||||
|
||||
float inv_freq = 1.0 / freqs[freq_stride * pos.x];
|
||||
rope_impl<T, traditional, forward>(
|
||||
in,
|
||||
out,
|
||||
*offset,
|
||||
inv_freq,
|
||||
scale,
|
||||
strides,
|
||||
out_strides,
|
||||
n_batch,
|
||||
pos,
|
||||
dims);
|
||||
}
|
||||
|
||||
} // namespace rocm
|
||||
|
||||
namespace fast {
|
||||
|
||||
bool RoPE::use_fallback(Stream s) {
|
||||
return s.device == Device::cpu;
|
||||
}
|
||||
|
||||
void RoPE::eval_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
auto& s = stream();
|
||||
auto& in = inputs[0];
|
||||
auto& offset = inputs[1];
|
||||
auto& out = outputs[0];
|
||||
|
||||
if (in.ndim() < 3) {
|
||||
throw std::runtime_error("[RoPE] Input must have at least 3 dimensions");
|
||||
}
|
||||
|
||||
hip_array<int64_t, 3> strides;
|
||||
hip_array<int64_t, 3> out_strides;
|
||||
bool donated = false;
|
||||
int ndim = in.ndim();
|
||||
int dispatch_ndim = in.ndim();
|
||||
while (in.shape(-dispatch_ndim) == 1 && dispatch_ndim > 3) {
|
||||
dispatch_ndim--;
|
||||
}
|
||||
size_t mat_size = in.shape(-2) * in.shape(-1);
|
||||
|
||||
// We apply rope to less that the whole vector so copy to output and then
|
||||
// apply in-place.
|
||||
if (dims_ < in.shape(-1)) {
|
||||
donated = true;
|
||||
auto ctype =
|
||||
(in.flags().row_contiguous) ? CopyType::Vector : CopyType::General;
|
||||
copy_gpu(in, out, ctype, s);
|
||||
strides[0] = mat_size;
|
||||
strides[1] = out.strides()[ndim - 2];
|
||||
strides[2] = out.strides()[ndim - 1];
|
||||
}
|
||||
|
||||
// Either copy or apply in-place
|
||||
else if (in.flags().row_contiguous) {
|
||||
if (in.is_donatable()) {
|
||||
donated = true;
|
||||
out.copy_shared_buffer(in);
|
||||
} else {
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
}
|
||||
strides[0] = mat_size;
|
||||
strides[1] = in.strides()[ndim - 2];
|
||||
strides[2] = in.strides()[ndim - 1];
|
||||
} else if (dispatch_ndim == 3) {
|
||||
// Handle non-contiguous 3D inputs
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
strides[0] = in.strides()[ndim - 3];
|
||||
strides[1] = in.strides()[ndim - 2];
|
||||
strides[2] = in.strides()[ndim - 1];
|
||||
} else {
|
||||
// Copy non-contiguous > 3D inputs into the output and treat
|
||||
// input as donated
|
||||
donated = true;
|
||||
copy_gpu(in, out, CopyType::General, s);
|
||||
strides[0] = mat_size;
|
||||
strides[1] = out.strides()[ndim - 2];
|
||||
strides[2] = out.strides()[ndim - 1];
|
||||
}
|
||||
out_strides[0] = mat_size;
|
||||
out_strides[1] = out.strides()[ndim - 2];
|
||||
out_strides[2] = out.strides()[ndim - 1];
|
||||
|
||||
// Some flags to help us dispatch below
|
||||
bool single = in.flags().row_contiguous && (mat_size == in.shape(-1));
|
||||
bool with_freqs = inputs.size() == 3;
|
||||
|
||||
auto& encoder = rocm::get_command_encoder(s);
|
||||
encoder.set_input_array(donated ? out : in);
|
||||
encoder.set_input_array(offset);
|
||||
encoder.set_output_array(out);
|
||||
encoder.launch_kernel([&](hipStream_t stream) {
|
||||
MLX_SWITCH_FLOAT_TYPES_CHECKED(in.dtype(), "rope", CTYPE, {
|
||||
using DataType = hip_type_t<CTYPE>;
|
||||
MLX_SWITCH_BOOL(traditional_, TRADITIONAL, {
|
||||
MLX_SWITCH_BOOL(forward_, FORWARD, {
|
||||
if (single && !with_freqs) {
|
||||
auto kernel = rocm::rope_single<DataType, TRADITIONAL, FORWARD>;
|
||||
uint2 dims = make_uint2(dims_ / 2, in.size() / mat_size);
|
||||
auto [grid, block] = get_grid_and_block(dims.x, dims.y, 1);
|
||||
hipLaunchKernelGGL(kernel, grid, block, 0, stream,
|
||||
(donated ? out : in).data<DataType>(),
|
||||
out.data<DataType>(),
|
||||
offset.data<int32_t>(),
|
||||
scale_,
|
||||
std::log2(base_),
|
||||
mat_size,
|
||||
dims);
|
||||
} else if (single) {
|
||||
auto kernel = rocm::rope_single_freqs<DataType, TRADITIONAL, FORWARD>;
|
||||
uint2 dims = make_uint2(dims_ / 2, in.size() / mat_size);
|
||||
auto [grid, block] = get_grid_and_block(dims.x, dims.y, 1);
|
||||
hipLaunchKernelGGL(kernel, grid, block, 0, stream,
|
||||
(donated ? out : in).data<DataType>(),
|
||||
out.data<DataType>(),
|
||||
offset.data<int32_t>(),
|
||||
inputs[2].data<float>(),
|
||||
scale_,
|
||||
mat_size,
|
||||
dims,
|
||||
inputs[2].strides(0));
|
||||
} else if (with_freqs) {
|
||||
auto kernel = rocm::rope_freqs<DataType, TRADITIONAL, FORWARD>;
|
||||
uint3 dims =
|
||||
make_uint3(dims_ / 2, in.shape(-2), in.size() / mat_size);
|
||||
dims.z = (dims.z + 3) / 4;
|
||||
auto [grid, block] = get_grid_and_block(dims.x, dims.y, dims.z);
|
||||
hipLaunchKernelGGL(kernel, grid, block, 0, stream,
|
||||
(donated ? out : in).data<DataType>(),
|
||||
out.data<DataType>(),
|
||||
offset.data<int32_t>(),
|
||||
inputs[2].data<float>(),
|
||||
scale_,
|
||||
std::log2(base_),
|
||||
strides,
|
||||
out_strides,
|
||||
in.size() / mat_size,
|
||||
dims,
|
||||
inputs[2].strides(0));
|
||||
} else {
|
||||
auto kernel = rocm::rope<DataType, TRADITIONAL, FORWARD>;
|
||||
uint3 dims =
|
||||
make_uint3(dims_ / 2, in.shape(-2), in.size() / mat_size);
|
||||
dims.z = (dims.z + 3) / 4;
|
||||
auto [grid, block] = get_grid_and_block(dims.x, dims.y, dims.z);
|
||||
hipLaunchKernelGGL(kernel, grid, block, 0, stream,
|
||||
(donated ? out : in).data<DataType>(),
|
||||
out.data<DataType>(),
|
||||
offset.data<int32_t>(),
|
||||
scale_,
|
||||
std::log2(base_),
|
||||
strides,
|
||||
out_strides,
|
||||
in.size() / mat_size,
|
||||
dims);
|
||||
}
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace fast
|
||||
|
||||
} // namespace mlx::core
|
||||
@@ -1,9 +0,0 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
namespace mlx::core::rocm {
|
||||
|
||||
void slice() {
|
||||
// Placeholder for ROCm slicing operation
|
||||
}
|
||||
|
||||
} // namespace mlx::core::rocm
|
||||
@@ -1,179 +0,0 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/rocm/device.h"
|
||||
#include "mlx/backend/rocm/device/cast_op.hpp"
|
||||
#include "mlx/backend/rocm/device/fp16_math.hpp"
|
||||
#include "mlx/backend/rocm/kernel_utils.hpp"
|
||||
#include "mlx/backend/gpu/copy.h"
|
||||
#include "mlx/dtype_utils.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
#include <hip/hip_runtime.h>
|
||||
#include <hip/hip_cooperative_groups.h>
|
||||
#include <rocprim/block/block_load.hpp>
|
||||
|
||||
#include <cassert>
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace rocm {
|
||||
|
||||
namespace cg = cooperative_groups;
|
||||
|
||||
template <typename T>
|
||||
inline __device__ T softmax_exp(T x) {
|
||||
// Softmax doesn't need high precision exponential cause x is gonna be in
|
||||
// (-oo, 0] anyway and subsequently it will be divided by sum(exp(x_i)).
|
||||
return __expf(x);
|
||||
}
|
||||
|
||||
template <typename T, typename AccT, int BLOCK_DIM, int N_READS = 4>
|
||||
__global__ void softmax(const T* in, T* out, int axis_size) {
|
||||
auto grid = cg::this_grid();
|
||||
auto block = cg::this_thread_block();
|
||||
auto warp = cg::tiled_partition<WARP_SIZE>(block);
|
||||
|
||||
in += grid.block_rank() * axis_size;
|
||||
out += grid.block_rank() * axis_size;
|
||||
|
||||
// Thread reduce.
|
||||
AccT prevmax;
|
||||
AccT maxval = -INFINITY;
|
||||
AccT normalizer = 0;
|
||||
for (int r = 0; r < hip_ceil_div(axis_size, BLOCK_DIM * N_READS); r++) {
|
||||
AccT vals[N_READS];
|
||||
rocprim::block_load_direct_blocked(
|
||||
r * BLOCK_DIM + block.thread_rank(),
|
||||
make_cast_iterator<AccT>(in),
|
||||
vals,
|
||||
axis_size,
|
||||
-INFINITY);
|
||||
prevmax = maxval;
|
||||
maxval = fmax(maxval, rocprim::thread_reduce(vals, hip_max<AccT>()));
|
||||
// Online normalizer calculation for softmax:
|
||||
// https://github.com/NVIDIA/online-softmax
|
||||
normalizer = normalizer * softmax_exp(prevmax - maxval);
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
normalizer = normalizer + softmax_exp(vals[i] - maxval);
|
||||
}
|
||||
}
|
||||
|
||||
// First warp reduce.
|
||||
prevmax = maxval;
|
||||
maxval = cg::reduce(warp, maxval, hip_max<AccT>());
|
||||
normalizer = normalizer * softmax_exp(prevmax - maxval);
|
||||
normalizer = cg::reduce(warp, normalizer, hip_plus<AccT>());
|
||||
|
||||
__shared__ AccT local_max[WARP_SIZE];
|
||||
__shared__ AccT local_normalizer[WARP_SIZE];
|
||||
|
||||
// Write to shared memory and do second warp reduce.
|
||||
prevmax = maxval;
|
||||
if (warp.thread_rank() == 0) {
|
||||
local_max[warp.meta_group_rank()] = maxval;
|
||||
}
|
||||
block.sync();
|
||||
maxval = warp.thread_rank() < warp.meta_group_size()
|
||||
? local_max[warp.thread_rank()]
|
||||
: -INFINITY;
|
||||
maxval = cg::reduce(warp, maxval, hip_max<AccT>());
|
||||
normalizer = normalizer * softmax_exp(prevmax - maxval);
|
||||
if (warp.thread_rank() == 0) {
|
||||
local_normalizer[warp.meta_group_rank()] = normalizer;
|
||||
}
|
||||
block.sync();
|
||||
normalizer = warp.thread_rank() < warp.meta_group_size()
|
||||
? local_normalizer[warp.thread_rank()]
|
||||
: AccT{};
|
||||
normalizer = cg::reduce(warp, normalizer, hip_plus<AccT>());
|
||||
normalizer = 1 / normalizer;
|
||||
|
||||
// Write output.
|
||||
for (int r = 0; r < hip_ceil_div(axis_size, BLOCK_DIM * N_READS); r++) {
|
||||
auto index = r * BLOCK_DIM + block.thread_rank();
|
||||
T vals[N_READS];
|
||||
rocprim::block_load_direct_blocked(index, in, vals, axis_size);
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
vals[i] = softmax_exp(static_cast<AccT>(vals[i]) - maxval) * normalizer;
|
||||
}
|
||||
rocprim::block_store_direct_blocked(index, out, vals, axis_size);
|
||||
}
|
||||
}
|
||||
|
||||
// Utility functions for ROCm
|
||||
template <typename T>
|
||||
struct hip_max {
|
||||
__device__ T operator()(const T& a, const T& b) const {
|
||||
return fmax(a, b);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct hip_plus {
|
||||
__device__ T operator()(const T& a, const T& b) const {
|
||||
return a + b;
|
||||
}
|
||||
};
|
||||
|
||||
inline __device__ int hip_ceil_div(int a, int b) {
|
||||
return (a + b - 1) / b;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__device__ inline T* make_cast_iterator(const T* ptr) {
|
||||
return const_cast<T*>(ptr);
|
||||
}
|
||||
|
||||
} // namespace rocm
|
||||
|
||||
void Softmax::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& s = stream();
|
||||
|
||||
// Make sure that the last dimension is contiguous.
|
||||
auto set_output = [&s, &out](const array& x) {
|
||||
if (x.flags().contiguous && x.strides()[x.ndim() - 1] == 1) {
|
||||
if (x.is_donatable()) {
|
||||
out.copy_shared_buffer(x);
|
||||
} else {
|
||||
out.set_data(
|
||||
allocator::malloc(x.data_size() * x.itemsize()),
|
||||
x.data_size(),
|
||||
x.strides(),
|
||||
x.flags());
|
||||
}
|
||||
return x;
|
||||
} else {
|
||||
auto x_copy = array(x.shape(), x.dtype(), nullptr, {});
|
||||
copy_gpu(x, x_copy, CopyType::General, s);
|
||||
out.copy_shared_buffer(x_copy);
|
||||
return x_copy;
|
||||
}
|
||||
};
|
||||
|
||||
array in = set_output(inputs[0]);
|
||||
bool precise = in.dtype() != float32 && precise_;
|
||||
|
||||
int axis_size = in.shape().back();
|
||||
int n_rows = in.data_size() / axis_size;
|
||||
|
||||
auto& encoder = rocm::get_command_encoder(s);
|
||||
encoder.set_input_array(in);
|
||||
encoder.set_output_array(out);
|
||||
encoder.launch_kernel([&](hipStream_t stream) {
|
||||
MLX_SWITCH_FLOAT_TYPES_CHECKED(out.dtype(), "softmax", CTYPE, {
|
||||
using DataType = hip_type_t<CTYPE>;
|
||||
constexpr int N_READS = 4;
|
||||
MLX_SWITCH_BLOCK_DIM(rocm::hip_ceil_div(axis_size, N_READS), BLOCK_DIM, {
|
||||
auto kernel = rocm::softmax<DataType, DataType, BLOCK_DIM, N_READS>;
|
||||
if (precise) {
|
||||
kernel = rocm::softmax<DataType, float, BLOCK_DIM, N_READS>;
|
||||
}
|
||||
hipLaunchKernelGGL(kernel, n_rows, BLOCK_DIM, 0, stream,
|
||||
in.data<DataType>(), out.data<DataType>(), axis_size);
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user