mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Compare commits
13 Commits
cd4b12ce1b
...
ibv-backen
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
dd91ee9534 | ||
|
|
8fab4f0929 | ||
|
|
47af2c8cb0 | ||
|
|
f40152ebc1 | ||
|
|
5d7e6a0642 | ||
|
|
b9b78b1059 | ||
|
|
45727b0c02 | ||
|
|
2444fbdfe9 | ||
|
|
f3b605e53c | ||
|
|
0388ae3aaf | ||
|
|
d4c1de4a8b | ||
|
|
4dbffb3954 | ||
|
|
b1a60b2d2d |
11
.github/actions/build-cuda-release/action.yml
vendored
11
.github/actions/build-cuda-release/action.yml
vendored
@@ -1,15 +1,6 @@
|
|||||||
name: 'Build CUDA wheel'
|
name: 'Build CUDA wheel'
|
||||||
description: 'Build CUDA wheel'
|
description: 'Build CUDA wheel'
|
||||||
|
|
||||||
inputs:
|
|
||||||
arch:
|
|
||||||
description: 'Platform architecture tag'
|
|
||||||
required: true
|
|
||||||
type: choice
|
|
||||||
options:
|
|
||||||
- x86_64
|
|
||||||
- aarch64
|
|
||||||
|
|
||||||
runs:
|
runs:
|
||||||
using: "composite"
|
using: "composite"
|
||||||
steps:
|
steps:
|
||||||
@@ -21,4 +12,4 @@ runs:
|
|||||||
pip install auditwheel build patchelf setuptools
|
pip install auditwheel build patchelf setuptools
|
||||||
python setup.py clean --all
|
python setup.py clean --all
|
||||||
MLX_BUILD_STAGE=2 python -m build -w
|
MLX_BUILD_STAGE=2 python -m build -w
|
||||||
bash python/scripts/repair_cuda.sh ${{ inputs.arch }}
|
bash python/scripts/repair_cuda.sh
|
||||||
|
|||||||
1
.github/actions/setup-linux/action.yml
vendored
1
.github/actions/setup-linux/action.yml
vendored
@@ -15,7 +15,6 @@ runs:
|
|||||||
using: "composite"
|
using: "composite"
|
||||||
steps:
|
steps:
|
||||||
- name: Use ccache
|
- name: Use ccache
|
||||||
if: ${{ runner.arch == 'x86_64' }}
|
|
||||||
uses: hendrikmuhs/ccache-action@v1.2
|
uses: hendrikmuhs/ccache-action@v1.2
|
||||||
with:
|
with:
|
||||||
key: ccache-${{ runner.os }}-${{ runner.arch }}-${{ inputs.toolkit }}-py${{ inputs.python-version }}
|
key: ccache-${{ runner.os }}-${{ runner.arch }}-${{ inputs.toolkit }}-py${{ inputs.python-version }}
|
||||||
|
|||||||
10
.github/workflows/release.yml
vendored
10
.github/workflows/release.yml
vendored
@@ -128,11 +128,7 @@ jobs:
|
|||||||
|
|
||||||
build_cuda_release:
|
build_cuda_release:
|
||||||
if: github.repository == 'ml-explore/mlx'
|
if: github.repository == 'ml-explore/mlx'
|
||||||
strategy:
|
runs-on: ubuntu-22-large
|
||||||
matrix:
|
|
||||||
arch: ['x86_64', 'aarch64']
|
|
||||||
toolkit: ['cuda-12.9', 'cuda-13.0']
|
|
||||||
runs-on: ${{ matrix.arch == 'x86_64' && 'ubuntu-22-large' || 'ubuntu-22-large-arm' }}
|
|
||||||
env:
|
env:
|
||||||
PYPI_RELEASE: 1
|
PYPI_RELEASE: 1
|
||||||
DEV_RELEASE: ${{ github.event.inputs.dev_release == 'true' && 1 || 0 }}
|
DEV_RELEASE: ${{ github.event.inputs.dev_release == 'true' && 1 || 0 }}
|
||||||
@@ -140,11 +136,9 @@ jobs:
|
|||||||
- uses: actions/checkout@v6
|
- uses: actions/checkout@v6
|
||||||
- uses: ./.github/actions/setup-linux
|
- uses: ./.github/actions/setup-linux
|
||||||
with:
|
with:
|
||||||
toolkit: ${{ matrix.toolkit }}
|
toolkit: 'cuda-12.9'
|
||||||
- name: Build Python package
|
- name: Build Python package
|
||||||
uses: ./.github/actions/build-cuda-release
|
uses: ./.github/actions/build-cuda-release
|
||||||
with:
|
|
||||||
arch: ${{ matrix.arch }}
|
|
||||||
- name: Upload artifacts
|
- name: Upload artifacts
|
||||||
uses: actions/upload-artifact@v5
|
uses: actions/upload-artifact@v5
|
||||||
with:
|
with:
|
||||||
|
|||||||
@@ -29,20 +29,17 @@ MLX has a CUDA backend which you can install with:
|
|||||||
|
|
||||||
.. code-block:: shell
|
.. code-block:: shell
|
||||||
|
|
||||||
pip install mlx[cuda12]
|
pip install mlx[cuda]
|
||||||
|
|
||||||
|
|
||||||
To install the CUDA package from PyPi your system must meet the following
|
To install the CUDA package from PyPi your system must meet the following
|
||||||
requirements:
|
requirements:
|
||||||
|
|
||||||
- Nvidia architecture >= SM 7.5
|
- Nvidia architecture >= SM 7.0 (Volta)
|
||||||
- Nvidia driver >= 550.54.14
|
- Nvidia driver >= 550.54.14
|
||||||
- CUDA toolkit >= 12.0
|
- CUDA toolkit >= 12.0
|
||||||
- Linux distribution with glibc >= 2.35
|
- Linux distribution with glibc >= 2.35
|
||||||
- Python >= 3.10
|
- Python >= 3.10
|
||||||
|
|
||||||
For CUDA 13 use ``pip install mlx[cuda13]``. The CUDA 13 package requires
|
|
||||||
an Nvidia driver >= 580 or an appropriate CUDA compatibility package.
|
|
||||||
|
|
||||||
CPU-only (Linux)
|
CPU-only (Linux)
|
||||||
^^^^^^^^^^^^^^^^
|
^^^^^^^^^^^^^^^^
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
target_sources(
|
target_sources(
|
||||||
mlx
|
mlx
|
||||||
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/array.cpp
|
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/array.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/compile.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/compile.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/dtype.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/dtype.cpp
|
||||||
|
|||||||
24
mlx/allocator.cpp
Normal file
24
mlx/allocator.cpp
Normal file
@@ -0,0 +1,24 @@
|
|||||||
|
// Copyright © 2023 Apple Inc.
|
||||||
|
|
||||||
|
#include <cstdlib>
|
||||||
|
#include <sstream>
|
||||||
|
|
||||||
|
#include "mlx/allocator.h"
|
||||||
|
|
||||||
|
namespace mlx::core::allocator {
|
||||||
|
|
||||||
|
Buffer malloc(size_t size) {
|
||||||
|
auto buffer = allocator().malloc(size);
|
||||||
|
if (size && !buffer.ptr()) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << "[malloc] Unable to allocate " << size << " bytes.";
|
||||||
|
throw std::runtime_error(msg.str());
|
||||||
|
}
|
||||||
|
return buffer;
|
||||||
|
}
|
||||||
|
|
||||||
|
void free(Buffer buffer) {
|
||||||
|
allocator().free(buffer);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core::allocator
|
||||||
@@ -28,16 +28,16 @@ class Buffer {
|
|||||||
};
|
};
|
||||||
};
|
};
|
||||||
|
|
||||||
|
Buffer malloc(size_t size);
|
||||||
|
|
||||||
|
void free(Buffer buffer);
|
||||||
|
|
||||||
class Allocator {
|
class Allocator {
|
||||||
/** Abstract base class for a memory allocator. */
|
/** Abstract base class for a memory allocator. */
|
||||||
public:
|
public:
|
||||||
virtual Buffer malloc(size_t size) = 0;
|
virtual Buffer malloc(size_t size) = 0;
|
||||||
virtual void free(Buffer buffer) = 0;
|
virtual void free(Buffer buffer) = 0;
|
||||||
virtual size_t size(Buffer buffer) const = 0;
|
virtual size_t size(Buffer buffer) const = 0;
|
||||||
virtual Buffer make_buffer(void* ptr, size_t size) {
|
|
||||||
return Buffer{nullptr};
|
|
||||||
};
|
|
||||||
virtual void release(Buffer buffer) {}
|
|
||||||
|
|
||||||
Allocator() = default;
|
Allocator() = default;
|
||||||
Allocator(const Allocator& other) = delete;
|
Allocator(const Allocator& other) = delete;
|
||||||
@@ -49,25 +49,4 @@ class Allocator {
|
|||||||
|
|
||||||
Allocator& allocator();
|
Allocator& allocator();
|
||||||
|
|
||||||
inline Buffer malloc(size_t size) {
|
|
||||||
return allocator().malloc(size);
|
|
||||||
}
|
|
||||||
|
|
||||||
inline void free(Buffer buffer) {
|
|
||||||
allocator().free(buffer);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Make a Buffer from a raw pointer of the given size without a copy. If a
|
|
||||||
// no-copy conversion is not possible then the returned buffer.ptr() will be
|
|
||||||
// nullptr. Any buffer created with this function must be released with
|
|
||||||
// release(buffer)
|
|
||||||
inline Buffer make_buffer(void* ptr, size_t size) {
|
|
||||||
return allocator().make_buffer(ptr, size);
|
|
||||||
};
|
|
||||||
|
|
||||||
// Release a buffer from the allocator made with make_buffer
|
|
||||||
inline void release(Buffer buffer) {
|
|
||||||
allocator().release(buffer);
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace mlx::core::allocator
|
} // namespace mlx::core::allocator
|
||||||
|
|||||||
@@ -82,28 +82,6 @@ array::array(std::initializer_list<int> data, Dtype dtype)
|
|||||||
init(data.begin());
|
init(data.begin());
|
||||||
}
|
}
|
||||||
|
|
||||||
array::array(
|
|
||||||
void* data,
|
|
||||||
Shape shape,
|
|
||||||
Dtype dtype,
|
|
||||||
const std::function<void(void*)>& deleter)
|
|
||||||
: array_desc_(std::make_shared<ArrayDesc>(std::move(shape), dtype)) {
|
|
||||||
auto buffer = allocator::make_buffer(data, nbytes());
|
|
||||||
if (buffer.ptr() == nullptr) {
|
|
||||||
set_data(allocator::malloc(nbytes()));
|
|
||||||
auto ptr = static_cast<char*>(data);
|
|
||||||
std::copy(ptr, ptr + nbytes(), this->data<char>());
|
|
||||||
deleter(data);
|
|
||||||
} else {
|
|
||||||
auto wrapped_deleter = [deleter](allocator::Buffer buffer) {
|
|
||||||
auto ptr = buffer.ptr();
|
|
||||||
allocator::release(buffer);
|
|
||||||
return deleter(ptr);
|
|
||||||
};
|
|
||||||
set_data(buffer, std::move(wrapped_deleter));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/* Build an array from a shared buffer */
|
/* Build an array from a shared buffer */
|
||||||
array::array(allocator::Buffer data, Shape shape, Dtype dtype, Deleter deleter)
|
array::array(allocator::Buffer data, Shape shape, Dtype dtype, Deleter deleter)
|
||||||
: array_desc_(std::make_shared<ArrayDesc>(std::move(shape), dtype)) {
|
: array_desc_(std::make_shared<ArrayDesc>(std::move(shape), dtype)) {
|
||||||
|
|||||||
10
mlx/array.h
10
mlx/array.h
@@ -57,16 +57,6 @@ class array {
|
|||||||
Shape shape,
|
Shape shape,
|
||||||
Dtype dtype = TypeToDtype<T>());
|
Dtype dtype = TypeToDtype<T>());
|
||||||
|
|
||||||
/* Build an array from a raw pointer. The constructor will attempt to use the
|
|
||||||
* input data without a copy. The deleter will be called when the array no
|
|
||||||
* longer needs the underlying memory - after the array is destroyed in the
|
|
||||||
* no-copy case and after the copy otherwise. */
|
|
||||||
explicit array(
|
|
||||||
void* data,
|
|
||||||
Shape shape,
|
|
||||||
Dtype dtype,
|
|
||||||
const std::function<void(void*)>& deleter);
|
|
||||||
|
|
||||||
/* Build an array from a buffer */
|
/* Build an array from a buffer */
|
||||||
explicit array(
|
explicit array(
|
||||||
allocator::Buffer data,
|
allocator::Buffer data,
|
||||||
|
|||||||
@@ -20,19 +20,6 @@ constexpr int page_size = 16384;
|
|||||||
// Any allocations smaller than this will try to use the small pool
|
// Any allocations smaller than this will try to use the small pool
|
||||||
constexpr int small_block_size = 8;
|
constexpr int small_block_size = 8;
|
||||||
|
|
||||||
#if CUDART_VERSION >= 13000
|
|
||||||
inline cudaMemLocation cuda_mem_loc(int i) {
|
|
||||||
cudaMemLocation loc;
|
|
||||||
loc.type = cudaMemLocationTypeDevice;
|
|
||||||
loc.id = i;
|
|
||||||
return loc;
|
|
||||||
}
|
|
||||||
#else
|
|
||||||
inline int cuda_mem_loc(int i) {
|
|
||||||
return i;
|
|
||||||
}
|
|
||||||
#endif // CUDART_VERSION >= 13000
|
|
||||||
|
|
||||||
// The small pool size in bytes. This should be a multiple of the host page
|
// The small pool size in bytes. This should be a multiple of the host page
|
||||||
// size and small_block_size.
|
// size and small_block_size.
|
||||||
constexpr int small_pool_size = 4 * page_size;
|
constexpr int small_pool_size = 4 * page_size;
|
||||||
@@ -48,7 +35,13 @@ SmallSizePool::SmallSizePool() {
|
|||||||
int device_count = 0;
|
int device_count = 0;
|
||||||
CHECK_CUDA_ERROR(cudaGetDeviceCount(&device_count));
|
CHECK_CUDA_ERROR(cudaGetDeviceCount(&device_count));
|
||||||
for (int i = 0; i < device_count; ++i) {
|
for (int i = 0; i < device_count; ++i) {
|
||||||
auto loc = cuda_mem_loc(i);
|
#if CUDART_VERSION >= 13000
|
||||||
|
cudaMemLocation loc;
|
||||||
|
loc.type = cudaMemLocationTypeDevice;
|
||||||
|
loc.id = i;
|
||||||
|
#else
|
||||||
|
int loc = i;
|
||||||
|
#endif // CUDART_VERSION >= 13000
|
||||||
CHECK_CUDA_ERROR(
|
CHECK_CUDA_ERROR(
|
||||||
cudaMemAdvise(data_, small_pool_size, cudaMemAdviseSetAccessedBy, loc));
|
cudaMemAdvise(data_, small_pool_size, cudaMemAdviseSetAccessedBy, loc));
|
||||||
}
|
}
|
||||||
@@ -97,10 +90,9 @@ CudaAllocator::CudaAllocator()
|
|||||||
page_size,
|
page_size,
|
||||||
[](CudaBuffer* buf) { return buf->size; },
|
[](CudaBuffer* buf) { return buf->size; },
|
||||||
[this](CudaBuffer* buf) { cuda_free(buf); }) {
|
[this](CudaBuffer* buf) { cuda_free(buf); }) {
|
||||||
size_t free;
|
size_t free, total;
|
||||||
CHECK_CUDA_ERROR(cudaMemGetInfo(&free, &total_memory_));
|
CHECK_CUDA_ERROR(cudaMemGetInfo(&free, &total));
|
||||||
memory_limit_ = total_memory_ * 0.95;
|
memory_limit_ = total * 0.9;
|
||||||
free_limit_ = total_memory_ - memory_limit_;
|
|
||||||
max_pool_size_ = memory_limit_;
|
max_pool_size_ = memory_limit_;
|
||||||
|
|
||||||
int device_count = 0;
|
int device_count = 0;
|
||||||
@@ -112,10 +104,6 @@ CudaAllocator::CudaAllocator()
|
|||||||
cudaStream_t s;
|
cudaStream_t s;
|
||||||
CHECK_CUDA_ERROR(cudaStreamCreateWithFlags(&s, cudaStreamNonBlocking));
|
CHECK_CUDA_ERROR(cudaStreamCreateWithFlags(&s, cudaStreamNonBlocking));
|
||||||
free_streams_.push_back(s);
|
free_streams_.push_back(s);
|
||||||
|
|
||||||
cudaMemPool_t mem_pool;
|
|
||||||
CHECK_CUDA_ERROR(cudaDeviceGetDefaultMemPool(&mem_pool, i));
|
|
||||||
mem_pools_.push_back(mem_pool);
|
|
||||||
}
|
}
|
||||||
CHECK_CUDA_ERROR(cudaSetDevice(curr));
|
CHECK_CUDA_ERROR(cudaSetDevice(curr));
|
||||||
}
|
}
|
||||||
@@ -166,35 +154,23 @@ CudaAllocator::malloc_async(size_t size, int device, cudaStream_t stream) {
|
|||||||
}
|
}
|
||||||
lock.unlock();
|
lock.unlock();
|
||||||
if (!buf) {
|
if (!buf) {
|
||||||
|
cudaError_t err;
|
||||||
void* data = nullptr;
|
void* data = nullptr;
|
||||||
if (device == -1) {
|
if (device == -1) {
|
||||||
CHECK_CUDA_ERROR(cudaMallocManaged(&data, size));
|
err = cudaMallocManaged(&data, size);
|
||||||
} else {
|
} else {
|
||||||
CHECK_CUDA_ERROR(cudaMallocAsync(&data, size, stream));
|
err = cudaMallocAsync(&data, size, stream);
|
||||||
|
}
|
||||||
|
if (err != cudaSuccess && err != cudaErrorMemoryAllocation) {
|
||||||
|
throw std::runtime_error(fmt::format(
|
||||||
|
"cudaMallocManaged failed: {}.", cudaGetErrorString(err)));
|
||||||
}
|
}
|
||||||
if (!data) {
|
if (!data) {
|
||||||
std::ostringstream msg;
|
return Buffer{nullptr};
|
||||||
msg << "[malloc] Unable to allocate " << size << " bytes.";
|
|
||||||
throw std::runtime_error(msg.str());
|
|
||||||
}
|
}
|
||||||
buf = new CudaBuffer{data, size, device};
|
buf = new CudaBuffer{data, size, device};
|
||||||
}
|
}
|
||||||
lock.lock();
|
lock.lock();
|
||||||
|
|
||||||
// If any cuda memory pool has too much reserved memory, clear some
|
|
||||||
// memory from the cache. This prevents graph / kernel execution failing
|
|
||||||
// from OOM
|
|
||||||
if (get_cache_memory() > 0) {
|
|
||||||
for (auto p : mem_pools_) {
|
|
||||||
size_t used = 0;
|
|
||||||
CHECK_CUDA_ERROR(cudaMemPoolGetAttribute(
|
|
||||||
p, cudaMemPoolAttrReservedMemCurrent, &used));
|
|
||||||
if (used > (total_memory_ - free_limit_)) {
|
|
||||||
buffer_cache_.release_cached_buffers(free_limit_);
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
active_memory_ += buf->size;
|
active_memory_ += buf->size;
|
||||||
peak_memory_ = std::max(active_memory_, peak_memory_);
|
peak_memory_ = std::max(active_memory_, peak_memory_);
|
||||||
|
|||||||
@@ -71,14 +71,11 @@ class CudaAllocator : public allocator::Allocator {
|
|||||||
|
|
||||||
std::mutex mutex_;
|
std::mutex mutex_;
|
||||||
size_t memory_limit_;
|
size_t memory_limit_;
|
||||||
size_t free_limit_;
|
|
||||||
size_t total_memory_;
|
|
||||||
size_t max_pool_size_;
|
size_t max_pool_size_;
|
||||||
BufferCache<CudaBuffer> buffer_cache_;
|
BufferCache<CudaBuffer> buffer_cache_;
|
||||||
size_t active_memory_{0};
|
size_t active_memory_{0};
|
||||||
size_t peak_memory_{0};
|
size_t peak_memory_{0};
|
||||||
std::vector<cudaStream_t> free_streams_;
|
std::vector<cudaStream_t> free_streams_;
|
||||||
std::vector<cudaMemPool_t> mem_pools_;
|
|
||||||
SmallSizePool scalar_pool_;
|
SmallSizePool scalar_pool_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|||||||
@@ -95,14 +95,11 @@ void copy_general_input(
|
|||||||
const InType* in_ptr = gpu_ptr<InType>(in) + offset_in;
|
const InType* in_ptr = gpu_ptr<InType>(in) + offset_in;
|
||||||
OutType* out_ptr = gpu_ptr<OutType>(out) + offset_out;
|
OutType* out_ptr = gpu_ptr<OutType>(out) + offset_out;
|
||||||
int ndim = shape.size();
|
int ndim = shape.size();
|
||||||
|
int work_per_thread = 1;
|
||||||
int work_per_thread = 8;
|
|
||||||
auto dim0 = ndim > 0 ? shape.back() : 1;
|
auto dim0 = ndim > 0 ? shape.back() : 1;
|
||||||
auto rest = out.size() / dim0;
|
auto rest = out.size() / dim0;
|
||||||
if (dim0 >= 4 && dim0 < 8) {
|
if (dim0 >= 4) {
|
||||||
work_per_thread = 4;
|
work_per_thread = 4;
|
||||||
} else if (dim0 < 4) {
|
|
||||||
work_per_thread = 1;
|
|
||||||
}
|
}
|
||||||
dim0 = (dim0 + work_per_thread - 1) / work_per_thread;
|
dim0 = (dim0 + work_per_thread - 1) / work_per_thread;
|
||||||
auto block_dims = get_block_dims(dim0, rest, 1);
|
auto block_dims = get_block_dims(dim0, rest, 1);
|
||||||
@@ -113,10 +110,7 @@ void copy_general_input(
|
|||||||
dispatch_1_2_3(ndim, [&](auto dims_constant) {
|
dispatch_1_2_3(ndim, [&](auto dims_constant) {
|
||||||
auto kernel =
|
auto kernel =
|
||||||
cu::copy_g_nd<InType, OutType, IdxT, dims_constant(), 1>;
|
cu::copy_g_nd<InType, OutType, IdxT, dims_constant(), 1>;
|
||||||
if (work_per_thread == 8) {
|
if (work_per_thread == 4) {
|
||||||
kernel =
|
|
||||||
cu::copy_g_nd<InType, OutType, IdxT, dims_constant(), 8>;
|
|
||||||
} else if (work_per_thread == 4) {
|
|
||||||
kernel =
|
kernel =
|
||||||
cu::copy_g_nd<InType, OutType, IdxT, dims_constant(), 4>;
|
cu::copy_g_nd<InType, OutType, IdxT, dims_constant(), 4>;
|
||||||
}
|
}
|
||||||
@@ -133,9 +127,7 @@ void copy_general_input(
|
|||||||
});
|
});
|
||||||
} else { // ndim >= 4
|
} else { // ndim >= 4
|
||||||
auto kernel = cu::copy_g<InType, OutType, IdxT, 1>;
|
auto kernel = cu::copy_g<InType, OutType, IdxT, 1>;
|
||||||
if (work_per_thread == 8) {
|
if (work_per_thread == 4) {
|
||||||
kernel = cu::copy_g<InType, OutType, IdxT, 8>;
|
|
||||||
} else if (work_per_thread == 4) {
|
|
||||||
kernel = cu::copy_g<InType, OutType, IdxT, 4>;
|
kernel = cu::copy_g<InType, OutType, IdxT, 4>;
|
||||||
}
|
}
|
||||||
encoder.add_kernel_node(
|
encoder.add_kernel_node(
|
||||||
|
|||||||
@@ -318,52 +318,46 @@ void CommandEncoder::add_kernel_node(const CUDA_KERNEL_NODE_PARAMS& params) {
|
|||||||
insert_graph_dependencies(GraphNode{node, "K"});
|
insert_graph_dependencies(GraphNode{node, "K"});
|
||||||
}
|
}
|
||||||
|
|
||||||
std::pair<std::string, bool> subgraph_to_key(cudaGraph_t graph) {
|
bool is_graph_updatable(cudaGraph_t graph, int& cluster_dim_x) {
|
||||||
// Constructs a key representing the nodes of a sub-graph.
|
// CUDA graphs do not get updated correctly if a kernel node getting updated
|
||||||
// Also checks if the sub-graph is updatable as CUDA graphs do not get
|
// has a different cluster shape than the node it's being updated with.
|
||||||
// updated correctly if a kernel node getting updated has a different cluster
|
|
||||||
// shape than the node it's being updated with.
|
|
||||||
std::string key = "(";
|
|
||||||
size_t num_nodes = 0;
|
size_t num_nodes = 0;
|
||||||
CHECK_CUDA_ERROR(cudaGraphGetNodes(graph, nullptr, &num_nodes));
|
CHECK_CUDA_ERROR(cudaGraphGetNodes(graph, nullptr, &num_nodes));
|
||||||
if (num_nodes == 0) {
|
if (num_nodes == 0) {
|
||||||
return {key + ")", true};
|
return true;
|
||||||
}
|
}
|
||||||
bool is_updatable = true;
|
|
||||||
std::vector<cudaGraphNode_t> nodes(num_nodes);
|
std::vector<cudaGraphNode_t> nodes(num_nodes);
|
||||||
CHECK_CUDA_ERROR(cudaGraphGetNodes(graph, nodes.data(), &num_nodes));
|
CHECK_CUDA_ERROR(cudaGraphGetNodes(graph, nodes.data(), &num_nodes));
|
||||||
for (const auto& node : nodes) {
|
for (const auto& node : nodes) {
|
||||||
if (!is_updatable) {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
cudaGraphNodeType type;
|
cudaGraphNodeType type;
|
||||||
CHECK_CUDA_ERROR(cudaGraphNodeGetType(node, &type));
|
CHECK_CUDA_ERROR(cudaGraphNodeGetType(node, &type));
|
||||||
if (type == cudaGraphNodeTypeGraph) {
|
if (type == cudaGraphNodeTypeGraph) {
|
||||||
// Try to be updatable for a structure like graph -> graph -> kernel
|
// Try to be updatable for a structure like graph -> graph -> kernel
|
||||||
|
if (num_nodes > 1) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
cudaGraph_t child;
|
cudaGraph_t child;
|
||||||
CHECK_CUDA_ERROR(cudaGraphChildGraphNodeGetGraph(node, &child));
|
CHECK_CUDA_ERROR(cudaGraphChildGraphNodeGetGraph(node, &child));
|
||||||
auto [subkey, sub_is_updatable] = subgraph_to_key(child);
|
return is_graph_updatable(child, cluster_dim_x);
|
||||||
is_updatable &= sub_is_updatable;
|
|
||||||
key += subkey;
|
|
||||||
} else if (type == cudaGraphNodeTypeMemset) {
|
|
||||||
key += "M";
|
|
||||||
} else if (type != cudaGraphNodeTypeKernel) {
|
} else if (type != cudaGraphNodeTypeKernel) {
|
||||||
is_updatable = false;
|
return false;
|
||||||
} else {
|
} else {
|
||||||
cudaLaunchAttributeValue cluster_dim;
|
cudaLaunchAttributeValue cluster_dim;
|
||||||
CHECK_CUDA_ERROR(cudaGraphKernelNodeGetAttribute(
|
CHECK_CUDA_ERROR(cudaGraphKernelNodeGetAttribute(
|
||||||
node, cudaLaunchAttributeClusterDimension, &cluster_dim));
|
node, cudaLaunchAttributeClusterDimension, &cluster_dim));
|
||||||
// Only allow dim.x to be greater than 1
|
// Only dim.x can be greater than 1
|
||||||
if (cluster_dim.clusterDim.y > 1 || cluster_dim.clusterDim.z > 1) {
|
if (cluster_dim.clusterDim.y > 1 || cluster_dim.clusterDim.z > 1) {
|
||||||
is_updatable = false;
|
return false;
|
||||||
} else {
|
}
|
||||||
key += "K";
|
// Only one child node allowed when subgraph uses clusters
|
||||||
key += std::to_string(cluster_dim.clusterDim.x);
|
if (cluster_dim.clusterDim.x > 0 && num_nodes > 1) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
cluster_dim_x = cluster_dim.clusterDim.x;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
return true;
|
||||||
key += ")";
|
|
||||||
return {key, is_updatable};
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void CommandEncoder::add_graph_node(cudaGraph_t child) {
|
void CommandEncoder::add_graph_node(cudaGraph_t child) {
|
||||||
@@ -376,10 +370,11 @@ void CommandEncoder::add_graph_node(cudaGraph_t child) {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
cudaGraphNode_t node;
|
cudaGraphNode_t node;
|
||||||
auto [sub_graph_key, is_updatable] = subgraph_to_key(child);
|
int cluster_dim_x = 0;
|
||||||
is_graph_updatable_ &= is_updatable;
|
is_graph_updatable_ &= is_graph_updatable(child, cluster_dim_x);
|
||||||
CHECK_CUDA_ERROR(cudaGraphAddChildGraphNode(&node, graph_, NULL, 0, child));
|
CHECK_CUDA_ERROR(cudaGraphAddChildGraphNode(&node, graph_, NULL, 0, child));
|
||||||
insert_graph_dependencies(GraphNode{node, sub_graph_key});
|
insert_graph_dependencies(
|
||||||
|
GraphNode{node, "G" + std::to_string(cluster_dim_x)});
|
||||||
}
|
}
|
||||||
|
|
||||||
bool CommandEncoder::needs_commit() {
|
bool CommandEncoder::needs_commit() {
|
||||||
|
|||||||
@@ -106,7 +106,7 @@ class CommandEncoder {
|
|||||||
cudaGraphNode_t node;
|
cudaGraphNode_t node;
|
||||||
// K = kernel
|
// K = kernel
|
||||||
// E = empty
|
// E = empty
|
||||||
// () = subgraph (with metadata)
|
// G* = subgraph (with metadata)
|
||||||
// Symbols ':', '-' are reserved as separators
|
// Symbols ':', '-' are reserved as separators
|
||||||
std::string node_type;
|
std::string node_type;
|
||||||
std::string id;
|
std::string id;
|
||||||
|
|||||||
@@ -89,13 +89,9 @@ template <
|
|||||||
int NDIM,
|
int NDIM,
|
||||||
int BM,
|
int BM,
|
||||||
int BN,
|
int BN,
|
||||||
int N_READS = 4,
|
int N_READS = 4>
|
||||||
int BLOCKS = 1>
|
__global__ void
|
||||||
__global__ void col_reduce_looped(
|
col_reduce_looped(T* in, U* out, const __grid_constant__ ColReduceArgs args) {
|
||||||
T* in,
|
|
||||||
U* out,
|
|
||||||
const __grid_constant__ ColReduceArgs args,
|
|
||||||
int64_t out_size) {
|
|
||||||
auto grid = cg::this_grid();
|
auto grid = cg::this_grid();
|
||||||
auto block = cg::this_thread_block();
|
auto block = cg::this_thread_block();
|
||||||
auto warp = cg::tiled_partition<WARP_SIZE>(block);
|
auto warp = cg::tiled_partition<WARP_SIZE>(block);
|
||||||
@@ -106,8 +102,6 @@ __global__ void col_reduce_looped(
|
|||||||
size_t tile_idx = grid.block_rank();
|
size_t tile_idx = grid.block_rank();
|
||||||
size_t tile_x = tile_idx % ((args.reduction_stride + BN - 1) / BN);
|
size_t tile_x = tile_idx % ((args.reduction_stride + BN - 1) / BN);
|
||||||
size_t tile_y = tile_idx / ((args.reduction_stride + BN - 1) / BN);
|
size_t tile_y = tile_idx / ((args.reduction_stride + BN - 1) / BN);
|
||||||
size_t tile_out = tile_y / out_size;
|
|
||||||
tile_y = tile_y % out_size;
|
|
||||||
|
|
||||||
// Compute the indices for the thread within the tile
|
// Compute the indices for the thread within the tile
|
||||||
short thread_x = block.thread_rank() % threads_per_row;
|
short thread_x = block.thread_rank() % threads_per_row;
|
||||||
@@ -124,23 +118,12 @@ __global__ void col_reduce_looped(
|
|||||||
totals[i] = ReduceInit<Op, T>::value();
|
totals[i] = ReduceInit<Op, T>::value();
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t total = args.non_col_reductions * args.reduction_size;
|
|
||||||
size_t per_block, start, end;
|
|
||||||
if constexpr (BLOCKS > 1) {
|
|
||||||
per_block = (total + BLOCKS - 1) / BLOCKS;
|
|
||||||
start = tile_out * per_block + thread_y;
|
|
||||||
end = min((tile_out + 1) * per_block, total);
|
|
||||||
} else {
|
|
||||||
per_block = total;
|
|
||||||
start = thread_y;
|
|
||||||
end = total;
|
|
||||||
}
|
|
||||||
|
|
||||||
LoopedElemToLoc<NDIM, (NDIM > 2)> loop(args.reduce_ndim);
|
LoopedElemToLoc<NDIM, (NDIM > 2)> loop(args.reduce_ndim);
|
||||||
loop.next(start, args.reduce_shape.data(), args.reduce_strides.data());
|
loop.next(thread_y, args.reduce_shape.data(), args.reduce_strides.data());
|
||||||
|
size_t total = args.non_col_reductions * args.reduction_size;
|
||||||
if (tile_x * BN + BN <= args.reduction_stride) {
|
if (tile_x * BN + BN <= args.reduction_stride) {
|
||||||
if (args.reduction_stride % N_READS == 0) {
|
if (args.reduction_stride % N_READS == 0) {
|
||||||
for (size_t r = start; r < end; r += BM) {
|
for (size_t r = thread_y; r < total; r += BM) {
|
||||||
T vals[N_READS];
|
T vals[N_READS];
|
||||||
cub::LoadDirectBlockedVectorized(thread_x, in + loop.location(), vals);
|
cub::LoadDirectBlockedVectorized(thread_x, in + loop.location(), vals);
|
||||||
for (int i = 0; i < N_READS; i++) {
|
for (int i = 0; i < N_READS; i++) {
|
||||||
@@ -149,7 +132,7 @@ __global__ void col_reduce_looped(
|
|||||||
loop.next(BM, args.reduce_shape.data(), args.reduce_strides.data());
|
loop.next(BM, args.reduce_shape.data(), args.reduce_strides.data());
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
for (size_t r = start; r < end; r += BM) {
|
for (size_t r = thread_y; r < total; r += BM) {
|
||||||
T vals[N_READS];
|
T vals[N_READS];
|
||||||
cub::LoadDirectBlocked(thread_x, in + loop.location(), vals);
|
cub::LoadDirectBlocked(thread_x, in + loop.location(), vals);
|
||||||
for (int i = 0; i < N_READS; i++) {
|
for (int i = 0; i < N_READS; i++) {
|
||||||
@@ -159,7 +142,7 @@ __global__ void col_reduce_looped(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
for (size_t r = start; r < end; r += BM) {
|
for (size_t r = thread_y; r < total; r += BM) {
|
||||||
T vals[N_READS];
|
T vals[N_READS];
|
||||||
cub::LoadDirectBlocked(
|
cub::LoadDirectBlocked(
|
||||||
thread_x,
|
thread_x,
|
||||||
@@ -190,9 +173,6 @@ __global__ void col_reduce_looped(
|
|||||||
|
|
||||||
// Write result.
|
// Write result.
|
||||||
if (warp.thread_rank() == 0) {
|
if (warp.thread_rank() == 0) {
|
||||||
if (BLOCKS > 1) {
|
|
||||||
out += tile_out * out_size * args.reduction_stride;
|
|
||||||
}
|
|
||||||
cub::StoreDirectBlocked(
|
cub::StoreDirectBlocked(
|
||||||
warp.meta_group_rank(),
|
warp.meta_group_rank(),
|
||||||
out + tile_y * args.reduction_stride + tile_x * BN,
|
out + tile_y * args.reduction_stride + tile_x * BN,
|
||||||
@@ -247,12 +227,11 @@ __global__ void col_reduce_small(
|
|||||||
inline auto output_grid_for_col_reduce(
|
inline auto output_grid_for_col_reduce(
|
||||||
const array& out,
|
const array& out,
|
||||||
const cu::ColReduceArgs& args,
|
const cu::ColReduceArgs& args,
|
||||||
int bn,
|
int bn) {
|
||||||
int outer = 1) {
|
|
||||||
int gx, gy = 1;
|
int gx, gy = 1;
|
||||||
size_t n_inner_blocks = cuda::ceil_div(args.reduction_stride, bn);
|
size_t n_inner_blocks = cuda::ceil_div(args.reduction_stride, bn);
|
||||||
size_t n_outer_blocks = out.size() / args.reduction_stride;
|
size_t n_outer_blocks = out.size() / args.reduction_stride;
|
||||||
size_t n_blocks = n_outer_blocks * n_inner_blocks * outer;
|
size_t n_blocks = n_outer_blocks * n_inner_blocks;
|
||||||
while (n_blocks / gy > INT32_MAX) {
|
while (n_blocks / gy > INT32_MAX) {
|
||||||
gy *= 2;
|
gy *= 2;
|
||||||
}
|
}
|
||||||
@@ -298,8 +277,7 @@ void col_reduce_looped(
|
|||||||
0,
|
0,
|
||||||
indata,
|
indata,
|
||||||
gpu_ptr<U>(out),
|
gpu_ptr<U>(out),
|
||||||
static_cast<cu::ColReduceArgs>(args),
|
static_cast<cu::ColReduceArgs>(args));
|
||||||
out.size() / args.reduction_stride);
|
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
@@ -342,117 +320,6 @@ void col_reduce_small(
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
void col_reduce_two_pass(
|
|
||||||
cu::CommandEncoder& encoder,
|
|
||||||
const array& in,
|
|
||||||
array& out,
|
|
||||||
Reduce::ReduceType reduce_type,
|
|
||||||
const std::vector<int>& axes,
|
|
||||||
const ReductionPlan& plan,
|
|
||||||
const cu::ColReduceArgs& args) {
|
|
||||||
// Allocate data for the output using in's layout to access them as
|
|
||||||
// contiguously as possible.
|
|
||||||
allocate_same_layout(out, in, axes, encoder);
|
|
||||||
|
|
||||||
// Allocate an intermediate array to hold the 1st pass result
|
|
||||||
constexpr int outer = 32;
|
|
||||||
|
|
||||||
Shape intermediate_shape;
|
|
||||||
intermediate_shape.push_back(outer);
|
|
||||||
intermediate_shape.insert(
|
|
||||||
intermediate_shape.end(), out.shape().begin(), out.shape().end());
|
|
||||||
|
|
||||||
Strides intermediate_strides;
|
|
||||||
intermediate_strides.push_back(out.size());
|
|
||||||
intermediate_strides.insert(
|
|
||||||
intermediate_strides.end(), out.strides().begin(), out.strides().end());
|
|
||||||
|
|
||||||
array intermediate(intermediate_shape, out.dtype(), nullptr, {});
|
|
||||||
auto [data_size, rc, cc] =
|
|
||||||
check_contiguity(intermediate_shape, intermediate_strides);
|
|
||||||
auto fl = out.flags();
|
|
||||||
fl.row_contiguous = rc;
|
|
||||||
fl.col_contiguous = cc;
|
|
||||||
fl.contiguous = true;
|
|
||||||
intermediate.set_data(
|
|
||||||
cu::malloc_async(intermediate.nbytes(), encoder),
|
|
||||||
data_size,
|
|
||||||
intermediate_strides,
|
|
||||||
fl,
|
|
||||||
allocator::free);
|
|
||||||
|
|
||||||
encoder.add_temporary(intermediate);
|
|
||||||
encoder.set_input_array(in);
|
|
||||||
encoder.set_output_array(intermediate);
|
|
||||||
dispatch_all_types(in.dtype(), [&](auto type_tag) {
|
|
||||||
dispatch_reduce_ops(reduce_type, [&](auto reduce_type_tag) {
|
|
||||||
dispatch_reduce_ndim(args.reduce_ndim, [&](auto reduce_ndim) {
|
|
||||||
using OP = MLX_GET_TYPE(reduce_type_tag);
|
|
||||||
using T = cuda_type_t<MLX_GET_TYPE(type_tag)>;
|
|
||||||
using U = typename cu::ReduceResult<OP, T>::type;
|
|
||||||
// Cub doesn't like const pointers for vectorized loads. (sigh)
|
|
||||||
T* indata = const_cast<T*>(gpu_ptr<T>(in));
|
|
||||||
|
|
||||||
constexpr int N_READS = 4;
|
|
||||||
constexpr int BM = 32;
|
|
||||||
constexpr int BN = 32;
|
|
||||||
dim3 grid = output_grid_for_col_reduce(out, args, BN, outer);
|
|
||||||
int blocks = BM * BN / N_READS;
|
|
||||||
auto kernel = cu::
|
|
||||||
col_reduce_looped<T, U, OP, reduce_ndim(), BM, BN, N_READS, outer>;
|
|
||||||
encoder.add_kernel_node(
|
|
||||||
kernel,
|
|
||||||
grid,
|
|
||||||
blocks,
|
|
||||||
0,
|
|
||||||
indata,
|
|
||||||
gpu_ptr<U>(intermediate),
|
|
||||||
static_cast<cu::ColReduceArgs>(args),
|
|
||||||
out.size() / args.reduction_stride);
|
|
||||||
});
|
|
||||||
});
|
|
||||||
});
|
|
||||||
|
|
||||||
// Prepare the reduction arguments for the 2nd pass
|
|
||||||
cu::ColReduceArgs second_args = args;
|
|
||||||
second_args.reduction_size = outer;
|
|
||||||
second_args.reduction_stride = out.size();
|
|
||||||
second_args.ndim = 0;
|
|
||||||
second_args.reduce_shape[0] = outer;
|
|
||||||
second_args.reduce_strides[0] = out.size();
|
|
||||||
second_args.reduce_ndim = 1;
|
|
||||||
second_args.non_col_reductions = 1;
|
|
||||||
|
|
||||||
encoder.set_input_array(intermediate);
|
|
||||||
encoder.set_output_array(out);
|
|
||||||
dispatch_all_types(intermediate.dtype(), [&](auto type_tag) {
|
|
||||||
dispatch_reduce_ops(reduce_type, [&](auto reduce_type_tag) {
|
|
||||||
dispatch_reduce_ndim(second_args.reduce_ndim, [&](auto reduce_ndim) {
|
|
||||||
using OP = MLX_GET_TYPE(reduce_type_tag);
|
|
||||||
using T = cuda_type_t<MLX_GET_TYPE(type_tag)>;
|
|
||||||
using U = typename cu::ReduceResult<OP, T>::type;
|
|
||||||
|
|
||||||
constexpr int N_READS = 4;
|
|
||||||
constexpr int BM = 32;
|
|
||||||
constexpr int BN = 32;
|
|
||||||
dim3 grid = output_grid_for_col_reduce(out, second_args, BN);
|
|
||||||
int blocks = BM * BN / N_READS;
|
|
||||||
auto kernel =
|
|
||||||
cu::col_reduce_looped<T, U, OP, reduce_ndim(), BM, BN, N_READS>;
|
|
||||||
encoder.add_kernel_node(
|
|
||||||
kernel,
|
|
||||||
grid,
|
|
||||||
blocks,
|
|
||||||
0,
|
|
||||||
gpu_ptr<T>(intermediate),
|
|
||||||
gpu_ptr<U>(out),
|
|
||||||
second_args,
|
|
||||||
second_args.reduction_stride);
|
|
||||||
});
|
|
||||||
});
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
void col_reduce(
|
void col_reduce(
|
||||||
cu::CommandEncoder& encoder,
|
cu::CommandEncoder& encoder,
|
||||||
const array& in,
|
const array& in,
|
||||||
@@ -467,18 +334,6 @@ void col_reduce(
|
|||||||
// It is a general strided reduce. Each threadblock computes the output for
|
// It is a general strided reduce. Each threadblock computes the output for
|
||||||
// a subrow of the fast moving axis. For instance 32 elements.
|
// a subrow of the fast moving axis. For instance 32 elements.
|
||||||
//
|
//
|
||||||
// - col_reduce_small
|
|
||||||
//
|
|
||||||
// It is a column reduce for small columns. Each thread loops over the whole
|
|
||||||
// column without communicating with any other thread.
|
|
||||||
//
|
|
||||||
// - col_reduce_two_pass
|
|
||||||
//
|
|
||||||
// It is a reduce for long columns. To increase parallelism, we split the
|
|
||||||
// reduction in two passes. First we do a column reduce where many
|
|
||||||
// threadblocks operate on different parts of the reduced axis. Then we
|
|
||||||
// perform a final column reduce.
|
|
||||||
//
|
|
||||||
// Notes: As in row reduce we opt to read as much in order as possible and
|
// Notes: As in row reduce we opt to read as much in order as possible and
|
||||||
// leave transpositions as they are (contrary to our Metal backend).
|
// leave transpositions as they are (contrary to our Metal backend).
|
||||||
//
|
//
|
||||||
@@ -494,14 +349,6 @@ void col_reduce(
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Long column with smallish row
|
|
||||||
size_t total_sums = args.non_col_reductions * args.reduction_size;
|
|
||||||
size_t approx_threads = out.size();
|
|
||||||
if (total_sums / approx_threads > 32) {
|
|
||||||
col_reduce_two_pass(encoder, in, out, reduce_type, axes, plan, args);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Fallback col reduce
|
// Fallback col reduce
|
||||||
col_reduce_looped(encoder, in, out, reduce_type, axes, plan, args);
|
col_reduce_looped(encoder, in, out, reduce_type, axes, plan, args);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -80,6 +80,7 @@ CudaGraph::CudaGraph(cu::Device& device) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void CudaGraph::end_capture(cudaStream_t stream) {
|
void CudaGraph::end_capture(cudaStream_t stream) {
|
||||||
|
assert(handle_ == nullptr);
|
||||||
CHECK_CUDA_ERROR(cudaStreamEndCapture(stream, &handle_));
|
CHECK_CUDA_ERROR(cudaStreamEndCapture(stream, &handle_));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -7,6 +7,8 @@
|
|||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
|
|
||||||
|
void copy_gpu(const array& in, array& out, CopyType ctype, const Stream& s);
|
||||||
|
|
||||||
void copy_gpu(const array& in, array& out, CopyType ctype) {
|
void copy_gpu(const array& in, array& out, CopyType ctype) {
|
||||||
copy_gpu(in, out, ctype, out.primitive().stream());
|
copy_gpu(in, out, ctype, out.primitive().stream());
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -149,9 +149,7 @@ Buffer MetalAllocator::malloc(size_t size) {
|
|||||||
buf = device_->newBuffer(size, resource_options);
|
buf = device_->newBuffer(size, resource_options);
|
||||||
}
|
}
|
||||||
if (!buf) {
|
if (!buf) {
|
||||||
std::ostringstream msg;
|
return Buffer{nullptr};
|
||||||
msg << "[malloc] Unable to allocate " << size << " bytes.";
|
|
||||||
throw std::runtime_error(msg.str());
|
|
||||||
}
|
}
|
||||||
lk.lock();
|
lk.lock();
|
||||||
num_resources_++;
|
num_resources_++;
|
||||||
@@ -203,32 +201,6 @@ size_t MetalAllocator::size(Buffer buffer) const {
|
|||||||
return static_cast<MTL::Buffer*>(buffer.ptr())->length();
|
return static_cast<MTL::Buffer*>(buffer.ptr())->length();
|
||||||
}
|
}
|
||||||
|
|
||||||
Buffer MetalAllocator::make_buffer(void* ptr, size_t size) {
|
|
||||||
auto buf = device_->newBuffer(ptr, size, resource_options, nullptr);
|
|
||||||
if (!buf) {
|
|
||||||
return Buffer{nullptr};
|
|
||||||
}
|
|
||||||
std::unique_lock lk(mutex_);
|
|
||||||
residency_set_.insert(buf);
|
|
||||||
active_memory_ += buf->length();
|
|
||||||
peak_memory_ = std::max(peak_memory_, active_memory_);
|
|
||||||
num_resources_++;
|
|
||||||
return Buffer{static_cast<void*>(buf)};
|
|
||||||
}
|
|
||||||
|
|
||||||
void MetalAllocator::release(Buffer buffer) {
|
|
||||||
auto buf = static_cast<MTL::Buffer*>(buffer.ptr());
|
|
||||||
if (buf == nullptr) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
std::unique_lock lk(mutex_);
|
|
||||||
active_memory_ -= buf->length();
|
|
||||||
num_resources_--;
|
|
||||||
lk.unlock();
|
|
||||||
auto pool = metal::new_scoped_memory_pool();
|
|
||||||
buf->release();
|
|
||||||
}
|
|
||||||
|
|
||||||
MetalAllocator& allocator() {
|
MetalAllocator& allocator() {
|
||||||
// By creating the |allocator_| on heap, the destructor of MetalAllocator
|
// By creating the |allocator_| on heap, the destructor of MetalAllocator
|
||||||
// will not be called on exit and buffers in the cache will be leaked. This
|
// will not be called on exit and buffers in the cache will be leaked. This
|
||||||
|
|||||||
@@ -21,9 +21,6 @@ class MetalAllocator : public allocator::Allocator {
|
|||||||
virtual Buffer malloc(size_t size) override;
|
virtual Buffer malloc(size_t size) override;
|
||||||
virtual void free(Buffer buffer) override;
|
virtual void free(Buffer buffer) override;
|
||||||
virtual size_t size(Buffer buffer) const override;
|
virtual size_t size(Buffer buffer) const override;
|
||||||
virtual Buffer make_buffer(void* ptr, size_t size) override;
|
|
||||||
virtual void release(Buffer buffer) override;
|
|
||||||
|
|
||||||
size_t get_active_memory() {
|
size_t get_active_memory() {
|
||||||
return active_memory_;
|
return active_memory_;
|
||||||
};
|
};
|
||||||
|
|||||||
@@ -25,7 +25,6 @@ class CommonAllocator : public Allocator {
|
|||||||
virtual Buffer malloc(size_t size) override;
|
virtual Buffer malloc(size_t size) override;
|
||||||
virtual void free(Buffer buffer) override;
|
virtual void free(Buffer buffer) override;
|
||||||
virtual size_t size(Buffer buffer) const override;
|
virtual size_t size(Buffer buffer) const override;
|
||||||
|
|
||||||
size_t get_active_memory() const {
|
size_t get_active_memory() const {
|
||||||
return active_memory_;
|
return active_memory_;
|
||||||
};
|
};
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
#!/bin/bash
|
#!/bin/bash
|
||||||
|
|
||||||
auditwheel repair dist/* \
|
auditwheel repair dist/* \
|
||||||
--plat manylinux_2_35_${1} \
|
--plat manylinux_2_35_x86_64 \
|
||||||
--exclude libcublas* \
|
--exclude libcublas* \
|
||||||
--exclude libnvrtc* \
|
--exclude libnvrtc* \
|
||||||
--exclude libcuda* \
|
--exclude libcuda* \
|
||||||
|
|||||||
@@ -210,14 +210,6 @@ class TestReduce(mlx_tests.MLXTestCase):
|
|||||||
ref = getattr(np, op)(np_arr, axis=axis)
|
ref = getattr(np, op)(np_arr, axis=axis)
|
||||||
self.assertTrue(np.array_equal(out, ref, equal_nan=True))
|
self.assertTrue(np.array_equal(out, ref, equal_nan=True))
|
||||||
|
|
||||||
def test_long_column(self):
|
|
||||||
a = (np.random.randn(8192, 64) * 32).astype(np.int32)
|
|
||||||
b = mx.array(a)
|
|
||||||
|
|
||||||
c1 = a.sum(0)
|
|
||||||
c2 = b.sum(0)
|
|
||||||
self.assertTrue(np.all(c1 == c2))
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
mlx_tests.MLXTestRunner(failfast=True)
|
mlx_tests.MLXTestRunner(failfast=True)
|
||||||
|
|||||||
36
setup.py
36
setup.py
@@ -7,21 +7,13 @@ import re
|
|||||||
import subprocess
|
import subprocess
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from subprocess import run
|
||||||
|
|
||||||
from setuptools import Command, Extension, find_namespace_packages, setup
|
from setuptools import Command, Extension, find_namespace_packages, setup
|
||||||
from setuptools.command.bdist_wheel import bdist_wheel
|
from setuptools.command.bdist_wheel import bdist_wheel
|
||||||
from setuptools.command.build_ext import build_ext
|
from setuptools.command.build_ext import build_ext
|
||||||
|
|
||||||
|
|
||||||
def cuda_toolkit_major_version():
|
|
||||||
out = subprocess.check_output(["nvcc", "--version"], stderr=subprocess.STDOUT)
|
|
||||||
text = out.decode()
|
|
||||||
m = re.search(r"release (\d+)", text)
|
|
||||||
if m:
|
|
||||||
return int(m.group(1))
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def get_version():
|
def get_version():
|
||||||
with open("mlx/version.h", "r") as fid:
|
with open("mlx/version.h", "r") as fid:
|
||||||
for l in fid:
|
for l in fid:
|
||||||
@@ -39,7 +31,7 @@ def get_version():
|
|||||||
version = f"{version}.dev{today.year}{today.month:02d}{today.day:02d}"
|
version = f"{version}.dev{today.year}{today.month:02d}{today.day:02d}"
|
||||||
if not pypi_release and not dev_release:
|
if not pypi_release and not dev_release:
|
||||||
git_hash = (
|
git_hash = (
|
||||||
subprocess.run(
|
run(
|
||||||
"git rev-parse --short HEAD".split(),
|
"git rev-parse --short HEAD".split(),
|
||||||
capture_output=True,
|
capture_output=True,
|
||||||
check=True,
|
check=True,
|
||||||
@@ -292,11 +284,7 @@ if __name__ == "__main__":
|
|||||||
install_requires.append(
|
install_requires.append(
|
||||||
f'mlx-metal=={version}; platform_system == "Darwin"'
|
f'mlx-metal=={version}; platform_system == "Darwin"'
|
||||||
)
|
)
|
||||||
extras["cuda"] = [f'mlx-cuda-12=={version}; platform_system == "Linux"']
|
extras["cuda"] = [f'mlx-cuda=={version}; platform_system == "Linux"']
|
||||||
for toolkit in [12, 13]:
|
|
||||||
extras[f"cuda{toolkit}"] = [
|
|
||||||
f'mlx-cuda-{toolkit}=={version}; platform_system == "Linux"'
|
|
||||||
]
|
|
||||||
extras["cpu"] = [f'mlx-cpu=={version}; platform_system == "Linux"']
|
extras["cpu"] = [f'mlx-cpu=={version}; platform_system == "Linux"']
|
||||||
|
|
||||||
_setup(
|
_setup(
|
||||||
@@ -311,25 +299,13 @@ if __name__ == "__main__":
|
|||||||
if build_macos:
|
if build_macos:
|
||||||
name = "mlx-metal"
|
name = "mlx-metal"
|
||||||
elif build_cuda:
|
elif build_cuda:
|
||||||
toolkit = cuda_toolkit_major_version()
|
name = "mlx-cuda"
|
||||||
name = f"mlx-cuda-{toolkit}"
|
|
||||||
if toolkit == 12:
|
|
||||||
install_requires += [
|
install_requires += [
|
||||||
"nvidia-cublas-cu12==12.9.*",
|
"nvidia-cublas-cu12==12.9.*",
|
||||||
"nvidia-cuda-nvrtc-cu12==12.9.*",
|
"nvidia-cuda-nvrtc-cu12==12.9.*",
|
||||||
|
"nvidia-cudnn-cu12==9.*",
|
||||||
|
"nvidia-nccl-cu12",
|
||||||
]
|
]
|
||||||
elif toolkit == 13:
|
|
||||||
install_requires += [
|
|
||||||
"nvidia-cublas-cu13",
|
|
||||||
"nvidia-cuda-nvrtc-cu13",
|
|
||||||
]
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unknown toolkit {toolkit}")
|
|
||||||
install_requires += [
|
|
||||||
f"nvidia-cudnn-cu{toolkit}==9.*",
|
|
||||||
f"nvidia-nccl-cu{toolkit}",
|
|
||||||
]
|
|
||||||
|
|
||||||
else:
|
else:
|
||||||
name = "mlx-cpu"
|
name = "mlx-cpu"
|
||||||
_setup(
|
_setup(
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
// Copyright © 2023 Apple Inc.
|
// Copyright © 2023 Apple Inc.
|
||||||
|
|
||||||
#include <climits>
|
#include <climits>
|
||||||
|
|
||||||
#include "doctest/doctest.h"
|
#include "doctest/doctest.h"
|
||||||
@@ -607,24 +608,3 @@ TEST_CASE("test make empty array") {
|
|||||||
CHECK_EQ(a.size(), 0);
|
CHECK_EQ(a.size(), 0);
|
||||||
CHECK_EQ(a.dtype(), bool_);
|
CHECK_EQ(a.dtype(), bool_);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_CASE("test make array from user buffer") {
|
|
||||||
int size = 4096;
|
|
||||||
std::vector<int> buffer(size, 0);
|
|
||||||
|
|
||||||
int count = 0;
|
|
||||||
auto deleter = [&count](void*) { count++; };
|
|
||||||
|
|
||||||
{
|
|
||||||
auto a = array(buffer.data(), Shape{size}, int32, deleter);
|
|
||||||
if (metal::is_available()) {
|
|
||||||
CHECK_EQ(buffer.data(), a.data<int>());
|
|
||||||
}
|
|
||||||
auto b = a + array(1);
|
|
||||||
eval(b);
|
|
||||||
auto expected = ones({4096});
|
|
||||||
CHECK(array_equal(b, expected).item<bool>());
|
|
||||||
}
|
|
||||||
// deleter should always get called
|
|
||||||
CHECK_EQ(count, 1);
|
|
||||||
}
|
|
||||||
|
|||||||
Reference in New Issue
Block a user