mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Compare commits
21 Commits
997cfc7699
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c2764d1073 | ||
|
|
093a62d2ed | ||
|
|
1b591ec736 | ||
|
|
47d2505ea9 | ||
|
|
bedefed784 | ||
|
|
ccaaa7d6df | ||
|
|
f3e5ca5414 | ||
|
|
81dfe5f137 | ||
|
|
012fb220a1 | ||
|
|
e1fee0074b | ||
|
|
3c8ce9b00e | ||
|
|
937ce79660 | ||
|
|
208f5441a7 | ||
|
|
b862d842e1 | ||
|
|
f7a400951a | ||
|
|
27232db1ba | ||
|
|
a4b3bc969b | ||
|
|
667c0f3bb9 | ||
|
|
6245824d42 | ||
|
|
39289ef025 | ||
|
|
aefc9bd3f6 |
2
.github/actions/build-macos/action.yml
vendored
2
.github/actions/build-macos/action.yml
vendored
@@ -11,7 +11,7 @@ runs:
|
||||
shell: bash -l {0}
|
||||
run: |
|
||||
pip install --upgrade pip
|
||||
pip install cmake setuptools nanobind==2.4.0
|
||||
pip install cmake setuptools nanobind==2.10.2
|
||||
pip install -e . -v
|
||||
|
||||
- name: Generate package stubs
|
||||
|
||||
22
.github/actions/setup-linux/action.yml
vendored
22
.github/actions/setup-linux/action.yml
vendored
@@ -10,23 +10,29 @@ inputs:
|
||||
description: 'Version of python to set up'
|
||||
required: false
|
||||
default: '3.10'
|
||||
use-ccache:
|
||||
description: 'Whether to enable ccache'
|
||||
required: false
|
||||
default: 'true'
|
||||
|
||||
runs:
|
||||
using: "composite"
|
||||
steps:
|
||||
- name: Use ccache
|
||||
if: ${{ runner.arch == 'x86_64' }}
|
||||
uses: hendrikmuhs/ccache-action@v1.2
|
||||
with:
|
||||
key: ccache-${{ runner.os }}-${{ runner.arch }}-${{ inputs.toolkit }}-py${{ inputs.python-version }}
|
||||
max-size: 1GB
|
||||
|
||||
- name: Install common dependencies
|
||||
shell: bash
|
||||
run: |
|
||||
sudo apt-get update
|
||||
sudo apt-get install -y libblas-dev liblapack-dev liblapacke-dev zip
|
||||
|
||||
- name: Use ccache
|
||||
if: ${{ inputs.use-ccache == 'true' }}
|
||||
uses: hendrikmuhs/ccache-action@v1.2
|
||||
with:
|
||||
key: ccache-${{ runner.os }}-${{ runner.arch }}-${{ inputs.toolkit }}
|
||||
max-size: 1GB
|
||||
# ccache-action bug: running "apt-get update" fails on large arm runner.
|
||||
update-package-index: false
|
||||
|
||||
- uses: actions/setup-python@v6
|
||||
with:
|
||||
python-version: ${{ inputs.python-version }}
|
||||
@@ -36,7 +42,7 @@ runs:
|
||||
run: |
|
||||
python -m venv .venv
|
||||
source .venv/bin/activate
|
||||
pip install setuptools cmake nanobind==2.4.0
|
||||
pip install setuptools cmake nanobind==2.10.2
|
||||
echo PATH=$PATH >> $GITHUB_ENV
|
||||
# Make cmake search .venv for nanobind
|
||||
echo PYTHONPATH=`python -c 'import sys; print(sys.path[-1])'` >> $GITHUB_ENV
|
||||
|
||||
6
.github/workflows/nightly.yml
vendored
6
.github/workflows/nightly.yml
vendored
@@ -23,14 +23,14 @@ jobs:
|
||||
build-backend: ${{ matrix.python-version == '3.10' }}
|
||||
arch: "x86_64"
|
||||
- name: Upload mlx artifacts
|
||||
uses: actions/upload-artifact@v5
|
||||
uses: actions/upload-artifact@v6
|
||||
with:
|
||||
name: linux-wheels-${{ matrix.python_version }}
|
||||
path: wheelhouse/mlx-*.whl
|
||||
retention-days: 7
|
||||
- name: Upload mlx-cpu artifacts
|
||||
if: matrix.python_version == '3.10'
|
||||
uses: actions/upload-artifact@v5
|
||||
uses: actions/upload-artifact@v6
|
||||
with:
|
||||
name: mlx-cpu
|
||||
path: wheelhouse/mlx_cpu-*.whl
|
||||
@@ -89,7 +89,7 @@ jobs:
|
||||
with:
|
||||
toolkit: 'cuda-12.9'
|
||||
- name: Upload artifacts
|
||||
uses: actions/upload-artifact@v5
|
||||
uses: actions/upload-artifact@v6
|
||||
with:
|
||||
name: mlx-cuda
|
||||
path: wheelhouse/mlx_cuda-*.whl
|
||||
|
||||
27
.github/workflows/release.yml
vendored
27
.github/workflows/release.yml
vendored
@@ -57,19 +57,20 @@ jobs:
|
||||
- uses: ./.github/actions/setup-linux
|
||||
with:
|
||||
python-version: ${{ matrix.python_version }}
|
||||
use-ccache: false
|
||||
- uses: ./.github/actions/build-linux-release
|
||||
with:
|
||||
build-backend: ${{ matrix.python-version == '3.10' }}
|
||||
arch: ${{ matrix.arch }}
|
||||
- name: Upload MLX artifacts
|
||||
uses: actions/upload-artifact@v5
|
||||
uses: actions/upload-artifact@v6
|
||||
with:
|
||||
overwrite: true
|
||||
name: linux-wheels-${{ matrix.python_version }}-${{ matrix.arch }}
|
||||
path: wheelhouse/mlx-*.whl
|
||||
- name: Upload CPU artifacts
|
||||
if: matrix.python_version == '3.10'
|
||||
uses: actions/upload-artifact@v5
|
||||
uses: actions/upload-artifact@v6
|
||||
with:
|
||||
overwrite: true
|
||||
name: mlx-cpu-${{ matrix.arch }}
|
||||
@@ -95,7 +96,7 @@ jobs:
|
||||
shell: bash -l {0}
|
||||
run: |
|
||||
pip install --upgrade pip
|
||||
pip install cmake setuptools nanobind==2.4.0
|
||||
pip install cmake setuptools nanobind==2.10.2
|
||||
pip install -e . -v
|
||||
- name: Generate package stubs
|
||||
shell: bash -l {0}
|
||||
@@ -113,14 +114,14 @@ jobs:
|
||||
macos-target: 15.0
|
||||
build-backend: ${{ matrix.python-version == '3.10' }}
|
||||
- name: Upload MLX artifacts
|
||||
uses: actions/upload-artifact@v5
|
||||
uses: actions/upload-artifact@v6
|
||||
with:
|
||||
overwrite: true
|
||||
name: mac-wheels-${{ matrix.python-version }}
|
||||
path: dist/mlx-*.whl
|
||||
- name: Upload Metal artifacts
|
||||
if: matrix.python-version == '3.10'
|
||||
uses: actions/upload-artifact@v5
|
||||
uses: actions/upload-artifact@v6
|
||||
with:
|
||||
overwrite: true
|
||||
name: mlx-metal
|
||||
@@ -131,6 +132,7 @@ jobs:
|
||||
strategy:
|
||||
matrix:
|
||||
arch: ['x86_64', 'aarch64']
|
||||
toolkit: ['cuda-12.9', 'cuda-13.0']
|
||||
runs-on: ${{ matrix.arch == 'x86_64' && 'ubuntu-22-large' || 'ubuntu-22-large-arm' }}
|
||||
env:
|
||||
PYPI_RELEASE: 1
|
||||
@@ -139,13 +141,14 @@ jobs:
|
||||
- uses: actions/checkout@v6
|
||||
- uses: ./.github/actions/setup-linux
|
||||
with:
|
||||
toolkit: 'cuda-12.9'
|
||||
toolkit: ${{ matrix.toolkit }}
|
||||
use-ccache: false
|
||||
- name: Build Python package
|
||||
uses: ./.github/actions/build-cuda-release
|
||||
with:
|
||||
arch: ${{ matrix.arch }}
|
||||
- name: Upload artifacts
|
||||
uses: actions/upload-artifact@v5
|
||||
uses: actions/upload-artifact@v6
|
||||
with:
|
||||
overwrite: true
|
||||
name: mlx-cuda
|
||||
@@ -161,12 +164,12 @@ jobs:
|
||||
name: pypi
|
||||
url: https://pypi.org/p/mlx
|
||||
steps:
|
||||
- uses: actions/download-artifact@v6
|
||||
- uses: actions/download-artifact@v7
|
||||
with:
|
||||
pattern: linux-wheels-*
|
||||
merge-multiple: true
|
||||
path: dist
|
||||
- uses: actions/download-artifact@v6
|
||||
- uses: actions/download-artifact@v7
|
||||
with:
|
||||
pattern: mac-wheels-*
|
||||
merge-multiple: true
|
||||
@@ -188,7 +191,7 @@ jobs:
|
||||
name: pypi
|
||||
url: https://pypi.org/p/mlx-cuda
|
||||
steps:
|
||||
- uses: actions/download-artifact@v6
|
||||
- uses: actions/download-artifact@v7
|
||||
with:
|
||||
name: mlx-cuda
|
||||
path: dist
|
||||
@@ -209,7 +212,7 @@ jobs:
|
||||
name: pypi
|
||||
url: https://pypi.org/p/mlx-cpu
|
||||
steps:
|
||||
- uses: actions/download-artifact@v6
|
||||
- uses: actions/download-artifact@v7
|
||||
with:
|
||||
pattern: mlx-cpu-*
|
||||
merge-multiple: true
|
||||
@@ -231,7 +234,7 @@ jobs:
|
||||
name: pypi
|
||||
url: https://pypi.org/p/mlx-metal
|
||||
steps:
|
||||
- uses: actions/download-artifact@v6
|
||||
- uses: actions/download-artifact@v7
|
||||
with:
|
||||
name: mlx-metal
|
||||
path: dist
|
||||
|
||||
@@ -273,7 +273,7 @@ target_link_libraries(mlx PRIVATE $<BUILD_INTERFACE:fmt::fmt-header-only>)
|
||||
if(MLX_BUILD_PYTHON_BINDINGS)
|
||||
message(STATUS "Building Python bindings.")
|
||||
find_package(
|
||||
Python 3.8
|
||||
Python 3.10
|
||||
COMPONENTS Interpreter Development.Module
|
||||
REQUIRED)
|
||||
execute_process(
|
||||
|
||||
@@ -29,17 +29,20 @@ MLX has a CUDA backend which you can install with:
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
pip install mlx[cuda]
|
||||
pip install mlx[cuda12]
|
||||
|
||||
|
||||
To install the CUDA package from PyPi your system must meet the following
|
||||
requirements:
|
||||
|
||||
- Nvidia architecture >= SM 7.0 (Volta)
|
||||
- Nvidia architecture >= SM 7.5
|
||||
- Nvidia driver >= 550.54.14
|
||||
- CUDA toolkit >= 12.0
|
||||
- Linux distribution with glibc >= 2.35
|
||||
- Python >= 3.10
|
||||
|
||||
For CUDA 13 use ``pip install mlx[cuda13]``. The CUDA 13 package requires
|
||||
an Nvidia driver >= 580 or an appropriate CUDA compatibility package.
|
||||
|
||||
CPU-only (Linux)
|
||||
^^^^^^^^^^^^^^^^
|
||||
|
||||
@@ -186,7 +186,7 @@ Boolean masks follow NumPy semantics:
|
||||
.. code-block:: shell
|
||||
|
||||
>>> a = mx.arange(1000).reshape(10, 10, 10)
|
||||
>>> a[mx.random.randn(10, 10) > 0.0] = 0 # valid: mask covers axes 0 and 1
|
||||
>>> a[mx.random.normal((10, 10)) > 0.0] = 0 # valid: mask covers axes 0 and 1
|
||||
|
||||
The mask of shape ``(10, 10)`` applies to the first two axes, so ``a[mask]``
|
||||
selects the 1-D slices ``a[i, j, :]`` where ``mask[i, j]`` is ``True``.
|
||||
|
||||
@@ -3,6 +3,6 @@ requires = [
|
||||
"setuptools>=42",
|
||||
"cmake>=3.25",
|
||||
"mlx>=0.18.0",
|
||||
"nanobind==2.4.0",
|
||||
"nanobind==2.10.2",
|
||||
]
|
||||
build-backend = "setuptools.build_meta"
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
setuptools>=42
|
||||
cmake>=3.25
|
||||
mlx>=0.21.0
|
||||
nanobind==2.4.0
|
||||
nanobind==2.10.2
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
target_sources(
|
||||
mlx
|
||||
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/array.cpp
|
||||
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/array.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/compile.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/dtype.cpp
|
||||
|
||||
@@ -1,24 +0,0 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
|
||||
#include <cstdlib>
|
||||
#include <sstream>
|
||||
|
||||
#include "mlx/allocator.h"
|
||||
|
||||
namespace mlx::core::allocator {
|
||||
|
||||
Buffer malloc(size_t size) {
|
||||
auto buffer = allocator().malloc(size);
|
||||
if (size && !buffer.ptr()) {
|
||||
std::ostringstream msg;
|
||||
msg << "[malloc] Unable to allocate " << size << " bytes.";
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
return buffer;
|
||||
}
|
||||
|
||||
void free(Buffer buffer) {
|
||||
allocator().free(buffer);
|
||||
}
|
||||
|
||||
} // namespace mlx::core::allocator
|
||||
@@ -28,16 +28,16 @@ class Buffer {
|
||||
};
|
||||
};
|
||||
|
||||
Buffer malloc(size_t size);
|
||||
|
||||
void free(Buffer buffer);
|
||||
|
||||
class Allocator {
|
||||
/** Abstract base class for a memory allocator. */
|
||||
public:
|
||||
virtual Buffer malloc(size_t size) = 0;
|
||||
virtual void free(Buffer buffer) = 0;
|
||||
virtual size_t size(Buffer buffer) const = 0;
|
||||
virtual Buffer make_buffer(void* ptr, size_t size) {
|
||||
return Buffer{nullptr};
|
||||
};
|
||||
virtual void release(Buffer buffer) {}
|
||||
|
||||
Allocator() = default;
|
||||
Allocator(const Allocator& other) = delete;
|
||||
@@ -49,4 +49,25 @@ class Allocator {
|
||||
|
||||
Allocator& allocator();
|
||||
|
||||
inline Buffer malloc(size_t size) {
|
||||
return allocator().malloc(size);
|
||||
}
|
||||
|
||||
inline void free(Buffer buffer) {
|
||||
allocator().free(buffer);
|
||||
}
|
||||
|
||||
// Make a Buffer from a raw pointer of the given size without a copy. If a
|
||||
// no-copy conversion is not possible then the returned buffer.ptr() will be
|
||||
// nullptr. Any buffer created with this function must be released with
|
||||
// release(buffer)
|
||||
inline Buffer make_buffer(void* ptr, size_t size) {
|
||||
return allocator().make_buffer(ptr, size);
|
||||
};
|
||||
|
||||
// Release a buffer from the allocator made with make_buffer
|
||||
inline void release(Buffer buffer) {
|
||||
allocator().release(buffer);
|
||||
}
|
||||
|
||||
} // namespace mlx::core::allocator
|
||||
|
||||
@@ -82,6 +82,28 @@ array::array(std::initializer_list<int> data, Dtype dtype)
|
||||
init(data.begin());
|
||||
}
|
||||
|
||||
array::array(
|
||||
void* data,
|
||||
Shape shape,
|
||||
Dtype dtype,
|
||||
const std::function<void(void*)>& deleter)
|
||||
: array_desc_(std::make_shared<ArrayDesc>(std::move(shape), dtype)) {
|
||||
auto buffer = allocator::make_buffer(data, nbytes());
|
||||
if (buffer.ptr() == nullptr) {
|
||||
set_data(allocator::malloc(nbytes()));
|
||||
auto ptr = static_cast<char*>(data);
|
||||
std::copy(ptr, ptr + nbytes(), this->data<char>());
|
||||
deleter(data);
|
||||
} else {
|
||||
auto wrapped_deleter = [deleter](allocator::Buffer buffer) {
|
||||
auto ptr = buffer.ptr();
|
||||
allocator::release(buffer);
|
||||
return deleter(ptr);
|
||||
};
|
||||
set_data(buffer, std::move(wrapped_deleter));
|
||||
}
|
||||
}
|
||||
|
||||
/* Build an array from a shared buffer */
|
||||
array::array(allocator::Buffer data, Shape shape, Dtype dtype, Deleter deleter)
|
||||
: array_desc_(std::make_shared<ArrayDesc>(std::move(shape), dtype)) {
|
||||
|
||||
10
mlx/array.h
10
mlx/array.h
@@ -57,6 +57,16 @@ class array {
|
||||
Shape shape,
|
||||
Dtype dtype = TypeToDtype<T>());
|
||||
|
||||
/* Build an array from a raw pointer. The constructor will attempt to use the
|
||||
* input data without a copy. The deleter will be called when the array no
|
||||
* longer needs the underlying memory - after the array is destroyed in the
|
||||
* no-copy case and after the copy otherwise. */
|
||||
explicit array(
|
||||
void* data,
|
||||
Shape shape,
|
||||
Dtype dtype,
|
||||
const std::function<void(void*)>& deleter);
|
||||
|
||||
/* Build an array from a buffer */
|
||||
explicit array(
|
||||
allocator::Buffer data,
|
||||
|
||||
@@ -130,7 +130,7 @@ void compiled_allocate_outputs(
|
||||
// - Donatable
|
||||
// - Not a constant
|
||||
if (in.itemsize() == outputs[o].itemsize() && !is_scalar(in) &&
|
||||
in.is_donatable() && is_constant(i)) {
|
||||
in.is_donatable() && !is_constant(i)) {
|
||||
outputs[o++].copy_shared_buffer(in);
|
||||
}
|
||||
// Get representative input flags to properly set non-donated outputs
|
||||
@@ -158,7 +158,7 @@ void compiled_allocate_outputs(
|
||||
// - Not a constant
|
||||
if (in.flags().row_contiguous && in.size() == outputs[o].size() &&
|
||||
in.itemsize() == outputs[o].itemsize() && in.is_donatable() &&
|
||||
is_constant(i)) {
|
||||
!is_constant(i)) {
|
||||
outputs[o].copy_shared_buffer(
|
||||
in, outputs[o].strides(), in.flags(), in.data_size());
|
||||
o++;
|
||||
|
||||
@@ -291,6 +291,17 @@ void RandomBits::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
num_keys,
|
||||
kshape = keys.shape(),
|
||||
kstrides = keys.strides()]() mutable {
|
||||
auto copy_remaining = [&](char* cptr, size_t loc, uint32_t v) {
|
||||
if (4 * loc + 4 <= bytes_per_key) {
|
||||
reinterpret_cast<uint32_t*>(cptr)[loc] = v;
|
||||
} else {
|
||||
std::copy(
|
||||
reinterpret_cast<char*>(&v),
|
||||
reinterpret_cast<char*>(&v) + bytes_per_key - 4 * loc,
|
||||
cptr + 4 * loc);
|
||||
}
|
||||
};
|
||||
|
||||
size_t out_skip = (bytes_per_key + 4 - 1) / 4;
|
||||
auto half_size = out_skip / 2;
|
||||
bool even = out_skip % 2 == 0;
|
||||
@@ -310,18 +321,12 @@ void RandomBits::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
if (count.first < half_size) {
|
||||
auto rb = random::threefry2x32_hash(key, count);
|
||||
ptr[count.first++] = rb.first;
|
||||
if (bytes_per_key % 4 > 0) {
|
||||
std::copy(
|
||||
reinterpret_cast<char*>(&rb.second),
|
||||
reinterpret_cast<char*>(&rb.second) + bytes_per_key % 4,
|
||||
cptr + 4 * count.second);
|
||||
} else {
|
||||
ptr[count.second] = rb.second;
|
||||
}
|
||||
copy_remaining(cptr, count.second, rb.second);
|
||||
}
|
||||
if (!even) {
|
||||
count.second = 0;
|
||||
ptr[half_size] = random::threefry2x32_hash(key, count).first;
|
||||
copy_remaining(
|
||||
cptr, half_size, random::threefry2x32_hash(key, count).first);
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
@@ -3,5 +3,9 @@
|
||||
#include "mlx/backend/cpu/simd/base_simd.h"
|
||||
|
||||
#ifdef MLX_USE_ACCELERATE
|
||||
#if defined(__x86_64__)
|
||||
// the accelerate_simd implementation require neon -- use base implementation
|
||||
#else
|
||||
#include "mlx/backend/cpu/simd/accelerate_simd.h"
|
||||
#endif
|
||||
#endif
|
||||
|
||||
@@ -20,6 +20,19 @@ constexpr int page_size = 16384;
|
||||
// Any allocations smaller than this will try to use the small pool
|
||||
constexpr int small_block_size = 8;
|
||||
|
||||
#if CUDART_VERSION >= 13000
|
||||
inline cudaMemLocation cuda_mem_loc(int i) {
|
||||
cudaMemLocation loc;
|
||||
loc.type = cudaMemLocationTypeDevice;
|
||||
loc.id = i;
|
||||
return loc;
|
||||
}
|
||||
#else
|
||||
inline int cuda_mem_loc(int i) {
|
||||
return i;
|
||||
}
|
||||
#endif // CUDART_VERSION >= 13000
|
||||
|
||||
// The small pool size in bytes. This should be a multiple of the host page
|
||||
// size and small_block_size.
|
||||
constexpr int small_pool_size = 4 * page_size;
|
||||
@@ -35,13 +48,7 @@ SmallSizePool::SmallSizePool() {
|
||||
int device_count = 0;
|
||||
CHECK_CUDA_ERROR(cudaGetDeviceCount(&device_count));
|
||||
for (int i = 0; i < device_count; ++i) {
|
||||
#if CUDART_VERSION >= 13000
|
||||
cudaMemLocation loc;
|
||||
loc.type = cudaMemLocationTypeDevice;
|
||||
loc.id = i;
|
||||
#else
|
||||
int loc = i;
|
||||
#endif // CUDART_VERSION >= 13000
|
||||
auto loc = cuda_mem_loc(i);
|
||||
CHECK_CUDA_ERROR(
|
||||
cudaMemAdvise(data_, small_pool_size, cudaMemAdviseSetAccessedBy, loc));
|
||||
}
|
||||
@@ -90,9 +97,10 @@ CudaAllocator::CudaAllocator()
|
||||
page_size,
|
||||
[](CudaBuffer* buf) { return buf->size; },
|
||||
[this](CudaBuffer* buf) { cuda_free(buf); }) {
|
||||
size_t free, total;
|
||||
CHECK_CUDA_ERROR(cudaMemGetInfo(&free, &total));
|
||||
memory_limit_ = total * 0.9;
|
||||
size_t free;
|
||||
CHECK_CUDA_ERROR(cudaMemGetInfo(&free, &total_memory_));
|
||||
memory_limit_ = total_memory_ * 0.95;
|
||||
free_limit_ = total_memory_ - memory_limit_;
|
||||
max_pool_size_ = memory_limit_;
|
||||
|
||||
int device_count = 0;
|
||||
@@ -104,6 +112,10 @@ CudaAllocator::CudaAllocator()
|
||||
cudaStream_t s;
|
||||
CHECK_CUDA_ERROR(cudaStreamCreateWithFlags(&s, cudaStreamNonBlocking));
|
||||
free_streams_.push_back(s);
|
||||
|
||||
cudaMemPool_t mem_pool;
|
||||
CHECK_CUDA_ERROR(cudaDeviceGetDefaultMemPool(&mem_pool, i));
|
||||
mem_pools_.push_back(mem_pool);
|
||||
}
|
||||
CHECK_CUDA_ERROR(cudaSetDevice(curr));
|
||||
}
|
||||
@@ -154,23 +166,35 @@ CudaAllocator::malloc_async(size_t size, int device, cudaStream_t stream) {
|
||||
}
|
||||
lock.unlock();
|
||||
if (!buf) {
|
||||
cudaError_t err;
|
||||
void* data = nullptr;
|
||||
if (device == -1) {
|
||||
err = cudaMallocManaged(&data, size);
|
||||
CHECK_CUDA_ERROR(cudaMallocManaged(&data, size));
|
||||
} else {
|
||||
err = cudaMallocAsync(&data, size, stream);
|
||||
}
|
||||
if (err != cudaSuccess && err != cudaErrorMemoryAllocation) {
|
||||
throw std::runtime_error(fmt::format(
|
||||
"cudaMallocManaged failed: {}.", cudaGetErrorString(err)));
|
||||
CHECK_CUDA_ERROR(cudaMallocAsync(&data, size, stream));
|
||||
}
|
||||
if (!data) {
|
||||
return Buffer{nullptr};
|
||||
std::ostringstream msg;
|
||||
msg << "[malloc] Unable to allocate " << size << " bytes.";
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
buf = new CudaBuffer{data, size, device};
|
||||
}
|
||||
lock.lock();
|
||||
|
||||
// If any cuda memory pool has too much reserved memory, clear some
|
||||
// memory from the cache. This prevents graph / kernel execution failing
|
||||
// from OOM
|
||||
if (get_cache_memory() > 0) {
|
||||
for (auto p : mem_pools_) {
|
||||
size_t used = 0;
|
||||
CHECK_CUDA_ERROR(cudaMemPoolGetAttribute(
|
||||
p, cudaMemPoolAttrReservedMemCurrent, &used));
|
||||
if (used > (total_memory_ - free_limit_)) {
|
||||
buffer_cache_.release_cached_buffers(free_limit_);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
active_memory_ += buf->size;
|
||||
peak_memory_ = std::max(active_memory_, peak_memory_);
|
||||
|
||||
@@ -71,11 +71,14 @@ class CudaAllocator : public allocator::Allocator {
|
||||
|
||||
std::mutex mutex_;
|
||||
size_t memory_limit_;
|
||||
size_t free_limit_;
|
||||
size_t total_memory_;
|
||||
size_t max_pool_size_;
|
||||
BufferCache<CudaBuffer> buffer_cache_;
|
||||
size_t active_memory_{0};
|
||||
size_t peak_memory_{0};
|
||||
std::vector<cudaStream_t> free_streams_;
|
||||
std::vector<cudaMemPool_t> mem_pools_;
|
||||
SmallSizePool scalar_pool_;
|
||||
};
|
||||
|
||||
|
||||
@@ -95,11 +95,14 @@ void copy_general_input(
|
||||
const InType* in_ptr = gpu_ptr<InType>(in) + offset_in;
|
||||
OutType* out_ptr = gpu_ptr<OutType>(out) + offset_out;
|
||||
int ndim = shape.size();
|
||||
int work_per_thread = 1;
|
||||
|
||||
int work_per_thread = 8;
|
||||
auto dim0 = ndim > 0 ? shape.back() : 1;
|
||||
auto rest = out.size() / dim0;
|
||||
if (dim0 >= 4) {
|
||||
if (dim0 >= 4 && dim0 < 8) {
|
||||
work_per_thread = 4;
|
||||
} else if (dim0 < 4) {
|
||||
work_per_thread = 1;
|
||||
}
|
||||
dim0 = (dim0 + work_per_thread - 1) / work_per_thread;
|
||||
auto block_dims = get_block_dims(dim0, rest, 1);
|
||||
@@ -110,7 +113,10 @@ void copy_general_input(
|
||||
dispatch_1_2_3(ndim, [&](auto dims_constant) {
|
||||
auto kernel =
|
||||
cu::copy_g_nd<InType, OutType, IdxT, dims_constant(), 1>;
|
||||
if (work_per_thread == 4) {
|
||||
if (work_per_thread == 8) {
|
||||
kernel =
|
||||
cu::copy_g_nd<InType, OutType, IdxT, dims_constant(), 8>;
|
||||
} else if (work_per_thread == 4) {
|
||||
kernel =
|
||||
cu::copy_g_nd<InType, OutType, IdxT, dims_constant(), 4>;
|
||||
}
|
||||
@@ -127,7 +133,9 @@ void copy_general_input(
|
||||
});
|
||||
} else { // ndim >= 4
|
||||
auto kernel = cu::copy_g<InType, OutType, IdxT, 1>;
|
||||
if (work_per_thread == 4) {
|
||||
if (work_per_thread == 8) {
|
||||
kernel = cu::copy_g<InType, OutType, IdxT, 8>;
|
||||
} else if (work_per_thread == 4) {
|
||||
kernel = cu::copy_g<InType, OutType, IdxT, 4>;
|
||||
}
|
||||
encoder.add_kernel_node(
|
||||
|
||||
@@ -318,46 +318,67 @@ void CommandEncoder::add_kernel_node(const CUDA_KERNEL_NODE_PARAMS& params) {
|
||||
insert_graph_dependencies(GraphNode{node, "K"});
|
||||
}
|
||||
|
||||
bool is_graph_updatable(cudaGraph_t graph, int& cluster_dim_x) {
|
||||
// CUDA graphs do not get updated correctly if a kernel node getting updated
|
||||
// has a different cluster shape than the node it's being updated with.
|
||||
std::pair<std::string, bool> subgraph_to_key(cudaGraph_t graph) {
|
||||
// Constructs a key representing the nodes of a sub-graph.
|
||||
// Also checks if the sub-graph is updatable as CUDA graphs do not get
|
||||
// updated correctly if a kernel node getting updated has a different cluster
|
||||
// shape than the node it's being updated with.
|
||||
std::string key = "(";
|
||||
size_t num_nodes = 0;
|
||||
CHECK_CUDA_ERROR(cudaGraphGetNodes(graph, nullptr, &num_nodes));
|
||||
if (num_nodes == 0) {
|
||||
return true;
|
||||
return {key + ")", true};
|
||||
}
|
||||
|
||||
bool is_updatable = true;
|
||||
std::vector<cudaGraphNode_t> nodes(num_nodes);
|
||||
CHECK_CUDA_ERROR(cudaGraphGetNodes(graph, nodes.data(), &num_nodes));
|
||||
for (const auto& node : nodes) {
|
||||
if (!is_updatable) {
|
||||
break;
|
||||
}
|
||||
cudaGraphNodeType type;
|
||||
CHECK_CUDA_ERROR(cudaGraphNodeGetType(node, &type));
|
||||
if (type == cudaGraphNodeTypeGraph) {
|
||||
// Try to be updatable for a structure like graph -> graph -> kernel
|
||||
if (num_nodes > 1) {
|
||||
return false;
|
||||
switch (type) {
|
||||
case cudaGraphNodeTypeGraph: {
|
||||
// Try to be updatable for a structure like graph -> graph -> kernel
|
||||
cudaGraph_t child;
|
||||
CHECK_CUDA_ERROR(cudaGraphChildGraphNodeGetGraph(node, &child));
|
||||
auto [subkey, sub_is_updatable] = subgraph_to_key(child);
|
||||
is_updatable &= sub_is_updatable;
|
||||
key += subkey;
|
||||
break;
|
||||
}
|
||||
cudaGraph_t child;
|
||||
CHECK_CUDA_ERROR(cudaGraphChildGraphNodeGetGraph(node, &child));
|
||||
return is_graph_updatable(child, cluster_dim_x);
|
||||
} else if (type != cudaGraphNodeTypeKernel) {
|
||||
return false;
|
||||
} else {
|
||||
cudaLaunchAttributeValue cluster_dim;
|
||||
CHECK_CUDA_ERROR(cudaGraphKernelNodeGetAttribute(
|
||||
node, cudaLaunchAttributeClusterDimension, &cluster_dim));
|
||||
// Only dim.x can be greater than 1
|
||||
if (cluster_dim.clusterDim.y > 1 || cluster_dim.clusterDim.z > 1) {
|
||||
return false;
|
||||
case cudaGraphNodeTypeHost:
|
||||
key += "H";
|
||||
break;
|
||||
case cudaGraphNodeTypeMemset:
|
||||
key += "M";
|
||||
break;
|
||||
case cudaGraphNodeTypeKernel: {
|
||||
cudaLaunchAttributeValue cluster_dim;
|
||||
CHECK_CUDA_ERROR(cudaGraphKernelNodeGetAttribute(
|
||||
node, cudaLaunchAttributeClusterDimension, &cluster_dim));
|
||||
// Only allow dim.x to be greater than 1
|
||||
if (cluster_dim.clusterDim.y > 1 || cluster_dim.clusterDim.z > 1) {
|
||||
is_updatable = false;
|
||||
} else {
|
||||
key += "K";
|
||||
key += std::to_string(cluster_dim.clusterDim.x);
|
||||
}
|
||||
break;
|
||||
}
|
||||
// Only one child node allowed when subgraph uses clusters
|
||||
if (cluster_dim.clusterDim.x > 0 && num_nodes > 1) {
|
||||
return false;
|
||||
}
|
||||
cluster_dim_x = cluster_dim.clusterDim.x;
|
||||
case cudaGraphNodeTypeWaitEvent:
|
||||
key += "W";
|
||||
break;
|
||||
case cudaGraphNodeTypeEventRecord:
|
||||
key += "R";
|
||||
break;
|
||||
default:
|
||||
is_updatable = false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
key += ")";
|
||||
return {key, is_updatable};
|
||||
}
|
||||
|
||||
void CommandEncoder::add_graph_node(cudaGraph_t child) {
|
||||
@@ -370,11 +391,10 @@ void CommandEncoder::add_graph_node(cudaGraph_t child) {
|
||||
return;
|
||||
}
|
||||
cudaGraphNode_t node;
|
||||
int cluster_dim_x = 0;
|
||||
is_graph_updatable_ &= is_graph_updatable(child, cluster_dim_x);
|
||||
auto [sub_graph_key, is_updatable] = subgraph_to_key(child);
|
||||
is_graph_updatable_ &= is_updatable;
|
||||
CHECK_CUDA_ERROR(cudaGraphAddChildGraphNode(&node, graph_, NULL, 0, child));
|
||||
insert_graph_dependencies(
|
||||
GraphNode{node, "G" + std::to_string(cluster_dim_x)});
|
||||
insert_graph_dependencies(GraphNode{node, sub_graph_key});
|
||||
}
|
||||
|
||||
bool CommandEncoder::needs_commit() {
|
||||
|
||||
@@ -106,7 +106,7 @@ class CommandEncoder {
|
||||
cudaGraphNode_t node;
|
||||
// K = kernel
|
||||
// E = empty
|
||||
// G* = subgraph (with metadata)
|
||||
// () = subgraph (with metadata)
|
||||
// Symbols ':', '-' are reserved as separators
|
||||
std::string node_type;
|
||||
std::string id;
|
||||
|
||||
@@ -2,7 +2,11 @@
|
||||
|
||||
#include "mlx/backend/cuda/device.h"
|
||||
#include "mlx/backend/cuda/kernel_utils.cuh"
|
||||
#include "mlx/backend/cuda/quantized/mxfp8_quantize.cuh"
|
||||
#include "mlx/backend/cuda/quantized/nvfp4_quantize.cuh"
|
||||
#include "mlx/backend/cuda/quantized/quantized.h"
|
||||
#include "mlx/backend/cuda/quantized/quantized_utils.cuh"
|
||||
#include "mlx/backend/cuda/vector_types.cuh"
|
||||
#include "mlx/dtype_utils.h"
|
||||
|
||||
#include <cooperative_groups.h>
|
||||
@@ -13,17 +17,6 @@
|
||||
namespace mlx::core {
|
||||
namespace cu {
|
||||
|
||||
template <int bits>
|
||||
struct Quantize {
|
||||
__device__ uint8_t operator()(float x) {
|
||||
if constexpr (bits == 8) {
|
||||
return __nv_fp8_e4m3(x).__x;
|
||||
} else {
|
||||
return __nv_fp4_e2m1(x).__x;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <int bits>
|
||||
struct Dequantize {
|
||||
__device__ float operator()(uint8_t x) {
|
||||
@@ -37,29 +30,40 @@ struct Dequantize {
|
||||
|
||||
namespace cg = cooperative_groups;
|
||||
|
||||
template <typename T, int group_size, int bits, bool use_mx_scale>
|
||||
__global__ void
|
||||
fp_quantize(const T* w, uint8_t* out, uint8_t* scales, size_t size) {
|
||||
template <typename T, int group_size, int bits, bool use_mx_scale, bool USE_SR>
|
||||
__global__ void fp_quantize(T* w, uint8_t* out, uint8_t* scales, size_t size) {
|
||||
using Tx2 = Vector2_t<T>;
|
||||
using Tx4 = Vector4_t<T>;
|
||||
uint32_t rbits = 0; // reserved bits for future use
|
||||
auto block_size = cg::this_thread_block().dim_threads();
|
||||
auto block_idx = cg::this_thread_block().group_index();
|
||||
auto idx_in_block = cg::this_thread_block().thread_index();
|
||||
|
||||
auto tidx = block_idx.x * block_size.x + idx_in_block.x;
|
||||
auto tidy = block_idx.y * block_size.y + idx_in_block.y;
|
||||
auto grid_dim_x = cg::this_grid().dim_blocks().x * block_size.x;
|
||||
|
||||
auto grid_dim_x =
|
||||
cg::this_grid().dim_blocks().x * cg::this_grid().block_index().x;
|
||||
size_t index = tidx + grid_dim_x * size_t(tidy);
|
||||
if (index >= size) {
|
||||
size_t thread_idx = tidx + grid_dim_x * size_t(tidy);
|
||||
size_t base_idx = thread_idx * group_size;
|
||||
|
||||
if (base_idx >= size) {
|
||||
return;
|
||||
}
|
||||
|
||||
float w_thread = w[index];
|
||||
auto w_tile = load_vector<group_size, T>(w, thread_idx);
|
||||
float scale = 0.0f;
|
||||
|
||||
cg::greater<float> max_op;
|
||||
auto warp = cg::tiled_partition<group_size>(cg::this_thread_block());
|
||||
Tx2 amax_2x = Tx2{0.0f, 0.0f};
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < group_size; i += 2) {
|
||||
auto pair = Tx2{w_tile[i], w_tile[i + 1]};
|
||||
abs_max_x2<Tx2>(amax_2x, amax_2x, pair);
|
||||
}
|
||||
|
||||
scale = static_cast<float>(
|
||||
max(fabsf(static_cast<float>(amax_2x.x)),
|
||||
fabsf(static_cast<float>(amax_2x.y))));
|
||||
|
||||
float scale = cg::reduce(warp, abs(w_thread), max_op);
|
||||
scale /= bits == 4 ? 6.0f : 448.0f;
|
||||
// Convert to mx scale or nv scale
|
||||
using ScaleType =
|
||||
@@ -68,21 +72,24 @@ fp_quantize(const T* w, uint8_t* out, uint8_t* scales, size_t size) {
|
||||
uint8_t q_scale = s.__x;
|
||||
scale = float(s);
|
||||
|
||||
// Write out the scales
|
||||
size_t gindex = index / group_size;
|
||||
if (index % group_size == 0) {
|
||||
scales[gindex] = q_scale;
|
||||
}
|
||||
scales[thread_idx] = q_scale;
|
||||
constexpr int elem_per_byte = bits == 8 ? 1 : 2;
|
||||
AlignedVector<uint8_t, group_size / elem_per_byte> quantized;
|
||||
|
||||
uint8_t output = Quantize<bits>{}(scale == 0 ? 0.0f : w_thread / scale);
|
||||
if (bits == 4) {
|
||||
uint8_t sval = warp.shfl_down(output, 1);
|
||||
output |= sval << bits;
|
||||
}
|
||||
constexpr int pack_factor = bits == 8 ? 1 : 2;
|
||||
if (index % pack_factor == 0) {
|
||||
out[index / pack_factor] = output;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < group_size / 4; i++) {
|
||||
Tx4 w_Tx4 = *reinterpret_cast<Tx4*>(&w_tile[i * 4]);
|
||||
if constexpr (bits == 8) {
|
||||
uint32_t quantized_val =
|
||||
scale_cvt_Tx4_to_fp8x4<T, USE_SR>(w_Tx4, 1.0f / scale, rbits);
|
||||
*reinterpret_cast<uint32_t*>(&quantized[i * 4]) = quantized_val;
|
||||
} else {
|
||||
uint16_t quantized_val =
|
||||
scale_cvt_Tx4_to_fp4x4<T, USE_SR>(w_Tx4, 1.0f / scale, rbits);
|
||||
*reinterpret_cast<uint16_t*>(&quantized[i * 2]) = quantized_val;
|
||||
}
|
||||
}
|
||||
store_vector<group_size / elem_per_byte>(out, thread_idx, quantized);
|
||||
}
|
||||
|
||||
template <typename T, int group_size, int bits, bool use_mx_scale>
|
||||
@@ -142,15 +149,16 @@ void fp_quantize(
|
||||
dispatch_float_types(w.dtype(), "fp_quantize", [&](auto type_tag) {
|
||||
using T = cuda_type_t<MLX_GET_TYPE(type_tag)>;
|
||||
if constexpr (!std::is_same_v<T, double>) {
|
||||
auto kernel = cu::fp_quantize<T, 32, 4, true>;
|
||||
auto kernel = cu::fp_quantize<T, 32, 4, true, false>;
|
||||
if (bits == 8) {
|
||||
kernel = cu::fp_quantize<T, 32, 8, true>;
|
||||
kernel = cu::fp_quantize<T, 32, 8, true, false>;
|
||||
} else if (group_size == 16) {
|
||||
kernel = cu::fp_quantize<T, 16, 4, false>;
|
||||
kernel = cu::fp_quantize<T, 16, 4, false, false>;
|
||||
}
|
||||
bool large = w.size() > UINT_MAX;
|
||||
auto [num_blocks, block_dims] =
|
||||
get_launch_args(w.size(), w.shape(), w.strides(), large);
|
||||
get_launch_args(w.size(), w.shape(), w.strides(), large, group_size);
|
||||
|
||||
enc.add_kernel_node(
|
||||
kernel,
|
||||
num_blocks,
|
||||
|
||||
32
mlx/backend/cuda/quantized/mxfp8_quantize.cuh
Normal file
32
mlx/backend/cuda/quantized/mxfp8_quantize.cuh
Normal file
@@ -0,0 +1,32 @@
|
||||
#pragma once
|
||||
|
||||
#include <cuda.h>
|
||||
#include <cuda_fp8.h>
|
||||
#include <cuda_runtime.h>
|
||||
#include "mlx/backend/cuda/vector_types.cuh"
|
||||
|
||||
namespace mlx::core::cu {
|
||||
|
||||
// TODO implement fast path
|
||||
template <typename T>
|
||||
__device__ __forceinline__ uint32_t
|
||||
scale_cvt_Tx4_to_fp8x4_fallback(const Vector4_t<T> input, const float scale) {
|
||||
uint32_t out_fp8x4 = 0;
|
||||
float4 scaled;
|
||||
scaled.x = static_cast<float>(input.x) * scale;
|
||||
scaled.y = static_cast<float>(input.y) * scale;
|
||||
scaled.z = static_cast<float>(input.z) * scale;
|
||||
scaled.w = static_cast<float>(input.w) * scale;
|
||||
out_fp8x4 = __nv_fp8x4_e4m3(scaled).__x;
|
||||
return out_fp8x4;
|
||||
}
|
||||
|
||||
// Place holder for future fast path implementation
|
||||
template <typename T, bool USE_SR>
|
||||
__device__ __forceinline__ uint32_t scale_cvt_Tx4_to_fp8x4(
|
||||
const Vector4_t<T> input,
|
||||
const float scale,
|
||||
uint32_t rbits) {
|
||||
return scale_cvt_Tx4_to_fp8x4_fallback(input, scale);
|
||||
}
|
||||
} // namespace mlx::core::cu
|
||||
334
mlx/backend/cuda/quantized/nvfp4_quantize.cuh
Normal file
334
mlx/backend/cuda/quantized/nvfp4_quantize.cuh
Normal file
@@ -0,0 +1,334 @@
|
||||
#pragma once
|
||||
|
||||
#include <cuda.h>
|
||||
#include <cuda_fp4.h>
|
||||
#include <cuda_runtime.h>
|
||||
#include "mlx/backend/cuda/vector_types.cuh"
|
||||
|
||||
namespace mlx::core::cu {
|
||||
|
||||
using bf16x4 = Vector4_t<__nv_bfloat16>;
|
||||
using fp16x4 = Vector4_t<__half>;
|
||||
using f32x4 = Vector4_t<float>;
|
||||
|
||||
template <typename T>
|
||||
__device__ __forceinline__ uint16_t
|
||||
scale_cvt_Tx4_to_fp4x4_fallback(const Vector4_t<T> input, const float scale) {
|
||||
// Fallback implementation for architectures that do not support cvt
|
||||
// instructions or for cuda versions with no fp4 support (< 12.8) -> scalar
|
||||
uint16_t out_fp4x4 = 0;
|
||||
fp32x4 scaled;
|
||||
scaled.x = static_cast<float>(input.x) * scale;
|
||||
scaled.y = static_cast<float>(input.y) * scale;
|
||||
scaled.z = static_cast<float>(input.z) * scale;
|
||||
scaled.w = static_cast<float>(input.w) * scale;
|
||||
uint8_t q0 = __nv_fp4_e2m1(scaled.x).__x;
|
||||
uint8_t q1 = __nv_fp4_e2m1(scaled.y).__x;
|
||||
uint8_t q2 = __nv_fp4_e2m1(scaled.z).__x;
|
||||
uint8_t q3 = __nv_fp4_e2m1(scaled.w).__x;
|
||||
out_fp4x4 = (static_cast<uint16_t>(q3) << 12) |
|
||||
(static_cast<uint16_t>(q2) << 8) | (static_cast<uint16_t>(q1) << 4) |
|
||||
static_cast<uint16_t>(q0);
|
||||
return out_fp4x4;
|
||||
}
|
||||
|
||||
#if (CUDART_VERSION >= 12080) && (__CUDA_ARCH__ >= 1000) && \
|
||||
defined(__CUDA_ARCH_SPECIFIC__)
|
||||
|
||||
__device__ __forceinline__ uint16_t
|
||||
scale_cvt_bf16x4_to_fp4x4_rn(const bf16x4 input_bf16x4, const float2 scale) {
|
||||
uint16_t out_fp4x4 = 0;
|
||||
asm volatile(
|
||||
"{\n"
|
||||
".reg.b16 x0_bf16; \n\t" // first bf16
|
||||
".reg.b16 x1_bf16; \n\t" // second bf16
|
||||
".reg.b16 x2_bf16; \n\t" // third bf16
|
||||
".reg.b16 x3_bf16; \n\t" // fourth bf16
|
||||
".reg.b32 x0; \n\t" // to hold scaled first
|
||||
".reg.b32 x1; \n\t" // to hold scaled second
|
||||
".reg.b32 x2; \n\t" // to hold scaled third
|
||||
".reg.b32 x3; \n\t" // to hold scaled fourth
|
||||
".reg.b64 x01; \n\t" // to hold vector mul
|
||||
".reg.b64 x23; \n\t"
|
||||
".reg.b8 q0; \n\t" // output byte fp4x2 (first pair)
|
||||
".reg.b8 q1; \n\t" // output byte fp4x2 (second pair)
|
||||
"mov.b64 {x0_bf16, x1_bf16, x2_bf16, x3_bf16} , %1; \n\t" // unpack bf16
|
||||
"cvt.f32.bf16 x0, x0_bf16; \n\t" // convert to f32
|
||||
"cvt.f32.bf16 x1, x1_bf16; \n\t"
|
||||
"cvt.f32.bf16 x2, x2_bf16; \n\t"
|
||||
"cvt.f32.bf16 x3, x3_bf16; \n\t"
|
||||
"mov.b64 x01, {x0, x1}; \n\t"
|
||||
"mul.f32x2 x01, x01, %2; \n\t" // scale first pair
|
||||
"mov.b64 x23, {x2, x3}; \n\t"
|
||||
"mul.f32x2 x23, x23, %2; \n\t" // scale second pair
|
||||
"mov.b64 {x0, x1}, x01; \n\t"
|
||||
"mov.b64 {x2, x3}, x23; \n\t"
|
||||
"cvt.rn.satfinite.e2m1x2.f32 q0, x1, x0; \n\t" // convert to fp4x2 first
|
||||
// pair
|
||||
"cvt.rn.satfinite.e2m1x2.f32 q1, x3, x2; \n\t" // convert to fp4x2 second
|
||||
// pair
|
||||
"mov.b16 %0, {q0, q1}; \n\t" // pack to output
|
||||
"}"
|
||||
: "=h"(out_fp4x4)
|
||||
: "l"(reinterpret_cast<const uint64_t&>(input_bf16x4)),
|
||||
"l"(reinterpret_cast<const uint64_t&>(
|
||||
scale))); // here cast is needed becuase an asm operand must have
|
||||
// scalar type
|
||||
return out_fp4x4;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ uint16_t scale_cvt_bf16x4_to_fp4x4_rs(
|
||||
const bf16x4 input_bf16x4,
|
||||
const float2 scale,
|
||||
uint32_t rbits) {
|
||||
uint16_t out_fp4x4 = 0;
|
||||
asm volatile(
|
||||
"{\n"
|
||||
".reg.b16 x0_bf16; \n\t"
|
||||
".reg.b16 x1_bf16; \n\t"
|
||||
".reg.b16 x2_bf16; \n\t"
|
||||
".reg.b16 x3_bf16; \n\t"
|
||||
".reg.b32 x0; \n\t"
|
||||
".reg.b32 x1; \n\t"
|
||||
".reg.b32 x2; \n\t"
|
||||
".reg.b32 x3; \n\t"
|
||||
".reg.b64 x01; \n\t"
|
||||
".reg.b64 x23; \n\t"
|
||||
".reg.b16 q0; \n\t"
|
||||
"mov.b64 {x0_bf16, x1_bf16, x2_bf16, x3_bf16} , %1; \n\t"
|
||||
"cvt.f32.bf16 x0, x0_bf16; \n\t"
|
||||
"cvt.f32.bf16 x1, x1_bf16; \n\t"
|
||||
"cvt.f32.bf16 x2, x2_bf16; \n\t"
|
||||
"cvt.f32.bf16 x3, x3_bf16; \n\t"
|
||||
"mov.b64 x01, {x0, x1}; \n\t"
|
||||
"mul.f32x2 x01, x01, %2; \n\t"
|
||||
"mov.b64 x23, {x2, x3}; \n\t"
|
||||
"mul.f32x2 x23, x23, %2; \n\t"
|
||||
"mov.b64 {x0, x1}, x01; \n\t"
|
||||
"mov.b64 {x2, x3}, x23; \n\t"
|
||||
"cvt.rs.satfinite.e2m1x4.f32 q0, {x3, x2, x1, x0}, %3; \n\t"
|
||||
"}"
|
||||
: "=h"(out_fp4x4)
|
||||
: "l"(reinterpret_cast<const uint64_t&>(input_bf16x4)),
|
||||
"l"(reinterpret_cast<const uint64_t&>(scale)),
|
||||
"r"(rbits));
|
||||
return out_fp4x4;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ uint16_t scale_cvt_fp32x4_to_fp4x4_rn(
|
||||
const float2 input_fp32x2_0,
|
||||
const float2 input_fp32x2_1,
|
||||
const float2 scale) {
|
||||
uint16_t out_fp4x4 = 0;
|
||||
asm volatile(
|
||||
"{\n"
|
||||
".reg.b32 x0; \n\t"
|
||||
".reg.b32 x1; \n\t"
|
||||
".reg.b32 x2; \n\t"
|
||||
".reg.b32 x3; \n\t"
|
||||
".reg.b64 x01; \n\t"
|
||||
".reg.b64 x23; \n\t"
|
||||
".reg.b8 q0; \n\t"
|
||||
".reg.b8 q1; \n\t"
|
||||
"mov.b64 x01, {%1, %2}; \n\t"
|
||||
"mul.f32x2 x01, x01, %5; \n\t"
|
||||
"mov.b64 x23, {%3, %4}; \n\t"
|
||||
"mul.f32x2 x23, x23, %5; \n\t"
|
||||
"mov.b64 {x0, x1}, x01; \n\t"
|
||||
"mov.b64 {x2, x3}, x23; \n\t"
|
||||
"cvt.rn.satfinite.e2m1x2.f32 q0, x1, x0; \n\t"
|
||||
"cvt.rn.satfinite.e2m1x2.f32 q1, x3, x2; \n\t"
|
||||
"mov.b16 %0, {q0, q1}; \n\t"
|
||||
"}"
|
||||
: "=h"(out_fp4x4)
|
||||
: "f"(input_fp32x2_0.x),
|
||||
"f"(input_fp32x2_0.y),
|
||||
"f"(input_fp32x2_1.x),
|
||||
"f"(input_fp32x2_1.y),
|
||||
"l"(reinterpret_cast<const uint64_t&>(scale)));
|
||||
return out_fp4x4;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ uint16_t scale_cvt_fp32x4_to_fp4x4_rs(
|
||||
const float2 input_fp32x2_0,
|
||||
const float2 input_fp32x2_1,
|
||||
const float2 scale,
|
||||
uint32_t rbits) {
|
||||
uint16_t out_fp4x4 = 0;
|
||||
asm volatile(
|
||||
"{\n"
|
||||
".reg.b32 x0; \n\t"
|
||||
".reg.b32 x1; \n\t"
|
||||
".reg.b32 x2; \n\t"
|
||||
".reg.b32 x3; \n\t"
|
||||
".reg.b64 x01; \n\t"
|
||||
".reg.b64 x23; \n\t"
|
||||
".reg.b16 q0; \n\t"
|
||||
"mov.b64 x01, {%1, %2}; \n\t"
|
||||
"mul.f32x2 x01, x01, %5; \n\t"
|
||||
"mov.b64 x23, {%3, %4}; \n\t"
|
||||
"mul.f32x2 x23, x23, %5; \n\t"
|
||||
"mov.b64 {x0, x1}, x01; \n\t"
|
||||
"mov.b64 {x2, x3}, x23; \n\t"
|
||||
"cvt.rs.satfinite.e2m1x4.f32 q0, {x3, x2, x1, x0}, %6; \n\t"
|
||||
"}"
|
||||
: "=h"(out_fp4x4)
|
||||
: "f"(input_fp32x2_0.x),
|
||||
"f"(input_fp32x2_0.y),
|
||||
"f"(input_fp32x2_1.x),
|
||||
"f"(input_fp32x2_1.y),
|
||||
"l"(reinterpret_cast<const uint64_t&>(scale)),
|
||||
"r"(rbits));
|
||||
return out_fp4x4;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ uint16_t
|
||||
scale_cvt_fp16x4_to_fp4x4_rn(const fp16x4 input_fp16x4, const float2 scale) {
|
||||
uint16_t out_fp4x4 = 0;
|
||||
asm volatile(
|
||||
"{\n"
|
||||
".reg.b16 x0_fp16; \n\t"
|
||||
".reg.b16 x1_fp16; \n\t"
|
||||
".reg.b16 x2_fp16; \n\t"
|
||||
".reg.b16 x3_fp16; \n\t"
|
||||
".reg.b32 x0; \n\t"
|
||||
".reg.b32 x1; \n\t"
|
||||
".reg.b32 x2; \n\t"
|
||||
".reg.b32 x3; \n\t"
|
||||
".reg.b64 x01; \n\t"
|
||||
".reg.b64 x23; \n\t"
|
||||
".reg.b8 q0; \n\t"
|
||||
".reg.b8 q1; \n\t"
|
||||
"mov.b64 {x0_fp16, x1_fp16, x2_fp16, x3_fp16} , %1; \n\t"
|
||||
"cvt.f32.f16 x0, x0_fp16; \n\t"
|
||||
"cvt.f32.f16 x1, x1_fp16; \n\t"
|
||||
"cvt.f32.f16 x2, x2_fp16; \n\t"
|
||||
"cvt.f32.f16 x3, x3_fp16; \n\t"
|
||||
"mov.b64 x01, {x0, x1}; \n\t"
|
||||
"mul.f32x2 x01, x01, %2; \n\t"
|
||||
"mov.b64 x23, {x2, x3}; \n\t"
|
||||
"mul.f32x2 x23, x23, %2; \n\t"
|
||||
"mov.b64 {x0, x1}, x01; \n\t"
|
||||
"mov.b64 {x2, x3}, x23; \n\t"
|
||||
"cvt.rn.satfinite.e2m1x2.f32 q0, x1, x0; \n\t"
|
||||
"cvt.rn.satfinite.e2m1x2.f32 q1, x3, x2; \n\t"
|
||||
"mov.b16 %0, {q0, q1}; \n\t"
|
||||
"}"
|
||||
: "=h"(out_fp4x4)
|
||||
: "l"(reinterpret_cast<const uint64_t&>(input_fp16x4)),
|
||||
"l"(reinterpret_cast<const uint64_t&>(scale)));
|
||||
return out_fp4x4;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ uint16_t scale_cvt_fp16x4_to_fp4x4_rs(
|
||||
const fp16x4 input_fp16x4,
|
||||
const float2 scale,
|
||||
uint32_t rbits) {
|
||||
uint16_t out_fp4x4 = 0;
|
||||
asm volatile(
|
||||
"{\n"
|
||||
".reg.b16 x0_fp16; \n\t"
|
||||
".reg.b16 x1_fp16; \n\t"
|
||||
".reg.b16 x2_fp16; \n\t"
|
||||
".reg.b16 x3_fp16; \n\t"
|
||||
".reg.b32 x0; \n\t"
|
||||
".reg.b32 x1; \n\t"
|
||||
".reg.b32 x2; \n\t"
|
||||
".reg.b32 x3; \n\t"
|
||||
".reg.b64 x01; \n\t"
|
||||
".reg.b64 x23; \n\t"
|
||||
".reg.b16 q0; \n\t"
|
||||
"mov.b64 {x0_fp16, x1_fp16, x2_fp16, x3_fp16} , %1; \n\t"
|
||||
"cvt.f32.f16 x0, x0_fp16; \n\t"
|
||||
"cvt.f32.f16 x1, x1_fp16; \n\t"
|
||||
"cvt.f32.f16 x2, x2_fp16; \n\t"
|
||||
"cvt.f32.f16 x3, x3_fp16; \n\t"
|
||||
"mov.b64 x01, {x0, x1}; \n\t"
|
||||
"mul.f32x2 x01, x01, %2; \n\t"
|
||||
"mov.b64 x23, {x2, x3}; \n\t"
|
||||
"mul.f32x2 x23, x23, %2; \n\t"
|
||||
"mov.b64 {x0, x1}, x01; \n\t"
|
||||
"mov.b64 {x2, x3}, x23; \n\t"
|
||||
"cvt.rs.satfinite.e2m1x4.f32 q0, {x3, x2, x1, x0}, %3; \n\t"
|
||||
"}"
|
||||
: "=h"(out_fp4x4)
|
||||
: "l"(reinterpret_cast<const uint64_t&>(input_fp16x4)),
|
||||
"l"(reinterpret_cast<const uint64_t&>(scale)),
|
||||
"r"(rbits));
|
||||
return out_fp4x4;
|
||||
}
|
||||
|
||||
template <bool USE_SR>
|
||||
__device__ __forceinline__ uint16_t scale_cvt_bf16x4_to_fp4x4(
|
||||
const bf16x4 input,
|
||||
const float scale,
|
||||
uint32_t rbits) {
|
||||
float2 scale_fp32x2 = make_float2(scale, scale);
|
||||
if constexpr (USE_SR) {
|
||||
return scale_cvt_bf16x4_to_fp4x4_rs(input, scale_fp32x2, rbits);
|
||||
} else {
|
||||
return scale_cvt_bf16x4_to_fp4x4_rn(input, scale_fp32x2);
|
||||
}
|
||||
}
|
||||
|
||||
template <bool USE_SR>
|
||||
__device__ __forceinline__ uint16_t scale_cvt_fp16x4_to_fp4x4(
|
||||
const fp16x4 input,
|
||||
const float scale,
|
||||
uint32_t rbits) {
|
||||
float2 scale_fp32x2 = make_float2(scale, scale);
|
||||
if constexpr (USE_SR) {
|
||||
return scale_cvt_fp16x4_to_fp4x4_rs(input, scale_fp32x2, rbits);
|
||||
} else {
|
||||
return scale_cvt_fp16x4_to_fp4x4_rn(input, scale_fp32x2);
|
||||
}
|
||||
}
|
||||
|
||||
template <bool USE_SR>
|
||||
__device__ __forceinline__ uint16_t
|
||||
scale_cvt_f32x4_to_fp4x4(const f32x4 input, const float scale, uint32_t rbits) {
|
||||
float2 scale_fp32x2 = make_float2(scale, scale);
|
||||
float2 input_fp32x2_0 = make_float2(input.x, input.y);
|
||||
float2 input_fp32x2_1 = make_float2(input.z, input.w);
|
||||
|
||||
if constexpr (USE_SR) {
|
||||
return scale_cvt_fp32x4_to_fp4x4_rs(
|
||||
input_fp32x2_0, input_fp32x2_1, scale_fp32x2, rbits);
|
||||
} else {
|
||||
return scale_cvt_fp32x4_to_fp4x4_rn(
|
||||
input_fp32x2_0, input_fp32x2_1, scale_fp32x2);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, bool USE_SR>
|
||||
__device__ __forceinline__ uint16_t scale_cvt_Tx4_to_fp4x4_fast(
|
||||
const Vector4_t<T> input,
|
||||
const float scale,
|
||||
uint32_t rbits) {
|
||||
if constexpr (std::is_same<T, __nv_bfloat16>::value) {
|
||||
return scale_cvt_bf16x4_to_fp4x4<USE_SR>(input, scale, rbits);
|
||||
} else if constexpr (std::is_same<T, __half>::value) {
|
||||
return scale_cvt_fp16x4_to_fp4x4<USE_SR>(input, scale, rbits);
|
||||
} else {
|
||||
return scale_cvt_f32x4_to_fp4x4<USE_SR>(input, scale, rbits);
|
||||
}
|
||||
}
|
||||
#endif // (CUDART_VERSION >= 12080) && (__CUDA_ARCH__ >= 1000) &&
|
||||
// (__CUDA_ARCH_FAMILY_SPECIFIC__ >= 1000)
|
||||
|
||||
template <typename T, bool USE_SR>
|
||||
__device__ __forceinline__ uint16_t scale_cvt_Tx4_to_fp4x4(
|
||||
const Vector4_t<T> input,
|
||||
const float scale,
|
||||
uint32_t rbits) {
|
||||
#if (CUDART_VERSION >= 12080) && (__CUDA_ARCH__ >= 1000) && \
|
||||
(__CUDA_ARCH_FAMILY_SPECIFIC__ >= 1000)
|
||||
return scale_cvt_Tx4_to_fp4x4_fast<T, USE_SR>(input, scale, rbits);
|
||||
#else
|
||||
static_assert(
|
||||
!USE_SR,
|
||||
"Stochastic rounding (USE_SR=true) requires CUDA >= 12.8 and compute capability >= 1000.");
|
||||
return scale_cvt_Tx4_to_fp4x4_fallback(input, scale);
|
||||
#endif
|
||||
}
|
||||
} // namespace mlx::core::cu
|
||||
@@ -15,6 +15,22 @@ inline constexpr __device__ short get_bytes_per_pack() {
|
||||
return power_of_2_bits ? (wsize / 8) : (bits == 5 ? 5 : 3);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__device__ __forceinline__ void abs_max_x2(T& out, const T& x1, const T& x2) {
|
||||
if constexpr (
|
||||
(std::is_same<T, __nv_bfloat162>::value) ||
|
||||
(std::is_same<T, __half2>::value)) {
|
||||
T a = x1;
|
||||
T b = x2;
|
||||
out = __hmax2(__habs2(a), __habs2(b));
|
||||
} else if constexpr (std::is_same<T, float2>::value) {
|
||||
float2 a = x1;
|
||||
float2 b = x2;
|
||||
out.x = fmaxf(fabsf(a.x), fabsf(b.x));
|
||||
out.y = fmaxf(fabsf(a.y), fabsf(b.y));
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace cu
|
||||
|
||||
template <typename F>
|
||||
|
||||
@@ -3,31 +3,10 @@
|
||||
#pragma once
|
||||
|
||||
#include "mlx/backend/cuda/steel/utils.cuh"
|
||||
#include "mlx/backend/cuda/vector_types.cuh"
|
||||
|
||||
namespace mlx::core::cu {
|
||||
|
||||
// Map types to their vector of 2 type float -> float2, double -> double2 etc
|
||||
template <typename T>
|
||||
struct Vector2;
|
||||
template <>
|
||||
struct Vector2<double> {
|
||||
using type = double2;
|
||||
};
|
||||
template <>
|
||||
struct Vector2<float> {
|
||||
using type = float2;
|
||||
};
|
||||
template <>
|
||||
struct Vector2<__half> {
|
||||
using type = __half2;
|
||||
};
|
||||
template <>
|
||||
struct Vector2<__nv_bfloat16> {
|
||||
using type = __nv_bfloat162;
|
||||
};
|
||||
template <typename T>
|
||||
using Vector2_t = typename Vector2<T>::type;
|
||||
|
||||
/**
|
||||
* The basic building block for Ampere mmas. A 16x16 tile distributed across
|
||||
* the warp.
|
||||
|
||||
@@ -80,7 +80,6 @@ CudaGraph::CudaGraph(cu::Device& device) {
|
||||
}
|
||||
|
||||
void CudaGraph::end_capture(cudaStream_t stream) {
|
||||
assert(handle_ == nullptr);
|
||||
CHECK_CUDA_ERROR(cudaStreamEndCapture(stream, &handle_));
|
||||
}
|
||||
|
||||
|
||||
48
mlx/backend/cuda/vector_types.cuh
Normal file
48
mlx/backend/cuda/vector_types.cuh
Normal file
@@ -0,0 +1,48 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cuda_bf16.h>
|
||||
#include <cuda_fp16.h>
|
||||
|
||||
namespace mlx::core::cu {
|
||||
|
||||
template <typename T>
|
||||
struct Vector2;
|
||||
|
||||
template <>
|
||||
struct Vector2<double> {
|
||||
using type = double2;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct Vector2<float> {
|
||||
using type = float2;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct Vector2<__half> {
|
||||
using type = __half2;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct Vector2<__nv_bfloat16> {
|
||||
using type = __nv_bfloat162;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
using Vector2_t = typename Vector2<T>::type;
|
||||
|
||||
template <typename T>
|
||||
struct Vector4 {
|
||||
T x, y, z, w;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
using Vector4_t = Vector4<T>;
|
||||
|
||||
using bf16x4 = Vector4_t<__nv_bfloat16>;
|
||||
using fp16x4 = Vector4_t<__half>;
|
||||
using fp32x4 = Vector4_t<float>;
|
||||
|
||||
} // namespace mlx::core::cu
|
||||
@@ -7,8 +7,6 @@
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
void copy_gpu(const array& in, array& out, CopyType ctype, const Stream& s);
|
||||
|
||||
void copy_gpu(const array& in, array& out, CopyType ctype) {
|
||||
copy_gpu(in, out, ctype, out.primitive().stream());
|
||||
}
|
||||
|
||||
@@ -149,7 +149,9 @@ Buffer MetalAllocator::malloc(size_t size) {
|
||||
buf = device_->newBuffer(size, resource_options);
|
||||
}
|
||||
if (!buf) {
|
||||
return Buffer{nullptr};
|
||||
std::ostringstream msg;
|
||||
msg << "[malloc] Unable to allocate " << size << " bytes.";
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
lk.lock();
|
||||
num_resources_++;
|
||||
@@ -201,6 +203,32 @@ size_t MetalAllocator::size(Buffer buffer) const {
|
||||
return static_cast<MTL::Buffer*>(buffer.ptr())->length();
|
||||
}
|
||||
|
||||
Buffer MetalAllocator::make_buffer(void* ptr, size_t size) {
|
||||
auto buf = device_->newBuffer(ptr, size, resource_options, nullptr);
|
||||
if (!buf) {
|
||||
return Buffer{nullptr};
|
||||
}
|
||||
std::unique_lock lk(mutex_);
|
||||
residency_set_.insert(buf);
|
||||
active_memory_ += buf->length();
|
||||
peak_memory_ = std::max(peak_memory_, active_memory_);
|
||||
num_resources_++;
|
||||
return Buffer{static_cast<void*>(buf)};
|
||||
}
|
||||
|
||||
void MetalAllocator::release(Buffer buffer) {
|
||||
auto buf = static_cast<MTL::Buffer*>(buffer.ptr());
|
||||
if (buf == nullptr) {
|
||||
return;
|
||||
}
|
||||
std::unique_lock lk(mutex_);
|
||||
active_memory_ -= buf->length();
|
||||
num_resources_--;
|
||||
lk.unlock();
|
||||
auto pool = metal::new_scoped_memory_pool();
|
||||
buf->release();
|
||||
}
|
||||
|
||||
MetalAllocator& allocator() {
|
||||
// By creating the |allocator_| on heap, the destructor of MetalAllocator
|
||||
// will not be called on exit and buffers in the cache will be leaked. This
|
||||
|
||||
@@ -21,6 +21,9 @@ class MetalAllocator : public allocator::Allocator {
|
||||
virtual Buffer malloc(size_t size) override;
|
||||
virtual void free(Buffer buffer) override;
|
||||
virtual size_t size(Buffer buffer) const override;
|
||||
virtual Buffer make_buffer(void* ptr, size_t size) override;
|
||||
virtual void release(Buffer buffer) override;
|
||||
|
||||
size_t get_active_memory() {
|
||||
return active_memory_;
|
||||
};
|
||||
|
||||
@@ -347,7 +347,7 @@ template <
|
||||
MMAFrag_mask_t::load_safe(
|
||||
mfrag,
|
||||
mask,
|
||||
int(mask_params->M_strides[2]),
|
||||
int64_t(mask_params->M_strides[2]),
|
||||
Int<1>{},
|
||||
params->qL,
|
||||
params->kL,
|
||||
|
||||
@@ -346,7 +346,7 @@ template <
|
||||
MSubTile mfrag;
|
||||
mfrag.load_safe(
|
||||
mask,
|
||||
int(mask_params->M_strides[2]),
|
||||
int64_t(mask_params->M_strides[2]),
|
||||
Int<1>{},
|
||||
params->qL,
|
||||
params->kL,
|
||||
|
||||
@@ -105,17 +105,20 @@ struct BaseMMAFrag<T, 8, 8> {
|
||||
LimY lim_y,
|
||||
OffX off_x = Int<0>{},
|
||||
OffY off_y = Int<0>{}) {
|
||||
src += off_x * str_x + off_y * str_y;
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < kElemRows; i++) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < kElemCols; j++) {
|
||||
if ((off_x + i) < lim_x && (off_y + j) < lim_y) {
|
||||
dst[i * kElemCols + j] =
|
||||
static_cast<T>(src[(off_x + i) * str_x + (off_y + j) * str_y]);
|
||||
dst[i * kElemCols + j] = static_cast<T>(src[0]);
|
||||
} else {
|
||||
dst[i * kElemCols + j] = T(0);
|
||||
}
|
||||
src += str_y;
|
||||
}
|
||||
src -= kElemCols * str_y;
|
||||
src += str_x;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -25,6 +25,7 @@ class CommonAllocator : public Allocator {
|
||||
virtual Buffer malloc(size_t size) override;
|
||||
virtual void free(Buffer buffer) override;
|
||||
virtual size_t size(Buffer buffer) const override;
|
||||
|
||||
size_t get_active_memory() const {
|
||||
return active_memory_;
|
||||
};
|
||||
|
||||
@@ -880,6 +880,11 @@ std::vector<array> ScaledDotProductAttention::vjp(
|
||||
|
||||
std::vector<array> returned_vjps;
|
||||
for (int arg : argnums) {
|
||||
if (arg >= 3) {
|
||||
throw std::invalid_argument(
|
||||
"[scale_dot_product_attention] Does not support VJP with respect "
|
||||
" to mask or attention sinks.");
|
||||
}
|
||||
returned_vjps.push_back(std::move(vjps[arg]));
|
||||
}
|
||||
return returned_vjps;
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
[build-system]
|
||||
requires = [
|
||||
"setuptools>=80",
|
||||
"nanobind==2.4.0",
|
||||
"nanobind==2.10.2",
|
||||
"cmake>=3.25",
|
||||
]
|
||||
build-backend = "setuptools.build_meta"
|
||||
|
||||
@@ -89,7 +89,8 @@ static PyType_Spec gc_func_spec = {
|
||||
/* .name = */ "mlx.gc_func",
|
||||
/* .basicsize = */ (int)sizeof(gc_func),
|
||||
/* .itemsize = */ 0,
|
||||
/* .flags = */ Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HAVE_GC | NB_HAVE_VECTORCALL,
|
||||
/* .flags = */ Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HAVE_GC |
|
||||
Py_TPFLAGS_HAVE_VECTORCALL,
|
||||
/* .slots = */ gc_func_slots};
|
||||
|
||||
static PyTypeObject* gc_func_tp = nullptr;
|
||||
|
||||
@@ -16,8 +16,7 @@ struct type_caster<mlx::core::SmallVector<Type, Size, Alloc>> {
|
||||
|
||||
NB_TYPE_CASTER(
|
||||
List,
|
||||
const_name(NB_TYPING_TUPLE "[") + make_caster<Type>::Name +
|
||||
const_name(", ...]"))
|
||||
const_name("tuple[") + make_caster<Type>::Name + const_name(", ...]"))
|
||||
|
||||
bool from_python(handle src, uint8_t flags, cleanup_list* cleanup) noexcept {
|
||||
size_t size;
|
||||
|
||||
@@ -124,37 +124,53 @@ auto py_value_and_grad(
|
||||
|
||||
// Collect the arrays
|
||||
std::vector<mx::array> arrays;
|
||||
std::vector<nb::object> array_objects;
|
||||
auto flatten_with_objects = [&arrays, &array_objects](
|
||||
auto tree, bool strict) {
|
||||
tree_visit(tree, [&](nb::handle obj) {
|
||||
if (nb::isinstance<mx::array>(obj)) {
|
||||
arrays.push_back(nb::cast<mx::array>(obj));
|
||||
array_objects.push_back(nb::borrow<nb::object>(obj));
|
||||
} else if (strict) {
|
||||
throw std::invalid_argument(
|
||||
"[tree_flatten] The argument should contain only arrays");
|
||||
}
|
||||
});
|
||||
};
|
||||
|
||||
std::vector<int> counts(1, 0);
|
||||
std::vector<int> gradient_indices;
|
||||
for (int i = 0, j = 0; i < args.size(); ++i) {
|
||||
bool needs_grad = (j < argnums.size() && argnums[j] == i);
|
||||
auto argsi = tree_flatten(args[i], /* strict = */ needs_grad);
|
||||
auto pre_size = arrays.size();
|
||||
flatten_with_objects(args[i], /* strict = */ needs_grad);
|
||||
if (needs_grad) {
|
||||
auto old_size = gradient_indices.size();
|
||||
gradient_indices.resize(old_size + argsi.size());
|
||||
auto delta_size = arrays.size() - pre_size;
|
||||
gradient_indices.resize(old_size + delta_size);
|
||||
std::iota(
|
||||
gradient_indices.begin() + old_size,
|
||||
gradient_indices.end(),
|
||||
arrays.size());
|
||||
pre_size);
|
||||
j++;
|
||||
counts.push_back(argsi.size());
|
||||
counts.push_back(delta_size);
|
||||
}
|
||||
arrays.insert(arrays.end(), argsi.begin(), argsi.end());
|
||||
}
|
||||
for (auto item : kwargs) {
|
||||
bool needs_grad =
|
||||
(argnames.find(nb::cast<std::string>(item.first)) != argnames.end());
|
||||
auto argsk = tree_flatten(item.second, /* strict = */ needs_grad);
|
||||
auto pre_size = arrays.size();
|
||||
flatten_with_objects(item.second, /* strict = */ needs_grad);
|
||||
if (needs_grad) {
|
||||
auto old_size = gradient_indices.size();
|
||||
gradient_indices.resize(old_size + argsk.size());
|
||||
auto delta_size = arrays.size() - pre_size;
|
||||
gradient_indices.resize(old_size + delta_size);
|
||||
std::iota(
|
||||
gradient_indices.begin() + old_size,
|
||||
gradient_indices.end(),
|
||||
arrays.size());
|
||||
counts.push_back(argsk.size());
|
||||
pre_size);
|
||||
counts.push_back(delta_size);
|
||||
}
|
||||
arrays.insert(arrays.end(), argsk.begin(), argsk.end());
|
||||
}
|
||||
std::partial_sum(counts.cbegin(), counts.cend(), counts.begin());
|
||||
|
||||
@@ -163,7 +179,7 @@ auto py_value_and_grad(
|
||||
nb::object py_value_out;
|
||||
auto value_and_grads = mx::value_and_grad(
|
||||
[&fun,
|
||||
&arrays,
|
||||
&array_objects,
|
||||
&args,
|
||||
&kwargs,
|
||||
&py_value_out,
|
||||
@@ -183,8 +199,9 @@ auto py_value_and_grad(
|
||||
tree_visit_update(tree, [&](nb::handle node) {
|
||||
auto replace_arr = nb::cast<mx::array>(node);
|
||||
if (replace_arr.id() == a[index].id()) {
|
||||
return nb::cast(arrays[index++]);
|
||||
return array_objects[index++];
|
||||
} else {
|
||||
index++;
|
||||
return nb::cast(replace_arr);
|
||||
}
|
||||
});
|
||||
|
||||
@@ -780,9 +780,21 @@ class TestAutograd(mlx_tests.MLXTestCase):
|
||||
return arrs[0]
|
||||
|
||||
arrs = [mx.array(1.0)]
|
||||
init_id = id(arrs[0])
|
||||
arr = arrs[0]
|
||||
mx.grad(fun)(arrs)
|
||||
self.assertEqual(init_id, id(arrs[0]))
|
||||
self.assertEqual(id(arr), id(arrs[0]))
|
||||
|
||||
def fun(arrs):
|
||||
arrs[1] = sum(arrs)
|
||||
return arrs[1]
|
||||
|
||||
arrs = [mx.array(1.0), mx.array(1.0), mx.array(1.0)]
|
||||
a_0, a_1, a_2 = arrs
|
||||
|
||||
mx.grad(fun)(arrs)
|
||||
self.assertEqual(id(a_0), id(arrs[0]))
|
||||
self.assertNotEqual(id(a_1), id(arrs[1]))
|
||||
self.assertEqual(id(a_2), id(arrs[2]))
|
||||
|
||||
def test_grad_with_inplace_update(self):
|
||||
def loss_fn(model):
|
||||
|
||||
@@ -4,12 +4,12 @@ import gc
|
||||
import inspect
|
||||
import io
|
||||
import math
|
||||
import unittest
|
||||
from functools import partial, wraps
|
||||
from io import StringIO
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx_tests
|
||||
import numpy as np
|
||||
|
||||
|
||||
class TestCompile(mlx_tests.MLXTestCase):
|
||||
@@ -1252,6 +1252,26 @@ class TestCompile(mlx_tests.MLXTestCase):
|
||||
loss, grads = step(emb, w, x)
|
||||
mx.eval(loss, grads)
|
||||
|
||||
def test_compile_donates_input_buffer(self):
|
||||
mx.set_default_device(mx.cpu)
|
||||
|
||||
def fun(x):
|
||||
return mx.sin(x) + 1
|
||||
|
||||
compiled_fn = mx.compile(fun)
|
||||
|
||||
input = mx.arange(16, dtype=mx.float32)
|
||||
mx.eval(input)
|
||||
in_ptr = np.asarray(input, copy=False).__array_interface__["data"][0]
|
||||
|
||||
out = compiled_fn(input)
|
||||
del input # Ensure the reference is dropped
|
||||
mx.eval(out)
|
||||
|
||||
self.assertEqual(
|
||||
np.asarray(out, copy=False).__array_interface__["data"][0], in_ptr
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
mlx_tests.MLXTestRunner()
|
||||
|
||||
@@ -744,7 +744,6 @@ class TestVmap(mlx_tests.MLXTestCase):
|
||||
return Vector([t[0] + 10, t[1] * 10])
|
||||
|
||||
x = State(mx.array(1), mx.array(2))
|
||||
print(f"{transform(x)=}")
|
||||
|
||||
vmap_transform = mx.vmap(transform)
|
||||
vmap_transform_tuple = mx.vmap(transform_tuple)
|
||||
|
||||
42
setup.py
42
setup.py
@@ -7,13 +7,21 @@ import re
|
||||
import subprocess
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
from subprocess import run
|
||||
|
||||
from setuptools import Command, Extension, find_namespace_packages, setup
|
||||
from setuptools.command.bdist_wheel import bdist_wheel
|
||||
from setuptools.command.build_ext import build_ext
|
||||
|
||||
|
||||
def cuda_toolkit_major_version():
|
||||
out = subprocess.check_output(["nvcc", "--version"], stderr=subprocess.STDOUT)
|
||||
text = out.decode()
|
||||
m = re.search(r"release (\d+)", text)
|
||||
if m:
|
||||
return int(m.group(1))
|
||||
return None
|
||||
|
||||
|
||||
def get_version():
|
||||
with open("mlx/version.h", "r") as fid:
|
||||
for l in fid:
|
||||
@@ -31,7 +39,7 @@ def get_version():
|
||||
version = f"{version}.dev{today.year}{today.month:02d}{today.day:02d}"
|
||||
if not pypi_release and not dev_release:
|
||||
git_hash = (
|
||||
run(
|
||||
subprocess.run(
|
||||
"git rev-parse --short HEAD".split(),
|
||||
capture_output=True,
|
||||
check=True,
|
||||
@@ -247,7 +255,7 @@ if __name__ == "__main__":
|
||||
|
||||
extras = {
|
||||
"dev": [
|
||||
"nanobind==2.4.0",
|
||||
"nanobind==2.10.2",
|
||||
"numpy",
|
||||
"pre-commit",
|
||||
"setuptools>=80",
|
||||
@@ -284,7 +292,11 @@ if __name__ == "__main__":
|
||||
install_requires.append(
|
||||
f'mlx-metal=={version}; platform_system == "Darwin"'
|
||||
)
|
||||
extras["cuda"] = [f'mlx-cuda=={version}; platform_system == "Linux"']
|
||||
extras["cuda"] = [f'mlx-cuda-12=={version}; platform_system == "Linux"']
|
||||
for toolkit in [12, 13]:
|
||||
extras[f"cuda{toolkit}"] = [
|
||||
f'mlx-cuda-{toolkit}=={version}; platform_system == "Linux"'
|
||||
]
|
||||
extras["cpu"] = [f'mlx-cpu=={version}; platform_system == "Linux"']
|
||||
|
||||
_setup(
|
||||
@@ -299,13 +311,25 @@ if __name__ == "__main__":
|
||||
if build_macos:
|
||||
name = "mlx-metal"
|
||||
elif build_cuda:
|
||||
name = "mlx-cuda"
|
||||
toolkit = cuda_toolkit_major_version()
|
||||
name = f"mlx-cuda-{toolkit}"
|
||||
if toolkit == 12:
|
||||
install_requires += [
|
||||
"nvidia-cublas-cu12==12.9.*",
|
||||
"nvidia-cuda-nvrtc-cu12==12.9.*",
|
||||
]
|
||||
elif toolkit == 13:
|
||||
install_requires += [
|
||||
"nvidia-cublas-cu13",
|
||||
"nvidia-cuda-nvrtc-cu13",
|
||||
]
|
||||
else:
|
||||
raise ValueError(f"Unknown toolkit {toolkit}")
|
||||
install_requires += [
|
||||
"nvidia-cublas-cu12==12.9.*",
|
||||
"nvidia-cuda-nvrtc-cu12==12.9.*",
|
||||
"nvidia-cudnn-cu12==9.*",
|
||||
"nvidia-nccl-cu12",
|
||||
f"nvidia-cudnn-cu{toolkit}==9.*",
|
||||
f"nvidia-nccl-cu{toolkit}",
|
||||
]
|
||||
|
||||
else:
|
||||
name = "mlx-cpu"
|
||||
_setup(
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
|
||||
#include <climits>
|
||||
|
||||
#include "doctest/doctest.h"
|
||||
@@ -608,3 +607,24 @@ TEST_CASE("test make empty array") {
|
||||
CHECK_EQ(a.size(), 0);
|
||||
CHECK_EQ(a.dtype(), bool_);
|
||||
}
|
||||
|
||||
TEST_CASE("test make array from user buffer") {
|
||||
int size = 4096;
|
||||
std::vector<int> buffer(size, 0);
|
||||
|
||||
int count = 0;
|
||||
auto deleter = [&count](void*) { count++; };
|
||||
|
||||
{
|
||||
auto a = array(buffer.data(), Shape{size}, int32, deleter);
|
||||
if (metal::is_available()) {
|
||||
CHECK_EQ(buffer.data(), a.data<int>());
|
||||
}
|
||||
auto b = a + array(1);
|
||||
eval(b);
|
||||
auto expected = ones({4096});
|
||||
CHECK(array_equal(b, expected).item<bool>());
|
||||
}
|
||||
// deleter should always get called
|
||||
CHECK_EQ(count, 1);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user