mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Compare commits
15 Commits
v0.26.5
...
70dc336785
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
70dc336785 | ||
|
|
4e504039f5 | ||
|
|
d1f4d291e8 | ||
|
|
e1840853ce | ||
|
|
0f5ce173da | ||
|
|
588854195f | ||
|
|
28d068bce6 | ||
|
|
d107d8d495 | ||
|
|
1e496ddb82 | ||
|
|
74eccbf3fa | ||
|
|
08638223ca | ||
|
|
56cc858af9 | ||
|
|
f55c4ed1d6 | ||
|
|
93d70419e7 | ||
|
|
63f663d9c6 |
@@ -7,6 +7,9 @@ parameters:
|
|||||||
nightly_build:
|
nightly_build:
|
||||||
type: boolean
|
type: boolean
|
||||||
default: false
|
default: false
|
||||||
|
test_release:
|
||||||
|
type: boolean
|
||||||
|
default: false
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
build_documentation:
|
build_documentation:
|
||||||
@@ -200,8 +203,12 @@ jobs:
|
|||||||
python -m xmlrunner discover -v python/tests -o test-results/gpu_jit
|
python -m xmlrunner discover -v python/tests -o test-results/gpu_jit
|
||||||
|
|
||||||
cuda_build_and_test:
|
cuda_build_and_test:
|
||||||
|
parameters:
|
||||||
|
image_date:
|
||||||
|
type: string
|
||||||
|
default: "2023.11.1"
|
||||||
machine:
|
machine:
|
||||||
image: linux-cuda-12:2023.11.1
|
image: "linux-cuda-12:<< parameters.image_date >>"
|
||||||
resource_class: gpu.nvidia.small.gen2
|
resource_class: gpu.nvidia.small.gen2
|
||||||
steps:
|
steps:
|
||||||
- checkout
|
- checkout
|
||||||
@@ -366,22 +373,27 @@ jobs:
|
|||||||
type: string
|
type: string
|
||||||
default: ""
|
default: ""
|
||||||
machine:
|
machine:
|
||||||
image: linux-cuda-12:default
|
image: ubuntu-2204:current
|
||||||
resource_class: gpu.nvidia.small.gen2
|
resource_class: large
|
||||||
steps:
|
steps:
|
||||||
- checkout
|
- checkout
|
||||||
- run:
|
- run:
|
||||||
name: Build wheel
|
name: Build wheel
|
||||||
command: |
|
command: |
|
||||||
|
export DEBIAN_FRONTEND=noninteractive
|
||||||
|
export NEEDRESTART_MODE=a
|
||||||
|
wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2404/x86_64/cuda-keyring_1.1-1_all.deb
|
||||||
|
sudo dpkg -i cuda-keyring_1.1-1_all.deb
|
||||||
sudo apt-get update
|
sudo apt-get update
|
||||||
|
sudo apt install cuda-toolkit-12-9
|
||||||
sudo apt-get install libblas-dev liblapack-dev liblapacke-dev
|
sudo apt-get install libblas-dev liblapack-dev liblapacke-dev
|
||||||
sudo apt-get install zip
|
sudo apt-get install zip
|
||||||
python -m venv env
|
|
||||||
source env/bin/activate
|
|
||||||
pip install auditwheel
|
pip install auditwheel
|
||||||
pip install patchelf
|
pip install patchelf
|
||||||
pip install build
|
pip install build
|
||||||
pip install twine
|
pip install twine
|
||||||
|
export PATH=/usr/local/cuda/bin${PATH:+:${PATH}}
|
||||||
|
export LD_LIBRARY_PATH=/usr/local/cuda/lib64${LD_LIBRARY_PATH:+:${LD_LIBRARY_PATH}}
|
||||||
<< parameters.build_env >> MLX_BUILD_STAGE=2 \
|
<< parameters.build_env >> MLX_BUILD_STAGE=2 \
|
||||||
CMAKE_ARGS="-DMLX_BUILD_CUDA=ON -DCMAKE_CUDA_COMPILER=`which nvcc`" \
|
CMAKE_ARGS="-DMLX_BUILD_CUDA=ON -DCMAKE_CUDA_COMPILER=`which nvcc`" \
|
||||||
python -m build -w
|
python -m build -w
|
||||||
@@ -392,7 +404,6 @@ jobs:
|
|||||||
- run:
|
- run:
|
||||||
name: Upload package
|
name: Upload package
|
||||||
command: |
|
command: |
|
||||||
source env/bin/activate
|
|
||||||
twine upload wheelhouse/*.whl
|
twine upload wheelhouse/*.whl
|
||||||
- store_artifacts:
|
- store_artifacts:
|
||||||
path: wheelhouse/
|
path: wheelhouse/
|
||||||
@@ -405,19 +416,24 @@ workflows:
|
|||||||
pattern: "^(?!pull/)[-\\w]+$"
|
pattern: "^(?!pull/)[-\\w]+$"
|
||||||
value: << pipeline.git.branch >>
|
value: << pipeline.git.branch >>
|
||||||
- not: << pipeline.parameters.nightly_build >>
|
- not: << pipeline.parameters.nightly_build >>
|
||||||
|
- not: << pipeline.parameters.test_release >>
|
||||||
jobs:
|
jobs:
|
||||||
- mac_build_and_test:
|
- mac_build_and_test:
|
||||||
matrix:
|
matrix:
|
||||||
parameters:
|
parameters:
|
||||||
macosx_deployment_target: ["13.5", "14.0"]
|
macosx_deployment_target: ["13.5", "14.0"]
|
||||||
- linux_build_and_test
|
- linux_build_and_test
|
||||||
- cuda_build_and_test
|
- cuda_build_and_test:
|
||||||
|
matrix:
|
||||||
|
parameters:
|
||||||
|
image_date: ["2023.11.1", "2025.05.1"]
|
||||||
- build_documentation
|
- build_documentation
|
||||||
|
|
||||||
build_pypi_release:
|
build_pypi_release:
|
||||||
when:
|
when:
|
||||||
and:
|
and:
|
||||||
- not: << pipeline.parameters.nightly_build >>
|
- not: << pipeline.parameters.nightly_build >>
|
||||||
|
- not: << pipeline.parameters.test_release >>
|
||||||
jobs:
|
jobs:
|
||||||
- build_release:
|
- build_release:
|
||||||
filters:
|
filters:
|
||||||
@@ -601,3 +617,87 @@ workflows:
|
|||||||
parameters:
|
parameters:
|
||||||
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
|
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
|
||||||
- build_cuda_release
|
- build_cuda_release
|
||||||
|
|
||||||
|
build_dev_release:
|
||||||
|
when:
|
||||||
|
and:
|
||||||
|
- equal: [ main, << pipeline.git.branch >> ]
|
||||||
|
- << pipeline.parameters.test_release >>
|
||||||
|
jobs:
|
||||||
|
- build_release:
|
||||||
|
matrix:
|
||||||
|
parameters:
|
||||||
|
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
|
||||||
|
macosx_deployment_target: ["13.5", "14.0", "15.0"]
|
||||||
|
build_env: ["DEV_RELEASE=1"]
|
||||||
|
xcode_version: ["16.2.0", "15.0.0"]
|
||||||
|
exclude:
|
||||||
|
- macosx_deployment_target: "13.5"
|
||||||
|
xcode_version: "16.2.0"
|
||||||
|
python_version: "3.9"
|
||||||
|
build_env: "DEV_RELEASE=1"
|
||||||
|
- macosx_deployment_target: "13.5"
|
||||||
|
xcode_version: "16.2.0"
|
||||||
|
python_version: "3.10"
|
||||||
|
build_env: "DEV_RELEASE=1"
|
||||||
|
- macosx_deployment_target: "13.5"
|
||||||
|
xcode_version: "16.2.0"
|
||||||
|
python_version: "3.11"
|
||||||
|
build_env: "DEV_RELEASE=1"
|
||||||
|
- macosx_deployment_target: "13.5"
|
||||||
|
xcode_version: "16.2.0"
|
||||||
|
python_version: "3.12"
|
||||||
|
build_env: "DEV_RELEASE=1"
|
||||||
|
- macosx_deployment_target: "13.5"
|
||||||
|
xcode_version: "16.2.0"
|
||||||
|
python_version: "3.13"
|
||||||
|
build_env: "DEV_RELEASE=1"
|
||||||
|
- macosx_deployment_target: "14.0"
|
||||||
|
xcode_version: "15.0.0"
|
||||||
|
python_version: "3.9"
|
||||||
|
build_env: "DEV_RELEASE=1"
|
||||||
|
- macosx_deployment_target: "14.0"
|
||||||
|
xcode_version: "15.0.0"
|
||||||
|
python_version: "3.10"
|
||||||
|
build_env: "DEV_RELEASE=1"
|
||||||
|
- macosx_deployment_target: "14.0"
|
||||||
|
xcode_version: "15.0.0"
|
||||||
|
python_version: "3.11"
|
||||||
|
build_env: "DEV_RELEASE=1"
|
||||||
|
- macosx_deployment_target: "14.0"
|
||||||
|
xcode_version: "15.0.0"
|
||||||
|
python_version: "3.12"
|
||||||
|
build_env: "DEV_RELEASE=1"
|
||||||
|
- macosx_deployment_target: "14.0"
|
||||||
|
xcode_version: "15.0.0"
|
||||||
|
python_version: "3.13"
|
||||||
|
build_env: "DEV_RELEASE=1"
|
||||||
|
- macosx_deployment_target: "15.0"
|
||||||
|
xcode_version: "15.0.0"
|
||||||
|
python_version: "3.9"
|
||||||
|
build_env: "DEV_RELEASE=1"
|
||||||
|
- macosx_deployment_target: "15.0"
|
||||||
|
xcode_version: "15.0.0"
|
||||||
|
python_version: "3.10"
|
||||||
|
build_env: "DEV_RELEASE=1"
|
||||||
|
- macosx_deployment_target: "15.0"
|
||||||
|
xcode_version: "15.0.0"
|
||||||
|
python_version: "3.11"
|
||||||
|
build_env: "DEV_RELEASE=1"
|
||||||
|
- macosx_deployment_target: "15.0"
|
||||||
|
xcode_version: "15.0.0"
|
||||||
|
python_version: "3.12"
|
||||||
|
build_env: "DEV_RELEASE=1"
|
||||||
|
- macosx_deployment_target: "15.0"
|
||||||
|
xcode_version: "15.0.0"
|
||||||
|
python_version: "3.13"
|
||||||
|
build_env: "DEV_RELEASE=1"
|
||||||
|
- build_linux_release:
|
||||||
|
matrix:
|
||||||
|
parameters:
|
||||||
|
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
|
||||||
|
build_env: ["DEV_RELEASE=1"]
|
||||||
|
- build_cuda_release:
|
||||||
|
matrix:
|
||||||
|
parameters:
|
||||||
|
build_env: ["DEV_RELEASE=1"]
|
||||||
|
|||||||
@@ -377,4 +377,10 @@ void copy_cpu_inplace(
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
array contiguous_copy_cpu(const array& arr, Stream stream) {
|
||||||
|
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
|
||||||
|
copy_cpu(arr, arr_copy, CopyType::General, stream);
|
||||||
|
return arr_copy;
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
|||||||
@@ -30,4 +30,7 @@ void copy_cpu_inplace(
|
|||||||
const std::optional<array>& dynamic_i_offset = std::nullopt,
|
const std::optional<array>& dynamic_i_offset = std::nullopt,
|
||||||
const std::optional<array>& dynamic_o_offset = std::nullopt);
|
const std::optional<array>& dynamic_o_offset = std::nullopt);
|
||||||
|
|
||||||
|
// Return a contiguous array with same shape that copies the data of |arr|.
|
||||||
|
array contiguous_copy_cpu(const array& arr, Stream stream);
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
|||||||
@@ -13,9 +13,7 @@ std::pair<array, bool> ensure_row_contiguous(const array& arr, Stream stream) {
|
|||||||
if (arr.flags().row_contiguous) {
|
if (arr.flags().row_contiguous) {
|
||||||
return {arr, false};
|
return {arr, false};
|
||||||
} else {
|
} else {
|
||||||
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
|
return {contiguous_copy_cpu(arr, stream), true};
|
||||||
copy_cpu(arr, arr_copy, CopyType::General, stream);
|
|
||||||
return {arr_copy, true};
|
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -34,8 +32,7 @@ void AllReduce::eval_cpu(
|
|||||||
}
|
}
|
||||||
return in;
|
return in;
|
||||||
} else {
|
} else {
|
||||||
array arr_copy(in.shape(), in.dtype(), nullptr, {});
|
array arr_copy = contiguous_copy_cpu(in, s);
|
||||||
copy_cpu(in, arr_copy, CopyType::General, s);
|
|
||||||
out.copy_shared_buffer(arr_copy);
|
out.copy_shared_buffer(arr_copy);
|
||||||
return arr_copy;
|
return arr_copy;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -87,8 +87,7 @@ void LogSumExp::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
if (x.flags().contiguous && x.strides()[x.ndim() - 1] == 1) {
|
if (x.flags().contiguous && x.strides()[x.ndim() - 1] == 1) {
|
||||||
return x;
|
return x;
|
||||||
} else {
|
} else {
|
||||||
auto x_copy = array(x.shape(), x.dtype(), nullptr, {});
|
array x_copy = contiguous_copy_cpu(x, s);
|
||||||
copy_cpu(x, x_copy, CopyType::General, s);
|
|
||||||
encoder.add_temporary(x_copy);
|
encoder.add_temporary(x_copy);
|
||||||
return x_copy;
|
return x_copy;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -136,9 +136,8 @@ void BlockMaskedMM::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
}
|
}
|
||||||
return std::make_tuple(true, sty, arr, false);
|
return std::make_tuple(true, sty, arr, false);
|
||||||
} else {
|
} else {
|
||||||
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
|
|
||||||
copy_cpu(arr, arr_copy, CopyType::General, s);
|
|
||||||
int64_t stx = arr.shape(-1);
|
int64_t stx = arr.shape(-1);
|
||||||
|
array arr_copy = contiguous_copy_cpu(arr, s);
|
||||||
return std::make_tuple(false, stx, arr_copy, true);
|
return std::make_tuple(false, stx, arr_copy, true);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|||||||
@@ -712,9 +712,7 @@ void fast::AffineQuantize::eval_cpu(
|
|||||||
if (arr.flags().row_contiguous) {
|
if (arr.flags().row_contiguous) {
|
||||||
return std::make_pair(arr, false);
|
return std::make_pair(arr, false);
|
||||||
} else {
|
} else {
|
||||||
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
|
return std::make_pair(contiguous_copy_cpu(arr, s), true);
|
||||||
copy_cpu(arr, arr_copy, CopyType::General, s);
|
|
||||||
return std::make_pair(arr_copy, true);
|
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|||||||
@@ -250,10 +250,8 @@ void Scan::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
// Ensure contiguity
|
// Ensure contiguity
|
||||||
auto in = inputs[0];
|
auto in = inputs[0];
|
||||||
if (!in.flags().row_contiguous) {
|
if (!in.flags().row_contiguous) {
|
||||||
array arr_copy(in.shape(), in.dtype(), nullptr, {});
|
in = contiguous_copy_cpu(in, stream());
|
||||||
copy_cpu(in, arr_copy, CopyType::General, stream());
|
encoder.add_temporary(in);
|
||||||
in = arr_copy;
|
|
||||||
encoder.add_temporary(arr_copy);
|
|
||||||
}
|
}
|
||||||
out.set_data(allocator::malloc(out.nbytes()));
|
out.set_data(allocator::malloc(out.nbytes()));
|
||||||
|
|
||||||
|
|||||||
@@ -131,8 +131,7 @@ void Softmax::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
}
|
}
|
||||||
return x;
|
return x;
|
||||||
} else {
|
} else {
|
||||||
array x_copy(x.shape(), x.dtype(), nullptr, {});
|
array x_copy = contiguous_copy_cpu(x, s);
|
||||||
copy_cpu(x, x_copy, CopyType::General, s);
|
|
||||||
out.copy_shared_buffer(x_copy);
|
out.copy_shared_buffer(x_copy);
|
||||||
return x_copy;
|
return x_copy;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -20,6 +20,7 @@ target_sources(
|
|||||||
${CMAKE_CURRENT_SOURCE_DIR}/eval.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/eval.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/event.cu
|
${CMAKE_CURRENT_SOURCE_DIR}/event.cu
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/fence.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/fence.cpp
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/gemv.cu
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/jit_module.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/jit_module.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/kernel_utils.cu
|
${CMAKE_CURRENT_SOURCE_DIR}/kernel_utils.cu
|
||||||
@@ -87,6 +88,13 @@ endif()
|
|||||||
target_compile_options(
|
target_compile_options(
|
||||||
mlx PRIVATE "$<$<COMPILE_LANGUAGE:CUDA>:--Wno-deprecated-gpu-targets>")
|
mlx PRIVATE "$<$<COMPILE_LANGUAGE:CUDA>:--Wno-deprecated-gpu-targets>")
|
||||||
|
|
||||||
|
# Use stronger binaries compression. This feature was introduced in CUDA 12.8
|
||||||
|
# and requires drivers released after CUDA 12.4.
|
||||||
|
if(CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 12.8.0)
|
||||||
|
target_compile_options(
|
||||||
|
mlx PRIVATE "$<$<COMPILE_LANGUAGE:CUDA>:--compress-mode=size>")
|
||||||
|
endif()
|
||||||
|
|
||||||
# Compute capability 7 is required for synchronization between CPU/GPU with
|
# Compute capability 7 is required for synchronization between CPU/GPU with
|
||||||
# managed memory. TODO: Add more architectures for potential performance gain.
|
# managed memory. TODO: Add more architectures for potential performance gain.
|
||||||
set(MLX_CUDA_ARCHITECTURES
|
set(MLX_CUDA_ARCHITECTURES
|
||||||
|
|||||||
@@ -2,7 +2,6 @@
|
|||||||
|
|
||||||
#include "mlx/backend/cuda/allocator.h"
|
#include "mlx/backend/cuda/allocator.h"
|
||||||
#include "mlx/backend/cuda/utils.h"
|
#include "mlx/backend/cuda/utils.h"
|
||||||
#include "mlx/backend/cuda/worker.h"
|
|
||||||
#include "mlx/utils.h"
|
#include "mlx/utils.h"
|
||||||
|
|
||||||
#include <cuda_runtime.h>
|
#include <cuda_runtime.h>
|
||||||
@@ -17,14 +16,66 @@ namespace cu {
|
|||||||
|
|
||||||
constexpr int page_size = 16384;
|
constexpr int page_size = 16384;
|
||||||
|
|
||||||
|
// Any allocations smaller than this will try to use the small pool
|
||||||
|
constexpr int small_block_size = 8;
|
||||||
|
|
||||||
|
// The small pool size in bytes. This should be a multiple of the host page
|
||||||
|
// size and small_block_size.
|
||||||
|
constexpr int small_pool_size = 4 * page_size;
|
||||||
|
|
||||||
|
SmallSizePool::SmallSizePool() {
|
||||||
|
auto num_blocks = small_pool_size / small_block_size;
|
||||||
|
buffer_ = new Block[num_blocks];
|
||||||
|
|
||||||
|
next_free_ = buffer_;
|
||||||
|
|
||||||
|
CHECK_CUDA_ERROR(cudaMallocManaged(&data_, small_pool_size));
|
||||||
|
CHECK_CUDA_ERROR(
|
||||||
|
cudaMemAdvise(data_, small_pool_size, cudaMemAdviseSetReadMostly, 0));
|
||||||
|
|
||||||
|
auto curr = next_free_;
|
||||||
|
for (size_t i = 1; i < num_blocks; ++i) {
|
||||||
|
curr->next = buffer_ + i;
|
||||||
|
curr = curr->next;
|
||||||
|
}
|
||||||
|
curr->next = nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
SmallSizePool::~SmallSizePool() {
|
||||||
|
CHECK_CUDA_ERROR(cudaFree(data_));
|
||||||
|
delete[] buffer_;
|
||||||
|
}
|
||||||
|
|
||||||
|
CudaBuffer* SmallSizePool::malloc() {
|
||||||
|
if (next_free_ == nullptr) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
Block* b = next_free_;
|
||||||
|
uint64_t i = next_free_ - buffer_;
|
||||||
|
next_free_ = next_free_->next;
|
||||||
|
b->buf.data = static_cast<char*>(data_) + i * small_block_size;
|
||||||
|
b->buf.size = small_block_size;
|
||||||
|
return &b->buf;
|
||||||
|
}
|
||||||
|
|
||||||
|
void SmallSizePool::free(CudaBuffer* buf) {
|
||||||
|
auto b = reinterpret_cast<Block*>(buf);
|
||||||
|
b->next = next_free_;
|
||||||
|
next_free_ = b;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool SmallSizePool::in_pool(CudaBuffer* buf) {
|
||||||
|
constexpr int num_blocks = (small_pool_size / small_block_size);
|
||||||
|
auto b = reinterpret_cast<Block*>(buf);
|
||||||
|
int64_t block_num = b - buffer_;
|
||||||
|
return block_num >= 0 && block_num < num_blocks;
|
||||||
|
}
|
||||||
|
|
||||||
CudaAllocator::CudaAllocator()
|
CudaAllocator::CudaAllocator()
|
||||||
: buffer_cache_(
|
: buffer_cache_(
|
||||||
page_size,
|
page_size,
|
||||||
[](CudaBuffer* buf) { return buf->size; },
|
[](CudaBuffer* buf) { return buf->size; },
|
||||||
[this](CudaBuffer* buf) {
|
[this](CudaBuffer* buf) { cuda_free(buf); }) {
|
||||||
cuda_free(buf->data);
|
|
||||||
delete buf;
|
|
||||||
}) {
|
|
||||||
// TODO: Set memory limit for multi-device.
|
// TODO: Set memory limit for multi-device.
|
||||||
size_t free, total;
|
size_t free, total;
|
||||||
CHECK_CUDA_ERROR(cudaMemGetInfo(&free, &total));
|
CHECK_CUDA_ERROR(cudaMemGetInfo(&free, &total));
|
||||||
@@ -36,7 +87,9 @@ Buffer CudaAllocator::malloc(size_t size) {
|
|||||||
// Find available buffer from cache.
|
// Find available buffer from cache.
|
||||||
auto orig_size = size;
|
auto orig_size = size;
|
||||||
std::unique_lock lock(mutex_);
|
std::unique_lock lock(mutex_);
|
||||||
if (size < page_size) {
|
if (size <= small_block_size) {
|
||||||
|
size = 8;
|
||||||
|
} else if (size < page_size) {
|
||||||
size = next_power_of_2(size);
|
size = next_power_of_2(size);
|
||||||
} else {
|
} else {
|
||||||
size = page_size * ((size + page_size - 1) / page_size);
|
size = page_size * ((size + page_size - 1) / page_size);
|
||||||
@@ -44,19 +97,25 @@ Buffer CudaAllocator::malloc(size_t size) {
|
|||||||
|
|
||||||
CudaBuffer* buf = buffer_cache_.reuse_from_cache(size);
|
CudaBuffer* buf = buffer_cache_.reuse_from_cache(size);
|
||||||
if (!buf) {
|
if (!buf) {
|
||||||
// If we have a lot of memory pressure or are over the maximum cache size,
|
// If we have a lot of memory pressure try to reclaim memory from the cache.
|
||||||
// try to reclaim memory from the cache.
|
int64_t mem_to_free =
|
||||||
size_t mem_required = get_active_memory() + get_cache_memory() + size;
|
get_active_memory() + get_cache_memory() + size - memory_limit_;
|
||||||
if (mem_required >= memory_limit_) {
|
if (mem_to_free > 0) {
|
||||||
buffer_cache_.release_cached_buffers(mem_required - memory_limit_);
|
buffer_cache_.release_cached_buffers(mem_to_free);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Try the scalar pool first
|
||||||
|
if (size <= small_block_size) {
|
||||||
|
buf = scalar_pool_.malloc();
|
||||||
|
}
|
||||||
lock.unlock();
|
lock.unlock();
|
||||||
buf = new CudaBuffer{nullptr, size};
|
if (!buf) {
|
||||||
cudaError_t err = cudaMallocManaged(&buf->data, size);
|
buf = new CudaBuffer{nullptr, size};
|
||||||
if (err != cudaSuccess && err != cudaErrorMemoryAllocation) {
|
cudaError_t err = cudaMallocManaged(&buf->data, size);
|
||||||
throw std::runtime_error(fmt::format(
|
if (err != cudaSuccess && err != cudaErrorMemoryAllocation) {
|
||||||
"cudaMallocManaged failed: {}.", cudaGetErrorString(err)));
|
throw std::runtime_error(fmt::format(
|
||||||
|
"cudaMallocManaged failed: {}.", cudaGetErrorString(err)));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
lock.lock();
|
lock.lock();
|
||||||
}
|
}
|
||||||
@@ -67,7 +126,6 @@ Buffer CudaAllocator::malloc(size_t size) {
|
|||||||
if (get_cache_memory() > max_pool_size_) {
|
if (get_cache_memory() > max_pool_size_) {
|
||||||
buffer_cache_.release_cached_buffers(get_cache_memory() - max_pool_size_);
|
buffer_cache_.release_cached_buffers(get_cache_memory() - max_pool_size_);
|
||||||
}
|
}
|
||||||
|
|
||||||
return Buffer{buf};
|
return Buffer{buf};
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -82,9 +140,7 @@ void CudaAllocator::free(Buffer buffer) {
|
|||||||
if (get_cache_memory() < max_pool_size_) {
|
if (get_cache_memory() < max_pool_size_) {
|
||||||
buffer_cache_.recycle_to_cache(buf);
|
buffer_cache_.recycle_to_cache(buf);
|
||||||
} else {
|
} else {
|
||||||
lock.unlock();
|
cuda_free(buf);
|
||||||
cuda_free(buf->data);
|
|
||||||
delete buf;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -96,27 +152,14 @@ size_t CudaAllocator::size(Buffer buffer) const {
|
|||||||
return buf->size;
|
return buf->size;
|
||||||
}
|
}
|
||||||
|
|
||||||
void CudaAllocator::register_this_thread() {
|
// This must be called with mutex_ aquired
|
||||||
std::lock_guard lock(worker_mutex_);
|
void CudaAllocator::cuda_free(CudaBuffer* buf) {
|
||||||
allowed_threads_.insert(std::this_thread::get_id());
|
if (scalar_pool_.in_pool(buf)) {
|
||||||
}
|
scalar_pool_.free(buf);
|
||||||
|
} else {
|
||||||
void CudaAllocator::cuda_free(void* buf) {
|
cudaFree(buf->data);
|
||||||
// If cuda_free() is called from a unregistered thread, reschedule the call to
|
delete buf;
|
||||||
// worker.
|
|
||||||
{
|
|
||||||
std::lock_guard lock(worker_mutex_);
|
|
||||||
if (allowed_threads_.count(std::this_thread::get_id()) == 0) {
|
|
||||||
if (!worker_) {
|
|
||||||
worker_.reset(new Worker);
|
|
||||||
}
|
|
||||||
worker_->add_task([this, buf]() { this->cuda_free(buf); });
|
|
||||||
worker_->end_batch();
|
|
||||||
worker_->commit();
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
cudaFree(buf);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t CudaAllocator::get_active_memory() const {
|
size_t CudaAllocator::get_active_memory() const {
|
||||||
|
|||||||
@@ -7,13 +7,10 @@
|
|||||||
|
|
||||||
#include <mutex>
|
#include <mutex>
|
||||||
#include <set>
|
#include <set>
|
||||||
#include <thread>
|
|
||||||
#include <utility>
|
#include <utility>
|
||||||
|
|
||||||
namespace mlx::core::cu {
|
namespace mlx::core::cu {
|
||||||
|
|
||||||
class Worker;
|
|
||||||
|
|
||||||
using allocator::Buffer;
|
using allocator::Buffer;
|
||||||
|
|
||||||
// Stores cuda-managed unified memory.
|
// Stores cuda-managed unified memory.
|
||||||
@@ -22,21 +19,35 @@ struct CudaBuffer {
|
|||||||
size_t size;
|
size_t size;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
class SmallSizePool {
|
||||||
|
private:
|
||||||
|
union Block {
|
||||||
|
Block* next;
|
||||||
|
CudaBuffer buf;
|
||||||
|
};
|
||||||
|
|
||||||
|
Block* buffer_{nullptr};
|
||||||
|
void* data_{nullptr};
|
||||||
|
Block* next_free_{nullptr};
|
||||||
|
|
||||||
|
public:
|
||||||
|
SmallSizePool();
|
||||||
|
~SmallSizePool();
|
||||||
|
|
||||||
|
SmallSizePool(const SmallSizePool&) = delete;
|
||||||
|
SmallSizePool& operator=(const SmallSizePool&) = delete;
|
||||||
|
|
||||||
|
CudaBuffer* malloc();
|
||||||
|
void free(CudaBuffer* buf);
|
||||||
|
bool in_pool(CudaBuffer* buf);
|
||||||
|
};
|
||||||
|
|
||||||
class CudaAllocator : public allocator::Allocator {
|
class CudaAllocator : public allocator::Allocator {
|
||||||
public:
|
public:
|
||||||
Buffer malloc(size_t size) override;
|
Buffer malloc(size_t size) override;
|
||||||
void free(Buffer buffer) override;
|
void free(Buffer buffer) override;
|
||||||
size_t size(Buffer buffer) const override;
|
size_t size(Buffer buffer) const override;
|
||||||
|
|
||||||
// Register current thread as safe to free buffers.
|
|
||||||
// In cuda freeing a buffer implicitly synchronizes stream, and for threads
|
|
||||||
// that may be waited by gpu stream (for example cpu stream threads), freeing
|
|
||||||
// buffers there would result in dead lock.
|
|
||||||
void register_this_thread();
|
|
||||||
|
|
||||||
// Call cudaFree in the safe thread.
|
|
||||||
void cuda_free(void* buf);
|
|
||||||
|
|
||||||
size_t get_active_memory() const;
|
size_t get_active_memory() const;
|
||||||
size_t get_peak_memory() const;
|
size_t get_peak_memory() const;
|
||||||
void reset_peak_memory();
|
void reset_peak_memory();
|
||||||
@@ -47,19 +58,18 @@ class CudaAllocator : public allocator::Allocator {
|
|||||||
void clear_cache();
|
void clear_cache();
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
void cuda_free(CudaBuffer* buf);
|
||||||
|
|
||||||
CudaAllocator();
|
CudaAllocator();
|
||||||
friend CudaAllocator& allocator();
|
friend CudaAllocator& allocator();
|
||||||
|
|
||||||
std::mutex worker_mutex_;
|
|
||||||
std::unique_ptr<Worker> worker_;
|
|
||||||
std::set<std::thread::id> allowed_threads_;
|
|
||||||
|
|
||||||
std::mutex mutex_;
|
std::mutex mutex_;
|
||||||
size_t memory_limit_;
|
size_t memory_limit_;
|
||||||
size_t max_pool_size_;
|
size_t max_pool_size_;
|
||||||
BufferCache<CudaBuffer> buffer_cache_;
|
BufferCache<CudaBuffer> buffer_cache_;
|
||||||
size_t active_memory_{0};
|
size_t active_memory_{0};
|
||||||
size_t peak_memory_{0};
|
size_t peak_memory_{0};
|
||||||
|
SmallSizePool scalar_pool_;
|
||||||
};
|
};
|
||||||
|
|
||||||
CudaAllocator& allocator();
|
CudaAllocator& allocator();
|
||||||
|
|||||||
@@ -1,8 +1,8 @@
|
|||||||
// Copyright © 2025 Apple Inc.
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
#include "mlx/backend/common/utils.h"
|
#include "mlx/backend/common/utils.h"
|
||||||
#include "mlx/backend/cuda/device.h"
|
#include "mlx/backend/cuda/device.h"
|
||||||
#include "mlx/backend/cuda/device/fp16_math.cuh"
|
#include "mlx/backend/cuda/device/fp16_math.cuh"
|
||||||
#include "mlx/backend/cuda/iterators/strided_iterator.cuh"
|
|
||||||
#include "mlx/backend/cuda/kernel_utils.cuh"
|
#include "mlx/backend/cuda/kernel_utils.cuh"
|
||||||
#include "mlx/dtype_utils.h"
|
#include "mlx/dtype_utils.h"
|
||||||
#include "mlx/primitives.h"
|
#include "mlx/primitives.h"
|
||||||
@@ -115,7 +115,7 @@ __global__ void arg_reduce_general(
|
|||||||
T vals[N_READS];
|
T vals[N_READS];
|
||||||
auto tid = r * BLOCK_DIM + block.thread_index().x;
|
auto tid = r * BLOCK_DIM + block.thread_index().x;
|
||||||
cub::LoadDirectBlocked(
|
cub::LoadDirectBlocked(
|
||||||
tid, strided_iterator(in + in_idx, axis_stride), vals, axis_size, init);
|
tid, StridedIterator(in + in_idx, axis_stride), vals, axis_size, init);
|
||||||
best = op.reduce_many(best, vals, tid * N_READS);
|
best = op.reduce_many(best, vals, tid * N_READS);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -128,7 +128,7 @@ __global__ void binary_g(
|
|||||||
int ndim) {
|
int ndim) {
|
||||||
IdxT index = cg::this_grid().thread_rank();
|
IdxT index = cg::this_grid().thread_rank();
|
||||||
if (index < size) {
|
if (index < size) {
|
||||||
auto [a_idx, b_idx] = elem_to_loc_4d(
|
auto [a_idx, b_idx] = elem_to_loc(
|
||||||
index, shape.data(), a_strides.data(), b_strides.data(), ndim);
|
index, shape.data(), a_strides.data(), b_strides.data(), ndim);
|
||||||
out[index] = Op{}(a[a_idx], b[b_idx]);
|
out[index] = Op{}(a[a_idx], b[b_idx]);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -160,7 +160,7 @@ __global__ void binary_two_g(
|
|||||||
int ndim) {
|
int ndim) {
|
||||||
IdxT index = cg::this_grid().thread_rank();
|
IdxT index = cg::this_grid().thread_rank();
|
||||||
if (index < size) {
|
if (index < size) {
|
||||||
auto [a_idx, b_idx] = elem_to_loc_4d(
|
auto [a_idx, b_idx] = elem_to_loc(
|
||||||
index, shape.data(), a_strides.data(), b_strides.data(), ndim);
|
index, shape.data(), a_strides.data(), b_strides.data(), ndim);
|
||||||
auto out = Op{}(a[a_idx], b[b_idx]);
|
auto out = Op{}(a[a_idx], b[b_idx]);
|
||||||
out_a[index] = out[0];
|
out_a[index] = out[0];
|
||||||
|
|||||||
@@ -37,7 +37,7 @@ __global__ void copy_gg(
|
|||||||
int ndim) {
|
int ndim) {
|
||||||
IdxT index = cg::this_grid().thread_rank();
|
IdxT index = cg::this_grid().thread_rank();
|
||||||
if (index < size) {
|
if (index < size) {
|
||||||
auto [idx_in, idx_out] = elem_to_loc_4d(
|
auto [idx_in, idx_out] = elem_to_loc(
|
||||||
index, shape.data(), strides_in.data(), strides_out.data(), ndim);
|
index, shape.data(), strides_in.data(), strides_out.data(), ndim);
|
||||||
out[idx_out] = CastOp<In, Out>{}(in[idx_in]);
|
out[idx_out] = CastOp<In, Out>{}(in[idx_in]);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -41,7 +41,7 @@ __global__ void copy_gg_dynamic(
|
|||||||
const int64_t* offset_out) {
|
const int64_t* offset_out) {
|
||||||
IdxT index = cg::this_grid().thread_rank();
|
IdxT index = cg::this_grid().thread_rank();
|
||||||
if (index < size) {
|
if (index < size) {
|
||||||
auto [idx_in, idx_out] = elem_to_loc_4d(
|
auto [idx_in, idx_out] = elem_to_loc(
|
||||||
index, shape.data(), strides_in.data(), strides_out.data(), ndim);
|
index, shape.data(), strides_in.data(), strides_out.data(), ndim);
|
||||||
out[idx_out + *offset_out] = CastOp<In, Out>{}(in[idx_in + *offset_in]);
|
out[idx_out + *offset_out] = CastOp<In, Out>{}(in[idx_in + *offset_in]);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -34,7 +34,7 @@ __global__ void copy_g(
|
|||||||
int ndim) {
|
int ndim) {
|
||||||
IdxT index = cg::this_grid().thread_rank();
|
IdxT index = cg::this_grid().thread_rank();
|
||||||
if (index < size) {
|
if (index < size) {
|
||||||
IdxT idx_in = elem_to_loc_4d(index, shape.data(), strides_in.data(), ndim);
|
IdxT idx_in = elem_to_loc(index, shape.data(), strides_in.data(), ndim);
|
||||||
out[index] = CastOp<In, Out>{}(in[idx_in]);
|
out[index] = CastOp<In, Out>{}(in[idx_in]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -306,7 +306,6 @@ void CommandEncoder::commit() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Put completion handlers in a batch.
|
// Put completion handlers in a batch.
|
||||||
worker_.end_batch();
|
|
||||||
worker_.commit(stream_);
|
worker_.commit(stream_);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -315,7 +314,6 @@ void CommandEncoder::synchronize() {
|
|||||||
auto p = std::make_shared<std::promise<void>>();
|
auto p = std::make_shared<std::promise<void>>();
|
||||||
std::future<void> f = p->get_future();
|
std::future<void> f = p->get_future();
|
||||||
add_completed_handler([p = std::move(p)]() { p->set_value(); });
|
add_completed_handler([p = std::move(p)]() { p->set_value(); });
|
||||||
worker_.end_batch();
|
|
||||||
commit();
|
commit();
|
||||||
f.wait();
|
f.wait();
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -49,6 +49,20 @@ store_vector(T* ptr, uint32_t offset, const AlignedVector<T, N>& vec) {
|
|||||||
to[offset] = vec;
|
to[offset] = vec;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Helper for accessing strided data.
|
||||||
|
template <typename T>
|
||||||
|
struct StridedIterator {
|
||||||
|
T it;
|
||||||
|
int64_t stride;
|
||||||
|
|
||||||
|
__host__ __device__ StridedIterator(T it, int64_t stride)
|
||||||
|
: it(it), stride(stride) {}
|
||||||
|
|
||||||
|
__host__ __device__ auto operator[](int i) const {
|
||||||
|
return it[i * stride];
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
// Type limits utils
|
// Type limits utils
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
@@ -204,20 +218,8 @@ inline __host__ __device__ cuda::std::tuple<IdxT, IdxT, IdxT> elem_to_loc_nd(
|
|||||||
return cuda::std::make_tuple(a_loc, b_loc, c_loc);
|
return cuda::std::make_tuple(a_loc, b_loc, c_loc);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Optimized version when ndim is larger than 4.
|
|
||||||
template <typename IdxT = int64_t>
|
template <typename IdxT = int64_t>
|
||||||
inline __host__ __device__ IdxT
|
inline __host__ __device__ cuda::std::tuple<IdxT, IdxT> elem_to_loc(
|
||||||
elem_to_loc_4d(IdxT elem, const int* shape, const int64_t* strides, int ndim) {
|
|
||||||
IdxT loc = 0;
|
|
||||||
for (int i = ndim - 1; i >= 0; --i) {
|
|
||||||
loc += (elem % shape[i]) * IdxT(strides[i]);
|
|
||||||
elem /= shape[i];
|
|
||||||
}
|
|
||||||
return loc;
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename IdxT = int64_t>
|
|
||||||
inline __host__ __device__ cuda::std::tuple<IdxT, IdxT> elem_to_loc_4d(
|
|
||||||
IdxT elem,
|
IdxT elem,
|
||||||
const int* shape,
|
const int* shape,
|
||||||
const int64_t* a_strides,
|
const int64_t* a_strides,
|
||||||
@@ -235,7 +237,7 @@ inline __host__ __device__ cuda::std::tuple<IdxT, IdxT> elem_to_loc_4d(
|
|||||||
}
|
}
|
||||||
|
|
||||||
template <typename IdxT = int64_t>
|
template <typename IdxT = int64_t>
|
||||||
inline __host__ __device__ cuda::std::tuple<IdxT, IdxT, IdxT> elem_to_loc_4d(
|
inline __host__ __device__ cuda::std::tuple<IdxT, IdxT, IdxT> elem_to_loc(
|
||||||
IdxT elem,
|
IdxT elem,
|
||||||
const int* shape,
|
const int* shape,
|
||||||
const int64_t* a_strides,
|
const int64_t* a_strides,
|
||||||
|
|||||||
@@ -19,8 +19,6 @@ void new_stream(Stream s) {
|
|||||||
cudaFree(nullptr);
|
cudaFree(nullptr);
|
||||||
// Ensure the static stream objects get created.
|
// Ensure the static stream objects get created.
|
||||||
cu::get_command_encoder(s);
|
cu::get_command_encoder(s);
|
||||||
// The main thread is safe to free buffers.
|
|
||||||
cu::allocator().register_this_thread();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void eval(array& arr) {
|
void eval(array& arr) {
|
||||||
|
|||||||
@@ -110,24 +110,26 @@ __global__ void event_signal_kernel(SharedEvent::Atomic* ac, uint64_t value) {
|
|||||||
event_signal(ac, value);
|
event_signal(ac, value);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
SharedEvent::Atomic* to_atomic(std::shared_ptr<Buffer> buf) {
|
||||||
|
return static_cast<SharedEvent::Atomic*>(buf->raw_ptr());
|
||||||
|
}
|
||||||
|
|
||||||
SharedEvent::SharedEvent() {
|
SharedEvent::SharedEvent() {
|
||||||
// Allocate cuda::atomic on managed memory.
|
buf_ = std::shared_ptr<Buffer>(
|
||||||
Atomic* ac;
|
new Buffer{allocator().malloc(sizeof(Atomic))}, [](Buffer* ptr) {
|
||||||
CHECK_CUDA_ERROR(cudaMallocManaged(&ac, sizeof(Atomic)));
|
allocator().free(*ptr);
|
||||||
new (ac) Atomic(0);
|
delete ptr;
|
||||||
ac_ = std::shared_ptr<Atomic>(ac, [](Atomic* ptr) {
|
});
|
||||||
ptr->~Atomic();
|
*static_cast<uint64_t*>(buf_->raw_ptr()) = 0;
|
||||||
allocator().cuda_free(ptr);
|
|
||||||
});
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void SharedEvent::wait(uint64_t value) {
|
void SharedEvent::wait(uint64_t value) {
|
||||||
nvtx3::scoped_range r("cu::SharedEvent::wait");
|
nvtx3::scoped_range r("cu::SharedEvent::wait");
|
||||||
event_wait(ac_.get(), value);
|
event_wait(to_atomic(buf_), value);
|
||||||
}
|
}
|
||||||
|
|
||||||
void SharedEvent::wait(cudaStream_t stream, uint64_t value) {
|
void SharedEvent::wait(cudaStream_t stream, uint64_t value) {
|
||||||
event_wait_kernel<<<1, 1, 0, stream>>>(ac_.get(), value);
|
event_wait_kernel<<<1, 1, 0, stream>>>(to_atomic(buf_), value);
|
||||||
}
|
}
|
||||||
|
|
||||||
void SharedEvent::wait(Stream s, uint64_t value) {
|
void SharedEvent::wait(Stream s, uint64_t value) {
|
||||||
@@ -138,17 +140,17 @@ void SharedEvent::wait(Stream s, uint64_t value) {
|
|||||||
auto& encoder = get_command_encoder(s);
|
auto& encoder = get_command_encoder(s);
|
||||||
encoder.commit();
|
encoder.commit();
|
||||||
wait(encoder.stream(), value);
|
wait(encoder.stream(), value);
|
||||||
encoder.add_completed_handler([ac = ac_]() {});
|
encoder.add_completed_handler([buf = buf_]() {});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void SharedEvent::signal(uint64_t value) {
|
void SharedEvent::signal(uint64_t value) {
|
||||||
nvtx3::scoped_range r("cu::SharedEvent::signal");
|
nvtx3::scoped_range r("cu::SharedEvent::signal");
|
||||||
event_signal(ac_.get(), value);
|
event_signal(to_atomic(buf_), value);
|
||||||
}
|
}
|
||||||
|
|
||||||
void SharedEvent::signal(cudaStream_t stream, uint64_t value) {
|
void SharedEvent::signal(cudaStream_t stream, uint64_t value) {
|
||||||
event_signal_kernel<<<1, 1, 0, stream>>>(ac_.get(), value);
|
event_signal_kernel<<<1, 1, 0, stream>>>(to_atomic(buf_), value);
|
||||||
}
|
}
|
||||||
|
|
||||||
void SharedEvent::signal(Stream s, uint64_t value) {
|
void SharedEvent::signal(Stream s, uint64_t value) {
|
||||||
@@ -162,18 +164,18 @@ void SharedEvent::signal(Stream s, uint64_t value) {
|
|||||||
auto& encoder = get_command_encoder(s);
|
auto& encoder = get_command_encoder(s);
|
||||||
encoder.commit();
|
encoder.commit();
|
||||||
signal(encoder.stream(), value);
|
signal(encoder.stream(), value);
|
||||||
encoder.add_completed_handler([ac = ac_]() {});
|
encoder.add_completed_handler([buf = buf_]() {});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
bool SharedEvent::is_signaled(uint64_t value) const {
|
bool SharedEvent::is_signaled(uint64_t value) const {
|
||||||
nvtx3::scoped_range r("cu::SharedEvent::is_signaled");
|
nvtx3::scoped_range r("cu::SharedEvent::is_signaled");
|
||||||
return ac_->load() >= value;
|
return to_atomic(buf_)->load() >= value;
|
||||||
}
|
}
|
||||||
|
|
||||||
uint64_t SharedEvent::value() const {
|
uint64_t SharedEvent::value() const {
|
||||||
nvtx3::scoped_range r("cu::SharedEvent::value");
|
nvtx3::scoped_range r("cu::SharedEvent::value");
|
||||||
return ac_->load();
|
return to_atomic(buf_)->load();
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace cu
|
} // namespace cu
|
||||||
|
|||||||
@@ -2,6 +2,7 @@
|
|||||||
|
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
|
#include "mlx/allocator.h"
|
||||||
#include "mlx/stream.h"
|
#include "mlx/stream.h"
|
||||||
|
|
||||||
#include <cuda_runtime.h>
|
#include <cuda_runtime.h>
|
||||||
@@ -55,12 +56,8 @@ class SharedEvent {
|
|||||||
bool is_signaled(uint64_t value) const;
|
bool is_signaled(uint64_t value) const;
|
||||||
uint64_t value() const;
|
uint64_t value() const;
|
||||||
|
|
||||||
const std::shared_ptr<Atomic>& atomic() const {
|
|
||||||
return ac_;
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::shared_ptr<Atomic> ac_;
|
std::shared_ptr<mlx::core::allocator::Buffer> buf_;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace mlx::core::cu
|
} // namespace mlx::core::cu
|
||||||
|
|||||||
147
mlx/backend/cuda/gemv.cu
Normal file
147
mlx/backend/cuda/gemv.cu
Normal file
@@ -0,0 +1,147 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#include "mlx/backend/cuda/gemv.h"
|
||||||
|
#include "mlx/backend/cuda/kernel_utils.cuh"
|
||||||
|
#include "mlx/dtype_utils.h"
|
||||||
|
|
||||||
|
#include <cooperative_groups.h>
|
||||||
|
#include <cooperative_groups/reduce.h>
|
||||||
|
|
||||||
|
namespace mlx::core::cu {
|
||||||
|
|
||||||
|
namespace cg = cooperative_groups;
|
||||||
|
|
||||||
|
static constexpr int n_per_thread = 4;
|
||||||
|
static constexpr int rows_per_block = 8;
|
||||||
|
|
||||||
|
template <typename T, int rows_per_block, int n_per_thread>
|
||||||
|
__device__ void
|
||||||
|
gemv_impl(const T* mat, const T* vec, T* out, int rows, int cols) {
|
||||||
|
auto block = cg::this_thread_block();
|
||||||
|
auto warp = cg::tiled_partition<WARP_SIZE>(block);
|
||||||
|
|
||||||
|
auto g_idx = block.group_index();
|
||||||
|
auto t_idx = block.thread_index();
|
||||||
|
int row = g_idx.x * rows_per_block + t_idx.y;
|
||||||
|
|
||||||
|
if (row < rows) {
|
||||||
|
float sum = 0.0f;
|
||||||
|
for (int col = n_per_thread * warp.thread_rank(); col < cols;
|
||||||
|
col += (WARP_SIZE * n_per_thread)) {
|
||||||
|
auto local_mat = load_vector<n_per_thread>(mat + row * cols + col, 0);
|
||||||
|
auto local_vec = load_vector<n_per_thread>(vec + col, 0);
|
||||||
|
#pragma unroll
|
||||||
|
for (int j = 0; j < n_per_thread; ++j) {
|
||||||
|
sum += static_cast<float>(local_mat.val[j]) *
|
||||||
|
static_cast<float>(local_vec.val[j]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
sum = cg::reduce(warp, sum, cg::plus<float>{});
|
||||||
|
if (warp.thread_rank() == 0) {
|
||||||
|
out[row] = static_cast<T>(sum);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, int rows_per_block, int n_per_thread>
|
||||||
|
__global__ void
|
||||||
|
gemv_single(const T* mat, const T* vec, T* out, int rows, int cols) {
|
||||||
|
gemv_impl<T, rows_per_block, n_per_thread>(mat, vec, out, rows, cols);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, int rows_per_block, int n_per_thread>
|
||||||
|
__global__ void gemv_batched(
|
||||||
|
const T* mat,
|
||||||
|
const T* vec,
|
||||||
|
T* out,
|
||||||
|
int rows,
|
||||||
|
int cols,
|
||||||
|
const __grid_constant__ Shape batch_shape,
|
||||||
|
const __grid_constant__ Strides mat_batch_strides,
|
||||||
|
const __grid_constant__ Strides vec_batch_strides,
|
||||||
|
int batch_ndim) {
|
||||||
|
auto block = cg::this_thread_block();
|
||||||
|
auto batch_idx = block.group_index().y;
|
||||||
|
auto [vec_offset, mat_offset] = elem_to_loc(
|
||||||
|
batch_idx,
|
||||||
|
batch_shape.data(),
|
||||||
|
vec_batch_strides.data(),
|
||||||
|
mat_batch_strides.data(),
|
||||||
|
batch_ndim);
|
||||||
|
gemv_impl<T, rows_per_block, n_per_thread>(
|
||||||
|
mat + mat_offset, vec + vec_offset, out + batch_idx * rows, rows, cols);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool can_use_gemv(int M, int N, int K, bool a_transposed, bool b_transposed) {
|
||||||
|
return K % (WARP_SIZE * n_per_thread) == 0 &&
|
||||||
|
((M == 1 && b_transposed) || (N == 1 && !a_transposed));
|
||||||
|
}
|
||||||
|
|
||||||
|
void gemv(
|
||||||
|
const array& a,
|
||||||
|
const array& b,
|
||||||
|
array& out,
|
||||||
|
int M,
|
||||||
|
int N,
|
||||||
|
int K,
|
||||||
|
uint32_t batch_count,
|
||||||
|
const mlx::core::Shape& batch_shape,
|
||||||
|
const mlx::core::Strides& a_batch_strides,
|
||||||
|
const mlx::core::Strides& b_batch_strides,
|
||||||
|
CommandEncoder& encoder) {
|
||||||
|
encoder.set_input_array(a);
|
||||||
|
encoder.set_input_array(b);
|
||||||
|
encoder.set_output_array(out);
|
||||||
|
dispatch_float_types(out.dtype(), "gemv", [&](auto type_tag) {
|
||||||
|
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
|
||||||
|
dim3 block_dims{WARP_SIZE, rows_per_block};
|
||||||
|
const DataType* mat;
|
||||||
|
const DataType* vec;
|
||||||
|
int rows;
|
||||||
|
int cols = K;
|
||||||
|
auto mat_strides = const_param(a_batch_strides);
|
||||||
|
auto vec_strides = const_param(b_batch_strides);
|
||||||
|
|
||||||
|
if (M == 1) {
|
||||||
|
mat = b.data<DataType>();
|
||||||
|
vec = a.data<DataType>();
|
||||||
|
rows = N;
|
||||||
|
std::swap(mat_strides, vec_strides);
|
||||||
|
} else {
|
||||||
|
mat = a.data<DataType>();
|
||||||
|
vec = b.data<DataType>();
|
||||||
|
rows = M;
|
||||||
|
}
|
||||||
|
uint32_t num_blocks_x = (rows + rows_per_block - 1) / rows_per_block;
|
||||||
|
if (batch_count == 1) {
|
||||||
|
auto kernel = gemv_single<DataType, rows_per_block, n_per_thread>;
|
||||||
|
encoder.add_kernel_node(
|
||||||
|
kernel,
|
||||||
|
num_blocks_x,
|
||||||
|
block_dims,
|
||||||
|
mat,
|
||||||
|
vec,
|
||||||
|
out.data<DataType>(),
|
||||||
|
rows,
|
||||||
|
cols);
|
||||||
|
} else {
|
||||||
|
auto kernel = gemv_batched<DataType, rows_per_block, n_per_thread>;
|
||||||
|
encoder.add_kernel_node(
|
||||||
|
kernel,
|
||||||
|
dim3{num_blocks_x, batch_count},
|
||||||
|
block_dims,
|
||||||
|
mat,
|
||||||
|
vec,
|
||||||
|
out.data<DataType>(),
|
||||||
|
rows,
|
||||||
|
cols,
|
||||||
|
const_param(batch_shape),
|
||||||
|
mat_strides,
|
||||||
|
vec_strides,
|
||||||
|
batch_shape.size());
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core::cu
|
||||||
24
mlx/backend/cuda/gemv.h
Normal file
24
mlx/backend/cuda/gemv.h
Normal file
@@ -0,0 +1,24 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "mlx/backend/cuda/device.h"
|
||||||
|
|
||||||
|
namespace mlx::core::cu {
|
||||||
|
|
||||||
|
bool can_use_gemv(int M, int N, int K, bool a_transposed, bool b_transposed);
|
||||||
|
|
||||||
|
void gemv(
|
||||||
|
const array& a,
|
||||||
|
const array& b,
|
||||||
|
array& out,
|
||||||
|
int M,
|
||||||
|
int N,
|
||||||
|
int K,
|
||||||
|
uint32_t batch_count,
|
||||||
|
const mlx::core::Shape& batch_shape,
|
||||||
|
const mlx::core::Strides& a_batch_strides,
|
||||||
|
const mlx::core::Strides& b_batch_strides,
|
||||||
|
CommandEncoder& encoder);
|
||||||
|
|
||||||
|
} // namespace mlx::core::cu
|
||||||
@@ -1,121 +0,0 @@
|
|||||||
// Copyright © 2025 Apple Inc.
|
|
||||||
|
|
||||||
#pragma once
|
|
||||||
|
|
||||||
#include <thrust/iterator/iterator_adaptor.h>
|
|
||||||
#include <cuda/std/utility>
|
|
||||||
|
|
||||||
#include "mlx/backend/cuda/kernel_utils.cuh"
|
|
||||||
|
|
||||||
namespace mlx::core::cu {
|
|
||||||
|
|
||||||
// Iterating non-contiguous array.
|
|
||||||
template <typename Iterator, typename IdxT = int64_t>
|
|
||||||
class general_iterator
|
|
||||||
: public thrust::
|
|
||||||
iterator_adaptor<general_iterator<Iterator, IdxT>, Iterator> {
|
|
||||||
public:
|
|
||||||
using super_t =
|
|
||||||
thrust::iterator_adaptor<general_iterator<Iterator, IdxT>, Iterator>;
|
|
||||||
|
|
||||||
using reference = typename super_t::reference;
|
|
||||||
using difference_type = typename super_t::difference_type;
|
|
||||||
|
|
||||||
__host__ __device__ general_iterator(
|
|
||||||
Iterator it,
|
|
||||||
IdxT index,
|
|
||||||
int ndim,
|
|
||||||
Shape shape,
|
|
||||||
Strides strides)
|
|
||||||
: super_t(it),
|
|
||||||
index_(index),
|
|
||||||
ndim_(ndim),
|
|
||||||
shape_(cuda::std::move(shape)),
|
|
||||||
strides_(cuda::std::move(strides)) {}
|
|
||||||
|
|
||||||
__host__ __device__ IdxT index() const {
|
|
||||||
return index_;
|
|
||||||
}
|
|
||||||
|
|
||||||
__host__ __device__ const Shape& shape() const {
|
|
||||||
return shape_;
|
|
||||||
}
|
|
||||||
|
|
||||||
__host__ __device__ const Strides& strides() const {
|
|
||||||
return strides_;
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
|
||||||
friend class thrust::iterator_core_access;
|
|
||||||
|
|
||||||
__host__ __device__ bool equal(const general_iterator& other) const {
|
|
||||||
return this->base() == other.base() && this->index() == other.index();
|
|
||||||
}
|
|
||||||
|
|
||||||
__host__ __device__ void advance(difference_type n) {
|
|
||||||
this->index_ += n;
|
|
||||||
}
|
|
||||||
|
|
||||||
__host__ __device__ void increment() {
|
|
||||||
this->index_ += 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
__host__ __device__ void decrement() {
|
|
||||||
this->index_ -= 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
__host__ __device__ difference_type
|
|
||||||
distance_to(const general_iterator& other) const {
|
|
||||||
_CCCL_ASSERT(
|
|
||||||
this->base() == other.base(),
|
|
||||||
"Underlying iterator must point to same base iterator");
|
|
||||||
return other.index() - this->index();
|
|
||||||
}
|
|
||||||
|
|
||||||
// The dereference is device-only to avoid accidental running in host.
|
|
||||||
__device__ typename super_t::reference dereference() const {
|
|
||||||
IdxT offset = elem_to_loc(index_, shape_.data(), strides_.data(), ndim_);
|
|
||||||
return *(this->base() + offset);
|
|
||||||
}
|
|
||||||
|
|
||||||
IdxT index_;
|
|
||||||
int ndim_;
|
|
||||||
Shape shape_;
|
|
||||||
Strides strides_;
|
|
||||||
};
|
|
||||||
|
|
||||||
template <typename IdxT, typename Iterator>
|
|
||||||
__host__ __device__ auto make_general_iterator(
|
|
||||||
Iterator it,
|
|
||||||
IdxT index,
|
|
||||||
int ndim,
|
|
||||||
Shape shape,
|
|
||||||
Strides strides) {
|
|
||||||
return general_iterator<Iterator, IdxT>(
|
|
||||||
it, index, ndim, cuda::std::move(shape), cuda::std::move(strides));
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename IdxT, typename Iterator>
|
|
||||||
auto make_general_iterator(
|
|
||||||
Iterator it,
|
|
||||||
const std::vector<int32_t>& shape,
|
|
||||||
const std::vector<int64_t>& strides) {
|
|
||||||
return make_general_iterator<IdxT>(
|
|
||||||
it, 0, shape.size(), const_param(shape), const_param(strides));
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename IdxT, typename Iterator>
|
|
||||||
auto make_general_iterators(
|
|
||||||
Iterator it,
|
|
||||||
IdxT size,
|
|
||||||
const std::vector<int32_t>& shape,
|
|
||||||
const std::vector<int64_t>& strides) {
|
|
||||||
auto ndim = shape.size();
|
|
||||||
auto shape_arg = const_param(shape);
|
|
||||||
auto strides_arg = const_param(strides);
|
|
||||||
return std::make_pair(
|
|
||||||
make_general_iterator<IdxT>(it, 0, ndim, shape_arg, strides_arg),
|
|
||||||
make_general_iterator<IdxT>(it, size, ndim, shape_arg, strides_arg));
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace mlx::core::cu
|
|
||||||
@@ -1,60 +0,0 @@
|
|||||||
// Copyright © 2025 Apple Inc.
|
|
||||||
|
|
||||||
#pragma once
|
|
||||||
|
|
||||||
#include <thrust/iterator/iterator_adaptor.h>
|
|
||||||
#include <thrust/iterator/iterator_facade.h>
|
|
||||||
|
|
||||||
namespace mlx::core::cu {
|
|
||||||
|
|
||||||
// RandomAccessIterator for strided access to array entries.
|
|
||||||
template <typename Iterator, typename Stride = int64_t>
|
|
||||||
class strided_iterator
|
|
||||||
: public thrust::
|
|
||||||
iterator_adaptor<strided_iterator<Iterator, Stride>, Iterator> {
|
|
||||||
public:
|
|
||||||
using super_t =
|
|
||||||
thrust::iterator_adaptor<strided_iterator<Iterator, Stride>, Iterator>;
|
|
||||||
|
|
||||||
using reference = typename super_t::reference;
|
|
||||||
using difference_type = typename super_t::difference_type;
|
|
||||||
|
|
||||||
__host__ __device__ strided_iterator(Iterator it, Stride stride)
|
|
||||||
: super_t(it), stride_(stride) {}
|
|
||||||
|
|
||||||
__host__ __device__ Stride stride() const {
|
|
||||||
return stride_;
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
|
||||||
friend class thrust::iterator_core_access;
|
|
||||||
|
|
||||||
__host__ __device__ bool equal(const strided_iterator& other) const {
|
|
||||||
return this->base() == other.base();
|
|
||||||
}
|
|
||||||
|
|
||||||
__host__ __device__ void advance(difference_type n) {
|
|
||||||
this->base_reference() += n * stride_;
|
|
||||||
}
|
|
||||||
|
|
||||||
__host__ __device__ void increment() {
|
|
||||||
this->base_reference() += stride_;
|
|
||||||
}
|
|
||||||
|
|
||||||
__host__ __device__ void decrement() {
|
|
||||||
this->base_reference() -= stride_;
|
|
||||||
}
|
|
||||||
|
|
||||||
__host__ __device__ difference_type
|
|
||||||
distance_to(const strided_iterator& other) const {
|
|
||||||
const difference_type dist = other.base() - this->base();
|
|
||||||
_CCCL_ASSERT(
|
|
||||||
dist % stride() == 0,
|
|
||||||
"Underlying iterator difference must be divisible by the stride");
|
|
||||||
return dist / stride();
|
|
||||||
}
|
|
||||||
|
|
||||||
Stride stride_;
|
|
||||||
};
|
|
||||||
|
|
||||||
} // namespace mlx::core::cu
|
|
||||||
@@ -1,7 +1,6 @@
|
|||||||
// Copyright © 2025 Apple Inc.
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
#include "mlx/backend/cuda/device.h"
|
#include "mlx/backend/cuda/device.h"
|
||||||
#include "mlx/backend/cuda/iterators/strided_iterator.cuh"
|
|
||||||
#include "mlx/backend/cuda/kernel_utils.cuh"
|
#include "mlx/backend/cuda/kernel_utils.cuh"
|
||||||
#include "mlx/backend/cuda/reduce/reduce.cuh"
|
#include "mlx/backend/cuda/reduce/reduce.cuh"
|
||||||
#include "mlx/backend/gpu/copy.h"
|
#include "mlx/backend/gpu/copy.h"
|
||||||
@@ -105,8 +104,8 @@ __global__ void layer_norm(
|
|||||||
T wn[N_READS];
|
T wn[N_READS];
|
||||||
T bn[N_READS];
|
T bn[N_READS];
|
||||||
cub::LoadDirectBlocked(index, x, xn, axis_size);
|
cub::LoadDirectBlocked(index, x, xn, axis_size);
|
||||||
cub::LoadDirectBlocked(index, strided_iterator(w, w_stride), wn, axis_size);
|
cub::LoadDirectBlocked(index, StridedIterator(w, w_stride), wn, axis_size);
|
||||||
cub::LoadDirectBlocked(index, strided_iterator(b, b_stride), bn, axis_size);
|
cub::LoadDirectBlocked(index, StridedIterator(b, b_stride), bn, axis_size);
|
||||||
for (int i = 0; i < N_READS; ++i) {
|
for (int i = 0; i < N_READS; ++i) {
|
||||||
float norm = (static_cast<float>(xn[i]) - mean) * normalizer;
|
float norm = (static_cast<float>(xn[i]) - mean) * normalizer;
|
||||||
xn[i] = wn[i] * static_cast<T>(norm) + bn[i];
|
xn[i] = wn[i] * static_cast<T>(norm) + bn[i];
|
||||||
@@ -162,7 +161,7 @@ __global__ void layer_norm_vjp(
|
|||||||
auto index = r * BLOCK_DIM + block.thread_rank();
|
auto index = r * BLOCK_DIM + block.thread_rank();
|
||||||
cub::LoadDirectBlocked(index, x, xn, axis_size, mean);
|
cub::LoadDirectBlocked(index, x, xn, axis_size, mean);
|
||||||
cub::LoadDirectBlocked(index, g, gn, axis_size);
|
cub::LoadDirectBlocked(index, g, gn, axis_size);
|
||||||
cub::LoadDirectBlocked(index, strided_iterator(w, w_stride), wn, axis_size);
|
cub::LoadDirectBlocked(index, StridedIterator(w, w_stride), wn, axis_size);
|
||||||
for (int i = 0; i < N_READS; i++) {
|
for (int i = 0; i < N_READS; i++) {
|
||||||
float t = static_cast<float>(xn[i]) - mean;
|
float t = static_cast<float>(xn[i]) - mean;
|
||||||
float wi = wn[i];
|
float wi = wn[i];
|
||||||
@@ -185,7 +184,7 @@ __global__ void layer_norm_vjp(
|
|||||||
T gn[N_READS];
|
T gn[N_READS];
|
||||||
cub::LoadDirectBlocked(index, x, xn, axis_size);
|
cub::LoadDirectBlocked(index, x, xn, axis_size);
|
||||||
cub::LoadDirectBlocked(index, g, gn, axis_size);
|
cub::LoadDirectBlocked(index, g, gn, axis_size);
|
||||||
cub::LoadDirectBlocked(index, strided_iterator(w, w_stride), wn, axis_size);
|
cub::LoadDirectBlocked(index, StridedIterator(w, w_stride), wn, axis_size);
|
||||||
for (int i = 0; i < N_READS; i++) {
|
for (int i = 0; i < N_READS; i++) {
|
||||||
float xi = (static_cast<float>(xn[i]) - mean) * normalizer;
|
float xi = (static_cast<float>(xn[i]) - mean) * normalizer;
|
||||||
float wi = wn[i];
|
float wi = wn[i];
|
||||||
|
|||||||
@@ -2,6 +2,7 @@
|
|||||||
|
|
||||||
#include "mlx/backend/common/matmul.h"
|
#include "mlx/backend/common/matmul.h"
|
||||||
#include "mlx/backend/cuda/device.h"
|
#include "mlx/backend/cuda/device.h"
|
||||||
|
#include "mlx/backend/cuda/gemv.h"
|
||||||
#include "mlx/backend/gpu/copy.h"
|
#include "mlx/backend/gpu/copy.h"
|
||||||
#include "mlx/dtype_utils.h"
|
#include "mlx/dtype_utils.h"
|
||||||
#include "mlx/primitives.h"
|
#include "mlx/primitives.h"
|
||||||
@@ -353,6 +354,22 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
batch_shape = {1};
|
batch_shape = {1};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (cu::can_use_gemv(M, N, K, a_transposed, b_transposed)) {
|
||||||
|
cu::gemv(
|
||||||
|
a,
|
||||||
|
b,
|
||||||
|
out,
|
||||||
|
M,
|
||||||
|
N,
|
||||||
|
K,
|
||||||
|
batch_count,
|
||||||
|
batch_shape,
|
||||||
|
a_batch_strides,
|
||||||
|
b_batch_strides,
|
||||||
|
encoder);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
/////////////////////////////////////////////////////////////////////////////
|
/////////////////////////////////////////////////////////////////////////////
|
||||||
// Invoke cublasLt
|
// Invoke cublasLt
|
||||||
|
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
// Copyright © 2025 Apple Inc.
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
#include "mlx/backend/cuda/device.h"
|
#include "mlx/backend/cuda/device.h"
|
||||||
#include "mlx/backend/cuda/iterators/strided_iterator.cuh"
|
|
||||||
#include "mlx/backend/cuda/kernel_utils.cuh"
|
#include "mlx/backend/cuda/kernel_utils.cuh"
|
||||||
#include "mlx/backend/cuda/reduce/reduce.cuh"
|
#include "mlx/backend/cuda/reduce/reduce.cuh"
|
||||||
#include "mlx/backend/gpu/copy.h"
|
#include "mlx/backend/gpu/copy.h"
|
||||||
@@ -89,7 +88,7 @@ __global__ void rms_norm(
|
|||||||
T xn[N_READS];
|
T xn[N_READS];
|
||||||
T wn[N_READS];
|
T wn[N_READS];
|
||||||
cub::LoadDirectBlocked(index, x, xn, axis_size);
|
cub::LoadDirectBlocked(index, x, xn, axis_size);
|
||||||
cub::LoadDirectBlocked(index, strided_iterator(w, w_stride), wn, axis_size);
|
cub::LoadDirectBlocked(index, StridedIterator(w, w_stride), wn, axis_size);
|
||||||
for (int i = 0; i < N_READS; ++i) {
|
for (int i = 0; i < N_READS; ++i) {
|
||||||
float norm = static_cast<float>(xn[i]) * normalizer;
|
float norm = static_cast<float>(xn[i]) * normalizer;
|
||||||
xn[i] = wn[i] * static_cast<T>(norm);
|
xn[i] = wn[i] * static_cast<T>(norm);
|
||||||
@@ -132,7 +131,7 @@ __global__ void rms_norm_vjp(
|
|||||||
auto index = r * BLOCK_DIM + block.thread_rank();
|
auto index = r * BLOCK_DIM + block.thread_rank();
|
||||||
cub::LoadDirectBlocked(index, x, xn, axis_size, cast_to<T>(0));
|
cub::LoadDirectBlocked(index, x, xn, axis_size, cast_to<T>(0));
|
||||||
cub::LoadDirectBlocked(index, g, gn, axis_size);
|
cub::LoadDirectBlocked(index, g, gn, axis_size);
|
||||||
cub::LoadDirectBlocked(index, strided_iterator(w, w_stride), wn, axis_size);
|
cub::LoadDirectBlocked(index, StridedIterator(w, w_stride), wn, axis_size);
|
||||||
for (int i = 0; i < N_READS; i++) {
|
for (int i = 0; i < N_READS; i++) {
|
||||||
float t = static_cast<float>(xn[i]);
|
float t = static_cast<float>(xn[i]);
|
||||||
float wi = wn[i];
|
float wi = wn[i];
|
||||||
@@ -154,7 +153,7 @@ __global__ void rms_norm_vjp(
|
|||||||
T gn[N_READS];
|
T gn[N_READS];
|
||||||
cub::LoadDirectBlocked(index, x, xn, axis_size);
|
cub::LoadDirectBlocked(index, x, xn, axis_size);
|
||||||
cub::LoadDirectBlocked(index, g, gn, axis_size);
|
cub::LoadDirectBlocked(index, g, gn, axis_size);
|
||||||
cub::LoadDirectBlocked(index, strided_iterator(w, w_stride), wn, axis_size);
|
cub::LoadDirectBlocked(index, StridedIterator(w, w_stride), wn, axis_size);
|
||||||
for (int i = 0; i < N_READS; i++) {
|
for (int i = 0; i < N_READS; i++) {
|
||||||
float xi = xn[i];
|
float xi = xn[i];
|
||||||
float wi = wn[i];
|
float wi = wn[i];
|
||||||
|
|||||||
@@ -76,7 +76,7 @@ __global__ void ternary_g(
|
|||||||
int ndim) {
|
int ndim) {
|
||||||
IdxT index = cg::this_grid().thread_rank();
|
IdxT index = cg::this_grid().thread_rank();
|
||||||
if (index < size) {
|
if (index < size) {
|
||||||
auto [a_idx, b_idx, c_idx] = elem_to_loc_4d(
|
auto [a_idx, b_idx, c_idx] = elem_to_loc(
|
||||||
index,
|
index,
|
||||||
shape.data(),
|
shape.data(),
|
||||||
a_strides.data(),
|
a_strides.data(),
|
||||||
|
|||||||
@@ -3,7 +3,6 @@
|
|||||||
#include "mlx/backend/common/unary.h"
|
#include "mlx/backend/common/unary.h"
|
||||||
#include "mlx/backend/cuda/device.h"
|
#include "mlx/backend/cuda/device.h"
|
||||||
#include "mlx/backend/cuda/device/unary_ops.cuh"
|
#include "mlx/backend/cuda/device/unary_ops.cuh"
|
||||||
#include "mlx/backend/cuda/iterators/general_iterator.cuh"
|
|
||||||
#include "mlx/backend/cuda/kernel_utils.cuh"
|
#include "mlx/backend/cuda/kernel_utils.cuh"
|
||||||
#include "mlx/dtype_utils.h"
|
#include "mlx/dtype_utils.h"
|
||||||
#include "mlx/primitives.h"
|
#include "mlx/primitives.h"
|
||||||
@@ -48,7 +47,7 @@ __global__ void unary_g(
|
|||||||
int ndim) {
|
int ndim) {
|
||||||
IdxT index = cg::this_grid().thread_rank();
|
IdxT index = cg::this_grid().thread_rank();
|
||||||
if (index < size) {
|
if (index < size) {
|
||||||
auto idx = elem_to_loc_4d(index, shape.data(), strides.data(), ndim);
|
auto idx = elem_to_loc(index, shape.data(), strides.data(), ndim);
|
||||||
out[index] = Op{}(in[idx]);
|
out[index] = Op{}(in[idx]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
// Copyright © 2025 Apple Inc.
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
#include "mlx/backend/cuda/worker.h"
|
#include "mlx/backend/cuda/worker.h"
|
||||||
#include "mlx/backend/cuda/allocator.h"
|
|
||||||
#include "mlx/backend/cuda/device.h"
|
#include "mlx/backend/cuda/device.h"
|
||||||
|
|
||||||
namespace mlx::core::cu {
|
namespace mlx::core::cu {
|
||||||
@@ -12,10 +11,10 @@ Worker::Worker()
|
|||||||
|
|
||||||
Worker::~Worker() {
|
Worker::~Worker() {
|
||||||
{
|
{
|
||||||
std::lock_guard lock(worker_mutex_);
|
std::lock_guard lock(mtx_);
|
||||||
stop_ = true;
|
stop_ = true;
|
||||||
}
|
}
|
||||||
worker_event_.signal(batch_ + 1);
|
cond_.notify_one();
|
||||||
worker_.join();
|
worker_.join();
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -23,53 +22,41 @@ void Worker::add_task(std::function<void()> task) {
|
|||||||
pending_tasks_.push_back(std::move(task));
|
pending_tasks_.push_back(std::move(task));
|
||||||
}
|
}
|
||||||
|
|
||||||
void Worker::consume_in_this_thread() {
|
void Worker::signal(void* data) {
|
||||||
for (auto& task : pending_tasks_) {
|
auto w = static_cast<Worker*>(data);
|
||||||
task();
|
|
||||||
}
|
|
||||||
pending_tasks_.clear();
|
|
||||||
}
|
|
||||||
|
|
||||||
void Worker::end_batch() {
|
|
||||||
batch_++;
|
|
||||||
{
|
{
|
||||||
std::lock_guard lock(worker_mutex_);
|
std::lock_guard lock(w->mtx_);
|
||||||
worker_tasks_[batch_] = std::move(pending_tasks_);
|
w->signaled_batch_++;
|
||||||
}
|
}
|
||||||
uncommited_batches_++;
|
w->cond_.notify_one();
|
||||||
}
|
|
||||||
|
|
||||||
void Worker::commit() {
|
|
||||||
if (uncommited_batches_ == 0) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
uncommited_batches_ = 0;
|
|
||||||
worker_event_.signal(batch_);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void Worker::commit(cudaStream_t stream) {
|
void Worker::commit(cudaStream_t stream) {
|
||||||
if (uncommited_batches_ == 0) {
|
// Move pending tasks into tasks
|
||||||
|
if (pending_tasks_.empty()) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
uncommited_batches_ = 0;
|
{
|
||||||
// Signal the |worker_event_| in |signal_stream_| after the kernels in
|
std::lock_guard lock(mtx_);
|
||||||
// |stream_| finish running.
|
// Move pending tasks into ready tasks
|
||||||
|
worker_tasks_[++committed_batch_] = std::move(pending_tasks_);
|
||||||
|
}
|
||||||
signal_event_.record(stream);
|
signal_event_.record(stream);
|
||||||
signal_event_.wait(signal_stream_);
|
signal_event_.wait(signal_stream_);
|
||||||
worker_event_.signal(signal_stream_, batch_);
|
cudaLaunchHostFunc(signal_stream_, signal, this);
|
||||||
}
|
}
|
||||||
|
|
||||||
void Worker::thread_fn() {
|
void Worker::thread_fn() {
|
||||||
// The worker thread is safe to free buffers.
|
|
||||||
allocator().register_this_thread();
|
|
||||||
|
|
||||||
while (!stop_) {
|
while (!stop_) {
|
||||||
uint64_t batch = worker_event_.value();
|
uint64_t current_batch = 0;
|
||||||
Tasks tasks;
|
Tasks tasks;
|
||||||
{
|
{
|
||||||
std::lock_guard lock(worker_mutex_);
|
std::unique_lock<std::mutex> lk(mtx_);
|
||||||
// Move tasks in signaled batches.
|
cond_.wait(lk, [this, ¤t_batch] {
|
||||||
auto end = worker_tasks_.upper_bound(batch);
|
return this->signaled_batch_ > current_batch || this->stop_;
|
||||||
|
});
|
||||||
|
current_batch = signaled_batch_;
|
||||||
|
auto end = worker_tasks_.upper_bound(current_batch);
|
||||||
for (auto it = worker_tasks_.begin(); it != end; ++it) {
|
for (auto it = worker_tasks_.begin(); it != end; ++it) {
|
||||||
if (tasks.empty()) {
|
if (tasks.empty()) {
|
||||||
tasks = std::move(it->second);
|
tasks = std::move(it->second);
|
||||||
@@ -85,7 +72,6 @@ void Worker::thread_fn() {
|
|||||||
auto task = std::move(tasks[i]);
|
auto task = std::move(tasks[i]);
|
||||||
task();
|
task();
|
||||||
}
|
}
|
||||||
worker_event_.wait(batch + 1);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -5,6 +5,7 @@
|
|||||||
#include "mlx/backend/cuda/event.h"
|
#include "mlx/backend/cuda/event.h"
|
||||||
#include "mlx/backend/cuda/utils.h"
|
#include "mlx/backend/cuda/utils.h"
|
||||||
|
|
||||||
|
#include <condition_variable>
|
||||||
#include <functional>
|
#include <functional>
|
||||||
#include <map>
|
#include <map>
|
||||||
#include <mutex>
|
#include <mutex>
|
||||||
@@ -24,38 +25,24 @@ class Worker {
|
|||||||
// Add a pending |task| that will run when consumed or commited.
|
// Add a pending |task| that will run when consumed or commited.
|
||||||
void add_task(std::function<void()> task);
|
void add_task(std::function<void()> task);
|
||||||
|
|
||||||
// Run pending tasks immediately in current thread.
|
|
||||||
void consume_in_this_thread();
|
|
||||||
|
|
||||||
// Put pending tasks in a batch.
|
|
||||||
void end_batch();
|
|
||||||
|
|
||||||
// Inform worker thread to run current batches now.
|
|
||||||
void commit();
|
|
||||||
|
|
||||||
// Inform worker thread to run current batches after kernels in |stream|
|
// Inform worker thread to run current batches after kernels in |stream|
|
||||||
// finish running.
|
// finish running.
|
||||||
void commit(cudaStream_t stream);
|
void commit(cudaStream_t stream);
|
||||||
|
|
||||||
// Return how many batches have been added but not committed yet.
|
|
||||||
size_t uncommited_batches() const {
|
|
||||||
return uncommited_batches_;
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
private:
|
||||||
void thread_fn();
|
static void signal(void*);
|
||||||
|
|
||||||
uint64_t batch_{0};
|
void thread_fn();
|
||||||
size_t uncommited_batches_{0};
|
std::mutex mtx_;
|
||||||
|
std::condition_variable cond_;
|
||||||
|
|
||||||
|
uint64_t committed_batch_{0};
|
||||||
|
uint64_t signaled_batch_{0};
|
||||||
|
|
||||||
// Cuda stream and event for signaling kernel completion.
|
// Cuda stream and event for signaling kernel completion.
|
||||||
CudaStream signal_stream_;
|
CudaStream signal_stream_;
|
||||||
CudaEvent signal_event_;
|
CudaEvent signal_event_;
|
||||||
|
|
||||||
// Worker thread.
|
|
||||||
SharedEvent worker_event_;
|
|
||||||
std::thread worker_;
|
|
||||||
std::mutex worker_mutex_;
|
|
||||||
bool stop_{false};
|
bool stop_{false};
|
||||||
|
|
||||||
// Tasks are put in |pending_tasks_| first, and then moved to
|
// Tasks are put in |pending_tasks_| first, and then moved to
|
||||||
@@ -63,6 +50,7 @@ class Worker {
|
|||||||
using Tasks = std::vector<std::function<void()>>;
|
using Tasks = std::vector<std::function<void()>>;
|
||||||
Tasks pending_tasks_;
|
Tasks pending_tasks_;
|
||||||
std::map<uint64_t, Tasks> worker_tasks_;
|
std::map<uint64_t, Tasks> worker_tasks_;
|
||||||
|
std::thread worker_;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace mlx::core::cu
|
} // namespace mlx::core::cu
|
||||||
|
|||||||
@@ -128,8 +128,7 @@ Buffer MetalAllocator::malloc(size_t size) {
|
|||||||
|
|
||||||
auto pool = metal::new_scoped_memory_pool();
|
auto pool = metal::new_scoped_memory_pool();
|
||||||
|
|
||||||
// If we have a lot of memory pressure or are over the maximum cache size,
|
// If we have a lot of memory pressure try to reclaim memory from the cache
|
||||||
// try to reclaim memory from the cache
|
|
||||||
if (mem_required >= gc_limit_ || num_resources_ >= resource_limit_) {
|
if (mem_required >= gc_limit_ || num_resources_ >= resource_limit_) {
|
||||||
num_resources_ -=
|
num_resources_ -=
|
||||||
buffer_cache_.release_cached_buffers(mem_required - gc_limit_);
|
buffer_cache_.release_cached_buffers(mem_required - gc_limit_);
|
||||||
|
|||||||
@@ -14,6 +14,10 @@ Event::Event(Stream stream) : stream_(stream) {
|
|||||||
auto p = metal::new_scoped_memory_pool();
|
auto p = metal::new_scoped_memory_pool();
|
||||||
event_ = std::shared_ptr<void>(
|
event_ = std::shared_ptr<void>(
|
||||||
metal::device(Device::gpu).mtl_device()->newSharedEvent(), dtor);
|
metal::device(Device::gpu).mtl_device()->newSharedEvent(), dtor);
|
||||||
|
if (event_ == nullptr) {
|
||||||
|
throw std::runtime_error(
|
||||||
|
"[Event::Event] Failed to create Metal shared event.");
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void Event::wait() {
|
void Event::wait() {
|
||||||
|
|||||||
@@ -708,7 +708,10 @@ array scaled_dot_product_attention(
|
|||||||
}
|
}
|
||||||
if (mask.dtype() == bool_) {
|
if (mask.dtype() == bool_) {
|
||||||
scores = where(
|
scores = where(
|
||||||
mask, scores, array(finfo(scores.dtype()).min, scores.dtype()));
|
mask,
|
||||||
|
scores,
|
||||||
|
array(-std::numeric_limits<float>::infinity(), scores.dtype()),
|
||||||
|
s);
|
||||||
} else {
|
} else {
|
||||||
scores = add(scores, mask, s);
|
scores = add(scores, mask, s);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1271,19 +1271,6 @@ std::vector<array> Convolution::vjp(
|
|||||||
has_neg_padding |= (pd < 0);
|
has_neg_padding |= (pd < 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
auto padding_lo_ = std::vector<int>(padding_lo);
|
|
||||||
auto padding_hi_ = std::vector<int>(padding_hi);
|
|
||||||
|
|
||||||
// Use negative padding on the gradient output
|
|
||||||
if (has_neg_padding) {
|
|
||||||
for (auto& p : padding_lo_) {
|
|
||||||
p = std::max(0, p);
|
|
||||||
}
|
|
||||||
for (auto& p : padding_hi_) {
|
|
||||||
p = std::max(0, p);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
auto wt_trans = group_transpose(wt, 0, 1, -1);
|
auto wt_trans = group_transpose(wt, 0, 1, -1);
|
||||||
auto grad = conv_general(
|
auto grad = conv_general(
|
||||||
/* const array& input = */ cotan,
|
/* const array& input = */ cotan,
|
||||||
@@ -1305,12 +1292,9 @@ std::vector<array> Convolution::vjp(
|
|||||||
for (int i = 0; i < grad.ndim() - 2; i++) {
|
for (int i = 0; i < grad.ndim() - 2; i++) {
|
||||||
if (padding_lo[i] < 0) {
|
if (padding_lo[i] < 0) {
|
||||||
starts[i + 1] -= padding_lo[i];
|
starts[i + 1] -= padding_lo[i];
|
||||||
padding_lo[i] = 0;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if (padding_hi[i] < 0) {
|
if (padding_hi[i] < 0) {
|
||||||
stops[i + 1] += padding_hi[i];
|
stops[i + 1] += padding_hi[i];
|
||||||
padding_hi[i] = 0;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -72,7 +72,12 @@ array eval_impl(std::vector<array> outputs, bool async) {
|
|||||||
|
|
||||||
// Stream events for synchronization after eval
|
// Stream events for synchronization after eval
|
||||||
std::unordered_map<uint32_t, Event> events;
|
std::unordered_map<uint32_t, Event> events;
|
||||||
events.emplace(stream.index, Event{stream});
|
{
|
||||||
|
auto e = Event{stream};
|
||||||
|
e.set_value(1);
|
||||||
|
synchronizer.attach_event(e);
|
||||||
|
events.emplace(stream.index, std::move(e));
|
||||||
|
}
|
||||||
|
|
||||||
{
|
{
|
||||||
// Record the degree of each input
|
// Record the degree of each input
|
||||||
@@ -184,21 +189,26 @@ array eval_impl(std::vector<array> outputs, bool async) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::unordered_set<int> open_streams;
|
||||||
|
|
||||||
while (!tape.empty()) {
|
while (!tape.empty()) {
|
||||||
auto arr = std::move(tape.back());
|
auto arr = std::move(tape.back());
|
||||||
tape.pop_back();
|
tape.pop_back();
|
||||||
|
|
||||||
auto stream = arr.primitive().stream();
|
auto stream = arr.primitive().stream();
|
||||||
|
open_streams.insert(stream.index);
|
||||||
|
|
||||||
// Lookup corresponding event
|
if (async) {
|
||||||
auto e = events.find(stream.index);
|
// Lookup corresponding event
|
||||||
if (e == events.end()) {
|
auto e = events.find(stream.index);
|
||||||
e = events.emplace(stream.index, Event{stream}).first;
|
if (e == events.end()) {
|
||||||
}
|
e = events.emplace(stream.index, Event{stream}).first;
|
||||||
e->second.set_value(1);
|
}
|
||||||
arr.attach_event(e->second);
|
e->second.set_value(1);
|
||||||
for (auto& s : arr.siblings()) {
|
arr.attach_event(e->second);
|
||||||
s.attach_event(e->second);
|
for (auto& s : arr.siblings()) {
|
||||||
|
s.attach_event(e->second);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for (auto& in : arr.inputs()) {
|
for (auto& in : arr.inputs()) {
|
||||||
@@ -227,9 +237,10 @@ array eval_impl(std::vector<array> outputs, bool async) {
|
|||||||
(get_active_memory() > get_memory_limit() &&
|
(get_active_memory() > get_memory_limit() &&
|
||||||
scheduler::n_active_tasks() > 0)) {
|
scheduler::n_active_tasks() > 0)) {
|
||||||
// Commit any open streams
|
// Commit any open streams
|
||||||
for (auto& [_, e] : events) {
|
for (auto i : open_streams) {
|
||||||
if (e.stream().device == Device::gpu) {
|
auto s = get_stream(i);
|
||||||
gpu::finalize(e.stream());
|
if (s.device == Device::gpu) {
|
||||||
|
gpu::finalize(s);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
scheduler::wait_for_one();
|
scheduler::wait_for_one();
|
||||||
@@ -263,9 +274,11 @@ array eval_impl(std::vector<array> outputs, bool async) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Signal the event in its stream
|
// Signal the event in its stream
|
||||||
for (auto& [_, e] : events) {
|
for (auto i : open_streams) {
|
||||||
auto s = e.stream();
|
auto s = get_stream(i);
|
||||||
e.signal(s);
|
if (auto e = events.find(i); e != events.end()) {
|
||||||
|
e->second.signal(s);
|
||||||
|
}
|
||||||
if (s.device == Device::gpu) {
|
if (s.device == Device::gpu) {
|
||||||
gpu::finalize(s);
|
gpu::finalize(s);
|
||||||
}
|
}
|
||||||
@@ -302,7 +315,7 @@ void eval(std::vector<array> outputs) {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
eval_impl(std::move(outputs), false).event().wait();
|
eval_impl(std::move(outputs), false).wait();
|
||||||
}
|
}
|
||||||
|
|
||||||
std::pair<std::vector<array>, std::vector<array>> vjp(
|
std::pair<std::vector<array>, std::vector<array>> vjp(
|
||||||
|
|||||||
@@ -1,9 +1,10 @@
|
|||||||
#!/bin/bash
|
#!/bin/bash
|
||||||
|
|
||||||
auditwheel repair dist/* \
|
auditwheel repair dist/* \
|
||||||
--plat manylinux_2_39_x86_64 \
|
--plat manylinux_2_35_x86_64 \
|
||||||
--exclude libcublas* \
|
--exclude libcublas* \
|
||||||
--exclude libnvrtc* \
|
--exclude libnvrtc* \
|
||||||
|
--exclude libcuda* \
|
||||||
-w wheel_tmp
|
-w wheel_tmp
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -4022,8 +4022,9 @@ void init_ops(nb::module_& m) {
|
|||||||
Args:
|
Args:
|
||||||
file (file, str): File in which the array is saved.
|
file (file, str): File in which the array is saved.
|
||||||
arrays (dict(str, array)): The dictionary of names to arrays to
|
arrays (dict(str, array)): The dictionary of names to arrays to
|
||||||
be saved. metadata (dict(str, str), optional): The dictionary of
|
be saved.
|
||||||
metadata to be saved.
|
metadata (dict(str, str), optional): The dictionary of
|
||||||
|
metadata to be saved.
|
||||||
)pbdoc");
|
)pbdoc");
|
||||||
m.def(
|
m.def(
|
||||||
"save_gguf",
|
"save_gguf",
|
||||||
@@ -4258,7 +4259,7 @@ void init_ops(nb::module_& m) {
|
|||||||
|
|
||||||
.. math::
|
.. math::
|
||||||
|
|
||||||
w_i = s \hat{w_i} - \beta
|
w_i = s \hat{w_i} + \beta
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
w (array): Matrix to be quantized
|
w (array): Matrix to be quantized
|
||||||
|
|||||||
@@ -398,6 +398,18 @@ class TestFastSDPA(mlx_tests.MLXTestCase):
|
|||||||
)
|
)
|
||||||
self.assertTrue(mx.allclose(ref, out, atol=1e-4, rtol=1e-4))
|
self.assertTrue(mx.allclose(ref, out, atol=1e-4, rtol=1e-4))
|
||||||
|
|
||||||
|
def test_fully_masked(self):
|
||||||
|
Lkv = 8
|
||||||
|
mask = mx.array(False)
|
||||||
|
for D in [4, 128]:
|
||||||
|
for Lq in [1, 8]:
|
||||||
|
q = mx.random.normal(shape=(1, 4, Lq, D))
|
||||||
|
k = mx.random.normal(shape=(1, 4, Lkv, D))
|
||||||
|
v = mx.random.normal(shape=(1, 4, Lkv, D))
|
||||||
|
|
||||||
|
out = mx.fast.scaled_dot_product_attention(q, k, v, mask=mask, scale=1)
|
||||||
|
self.assertTrue(mx.all(mx.isnan(out)))
|
||||||
|
|
||||||
def test_fast_sdpa_few_query(self):
|
def test_fast_sdpa_few_query(self):
|
||||||
D = 64
|
D = 64
|
||||||
L = 43
|
L = 43
|
||||||
|
|||||||
46
setup.py
46
setup.py
@@ -9,7 +9,7 @@ from functools import partial
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from subprocess import run
|
from subprocess import run
|
||||||
|
|
||||||
from setuptools import Command, Extension, setup
|
from setuptools import Command, Extension, find_namespace_packages, setup
|
||||||
from setuptools.command.bdist_wheel import bdist_wheel
|
from setuptools.command.bdist_wheel import bdist_wheel
|
||||||
from setuptools.command.build_ext import build_ext
|
from setuptools.command.build_ext import build_ext
|
||||||
|
|
||||||
@@ -166,6 +166,10 @@ class GenerateStubs(Command):
|
|||||||
# Run again without recursive to specify output file name
|
# Run again without recursive to specify output file name
|
||||||
subprocess.run(["rm", f"{out_path}/mlx.pyi"])
|
subprocess.run(["rm", f"{out_path}/mlx.pyi"])
|
||||||
subprocess.run(stub_cmd + ["-o", f"{out_path}/__init__.pyi"])
|
subprocess.run(stub_cmd + ["-o", f"{out_path}/__init__.pyi"])
|
||||||
|
# mx.bool_ gets filtered by nanobind because of the trailing
|
||||||
|
# underscore, add it manually:
|
||||||
|
with open(f"{out_path}/__init__.pyi", "a") as fid:
|
||||||
|
fid.write("\nbool_: Dtype = ...")
|
||||||
|
|
||||||
|
|
||||||
class MLXBdistWheel(bdist_wheel):
|
class MLXBdistWheel(bdist_wheel):
|
||||||
@@ -184,19 +188,23 @@ with open(Path(__file__).parent / "README.md", encoding="utf-8") as f:
|
|||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
package_dir = {"": "python"}
|
package_dir = {"": "python"}
|
||||||
packages = [
|
packages = find_namespace_packages(
|
||||||
"mlx",
|
where="python",
|
||||||
"mlx.nn",
|
exclude=[
|
||||||
"mlx.nn.layers",
|
"src",
|
||||||
"mlx.optimizers",
|
"tests",
|
||||||
]
|
"scripts",
|
||||||
|
"mlx.lib",
|
||||||
|
"mlx.include",
|
||||||
|
"mlx.share",
|
||||||
|
"mlx.share.**",
|
||||||
|
"mlx.include.**",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
build_macos = platform.system() == "Darwin"
|
build_macos = platform.system() == "Darwin"
|
||||||
build_cuda = "MLX_BUILD_CUDA=ON" in os.environ.get("CMAKE_ARGS", "")
|
build_cuda = "MLX_BUILD_CUDA=ON" in os.environ.get("CMAKE_ARGS", "")
|
||||||
|
|
||||||
install_requires = []
|
|
||||||
if build_cuda:
|
|
||||||
install_requires = ["nvidia-cublas-cu12", "nvidia-cuda-nvrtc-cu12"]
|
|
||||||
version = get_version()
|
version = get_version()
|
||||||
|
|
||||||
_setup = partial(
|
_setup = partial(
|
||||||
@@ -221,7 +229,7 @@ if __name__ == "__main__":
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
package_data = {"mlx": ["lib/*", "include/*", "share/*"], "mlx.core": ["*.pyi"]}
|
package_data = {"mlx.core": ["*.pyi"]}
|
||||||
|
|
||||||
extras = {
|
extras = {
|
||||||
"dev": [
|
"dev": [
|
||||||
@@ -239,6 +247,7 @@ if __name__ == "__main__":
|
|||||||
"mlx.distributed_config = mlx.distributed_run:distributed_config",
|
"mlx.distributed_config = mlx.distributed_run:distributed_config",
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
|
install_requires = []
|
||||||
|
|
||||||
# Release builds for PyPi are in two stages.
|
# Release builds for PyPi are in two stages.
|
||||||
# Each stage should be run from a clean build:
|
# Each stage should be run from a clean build:
|
||||||
@@ -258,11 +267,11 @@ if __name__ == "__main__":
|
|||||||
# - Package name is back-end specific, e.g mlx-metal
|
# - Package name is back-end specific, e.g mlx-metal
|
||||||
if build_stage != 2:
|
if build_stage != 2:
|
||||||
if build_stage == 1:
|
if build_stage == 1:
|
||||||
if build_macos:
|
install_requires.append(
|
||||||
install_requires += [f"mlx-metal=={version}"]
|
f'mlx-metal=={version}; platform_system == "Darwin"'
|
||||||
else:
|
)
|
||||||
extras["cuda"] = [f"mlx-cuda=={version}"]
|
extras["cuda"] = [f'mlx-cuda=={version}; platform_system == "Linux"']
|
||||||
extras["cpu"] = [f"mlx-cpu=={version}"]
|
extras["cpu"] = [f'mlx-cpu=={version}; platform_system == "Linux"']
|
||||||
|
|
||||||
_setup(
|
_setup(
|
||||||
name="mlx",
|
name="mlx",
|
||||||
@@ -277,9 +286,14 @@ if __name__ == "__main__":
|
|||||||
name = "mlx-metal"
|
name = "mlx-metal"
|
||||||
elif build_cuda:
|
elif build_cuda:
|
||||||
name = "mlx-cuda"
|
name = "mlx-cuda"
|
||||||
|
install_requires += [
|
||||||
|
"nvidia-cublas-cu12==12.9.*",
|
||||||
|
"nvidia-cuda-nvrtc-cu12==12.9.*",
|
||||||
|
]
|
||||||
else:
|
else:
|
||||||
name = "mlx-cpu"
|
name = "mlx-cpu"
|
||||||
_setup(
|
_setup(
|
||||||
name=name,
|
name=name,
|
||||||
packages=["mlx"],
|
packages=["mlx"],
|
||||||
|
install_requires=install_requires,
|
||||||
)
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user