mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 01:17:26 +08:00
Compare commits
8 Commits
23eb58b37c
...
4ce48a3996
Author | SHA1 | Date | |
---|---|---|---|
![]() |
4ce48a3996 | ||
![]() |
c9a9180584 | ||
![]() |
76831ed83d | ||
![]() |
b3d7b85376 | ||
![]() |
cad5c0241c | ||
![]() |
b8022c578a | ||
![]() |
bc53f8293f | ||
![]() |
c552ff2451 |
@ -16,6 +16,9 @@ parameters:
|
|||||||
linux_release:
|
linux_release:
|
||||||
type: boolean
|
type: boolean
|
||||||
default: false
|
default: false
|
||||||
|
cuda_release:
|
||||||
|
type: boolean
|
||||||
|
default: false
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
build_documentation:
|
build_documentation:
|
||||||
@ -104,7 +107,7 @@ jobs:
|
|||||||
command: |
|
command: |
|
||||||
echo "stubs"
|
echo "stubs"
|
||||||
pip install typing_extensions
|
pip install typing_extensions
|
||||||
python setup.py generate_stubs
|
python setup.py generate_stubs
|
||||||
- run:
|
- run:
|
||||||
name: Run Python tests
|
name: Run Python tests
|
||||||
command: |
|
command: |
|
||||||
@ -162,7 +165,7 @@ jobs:
|
|||||||
command: |
|
command: |
|
||||||
source env/bin/activate
|
source env/bin/activate
|
||||||
pip install typing_extensions
|
pip install typing_extensions
|
||||||
python setup.py generate_stubs
|
python setup.py generate_stubs
|
||||||
- run:
|
- run:
|
||||||
name: Run Python tests
|
name: Run Python tests
|
||||||
command: |
|
command: |
|
||||||
@ -223,7 +226,6 @@ jobs:
|
|||||||
command: |
|
command: |
|
||||||
sudo apt-get update
|
sudo apt-get update
|
||||||
sudo apt-get install libblas-dev liblapack-dev liblapacke-dev
|
sudo apt-get install libblas-dev liblapack-dev liblapacke-dev
|
||||||
sudo apt-get install openmpi-bin openmpi-common libopenmpi-dev
|
|
||||||
python -m venv env
|
python -m venv env
|
||||||
source env/bin/activate
|
source env/bin/activate
|
||||||
CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
|
CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
|
||||||
@ -283,7 +285,7 @@ jobs:
|
|||||||
command: |
|
command: |
|
||||||
source env/bin/activate
|
source env/bin/activate
|
||||||
pip install typing_extensions
|
pip install typing_extensions
|
||||||
python setup.py generate_stubs
|
python setup.py generate_stubs
|
||||||
- run:
|
- run:
|
||||||
name: Build Python package
|
name: Build Python package
|
||||||
command: |
|
command: |
|
||||||
@ -342,7 +344,7 @@ jobs:
|
|||||||
CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
|
CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
|
||||||
pip install . -v
|
pip install . -v
|
||||||
pip install typing_extensions
|
pip install typing_extensions
|
||||||
python setup.py generate_stubs
|
python setup.py generate_stubs
|
||||||
<< parameters.extra_env >> \
|
<< parameters.extra_env >> \
|
||||||
CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
|
CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
|
||||||
python -m build --wheel
|
python -m build --wheel
|
||||||
@ -356,6 +358,48 @@ jobs:
|
|||||||
- store_artifacts:
|
- store_artifacts:
|
||||||
path: wheelhouse/
|
path: wheelhouse/
|
||||||
|
|
||||||
|
build_cuda_release:
|
||||||
|
parameters:
|
||||||
|
python_version:
|
||||||
|
type: string
|
||||||
|
default: "3.9"
|
||||||
|
extra_env:
|
||||||
|
type: string
|
||||||
|
default: "DEV_RELEASE=1"
|
||||||
|
machine:
|
||||||
|
image: linux-cuda-12:default
|
||||||
|
resource_class: gpu.nvidia.small.gen2
|
||||||
|
steps:
|
||||||
|
- checkout
|
||||||
|
- run:
|
||||||
|
name: Build wheel
|
||||||
|
command: |
|
||||||
|
sudo apt-get update
|
||||||
|
sudo apt-get install libblas-dev liblapack-dev liblapacke-dev
|
||||||
|
python -m venv env
|
||||||
|
source env/bin/activate
|
||||||
|
pip install auditwheel
|
||||||
|
pip install patchelf
|
||||||
|
pip install build
|
||||||
|
pip install twine
|
||||||
|
<< parameters.extra_env >> \
|
||||||
|
CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
|
||||||
|
CMAKE_ARGS="-DMLX_BUILD_CUDA=ON -DCMAKE_CUDA_COMPILER=`which nvcc`" \
|
||||||
|
pip install ".[dev]" -v
|
||||||
|
python setup.py generate_stubs
|
||||||
|
<< parameters.extra_env >> \
|
||||||
|
CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
|
||||||
|
CMAKE_ARGS="-DMLX_BUILD_CUDA=ON -DCMAKE_CUDA_COMPILER=`which nvcc`" \
|
||||||
|
python -m build --wheel
|
||||||
|
bash python/scripts/repair_cuda.sh
|
||||||
|
- run:
|
||||||
|
name: Upload package
|
||||||
|
command: |
|
||||||
|
source env/bin/activate
|
||||||
|
twine upload wheelhouse/*.whl
|
||||||
|
- store_artifacts:
|
||||||
|
path: wheelhouse/
|
||||||
|
|
||||||
workflows:
|
workflows:
|
||||||
build_and_test:
|
build_and_test:
|
||||||
when:
|
when:
|
||||||
@ -625,3 +669,14 @@ workflows:
|
|||||||
parameters:
|
parameters:
|
||||||
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
|
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
|
||||||
extra_env: ["PYPI_RELEASE=1"]
|
extra_env: ["PYPI_RELEASE=1"]
|
||||||
|
cuda_test_release:
|
||||||
|
when:
|
||||||
|
and:
|
||||||
|
- equal: [ main, << pipeline.git.branch >> ]
|
||||||
|
- << pipeline.parameters.cuda_release >>
|
||||||
|
jobs:
|
||||||
|
- build_cuda_release:
|
||||||
|
matrix:
|
||||||
|
parameters:
|
||||||
|
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
|
||||||
|
extra_env: ["PYPI_RELEASE=1"]
|
||||||
|
@ -30,6 +30,16 @@ MLX is also available on conda-forge. To install MLX with conda do:
|
|||||||
|
|
||||||
conda install conda-forge::mlx
|
conda install conda-forge::mlx
|
||||||
|
|
||||||
|
CUDA
|
||||||
|
^^^^
|
||||||
|
|
||||||
|
MLX has a CUDA backend which you can use on any Linux platform with CUDA 12
|
||||||
|
and SM 7.0 (Volta) and up. To install MLX with CUDA support, run:
|
||||||
|
|
||||||
|
.. code-block:: shell
|
||||||
|
|
||||||
|
pip install mlx-cuda
|
||||||
|
|
||||||
|
|
||||||
Troubleshooting
|
Troubleshooting
|
||||||
^^^^^^^^^^^^^^^
|
^^^^^^^^^^^^^^^
|
||||||
@ -65,6 +75,8 @@ Build Requirements
|
|||||||
Python API
|
Python API
|
||||||
^^^^^^^^^^
|
^^^^^^^^^^
|
||||||
|
|
||||||
|
.. _python install:
|
||||||
|
|
||||||
To build and install the MLX python library from source, first, clone MLX from
|
To build and install the MLX python library from source, first, clone MLX from
|
||||||
`its GitHub repo <https://github.com/ml-explore/mlx>`_:
|
`its GitHub repo <https://github.com/ml-explore/mlx>`_:
|
||||||
|
|
||||||
@ -107,6 +119,8 @@ IDE:
|
|||||||
C++ API
|
C++ API
|
||||||
^^^^^^^
|
^^^^^^^
|
||||||
|
|
||||||
|
.. _cpp install:
|
||||||
|
|
||||||
Currently, MLX must be built and installed from source.
|
Currently, MLX must be built and installed from source.
|
||||||
|
|
||||||
Similarly to the python library, to build and install the MLX C++ library start
|
Similarly to the python library, to build and install the MLX C++ library start
|
||||||
@ -185,6 +199,7 @@ should point to the path to the built metal library.
|
|||||||
|
|
||||||
xcrun -sdk macosx --show-sdk-version
|
xcrun -sdk macosx --show-sdk-version
|
||||||
|
|
||||||
|
|
||||||
Binary Size Minimization
|
Binary Size Minimization
|
||||||
~~~~~~~~~~~~~~~~~~~~~~~~
|
~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
@ -213,6 +228,50 @@ be anwywhere from a few hundred millisecond to a few seconds depending on the
|
|||||||
application. Once a kernel is compiled, it will be cached by the system. The
|
application. Once a kernel is compiled, it will be cached by the system. The
|
||||||
Metal kernel cache persists across reboots.
|
Metal kernel cache persists across reboots.
|
||||||
|
|
||||||
|
Linux
|
||||||
|
^^^^^
|
||||||
|
|
||||||
|
To build from source on Linux (CPU only), install the BLAS and LAPACK headers.
|
||||||
|
For example on Ubuntu, run the following:
|
||||||
|
|
||||||
|
.. code-block:: shell
|
||||||
|
|
||||||
|
apt-get update -y
|
||||||
|
apt-get install libblas-dev liblapack-dev liblapacke-dev -y
|
||||||
|
|
||||||
|
From here follow the instructions to install either the :ref:`Python <python
|
||||||
|
install>` or :ref:`C++ <cpp install>` APIs.
|
||||||
|
|
||||||
|
CUDA
|
||||||
|
^^^^
|
||||||
|
|
||||||
|
To build from source on Linux with CUDA, install the BLAS and LAPACK headers
|
||||||
|
and the CUDA toolkit. For example on Ubuntu, run the following:
|
||||||
|
|
||||||
|
.. code-block:: shell
|
||||||
|
|
||||||
|
wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64/cuda-keyring_1.1-1_all.deb
|
||||||
|
dpkg -i cuda-keyring_1.1-1_all.deb
|
||||||
|
apt-get update -y
|
||||||
|
apt-get -y install cuda-toolkit-12-9
|
||||||
|
apt-get install libblas-dev liblapack-dev liblapacke-dev -y
|
||||||
|
|
||||||
|
|
||||||
|
When building either the Python or C++ APIs make sure to pass the cmake flag
|
||||||
|
``MLX_BUILD_CUDA=ON``. For example, to build the Python API run:
|
||||||
|
|
||||||
|
.. code-block:: shell
|
||||||
|
|
||||||
|
CMAKE_BUILD_PARALLEL_LEVEL=8 CMAKE_ARGS="-DMLX_BUILD_CUDA=ON" pip install -e ".[dev]"
|
||||||
|
|
||||||
|
To build the C++ package run:
|
||||||
|
|
||||||
|
.. code-block:: shell
|
||||||
|
|
||||||
|
mkdir -p build && cd build
|
||||||
|
cmake .. -DMLX_BUILD_CUDA=ON && make -j
|
||||||
|
|
||||||
|
|
||||||
Troubleshooting
|
Troubleshooting
|
||||||
^^^^^^^^^^^^^^^
|
^^^^^^^^^^^^^^^
|
||||||
|
|
||||||
|
@ -107,6 +107,16 @@ same array:
|
|||||||
>>> a
|
>>> a
|
||||||
array([1, 2, 0], dtype=int32)
|
array([1, 2, 0], dtype=int32)
|
||||||
|
|
||||||
|
|
||||||
|
Note, unlike NumPy, updates to the same location are nondeterministic:
|
||||||
|
|
||||||
|
.. code-block:: shell
|
||||||
|
|
||||||
|
>>> a = mx.array([1, 2, 3])
|
||||||
|
>>> a[[0, 0]] = mx.array([4, 5])
|
||||||
|
|
||||||
|
The first element of ``a`` could be ``4`` or ``5``.
|
||||||
|
|
||||||
Transformations of functions which use in-place updates are allowed and work as
|
Transformations of functions which use in-place updates are allowed and work as
|
||||||
expected. For example:
|
expected. For example:
|
||||||
|
|
||||||
|
@ -8,6 +8,7 @@ target_sources(
|
|||||||
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp
|
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/arg_reduce.cu
|
${CMAKE_CURRENT_SOURCE_DIR}/arg_reduce.cu
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/binary.cu
|
${CMAKE_CURRENT_SOURCE_DIR}/binary.cu
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/binary_two.cu
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/compiled.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/compiled.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/copy.cu
|
${CMAKE_CURRENT_SOURCE_DIR}/copy.cu
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/copy/copy_contiguous.cu
|
${CMAKE_CURRENT_SOURCE_DIR}/copy/copy_contiguous.cu
|
||||||
|
@ -3,6 +3,7 @@
|
|||||||
#include "mlx/backend/cuda/allocator.h"
|
#include "mlx/backend/cuda/allocator.h"
|
||||||
#include "mlx/backend/cuda/utils.h"
|
#include "mlx/backend/cuda/utils.h"
|
||||||
#include "mlx/backend/cuda/worker.h"
|
#include "mlx/backend/cuda/worker.h"
|
||||||
|
#include "mlx/utils.h"
|
||||||
|
|
||||||
#include <cuda_runtime.h>
|
#include <cuda_runtime.h>
|
||||||
#include <fmt/format.h>
|
#include <fmt/format.h>
|
||||||
@ -14,9 +15,11 @@ namespace mlx::core {
|
|||||||
|
|
||||||
namespace cu {
|
namespace cu {
|
||||||
|
|
||||||
|
constexpr int page_size = 16384;
|
||||||
|
|
||||||
CudaAllocator::CudaAllocator()
|
CudaAllocator::CudaAllocator()
|
||||||
: buffer_cache_(
|
: buffer_cache_(
|
||||||
getpagesize(),
|
page_size,
|
||||||
[](CudaBuffer* buf) { return buf->size; },
|
[](CudaBuffer* buf) { return buf->size; },
|
||||||
[this](CudaBuffer* buf) {
|
[this](CudaBuffer* buf) {
|
||||||
cuda_free(buf->data);
|
cuda_free(buf->data);
|
||||||
@ -31,7 +34,14 @@ CudaAllocator::CudaAllocator()
|
|||||||
|
|
||||||
Buffer CudaAllocator::malloc(size_t size) {
|
Buffer CudaAllocator::malloc(size_t size) {
|
||||||
// Find available buffer from cache.
|
// Find available buffer from cache.
|
||||||
|
auto orig_size = size;
|
||||||
std::unique_lock lock(mutex_);
|
std::unique_lock lock(mutex_);
|
||||||
|
if (size < page_size) {
|
||||||
|
size = next_power_of_2(size);
|
||||||
|
} else {
|
||||||
|
size = page_size * ((size + page_size - 1) / page_size);
|
||||||
|
}
|
||||||
|
|
||||||
CudaBuffer* buf = buffer_cache_.reuse_from_cache(size);
|
CudaBuffer* buf = buffer_cache_.reuse_from_cache(size);
|
||||||
if (!buf) {
|
if (!buf) {
|
||||||
// If we have a lot of memory pressure or are over the maximum cache size,
|
// If we have a lot of memory pressure or are over the maximum cache size,
|
||||||
@ -106,7 +116,6 @@ void CudaAllocator::cuda_free(void* buf) {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
cudaFree(buf);
|
cudaFree(buf);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -101,10 +101,12 @@ constexpr bool supports_binary_op() {
|
|||||||
return std::is_same_v<Out, bool> && std::is_same_v<In, bool>;
|
return std::is_same_v<Out, bool> && std::is_same_v<In, bool>;
|
||||||
}
|
}
|
||||||
if (std::is_same_v<Op, NaNEqual>) {
|
if (std::is_same_v<Op, NaNEqual>) {
|
||||||
return std::is_same_v<Out, bool> &&
|
return std::is_same_v<Out, bool> && is_inexact_v<In>;
|
||||||
(is_floating_v<In> || std::is_same_v<In, complex64_t>);
|
|
||||||
}
|
}
|
||||||
if (std::is_same_v<Op, LogAddExp> || std::is_same_v<Op, ArcTan2>) {
|
if (std::is_same_v<Op, LogAddExp>) {
|
||||||
|
return std::is_same_v<In, Out> && is_inexact_v<In>;
|
||||||
|
}
|
||||||
|
if (std::is_same_v<Op, ArcTan2>) {
|
||||||
return std::is_same_v<In, Out> && is_floating_v<In>;
|
return std::is_same_v<In, Out> && is_floating_v<In>;
|
||||||
}
|
}
|
||||||
if (std::is_same_v<Op, BitwiseAnd> || std::is_same_v<Op, BitwiseOr> ||
|
if (std::is_same_v<Op, BitwiseAnd> || std::is_same_v<Op, BitwiseOr> ||
|
||||||
@ -123,13 +125,12 @@ constexpr bool supports_binary_op() {
|
|||||||
template <typename Op>
|
template <typename Op>
|
||||||
void binary_op_gpu_inplace(
|
void binary_op_gpu_inplace(
|
||||||
const std::vector<array>& inputs,
|
const std::vector<array>& inputs,
|
||||||
std::vector<array>& outputs,
|
array& out,
|
||||||
std::string_view op,
|
std::string_view op,
|
||||||
const Stream& s) {
|
const Stream& s) {
|
||||||
assert(inputs.size() > 1);
|
assert(inputs.size() > 1);
|
||||||
const auto& a = inputs[0];
|
const auto& a = inputs[0];
|
||||||
const auto& b = inputs[1];
|
const auto& b = inputs[1];
|
||||||
auto& out = outputs[0];
|
|
||||||
if (out.size() == 0) {
|
if (out.size() == 0) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
@ -144,16 +145,15 @@ void binary_op_gpu_inplace(
|
|||||||
if constexpr (cu::supports_binary_op<Op, CTYPE_IN, CTYPE_OUT>()) {
|
if constexpr (cu::supports_binary_op<Op, CTYPE_IN, CTYPE_OUT>()) {
|
||||||
using InType = cuda_type_t<CTYPE_IN>;
|
using InType = cuda_type_t<CTYPE_IN>;
|
||||||
using OutType = cuda_type_t<CTYPE_OUT>;
|
using OutType = cuda_type_t<CTYPE_OUT>;
|
||||||
|
|
||||||
auto bopt = get_binary_op_type(a, b);
|
auto bopt = get_binary_op_type(a, b);
|
||||||
if (bopt == BinaryOpType::General) {
|
if (bopt == BinaryOpType::General) {
|
||||||
auto [shape, strides] = collapse_contiguous_dims(a, b, out);
|
auto [shape, strides] = collapse_contiguous_dims(a, b, out);
|
||||||
auto& a_strides = strides[0];
|
auto& a_strides = strides[0];
|
||||||
auto& b_strides = strides[1];
|
auto& b_strides = strides[1];
|
||||||
bool large = a.data_size() > UINT32_MAX ||
|
bool large = a.data_size() > INT32_MAX ||
|
||||||
b.data_size() > UINT32_MAX || out.data_size() > UINT32_MAX;
|
b.data_size() > INT32_MAX || out.data_size() > INT32_MAX;
|
||||||
MLX_SWITCH_BOOL(large, LARGE, {
|
MLX_SWITCH_BOOL(large, LARGE, {
|
||||||
using IdxT = std::conditional_t<LARGE, int64_t, uint32_t>;
|
using IdxT = std::conditional_t<LARGE, int64_t, int32_t>;
|
||||||
int ndim = shape.size();
|
int ndim = shape.size();
|
||||||
if (ndim <= 3) {
|
if (ndim <= 3) {
|
||||||
MLX_SWITCH_1_2_3(ndim, NDIM, {
|
MLX_SWITCH_1_2_3(ndim, NDIM, {
|
||||||
@ -165,7 +165,7 @@ void binary_op_gpu_inplace(
|
|||||||
a.data<InType>(),
|
a.data<InType>(),
|
||||||
b.data<InType>(),
|
b.data<InType>(),
|
||||||
out.data<OutType>(),
|
out.data<OutType>(),
|
||||||
out.data_size(),
|
out.size(),
|
||||||
const_param<NDIM>(shape),
|
const_param<NDIM>(shape),
|
||||||
const_param<NDIM>(a_strides),
|
const_param<NDIM>(a_strides),
|
||||||
const_param<NDIM>(b_strides));
|
const_param<NDIM>(b_strides));
|
||||||
@ -178,7 +178,7 @@ void binary_op_gpu_inplace(
|
|||||||
a.data<InType>(),
|
a.data<InType>(),
|
||||||
b.data<InType>(),
|
b.data<InType>(),
|
||||||
out.data<OutType>(),
|
out.data<OutType>(),
|
||||||
out.data_size(),
|
out.size(),
|
||||||
const_param(shape),
|
const_param(shape),
|
||||||
const_param(a_strides),
|
const_param(a_strides),
|
||||||
const_param(b_strides),
|
const_param(b_strides),
|
||||||
@ -196,8 +196,8 @@ void binary_op_gpu_inplace(
|
|||||||
} else if (bopt == BinaryOpType::VectorVector) {
|
} else if (bopt == BinaryOpType::VectorVector) {
|
||||||
kernel = cu::binary_vv<Op, InType, OutType, IdxT>;
|
kernel = cu::binary_vv<Op, InType, OutType, IdxT>;
|
||||||
}
|
}
|
||||||
auto [num_blocks, block_dims] =
|
auto [num_blocks, block_dims] = get_launch_args(
|
||||||
get_launch_args(kernel, out, LARGE);
|
kernel, out.data_size(), out.shape(), out.strides(), LARGE);
|
||||||
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
||||||
a.data<InType>(),
|
a.data<InType>(),
|
||||||
b.data<InType>(),
|
b.data<InType>(),
|
||||||
@ -217,20 +217,6 @@ void binary_op_gpu_inplace(
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename Op>
|
|
||||||
void binary_op_gpu(
|
|
||||||
const std::vector<array>& inputs,
|
|
||||||
std::vector<array>& outputs,
|
|
||||||
std::string_view op,
|
|
||||||
const Stream& s) {
|
|
||||||
auto& a = inputs[0];
|
|
||||||
auto& b = inputs[1];
|
|
||||||
auto bopt = get_binary_op_type(a, b);
|
|
||||||
set_binary_op_output_data(a, b, outputs[0], bopt);
|
|
||||||
set_binary_op_output_data(a, b, outputs[1], bopt);
|
|
||||||
binary_op_gpu_inplace<Op>(inputs, outputs, op, s);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename Op>
|
template <typename Op>
|
||||||
void binary_op_gpu(
|
void binary_op_gpu(
|
||||||
const std::vector<array>& inputs,
|
const std::vector<array>& inputs,
|
||||||
@ -241,8 +227,7 @@ void binary_op_gpu(
|
|||||||
auto& b = inputs[1];
|
auto& b = inputs[1];
|
||||||
auto bopt = get_binary_op_type(a, b);
|
auto bopt = get_binary_op_type(a, b);
|
||||||
set_binary_op_output_data(a, b, out, bopt);
|
set_binary_op_output_data(a, b, out, bopt);
|
||||||
std::vector<array> outputs{out};
|
binary_op_gpu_inplace<Op>(inputs, out, op, s);
|
||||||
binary_op_gpu_inplace<Op>(inputs, outputs, op, s);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#define BINARY_GPU(func) \
|
#define BINARY_GPU(func) \
|
||||||
@ -252,19 +237,10 @@ void binary_op_gpu(
|
|||||||
binary_op_gpu<cu::func>(inputs, out, get_primitive_string(this), s); \
|
binary_op_gpu<cu::func>(inputs, out, get_primitive_string(this), s); \
|
||||||
}
|
}
|
||||||
|
|
||||||
#define BINARY_GPU_MULTI(func) \
|
|
||||||
void func::eval_gpu( \
|
|
||||||
const std::vector<array>& inputs, std::vector<array>& outputs) { \
|
|
||||||
nvtx3::scoped_range r(#func "::eval_gpu"); \
|
|
||||||
auto& s = outputs[0].primitive().stream(); \
|
|
||||||
binary_op_gpu<cu::func>(inputs, outputs, get_primitive_string(this), s); \
|
|
||||||
}
|
|
||||||
|
|
||||||
BINARY_GPU(Add)
|
BINARY_GPU(Add)
|
||||||
BINARY_GPU(ArcTan2)
|
BINARY_GPU(ArcTan2)
|
||||||
BINARY_GPU(Divide)
|
BINARY_GPU(Divide)
|
||||||
BINARY_GPU(Remainder)
|
BINARY_GPU(Remainder)
|
||||||
BINARY_GPU(Equal)
|
|
||||||
BINARY_GPU(Greater)
|
BINARY_GPU(Greater)
|
||||||
BINARY_GPU(GreaterEqual)
|
BINARY_GPU(GreaterEqual)
|
||||||
BINARY_GPU(Less)
|
BINARY_GPU(Less)
|
||||||
@ -279,6 +255,17 @@ BINARY_GPU(NotEqual)
|
|||||||
BINARY_GPU(Power)
|
BINARY_GPU(Power)
|
||||||
BINARY_GPU(Subtract)
|
BINARY_GPU(Subtract)
|
||||||
|
|
||||||
|
void Equal::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
nvtx3::scoped_range r("Equal::eval_gpu");
|
||||||
|
auto& s = out.primitive().stream();
|
||||||
|
auto op = get_primitive_string(this);
|
||||||
|
if (equal_nan_) {
|
||||||
|
binary_op_gpu<cu::NaNEqual>(inputs, out, op, s);
|
||||||
|
} else {
|
||||||
|
binary_op_gpu<cu::Equal>(inputs, out, op, s);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
void BitwiseBinary::eval_gpu(const std::vector<array>& inputs, array& out) {
|
void BitwiseBinary::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
nvtx3::scoped_range r("BitwiseBinary::eval_gpu");
|
nvtx3::scoped_range r("BitwiseBinary::eval_gpu");
|
||||||
auto& s = out.primitive().stream();
|
auto& s = out.primitive().stream();
|
||||||
|
248
mlx/backend/cuda/binary_two.cu
Normal file
248
mlx/backend/cuda/binary_two.cu
Normal file
@ -0,0 +1,248 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#include "mlx/backend/common/binary.h"
|
||||||
|
#include "mlx/backend/cuda/device.h"
|
||||||
|
#include "mlx/backend/cuda/device/binary_ops.cuh"
|
||||||
|
#include "mlx/backend/cuda/device/cucomplex_math.cuh"
|
||||||
|
#include "mlx/backend/cuda/kernel_utils.cuh"
|
||||||
|
#include "mlx/dtype_utils.h"
|
||||||
|
#include "mlx/primitives.h"
|
||||||
|
|
||||||
|
#include <cooperative_groups.h>
|
||||||
|
#include <nvtx3/nvtx3.hpp>
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
|
||||||
|
namespace cu {
|
||||||
|
|
||||||
|
namespace cg = cooperative_groups;
|
||||||
|
|
||||||
|
template <typename Op, typename In, typename Out, typename IdxT>
|
||||||
|
__global__ void
|
||||||
|
binary_ss(const In* a, const In* b, Out* out_a, Out* out_b, IdxT size) {
|
||||||
|
IdxT index = cg::this_grid().thread_rank();
|
||||||
|
if (index < size) {
|
||||||
|
auto out = Op{}(a[0], b[0]);
|
||||||
|
out_a[0] = out[0];
|
||||||
|
out_b[0] = out[1];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename Op, typename In, typename Out, typename IdxT>
|
||||||
|
__global__ void
|
||||||
|
binary_sv(const In* a, const In* b, Out* out_a, Out* out_b, IdxT size) {
|
||||||
|
IdxT index = cg::this_grid().thread_rank();
|
||||||
|
if (index < size) {
|
||||||
|
auto out = Op{}(a[0], b[index]);
|
||||||
|
out_a[index] = out[0];
|
||||||
|
out_b[index] = out[1];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename Op, typename In, typename Out, typename IdxT>
|
||||||
|
__global__ void
|
||||||
|
binary_vs(const In* a, const In* b, Out* out_a, Out* out_b, IdxT size) {
|
||||||
|
IdxT index = cg::this_grid().thread_rank();
|
||||||
|
if (index < size) {
|
||||||
|
auto out = Op{}(a[index], b[0]);
|
||||||
|
out_a[index] = out[0];
|
||||||
|
out_b[index] = out[1];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename Op, typename In, typename Out, typename IdxT>
|
||||||
|
__global__ void
|
||||||
|
binary_vv(const In* a, const In* b, Out* out_a, Out* out_b, IdxT size) {
|
||||||
|
IdxT index = cg::this_grid().thread_rank();
|
||||||
|
if (index < size) {
|
||||||
|
auto out = Op{}(a[index], b[index]);
|
||||||
|
out_a[index] = out[0];
|
||||||
|
out_b[index] = out[1];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename Op, typename In, typename Out, typename IdxT, int NDIM>
|
||||||
|
__global__ void binary_g_nd(
|
||||||
|
const In* a,
|
||||||
|
const In* b,
|
||||||
|
Out* out_a,
|
||||||
|
Out* out_b,
|
||||||
|
IdxT size,
|
||||||
|
const __grid_constant__ cuda::std::array<int32_t, NDIM> shape,
|
||||||
|
const __grid_constant__ cuda::std::array<int64_t, NDIM> a_strides,
|
||||||
|
const __grid_constant__ cuda::std::array<int64_t, NDIM> b_strides) {
|
||||||
|
IdxT index = cg::this_grid().thread_rank();
|
||||||
|
if (index < size) {
|
||||||
|
auto [a_idx, b_idx] = elem_to_loc_nd<NDIM>(
|
||||||
|
index, shape.data(), a_strides.data(), b_strides.data());
|
||||||
|
auto out = Op{}(a[a_idx], b[b_idx]);
|
||||||
|
out_a[index] = out[0];
|
||||||
|
out_b[index] = out[1];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename Op, typename In, typename Out, typename IdxT>
|
||||||
|
__global__ void binary_g(
|
||||||
|
const In* a,
|
||||||
|
const In* b,
|
||||||
|
Out* out_a,
|
||||||
|
Out* out_b,
|
||||||
|
IdxT size,
|
||||||
|
const __grid_constant__ Shape shape,
|
||||||
|
const __grid_constant__ Strides a_strides,
|
||||||
|
const __grid_constant__ Strides b_strides,
|
||||||
|
int ndim) {
|
||||||
|
IdxT index = cg::this_grid().thread_rank();
|
||||||
|
if (index < size) {
|
||||||
|
auto [a_idx, b_idx] = elem_to_loc_4d(
|
||||||
|
index, shape.data(), a_strides.data(), b_strides.data(), ndim);
|
||||||
|
auto out = Op{}(a[a_idx], b[b_idx]);
|
||||||
|
out_a[index] = out[0];
|
||||||
|
out_b[index] = out[1];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename Op, typename In, typename Out>
|
||||||
|
constexpr bool supports_binary_op() {
|
||||||
|
if (std::is_same_v<Op, DivMod>) {
|
||||||
|
return std::is_same_v<In, Out> &&
|
||||||
|
(std::is_integral_v<Out> || is_floating_v<Out>);
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace cu
|
||||||
|
|
||||||
|
template <typename Op>
|
||||||
|
void binary_op_gpu_inplace(
|
||||||
|
const std::vector<array>& inputs,
|
||||||
|
std::vector<array>& outputs,
|
||||||
|
std::string_view op,
|
||||||
|
const Stream& s) {
|
||||||
|
assert(inputs.size() > 1);
|
||||||
|
const auto& a = inputs[0];
|
||||||
|
const auto& b = inputs[1];
|
||||||
|
auto& out_a = outputs[0];
|
||||||
|
auto& out_b = outputs[1];
|
||||||
|
auto bopt = get_binary_op_type(a, b);
|
||||||
|
set_binary_op_output_data(a, b, out_a, bopt);
|
||||||
|
set_binary_op_output_data(a, b, out_b, bopt);
|
||||||
|
|
||||||
|
if (out_a.size() == 0) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto& encoder = cu::get_command_encoder(s);
|
||||||
|
encoder.set_input_array(a);
|
||||||
|
encoder.set_input_array(b);
|
||||||
|
encoder.set_output_array(out_a);
|
||||||
|
encoder.set_output_array(out_b);
|
||||||
|
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||||
|
MLX_SWITCH_ALL_TYPES(a.dtype(), CTYPE_IN, {
|
||||||
|
MLX_SWITCH_ALL_TYPES(out_a.dtype(), CTYPE_OUT, {
|
||||||
|
if constexpr (cu::supports_binary_op<Op, CTYPE_IN, CTYPE_OUT>()) {
|
||||||
|
using InType = cuda_type_t<CTYPE_IN>;
|
||||||
|
using OutType = cuda_type_t<CTYPE_OUT>;
|
||||||
|
|
||||||
|
auto bopt = get_binary_op_type(a, b);
|
||||||
|
if (bopt == BinaryOpType::General) {
|
||||||
|
auto [shape, strides] = collapse_contiguous_dims(a, b, out_a);
|
||||||
|
auto& a_strides = strides[0];
|
||||||
|
auto& b_strides = strides[1];
|
||||||
|
bool large = a.data_size() > INT32_MAX ||
|
||||||
|
b.data_size() > INT32_MAX || out_a.data_size() > INT32_MAX;
|
||||||
|
MLX_SWITCH_BOOL(large, LARGE, {
|
||||||
|
using IdxT = std::conditional_t<LARGE, int64_t, int32_t>;
|
||||||
|
int ndim = shape.size();
|
||||||
|
if (ndim <= 3) {
|
||||||
|
MLX_SWITCH_1_2_3(ndim, NDIM, {
|
||||||
|
auto kernel =
|
||||||
|
&cu::binary_g_nd<Op, InType, OutType, IdxT, NDIM>;
|
||||||
|
auto [num_blocks, block_dims] =
|
||||||
|
get_launch_args(kernel, out_a, large);
|
||||||
|
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
||||||
|
a.data<InType>(),
|
||||||
|
b.data<InType>(),
|
||||||
|
out_a.data<OutType>(),
|
||||||
|
out_b.data<OutType>(),
|
||||||
|
out_a.size(),
|
||||||
|
const_param<NDIM>(shape),
|
||||||
|
const_param<NDIM>(a_strides),
|
||||||
|
const_param<NDIM>(b_strides));
|
||||||
|
});
|
||||||
|
} else {
|
||||||
|
auto kernel = cu::binary_g<Op, InType, OutType, IdxT>;
|
||||||
|
auto [num_blocks, block_dims] =
|
||||||
|
get_launch_args(kernel, out_a, large);
|
||||||
|
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
||||||
|
a.data<InType>(),
|
||||||
|
b.data<InType>(),
|
||||||
|
out_a.data<OutType>(),
|
||||||
|
out_b.data<OutType>(),
|
||||||
|
out_a.size(),
|
||||||
|
const_param(shape),
|
||||||
|
const_param(a_strides),
|
||||||
|
const_param(b_strides),
|
||||||
|
ndim);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
} else {
|
||||||
|
MLX_SWITCH_BOOL(out_a.data_size() > UINT32_MAX, LARGE, {
|
||||||
|
using IdxT = std::conditional_t<LARGE, int64_t, uint32_t>;
|
||||||
|
auto kernel = cu::binary_ss<Op, InType, OutType, IdxT>;
|
||||||
|
if (bopt == BinaryOpType::ScalarVector) {
|
||||||
|
kernel = cu::binary_sv<Op, InType, OutType, IdxT>;
|
||||||
|
} else if (bopt == BinaryOpType::VectorScalar) {
|
||||||
|
kernel = cu::binary_vs<Op, InType, OutType, IdxT>;
|
||||||
|
} else if (bopt == BinaryOpType::VectorVector) {
|
||||||
|
kernel = cu::binary_vv<Op, InType, OutType, IdxT>;
|
||||||
|
}
|
||||||
|
auto [num_blocks, block_dims] = get_launch_args(
|
||||||
|
kernel,
|
||||||
|
out_a.data_size(),
|
||||||
|
out_a.shape(),
|
||||||
|
out_a.strides(),
|
||||||
|
LARGE);
|
||||||
|
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
||||||
|
a.data<InType>(),
|
||||||
|
b.data<InType>(),
|
||||||
|
out_a.data<OutType>(),
|
||||||
|
out_b.data<OutType>(),
|
||||||
|
out_a.data_size());
|
||||||
|
});
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
throw std::runtime_error(fmt::format(
|
||||||
|
"Can not do binary op {} on inputs of {} with result of {}.",
|
||||||
|
op,
|
||||||
|
dtype_to_string(a.dtype()),
|
||||||
|
dtype_to_string(out_a.dtype())));
|
||||||
|
}
|
||||||
|
});
|
||||||
|
});
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename Op>
|
||||||
|
void binary_op_gpu(
|
||||||
|
const std::vector<array>& inputs,
|
||||||
|
std::vector<array>& outputs,
|
||||||
|
std::string_view op,
|
||||||
|
const Stream& s) {
|
||||||
|
auto& a = inputs[0];
|
||||||
|
auto& b = inputs[1];
|
||||||
|
auto bopt = get_binary_op_type(a, b);
|
||||||
|
set_binary_op_output_data(a, b, outputs[0], bopt);
|
||||||
|
set_binary_op_output_data(a, b, outputs[1], bopt);
|
||||||
|
binary_op_gpu_inplace<Op>(inputs, outputs, op, s);
|
||||||
|
}
|
||||||
|
|
||||||
|
void DivMod::eval_gpu(
|
||||||
|
const std::vector<array>& inputs,
|
||||||
|
std::vector<array>& outputs) {
|
||||||
|
nvtx3::scoped_range r("DivMod::eval_gpu");
|
||||||
|
auto& s = outputs[0].primitive().stream();
|
||||||
|
binary_op_gpu<cu::DivMod>(inputs, outputs, get_primitive_string(this), s);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core
|
@ -130,11 +130,13 @@ struct FusedKernelBuilder {
|
|||||||
|
|
||||||
constexpr const char* g_jit_includes = R"(
|
constexpr const char* g_jit_includes = R"(
|
||||||
#include "mlx/backend/cuda/device/binary_ops.cuh"
|
#include "mlx/backend/cuda/device/binary_ops.cuh"
|
||||||
|
#include "mlx/backend/cuda/device/ternary_ops.cuh"
|
||||||
#include "mlx/backend/cuda/device/unary_ops.cuh"
|
#include "mlx/backend/cuda/device/unary_ops.cuh"
|
||||||
#include "mlx/backend/cuda/device/utils.cuh"
|
#include "mlx/backend/cuda/device/utils.cuh"
|
||||||
|
|
||||||
#include <cooperative_groups.h>
|
#include <cooperative_groups.h>
|
||||||
|
|
||||||
|
#define inf cuda::std::numeric_limits<float>::infinity()
|
||||||
)";
|
)";
|
||||||
|
|
||||||
void Compiled::eval_gpu(
|
void Compiled::eval_gpu(
|
||||||
|
@ -6,7 +6,7 @@
|
|||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
|
|
||||||
void copy_gpu_inplace(
|
void copy_gpu_inplace(
|
||||||
const array& in_,
|
const array& in,
|
||||||
array& out,
|
array& out,
|
||||||
const Shape& shape,
|
const Shape& shape,
|
||||||
const Strides& strides_in,
|
const Strides& strides_in,
|
||||||
@ -20,12 +20,10 @@ void copy_gpu_inplace(
|
|||||||
if (out.size() == 0) {
|
if (out.size() == 0) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
const array& in = in_.data_shared_ptr() ? in_ : out;
|
|
||||||
|
|
||||||
auto& encoder = cu::get_command_encoder(s);
|
auto& encoder = cu::get_command_encoder(s);
|
||||||
encoder.set_input_array(in);
|
encoder.set_input_array(in);
|
||||||
encoder.set_output_array(out);
|
encoder.set_output_array(out);
|
||||||
|
|
||||||
if (ctype == CopyType::Scalar || ctype == CopyType::Vector) {
|
if (ctype == CopyType::Scalar || ctype == CopyType::Vector) {
|
||||||
copy_contiguous(encoder, ctype, in, out, offset_in, offset_out);
|
copy_contiguous(encoder, ctype, in, out, offset_in, offset_out);
|
||||||
return;
|
return;
|
||||||
|
@ -10,20 +10,13 @@
|
|||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
|
|
||||||
#define MLX_SWITCH_COPY_TYPES(in, out, InType, OutType, ...) \
|
#define MLX_SWITCH_COPY_TYPES(in, out, InType, OutType, ...) \
|
||||||
MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE_IN, { \
|
MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE_IN, { \
|
||||||
MLX_SWITCH_ALL_TYPES(out.dtype(), CTYPE_OUT, { \
|
MLX_SWITCH_ALL_TYPES(out.dtype(), CTYPE_OUT, { \
|
||||||
using InType = cuda_type_t<CTYPE_IN>; \
|
using InType = cuda_type_t<CTYPE_IN>; \
|
||||||
using OutType = cuda_type_t<CTYPE_OUT>; \
|
using OutType = cuda_type_t<CTYPE_OUT>; \
|
||||||
if constexpr (cu::CastOp<InType, OutType>::is_castable) { \
|
__VA_ARGS__; \
|
||||||
__VA_ARGS__; \
|
}); \
|
||||||
} else { \
|
|
||||||
throw std::runtime_error(fmt::format( \
|
|
||||||
"Can not copy data from dtype {} to {}.", \
|
|
||||||
dtype_to_string(out.dtype()), \
|
|
||||||
dtype_to_string(in.dtype()))); \
|
|
||||||
} \
|
|
||||||
}); \
|
|
||||||
})
|
})
|
||||||
|
|
||||||
void copy_contiguous(
|
void copy_contiguous(
|
||||||
|
@ -43,7 +43,8 @@ void copy_contiguous(
|
|||||||
if (ctype == CopyType::Vector) {
|
if (ctype == CopyType::Vector) {
|
||||||
kernel = cu::copy_v<InType, OutType, IdxT>;
|
kernel = cu::copy_v<InType, OutType, IdxT>;
|
||||||
}
|
}
|
||||||
auto [num_blocks, block_dims] = get_launch_args(kernel, out, LARGE);
|
auto [num_blocks, block_dims] = get_launch_args(
|
||||||
|
kernel, out.data_size(), out.shape(), out.strides(), LARGE);
|
||||||
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
||||||
in.data<InType>() + in_offset,
|
in.data<InType>() + in_offset,
|
||||||
out.data<OutType>() + out_offset,
|
out.data<OutType>() + out_offset,
|
||||||
|
@ -59,29 +59,34 @@ void copy_general(
|
|||||||
MLX_SWITCH_COPY_TYPES(in, out, InType, OutType, {
|
MLX_SWITCH_COPY_TYPES(in, out, InType, OutType, {
|
||||||
const InType* in_ptr = in.data<InType>() + offset_in;
|
const InType* in_ptr = in.data<InType>() + offset_in;
|
||||||
OutType* out_ptr = out.data<OutType>() + offset_out;
|
OutType* out_ptr = out.data<OutType>() + offset_out;
|
||||||
bool large = in.data_size() > UINT32_MAX || out.data_size() > UINT32_MAX;
|
bool large = in.data_size() > INT32_MAX || out.data_size() > INT32_MAX;
|
||||||
MLX_SWITCH_BOOL(large, LARGE, {
|
MLX_SWITCH_BOOL(large, LARGE, {
|
||||||
using IdxT = std::conditional_t<LARGE, int64_t, uint32_t>;
|
using IdxT = std::conditional_t<LARGE, int64_t, int32_t>;
|
||||||
int ndim = shape.size();
|
int ndim = shape.size();
|
||||||
|
size_t data_size = 1;
|
||||||
|
for (auto& s : shape)
|
||||||
|
data_size *= s;
|
||||||
if (ndim <= 3) {
|
if (ndim <= 3) {
|
||||||
MLX_SWITCH_1_2_3(ndim, NDIM, {
|
MLX_SWITCH_1_2_3(ndim, NDIM, {
|
||||||
auto kernel = cu::copy_gg_nd<InType, OutType, IdxT, NDIM>;
|
auto kernel = cu::copy_gg_nd<InType, OutType, IdxT, NDIM>;
|
||||||
auto [num_blocks, block_dims] = get_launch_args(kernel, out, large);
|
auto [num_blocks, block_dims] =
|
||||||
|
get_launch_args(kernel, data_size, shape, out.strides(), large);
|
||||||
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
||||||
in_ptr,
|
in_ptr,
|
||||||
out_ptr,
|
out_ptr,
|
||||||
out.data_size(),
|
data_size,
|
||||||
const_param<NDIM>(shape),
|
const_param<NDIM>(shape),
|
||||||
const_param<NDIM>(strides_in),
|
const_param<NDIM>(strides_in),
|
||||||
const_param<NDIM>(strides_out));
|
const_param<NDIM>(strides_out));
|
||||||
});
|
});
|
||||||
} else { // ndim >= 4
|
} else { // ndim >= 4
|
||||||
auto kernel = cu::copy_gg<InType, OutType, IdxT>;
|
auto kernel = cu::copy_gg<InType, OutType, IdxT>;
|
||||||
auto [num_blocks, block_dims] = get_launch_args(kernel, out, large);
|
auto [num_blocks, block_dims] =
|
||||||
|
get_launch_args(kernel, data_size, shape, out.strides(), large);
|
||||||
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
||||||
in_ptr,
|
in_ptr,
|
||||||
out_ptr,
|
out_ptr,
|
||||||
out.data_size(),
|
data_size,
|
||||||
const_param(shape),
|
const_param(shape),
|
||||||
const_param(strides_in),
|
const_param(strides_in),
|
||||||
const_param(strides_out),
|
const_param(strides_out),
|
||||||
|
@ -65,9 +65,9 @@ void copy_general_dynamic(
|
|||||||
MLX_SWITCH_COPY_TYPES(in, out, InType, OutType, {
|
MLX_SWITCH_COPY_TYPES(in, out, InType, OutType, {
|
||||||
const InType* in_ptr = in.data<InType>() + offset_in;
|
const InType* in_ptr = in.data<InType>() + offset_in;
|
||||||
OutType* out_ptr = out.data<OutType>() + offset_out;
|
OutType* out_ptr = out.data<OutType>() + offset_out;
|
||||||
bool large = in.data_size() > UINT32_MAX || out.data_size() > UINT32_MAX;
|
bool large = in.data_size() > INT32_MAX || out.data_size() > INT32_MAX;
|
||||||
MLX_SWITCH_BOOL(large, LARGE, {
|
MLX_SWITCH_BOOL(large, LARGE, {
|
||||||
using IdxT = std::conditional_t<LARGE, int64_t, uint32_t>;
|
using IdxT = std::conditional_t<LARGE, int64_t, int32_t>;
|
||||||
int ndim = shape.size();
|
int ndim = shape.size();
|
||||||
if (ndim <= 3) {
|
if (ndim <= 3) {
|
||||||
MLX_SWITCH_1_2_3(ndim, NDIM, {
|
MLX_SWITCH_1_2_3(ndim, NDIM, {
|
||||||
@ -76,7 +76,7 @@ void copy_general_dynamic(
|
|||||||
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
||||||
in_ptr,
|
in_ptr,
|
||||||
out_ptr,
|
out_ptr,
|
||||||
out.data_size(),
|
out.size(),
|
||||||
const_param<NDIM>(shape),
|
const_param<NDIM>(shape),
|
||||||
const_param<NDIM>(strides_in),
|
const_param<NDIM>(strides_in),
|
||||||
const_param<NDIM>(strides_out),
|
const_param<NDIM>(strides_out),
|
||||||
@ -89,7 +89,7 @@ void copy_general_dynamic(
|
|||||||
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
||||||
in_ptr,
|
in_ptr,
|
||||||
out_ptr,
|
out_ptr,
|
||||||
out.data_size(),
|
out.size(),
|
||||||
const_param(shape),
|
const_param(shape),
|
||||||
const_param(strides_in),
|
const_param(strides_in),
|
||||||
const_param(strides_out),
|
const_param(strides_out),
|
||||||
|
@ -54,9 +54,9 @@ void copy_general_input(
|
|||||||
MLX_SWITCH_COPY_TYPES(in, out, InType, OutType, {
|
MLX_SWITCH_COPY_TYPES(in, out, InType, OutType, {
|
||||||
const InType* in_ptr = in.data<InType>() + offset_in;
|
const InType* in_ptr = in.data<InType>() + offset_in;
|
||||||
OutType* out_ptr = out.data<OutType>() + offset_out;
|
OutType* out_ptr = out.data<OutType>() + offset_out;
|
||||||
bool large = in.data_size() > UINT32_MAX || out.data_size() > UINT32_MAX;
|
bool large = in.data_size() > INT32_MAX || out.data_size() > INT32_MAX;
|
||||||
MLX_SWITCH_BOOL(large, LARGE, {
|
MLX_SWITCH_BOOL(large, LARGE, {
|
||||||
using IdxT = std::conditional_t<LARGE, int64_t, uint32_t>;
|
using IdxT = std::conditional_t<LARGE, int64_t, int32_t>;
|
||||||
int ndim = shape.size();
|
int ndim = shape.size();
|
||||||
if (ndim <= 3) {
|
if (ndim <= 3) {
|
||||||
MLX_SWITCH_1_2_3(ndim, NDIM, {
|
MLX_SWITCH_1_2_3(ndim, NDIM, {
|
||||||
@ -65,7 +65,7 @@ void copy_general_input(
|
|||||||
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
||||||
in_ptr,
|
in_ptr,
|
||||||
out_ptr,
|
out_ptr,
|
||||||
out.data_size(),
|
out.size(),
|
||||||
const_param<NDIM>(shape),
|
const_param<NDIM>(shape),
|
||||||
const_param<NDIM>(strides_in));
|
const_param<NDIM>(strides_in));
|
||||||
});
|
});
|
||||||
@ -75,7 +75,7 @@ void copy_general_input(
|
|||||||
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
||||||
in_ptr,
|
in_ptr,
|
||||||
out_ptr,
|
out_ptr,
|
||||||
out.data_size(),
|
out.size(),
|
||||||
const_param(shape),
|
const_param(shape),
|
||||||
const_param(strides_in),
|
const_param(strides_in),
|
||||||
ndim);
|
ndim);
|
||||||
|
@ -6,6 +6,7 @@
|
|||||||
|
|
||||||
#include <fmt/format.h>
|
#include <fmt/format.h>
|
||||||
#include <nvtx3/nvtx3.hpp>
|
#include <nvtx3/nvtx3.hpp>
|
||||||
|
#include <future>
|
||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
|
|
||||||
@ -107,6 +108,16 @@ void CommandEncoder::commit() {
|
|||||||
worker_.commit(stream_.last_cuda_stream());
|
worker_.commit(stream_.last_cuda_stream());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void CommandEncoder::synchronize() {
|
||||||
|
stream().synchronize();
|
||||||
|
auto p = std::make_shared<std::promise<void>>();
|
||||||
|
std::future<void> f = p->get_future();
|
||||||
|
add_completed_handler([p = std::move(p)]() { p->set_value(); });
|
||||||
|
worker_.end_batch();
|
||||||
|
commit();
|
||||||
|
f.wait();
|
||||||
|
}
|
||||||
|
|
||||||
Device& device(mlx::core::Device device) {
|
Device& device(mlx::core::Device device) {
|
||||||
static std::unordered_map<int, Device> devices;
|
static std::unordered_map<int, Device> devices;
|
||||||
auto it = devices.find(device.index);
|
auto it = devices.find(device.index);
|
||||||
|
@ -123,6 +123,9 @@ class CommandEncoder {
|
|||||||
return has_gpu_work_;
|
return has_gpu_work_;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Wait until kernels and completion handlers are finished
|
||||||
|
void synchronize();
|
||||||
|
|
||||||
private:
|
private:
|
||||||
Device& device_;
|
Device& device_;
|
||||||
DeviceStream& stream_;
|
DeviceStream& stream_;
|
||||||
|
@ -1,6 +1,8 @@
|
|||||||
// Copyright © 2025 Apple Inc.
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#include "mlx/backend/cuda/device/cucomplex_math.cuh"
|
||||||
#include "mlx/backend/cuda/device/fp16_math.cuh"
|
#include "mlx/backend/cuda/device/fp16_math.cuh"
|
||||||
|
#include "mlx/backend/cuda/device/utils.cuh"
|
||||||
|
|
||||||
#include <cuComplex.h>
|
#include <cuComplex.h>
|
||||||
#include <cuda/std/array>
|
#include <cuda/std/array>
|
||||||
@ -20,7 +22,7 @@ struct FloorDivide {
|
|||||||
if constexpr (cuda::std::is_integral_v<T>) {
|
if constexpr (cuda::std::is_integral_v<T>) {
|
||||||
return x / y;
|
return x / y;
|
||||||
} else {
|
} else {
|
||||||
return trunc(x / y);
|
return truncf(x / y);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@ -122,6 +124,26 @@ struct LogAddExp {
|
|||||||
? maxval
|
? maxval
|
||||||
: T(float(maxval) + log1p(expf(minval - maxval)));
|
: T(float(maxval) + log1p(expf(minval - maxval)));
|
||||||
};
|
};
|
||||||
|
|
||||||
|
__device__ cuComplex operator()(cuComplex x, cuComplex y) {
|
||||||
|
if (isnan(cuCrealf(x)) || isnan(cuCimagf(x)) || isnan(cuCrealf(y)) ||
|
||||||
|
isnan(cuCimagf(y))) {
|
||||||
|
return {
|
||||||
|
cuda::std::numeric_limits<float>::quiet_NaN(),
|
||||||
|
cuda::std::numeric_limits<float>::quiet_NaN()};
|
||||||
|
}
|
||||||
|
float inf = cuda::std::numeric_limits<float>::infinity();
|
||||||
|
auto maxval = x > y ? x : y;
|
||||||
|
auto minval = x < y ? x : y;
|
||||||
|
if (cuCrealf(minval) == -inf || cuCrealf(maxval) == inf)
|
||||||
|
return maxval;
|
||||||
|
float m = exp(cuCrealf(minval) - cuCrealf(maxval));
|
||||||
|
cuComplex dexp{
|
||||||
|
m * cos(cuCimagf(minval) - cuCimagf(maxval)),
|
||||||
|
m * sin(cuCimagf(minval) - cuCimagf(maxval)),
|
||||||
|
};
|
||||||
|
return maxval + log1p(dexp);
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
struct Maximum {
|
struct Maximum {
|
||||||
|
@ -45,6 +45,18 @@ struct CastOp<
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
template <typename SrcT, typename DstT>
|
||||||
|
struct CastOp<
|
||||||
|
SrcT,
|
||||||
|
DstT,
|
||||||
|
cuda::std::enable_if_t<cuda::std::is_same_v<SrcT, DstT>>> {
|
||||||
|
static constexpr bool is_castable = true;
|
||||||
|
|
||||||
|
__device__ SrcT operator()(SrcT x) {
|
||||||
|
return x;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
// Return an iterator that cast the value to DstT using CastOp.
|
// Return an iterator that cast the value to DstT using CastOp.
|
||||||
template <typename DstT, typename Iterator>
|
template <typename DstT, typename Iterator>
|
||||||
__host__ __device__ auto make_cast_iterator(Iterator it) {
|
__host__ __device__ auto make_cast_iterator(Iterator it) {
|
||||||
|
@ -5,7 +5,7 @@
|
|||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
// The maximum dimensions of shape/strides passed as kernel parameters.
|
// The maximum dimensions of shape/strides passed as kernel parameters.
|
||||||
#define MAX_NDIM 8
|
#define MAX_NDIM 10
|
||||||
|
|
||||||
// All existing NVIDIA hardware has a fixed 32 warp size. Though a built-in
|
// All existing NVIDIA hardware has a fixed 32 warp size. Though a built-in
|
||||||
// warpSize variable exists, using it would prevent compile-time optimizations.
|
// warpSize variable exists, using it would prevent compile-time optimizations.
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
// Copyright © 2025 Apple Inc.
|
// Copyright © 2025 Apple Inc.
|
||||||
|
#pragma once
|
||||||
|
|
||||||
namespace mlx::core::cu {
|
namespace mlx::core::cu {
|
||||||
|
|
||||||
|
@ -5,6 +5,8 @@
|
|||||||
#include "mlx/backend/cuda/device/fp16_math.cuh"
|
#include "mlx/backend/cuda/device/fp16_math.cuh"
|
||||||
#include "mlx/backend/cuda/device/utils.cuh"
|
#include "mlx/backend/cuda/device/utils.cuh"
|
||||||
|
|
||||||
|
#include <math_constants.h>
|
||||||
|
|
||||||
namespace mlx::core::cu {
|
namespace mlx::core::cu {
|
||||||
|
|
||||||
struct Abs {
|
struct Abs {
|
||||||
@ -183,21 +185,38 @@ struct Imag {
|
|||||||
struct Log {
|
struct Log {
|
||||||
template <typename T>
|
template <typename T>
|
||||||
__device__ T operator()(T x) {
|
__device__ T operator()(T x) {
|
||||||
return log(x);
|
if constexpr (cuda::std::is_same_v<T, cuComplex>) {
|
||||||
|
auto r = log(cuCrealf(Abs{}(x)));
|
||||||
|
auto i = atan2f(cuCimagf(x), cuCrealf(x));
|
||||||
|
return {r, i};
|
||||||
|
} else {
|
||||||
|
return log(x);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
struct Log2 {
|
struct Log2 {
|
||||||
template <typename T>
|
template <typename T>
|
||||||
__device__ T operator()(T x) {
|
__device__ T operator()(T x) {
|
||||||
return log2(x);
|
if constexpr (cuda::std::is_same_v<T, cuComplex>) {
|
||||||
|
auto y = Log{}(x);
|
||||||
|
return {cuCrealf(y) / CUDART_LN2_F, cuCimagf(y) / CUDART_LN2_F};
|
||||||
|
} else {
|
||||||
|
return log2(x);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
struct Log10 {
|
struct Log10 {
|
||||||
template <typename T>
|
template <typename T>
|
||||||
__device__ T operator()(T x) {
|
__device__ T operator()(T x) {
|
||||||
return log10(x);
|
if constexpr (cuda::std::is_same_v<T, cuComplex>) {
|
||||||
|
auto y = Log{}(x);
|
||||||
|
return {cuCrealf(y) / CUDART_LNT_F, cuCimagf(y) / CUDART_LNT_F};
|
||||||
|
return y;
|
||||||
|
} else {
|
||||||
|
return log10(x);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -155,8 +155,8 @@ inline __host__ __device__ cuda::std::tuple<IdxT, IdxT> elem_to_loc_nd(
|
|||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i = NDIM - 1; i >= 0; --i) {
|
for (int i = NDIM - 1; i >= 0; --i) {
|
||||||
int dim_idx = elem % shape[i];
|
int dim_idx = elem % shape[i];
|
||||||
a_loc += dim_idx * a_strides[i];
|
a_loc += dim_idx * IdxT(a_strides[i]);
|
||||||
b_loc += dim_idx * b_strides[i];
|
b_loc += dim_idx * IdxT(b_strides[i]);
|
||||||
elem /= shape[i];
|
elem /= shape[i];
|
||||||
}
|
}
|
||||||
return cuda::std::make_tuple(a_loc, b_loc);
|
return cuda::std::make_tuple(a_loc, b_loc);
|
||||||
@ -175,9 +175,9 @@ inline __host__ __device__ cuda::std::tuple<IdxT, IdxT, IdxT> elem_to_loc_nd(
|
|||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i = NDIM - 1; i >= 0; --i) {
|
for (int i = NDIM - 1; i >= 0; --i) {
|
||||||
int dim_idx = elem % shape[i];
|
int dim_idx = elem % shape[i];
|
||||||
a_loc += dim_idx * a_strides[i];
|
a_loc += dim_idx * IdxT(a_strides[i]);
|
||||||
b_loc += dim_idx * b_strides[i];
|
b_loc += dim_idx * IdxT(b_strides[i]);
|
||||||
c_loc += dim_idx * c_strides[i];
|
c_loc += dim_idx * IdxT(c_strides[i]);
|
||||||
elem /= shape[i];
|
elem /= shape[i];
|
||||||
}
|
}
|
||||||
return cuda::std::make_tuple(a_loc, b_loc, c_loc);
|
return cuda::std::make_tuple(a_loc, b_loc, c_loc);
|
||||||
@ -187,8 +187,8 @@ inline __host__ __device__ cuda::std::tuple<IdxT, IdxT, IdxT> elem_to_loc_nd(
|
|||||||
template <typename IdxT = int64_t>
|
template <typename IdxT = int64_t>
|
||||||
inline __host__ __device__ IdxT
|
inline __host__ __device__ IdxT
|
||||||
elem_to_loc_4d(IdxT elem, const int* shape, const int64_t* strides, int ndim) {
|
elem_to_loc_4d(IdxT elem, const int* shape, const int64_t* strides, int ndim) {
|
||||||
IdxT loc = elem_to_loc_nd<3>(elem, shape, strides);
|
IdxT loc = 0;
|
||||||
for (int i = ndim - 1; i >= 3; --i) {
|
for (int i = ndim - 1; i >= 0; --i) {
|
||||||
loc += (elem % shape[i]) * IdxT(strides[i]);
|
loc += (elem % shape[i]) * IdxT(strides[i]);
|
||||||
elem /= shape[i];
|
elem /= shape[i];
|
||||||
}
|
}
|
||||||
@ -202,11 +202,12 @@ inline __host__ __device__ cuda::std::tuple<IdxT, IdxT> elem_to_loc_4d(
|
|||||||
const int64_t* a_strides,
|
const int64_t* a_strides,
|
||||||
const int64_t* b_strides,
|
const int64_t* b_strides,
|
||||||
int ndim) {
|
int ndim) {
|
||||||
auto [a_loc, b_loc] = elem_to_loc_nd<3>(elem, shape, a_strides, b_strides);
|
IdxT a_loc = 0;
|
||||||
for (int i = ndim - 1; i >= 3; --i) {
|
IdxT b_loc = 0;
|
||||||
|
for (int i = ndim - 1; i >= 0; --i) {
|
||||||
int dim_idx = elem % shape[i];
|
int dim_idx = elem % shape[i];
|
||||||
a_loc += dim_idx * a_strides[i];
|
a_loc += dim_idx * IdxT(a_strides[i]);
|
||||||
b_loc += dim_idx * b_strides[i];
|
b_loc += dim_idx * IdxT(b_strides[i]);
|
||||||
elem /= shape[i];
|
elem /= shape[i];
|
||||||
}
|
}
|
||||||
return cuda::std::make_tuple(a_loc, b_loc);
|
return cuda::std::make_tuple(a_loc, b_loc);
|
||||||
@ -220,13 +221,14 @@ inline __host__ __device__ cuda::std::tuple<IdxT, IdxT, IdxT> elem_to_loc_4d(
|
|||||||
const int64_t* b_strides,
|
const int64_t* b_strides,
|
||||||
const int64_t* c_strides,
|
const int64_t* c_strides,
|
||||||
int ndim) {
|
int ndim) {
|
||||||
auto [a_loc, b_loc, c_loc] =
|
IdxT a_loc = 0;
|
||||||
elem_to_loc_nd<3>(elem, shape, a_strides, b_strides, c_strides);
|
IdxT b_loc = 0;
|
||||||
for (int i = ndim - 1; i >= 3; --i) {
|
IdxT c_loc = 0;
|
||||||
|
for (int i = ndim - 1; i >= 0; --i) {
|
||||||
int dim_idx = elem % shape[i];
|
int dim_idx = elem % shape[i];
|
||||||
a_loc += dim_idx * a_strides[i];
|
a_loc += dim_idx * IdxT(a_strides[i]);
|
||||||
b_loc += dim_idx * b_strides[i];
|
b_loc += dim_idx * IdxT(b_strides[i]);
|
||||||
c_loc += dim_idx * c_strides[i];
|
c_loc += dim_idx * IdxT(c_strides[i]);
|
||||||
elem /= shape[i];
|
elem /= shape[i];
|
||||||
}
|
}
|
||||||
return cuda::std::make_tuple(a_loc, b_loc, c_loc);
|
return cuda::std::make_tuple(a_loc, b_loc, c_loc);
|
||||||
@ -336,4 +338,21 @@ struct LoopedElemToLoc<1, false, OffsetT> {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
inline __device__ cuComplex log1p(cuComplex in) {
|
||||||
|
float x = cuCrealf(in);
|
||||||
|
float y = cuCimagf(in);
|
||||||
|
float zabs = sqrt(x * x + y * y);
|
||||||
|
float theta = atan2f(y, x + 1);
|
||||||
|
if (zabs < 0.5f) {
|
||||||
|
float r = x * (2 + x) + y * y;
|
||||||
|
if (r == 0) { // handle underflow
|
||||||
|
return {x, theta};
|
||||||
|
}
|
||||||
|
return {0.5f * log1pf(r), theta};
|
||||||
|
} else {
|
||||||
|
auto z0 = sqrt((x + 1) * (x + 1) + y * y);
|
||||||
|
return {log(z0), theta};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace mlx::core::cu
|
} // namespace mlx::core::cu
|
||||||
|
@ -62,7 +62,7 @@ void finalize(Stream s) {
|
|||||||
|
|
||||||
void synchronize(Stream s) {
|
void synchronize(Stream s) {
|
||||||
nvtx3::scoped_range r("gpu::synchronize");
|
nvtx3::scoped_range r("gpu::synchronize");
|
||||||
cu::get_stream(s).synchronize();
|
cu::get_command_encoder(s).synchronize();
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace mlx::core::gpu
|
} // namespace mlx::core::gpu
|
||||||
|
@ -65,8 +65,8 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
Dtype idx_dtype = nidx > 0 ? inputs[1].dtype() : int32;
|
Dtype idx_dtype = nidx > 0 ? inputs[1].dtype() : int32;
|
||||||
int32_t idx_ndim = nidx > 0 ? inputs[1].ndim() : 0;
|
int32_t idx_ndim = nidx > 0 ? inputs[1].ndim() : 0;
|
||||||
|
|
||||||
bool large = (nidx > 0 && inputs[1].size() > UINT32_MAX) ||
|
bool large = (nidx > 0 && inputs[1].size() > INT32_MAX) ||
|
||||||
(src.size() > UINT32_MAX) || (out.size() > UINT32_MAX);
|
(src.size() > INT32_MAX) || (out.size() > INT32_MAX);
|
||||||
|
|
||||||
uint32_t slice_size = std::accumulate(
|
uint32_t slice_size = std::accumulate(
|
||||||
slice_sizes_.begin(), slice_sizes_.end(), 1, std::multiplies<uint32_t>());
|
slice_sizes_.begin(), slice_sizes_.end(), 1, std::multiplies<uint32_t>());
|
||||||
@ -88,7 +88,7 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
dtype_to_cuda_type(idx_dtype),
|
dtype_to_cuda_type(idx_dtype),
|
||||||
nidx,
|
nidx,
|
||||||
ndim,
|
ndim,
|
||||||
large ? "int64_t" : "uint32_t"));
|
large ? "int64_t" : "int32_t"));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return std::make_pair(jit_source_gather, std::move(kernel_names));
|
return std::make_pair(jit_source_gather, std::move(kernel_names));
|
||||||
@ -99,7 +99,7 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
if (large) {
|
if (large) {
|
||||||
mod.append_arg<int64_t>(out.size());
|
mod.append_arg<int64_t>(out.size());
|
||||||
} else {
|
} else {
|
||||||
mod.append_arg<uint32_t>(out.size());
|
mod.append_arg<int32_t>(out.size());
|
||||||
}
|
}
|
||||||
mod.append_ndim_arg(src.shape());
|
mod.append_ndim_arg(src.shape());
|
||||||
mod.append_ndim_arg(src.strides());
|
mod.append_ndim_arg(src.strides());
|
||||||
@ -115,7 +115,7 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
dtype_to_cuda_type(idx_dtype),
|
dtype_to_cuda_type(idx_dtype),
|
||||||
nidx,
|
nidx,
|
||||||
idx_ndim,
|
idx_ndim,
|
||||||
large ? "int64_t" : "uint32_t");
|
large ? "int64_t" : "int32_t");
|
||||||
|
|
||||||
auto& encoder = cu::get_command_encoder(s);
|
auto& encoder = cu::get_command_encoder(s);
|
||||||
for (const auto& in : inputs) {
|
for (const auto& in : inputs) {
|
||||||
@ -152,14 +152,14 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
Dtype idx_dtype = nidx > 0 ? inputs[1].dtype() : int32;
|
Dtype idx_dtype = nidx > 0 ? inputs[1].dtype() : int32;
|
||||||
int32_t idx_ndim = nidx > 0 ? inputs[1].ndim() : 0;
|
int32_t idx_ndim = nidx > 0 ? inputs[1].ndim() : 0;
|
||||||
|
|
||||||
bool large = (nidx > 0 && inputs[1].size() > UINT32_MAX) ||
|
bool large = (nidx > 0 && inputs[1].size() > INT32_MAX) ||
|
||||||
(upd.size() > UINT32_MAX) || (out.size() > UINT32_MAX);
|
(upd.size() > INT32_MAX) || (out.size() > INT32_MAX);
|
||||||
|
|
||||||
uint32_t upd_post_idx_size = std::accumulate(
|
int32_t upd_post_idx_size = std::accumulate(
|
||||||
upd.shape().begin() + idx_ndim,
|
upd.shape().begin() + idx_ndim,
|
||||||
upd.shape().end(),
|
upd.shape().end(),
|
||||||
1,
|
1,
|
||||||
std::multiplies<uint32_t>());
|
std::multiplies<int32_t>());
|
||||||
|
|
||||||
const char* op = g_scatter_ops[reduce_type_];
|
const char* op = g_scatter_ops[reduce_type_];
|
||||||
std::string module_name = fmt::format(
|
std::string module_name = fmt::format(
|
||||||
@ -181,7 +181,7 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
op,
|
op,
|
||||||
nidx,
|
nidx,
|
||||||
ndim,
|
ndim,
|
||||||
large ? "int64_t" : "uint32_t"));
|
large ? "int64_t" : "int32_t"));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return std::make_pair(jit_source_scatter, std::move(kernel_names));
|
return std::make_pair(jit_source_scatter, std::move(kernel_names));
|
||||||
@ -192,7 +192,7 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
if (large) {
|
if (large) {
|
||||||
mod.append_arg<int64_t>(upd.size());
|
mod.append_arg<int64_t>(upd.size());
|
||||||
} else {
|
} else {
|
||||||
mod.append_arg<uint32_t>(upd.size());
|
mod.append_arg<int32_t>(upd.size());
|
||||||
}
|
}
|
||||||
mod.append_ndim_arg(upd.shape());
|
mod.append_ndim_arg(upd.shape());
|
||||||
mod.append_ndim_arg(upd.strides());
|
mod.append_ndim_arg(upd.strides());
|
||||||
@ -200,7 +200,7 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
if (large) {
|
if (large) {
|
||||||
mod.append_arg<int64_t>(upd_post_idx_size);
|
mod.append_arg<int64_t>(upd_post_idx_size);
|
||||||
} else {
|
} else {
|
||||||
mod.append_arg<uint32_t>(upd_post_idx_size);
|
mod.append_arg<int32_t>(upd_post_idx_size);
|
||||||
}
|
}
|
||||||
mod.append_ndim_arg(out.shape());
|
mod.append_ndim_arg(out.shape());
|
||||||
mod.append_ndim_arg(out.strides());
|
mod.append_ndim_arg(out.strides());
|
||||||
@ -215,7 +215,7 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
op,
|
op,
|
||||||
nidx,
|
nidx,
|
||||||
idx_ndim,
|
idx_ndim,
|
||||||
large ? "int64_t" : "uint32_t");
|
large ? "int64_t" : "int32_t");
|
||||||
|
|
||||||
auto& encoder = cu::get_command_encoder(s);
|
auto& encoder = cu::get_command_encoder(s);
|
||||||
for (const auto& in : inputs) {
|
for (const auto& in : inputs) {
|
||||||
@ -238,7 +238,7 @@ void GatherAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool large = idx.size() > UINT32_MAX || src.size() > UINT32_MAX;
|
bool large = idx.size() > INT32_MAX || src.size() > INT32_MAX;
|
||||||
|
|
||||||
std::string module_name = fmt::format(
|
std::string module_name = fmt::format(
|
||||||
"gather_axis_{}_{}",
|
"gather_axis_{}_{}",
|
||||||
@ -258,7 +258,7 @@ void GatherAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
ndim,
|
ndim,
|
||||||
contiguous & 1 ? true : false,
|
contiguous & 1 ? true : false,
|
||||||
contiguous & 2 ? true : false,
|
contiguous & 2 ? true : false,
|
||||||
large ? "int64_t" : "uint32_t"));
|
large ? "int64_t" : "int32_t"));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -283,9 +283,9 @@ void GatherAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
mod.append_arg<int64_t>(idx_size_axis);
|
mod.append_arg<int64_t>(idx_size_axis);
|
||||||
mod.append_arg<int64_t>(idx_size_post);
|
mod.append_arg<int64_t>(idx_size_post);
|
||||||
} else {
|
} else {
|
||||||
mod.append_arg<uint32_t>(idx_size_pre);
|
mod.append_arg<int32_t>(idx_size_pre);
|
||||||
mod.append_arg<uint32_t>(idx_size_axis);
|
mod.append_arg<int32_t>(idx_size_axis);
|
||||||
mod.append_arg<uint32_t>(idx_size_post);
|
mod.append_arg<int32_t>(idx_size_post);
|
||||||
}
|
}
|
||||||
mod.append_arg(remove_index(idx.shape(), axis_));
|
mod.append_arg(remove_index(idx.shape(), axis_));
|
||||||
mod.append_arg(remove_index(src.strides(), axis_));
|
mod.append_arg(remove_index(src.strides(), axis_));
|
||||||
@ -302,7 +302,7 @@ void GatherAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
src.ndim() - 1,
|
src.ndim() - 1,
|
||||||
src.flags().row_contiguous,
|
src.flags().row_contiguous,
|
||||||
idx.flags().row_contiguous,
|
idx.flags().row_contiguous,
|
||||||
large ? "int64_t" : "uint32_t");
|
large ? "int64_t" : "int32_t");
|
||||||
|
|
||||||
auto& encoder = cu::get_command_encoder(s);
|
auto& encoder = cu::get_command_encoder(s);
|
||||||
for (const auto& in : inputs) {
|
for (const auto& in : inputs) {
|
||||||
@ -337,7 +337,7 @@ void ScatterAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool large = idx.size() > UINT32_MAX || src.size() > UINT32_MAX;
|
bool large = idx.size() > INT32_MAX || src.size() > INT32_MAX;
|
||||||
|
|
||||||
const char* op = reduce_type_ == ScatterAxis::Sum ? "Sum" : "Assign";
|
const char* op = reduce_type_ == ScatterAxis::Sum ? "Sum" : "Assign";
|
||||||
std::string module_name = fmt::format(
|
std::string module_name = fmt::format(
|
||||||
@ -360,7 +360,7 @@ void ScatterAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
ndim,
|
ndim,
|
||||||
contiguous & 1 ? true : false,
|
contiguous & 1 ? true : false,
|
||||||
contiguous & 2 ? true : false,
|
contiguous & 2 ? true : false,
|
||||||
large ? "int64_t" : "uint32_t"));
|
large ? "int64_t" : "int32_t"));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -385,9 +385,9 @@ void ScatterAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
mod.append_arg<int64_t>(idx_size_axis);
|
mod.append_arg<int64_t>(idx_size_axis);
|
||||||
mod.append_arg<int64_t>(idx_size_post);
|
mod.append_arg<int64_t>(idx_size_post);
|
||||||
} else {
|
} else {
|
||||||
mod.append_arg<uint32_t>(idx_size_pre);
|
mod.append_arg<int32_t>(idx_size_pre);
|
||||||
mod.append_arg<uint32_t>(idx_size_axis);
|
mod.append_arg<int32_t>(idx_size_axis);
|
||||||
mod.append_arg<uint32_t>(idx_size_post);
|
mod.append_arg<int32_t>(idx_size_post);
|
||||||
}
|
}
|
||||||
mod.append_arg(remove_index(idx.shape(), axis_));
|
mod.append_arg(remove_index(idx.shape(), axis_));
|
||||||
mod.append_arg(remove_index(upd.strides(), axis_));
|
mod.append_arg(remove_index(upd.strides(), axis_));
|
||||||
@ -405,7 +405,7 @@ void ScatterAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
idx.ndim() - 1,
|
idx.ndim() - 1,
|
||||||
upd.flags().row_contiguous,
|
upd.flags().row_contiguous,
|
||||||
idx.flags().row_contiguous,
|
idx.flags().row_contiguous,
|
||||||
large ? "int64_t" : "uint32_t");
|
large ? "int64_t" : "int32_t");
|
||||||
|
|
||||||
auto& encoder = cu::get_command_encoder(s);
|
auto& encoder = cu::get_command_encoder(s);
|
||||||
for (const auto& in : inputs) {
|
for (const auto& in : inputs) {
|
||||||
|
@ -37,36 +37,46 @@ void check_cu_error(const char* name, CUresult err) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Return the location of the CUDA toolkit.
|
// Return the location of the CUDA toolkit.
|
||||||
const char* cuda_home() {
|
const std::string& cuda_home() {
|
||||||
const char* home = std::getenv("CUDA_HOME");
|
static std::string home = []() -> std::string {
|
||||||
if (home) {
|
const char* home = std::getenv("CUDA_HOME");
|
||||||
return home;
|
if (home) {
|
||||||
}
|
return home;
|
||||||
home = std::getenv("CUDA_PATH");
|
}
|
||||||
if (home) {
|
home = std::getenv("CUDA_PATH");
|
||||||
return home;
|
if (home) {
|
||||||
}
|
return home;
|
||||||
|
}
|
||||||
#if defined(__linux__)
|
#if defined(__linux__)
|
||||||
home = "/usr/local/cuda";
|
home = "/usr/local/cuda";
|
||||||
if (std::filesystem::exists(home)) {
|
if (std::filesystem::exists(home)) {
|
||||||
return home;
|
return home;
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
throw std::runtime_error(
|
throw std::runtime_error(
|
||||||
"Environment variable CUDA_HOME or CUDA_PATH is not set.");
|
"Environment variable CUDA_HOME or CUDA_PATH is not set.");
|
||||||
|
}();
|
||||||
|
return home;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get the cache directory for storing compiled results.
|
// Get the cache directory for storing compiled results.
|
||||||
bool get_ptx_cache_dir(std::filesystem::path* result) {
|
const std::filesystem::path& ptx_cache_dir() {
|
||||||
auto path = std::filesystem::temp_directory_path() / "mlx" / "ptx";
|
static std::filesystem::path cache = []() -> std::filesystem::path {
|
||||||
if (!std::filesystem::is_directory(path)) {
|
std::filesystem::path cache;
|
||||||
std::error_code error;
|
if (auto c = std::getenv("MLX_PTX_CACHE"); c) {
|
||||||
if (!std::filesystem::create_directories(path, error)) {
|
cache = c;
|
||||||
return false;
|
} else {
|
||||||
|
cache = std::filesystem::temp_directory_path() / "mlx" / "ptx";
|
||||||
}
|
}
|
||||||
}
|
if (!std::filesystem::exists(cache)) {
|
||||||
*result = path;
|
std::error_code error;
|
||||||
return true;
|
if (!std::filesystem::create_directories(cache, error)) {
|
||||||
|
return std::filesystem::path();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return cache;
|
||||||
|
}();
|
||||||
|
return cache;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Try to read the cached |ptx| and |ptx_kernels| from |cache_dir|.
|
// Try to read the cached |ptx| and |ptx_kernels| from |cache_dir|.
|
||||||
@ -75,6 +85,10 @@ bool read_cached_ptx(
|
|||||||
const std::string& module_name,
|
const std::string& module_name,
|
||||||
std::vector<char>* ptx,
|
std::vector<char>* ptx,
|
||||||
std::vector<std::pair<std::string, std::string>>* ptx_kernels) {
|
std::vector<std::pair<std::string, std::string>>* ptx_kernels) {
|
||||||
|
if (cache_dir.empty()) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
auto ptx_path = cache_dir / (module_name + ".ptx");
|
auto ptx_path = cache_dir / (module_name + ".ptx");
|
||||||
std::error_code error;
|
std::error_code error;
|
||||||
auto ptx_size = std::filesystem::file_size(ptx_path, error);
|
auto ptx_size = std::filesystem::file_size(ptx_path, error);
|
||||||
@ -105,6 +119,10 @@ void write_cached_ptx(
|
|||||||
const std::string& module_name,
|
const std::string& module_name,
|
||||||
const std::vector<char>& ptx,
|
const std::vector<char>& ptx,
|
||||||
const std::vector<std::pair<std::string, std::string>>& ptx_kernels) {
|
const std::vector<std::pair<std::string, std::string>>& ptx_kernels) {
|
||||||
|
if (cache_dir.empty()) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
std::ofstream ptx_file(cache_dir / (module_name + ".ptx"), std::ios::binary);
|
std::ofstream ptx_file(cache_dir / (module_name + ".ptx"), std::ios::binary);
|
||||||
if (!ptx.empty()) {
|
if (!ptx.empty()) {
|
||||||
ptx_file.write(&ptx.front(), ptx.size());
|
ptx_file.write(&ptx.front(), ptx.size());
|
||||||
@ -184,11 +202,9 @@ JitModule::JitModule(
|
|||||||
const std::string& module_name,
|
const std::string& module_name,
|
||||||
const KernelBuilder& builder) {
|
const KernelBuilder& builder) {
|
||||||
// Check cache.
|
// Check cache.
|
||||||
std::filesystem::path cache_dir;
|
|
||||||
std::vector<char> ptx;
|
std::vector<char> ptx;
|
||||||
std::vector<std::pair<std::string, std::string>> ptx_kernels;
|
std::vector<std::pair<std::string, std::string>> ptx_kernels;
|
||||||
if (!get_ptx_cache_dir(&cache_dir) ||
|
if (!read_cached_ptx(ptx_cache_dir(), module_name, &ptx, &ptx_kernels)) {
|
||||||
!read_cached_ptx(cache_dir, module_name, &ptx, &ptx_kernels)) {
|
|
||||||
// Create program.
|
// Create program.
|
||||||
auto [source_code, kernel_names] = builder();
|
auto [source_code, kernel_names] = builder();
|
||||||
nvrtcProgram prog;
|
nvrtcProgram prog;
|
||||||
@ -246,7 +262,7 @@ JitModule::JitModule(
|
|||||||
} else {
|
} else {
|
||||||
CHECK_NVRTC_ERROR(nvrtcGetPTX(prog, ptx.data()));
|
CHECK_NVRTC_ERROR(nvrtcGetPTX(prog, ptx.data()));
|
||||||
}
|
}
|
||||||
write_cached_ptx(cache_dir, module_name, ptx, ptx_kernels);
|
write_cached_ptx(ptx_cache_dir(), module_name, ptx, ptx_kernels);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Load module.
|
// Load module.
|
||||||
|
@ -102,6 +102,11 @@ inline constexpr bool is_floating_v =
|
|||||||
cuda::std::is_same_v<T, float> || cuda::std::is_same_v<T, double> ||
|
cuda::std::is_same_v<T, float> || cuda::std::is_same_v<T, double> ||
|
||||||
cuda::std::is_same_v<T, float16_t> || cuda::std::is_same_v<T, bfloat16_t>;
|
cuda::std::is_same_v<T, float16_t> || cuda::std::is_same_v<T, bfloat16_t>;
|
||||||
|
|
||||||
|
// Type traits for detecting complex or real floating point numbers.
|
||||||
|
template <typename T>
|
||||||
|
inline constexpr bool is_inexact_v =
|
||||||
|
is_floating_v<T> || cuda::std::is_same_v<T, complex64_t>;
|
||||||
|
|
||||||
// Utility to copy data from vector to array in host.
|
// Utility to copy data from vector to array in host.
|
||||||
template <int NDIM = MAX_NDIM, typename T = int32_t>
|
template <int NDIM = MAX_NDIM, typename T = int32_t>
|
||||||
inline cuda::std::array<T, NDIM> const_param(const std::vector<T>& vec) {
|
inline cuda::std::array<T, NDIM> const_param(const std::vector<T>& vec) {
|
||||||
@ -136,17 +141,19 @@ inline uint max_occupancy_block_dim(T kernel) {
|
|||||||
template <typename T>
|
template <typename T>
|
||||||
inline std::tuple<dim3, uint> get_launch_args(
|
inline std::tuple<dim3, uint> get_launch_args(
|
||||||
T kernel,
|
T kernel,
|
||||||
const array& arr,
|
size_t size,
|
||||||
|
const Shape& shape,
|
||||||
|
const Strides& strides,
|
||||||
bool large,
|
bool large,
|
||||||
int work_per_thread = 1) {
|
int work_per_thread = 1) {
|
||||||
size_t nthreads = cuda::ceil_div(arr.size(), work_per_thread);
|
size_t nthreads = cuda::ceil_div(size, work_per_thread);
|
||||||
uint block_dim = max_occupancy_block_dim(kernel);
|
uint block_dim = max_occupancy_block_dim(kernel);
|
||||||
if (block_dim > nthreads) {
|
if (block_dim > nthreads) {
|
||||||
block_dim = nthreads;
|
block_dim = nthreads;
|
||||||
}
|
}
|
||||||
dim3 num_blocks;
|
dim3 num_blocks;
|
||||||
if (large) {
|
if (large) {
|
||||||
num_blocks = get_2d_grid_dims(arr.shape(), arr.strides(), work_per_thread);
|
num_blocks = get_2d_grid_dims(shape, strides, work_per_thread);
|
||||||
num_blocks.x = cuda::ceil_div(num_blocks.x, block_dim);
|
num_blocks.x = cuda::ceil_div(num_blocks.x, block_dim);
|
||||||
} else {
|
} else {
|
||||||
num_blocks.x = cuda::ceil_div(nthreads, block_dim);
|
num_blocks.x = cuda::ceil_div(nthreads, block_dim);
|
||||||
@ -154,4 +161,14 @@ inline std::tuple<dim3, uint> get_launch_args(
|
|||||||
return std::make_tuple(num_blocks, block_dim);
|
return std::make_tuple(num_blocks, block_dim);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
inline std::tuple<dim3, uint> get_launch_args(
|
||||||
|
T kernel,
|
||||||
|
const array& arr,
|
||||||
|
bool large,
|
||||||
|
int work_per_thread = 1) {
|
||||||
|
return get_launch_args(
|
||||||
|
kernel, arr.size(), arr.shape(), arr.strides(), large, work_per_thread);
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
@ -162,11 +162,15 @@ class MatMul {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
array workspace(
|
void* workspace_ptr = nullptr;
|
||||||
allocator::malloc(heuristic_.workspaceSize),
|
if (heuristic_.workspaceSize > 0) {
|
||||||
{static_cast<int>(heuristic_.workspaceSize)},
|
array workspace(
|
||||||
int8);
|
allocator::malloc(heuristic_.workspaceSize),
|
||||||
encoder.add_temporary(workspace);
|
{static_cast<int>(heuristic_.workspaceSize)},
|
||||||
|
int8);
|
||||||
|
encoder.add_temporary(workspace);
|
||||||
|
workspace_ptr = workspace.data<void>();
|
||||||
|
}
|
||||||
|
|
||||||
encoder.launch_kernel([&](cudaStream_t stream) {
|
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||||
CHECK_CUBLAS_ERROR(cublasLtMatmul(
|
CHECK_CUBLAS_ERROR(cublasLtMatmul(
|
||||||
@ -183,8 +187,8 @@ class MatMul {
|
|||||||
out,
|
out,
|
||||||
out_desc_,
|
out_desc_,
|
||||||
&heuristic_.algo,
|
&heuristic_.algo,
|
||||||
workspace.data<void>(),
|
workspace_ptr,
|
||||||
workspace.nbytes(),
|
heuristic_.workspaceSize,
|
||||||
stream));
|
stream));
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
@ -358,9 +362,18 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
a_batch_strides.back(),
|
a_batch_strides.back(),
|
||||||
b_batch_strides.back());
|
b_batch_strides.back());
|
||||||
|
|
||||||
|
encoder.set_input_array(a);
|
||||||
|
encoder.set_input_array(b);
|
||||||
|
encoder.set_output_array(out);
|
||||||
|
auto nbatch = batch_count / batch_shape.back();
|
||||||
|
if (nbatch == 1) {
|
||||||
|
matmul.run(encoder, out.data<int8_t>(), a.data<int8_t>(), b.data<int8_t>());
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
ContiguousIterator a_it(batch_shape, a_batch_strides, batch_shape.size() - 1);
|
ContiguousIterator a_it(batch_shape, a_batch_strides, batch_shape.size() - 1);
|
||||||
ContiguousIterator b_it(batch_shape, b_batch_strides, batch_shape.size() - 1);
|
ContiguousIterator b_it(batch_shape, b_batch_strides, batch_shape.size() - 1);
|
||||||
for (size_t i = 0; i < batch_count / batch_shape.back(); ++i) {
|
for (size_t i = 0; i < nbatch; ++i) {
|
||||||
matmul.run(
|
matmul.run(
|
||||||
encoder,
|
encoder,
|
||||||
out.data<int8_t>() + out.itemsize() * i * batch_shape.back() * M * N,
|
out.data<int8_t>() + out.itemsize() * i * batch_shape.back() * M * N,
|
||||||
@ -444,10 +457,28 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
b_batch_strides.back(),
|
b_batch_strides.back(),
|
||||||
c_batch_strides.back());
|
c_batch_strides.back());
|
||||||
|
|
||||||
|
encoder.set_input_array(a);
|
||||||
|
encoder.set_input_array(b);
|
||||||
|
encoder.set_input_array(c);
|
||||||
|
encoder.set_output_array(out);
|
||||||
|
|
||||||
|
auto nbatch = batch_count / batch_shape.back();
|
||||||
|
if (nbatch == 1) {
|
||||||
|
matmul.run(
|
||||||
|
encoder,
|
||||||
|
out.data<int8_t>(),
|
||||||
|
a.data<int8_t>(),
|
||||||
|
b.data<int8_t>(),
|
||||||
|
c.data<int8_t>(),
|
||||||
|
alpha_,
|
||||||
|
beta_);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
ContiguousIterator a_it(batch_shape, a_batch_strides, batch_shape.size() - 1);
|
ContiguousIterator a_it(batch_shape, a_batch_strides, batch_shape.size() - 1);
|
||||||
ContiguousIterator b_it(batch_shape, b_batch_strides, batch_shape.size() - 1);
|
ContiguousIterator b_it(batch_shape, b_batch_strides, batch_shape.size() - 1);
|
||||||
ContiguousIterator c_it(batch_shape, c_batch_strides, batch_shape.size() - 1);
|
ContiguousIterator c_it(batch_shape, c_batch_strides, batch_shape.size() - 1);
|
||||||
for (size_t i = 0; i < batch_count / batch_shape.back(); ++i) {
|
for (size_t i = 0; i < nbatch; ++i) {
|
||||||
matmul.run(
|
matmul.run(
|
||||||
encoder,
|
encoder,
|
||||||
out.data<int8_t>() + out.itemsize() * i * batch_shape.back() * M * N,
|
out.data<int8_t>() + out.itemsize() * i * batch_shape.back() * M * N,
|
||||||
|
@ -93,10 +93,8 @@ void AllReduce::eval_gpu(
|
|||||||
throw std::runtime_error(#func " has no CUDA implementation."); \
|
throw std::runtime_error(#func " has no CUDA implementation."); \
|
||||||
}
|
}
|
||||||
|
|
||||||
NO_GPU(ArgPartition)
|
|
||||||
NO_GPU(BlockMaskedMM)
|
NO_GPU(BlockMaskedMM)
|
||||||
NO_GPU(Convolution)
|
NO_GPU(Convolution)
|
||||||
NO_GPU_MULTI(DivMod)
|
|
||||||
NO_GPU(DynamicSlice)
|
NO_GPU(DynamicSlice)
|
||||||
NO_GPU(DynamicSliceUpdate)
|
NO_GPU(DynamicSliceUpdate)
|
||||||
NO_GPU(FFT)
|
NO_GPU(FFT)
|
||||||
@ -105,7 +103,6 @@ NO_GPU(GatherQMM)
|
|||||||
NO_GPU(Hadamard)
|
NO_GPU(Hadamard)
|
||||||
NO_GPU(Load)
|
NO_GPU(Load)
|
||||||
NO_GPU_MULTI(LUF)
|
NO_GPU_MULTI(LUF)
|
||||||
NO_GPU(Partition)
|
|
||||||
NO_GPU_MULTI(QRF)
|
NO_GPU_MULTI(QRF)
|
||||||
NO_GPU(QuantizedMatmul)
|
NO_GPU(QuantizedMatmul)
|
||||||
NO_GPU(Scan)
|
NO_GPU(Scan)
|
||||||
|
@ -79,14 +79,10 @@ void segmented_sort(cu::CommandEncoder& encoder, Args&&... args) {
|
|||||||
void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) {
|
void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) {
|
||||||
array out = out_;
|
array out = out_;
|
||||||
auto& encoder = cu::get_command_encoder(s);
|
auto& encoder = cu::get_command_encoder(s);
|
||||||
encoder.set_input_array(in);
|
|
||||||
encoder.set_output_array(out);
|
|
||||||
|
|
||||||
if (axis < 0) {
|
if (axis < 0) {
|
||||||
axis += in.ndim();
|
axis += in.ndim();
|
||||||
}
|
}
|
||||||
int nsort = in.shape(axis);
|
int nsort = in.shape(axis);
|
||||||
int nsegments = in.data_size() / nsort;
|
|
||||||
int last_dim = in.ndim() - 1;
|
int last_dim = in.ndim() - 1;
|
||||||
|
|
||||||
// If we are not sorting the innermost dimension of a contiguous array,
|
// If we are not sorting the innermost dimension of a contiguous array,
|
||||||
@ -100,9 +96,15 @@ void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) {
|
|||||||
out = array(allocator::malloc(out.nbytes()), in.shape(), out.dtype());
|
out = array(allocator::malloc(out.nbytes()), in.shape(), out.dtype());
|
||||||
encoder.add_temporary(out);
|
encoder.add_temporary(out);
|
||||||
} else {
|
} else {
|
||||||
out.set_data(allocator::malloc(out.nbytes()));
|
out.set_data(
|
||||||
|
allocator::malloc(in.data_size() * out.itemsize()),
|
||||||
|
in.data_size(),
|
||||||
|
in.strides(),
|
||||||
|
in.flags());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
encoder.set_input_array(in);
|
||||||
|
encoder.set_output_array(out);
|
||||||
encoder.launch_kernel([&](cudaStream_t stream) {
|
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||||
MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE, {
|
MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE, {
|
||||||
if constexpr (!std::is_same_v<CTYPE, complex64_t>) {
|
if constexpr (!std::is_same_v<CTYPE, complex64_t>) {
|
||||||
@ -134,7 +136,7 @@ void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) {
|
|||||||
indices.data<uint32_t>(),
|
indices.data<uint32_t>(),
|
||||||
out.data<uint32_t>(),
|
out.data<uint32_t>(),
|
||||||
in.data_size(),
|
in.data_size(),
|
||||||
nsegments,
|
in.data_size() / nsort,
|
||||||
offsets,
|
offsets,
|
||||||
offsets + 1,
|
offsets + 1,
|
||||||
stream);
|
stream);
|
||||||
@ -144,7 +146,7 @@ void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) {
|
|||||||
in.data<Type>(),
|
in.data<Type>(),
|
||||||
out.data<Type>(),
|
out.data<Type>(),
|
||||||
in.data_size(),
|
in.data_size(),
|
||||||
nsegments,
|
in.data_size() / nsort,
|
||||||
offsets,
|
offsets,
|
||||||
offsets + 1,
|
offsets + 1,
|
||||||
stream);
|
stream);
|
||||||
@ -177,4 +179,14 @@ void Sort::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
gpu_sort(stream(), inputs[0], out, axis_, false);
|
gpu_sort(stream(), inputs[0], out, axis_, false);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void ArgPartition::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
nvtx3::scoped_range r("ArgPartition::eval_gpu");
|
||||||
|
gpu_sort(stream(), inputs[0], out, axis_, true);
|
||||||
|
}
|
||||||
|
|
||||||
|
void Partition::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
nvtx3::scoped_range r("Partition::eval_gpu");
|
||||||
|
gpu_sort(stream(), inputs[0], out, axis_, false);
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
@ -101,10 +101,10 @@ void ternary_op_gpu_inplace(
|
|||||||
auto& a_strides = strides[0];
|
auto& a_strides = strides[0];
|
||||||
auto& b_strides = strides[1];
|
auto& b_strides = strides[1];
|
||||||
auto& c_strides = strides[2];
|
auto& c_strides = strides[2];
|
||||||
bool large = a.data_size() > UINT32_MAX || b.data_size() > UINT32_MAX ||
|
bool large = a.data_size() > INT32_MAX || b.data_size() > INT32_MAX ||
|
||||||
c.data_size() > UINT32_MAX || out.data_size() > UINT32_MAX;
|
c.data_size() > INT32_MAX || out.data_size() > INT32_MAX;
|
||||||
MLX_SWITCH_BOOL(large, LARGE, {
|
MLX_SWITCH_BOOL(large, LARGE, {
|
||||||
using IdxT = std::conditional_t<LARGE, int64_t, uint32_t>;
|
using IdxT = std::conditional_t<LARGE, int64_t, int32_t>;
|
||||||
int ndim = shape.size();
|
int ndim = shape.size();
|
||||||
if (ndim <= 3) {
|
if (ndim <= 3) {
|
||||||
MLX_SWITCH_1_2_3(ndim, NDIM, {
|
MLX_SWITCH_1_2_3(ndim, NDIM, {
|
||||||
@ -116,7 +116,7 @@ void ternary_op_gpu_inplace(
|
|||||||
b.data<DType>(),
|
b.data<DType>(),
|
||||||
c.data<DType>(),
|
c.data<DType>(),
|
||||||
out.data<DType>(),
|
out.data<DType>(),
|
||||||
out.data_size(),
|
out.size(),
|
||||||
const_param<NDIM>(shape),
|
const_param<NDIM>(shape),
|
||||||
const_param<NDIM>(a_strides),
|
const_param<NDIM>(a_strides),
|
||||||
const_param<NDIM>(b_strides),
|
const_param<NDIM>(b_strides),
|
||||||
@ -142,7 +142,8 @@ void ternary_op_gpu_inplace(
|
|||||||
MLX_SWITCH_BOOL(out.data_size() > UINT32_MAX, LARGE, {
|
MLX_SWITCH_BOOL(out.data_size() > UINT32_MAX, LARGE, {
|
||||||
using IdxT = std::conditional_t<LARGE, int64_t, uint32_t>;
|
using IdxT = std::conditional_t<LARGE, int64_t, uint32_t>;
|
||||||
auto kernel = cu::ternary_v<Op, DType, IdxT>;
|
auto kernel = cu::ternary_v<Op, DType, IdxT>;
|
||||||
auto [num_blocks, block_dims] = get_launch_args(kernel, out, LARGE);
|
auto [num_blocks, block_dims] = get_launch_args(
|
||||||
|
kernel, out.data_size(), out.shape(), out.strides(), LARGE);
|
||||||
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
||||||
a.data<bool>(),
|
a.data<bool>(),
|
||||||
b.data<DType>(),
|
b.data<DType>(),
|
||||||
|
@ -27,12 +27,14 @@ constexpr bool supports_unary_op() {
|
|||||||
std::is_same_v<Op, ArcSin> || std::is_same_v<Op, ArcSinh> ||
|
std::is_same_v<Op, ArcSin> || std::is_same_v<Op, ArcSinh> ||
|
||||||
std::is_same_v<Op, ArcTan> || std::is_same_v<Op, ArcTanh> ||
|
std::is_same_v<Op, ArcTan> || std::is_same_v<Op, ArcTanh> ||
|
||||||
std::is_same_v<Op, Erf> || std::is_same_v<Op, ErfInv> ||
|
std::is_same_v<Op, Erf> || std::is_same_v<Op, ErfInv> ||
|
||||||
std::is_same_v<Op, Expm1> || std::is_same_v<Op, Log1p> ||
|
std::is_same_v<Op, Expm1> || std::is_same_v<Op, Sigmoid> ||
|
||||||
std::is_same_v<Op, Log> || std::is_same_v<Op, Log2> ||
|
|
||||||
std::is_same_v<Op, Log10> || std::is_same_v<Op, Sigmoid> ||
|
|
||||||
std::is_same_v<Op, Sqrt> || std::is_same_v<Op, Rsqrt>) {
|
std::is_same_v<Op, Sqrt> || std::is_same_v<Op, Rsqrt>) {
|
||||||
return std::is_same_v<In, Out> && is_floating_v<In>;
|
return std::is_same_v<In, Out> && is_floating_v<In>;
|
||||||
}
|
}
|
||||||
|
if (std::is_same_v<Op, Log> || std::is_same_v<Op, Log2> ||
|
||||||
|
std::is_same_v<Op, Log10> || std::is_same_v<Op, Log1p>) {
|
||||||
|
return std::is_same_v<In, Out> && is_inexact_v<In>;
|
||||||
|
}
|
||||||
if (std::is_same_v<Op, BitwiseInvert>) {
|
if (std::is_same_v<Op, BitwiseInvert>) {
|
||||||
return std::is_same_v<In, Out> && std::is_integral_v<In> &&
|
return std::is_same_v<In, Out> && std::is_integral_v<In> &&
|
||||||
!std::is_same_v<In, bool>;
|
!std::is_same_v<In, bool>;
|
||||||
@ -91,7 +93,7 @@ void unary_op_gpu_inplace(
|
|||||||
} else {
|
} else {
|
||||||
auto [shape, strides] = collapse_contiguous_dims(in);
|
auto [shape, strides] = collapse_contiguous_dims(in);
|
||||||
auto [in_begin, in_end] = cu::make_general_iterators<int64_t>(
|
auto [in_begin, in_end] = cu::make_general_iterators<int64_t>(
|
||||||
in_ptr, in.data_size(), shape, strides);
|
in_ptr, in.size(), shape, strides);
|
||||||
thrust::transform(policy, in_begin, in_end, out_ptr, Op());
|
thrust::transform(policy, in_begin, in_end, out_ptr, Op());
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
|
@ -31,6 +31,9 @@ const char* dtype_to_cuda_type(const Dtype& dtype) {
|
|||||||
if (dtype == bfloat16) {
|
if (dtype == bfloat16) {
|
||||||
return "__nv_bfloat16";
|
return "__nv_bfloat16";
|
||||||
}
|
}
|
||||||
|
if (dtype == complex64) {
|
||||||
|
return "cuComplex";
|
||||||
|
}
|
||||||
#define SPECIALIZE_DtypeToString(CPP_TYPE, DTYPE) \
|
#define SPECIALIZE_DtypeToString(CPP_TYPE, DTYPE) \
|
||||||
if (dtype == DTYPE) { \
|
if (dtype == DTYPE) { \
|
||||||
return #CPP_TYPE; \
|
return #CPP_TYPE; \
|
||||||
|
@ -80,7 +80,9 @@ void Worker::thread_fn() {
|
|||||||
}
|
}
|
||||||
worker_tasks_.erase(worker_tasks_.begin(), end);
|
worker_tasks_.erase(worker_tasks_.begin(), end);
|
||||||
}
|
}
|
||||||
for (auto& task : tasks) {
|
// Make sure tasks are cleared before the next wait
|
||||||
|
for (int i = 0; i < tasks.size(); ++i) {
|
||||||
|
auto task = std::move(tasks[i]);
|
||||||
task();
|
task();
|
||||||
}
|
}
|
||||||
worker_event_.wait(batch + 1);
|
worker_event_.wait(batch + 1);
|
||||||
|
17
python/scripts/repair_cuda.sh
Normal file
17
python/scripts/repair_cuda.sh
Normal file
@ -0,0 +1,17 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
auditwheel repair dist/* \
|
||||||
|
--plat manylinux_2_35_x86_64 \
|
||||||
|
--exclude libcublas* \
|
||||||
|
--exclude libnvrtc*
|
||||||
|
|
||||||
|
cd wheelhouse
|
||||||
|
repaired_wheel=$(find . -name "*.whl" -print -quit)
|
||||||
|
unzip -q "${repaired_wheel}"
|
||||||
|
core_so=$(find mlx -name "core*.so" -print -quit)
|
||||||
|
rpath=$(patchelf --print-rpath "${core_so}")
|
||||||
|
rpath=$rpath:\$ORIGIN/../nvidia/cublas/lib:\$ORIGIN/../nvidia/cuda_nvrtc/lib
|
||||||
|
patchelf --force-rpath --set-rpath "$rpath" "$core_so"
|
||||||
|
|
||||||
|
# Re-zip the repaired wheel
|
||||||
|
zip -r -q "${repaired_wheel}" .
|
@ -1,25 +1,37 @@
|
|||||||
cuda_skip = {
|
cuda_skip = {
|
||||||
"TestArray.test_api",
|
"TestArray.test_api",
|
||||||
"TestArray.test_setitem",
|
|
||||||
"TestAutograd.test_cumprod_grad",
|
|
||||||
"TestAutograd.test_slice_grads",
|
|
||||||
"TestAutograd.test_split_against_slice",
|
|
||||||
"TestAutograd.test_stop_gradient",
|
|
||||||
"TestAutograd.test_topk_grad",
|
|
||||||
"TestAutograd.test_update_state",
|
|
||||||
"TestAutograd.test_vjp",
|
|
||||||
"TestBF16.test_arg_reduction_ops",
|
"TestBF16.test_arg_reduction_ops",
|
||||||
"TestBF16.test_binary_ops",
|
|
||||||
"TestBF16.test_reduction_ops",
|
"TestBF16.test_reduction_ops",
|
||||||
"TestBlas.test_block_masked_matmul",
|
|
||||||
"TestBlas.test_complex_gemm",
|
"TestBlas.test_complex_gemm",
|
||||||
|
"TestEinsum.test_ellipses",
|
||||||
|
"TestEinsum.test_opt_einsum_test_cases",
|
||||||
|
"TestLoad.test_load_f8_e4m3",
|
||||||
|
"TestLayers.test_group_norm",
|
||||||
|
"TestLayers.test_pooling",
|
||||||
|
"TestLayers.test_quantized_embedding",
|
||||||
|
"TestLayers.test_sin_pe",
|
||||||
|
"TestLayers.test_upsample",
|
||||||
|
"TestOps.test_complex_ops",
|
||||||
|
"TestOps.test_dynamic_slicing",
|
||||||
|
"TestOps.test_softmax",
|
||||||
|
"TestReduce.test_axis_permutation_sums",
|
||||||
|
"TestReduce.test_dtypes",
|
||||||
|
"TestReduce.test_expand_sums",
|
||||||
|
"TestReduce.test_many_reduction_axes",
|
||||||
|
"TestUpsample.test_torch_upsample",
|
||||||
|
# Block masked matmul NYI
|
||||||
|
"TestBlas.test_block_masked_matmul",
|
||||||
|
# Gather matmul NYI
|
||||||
"TestBlas.test_gather_matmul",
|
"TestBlas.test_gather_matmul",
|
||||||
"TestBlas.test_gather_matmul_grad",
|
"TestBlas.test_gather_matmul_grad",
|
||||||
"TestBlas.test_matmul_batched",
|
# Scan NYI
|
||||||
"TestBlas.test_matrix_vector_attn",
|
"TestAutograd.test_cumprod_grad",
|
||||||
"TestCompile.test_compile_dynamic_dims",
|
"TestOps.test_scans",
|
||||||
"TestCompile.test_compile_inf",
|
"TestOps.test_logcumsumexp",
|
||||||
"TestCompile.test_inf_constant",
|
# Hadamard NYI
|
||||||
|
"TestOps.test_hadamard",
|
||||||
|
"TestOps.test_hadamard_grad_vmap",
|
||||||
|
# Convolutions NYI
|
||||||
"TestConv.test_1d_conv_with_2d",
|
"TestConv.test_1d_conv_with_2d",
|
||||||
"TestConv.test_asymmetric_padding",
|
"TestConv.test_asymmetric_padding",
|
||||||
"TestConv.test_basic_grad_shapes",
|
"TestConv.test_basic_grad_shapes",
|
||||||
@ -46,12 +58,11 @@ cuda_skip = {
|
|||||||
"TestConvTranspose.test_torch_conv_transpose_3D",
|
"TestConvTranspose.test_torch_conv_transpose_3D",
|
||||||
"TestConvTranspose.test_torch_conv_transpose_3D_grad",
|
"TestConvTranspose.test_torch_conv_transpose_3D_grad",
|
||||||
"TestConvTranspose.test_torch_conv_transpose_3d_output_padding",
|
"TestConvTranspose.test_torch_conv_transpose_3d_output_padding",
|
||||||
"TestEinsum.test_attention",
|
|
||||||
"TestEinsum.test_ellipses",
|
|
||||||
"TestEinsum.test_opt_einsum_test_cases",
|
|
||||||
"TestEval.test_multi_output_eval_during_transform",
|
|
||||||
"TestExportImport.test_export_conv",
|
"TestExportImport.test_export_conv",
|
||||||
"TestFast.test_rope_grad",
|
"TestLayers.test_conv1d",
|
||||||
|
"TestLayers.test_conv2d",
|
||||||
|
"TestVmap.test_vmap_conv",
|
||||||
|
# FFTs NYI
|
||||||
"TestFFT.test_fft",
|
"TestFFT.test_fft",
|
||||||
"TestFFT.test_fft_big_powers_of_two",
|
"TestFFT.test_fft_big_powers_of_two",
|
||||||
"TestFFT.test_fft_contiguity",
|
"TestFFT.test_fft_contiguity",
|
||||||
@ -61,61 +72,22 @@ cuda_skip = {
|
|||||||
"TestFFT.test_fft_large_numbers",
|
"TestFFT.test_fft_large_numbers",
|
||||||
"TestFFT.test_fft_shared_mem",
|
"TestFFT.test_fft_shared_mem",
|
||||||
"TestFFT.test_fftn",
|
"TestFFT.test_fftn",
|
||||||
"TestInit.test_orthogonal",
|
# Lapack ops NYI
|
||||||
"TestLinalg.test_cholesky",
|
"TestLinalg.test_cholesky",
|
||||||
"TestLinalg.test_cholesky_inv",
|
"TestLinalg.test_cholesky_inv",
|
||||||
"TestLinalg.test_eig",
|
"TestLinalg.test_eig",
|
||||||
"TestLinalg.test_eigh",
|
"TestLinalg.test_eigh",
|
||||||
"TestLinalg.test_inverse",
|
"TestLinalg.test_inverse",
|
||||||
|
"TestVmap.test_vmap_inverse",
|
||||||
"TestLinalg.test_lu",
|
"TestLinalg.test_lu",
|
||||||
"TestLinalg.test_lu_factor",
|
"TestLinalg.test_lu_factor",
|
||||||
"TestLinalg.test_pseudo_inverse",
|
"TestLinalg.test_pseudo_inverse",
|
||||||
"TestLinalg.test_qr_factorization",
|
"TestLinalg.test_qr_factorization",
|
||||||
|
"TestInit.test_orthogonal",
|
||||||
"TestLinalg.test_svd_decomposition",
|
"TestLinalg.test_svd_decomposition",
|
||||||
|
"TestVmap.test_vmap_svd",
|
||||||
"TestLinalg.test_tri_inverse",
|
"TestLinalg.test_tri_inverse",
|
||||||
"TestLoad.test_load_f8_e4m3",
|
# Quantization NYI
|
||||||
"TestLosses.test_binary_cross_entropy",
|
|
||||||
"TestMemory.test_memory_info",
|
|
||||||
"TestLayers.test_conv1d",
|
|
||||||
"TestLayers.test_conv2d",
|
|
||||||
"TestLayers.test_elu",
|
|
||||||
"TestLayers.test_group_norm",
|
|
||||||
"TestLayers.test_hard_shrink",
|
|
||||||
"TestLayers.test_pooling",
|
|
||||||
"TestLayers.test_quantized_embedding",
|
|
||||||
"TestLayers.test_sin_pe",
|
|
||||||
"TestLayers.test_softshrink",
|
|
||||||
"TestLayers.test_upsample",
|
|
||||||
"TestOps.test_argpartition",
|
|
||||||
"TestOps.test_array_equal",
|
|
||||||
"TestOps.test_as_strided",
|
|
||||||
"TestOps.test_atleast_1d",
|
|
||||||
"TestOps.test_atleast_2d",
|
|
||||||
"TestOps.test_atleast_3d",
|
|
||||||
"TestOps.test_binary_ops",
|
|
||||||
"TestOps.test_bitwise_grad",
|
|
||||||
"TestOps.test_complex_ops",
|
|
||||||
"TestOps.test_divmod",
|
|
||||||
"TestOps.test_dynamic_slicing",
|
|
||||||
"TestOps.test_hadamard",
|
|
||||||
"TestOps.test_hadamard_grad_vmap",
|
|
||||||
"TestOps.test_irregular_binary_ops",
|
|
||||||
"TestOps.test_isfinite",
|
|
||||||
"TestOps.test_kron",
|
|
||||||
"TestOps.test_log",
|
|
||||||
"TestOps.test_log10",
|
|
||||||
"TestOps.test_log1p",
|
|
||||||
"TestOps.test_log2",
|
|
||||||
"TestOps.test_logaddexp",
|
|
||||||
"TestOps.test_logcumsumexp",
|
|
||||||
"TestOps.test_partition",
|
|
||||||
"TestOps.test_scans",
|
|
||||||
"TestOps.test_slice_update_reversed",
|
|
||||||
"TestOps.test_softmax",
|
|
||||||
"TestOps.test_sort",
|
|
||||||
"TestOps.test_tensordot",
|
|
||||||
"TestOps.test_tile",
|
|
||||||
"TestOps.test_view",
|
|
||||||
"TestQuantized.test_gather_matmul_grad",
|
"TestQuantized.test_gather_matmul_grad",
|
||||||
"TestQuantized.test_gather_qmm",
|
"TestQuantized.test_gather_qmm",
|
||||||
"TestQuantized.test_gather_qmm_sorted",
|
"TestQuantized.test_gather_qmm_sorted",
|
||||||
@ -131,13 +103,4 @@ cuda_skip = {
|
|||||||
"TestQuantized.test_small_matrix",
|
"TestQuantized.test_small_matrix",
|
||||||
"TestQuantized.test_throw",
|
"TestQuantized.test_throw",
|
||||||
"TestQuantized.test_vjp_scales_biases",
|
"TestQuantized.test_vjp_scales_biases",
|
||||||
"TestReduce.test_axis_permutation_sums",
|
|
||||||
"TestReduce.test_dtypes",
|
|
||||||
"TestReduce.test_expand_sums",
|
|
||||||
"TestReduce.test_many_reduction_axes",
|
|
||||||
"TestUpsample.test_torch_upsample",
|
|
||||||
"TestVmap.test_unary",
|
|
||||||
"TestVmap.test_vmap_conv",
|
|
||||||
"TestVmap.test_vmap_inverse",
|
|
||||||
"TestVmap.test_vmap_svd",
|
|
||||||
}
|
}
|
||||||
|
@ -1187,7 +1187,7 @@ class TestArray(mlx_tests.MLXTestCase):
|
|||||||
check_slices(np.zeros((3, 2)), np.array([[3, 3], [4, 4]]), np.array([0, 1]))
|
check_slices(np.zeros((3, 2)), np.array([[3, 3], [4, 4]]), np.array([0, 1]))
|
||||||
check_slices(np.zeros((3, 2)), np.array([[3, 3], [4, 4]]), np.array([0, 1]))
|
check_slices(np.zeros((3, 2)), np.array([[3, 3], [4, 4]]), np.array([0, 1]))
|
||||||
check_slices(
|
check_slices(
|
||||||
np.zeros((3, 2)), np.array([[3, 3], [4, 4], [5, 5]]), np.array([0, 0, 1])
|
np.zeros((3, 2)), np.array([[3, 3], [4, 4], [5, 5]]), np.array([0, 2, 1])
|
||||||
)
|
)
|
||||||
|
|
||||||
# Multiple slices
|
# Multiple slices
|
||||||
|
@ -83,14 +83,14 @@ class TestLosses(mlx_tests.MLXTestCase):
|
|||||||
logits, targets, reduction="mean"
|
logits, targets, reduction="mean"
|
||||||
)
|
)
|
||||||
expected_mean = mx.mean(expected_none)
|
expected_mean = mx.mean(expected_none)
|
||||||
self.assertEqual(losses_mean, expected_mean)
|
self.assertTrue(mx.allclose(losses_mean, expected_mean))
|
||||||
|
|
||||||
# Test with reduction 'sum'
|
# Test with reduction 'sum'
|
||||||
losses_sum = nn.losses.binary_cross_entropy(
|
losses_sum = nn.losses.binary_cross_entropy(
|
||||||
logits, targets, reduction="sum"
|
logits, targets, reduction="sum"
|
||||||
)
|
)
|
||||||
expected_sum = mx.sum(expected_none)
|
expected_sum = mx.sum(expected_none)
|
||||||
self.assertEqual(losses_sum, expected_sum)
|
self.assertTrue(mx.allclose(losses_sum, expected_sum))
|
||||||
|
|
||||||
# With weights, no label smoothing
|
# With weights, no label smoothing
|
||||||
weights = mx.array([1.0, 2.0, 1.0, 2.0])
|
weights = mx.array([1.0, 2.0, 1.0, 2.0])
|
||||||
|
@ -2586,17 +2586,6 @@ class TestOps(mlx_tests.MLXTestCase):
|
|||||||
self.assertEqualArray(result, mx.array(expected))
|
self.assertEqualArray(result, mx.array(expected))
|
||||||
|
|
||||||
def test_atleast_1d(self):
|
def test_atleast_1d(self):
|
||||||
def compare_nested_lists(x, y):
|
|
||||||
if isinstance(x, list) and isinstance(y, list):
|
|
||||||
if len(x) != len(y):
|
|
||||||
return False
|
|
||||||
for i in range(len(x)):
|
|
||||||
if not compare_nested_lists(x[i], y[i]):
|
|
||||||
return False
|
|
||||||
return True
|
|
||||||
else:
|
|
||||||
return x == y
|
|
||||||
|
|
||||||
# Test 1D input
|
# Test 1D input
|
||||||
arrays = [
|
arrays = [
|
||||||
[1],
|
[1],
|
||||||
@ -2614,23 +2603,11 @@ class TestOps(mlx_tests.MLXTestCase):
|
|||||||
for i, array in enumerate(arrays):
|
for i, array in enumerate(arrays):
|
||||||
mx_res = mx.atleast_1d(mx.array(array))
|
mx_res = mx.atleast_1d(mx.array(array))
|
||||||
np_res = np.atleast_1d(np.array(array))
|
np_res = np.atleast_1d(np.array(array))
|
||||||
self.assertTrue(compare_nested_lists(mx_res.tolist(), np_res.tolist()))
|
|
||||||
self.assertEqual(mx_res.shape, np_res.shape)
|
self.assertEqual(mx_res.shape, np_res.shape)
|
||||||
self.assertEqual(mx_res.ndim, np_res.ndim)
|
self.assertEqual(mx_res.ndim, np_res.ndim)
|
||||||
self.assertTrue(mx.all(mx.equal(mx_res, atleast_arrays[i])))
|
self.assertTrue(mx.array_equal(mx_res, atleast_arrays[i]))
|
||||||
|
|
||||||
def test_atleast_2d(self):
|
def test_atleast_2d(self):
|
||||||
def compare_nested_lists(x, y):
|
|
||||||
if isinstance(x, list) and isinstance(y, list):
|
|
||||||
if len(x) != len(y):
|
|
||||||
return False
|
|
||||||
for i in range(len(x)):
|
|
||||||
if not compare_nested_lists(x[i], y[i]):
|
|
||||||
return False
|
|
||||||
return True
|
|
||||||
else:
|
|
||||||
return x == y
|
|
||||||
|
|
||||||
# Test 1D input
|
# Test 1D input
|
||||||
arrays = [
|
arrays = [
|
||||||
[1],
|
[1],
|
||||||
@ -2648,23 +2625,11 @@ class TestOps(mlx_tests.MLXTestCase):
|
|||||||
for i, array in enumerate(arrays):
|
for i, array in enumerate(arrays):
|
||||||
mx_res = mx.atleast_2d(mx.array(array))
|
mx_res = mx.atleast_2d(mx.array(array))
|
||||||
np_res = np.atleast_2d(np.array(array))
|
np_res = np.atleast_2d(np.array(array))
|
||||||
self.assertTrue(compare_nested_lists(mx_res.tolist(), np_res.tolist()))
|
|
||||||
self.assertEqual(mx_res.shape, np_res.shape)
|
self.assertEqual(mx_res.shape, np_res.shape)
|
||||||
self.assertEqual(mx_res.ndim, np_res.ndim)
|
self.assertEqual(mx_res.ndim, np_res.ndim)
|
||||||
self.assertTrue(mx.all(mx.equal(mx_res, atleast_arrays[i])))
|
self.assertTrue(mx.array_equal(mx_res, atleast_arrays[i]))
|
||||||
|
|
||||||
def test_atleast_3d(self):
|
def test_atleast_3d(self):
|
||||||
def compare_nested_lists(x, y):
|
|
||||||
if isinstance(x, list) and isinstance(y, list):
|
|
||||||
if len(x) != len(y):
|
|
||||||
return False
|
|
||||||
for i in range(len(x)):
|
|
||||||
if not compare_nested_lists(x[i], y[i]):
|
|
||||||
return False
|
|
||||||
return True
|
|
||||||
else:
|
|
||||||
return x == y
|
|
||||||
|
|
||||||
# Test 1D input
|
# Test 1D input
|
||||||
arrays = [
|
arrays = [
|
||||||
[1],
|
[1],
|
||||||
@ -2682,10 +2647,9 @@ class TestOps(mlx_tests.MLXTestCase):
|
|||||||
for i, array in enumerate(arrays):
|
for i, array in enumerate(arrays):
|
||||||
mx_res = mx.atleast_3d(mx.array(array))
|
mx_res = mx.atleast_3d(mx.array(array))
|
||||||
np_res = np.atleast_3d(np.array(array))
|
np_res = np.atleast_3d(np.array(array))
|
||||||
self.assertTrue(compare_nested_lists(mx_res.tolist(), np_res.tolist()))
|
|
||||||
self.assertEqual(mx_res.shape, np_res.shape)
|
self.assertEqual(mx_res.shape, np_res.shape)
|
||||||
self.assertEqual(mx_res.ndim, np_res.ndim)
|
self.assertEqual(mx_res.ndim, np_res.ndim)
|
||||||
self.assertTrue(mx.all(mx.equal(mx_res, atleast_arrays[i])))
|
self.assertTrue(mx.array_equal(mx_res, atleast_arrays[i]))
|
||||||
|
|
||||||
def test_issubdtype(self):
|
def test_issubdtype(self):
|
||||||
self.assertTrue(mx.issubdtype(mx.bfloat16, mx.inexact))
|
self.assertTrue(mx.issubdtype(mx.bfloat16, mx.inexact))
|
||||||
|
8
setup.py
8
setup.py
@ -174,20 +174,26 @@ if __name__ == "__main__":
|
|||||||
)
|
)
|
||||||
package_dir = {"": "python"}
|
package_dir = {"": "python"}
|
||||||
package_data = {"mlx": ["lib/*", "include/*", "share/*"], "mlx.core": ["*.pyi"]}
|
package_data = {"mlx": ["lib/*", "include/*", "share/*"], "mlx.core": ["*.pyi"]}
|
||||||
|
install_requires = []
|
||||||
|
build_cuda = "MLX_BUILD_CUDA=ON" in os.environ.get("CMAKE_ARGS", "")
|
||||||
|
if build_cuda:
|
||||||
|
install_requires = ["nvidia-cublas-cu12", "nvidia-cuda-nvrtc-cu12"]
|
||||||
|
|
||||||
setup(
|
setup(
|
||||||
name="mlx",
|
name="mlx-cuda" if build_cuda else "mlx",
|
||||||
version=get_version(),
|
version=get_version(),
|
||||||
author="MLX Contributors",
|
author="MLX Contributors",
|
||||||
author_email="mlx@group.apple.com",
|
author_email="mlx@group.apple.com",
|
||||||
description="A framework for machine learning on Apple silicon.",
|
description="A framework for machine learning on Apple silicon.",
|
||||||
long_description=long_description,
|
long_description=long_description,
|
||||||
long_description_content_type="text/markdown",
|
long_description_content_type="text/markdown",
|
||||||
|
license="MIT",
|
||||||
url="https://github.com/ml-explore/mlx",
|
url="https://github.com/ml-explore/mlx",
|
||||||
packages=packages,
|
packages=packages,
|
||||||
package_dir=package_dir,
|
package_dir=package_dir,
|
||||||
package_data=package_data,
|
package_data=package_data,
|
||||||
include_package_data=True,
|
include_package_data=True,
|
||||||
|
install_requires=install_requires,
|
||||||
extras_require={
|
extras_require={
|
||||||
"dev": [
|
"dev": [
|
||||||
"nanobind==2.4.0",
|
"nanobind==2.4.0",
|
||||||
|
Loading…
Reference in New Issue
Block a user