Compare commits

...

5 Commits

Author SHA1 Message Date
Cheng
6245824d42 Make allocator::malloc throw on allocation failure (#2874)
Some checks failed
Build and Test / Check Lint (push) Has been cancelled
Build and Test / Linux (cpu, aarch64) (push) Has been cancelled
Build and Test / Linux (cpu, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, x86_64) (push) Has been cancelled
Build and Test / macOS (14.0) (push) Has been cancelled
Build and Test / macOS (15.0) (push) Has been cancelled
Build and Test / Build Documentation (push) Has been cancelled
Build and Test / Linux Fedora (aarch64) (push) Has been cancelled
Build and Test / Linux Fedora (x86_64) (push) Has been cancelled
2025-12-05 17:44:38 +09:00
Awni Hannun
39289ef025 [CUDA] Release build for cuda 13 (#2872) 2025-12-04 21:42:26 -08:00
Awni Hannun
aefc9bd3f6 [CUDA] Faster general copy (#2873) 2025-12-04 21:42:15 -08:00
Angelos Katharopoulos
997cfc7699 Add a 2-pass col reduce for CUDA (#2863)
Some checks failed
Build and Test / Check Lint (push) Has been cancelled
Build and Test / Linux (cpu, aarch64) (push) Has been cancelled
Build and Test / Linux (cpu, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, x86_64) (push) Has been cancelled
Build and Test / macOS (14.0) (push) Has been cancelled
Build and Test / macOS (15.0) (push) Has been cancelled
Build and Test / Build Documentation (push) Has been cancelled
Build and Test / Linux Fedora (aarch64) (push) Has been cancelled
Build and Test / Linux Fedora (x86_64) (push) Has been cancelled
Nightly Build / build_linux_release (3.10) (push) Has been cancelled
Nightly Build / build_linux_release (3.14) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_mac_release (3.10) (push) Has been cancelled
Nightly Build / build_mac_release (3.13) (push) Has been cancelled
Nightly Build / build_cuda_release (push) Has been cancelled
2025-12-04 15:53:59 -08:00
Awni Hannun
1fa8dc5797 Do a PyPi release for cuda on arm (#2866) 2025-12-04 15:28:29 -08:00
15 changed files with 258 additions and 69 deletions

View File

@@ -1,6 +1,15 @@
name: 'Build CUDA wheel' name: 'Build CUDA wheel'
description: 'Build CUDA wheel' description: 'Build CUDA wheel'
inputs:
arch:
description: 'Platform architecture tag'
required: true
type: choice
options:
- x86_64
- aarch64
runs: runs:
using: "composite" using: "composite"
steps: steps:
@@ -12,4 +21,4 @@ runs:
pip install auditwheel build patchelf setuptools pip install auditwheel build patchelf setuptools
python setup.py clean --all python setup.py clean --all
MLX_BUILD_STAGE=2 python -m build -w MLX_BUILD_STAGE=2 python -m build -w
bash python/scripts/repair_cuda.sh bash python/scripts/repair_cuda.sh ${{ inputs.arch }}

View File

@@ -15,6 +15,7 @@ runs:
using: "composite" using: "composite"
steps: steps:
- name: Use ccache - name: Use ccache
if: ${{ runner.arch == 'x86_64' }}
uses: hendrikmuhs/ccache-action@v1.2 uses: hendrikmuhs/ccache-action@v1.2
with: with:
key: ccache-${{ runner.os }}-${{ runner.arch }}-${{ inputs.toolkit }}-py${{ inputs.python-version }} key: ccache-${{ runner.os }}-${{ runner.arch }}-${{ inputs.toolkit }}-py${{ inputs.python-version }}

View File

@@ -128,7 +128,11 @@ jobs:
build_cuda_release: build_cuda_release:
if: github.repository == 'ml-explore/mlx' if: github.repository == 'ml-explore/mlx'
runs-on: ubuntu-22-large strategy:
matrix:
arch: ['x86_64', 'aarch64']
toolkit: ['cuda-12.9', 'cuda-13.0']
runs-on: ${{ matrix.arch == 'x86_64' && 'ubuntu-22-large' || 'ubuntu-22-large-arm' }}
env: env:
PYPI_RELEASE: 1 PYPI_RELEASE: 1
DEV_RELEASE: ${{ github.event.inputs.dev_release == 'true' && 1 || 0 }} DEV_RELEASE: ${{ github.event.inputs.dev_release == 'true' && 1 || 0 }}
@@ -136,9 +140,11 @@ jobs:
- uses: actions/checkout@v6 - uses: actions/checkout@v6
- uses: ./.github/actions/setup-linux - uses: ./.github/actions/setup-linux
with: with:
toolkit: 'cuda-12.9' toolkit: ${{ matrix.toolkit }}
- name: Build Python package - name: Build Python package
uses: ./.github/actions/build-cuda-release uses: ./.github/actions/build-cuda-release
with:
arch: ${{ matrix.arch }}
- name: Upload artifacts - name: Upload artifacts
uses: actions/upload-artifact@v5 uses: actions/upload-artifact@v5
with: with:

View File

@@ -29,17 +29,20 @@ MLX has a CUDA backend which you can install with:
.. code-block:: shell .. code-block:: shell
pip install mlx[cuda] pip install mlx[cuda12]
To install the CUDA package from PyPi your system must meet the following To install the CUDA package from PyPi your system must meet the following
requirements: requirements:
- Nvidia architecture >= SM 7.0 (Volta) - Nvidia architecture >= SM 7.5
- Nvidia driver >= 550.54.14 - Nvidia driver >= 550.54.14
- CUDA toolkit >= 12.0 - CUDA toolkit >= 12.0
- Linux distribution with glibc >= 2.35 - Linux distribution with glibc >= 2.35
- Python >= 3.10 - Python >= 3.10
For CUDA 13 use ``pip install mlx[cuda13]``. The CUDA 13 package requires
an Nvidia driver >= 580 or an appropriate CUDA compatibility package.
CPU-only (Linux) CPU-only (Linux)
^^^^^^^^^^^^^^^^ ^^^^^^^^^^^^^^^^

View File

@@ -1,7 +1,6 @@
target_sources( target_sources(
mlx mlx
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/array.cpp
${CMAKE_CURRENT_SOURCE_DIR}/array.cpp
${CMAKE_CURRENT_SOURCE_DIR}/compile.cpp ${CMAKE_CURRENT_SOURCE_DIR}/compile.cpp
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp ${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
${CMAKE_CURRENT_SOURCE_DIR}/dtype.cpp ${CMAKE_CURRENT_SOURCE_DIR}/dtype.cpp

View File

@@ -1,24 +0,0 @@
// Copyright © 2023 Apple Inc.
#include <cstdlib>
#include <sstream>
#include "mlx/allocator.h"
namespace mlx::core::allocator {
Buffer malloc(size_t size) {
auto buffer = allocator().malloc(size);
if (size && !buffer.ptr()) {
std::ostringstream msg;
msg << "[malloc] Unable to allocate " << size << " bytes.";
throw std::runtime_error(msg.str());
}
return buffer;
}
void free(Buffer buffer) {
allocator().free(buffer);
}
} // namespace mlx::core::allocator

View File

@@ -28,10 +28,6 @@ class Buffer {
}; };
}; };
Buffer malloc(size_t size);
void free(Buffer buffer);
class Allocator { class Allocator {
/** Abstract base class for a memory allocator. */ /** Abstract base class for a memory allocator. */
public: public:
@@ -49,4 +45,12 @@ class Allocator {
Allocator& allocator(); Allocator& allocator();
inline Buffer malloc(size_t size) {
return allocator().malloc(size);
}
inline void free(Buffer buffer) {
allocator().free(buffer);
}
} // namespace mlx::core::allocator } // namespace mlx::core::allocator

View File

@@ -157,16 +157,14 @@ CudaAllocator::malloc_async(size_t size, int device, cudaStream_t stream) {
cudaError_t err; cudaError_t err;
void* data = nullptr; void* data = nullptr;
if (device == -1) { if (device == -1) {
err = cudaMallocManaged(&data, size); CHECK_CUDA_ERROR(cudaMallocManaged(&data, size));
} else { } else {
err = cudaMallocAsync(&data, size, stream); CHECK_CUDA_ERROR(cudaMallocAsync(&data, size, stream));
}
if (err != cudaSuccess && err != cudaErrorMemoryAllocation) {
throw std::runtime_error(fmt::format(
"cudaMallocManaged failed: {}.", cudaGetErrorString(err)));
} }
if (!data) { if (!data) {
return Buffer{nullptr}; std::ostringstream msg;
msg << "[malloc] Unable to allocate " << size << " bytes.";
throw std::runtime_error(msg.str());
} }
buf = new CudaBuffer{data, size, device}; buf = new CudaBuffer{data, size, device};
} }

View File

@@ -95,11 +95,14 @@ void copy_general_input(
const InType* in_ptr = gpu_ptr<InType>(in) + offset_in; const InType* in_ptr = gpu_ptr<InType>(in) + offset_in;
OutType* out_ptr = gpu_ptr<OutType>(out) + offset_out; OutType* out_ptr = gpu_ptr<OutType>(out) + offset_out;
int ndim = shape.size(); int ndim = shape.size();
int work_per_thread = 1;
int work_per_thread = 8;
auto dim0 = ndim > 0 ? shape.back() : 1; auto dim0 = ndim > 0 ? shape.back() : 1;
auto rest = out.size() / dim0; auto rest = out.size() / dim0;
if (dim0 >= 4) { if (dim0 >= 4 && dim0 < 8) {
work_per_thread = 4; work_per_thread = 4;
} else if (dim0 < 4) {
work_per_thread = 1;
} }
dim0 = (dim0 + work_per_thread - 1) / work_per_thread; dim0 = (dim0 + work_per_thread - 1) / work_per_thread;
auto block_dims = get_block_dims(dim0, rest, 1); auto block_dims = get_block_dims(dim0, rest, 1);
@@ -110,7 +113,10 @@ void copy_general_input(
dispatch_1_2_3(ndim, [&](auto dims_constant) { dispatch_1_2_3(ndim, [&](auto dims_constant) {
auto kernel = auto kernel =
cu::copy_g_nd<InType, OutType, IdxT, dims_constant(), 1>; cu::copy_g_nd<InType, OutType, IdxT, dims_constant(), 1>;
if (work_per_thread == 4) { if (work_per_thread == 8) {
kernel =
cu::copy_g_nd<InType, OutType, IdxT, dims_constant(), 8>;
} else if (work_per_thread == 4) {
kernel = kernel =
cu::copy_g_nd<InType, OutType, IdxT, dims_constant(), 4>; cu::copy_g_nd<InType, OutType, IdxT, dims_constant(), 4>;
} }
@@ -127,7 +133,9 @@ void copy_general_input(
}); });
} else { // ndim >= 4 } else { // ndim >= 4
auto kernel = cu::copy_g<InType, OutType, IdxT, 1>; auto kernel = cu::copy_g<InType, OutType, IdxT, 1>;
if (work_per_thread == 4) { if (work_per_thread == 8) {
kernel = cu::copy_g<InType, OutType, IdxT, 8>;
} else if (work_per_thread == 4) {
kernel = cu::copy_g<InType, OutType, IdxT, 4>; kernel = cu::copy_g<InType, OutType, IdxT, 4>;
} }
encoder.add_kernel_node( encoder.add_kernel_node(

View File

@@ -89,9 +89,13 @@ template <
int NDIM, int NDIM,
int BM, int BM,
int BN, int BN,
int N_READS = 4> int N_READS = 4,
__global__ void int BLOCKS = 1>
col_reduce_looped(T* in, U* out, const __grid_constant__ ColReduceArgs args) { __global__ void col_reduce_looped(
T* in,
U* out,
const __grid_constant__ ColReduceArgs args,
int64_t out_size) {
auto grid = cg::this_grid(); auto grid = cg::this_grid();
auto block = cg::this_thread_block(); auto block = cg::this_thread_block();
auto warp = cg::tiled_partition<WARP_SIZE>(block); auto warp = cg::tiled_partition<WARP_SIZE>(block);
@@ -102,6 +106,8 @@ col_reduce_looped(T* in, U* out, const __grid_constant__ ColReduceArgs args) {
size_t tile_idx = grid.block_rank(); size_t tile_idx = grid.block_rank();
size_t tile_x = tile_idx % ((args.reduction_stride + BN - 1) / BN); size_t tile_x = tile_idx % ((args.reduction_stride + BN - 1) / BN);
size_t tile_y = tile_idx / ((args.reduction_stride + BN - 1) / BN); size_t tile_y = tile_idx / ((args.reduction_stride + BN - 1) / BN);
size_t tile_out = tile_y / out_size;
tile_y = tile_y % out_size;
// Compute the indices for the thread within the tile // Compute the indices for the thread within the tile
short thread_x = block.thread_rank() % threads_per_row; short thread_x = block.thread_rank() % threads_per_row;
@@ -118,12 +124,23 @@ col_reduce_looped(T* in, U* out, const __grid_constant__ ColReduceArgs args) {
totals[i] = ReduceInit<Op, T>::value(); totals[i] = ReduceInit<Op, T>::value();
} }
LoopedElemToLoc<NDIM, (NDIM > 2)> loop(args.reduce_ndim);
loop.next(thread_y, args.reduce_shape.data(), args.reduce_strides.data());
size_t total = args.non_col_reductions * args.reduction_size; size_t total = args.non_col_reductions * args.reduction_size;
size_t per_block, start, end;
if constexpr (BLOCKS > 1) {
per_block = (total + BLOCKS - 1) / BLOCKS;
start = tile_out * per_block + thread_y;
end = min((tile_out + 1) * per_block, total);
} else {
per_block = total;
start = thread_y;
end = total;
}
LoopedElemToLoc<NDIM, (NDIM > 2)> loop(args.reduce_ndim);
loop.next(start, args.reduce_shape.data(), args.reduce_strides.data());
if (tile_x * BN + BN <= args.reduction_stride) { if (tile_x * BN + BN <= args.reduction_stride) {
if (args.reduction_stride % N_READS == 0) { if (args.reduction_stride % N_READS == 0) {
for (size_t r = thread_y; r < total; r += BM) { for (size_t r = start; r < end; r += BM) {
T vals[N_READS]; T vals[N_READS];
cub::LoadDirectBlockedVectorized(thread_x, in + loop.location(), vals); cub::LoadDirectBlockedVectorized(thread_x, in + loop.location(), vals);
for (int i = 0; i < N_READS; i++) { for (int i = 0; i < N_READS; i++) {
@@ -132,7 +149,7 @@ col_reduce_looped(T* in, U* out, const __grid_constant__ ColReduceArgs args) {
loop.next(BM, args.reduce_shape.data(), args.reduce_strides.data()); loop.next(BM, args.reduce_shape.data(), args.reduce_strides.data());
} }
} else { } else {
for (size_t r = thread_y; r < total; r += BM) { for (size_t r = start; r < end; r += BM) {
T vals[N_READS]; T vals[N_READS];
cub::LoadDirectBlocked(thread_x, in + loop.location(), vals); cub::LoadDirectBlocked(thread_x, in + loop.location(), vals);
for (int i = 0; i < N_READS; i++) { for (int i = 0; i < N_READS; i++) {
@@ -142,7 +159,7 @@ col_reduce_looped(T* in, U* out, const __grid_constant__ ColReduceArgs args) {
} }
} }
} else { } else {
for (size_t r = thread_y; r < total; r += BM) { for (size_t r = start; r < end; r += BM) {
T vals[N_READS]; T vals[N_READS];
cub::LoadDirectBlocked( cub::LoadDirectBlocked(
thread_x, thread_x,
@@ -173,6 +190,9 @@ col_reduce_looped(T* in, U* out, const __grid_constant__ ColReduceArgs args) {
// Write result. // Write result.
if (warp.thread_rank() == 0) { if (warp.thread_rank() == 0) {
if (BLOCKS > 1) {
out += tile_out * out_size * args.reduction_stride;
}
cub::StoreDirectBlocked( cub::StoreDirectBlocked(
warp.meta_group_rank(), warp.meta_group_rank(),
out + tile_y * args.reduction_stride + tile_x * BN, out + tile_y * args.reduction_stride + tile_x * BN,
@@ -227,11 +247,12 @@ __global__ void col_reduce_small(
inline auto output_grid_for_col_reduce( inline auto output_grid_for_col_reduce(
const array& out, const array& out,
const cu::ColReduceArgs& args, const cu::ColReduceArgs& args,
int bn) { int bn,
int outer = 1) {
int gx, gy = 1; int gx, gy = 1;
size_t n_inner_blocks = cuda::ceil_div(args.reduction_stride, bn); size_t n_inner_blocks = cuda::ceil_div(args.reduction_stride, bn);
size_t n_outer_blocks = out.size() / args.reduction_stride; size_t n_outer_blocks = out.size() / args.reduction_stride;
size_t n_blocks = n_outer_blocks * n_inner_blocks; size_t n_blocks = n_outer_blocks * n_inner_blocks * outer;
while (n_blocks / gy > INT32_MAX) { while (n_blocks / gy > INT32_MAX) {
gy *= 2; gy *= 2;
} }
@@ -277,7 +298,8 @@ void col_reduce_looped(
0, 0,
indata, indata,
gpu_ptr<U>(out), gpu_ptr<U>(out),
static_cast<cu::ColReduceArgs>(args)); static_cast<cu::ColReduceArgs>(args),
out.size() / args.reduction_stride);
}); });
}); });
}); });
@@ -320,6 +342,117 @@ void col_reduce_small(
}); });
} }
void col_reduce_two_pass(
cu::CommandEncoder& encoder,
const array& in,
array& out,
Reduce::ReduceType reduce_type,
const std::vector<int>& axes,
const ReductionPlan& plan,
const cu::ColReduceArgs& args) {
// Allocate data for the output using in's layout to access them as
// contiguously as possible.
allocate_same_layout(out, in, axes, encoder);
// Allocate an intermediate array to hold the 1st pass result
constexpr int outer = 32;
Shape intermediate_shape;
intermediate_shape.push_back(outer);
intermediate_shape.insert(
intermediate_shape.end(), out.shape().begin(), out.shape().end());
Strides intermediate_strides;
intermediate_strides.push_back(out.size());
intermediate_strides.insert(
intermediate_strides.end(), out.strides().begin(), out.strides().end());
array intermediate(intermediate_shape, out.dtype(), nullptr, {});
auto [data_size, rc, cc] =
check_contiguity(intermediate_shape, intermediate_strides);
auto fl = out.flags();
fl.row_contiguous = rc;
fl.col_contiguous = cc;
fl.contiguous = true;
intermediate.set_data(
cu::malloc_async(intermediate.nbytes(), encoder),
data_size,
intermediate_strides,
fl,
allocator::free);
encoder.add_temporary(intermediate);
encoder.set_input_array(in);
encoder.set_output_array(intermediate);
dispatch_all_types(in.dtype(), [&](auto type_tag) {
dispatch_reduce_ops(reduce_type, [&](auto reduce_type_tag) {
dispatch_reduce_ndim(args.reduce_ndim, [&](auto reduce_ndim) {
using OP = MLX_GET_TYPE(reduce_type_tag);
using T = cuda_type_t<MLX_GET_TYPE(type_tag)>;
using U = typename cu::ReduceResult<OP, T>::type;
// Cub doesn't like const pointers for vectorized loads. (sigh)
T* indata = const_cast<T*>(gpu_ptr<T>(in));
constexpr int N_READS = 4;
constexpr int BM = 32;
constexpr int BN = 32;
dim3 grid = output_grid_for_col_reduce(out, args, BN, outer);
int blocks = BM * BN / N_READS;
auto kernel = cu::
col_reduce_looped<T, U, OP, reduce_ndim(), BM, BN, N_READS, outer>;
encoder.add_kernel_node(
kernel,
grid,
blocks,
0,
indata,
gpu_ptr<U>(intermediate),
static_cast<cu::ColReduceArgs>(args),
out.size() / args.reduction_stride);
});
});
});
// Prepare the reduction arguments for the 2nd pass
cu::ColReduceArgs second_args = args;
second_args.reduction_size = outer;
second_args.reduction_stride = out.size();
second_args.ndim = 0;
second_args.reduce_shape[0] = outer;
second_args.reduce_strides[0] = out.size();
second_args.reduce_ndim = 1;
second_args.non_col_reductions = 1;
encoder.set_input_array(intermediate);
encoder.set_output_array(out);
dispatch_all_types(intermediate.dtype(), [&](auto type_tag) {
dispatch_reduce_ops(reduce_type, [&](auto reduce_type_tag) {
dispatch_reduce_ndim(second_args.reduce_ndim, [&](auto reduce_ndim) {
using OP = MLX_GET_TYPE(reduce_type_tag);
using T = cuda_type_t<MLX_GET_TYPE(type_tag)>;
using U = typename cu::ReduceResult<OP, T>::type;
constexpr int N_READS = 4;
constexpr int BM = 32;
constexpr int BN = 32;
dim3 grid = output_grid_for_col_reduce(out, second_args, BN);
int blocks = BM * BN / N_READS;
auto kernel =
cu::col_reduce_looped<T, U, OP, reduce_ndim(), BM, BN, N_READS>;
encoder.add_kernel_node(
kernel,
grid,
blocks,
0,
gpu_ptr<T>(intermediate),
gpu_ptr<U>(out),
second_args,
second_args.reduction_stride);
});
});
});
}
void col_reduce( void col_reduce(
cu::CommandEncoder& encoder, cu::CommandEncoder& encoder,
const array& in, const array& in,
@@ -334,6 +467,18 @@ void col_reduce(
// It is a general strided reduce. Each threadblock computes the output for // It is a general strided reduce. Each threadblock computes the output for
// a subrow of the fast moving axis. For instance 32 elements. // a subrow of the fast moving axis. For instance 32 elements.
// //
// - col_reduce_small
//
// It is a column reduce for small columns. Each thread loops over the whole
// column without communicating with any other thread.
//
// - col_reduce_two_pass
//
// It is a reduce for long columns. To increase parallelism, we split the
// reduction in two passes. First we do a column reduce where many
// threadblocks operate on different parts of the reduced axis. Then we
// perform a final column reduce.
//
// Notes: As in row reduce we opt to read as much in order as possible and // Notes: As in row reduce we opt to read as much in order as possible and
// leave transpositions as they are (contrary to our Metal backend). // leave transpositions as they are (contrary to our Metal backend).
// //
@@ -349,6 +494,14 @@ void col_reduce(
return; return;
} }
// Long column with smallish row
size_t total_sums = args.non_col_reductions * args.reduction_size;
size_t approx_threads = out.size();
if (total_sums / approx_threads > 32) {
col_reduce_two_pass(encoder, in, out, reduce_type, axes, plan, args);
return;
}
// Fallback col reduce // Fallback col reduce
col_reduce_looped(encoder, in, out, reduce_type, axes, plan, args); col_reduce_looped(encoder, in, out, reduce_type, axes, plan, args);
} }

View File

@@ -7,8 +7,6 @@
namespace mlx::core { namespace mlx::core {
void copy_gpu(const array& in, array& out, CopyType ctype, const Stream& s);
void copy_gpu(const array& in, array& out, CopyType ctype) { void copy_gpu(const array& in, array& out, CopyType ctype) {
copy_gpu(in, out, ctype, out.primitive().stream()); copy_gpu(in, out, ctype, out.primitive().stream());
} }

View File

@@ -149,7 +149,9 @@ Buffer MetalAllocator::malloc(size_t size) {
buf = device_->newBuffer(size, resource_options); buf = device_->newBuffer(size, resource_options);
} }
if (!buf) { if (!buf) {
return Buffer{nullptr}; std::ostringstream msg;
msg << "[malloc] Unable to allocate " << size << " bytes.";
throw std::runtime_error(msg.str());
} }
lk.lock(); lk.lock();
num_resources_++; num_resources_++;

View File

@@ -1,7 +1,7 @@
#!/bin/bash #!/bin/bash
auditwheel repair dist/* \ auditwheel repair dist/* \
--plat manylinux_2_35_x86_64 \ --plat manylinux_2_35_${1} \
--exclude libcublas* \ --exclude libcublas* \
--exclude libnvrtc* \ --exclude libnvrtc* \
--exclude libcuda* \ --exclude libcuda* \

View File

@@ -210,6 +210,14 @@ class TestReduce(mlx_tests.MLXTestCase):
ref = getattr(np, op)(np_arr, axis=axis) ref = getattr(np, op)(np_arr, axis=axis)
self.assertTrue(np.array_equal(out, ref, equal_nan=True)) self.assertTrue(np.array_equal(out, ref, equal_nan=True))
def test_long_column(self):
a = (np.random.randn(8192, 64) * 32).astype(np.int32)
b = mx.array(a)
c1 = a.sum(0)
c2 = b.sum(0)
self.assertTrue(np.all(c1 == c2))
if __name__ == "__main__": if __name__ == "__main__":
mlx_tests.MLXTestRunner(failfast=True) mlx_tests.MLXTestRunner(failfast=True)

View File

@@ -7,13 +7,21 @@ import re
import subprocess import subprocess
from functools import partial from functools import partial
from pathlib import Path from pathlib import Path
from subprocess import run
from setuptools import Command, Extension, find_namespace_packages, setup from setuptools import Command, Extension, find_namespace_packages, setup
from setuptools.command.bdist_wheel import bdist_wheel from setuptools.command.bdist_wheel import bdist_wheel
from setuptools.command.build_ext import build_ext from setuptools.command.build_ext import build_ext
def cuda_toolkit_major_version():
out = subprocess.check_output(["nvcc", "--version"], stderr=subprocess.STDOUT)
text = out.decode()
m = re.search(r"release (\d+)", text)
if m:
return int(m.group(1))
return None
def get_version(): def get_version():
with open("mlx/version.h", "r") as fid: with open("mlx/version.h", "r") as fid:
for l in fid: for l in fid:
@@ -31,7 +39,7 @@ def get_version():
version = f"{version}.dev{today.year}{today.month:02d}{today.day:02d}" version = f"{version}.dev{today.year}{today.month:02d}{today.day:02d}"
if not pypi_release and not dev_release: if not pypi_release and not dev_release:
git_hash = ( git_hash = (
run( subprocess.run(
"git rev-parse --short HEAD".split(), "git rev-parse --short HEAD".split(),
capture_output=True, capture_output=True,
check=True, check=True,
@@ -284,7 +292,11 @@ if __name__ == "__main__":
install_requires.append( install_requires.append(
f'mlx-metal=={version}; platform_system == "Darwin"' f'mlx-metal=={version}; platform_system == "Darwin"'
) )
extras["cuda"] = [f'mlx-cuda=={version}; platform_system == "Linux"'] extras["cuda"] = [f'mlx-cuda-12=={version}; platform_system == "Linux"']
for toolkit in [12, 13]:
extras[f"cuda{toolkit}"] = [
f'mlx-cuda-{toolkit}=={version}; platform_system == "Linux"'
]
extras["cpu"] = [f'mlx-cpu=={version}; platform_system == "Linux"'] extras["cpu"] = [f'mlx-cpu=={version}; platform_system == "Linux"']
_setup( _setup(
@@ -299,13 +311,25 @@ if __name__ == "__main__":
if build_macos: if build_macos:
name = "mlx-metal" name = "mlx-metal"
elif build_cuda: elif build_cuda:
name = "mlx-cuda" toolkit = cuda_toolkit_major_version()
name = f"mlx-cuda-{toolkit}"
if toolkit == 12:
install_requires += [ install_requires += [
"nvidia-cublas-cu12==12.9.*", "nvidia-cublas-cu12==12.9.*",
"nvidia-cuda-nvrtc-cu12==12.9.*", "nvidia-cuda-nvrtc-cu12==12.9.*",
"nvidia-cudnn-cu12==9.*",
"nvidia-nccl-cu12",
] ]
elif toolkit == 13:
install_requires += [
"nvidia-cublas-cu13",
"nvidia-cuda-nvrtc-cu13",
]
else:
raise ValueError(f"Unknown toolkit {toolkit}")
install_requires += [
f"nvidia-cudnn-cu{toolkit}==9.*",
f"nvidia-nccl-cu{toolkit}",
]
else: else:
name = "mlx-cpu" name = "mlx-cpu"
_setup( _setup(