Compare commits

...

15 Commits

Author SHA1 Message Date
Awni Hannun
70dc336785 Test on cuda 12.2 and 12.9 (#2413) 2025-07-24 06:06:15 -07:00
Awni Hannun
4e504039f5 [Metal] Release metal events (#2412)
* release metal events

* fix

* fix
2025-07-23 19:53:42 -07:00
Awni Hannun
d1f4d291e8 Fix uv install and add dev release (#2411)
* fix uv install and add dev release

* fix docstring

* pin cuda deps

* cuda release on cpu-only machine
2025-07-23 16:54:19 -07:00
Awni Hannun
e1840853ce full row mask in sdpa consistently gives nan (#2406) 2025-07-23 16:37:03 -07:00
Cheng
0f5ce173da [CUDA] --compress-mode requires CUDA 12.8 (#2407) 2025-07-23 06:11:11 -07:00
Cheng
588854195f Remove unused code in Convolution::vjp (#2408) 2025-07-23 06:11:00 -07:00
Fangjun Kuang
28d068bce6 Fix an error in the comment for mx.dequantize (#2409) 2025-07-23 06:10:50 -07:00
Awni Hannun
d107d8d495 add cuda gemv (#2400) 2025-07-22 08:24:13 -07:00
Awni Hannun
1e496ddb82 [CUDA] Simplify allocator (#2392)
* simplify allocator and fixe race with small pool

* Don't use shared event in worker

* use cuda buffer in small pool

* comment

* comment
2025-07-22 08:24:01 -07:00
Awni Hannun
74eccbf3fa use size option in binary (#2399) 2025-07-22 07:00:53 -07:00
Awni Hannun
08638223ca Fix including stubs in wheel (#2398)
* fix including stubs in wheel

* fix bool_
2025-07-22 06:30:17 -07:00
Cheng
56cc858af9 Add contiguous_copy_cpu util for copying array (#2397) 2025-07-21 07:30:35 -07:00
Cheng
f55c4ed1d6 Remove thrust iterators (#2396) 2025-07-21 07:30:27 -07:00
Awni Hannun
93d70419e7 [CUDA] speedup handling scalars (#2389)
* speedup scalars in cuda

* comment
2025-07-18 21:47:31 -07:00
Awni Hannun
63f663d9c6 fix cuda manylinux version to match others (#2388) 2025-07-18 21:02:16 -07:00
43 changed files with 599 additions and 433 deletions

View File

@@ -7,6 +7,9 @@ parameters:
nightly_build: nightly_build:
type: boolean type: boolean
default: false default: false
test_release:
type: boolean
default: false
jobs: jobs:
build_documentation: build_documentation:
@@ -200,8 +203,12 @@ jobs:
python -m xmlrunner discover -v python/tests -o test-results/gpu_jit python -m xmlrunner discover -v python/tests -o test-results/gpu_jit
cuda_build_and_test: cuda_build_and_test:
parameters:
image_date:
type: string
default: "2023.11.1"
machine: machine:
image: linux-cuda-12:2023.11.1 image: "linux-cuda-12:<< parameters.image_date >>"
resource_class: gpu.nvidia.small.gen2 resource_class: gpu.nvidia.small.gen2
steps: steps:
- checkout - checkout
@@ -366,22 +373,27 @@ jobs:
type: string type: string
default: "" default: ""
machine: machine:
image: linux-cuda-12:default image: ubuntu-2204:current
resource_class: gpu.nvidia.small.gen2 resource_class: large
steps: steps:
- checkout - checkout
- run: - run:
name: Build wheel name: Build wheel
command: | command: |
export DEBIAN_FRONTEND=noninteractive
export NEEDRESTART_MODE=a
wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2404/x86_64/cuda-keyring_1.1-1_all.deb
sudo dpkg -i cuda-keyring_1.1-1_all.deb
sudo apt-get update sudo apt-get update
sudo apt install cuda-toolkit-12-9
sudo apt-get install libblas-dev liblapack-dev liblapacke-dev sudo apt-get install libblas-dev liblapack-dev liblapacke-dev
sudo apt-get install zip sudo apt-get install zip
python -m venv env
source env/bin/activate
pip install auditwheel pip install auditwheel
pip install patchelf pip install patchelf
pip install build pip install build
pip install twine pip install twine
export PATH=/usr/local/cuda/bin${PATH:+:${PATH}}
export LD_LIBRARY_PATH=/usr/local/cuda/lib64${LD_LIBRARY_PATH:+:${LD_LIBRARY_PATH}}
<< parameters.build_env >> MLX_BUILD_STAGE=2 \ << parameters.build_env >> MLX_BUILD_STAGE=2 \
CMAKE_ARGS="-DMLX_BUILD_CUDA=ON -DCMAKE_CUDA_COMPILER=`which nvcc`" \ CMAKE_ARGS="-DMLX_BUILD_CUDA=ON -DCMAKE_CUDA_COMPILER=`which nvcc`" \
python -m build -w python -m build -w
@@ -392,7 +404,6 @@ jobs:
- run: - run:
name: Upload package name: Upload package
command: | command: |
source env/bin/activate
twine upload wheelhouse/*.whl twine upload wheelhouse/*.whl
- store_artifacts: - store_artifacts:
path: wheelhouse/ path: wheelhouse/
@@ -405,19 +416,24 @@ workflows:
pattern: "^(?!pull/)[-\\w]+$" pattern: "^(?!pull/)[-\\w]+$"
value: << pipeline.git.branch >> value: << pipeline.git.branch >>
- not: << pipeline.parameters.nightly_build >> - not: << pipeline.parameters.nightly_build >>
- not: << pipeline.parameters.test_release >>
jobs: jobs:
- mac_build_and_test: - mac_build_and_test:
matrix: matrix:
parameters: parameters:
macosx_deployment_target: ["13.5", "14.0"] macosx_deployment_target: ["13.5", "14.0"]
- linux_build_and_test - linux_build_and_test
- cuda_build_and_test - cuda_build_and_test:
matrix:
parameters:
image_date: ["2023.11.1", "2025.05.1"]
- build_documentation - build_documentation
build_pypi_release: build_pypi_release:
when: when:
and: and:
- not: << pipeline.parameters.nightly_build >> - not: << pipeline.parameters.nightly_build >>
- not: << pipeline.parameters.test_release >>
jobs: jobs:
- build_release: - build_release:
filters: filters:
@@ -601,3 +617,87 @@ workflows:
parameters: parameters:
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"] python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
- build_cuda_release - build_cuda_release
build_dev_release:
when:
and:
- equal: [ main, << pipeline.git.branch >> ]
- << pipeline.parameters.test_release >>
jobs:
- build_release:
matrix:
parameters:
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
macosx_deployment_target: ["13.5", "14.0", "15.0"]
build_env: ["DEV_RELEASE=1"]
xcode_version: ["16.2.0", "15.0.0"]
exclude:
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.9"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.10"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.11"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.12"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.13"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.9"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.10"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.11"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.12"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.13"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.9"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.10"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.11"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.12"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.13"
build_env: "DEV_RELEASE=1"
- build_linux_release:
matrix:
parameters:
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
build_env: ["DEV_RELEASE=1"]
- build_cuda_release:
matrix:
parameters:
build_env: ["DEV_RELEASE=1"]

View File

@@ -377,4 +377,10 @@ void copy_cpu_inplace(
}); });
} }
array contiguous_copy_cpu(const array& arr, Stream stream) {
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
copy_cpu(arr, arr_copy, CopyType::General, stream);
return arr_copy;
}
} // namespace mlx::core } // namespace mlx::core

View File

@@ -30,4 +30,7 @@ void copy_cpu_inplace(
const std::optional<array>& dynamic_i_offset = std::nullopt, const std::optional<array>& dynamic_i_offset = std::nullopt,
const std::optional<array>& dynamic_o_offset = std::nullopt); const std::optional<array>& dynamic_o_offset = std::nullopt);
// Return a contiguous array with same shape that copies the data of |arr|.
array contiguous_copy_cpu(const array& arr, Stream stream);
} // namespace mlx::core } // namespace mlx::core

View File

@@ -13,9 +13,7 @@ std::pair<array, bool> ensure_row_contiguous(const array& arr, Stream stream) {
if (arr.flags().row_contiguous) { if (arr.flags().row_contiguous) {
return {arr, false}; return {arr, false};
} else { } else {
array arr_copy(arr.shape(), arr.dtype(), nullptr, {}); return {contiguous_copy_cpu(arr, stream), true};
copy_cpu(arr, arr_copy, CopyType::General, stream);
return {arr_copy, true};
} }
}; };
@@ -34,8 +32,7 @@ void AllReduce::eval_cpu(
} }
return in; return in;
} else { } else {
array arr_copy(in.shape(), in.dtype(), nullptr, {}); array arr_copy = contiguous_copy_cpu(in, s);
copy_cpu(in, arr_copy, CopyType::General, s);
out.copy_shared_buffer(arr_copy); out.copy_shared_buffer(arr_copy);
return arr_copy; return arr_copy;
} }

View File

@@ -87,8 +87,7 @@ void LogSumExp::eval_cpu(const std::vector<array>& inputs, array& out) {
if (x.flags().contiguous && x.strides()[x.ndim() - 1] == 1) { if (x.flags().contiguous && x.strides()[x.ndim() - 1] == 1) {
return x; return x;
} else { } else {
auto x_copy = array(x.shape(), x.dtype(), nullptr, {}); array x_copy = contiguous_copy_cpu(x, s);
copy_cpu(x, x_copy, CopyType::General, s);
encoder.add_temporary(x_copy); encoder.add_temporary(x_copy);
return x_copy; return x_copy;
} }

View File

@@ -136,9 +136,8 @@ void BlockMaskedMM::eval_cpu(const std::vector<array>& inputs, array& out) {
} }
return std::make_tuple(true, sty, arr, false); return std::make_tuple(true, sty, arr, false);
} else { } else {
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
copy_cpu(arr, arr_copy, CopyType::General, s);
int64_t stx = arr.shape(-1); int64_t stx = arr.shape(-1);
array arr_copy = contiguous_copy_cpu(arr, s);
return std::make_tuple(false, stx, arr_copy, true); return std::make_tuple(false, stx, arr_copy, true);
} }
}; };

View File

@@ -712,9 +712,7 @@ void fast::AffineQuantize::eval_cpu(
if (arr.flags().row_contiguous) { if (arr.flags().row_contiguous) {
return std::make_pair(arr, false); return std::make_pair(arr, false);
} else { } else {
array arr_copy(arr.shape(), arr.dtype(), nullptr, {}); return std::make_pair(contiguous_copy_cpu(arr, s), true);
copy_cpu(arr, arr_copy, CopyType::General, s);
return std::make_pair(arr_copy, true);
} }
}; };

View File

@@ -250,10 +250,8 @@ void Scan::eval_cpu(const std::vector<array>& inputs, array& out) {
// Ensure contiguity // Ensure contiguity
auto in = inputs[0]; auto in = inputs[0];
if (!in.flags().row_contiguous) { if (!in.flags().row_contiguous) {
array arr_copy(in.shape(), in.dtype(), nullptr, {}); in = contiguous_copy_cpu(in, stream());
copy_cpu(in, arr_copy, CopyType::General, stream()); encoder.add_temporary(in);
in = arr_copy;
encoder.add_temporary(arr_copy);
} }
out.set_data(allocator::malloc(out.nbytes())); out.set_data(allocator::malloc(out.nbytes()));

View File

@@ -131,8 +131,7 @@ void Softmax::eval_cpu(const std::vector<array>& inputs, array& out) {
} }
return x; return x;
} else { } else {
array x_copy(x.shape(), x.dtype(), nullptr, {}); array x_copy = contiguous_copy_cpu(x, s);
copy_cpu(x, x_copy, CopyType::General, s);
out.copy_shared_buffer(x_copy); out.copy_shared_buffer(x_copy);
return x_copy; return x_copy;
} }

View File

@@ -20,6 +20,7 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/eval.cpp ${CMAKE_CURRENT_SOURCE_DIR}/eval.cpp
${CMAKE_CURRENT_SOURCE_DIR}/event.cu ${CMAKE_CURRENT_SOURCE_DIR}/event.cu
${CMAKE_CURRENT_SOURCE_DIR}/fence.cpp ${CMAKE_CURRENT_SOURCE_DIR}/fence.cpp
${CMAKE_CURRENT_SOURCE_DIR}/gemv.cu
${CMAKE_CURRENT_SOURCE_DIR}/jit_module.cpp ${CMAKE_CURRENT_SOURCE_DIR}/jit_module.cpp
${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp ${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp
${CMAKE_CURRENT_SOURCE_DIR}/kernel_utils.cu ${CMAKE_CURRENT_SOURCE_DIR}/kernel_utils.cu
@@ -87,6 +88,13 @@ endif()
target_compile_options( target_compile_options(
mlx PRIVATE "$<$<COMPILE_LANGUAGE:CUDA>:--Wno-deprecated-gpu-targets>") mlx PRIVATE "$<$<COMPILE_LANGUAGE:CUDA>:--Wno-deprecated-gpu-targets>")
# Use stronger binaries compression. This feature was introduced in CUDA 12.8
# and requires drivers released after CUDA 12.4.
if(CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 12.8.0)
target_compile_options(
mlx PRIVATE "$<$<COMPILE_LANGUAGE:CUDA>:--compress-mode=size>")
endif()
# Compute capability 7 is required for synchronization between CPU/GPU with # Compute capability 7 is required for synchronization between CPU/GPU with
# managed memory. TODO: Add more architectures for potential performance gain. # managed memory. TODO: Add more architectures for potential performance gain.
set(MLX_CUDA_ARCHITECTURES set(MLX_CUDA_ARCHITECTURES

View File

@@ -2,7 +2,6 @@
#include "mlx/backend/cuda/allocator.h" #include "mlx/backend/cuda/allocator.h"
#include "mlx/backend/cuda/utils.h" #include "mlx/backend/cuda/utils.h"
#include "mlx/backend/cuda/worker.h"
#include "mlx/utils.h" #include "mlx/utils.h"
#include <cuda_runtime.h> #include <cuda_runtime.h>
@@ -17,14 +16,66 @@ namespace cu {
constexpr int page_size = 16384; constexpr int page_size = 16384;
// Any allocations smaller than this will try to use the small pool
constexpr int small_block_size = 8;
// The small pool size in bytes. This should be a multiple of the host page
// size and small_block_size.
constexpr int small_pool_size = 4 * page_size;
SmallSizePool::SmallSizePool() {
auto num_blocks = small_pool_size / small_block_size;
buffer_ = new Block[num_blocks];
next_free_ = buffer_;
CHECK_CUDA_ERROR(cudaMallocManaged(&data_, small_pool_size));
CHECK_CUDA_ERROR(
cudaMemAdvise(data_, small_pool_size, cudaMemAdviseSetReadMostly, 0));
auto curr = next_free_;
for (size_t i = 1; i < num_blocks; ++i) {
curr->next = buffer_ + i;
curr = curr->next;
}
curr->next = nullptr;
}
SmallSizePool::~SmallSizePool() {
CHECK_CUDA_ERROR(cudaFree(data_));
delete[] buffer_;
}
CudaBuffer* SmallSizePool::malloc() {
if (next_free_ == nullptr) {
return nullptr;
}
Block* b = next_free_;
uint64_t i = next_free_ - buffer_;
next_free_ = next_free_->next;
b->buf.data = static_cast<char*>(data_) + i * small_block_size;
b->buf.size = small_block_size;
return &b->buf;
}
void SmallSizePool::free(CudaBuffer* buf) {
auto b = reinterpret_cast<Block*>(buf);
b->next = next_free_;
next_free_ = b;
}
bool SmallSizePool::in_pool(CudaBuffer* buf) {
constexpr int num_blocks = (small_pool_size / small_block_size);
auto b = reinterpret_cast<Block*>(buf);
int64_t block_num = b - buffer_;
return block_num >= 0 && block_num < num_blocks;
}
CudaAllocator::CudaAllocator() CudaAllocator::CudaAllocator()
: buffer_cache_( : buffer_cache_(
page_size, page_size,
[](CudaBuffer* buf) { return buf->size; }, [](CudaBuffer* buf) { return buf->size; },
[this](CudaBuffer* buf) { [this](CudaBuffer* buf) { cuda_free(buf); }) {
cuda_free(buf->data);
delete buf;
}) {
// TODO: Set memory limit for multi-device. // TODO: Set memory limit for multi-device.
size_t free, total; size_t free, total;
CHECK_CUDA_ERROR(cudaMemGetInfo(&free, &total)); CHECK_CUDA_ERROR(cudaMemGetInfo(&free, &total));
@@ -36,7 +87,9 @@ Buffer CudaAllocator::malloc(size_t size) {
// Find available buffer from cache. // Find available buffer from cache.
auto orig_size = size; auto orig_size = size;
std::unique_lock lock(mutex_); std::unique_lock lock(mutex_);
if (size < page_size) { if (size <= small_block_size) {
size = 8;
} else if (size < page_size) {
size = next_power_of_2(size); size = next_power_of_2(size);
} else { } else {
size = page_size * ((size + page_size - 1) / page_size); size = page_size * ((size + page_size - 1) / page_size);
@@ -44,19 +97,25 @@ Buffer CudaAllocator::malloc(size_t size) {
CudaBuffer* buf = buffer_cache_.reuse_from_cache(size); CudaBuffer* buf = buffer_cache_.reuse_from_cache(size);
if (!buf) { if (!buf) {
// If we have a lot of memory pressure or are over the maximum cache size, // If we have a lot of memory pressure try to reclaim memory from the cache.
// try to reclaim memory from the cache. int64_t mem_to_free =
size_t mem_required = get_active_memory() + get_cache_memory() + size; get_active_memory() + get_cache_memory() + size - memory_limit_;
if (mem_required >= memory_limit_) { if (mem_to_free > 0) {
buffer_cache_.release_cached_buffers(mem_required - memory_limit_); buffer_cache_.release_cached_buffers(mem_to_free);
} }
// Try the scalar pool first
if (size <= small_block_size) {
buf = scalar_pool_.malloc();
}
lock.unlock(); lock.unlock();
buf = new CudaBuffer{nullptr, size}; if (!buf) {
cudaError_t err = cudaMallocManaged(&buf->data, size); buf = new CudaBuffer{nullptr, size};
if (err != cudaSuccess && err != cudaErrorMemoryAllocation) { cudaError_t err = cudaMallocManaged(&buf->data, size);
throw std::runtime_error(fmt::format( if (err != cudaSuccess && err != cudaErrorMemoryAllocation) {
"cudaMallocManaged failed: {}.", cudaGetErrorString(err))); throw std::runtime_error(fmt::format(
"cudaMallocManaged failed: {}.", cudaGetErrorString(err)));
}
} }
lock.lock(); lock.lock();
} }
@@ -67,7 +126,6 @@ Buffer CudaAllocator::malloc(size_t size) {
if (get_cache_memory() > max_pool_size_) { if (get_cache_memory() > max_pool_size_) {
buffer_cache_.release_cached_buffers(get_cache_memory() - max_pool_size_); buffer_cache_.release_cached_buffers(get_cache_memory() - max_pool_size_);
} }
return Buffer{buf}; return Buffer{buf};
} }
@@ -82,9 +140,7 @@ void CudaAllocator::free(Buffer buffer) {
if (get_cache_memory() < max_pool_size_) { if (get_cache_memory() < max_pool_size_) {
buffer_cache_.recycle_to_cache(buf); buffer_cache_.recycle_to_cache(buf);
} else { } else {
lock.unlock(); cuda_free(buf);
cuda_free(buf->data);
delete buf;
} }
} }
@@ -96,27 +152,14 @@ size_t CudaAllocator::size(Buffer buffer) const {
return buf->size; return buf->size;
} }
void CudaAllocator::register_this_thread() { // This must be called with mutex_ aquired
std::lock_guard lock(worker_mutex_); void CudaAllocator::cuda_free(CudaBuffer* buf) {
allowed_threads_.insert(std::this_thread::get_id()); if (scalar_pool_.in_pool(buf)) {
} scalar_pool_.free(buf);
} else {
void CudaAllocator::cuda_free(void* buf) { cudaFree(buf->data);
// If cuda_free() is called from a unregistered thread, reschedule the call to delete buf;
// worker.
{
std::lock_guard lock(worker_mutex_);
if (allowed_threads_.count(std::this_thread::get_id()) == 0) {
if (!worker_) {
worker_.reset(new Worker);
}
worker_->add_task([this, buf]() { this->cuda_free(buf); });
worker_->end_batch();
worker_->commit();
return;
}
} }
cudaFree(buf);
} }
size_t CudaAllocator::get_active_memory() const { size_t CudaAllocator::get_active_memory() const {

View File

@@ -7,13 +7,10 @@
#include <mutex> #include <mutex>
#include <set> #include <set>
#include <thread>
#include <utility> #include <utility>
namespace mlx::core::cu { namespace mlx::core::cu {
class Worker;
using allocator::Buffer; using allocator::Buffer;
// Stores cuda-managed unified memory. // Stores cuda-managed unified memory.
@@ -22,21 +19,35 @@ struct CudaBuffer {
size_t size; size_t size;
}; };
class SmallSizePool {
private:
union Block {
Block* next;
CudaBuffer buf;
};
Block* buffer_{nullptr};
void* data_{nullptr};
Block* next_free_{nullptr};
public:
SmallSizePool();
~SmallSizePool();
SmallSizePool(const SmallSizePool&) = delete;
SmallSizePool& operator=(const SmallSizePool&) = delete;
CudaBuffer* malloc();
void free(CudaBuffer* buf);
bool in_pool(CudaBuffer* buf);
};
class CudaAllocator : public allocator::Allocator { class CudaAllocator : public allocator::Allocator {
public: public:
Buffer malloc(size_t size) override; Buffer malloc(size_t size) override;
void free(Buffer buffer) override; void free(Buffer buffer) override;
size_t size(Buffer buffer) const override; size_t size(Buffer buffer) const override;
// Register current thread as safe to free buffers.
// In cuda freeing a buffer implicitly synchronizes stream, and for threads
// that may be waited by gpu stream (for example cpu stream threads), freeing
// buffers there would result in dead lock.
void register_this_thread();
// Call cudaFree in the safe thread.
void cuda_free(void* buf);
size_t get_active_memory() const; size_t get_active_memory() const;
size_t get_peak_memory() const; size_t get_peak_memory() const;
void reset_peak_memory(); void reset_peak_memory();
@@ -47,19 +58,18 @@ class CudaAllocator : public allocator::Allocator {
void clear_cache(); void clear_cache();
private: private:
void cuda_free(CudaBuffer* buf);
CudaAllocator(); CudaAllocator();
friend CudaAllocator& allocator(); friend CudaAllocator& allocator();
std::mutex worker_mutex_;
std::unique_ptr<Worker> worker_;
std::set<std::thread::id> allowed_threads_;
std::mutex mutex_; std::mutex mutex_;
size_t memory_limit_; size_t memory_limit_;
size_t max_pool_size_; size_t max_pool_size_;
BufferCache<CudaBuffer> buffer_cache_; BufferCache<CudaBuffer> buffer_cache_;
size_t active_memory_{0}; size_t active_memory_{0};
size_t peak_memory_{0}; size_t peak_memory_{0};
SmallSizePool scalar_pool_;
}; };
CudaAllocator& allocator(); CudaAllocator& allocator();

View File

@@ -1,8 +1,8 @@
// Copyright © 2025 Apple Inc. // Copyright © 2025 Apple Inc.
#include "mlx/backend/common/utils.h" #include "mlx/backend/common/utils.h"
#include "mlx/backend/cuda/device.h" #include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/device/fp16_math.cuh" #include "mlx/backend/cuda/device/fp16_math.cuh"
#include "mlx/backend/cuda/iterators/strided_iterator.cuh"
#include "mlx/backend/cuda/kernel_utils.cuh" #include "mlx/backend/cuda/kernel_utils.cuh"
#include "mlx/dtype_utils.h" #include "mlx/dtype_utils.h"
#include "mlx/primitives.h" #include "mlx/primitives.h"
@@ -115,7 +115,7 @@ __global__ void arg_reduce_general(
T vals[N_READS]; T vals[N_READS];
auto tid = r * BLOCK_DIM + block.thread_index().x; auto tid = r * BLOCK_DIM + block.thread_index().x;
cub::LoadDirectBlocked( cub::LoadDirectBlocked(
tid, strided_iterator(in + in_idx, axis_stride), vals, axis_size, init); tid, StridedIterator(in + in_idx, axis_stride), vals, axis_size, init);
best = op.reduce_many(best, vals, tid * N_READS); best = op.reduce_many(best, vals, tid * N_READS);
} }

View File

@@ -128,7 +128,7 @@ __global__ void binary_g(
int ndim) { int ndim) {
IdxT index = cg::this_grid().thread_rank(); IdxT index = cg::this_grid().thread_rank();
if (index < size) { if (index < size) {
auto [a_idx, b_idx] = elem_to_loc_4d( auto [a_idx, b_idx] = elem_to_loc(
index, shape.data(), a_strides.data(), b_strides.data(), ndim); index, shape.data(), a_strides.data(), b_strides.data(), ndim);
out[index] = Op{}(a[a_idx], b[b_idx]); out[index] = Op{}(a[a_idx], b[b_idx]);
} }

View File

@@ -160,7 +160,7 @@ __global__ void binary_two_g(
int ndim) { int ndim) {
IdxT index = cg::this_grid().thread_rank(); IdxT index = cg::this_grid().thread_rank();
if (index < size) { if (index < size) {
auto [a_idx, b_idx] = elem_to_loc_4d( auto [a_idx, b_idx] = elem_to_loc(
index, shape.data(), a_strides.data(), b_strides.data(), ndim); index, shape.data(), a_strides.data(), b_strides.data(), ndim);
auto out = Op{}(a[a_idx], b[b_idx]); auto out = Op{}(a[a_idx], b[b_idx]);
out_a[index] = out[0]; out_a[index] = out[0];

View File

@@ -37,7 +37,7 @@ __global__ void copy_gg(
int ndim) { int ndim) {
IdxT index = cg::this_grid().thread_rank(); IdxT index = cg::this_grid().thread_rank();
if (index < size) { if (index < size) {
auto [idx_in, idx_out] = elem_to_loc_4d( auto [idx_in, idx_out] = elem_to_loc(
index, shape.data(), strides_in.data(), strides_out.data(), ndim); index, shape.data(), strides_in.data(), strides_out.data(), ndim);
out[idx_out] = CastOp<In, Out>{}(in[idx_in]); out[idx_out] = CastOp<In, Out>{}(in[idx_in]);
} }

View File

@@ -41,7 +41,7 @@ __global__ void copy_gg_dynamic(
const int64_t* offset_out) { const int64_t* offset_out) {
IdxT index = cg::this_grid().thread_rank(); IdxT index = cg::this_grid().thread_rank();
if (index < size) { if (index < size) {
auto [idx_in, idx_out] = elem_to_loc_4d( auto [idx_in, idx_out] = elem_to_loc(
index, shape.data(), strides_in.data(), strides_out.data(), ndim); index, shape.data(), strides_in.data(), strides_out.data(), ndim);
out[idx_out + *offset_out] = CastOp<In, Out>{}(in[idx_in + *offset_in]); out[idx_out + *offset_out] = CastOp<In, Out>{}(in[idx_in + *offset_in]);
} }

View File

@@ -34,7 +34,7 @@ __global__ void copy_g(
int ndim) { int ndim) {
IdxT index = cg::this_grid().thread_rank(); IdxT index = cg::this_grid().thread_rank();
if (index < size) { if (index < size) {
IdxT idx_in = elem_to_loc_4d(index, shape.data(), strides_in.data(), ndim); IdxT idx_in = elem_to_loc(index, shape.data(), strides_in.data(), ndim);
out[index] = CastOp<In, Out>{}(in[idx_in]); out[index] = CastOp<In, Out>{}(in[idx_in]);
} }
} }

View File

@@ -306,7 +306,6 @@ void CommandEncoder::commit() {
} }
// Put completion handlers in a batch. // Put completion handlers in a batch.
worker_.end_batch();
worker_.commit(stream_); worker_.commit(stream_);
} }
@@ -315,7 +314,6 @@ void CommandEncoder::synchronize() {
auto p = std::make_shared<std::promise<void>>(); auto p = std::make_shared<std::promise<void>>();
std::future<void> f = p->get_future(); std::future<void> f = p->get_future();
add_completed_handler([p = std::move(p)]() { p->set_value(); }); add_completed_handler([p = std::move(p)]() { p->set_value(); });
worker_.end_batch();
commit(); commit();
f.wait(); f.wait();
} }

View File

@@ -49,6 +49,20 @@ store_vector(T* ptr, uint32_t offset, const AlignedVector<T, N>& vec) {
to[offset] = vec; to[offset] = vec;
} }
// Helper for accessing strided data.
template <typename T>
struct StridedIterator {
T it;
int64_t stride;
__host__ __device__ StridedIterator(T it, int64_t stride)
: it(it), stride(stride) {}
__host__ __device__ auto operator[](int i) const {
return it[i * stride];
}
};
/////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////
// Type limits utils // Type limits utils
/////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////
@@ -204,20 +218,8 @@ inline __host__ __device__ cuda::std::tuple<IdxT, IdxT, IdxT> elem_to_loc_nd(
return cuda::std::make_tuple(a_loc, b_loc, c_loc); return cuda::std::make_tuple(a_loc, b_loc, c_loc);
} }
// Optimized version when ndim is larger than 4.
template <typename IdxT = int64_t> template <typename IdxT = int64_t>
inline __host__ __device__ IdxT inline __host__ __device__ cuda::std::tuple<IdxT, IdxT> elem_to_loc(
elem_to_loc_4d(IdxT elem, const int* shape, const int64_t* strides, int ndim) {
IdxT loc = 0;
for (int i = ndim - 1; i >= 0; --i) {
loc += (elem % shape[i]) * IdxT(strides[i]);
elem /= shape[i];
}
return loc;
}
template <typename IdxT = int64_t>
inline __host__ __device__ cuda::std::tuple<IdxT, IdxT> elem_to_loc_4d(
IdxT elem, IdxT elem,
const int* shape, const int* shape,
const int64_t* a_strides, const int64_t* a_strides,
@@ -235,7 +237,7 @@ inline __host__ __device__ cuda::std::tuple<IdxT, IdxT> elem_to_loc_4d(
} }
template <typename IdxT = int64_t> template <typename IdxT = int64_t>
inline __host__ __device__ cuda::std::tuple<IdxT, IdxT, IdxT> elem_to_loc_4d( inline __host__ __device__ cuda::std::tuple<IdxT, IdxT, IdxT> elem_to_loc(
IdxT elem, IdxT elem,
const int* shape, const int* shape,
const int64_t* a_strides, const int64_t* a_strides,

View File

@@ -19,8 +19,6 @@ void new_stream(Stream s) {
cudaFree(nullptr); cudaFree(nullptr);
// Ensure the static stream objects get created. // Ensure the static stream objects get created.
cu::get_command_encoder(s); cu::get_command_encoder(s);
// The main thread is safe to free buffers.
cu::allocator().register_this_thread();
} }
void eval(array& arr) { void eval(array& arr) {

View File

@@ -110,24 +110,26 @@ __global__ void event_signal_kernel(SharedEvent::Atomic* ac, uint64_t value) {
event_signal(ac, value); event_signal(ac, value);
} }
SharedEvent::Atomic* to_atomic(std::shared_ptr<Buffer> buf) {
return static_cast<SharedEvent::Atomic*>(buf->raw_ptr());
}
SharedEvent::SharedEvent() { SharedEvent::SharedEvent() {
// Allocate cuda::atomic on managed memory. buf_ = std::shared_ptr<Buffer>(
Atomic* ac; new Buffer{allocator().malloc(sizeof(Atomic))}, [](Buffer* ptr) {
CHECK_CUDA_ERROR(cudaMallocManaged(&ac, sizeof(Atomic))); allocator().free(*ptr);
new (ac) Atomic(0); delete ptr;
ac_ = std::shared_ptr<Atomic>(ac, [](Atomic* ptr) { });
ptr->~Atomic(); *static_cast<uint64_t*>(buf_->raw_ptr()) = 0;
allocator().cuda_free(ptr);
});
} }
void SharedEvent::wait(uint64_t value) { void SharedEvent::wait(uint64_t value) {
nvtx3::scoped_range r("cu::SharedEvent::wait"); nvtx3::scoped_range r("cu::SharedEvent::wait");
event_wait(ac_.get(), value); event_wait(to_atomic(buf_), value);
} }
void SharedEvent::wait(cudaStream_t stream, uint64_t value) { void SharedEvent::wait(cudaStream_t stream, uint64_t value) {
event_wait_kernel<<<1, 1, 0, stream>>>(ac_.get(), value); event_wait_kernel<<<1, 1, 0, stream>>>(to_atomic(buf_), value);
} }
void SharedEvent::wait(Stream s, uint64_t value) { void SharedEvent::wait(Stream s, uint64_t value) {
@@ -138,17 +140,17 @@ void SharedEvent::wait(Stream s, uint64_t value) {
auto& encoder = get_command_encoder(s); auto& encoder = get_command_encoder(s);
encoder.commit(); encoder.commit();
wait(encoder.stream(), value); wait(encoder.stream(), value);
encoder.add_completed_handler([ac = ac_]() {}); encoder.add_completed_handler([buf = buf_]() {});
} }
} }
void SharedEvent::signal(uint64_t value) { void SharedEvent::signal(uint64_t value) {
nvtx3::scoped_range r("cu::SharedEvent::signal"); nvtx3::scoped_range r("cu::SharedEvent::signal");
event_signal(ac_.get(), value); event_signal(to_atomic(buf_), value);
} }
void SharedEvent::signal(cudaStream_t stream, uint64_t value) { void SharedEvent::signal(cudaStream_t stream, uint64_t value) {
event_signal_kernel<<<1, 1, 0, stream>>>(ac_.get(), value); event_signal_kernel<<<1, 1, 0, stream>>>(to_atomic(buf_), value);
} }
void SharedEvent::signal(Stream s, uint64_t value) { void SharedEvent::signal(Stream s, uint64_t value) {
@@ -162,18 +164,18 @@ void SharedEvent::signal(Stream s, uint64_t value) {
auto& encoder = get_command_encoder(s); auto& encoder = get_command_encoder(s);
encoder.commit(); encoder.commit();
signal(encoder.stream(), value); signal(encoder.stream(), value);
encoder.add_completed_handler([ac = ac_]() {}); encoder.add_completed_handler([buf = buf_]() {});
} }
} }
bool SharedEvent::is_signaled(uint64_t value) const { bool SharedEvent::is_signaled(uint64_t value) const {
nvtx3::scoped_range r("cu::SharedEvent::is_signaled"); nvtx3::scoped_range r("cu::SharedEvent::is_signaled");
return ac_->load() >= value; return to_atomic(buf_)->load() >= value;
} }
uint64_t SharedEvent::value() const { uint64_t SharedEvent::value() const {
nvtx3::scoped_range r("cu::SharedEvent::value"); nvtx3::scoped_range r("cu::SharedEvent::value");
return ac_->load(); return to_atomic(buf_)->load();
} }
} // namespace cu } // namespace cu

View File

@@ -2,6 +2,7 @@
#pragma once #pragma once
#include "mlx/allocator.h"
#include "mlx/stream.h" #include "mlx/stream.h"
#include <cuda_runtime.h> #include <cuda_runtime.h>
@@ -55,12 +56,8 @@ class SharedEvent {
bool is_signaled(uint64_t value) const; bool is_signaled(uint64_t value) const;
uint64_t value() const; uint64_t value() const;
const std::shared_ptr<Atomic>& atomic() const {
return ac_;
}
private: private:
std::shared_ptr<Atomic> ac_; std::shared_ptr<mlx::core::allocator::Buffer> buf_;
}; };
} // namespace mlx::core::cu } // namespace mlx::core::cu

147
mlx/backend/cuda/gemv.cu Normal file
View File

@@ -0,0 +1,147 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/gemv.h"
#include "mlx/backend/cuda/kernel_utils.cuh"
#include "mlx/dtype_utils.h"
#include <cooperative_groups.h>
#include <cooperative_groups/reduce.h>
namespace mlx::core::cu {
namespace cg = cooperative_groups;
static constexpr int n_per_thread = 4;
static constexpr int rows_per_block = 8;
template <typename T, int rows_per_block, int n_per_thread>
__device__ void
gemv_impl(const T* mat, const T* vec, T* out, int rows, int cols) {
auto block = cg::this_thread_block();
auto warp = cg::tiled_partition<WARP_SIZE>(block);
auto g_idx = block.group_index();
auto t_idx = block.thread_index();
int row = g_idx.x * rows_per_block + t_idx.y;
if (row < rows) {
float sum = 0.0f;
for (int col = n_per_thread * warp.thread_rank(); col < cols;
col += (WARP_SIZE * n_per_thread)) {
auto local_mat = load_vector<n_per_thread>(mat + row * cols + col, 0);
auto local_vec = load_vector<n_per_thread>(vec + col, 0);
#pragma unroll
for (int j = 0; j < n_per_thread; ++j) {
sum += static_cast<float>(local_mat.val[j]) *
static_cast<float>(local_vec.val[j]);
}
}
sum = cg::reduce(warp, sum, cg::plus<float>{});
if (warp.thread_rank() == 0) {
out[row] = static_cast<T>(sum);
}
}
}
template <typename T, int rows_per_block, int n_per_thread>
__global__ void
gemv_single(const T* mat, const T* vec, T* out, int rows, int cols) {
gemv_impl<T, rows_per_block, n_per_thread>(mat, vec, out, rows, cols);
}
template <typename T, int rows_per_block, int n_per_thread>
__global__ void gemv_batched(
const T* mat,
const T* vec,
T* out,
int rows,
int cols,
const __grid_constant__ Shape batch_shape,
const __grid_constant__ Strides mat_batch_strides,
const __grid_constant__ Strides vec_batch_strides,
int batch_ndim) {
auto block = cg::this_thread_block();
auto batch_idx = block.group_index().y;
auto [vec_offset, mat_offset] = elem_to_loc(
batch_idx,
batch_shape.data(),
vec_batch_strides.data(),
mat_batch_strides.data(),
batch_ndim);
gemv_impl<T, rows_per_block, n_per_thread>(
mat + mat_offset, vec + vec_offset, out + batch_idx * rows, rows, cols);
}
bool can_use_gemv(int M, int N, int K, bool a_transposed, bool b_transposed) {
return K % (WARP_SIZE * n_per_thread) == 0 &&
((M == 1 && b_transposed) || (N == 1 && !a_transposed));
}
void gemv(
const array& a,
const array& b,
array& out,
int M,
int N,
int K,
uint32_t batch_count,
const mlx::core::Shape& batch_shape,
const mlx::core::Strides& a_batch_strides,
const mlx::core::Strides& b_batch_strides,
CommandEncoder& encoder) {
encoder.set_input_array(a);
encoder.set_input_array(b);
encoder.set_output_array(out);
dispatch_float_types(out.dtype(), "gemv", [&](auto type_tag) {
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
dim3 block_dims{WARP_SIZE, rows_per_block};
const DataType* mat;
const DataType* vec;
int rows;
int cols = K;
auto mat_strides = const_param(a_batch_strides);
auto vec_strides = const_param(b_batch_strides);
if (M == 1) {
mat = b.data<DataType>();
vec = a.data<DataType>();
rows = N;
std::swap(mat_strides, vec_strides);
} else {
mat = a.data<DataType>();
vec = b.data<DataType>();
rows = M;
}
uint32_t num_blocks_x = (rows + rows_per_block - 1) / rows_per_block;
if (batch_count == 1) {
auto kernel = gemv_single<DataType, rows_per_block, n_per_thread>;
encoder.add_kernel_node(
kernel,
num_blocks_x,
block_dims,
mat,
vec,
out.data<DataType>(),
rows,
cols);
} else {
auto kernel = gemv_batched<DataType, rows_per_block, n_per_thread>;
encoder.add_kernel_node(
kernel,
dim3{num_blocks_x, batch_count},
block_dims,
mat,
vec,
out.data<DataType>(),
rows,
cols,
const_param(batch_shape),
mat_strides,
vec_strides,
batch_shape.size());
}
});
}
} // namespace mlx::core::cu

24
mlx/backend/cuda/gemv.h Normal file
View File

@@ -0,0 +1,24 @@
// Copyright © 2025 Apple Inc.
#pragma once
#include "mlx/backend/cuda/device.h"
namespace mlx::core::cu {
bool can_use_gemv(int M, int N, int K, bool a_transposed, bool b_transposed);
void gemv(
const array& a,
const array& b,
array& out,
int M,
int N,
int K,
uint32_t batch_count,
const mlx::core::Shape& batch_shape,
const mlx::core::Strides& a_batch_strides,
const mlx::core::Strides& b_batch_strides,
CommandEncoder& encoder);
} // namespace mlx::core::cu

View File

@@ -1,121 +0,0 @@
// Copyright © 2025 Apple Inc.
#pragma once
#include <thrust/iterator/iterator_adaptor.h>
#include <cuda/std/utility>
#include "mlx/backend/cuda/kernel_utils.cuh"
namespace mlx::core::cu {
// Iterating non-contiguous array.
template <typename Iterator, typename IdxT = int64_t>
class general_iterator
: public thrust::
iterator_adaptor<general_iterator<Iterator, IdxT>, Iterator> {
public:
using super_t =
thrust::iterator_adaptor<general_iterator<Iterator, IdxT>, Iterator>;
using reference = typename super_t::reference;
using difference_type = typename super_t::difference_type;
__host__ __device__ general_iterator(
Iterator it,
IdxT index,
int ndim,
Shape shape,
Strides strides)
: super_t(it),
index_(index),
ndim_(ndim),
shape_(cuda::std::move(shape)),
strides_(cuda::std::move(strides)) {}
__host__ __device__ IdxT index() const {
return index_;
}
__host__ __device__ const Shape& shape() const {
return shape_;
}
__host__ __device__ const Strides& strides() const {
return strides_;
}
private:
friend class thrust::iterator_core_access;
__host__ __device__ bool equal(const general_iterator& other) const {
return this->base() == other.base() && this->index() == other.index();
}
__host__ __device__ void advance(difference_type n) {
this->index_ += n;
}
__host__ __device__ void increment() {
this->index_ += 1;
}
__host__ __device__ void decrement() {
this->index_ -= 1;
}
__host__ __device__ difference_type
distance_to(const general_iterator& other) const {
_CCCL_ASSERT(
this->base() == other.base(),
"Underlying iterator must point to same base iterator");
return other.index() - this->index();
}
// The dereference is device-only to avoid accidental running in host.
__device__ typename super_t::reference dereference() const {
IdxT offset = elem_to_loc(index_, shape_.data(), strides_.data(), ndim_);
return *(this->base() + offset);
}
IdxT index_;
int ndim_;
Shape shape_;
Strides strides_;
};
template <typename IdxT, typename Iterator>
__host__ __device__ auto make_general_iterator(
Iterator it,
IdxT index,
int ndim,
Shape shape,
Strides strides) {
return general_iterator<Iterator, IdxT>(
it, index, ndim, cuda::std::move(shape), cuda::std::move(strides));
}
template <typename IdxT, typename Iterator>
auto make_general_iterator(
Iterator it,
const std::vector<int32_t>& shape,
const std::vector<int64_t>& strides) {
return make_general_iterator<IdxT>(
it, 0, shape.size(), const_param(shape), const_param(strides));
}
template <typename IdxT, typename Iterator>
auto make_general_iterators(
Iterator it,
IdxT size,
const std::vector<int32_t>& shape,
const std::vector<int64_t>& strides) {
auto ndim = shape.size();
auto shape_arg = const_param(shape);
auto strides_arg = const_param(strides);
return std::make_pair(
make_general_iterator<IdxT>(it, 0, ndim, shape_arg, strides_arg),
make_general_iterator<IdxT>(it, size, ndim, shape_arg, strides_arg));
}
} // namespace mlx::core::cu

View File

@@ -1,60 +0,0 @@
// Copyright © 2025 Apple Inc.
#pragma once
#include <thrust/iterator/iterator_adaptor.h>
#include <thrust/iterator/iterator_facade.h>
namespace mlx::core::cu {
// RandomAccessIterator for strided access to array entries.
template <typename Iterator, typename Stride = int64_t>
class strided_iterator
: public thrust::
iterator_adaptor<strided_iterator<Iterator, Stride>, Iterator> {
public:
using super_t =
thrust::iterator_adaptor<strided_iterator<Iterator, Stride>, Iterator>;
using reference = typename super_t::reference;
using difference_type = typename super_t::difference_type;
__host__ __device__ strided_iterator(Iterator it, Stride stride)
: super_t(it), stride_(stride) {}
__host__ __device__ Stride stride() const {
return stride_;
}
private:
friend class thrust::iterator_core_access;
__host__ __device__ bool equal(const strided_iterator& other) const {
return this->base() == other.base();
}
__host__ __device__ void advance(difference_type n) {
this->base_reference() += n * stride_;
}
__host__ __device__ void increment() {
this->base_reference() += stride_;
}
__host__ __device__ void decrement() {
this->base_reference() -= stride_;
}
__host__ __device__ difference_type
distance_to(const strided_iterator& other) const {
const difference_type dist = other.base() - this->base();
_CCCL_ASSERT(
dist % stride() == 0,
"Underlying iterator difference must be divisible by the stride");
return dist / stride();
}
Stride stride_;
};
} // namespace mlx::core::cu

View File

@@ -1,7 +1,6 @@
// Copyright © 2025 Apple Inc. // Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/device.h" #include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/iterators/strided_iterator.cuh"
#include "mlx/backend/cuda/kernel_utils.cuh" #include "mlx/backend/cuda/kernel_utils.cuh"
#include "mlx/backend/cuda/reduce/reduce.cuh" #include "mlx/backend/cuda/reduce/reduce.cuh"
#include "mlx/backend/gpu/copy.h" #include "mlx/backend/gpu/copy.h"
@@ -105,8 +104,8 @@ __global__ void layer_norm(
T wn[N_READS]; T wn[N_READS];
T bn[N_READS]; T bn[N_READS];
cub::LoadDirectBlocked(index, x, xn, axis_size); cub::LoadDirectBlocked(index, x, xn, axis_size);
cub::LoadDirectBlocked(index, strided_iterator(w, w_stride), wn, axis_size); cub::LoadDirectBlocked(index, StridedIterator(w, w_stride), wn, axis_size);
cub::LoadDirectBlocked(index, strided_iterator(b, b_stride), bn, axis_size); cub::LoadDirectBlocked(index, StridedIterator(b, b_stride), bn, axis_size);
for (int i = 0; i < N_READS; ++i) { for (int i = 0; i < N_READS; ++i) {
float norm = (static_cast<float>(xn[i]) - mean) * normalizer; float norm = (static_cast<float>(xn[i]) - mean) * normalizer;
xn[i] = wn[i] * static_cast<T>(norm) + bn[i]; xn[i] = wn[i] * static_cast<T>(norm) + bn[i];
@@ -162,7 +161,7 @@ __global__ void layer_norm_vjp(
auto index = r * BLOCK_DIM + block.thread_rank(); auto index = r * BLOCK_DIM + block.thread_rank();
cub::LoadDirectBlocked(index, x, xn, axis_size, mean); cub::LoadDirectBlocked(index, x, xn, axis_size, mean);
cub::LoadDirectBlocked(index, g, gn, axis_size); cub::LoadDirectBlocked(index, g, gn, axis_size);
cub::LoadDirectBlocked(index, strided_iterator(w, w_stride), wn, axis_size); cub::LoadDirectBlocked(index, StridedIterator(w, w_stride), wn, axis_size);
for (int i = 0; i < N_READS; i++) { for (int i = 0; i < N_READS; i++) {
float t = static_cast<float>(xn[i]) - mean; float t = static_cast<float>(xn[i]) - mean;
float wi = wn[i]; float wi = wn[i];
@@ -185,7 +184,7 @@ __global__ void layer_norm_vjp(
T gn[N_READS]; T gn[N_READS];
cub::LoadDirectBlocked(index, x, xn, axis_size); cub::LoadDirectBlocked(index, x, xn, axis_size);
cub::LoadDirectBlocked(index, g, gn, axis_size); cub::LoadDirectBlocked(index, g, gn, axis_size);
cub::LoadDirectBlocked(index, strided_iterator(w, w_stride), wn, axis_size); cub::LoadDirectBlocked(index, StridedIterator(w, w_stride), wn, axis_size);
for (int i = 0; i < N_READS; i++) { for (int i = 0; i < N_READS; i++) {
float xi = (static_cast<float>(xn[i]) - mean) * normalizer; float xi = (static_cast<float>(xn[i]) - mean) * normalizer;
float wi = wn[i]; float wi = wn[i];

View File

@@ -2,6 +2,7 @@
#include "mlx/backend/common/matmul.h" #include "mlx/backend/common/matmul.h"
#include "mlx/backend/cuda/device.h" #include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/gemv.h"
#include "mlx/backend/gpu/copy.h" #include "mlx/backend/gpu/copy.h"
#include "mlx/dtype_utils.h" #include "mlx/dtype_utils.h"
#include "mlx/primitives.h" #include "mlx/primitives.h"
@@ -353,6 +354,22 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
batch_shape = {1}; batch_shape = {1};
} }
if (cu::can_use_gemv(M, N, K, a_transposed, b_transposed)) {
cu::gemv(
a,
b,
out,
M,
N,
K,
batch_count,
batch_shape,
a_batch_strides,
b_batch_strides,
encoder);
return;
}
///////////////////////////////////////////////////////////////////////////// /////////////////////////////////////////////////////////////////////////////
// Invoke cublasLt // Invoke cublasLt

View File

@@ -1,7 +1,6 @@
// Copyright © 2025 Apple Inc. // Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/device.h" #include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/iterators/strided_iterator.cuh"
#include "mlx/backend/cuda/kernel_utils.cuh" #include "mlx/backend/cuda/kernel_utils.cuh"
#include "mlx/backend/cuda/reduce/reduce.cuh" #include "mlx/backend/cuda/reduce/reduce.cuh"
#include "mlx/backend/gpu/copy.h" #include "mlx/backend/gpu/copy.h"
@@ -89,7 +88,7 @@ __global__ void rms_norm(
T xn[N_READS]; T xn[N_READS];
T wn[N_READS]; T wn[N_READS];
cub::LoadDirectBlocked(index, x, xn, axis_size); cub::LoadDirectBlocked(index, x, xn, axis_size);
cub::LoadDirectBlocked(index, strided_iterator(w, w_stride), wn, axis_size); cub::LoadDirectBlocked(index, StridedIterator(w, w_stride), wn, axis_size);
for (int i = 0; i < N_READS; ++i) { for (int i = 0; i < N_READS; ++i) {
float norm = static_cast<float>(xn[i]) * normalizer; float norm = static_cast<float>(xn[i]) * normalizer;
xn[i] = wn[i] * static_cast<T>(norm); xn[i] = wn[i] * static_cast<T>(norm);
@@ -132,7 +131,7 @@ __global__ void rms_norm_vjp(
auto index = r * BLOCK_DIM + block.thread_rank(); auto index = r * BLOCK_DIM + block.thread_rank();
cub::LoadDirectBlocked(index, x, xn, axis_size, cast_to<T>(0)); cub::LoadDirectBlocked(index, x, xn, axis_size, cast_to<T>(0));
cub::LoadDirectBlocked(index, g, gn, axis_size); cub::LoadDirectBlocked(index, g, gn, axis_size);
cub::LoadDirectBlocked(index, strided_iterator(w, w_stride), wn, axis_size); cub::LoadDirectBlocked(index, StridedIterator(w, w_stride), wn, axis_size);
for (int i = 0; i < N_READS; i++) { for (int i = 0; i < N_READS; i++) {
float t = static_cast<float>(xn[i]); float t = static_cast<float>(xn[i]);
float wi = wn[i]; float wi = wn[i];
@@ -154,7 +153,7 @@ __global__ void rms_norm_vjp(
T gn[N_READS]; T gn[N_READS];
cub::LoadDirectBlocked(index, x, xn, axis_size); cub::LoadDirectBlocked(index, x, xn, axis_size);
cub::LoadDirectBlocked(index, g, gn, axis_size); cub::LoadDirectBlocked(index, g, gn, axis_size);
cub::LoadDirectBlocked(index, strided_iterator(w, w_stride), wn, axis_size); cub::LoadDirectBlocked(index, StridedIterator(w, w_stride), wn, axis_size);
for (int i = 0; i < N_READS; i++) { for (int i = 0; i < N_READS; i++) {
float xi = xn[i]; float xi = xn[i];
float wi = wn[i]; float wi = wn[i];

View File

@@ -76,7 +76,7 @@ __global__ void ternary_g(
int ndim) { int ndim) {
IdxT index = cg::this_grid().thread_rank(); IdxT index = cg::this_grid().thread_rank();
if (index < size) { if (index < size) {
auto [a_idx, b_idx, c_idx] = elem_to_loc_4d( auto [a_idx, b_idx, c_idx] = elem_to_loc(
index, index,
shape.data(), shape.data(),
a_strides.data(), a_strides.data(),

View File

@@ -3,7 +3,6 @@
#include "mlx/backend/common/unary.h" #include "mlx/backend/common/unary.h"
#include "mlx/backend/cuda/device.h" #include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/device/unary_ops.cuh" #include "mlx/backend/cuda/device/unary_ops.cuh"
#include "mlx/backend/cuda/iterators/general_iterator.cuh"
#include "mlx/backend/cuda/kernel_utils.cuh" #include "mlx/backend/cuda/kernel_utils.cuh"
#include "mlx/dtype_utils.h" #include "mlx/dtype_utils.h"
#include "mlx/primitives.h" #include "mlx/primitives.h"
@@ -48,7 +47,7 @@ __global__ void unary_g(
int ndim) { int ndim) {
IdxT index = cg::this_grid().thread_rank(); IdxT index = cg::this_grid().thread_rank();
if (index < size) { if (index < size) {
auto idx = elem_to_loc_4d(index, shape.data(), strides.data(), ndim); auto idx = elem_to_loc(index, shape.data(), strides.data(), ndim);
out[index] = Op{}(in[idx]); out[index] = Op{}(in[idx]);
} }
} }

View File

@@ -1,7 +1,6 @@
// Copyright © 2025 Apple Inc. // Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/worker.h" #include "mlx/backend/cuda/worker.h"
#include "mlx/backend/cuda/allocator.h"
#include "mlx/backend/cuda/device.h" #include "mlx/backend/cuda/device.h"
namespace mlx::core::cu { namespace mlx::core::cu {
@@ -12,10 +11,10 @@ Worker::Worker()
Worker::~Worker() { Worker::~Worker() {
{ {
std::lock_guard lock(worker_mutex_); std::lock_guard lock(mtx_);
stop_ = true; stop_ = true;
} }
worker_event_.signal(batch_ + 1); cond_.notify_one();
worker_.join(); worker_.join();
} }
@@ -23,53 +22,41 @@ void Worker::add_task(std::function<void()> task) {
pending_tasks_.push_back(std::move(task)); pending_tasks_.push_back(std::move(task));
} }
void Worker::consume_in_this_thread() { void Worker::signal(void* data) {
for (auto& task : pending_tasks_) { auto w = static_cast<Worker*>(data);
task();
}
pending_tasks_.clear();
}
void Worker::end_batch() {
batch_++;
{ {
std::lock_guard lock(worker_mutex_); std::lock_guard lock(w->mtx_);
worker_tasks_[batch_] = std::move(pending_tasks_); w->signaled_batch_++;
} }
uncommited_batches_++; w->cond_.notify_one();
}
void Worker::commit() {
if (uncommited_batches_ == 0) {
return;
}
uncommited_batches_ = 0;
worker_event_.signal(batch_);
} }
void Worker::commit(cudaStream_t stream) { void Worker::commit(cudaStream_t stream) {
if (uncommited_batches_ == 0) { // Move pending tasks into tasks
if (pending_tasks_.empty()) {
return; return;
} }
uncommited_batches_ = 0; {
// Signal the |worker_event_| in |signal_stream_| after the kernels in std::lock_guard lock(mtx_);
// |stream_| finish running. // Move pending tasks into ready tasks
worker_tasks_[++committed_batch_] = std::move(pending_tasks_);
}
signal_event_.record(stream); signal_event_.record(stream);
signal_event_.wait(signal_stream_); signal_event_.wait(signal_stream_);
worker_event_.signal(signal_stream_, batch_); cudaLaunchHostFunc(signal_stream_, signal, this);
} }
void Worker::thread_fn() { void Worker::thread_fn() {
// The worker thread is safe to free buffers.
allocator().register_this_thread();
while (!stop_) { while (!stop_) {
uint64_t batch = worker_event_.value(); uint64_t current_batch = 0;
Tasks tasks; Tasks tasks;
{ {
std::lock_guard lock(worker_mutex_); std::unique_lock<std::mutex> lk(mtx_);
// Move tasks in signaled batches. cond_.wait(lk, [this, &current_batch] {
auto end = worker_tasks_.upper_bound(batch); return this->signaled_batch_ > current_batch || this->stop_;
});
current_batch = signaled_batch_;
auto end = worker_tasks_.upper_bound(current_batch);
for (auto it = worker_tasks_.begin(); it != end; ++it) { for (auto it = worker_tasks_.begin(); it != end; ++it) {
if (tasks.empty()) { if (tasks.empty()) {
tasks = std::move(it->second); tasks = std::move(it->second);
@@ -85,7 +72,6 @@ void Worker::thread_fn() {
auto task = std::move(tasks[i]); auto task = std::move(tasks[i]);
task(); task();
} }
worker_event_.wait(batch + 1);
} }
} }

View File

@@ -5,6 +5,7 @@
#include "mlx/backend/cuda/event.h" #include "mlx/backend/cuda/event.h"
#include "mlx/backend/cuda/utils.h" #include "mlx/backend/cuda/utils.h"
#include <condition_variable>
#include <functional> #include <functional>
#include <map> #include <map>
#include <mutex> #include <mutex>
@@ -24,38 +25,24 @@ class Worker {
// Add a pending |task| that will run when consumed or commited. // Add a pending |task| that will run when consumed or commited.
void add_task(std::function<void()> task); void add_task(std::function<void()> task);
// Run pending tasks immediately in current thread.
void consume_in_this_thread();
// Put pending tasks in a batch.
void end_batch();
// Inform worker thread to run current batches now.
void commit();
// Inform worker thread to run current batches after kernels in |stream| // Inform worker thread to run current batches after kernels in |stream|
// finish running. // finish running.
void commit(cudaStream_t stream); void commit(cudaStream_t stream);
// Return how many batches have been added but not committed yet.
size_t uncommited_batches() const {
return uncommited_batches_;
}
private: private:
void thread_fn(); static void signal(void*);
uint64_t batch_{0}; void thread_fn();
size_t uncommited_batches_{0}; std::mutex mtx_;
std::condition_variable cond_;
uint64_t committed_batch_{0};
uint64_t signaled_batch_{0};
// Cuda stream and event for signaling kernel completion. // Cuda stream and event for signaling kernel completion.
CudaStream signal_stream_; CudaStream signal_stream_;
CudaEvent signal_event_; CudaEvent signal_event_;
// Worker thread.
SharedEvent worker_event_;
std::thread worker_;
std::mutex worker_mutex_;
bool stop_{false}; bool stop_{false};
// Tasks are put in |pending_tasks_| first, and then moved to // Tasks are put in |pending_tasks_| first, and then moved to
@@ -63,6 +50,7 @@ class Worker {
using Tasks = std::vector<std::function<void()>>; using Tasks = std::vector<std::function<void()>>;
Tasks pending_tasks_; Tasks pending_tasks_;
std::map<uint64_t, Tasks> worker_tasks_; std::map<uint64_t, Tasks> worker_tasks_;
std::thread worker_;
}; };
} // namespace mlx::core::cu } // namespace mlx::core::cu

View File

@@ -128,8 +128,7 @@ Buffer MetalAllocator::malloc(size_t size) {
auto pool = metal::new_scoped_memory_pool(); auto pool = metal::new_scoped_memory_pool();
// If we have a lot of memory pressure or are over the maximum cache size, // If we have a lot of memory pressure try to reclaim memory from the cache
// try to reclaim memory from the cache
if (mem_required >= gc_limit_ || num_resources_ >= resource_limit_) { if (mem_required >= gc_limit_ || num_resources_ >= resource_limit_) {
num_resources_ -= num_resources_ -=
buffer_cache_.release_cached_buffers(mem_required - gc_limit_); buffer_cache_.release_cached_buffers(mem_required - gc_limit_);

View File

@@ -14,6 +14,10 @@ Event::Event(Stream stream) : stream_(stream) {
auto p = metal::new_scoped_memory_pool(); auto p = metal::new_scoped_memory_pool();
event_ = std::shared_ptr<void>( event_ = std::shared_ptr<void>(
metal::device(Device::gpu).mtl_device()->newSharedEvent(), dtor); metal::device(Device::gpu).mtl_device()->newSharedEvent(), dtor);
if (event_ == nullptr) {
throw std::runtime_error(
"[Event::Event] Failed to create Metal shared event.");
}
} }
void Event::wait() { void Event::wait() {

View File

@@ -708,7 +708,10 @@ array scaled_dot_product_attention(
} }
if (mask.dtype() == bool_) { if (mask.dtype() == bool_) {
scores = where( scores = where(
mask, scores, array(finfo(scores.dtype()).min, scores.dtype())); mask,
scores,
array(-std::numeric_limits<float>::infinity(), scores.dtype()),
s);
} else { } else {
scores = add(scores, mask, s); scores = add(scores, mask, s);
} }

View File

@@ -1271,19 +1271,6 @@ std::vector<array> Convolution::vjp(
has_neg_padding |= (pd < 0); has_neg_padding |= (pd < 0);
} }
auto padding_lo_ = std::vector<int>(padding_lo);
auto padding_hi_ = std::vector<int>(padding_hi);
// Use negative padding on the gradient output
if (has_neg_padding) {
for (auto& p : padding_lo_) {
p = std::max(0, p);
}
for (auto& p : padding_hi_) {
p = std::max(0, p);
}
}
auto wt_trans = group_transpose(wt, 0, 1, -1); auto wt_trans = group_transpose(wt, 0, 1, -1);
auto grad = conv_general( auto grad = conv_general(
/* const array& input = */ cotan, /* const array& input = */ cotan,
@@ -1305,12 +1292,9 @@ std::vector<array> Convolution::vjp(
for (int i = 0; i < grad.ndim() - 2; i++) { for (int i = 0; i < grad.ndim() - 2; i++) {
if (padding_lo[i] < 0) { if (padding_lo[i] < 0) {
starts[i + 1] -= padding_lo[i]; starts[i + 1] -= padding_lo[i];
padding_lo[i] = 0;
} }
if (padding_hi[i] < 0) { if (padding_hi[i] < 0) {
stops[i + 1] += padding_hi[i]; stops[i + 1] += padding_hi[i];
padding_hi[i] = 0;
} }
} }

View File

@@ -72,7 +72,12 @@ array eval_impl(std::vector<array> outputs, bool async) {
// Stream events for synchronization after eval // Stream events for synchronization after eval
std::unordered_map<uint32_t, Event> events; std::unordered_map<uint32_t, Event> events;
events.emplace(stream.index, Event{stream}); {
auto e = Event{stream};
e.set_value(1);
synchronizer.attach_event(e);
events.emplace(stream.index, std::move(e));
}
{ {
// Record the degree of each input // Record the degree of each input
@@ -184,21 +189,26 @@ array eval_impl(std::vector<array> outputs, bool async) {
} }
} }
std::unordered_set<int> open_streams;
while (!tape.empty()) { while (!tape.empty()) {
auto arr = std::move(tape.back()); auto arr = std::move(tape.back());
tape.pop_back(); tape.pop_back();
auto stream = arr.primitive().stream(); auto stream = arr.primitive().stream();
open_streams.insert(stream.index);
// Lookup corresponding event if (async) {
auto e = events.find(stream.index); // Lookup corresponding event
if (e == events.end()) { auto e = events.find(stream.index);
e = events.emplace(stream.index, Event{stream}).first; if (e == events.end()) {
} e = events.emplace(stream.index, Event{stream}).first;
e->second.set_value(1); }
arr.attach_event(e->second); e->second.set_value(1);
for (auto& s : arr.siblings()) { arr.attach_event(e->second);
s.attach_event(e->second); for (auto& s : arr.siblings()) {
s.attach_event(e->second);
}
} }
for (auto& in : arr.inputs()) { for (auto& in : arr.inputs()) {
@@ -227,9 +237,10 @@ array eval_impl(std::vector<array> outputs, bool async) {
(get_active_memory() > get_memory_limit() && (get_active_memory() > get_memory_limit() &&
scheduler::n_active_tasks() > 0)) { scheduler::n_active_tasks() > 0)) {
// Commit any open streams // Commit any open streams
for (auto& [_, e] : events) { for (auto i : open_streams) {
if (e.stream().device == Device::gpu) { auto s = get_stream(i);
gpu::finalize(e.stream()); if (s.device == Device::gpu) {
gpu::finalize(s);
} }
} }
scheduler::wait_for_one(); scheduler::wait_for_one();
@@ -263,9 +274,11 @@ array eval_impl(std::vector<array> outputs, bool async) {
} }
// Signal the event in its stream // Signal the event in its stream
for (auto& [_, e] : events) { for (auto i : open_streams) {
auto s = e.stream(); auto s = get_stream(i);
e.signal(s); if (auto e = events.find(i); e != events.end()) {
e->second.signal(s);
}
if (s.device == Device::gpu) { if (s.device == Device::gpu) {
gpu::finalize(s); gpu::finalize(s);
} }
@@ -302,7 +315,7 @@ void eval(std::vector<array> outputs) {
return; return;
} }
eval_impl(std::move(outputs), false).event().wait(); eval_impl(std::move(outputs), false).wait();
} }
std::pair<std::vector<array>, std::vector<array>> vjp( std::pair<std::vector<array>, std::vector<array>> vjp(

View File

@@ -1,9 +1,10 @@
#!/bin/bash #!/bin/bash
auditwheel repair dist/* \ auditwheel repair dist/* \
--plat manylinux_2_39_x86_64 \ --plat manylinux_2_35_x86_64 \
--exclude libcublas* \ --exclude libcublas* \
--exclude libnvrtc* \ --exclude libnvrtc* \
--exclude libcuda* \
-w wheel_tmp -w wheel_tmp

View File

@@ -4022,8 +4022,9 @@ void init_ops(nb::module_& m) {
Args: Args:
file (file, str): File in which the array is saved. file (file, str): File in which the array is saved.
arrays (dict(str, array)): The dictionary of names to arrays to arrays (dict(str, array)): The dictionary of names to arrays to
be saved. metadata (dict(str, str), optional): The dictionary of be saved.
metadata to be saved. metadata (dict(str, str), optional): The dictionary of
metadata to be saved.
)pbdoc"); )pbdoc");
m.def( m.def(
"save_gguf", "save_gguf",
@@ -4258,7 +4259,7 @@ void init_ops(nb::module_& m) {
.. math:: .. math::
w_i = s \hat{w_i} - \beta w_i = s \hat{w_i} + \beta
Args: Args:
w (array): Matrix to be quantized w (array): Matrix to be quantized

View File

@@ -398,6 +398,18 @@ class TestFastSDPA(mlx_tests.MLXTestCase):
) )
self.assertTrue(mx.allclose(ref, out, atol=1e-4, rtol=1e-4)) self.assertTrue(mx.allclose(ref, out, atol=1e-4, rtol=1e-4))
def test_fully_masked(self):
Lkv = 8
mask = mx.array(False)
for D in [4, 128]:
for Lq in [1, 8]:
q = mx.random.normal(shape=(1, 4, Lq, D))
k = mx.random.normal(shape=(1, 4, Lkv, D))
v = mx.random.normal(shape=(1, 4, Lkv, D))
out = mx.fast.scaled_dot_product_attention(q, k, v, mask=mask, scale=1)
self.assertTrue(mx.all(mx.isnan(out)))
def test_fast_sdpa_few_query(self): def test_fast_sdpa_few_query(self):
D = 64 D = 64
L = 43 L = 43

View File

@@ -9,7 +9,7 @@ from functools import partial
from pathlib import Path from pathlib import Path
from subprocess import run from subprocess import run
from setuptools import Command, Extension, setup from setuptools import Command, Extension, find_namespace_packages, setup
from setuptools.command.bdist_wheel import bdist_wheel from setuptools.command.bdist_wheel import bdist_wheel
from setuptools.command.build_ext import build_ext from setuptools.command.build_ext import build_ext
@@ -166,6 +166,10 @@ class GenerateStubs(Command):
# Run again without recursive to specify output file name # Run again without recursive to specify output file name
subprocess.run(["rm", f"{out_path}/mlx.pyi"]) subprocess.run(["rm", f"{out_path}/mlx.pyi"])
subprocess.run(stub_cmd + ["-o", f"{out_path}/__init__.pyi"]) subprocess.run(stub_cmd + ["-o", f"{out_path}/__init__.pyi"])
# mx.bool_ gets filtered by nanobind because of the trailing
# underscore, add it manually:
with open(f"{out_path}/__init__.pyi", "a") as fid:
fid.write("\nbool_: Dtype = ...")
class MLXBdistWheel(bdist_wheel): class MLXBdistWheel(bdist_wheel):
@@ -184,19 +188,23 @@ with open(Path(__file__).parent / "README.md", encoding="utf-8") as f:
if __name__ == "__main__": if __name__ == "__main__":
package_dir = {"": "python"} package_dir = {"": "python"}
packages = [ packages = find_namespace_packages(
"mlx", where="python",
"mlx.nn", exclude=[
"mlx.nn.layers", "src",
"mlx.optimizers", "tests",
] "scripts",
"mlx.lib",
"mlx.include",
"mlx.share",
"mlx.share.**",
"mlx.include.**",
],
)
build_macos = platform.system() == "Darwin" build_macos = platform.system() == "Darwin"
build_cuda = "MLX_BUILD_CUDA=ON" in os.environ.get("CMAKE_ARGS", "") build_cuda = "MLX_BUILD_CUDA=ON" in os.environ.get("CMAKE_ARGS", "")
install_requires = []
if build_cuda:
install_requires = ["nvidia-cublas-cu12", "nvidia-cuda-nvrtc-cu12"]
version = get_version() version = get_version()
_setup = partial( _setup = partial(
@@ -221,7 +229,7 @@ if __name__ == "__main__":
}, },
) )
package_data = {"mlx": ["lib/*", "include/*", "share/*"], "mlx.core": ["*.pyi"]} package_data = {"mlx.core": ["*.pyi"]}
extras = { extras = {
"dev": [ "dev": [
@@ -239,6 +247,7 @@ if __name__ == "__main__":
"mlx.distributed_config = mlx.distributed_run:distributed_config", "mlx.distributed_config = mlx.distributed_run:distributed_config",
] ]
} }
install_requires = []
# Release builds for PyPi are in two stages. # Release builds for PyPi are in two stages.
# Each stage should be run from a clean build: # Each stage should be run from a clean build:
@@ -258,11 +267,11 @@ if __name__ == "__main__":
# - Package name is back-end specific, e.g mlx-metal # - Package name is back-end specific, e.g mlx-metal
if build_stage != 2: if build_stage != 2:
if build_stage == 1: if build_stage == 1:
if build_macos: install_requires.append(
install_requires += [f"mlx-metal=={version}"] f'mlx-metal=={version}; platform_system == "Darwin"'
else: )
extras["cuda"] = [f"mlx-cuda=={version}"] extras["cuda"] = [f'mlx-cuda=={version}; platform_system == "Linux"']
extras["cpu"] = [f"mlx-cpu=={version}"] extras["cpu"] = [f'mlx-cpu=={version}; platform_system == "Linux"']
_setup( _setup(
name="mlx", name="mlx",
@@ -277,9 +286,14 @@ if __name__ == "__main__":
name = "mlx-metal" name = "mlx-metal"
elif build_cuda: elif build_cuda:
name = "mlx-cuda" name = "mlx-cuda"
install_requires += [
"nvidia-cublas-cu12==12.9.*",
"nvidia-cuda-nvrtc-cu12==12.9.*",
]
else: else:
name = "mlx-cpu" name = "mlx-cpu"
_setup( _setup(
name=name, name=name,
packages=["mlx"], packages=["mlx"],
install_requires=install_requires,
) )