mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Compare commits
5 Commits
ibv-backen
...
6245824d42
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
6245824d42 | ||
|
|
39289ef025 | ||
|
|
aefc9bd3f6 | ||
|
|
997cfc7699 | ||
|
|
1fa8dc5797 |
11
.github/actions/build-cuda-release/action.yml
vendored
11
.github/actions/build-cuda-release/action.yml
vendored
@@ -1,6 +1,15 @@
|
||||
name: 'Build CUDA wheel'
|
||||
description: 'Build CUDA wheel'
|
||||
|
||||
inputs:
|
||||
arch:
|
||||
description: 'Platform architecture tag'
|
||||
required: true
|
||||
type: choice
|
||||
options:
|
||||
- x86_64
|
||||
- aarch64
|
||||
|
||||
runs:
|
||||
using: "composite"
|
||||
steps:
|
||||
@@ -12,4 +21,4 @@ runs:
|
||||
pip install auditwheel build patchelf setuptools
|
||||
python setup.py clean --all
|
||||
MLX_BUILD_STAGE=2 python -m build -w
|
||||
bash python/scripts/repair_cuda.sh
|
||||
bash python/scripts/repair_cuda.sh ${{ inputs.arch }}
|
||||
|
||||
1
.github/actions/setup-linux/action.yml
vendored
1
.github/actions/setup-linux/action.yml
vendored
@@ -15,6 +15,7 @@ runs:
|
||||
using: "composite"
|
||||
steps:
|
||||
- name: Use ccache
|
||||
if: ${{ runner.arch == 'x86_64' }}
|
||||
uses: hendrikmuhs/ccache-action@v1.2
|
||||
with:
|
||||
key: ccache-${{ runner.os }}-${{ runner.arch }}-${{ inputs.toolkit }}-py${{ inputs.python-version }}
|
||||
|
||||
10
.github/workflows/release.yml
vendored
10
.github/workflows/release.yml
vendored
@@ -128,7 +128,11 @@ jobs:
|
||||
|
||||
build_cuda_release:
|
||||
if: github.repository == 'ml-explore/mlx'
|
||||
runs-on: ubuntu-22-large
|
||||
strategy:
|
||||
matrix:
|
||||
arch: ['x86_64', 'aarch64']
|
||||
toolkit: ['cuda-12.9', 'cuda-13.0']
|
||||
runs-on: ${{ matrix.arch == 'x86_64' && 'ubuntu-22-large' || 'ubuntu-22-large-arm' }}
|
||||
env:
|
||||
PYPI_RELEASE: 1
|
||||
DEV_RELEASE: ${{ github.event.inputs.dev_release == 'true' && 1 || 0 }}
|
||||
@@ -136,9 +140,11 @@ jobs:
|
||||
- uses: actions/checkout@v6
|
||||
- uses: ./.github/actions/setup-linux
|
||||
with:
|
||||
toolkit: 'cuda-12.9'
|
||||
toolkit: ${{ matrix.toolkit }}
|
||||
- name: Build Python package
|
||||
uses: ./.github/actions/build-cuda-release
|
||||
with:
|
||||
arch: ${{ matrix.arch }}
|
||||
- name: Upload artifacts
|
||||
uses: actions/upload-artifact@v5
|
||||
with:
|
||||
|
||||
@@ -29,17 +29,20 @@ MLX has a CUDA backend which you can install with:
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
pip install mlx[cuda]
|
||||
pip install mlx[cuda12]
|
||||
|
||||
|
||||
To install the CUDA package from PyPi your system must meet the following
|
||||
requirements:
|
||||
|
||||
- Nvidia architecture >= SM 7.0 (Volta)
|
||||
- Nvidia architecture >= SM 7.5
|
||||
- Nvidia driver >= 550.54.14
|
||||
- CUDA toolkit >= 12.0
|
||||
- Linux distribution with glibc >= 2.35
|
||||
- Python >= 3.10
|
||||
|
||||
For CUDA 13 use ``pip install mlx[cuda13]``. The CUDA 13 package requires
|
||||
an Nvidia driver >= 580 or an appropriate CUDA compatibility package.
|
||||
|
||||
CPU-only (Linux)
|
||||
^^^^^^^^^^^^^^^^
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
target_sources(
|
||||
mlx
|
||||
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/array.cpp
|
||||
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/array.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/compile.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/dtype.cpp
|
||||
|
||||
@@ -1,24 +0,0 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
|
||||
#include <cstdlib>
|
||||
#include <sstream>
|
||||
|
||||
#include "mlx/allocator.h"
|
||||
|
||||
namespace mlx::core::allocator {
|
||||
|
||||
Buffer malloc(size_t size) {
|
||||
auto buffer = allocator().malloc(size);
|
||||
if (size && !buffer.ptr()) {
|
||||
std::ostringstream msg;
|
||||
msg << "[malloc] Unable to allocate " << size << " bytes.";
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
return buffer;
|
||||
}
|
||||
|
||||
void free(Buffer buffer) {
|
||||
allocator().free(buffer);
|
||||
}
|
||||
|
||||
} // namespace mlx::core::allocator
|
||||
@@ -28,10 +28,6 @@ class Buffer {
|
||||
};
|
||||
};
|
||||
|
||||
Buffer malloc(size_t size);
|
||||
|
||||
void free(Buffer buffer);
|
||||
|
||||
class Allocator {
|
||||
/** Abstract base class for a memory allocator. */
|
||||
public:
|
||||
@@ -49,4 +45,12 @@ class Allocator {
|
||||
|
||||
Allocator& allocator();
|
||||
|
||||
inline Buffer malloc(size_t size) {
|
||||
return allocator().malloc(size);
|
||||
}
|
||||
|
||||
inline void free(Buffer buffer) {
|
||||
allocator().free(buffer);
|
||||
}
|
||||
|
||||
} // namespace mlx::core::allocator
|
||||
|
||||
@@ -157,16 +157,14 @@ CudaAllocator::malloc_async(size_t size, int device, cudaStream_t stream) {
|
||||
cudaError_t err;
|
||||
void* data = nullptr;
|
||||
if (device == -1) {
|
||||
err = cudaMallocManaged(&data, size);
|
||||
CHECK_CUDA_ERROR(cudaMallocManaged(&data, size));
|
||||
} else {
|
||||
err = cudaMallocAsync(&data, size, stream);
|
||||
}
|
||||
if (err != cudaSuccess && err != cudaErrorMemoryAllocation) {
|
||||
throw std::runtime_error(fmt::format(
|
||||
"cudaMallocManaged failed: {}.", cudaGetErrorString(err)));
|
||||
CHECK_CUDA_ERROR(cudaMallocAsync(&data, size, stream));
|
||||
}
|
||||
if (!data) {
|
||||
return Buffer{nullptr};
|
||||
std::ostringstream msg;
|
||||
msg << "[malloc] Unable to allocate " << size << " bytes.";
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
buf = new CudaBuffer{data, size, device};
|
||||
}
|
||||
|
||||
@@ -95,11 +95,14 @@ void copy_general_input(
|
||||
const InType* in_ptr = gpu_ptr<InType>(in) + offset_in;
|
||||
OutType* out_ptr = gpu_ptr<OutType>(out) + offset_out;
|
||||
int ndim = shape.size();
|
||||
int work_per_thread = 1;
|
||||
|
||||
int work_per_thread = 8;
|
||||
auto dim0 = ndim > 0 ? shape.back() : 1;
|
||||
auto rest = out.size() / dim0;
|
||||
if (dim0 >= 4) {
|
||||
if (dim0 >= 4 && dim0 < 8) {
|
||||
work_per_thread = 4;
|
||||
} else if (dim0 < 4) {
|
||||
work_per_thread = 1;
|
||||
}
|
||||
dim0 = (dim0 + work_per_thread - 1) / work_per_thread;
|
||||
auto block_dims = get_block_dims(dim0, rest, 1);
|
||||
@@ -110,7 +113,10 @@ void copy_general_input(
|
||||
dispatch_1_2_3(ndim, [&](auto dims_constant) {
|
||||
auto kernel =
|
||||
cu::copy_g_nd<InType, OutType, IdxT, dims_constant(), 1>;
|
||||
if (work_per_thread == 4) {
|
||||
if (work_per_thread == 8) {
|
||||
kernel =
|
||||
cu::copy_g_nd<InType, OutType, IdxT, dims_constant(), 8>;
|
||||
} else if (work_per_thread == 4) {
|
||||
kernel =
|
||||
cu::copy_g_nd<InType, OutType, IdxT, dims_constant(), 4>;
|
||||
}
|
||||
@@ -127,7 +133,9 @@ void copy_general_input(
|
||||
});
|
||||
} else { // ndim >= 4
|
||||
auto kernel = cu::copy_g<InType, OutType, IdxT, 1>;
|
||||
if (work_per_thread == 4) {
|
||||
if (work_per_thread == 8) {
|
||||
kernel = cu::copy_g<InType, OutType, IdxT, 8>;
|
||||
} else if (work_per_thread == 4) {
|
||||
kernel = cu::copy_g<InType, OutType, IdxT, 4>;
|
||||
}
|
||||
encoder.add_kernel_node(
|
||||
|
||||
@@ -89,9 +89,13 @@ template <
|
||||
int NDIM,
|
||||
int BM,
|
||||
int BN,
|
||||
int N_READS = 4>
|
||||
__global__ void
|
||||
col_reduce_looped(T* in, U* out, const __grid_constant__ ColReduceArgs args) {
|
||||
int N_READS = 4,
|
||||
int BLOCKS = 1>
|
||||
__global__ void col_reduce_looped(
|
||||
T* in,
|
||||
U* out,
|
||||
const __grid_constant__ ColReduceArgs args,
|
||||
int64_t out_size) {
|
||||
auto grid = cg::this_grid();
|
||||
auto block = cg::this_thread_block();
|
||||
auto warp = cg::tiled_partition<WARP_SIZE>(block);
|
||||
@@ -102,6 +106,8 @@ col_reduce_looped(T* in, U* out, const __grid_constant__ ColReduceArgs args) {
|
||||
size_t tile_idx = grid.block_rank();
|
||||
size_t tile_x = tile_idx % ((args.reduction_stride + BN - 1) / BN);
|
||||
size_t tile_y = tile_idx / ((args.reduction_stride + BN - 1) / BN);
|
||||
size_t tile_out = tile_y / out_size;
|
||||
tile_y = tile_y % out_size;
|
||||
|
||||
// Compute the indices for the thread within the tile
|
||||
short thread_x = block.thread_rank() % threads_per_row;
|
||||
@@ -118,12 +124,23 @@ col_reduce_looped(T* in, U* out, const __grid_constant__ ColReduceArgs args) {
|
||||
totals[i] = ReduceInit<Op, T>::value();
|
||||
}
|
||||
|
||||
LoopedElemToLoc<NDIM, (NDIM > 2)> loop(args.reduce_ndim);
|
||||
loop.next(thread_y, args.reduce_shape.data(), args.reduce_strides.data());
|
||||
size_t total = args.non_col_reductions * args.reduction_size;
|
||||
size_t per_block, start, end;
|
||||
if constexpr (BLOCKS > 1) {
|
||||
per_block = (total + BLOCKS - 1) / BLOCKS;
|
||||
start = tile_out * per_block + thread_y;
|
||||
end = min((tile_out + 1) * per_block, total);
|
||||
} else {
|
||||
per_block = total;
|
||||
start = thread_y;
|
||||
end = total;
|
||||
}
|
||||
|
||||
LoopedElemToLoc<NDIM, (NDIM > 2)> loop(args.reduce_ndim);
|
||||
loop.next(start, args.reduce_shape.data(), args.reduce_strides.data());
|
||||
if (tile_x * BN + BN <= args.reduction_stride) {
|
||||
if (args.reduction_stride % N_READS == 0) {
|
||||
for (size_t r = thread_y; r < total; r += BM) {
|
||||
for (size_t r = start; r < end; r += BM) {
|
||||
T vals[N_READS];
|
||||
cub::LoadDirectBlockedVectorized(thread_x, in + loop.location(), vals);
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
@@ -132,7 +149,7 @@ col_reduce_looped(T* in, U* out, const __grid_constant__ ColReduceArgs args) {
|
||||
loop.next(BM, args.reduce_shape.data(), args.reduce_strides.data());
|
||||
}
|
||||
} else {
|
||||
for (size_t r = thread_y; r < total; r += BM) {
|
||||
for (size_t r = start; r < end; r += BM) {
|
||||
T vals[N_READS];
|
||||
cub::LoadDirectBlocked(thread_x, in + loop.location(), vals);
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
@@ -142,7 +159,7 @@ col_reduce_looped(T* in, U* out, const __grid_constant__ ColReduceArgs args) {
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for (size_t r = thread_y; r < total; r += BM) {
|
||||
for (size_t r = start; r < end; r += BM) {
|
||||
T vals[N_READS];
|
||||
cub::LoadDirectBlocked(
|
||||
thread_x,
|
||||
@@ -173,6 +190,9 @@ col_reduce_looped(T* in, U* out, const __grid_constant__ ColReduceArgs args) {
|
||||
|
||||
// Write result.
|
||||
if (warp.thread_rank() == 0) {
|
||||
if (BLOCKS > 1) {
|
||||
out += tile_out * out_size * args.reduction_stride;
|
||||
}
|
||||
cub::StoreDirectBlocked(
|
||||
warp.meta_group_rank(),
|
||||
out + tile_y * args.reduction_stride + tile_x * BN,
|
||||
@@ -227,11 +247,12 @@ __global__ void col_reduce_small(
|
||||
inline auto output_grid_for_col_reduce(
|
||||
const array& out,
|
||||
const cu::ColReduceArgs& args,
|
||||
int bn) {
|
||||
int bn,
|
||||
int outer = 1) {
|
||||
int gx, gy = 1;
|
||||
size_t n_inner_blocks = cuda::ceil_div(args.reduction_stride, bn);
|
||||
size_t n_outer_blocks = out.size() / args.reduction_stride;
|
||||
size_t n_blocks = n_outer_blocks * n_inner_blocks;
|
||||
size_t n_blocks = n_outer_blocks * n_inner_blocks * outer;
|
||||
while (n_blocks / gy > INT32_MAX) {
|
||||
gy *= 2;
|
||||
}
|
||||
@@ -277,7 +298,8 @@ void col_reduce_looped(
|
||||
0,
|
||||
indata,
|
||||
gpu_ptr<U>(out),
|
||||
static_cast<cu::ColReduceArgs>(args));
|
||||
static_cast<cu::ColReduceArgs>(args),
|
||||
out.size() / args.reduction_stride);
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -320,6 +342,117 @@ void col_reduce_small(
|
||||
});
|
||||
}
|
||||
|
||||
void col_reduce_two_pass(
|
||||
cu::CommandEncoder& encoder,
|
||||
const array& in,
|
||||
array& out,
|
||||
Reduce::ReduceType reduce_type,
|
||||
const std::vector<int>& axes,
|
||||
const ReductionPlan& plan,
|
||||
const cu::ColReduceArgs& args) {
|
||||
// Allocate data for the output using in's layout to access them as
|
||||
// contiguously as possible.
|
||||
allocate_same_layout(out, in, axes, encoder);
|
||||
|
||||
// Allocate an intermediate array to hold the 1st pass result
|
||||
constexpr int outer = 32;
|
||||
|
||||
Shape intermediate_shape;
|
||||
intermediate_shape.push_back(outer);
|
||||
intermediate_shape.insert(
|
||||
intermediate_shape.end(), out.shape().begin(), out.shape().end());
|
||||
|
||||
Strides intermediate_strides;
|
||||
intermediate_strides.push_back(out.size());
|
||||
intermediate_strides.insert(
|
||||
intermediate_strides.end(), out.strides().begin(), out.strides().end());
|
||||
|
||||
array intermediate(intermediate_shape, out.dtype(), nullptr, {});
|
||||
auto [data_size, rc, cc] =
|
||||
check_contiguity(intermediate_shape, intermediate_strides);
|
||||
auto fl = out.flags();
|
||||
fl.row_contiguous = rc;
|
||||
fl.col_contiguous = cc;
|
||||
fl.contiguous = true;
|
||||
intermediate.set_data(
|
||||
cu::malloc_async(intermediate.nbytes(), encoder),
|
||||
data_size,
|
||||
intermediate_strides,
|
||||
fl,
|
||||
allocator::free);
|
||||
|
||||
encoder.add_temporary(intermediate);
|
||||
encoder.set_input_array(in);
|
||||
encoder.set_output_array(intermediate);
|
||||
dispatch_all_types(in.dtype(), [&](auto type_tag) {
|
||||
dispatch_reduce_ops(reduce_type, [&](auto reduce_type_tag) {
|
||||
dispatch_reduce_ndim(args.reduce_ndim, [&](auto reduce_ndim) {
|
||||
using OP = MLX_GET_TYPE(reduce_type_tag);
|
||||
using T = cuda_type_t<MLX_GET_TYPE(type_tag)>;
|
||||
using U = typename cu::ReduceResult<OP, T>::type;
|
||||
// Cub doesn't like const pointers for vectorized loads. (sigh)
|
||||
T* indata = const_cast<T*>(gpu_ptr<T>(in));
|
||||
|
||||
constexpr int N_READS = 4;
|
||||
constexpr int BM = 32;
|
||||
constexpr int BN = 32;
|
||||
dim3 grid = output_grid_for_col_reduce(out, args, BN, outer);
|
||||
int blocks = BM * BN / N_READS;
|
||||
auto kernel = cu::
|
||||
col_reduce_looped<T, U, OP, reduce_ndim(), BM, BN, N_READS, outer>;
|
||||
encoder.add_kernel_node(
|
||||
kernel,
|
||||
grid,
|
||||
blocks,
|
||||
0,
|
||||
indata,
|
||||
gpu_ptr<U>(intermediate),
|
||||
static_cast<cu::ColReduceArgs>(args),
|
||||
out.size() / args.reduction_stride);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
// Prepare the reduction arguments for the 2nd pass
|
||||
cu::ColReduceArgs second_args = args;
|
||||
second_args.reduction_size = outer;
|
||||
second_args.reduction_stride = out.size();
|
||||
second_args.ndim = 0;
|
||||
second_args.reduce_shape[0] = outer;
|
||||
second_args.reduce_strides[0] = out.size();
|
||||
second_args.reduce_ndim = 1;
|
||||
second_args.non_col_reductions = 1;
|
||||
|
||||
encoder.set_input_array(intermediate);
|
||||
encoder.set_output_array(out);
|
||||
dispatch_all_types(intermediate.dtype(), [&](auto type_tag) {
|
||||
dispatch_reduce_ops(reduce_type, [&](auto reduce_type_tag) {
|
||||
dispatch_reduce_ndim(second_args.reduce_ndim, [&](auto reduce_ndim) {
|
||||
using OP = MLX_GET_TYPE(reduce_type_tag);
|
||||
using T = cuda_type_t<MLX_GET_TYPE(type_tag)>;
|
||||
using U = typename cu::ReduceResult<OP, T>::type;
|
||||
|
||||
constexpr int N_READS = 4;
|
||||
constexpr int BM = 32;
|
||||
constexpr int BN = 32;
|
||||
dim3 grid = output_grid_for_col_reduce(out, second_args, BN);
|
||||
int blocks = BM * BN / N_READS;
|
||||
auto kernel =
|
||||
cu::col_reduce_looped<T, U, OP, reduce_ndim(), BM, BN, N_READS>;
|
||||
encoder.add_kernel_node(
|
||||
kernel,
|
||||
grid,
|
||||
blocks,
|
||||
0,
|
||||
gpu_ptr<T>(intermediate),
|
||||
gpu_ptr<U>(out),
|
||||
second_args,
|
||||
second_args.reduction_stride);
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
void col_reduce(
|
||||
cu::CommandEncoder& encoder,
|
||||
const array& in,
|
||||
@@ -334,6 +467,18 @@ void col_reduce(
|
||||
// It is a general strided reduce. Each threadblock computes the output for
|
||||
// a subrow of the fast moving axis. For instance 32 elements.
|
||||
//
|
||||
// - col_reduce_small
|
||||
//
|
||||
// It is a column reduce for small columns. Each thread loops over the whole
|
||||
// column without communicating with any other thread.
|
||||
//
|
||||
// - col_reduce_two_pass
|
||||
//
|
||||
// It is a reduce for long columns. To increase parallelism, we split the
|
||||
// reduction in two passes. First we do a column reduce where many
|
||||
// threadblocks operate on different parts of the reduced axis. Then we
|
||||
// perform a final column reduce.
|
||||
//
|
||||
// Notes: As in row reduce we opt to read as much in order as possible and
|
||||
// leave transpositions as they are (contrary to our Metal backend).
|
||||
//
|
||||
@@ -349,6 +494,14 @@ void col_reduce(
|
||||
return;
|
||||
}
|
||||
|
||||
// Long column with smallish row
|
||||
size_t total_sums = args.non_col_reductions * args.reduction_size;
|
||||
size_t approx_threads = out.size();
|
||||
if (total_sums / approx_threads > 32) {
|
||||
col_reduce_two_pass(encoder, in, out, reduce_type, axes, plan, args);
|
||||
return;
|
||||
}
|
||||
|
||||
// Fallback col reduce
|
||||
col_reduce_looped(encoder, in, out, reduce_type, axes, plan, args);
|
||||
}
|
||||
|
||||
@@ -7,8 +7,6 @@
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
void copy_gpu(const array& in, array& out, CopyType ctype, const Stream& s);
|
||||
|
||||
void copy_gpu(const array& in, array& out, CopyType ctype) {
|
||||
copy_gpu(in, out, ctype, out.primitive().stream());
|
||||
}
|
||||
|
||||
@@ -149,7 +149,9 @@ Buffer MetalAllocator::malloc(size_t size) {
|
||||
buf = device_->newBuffer(size, resource_options);
|
||||
}
|
||||
if (!buf) {
|
||||
return Buffer{nullptr};
|
||||
std::ostringstream msg;
|
||||
msg << "[malloc] Unable to allocate " << size << " bytes.";
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
lk.lock();
|
||||
num_resources_++;
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
#!/bin/bash
|
||||
|
||||
auditwheel repair dist/* \
|
||||
--plat manylinux_2_35_x86_64 \
|
||||
--plat manylinux_2_35_${1} \
|
||||
--exclude libcublas* \
|
||||
--exclude libnvrtc* \
|
||||
--exclude libcuda* \
|
||||
|
||||
@@ -210,6 +210,14 @@ class TestReduce(mlx_tests.MLXTestCase):
|
||||
ref = getattr(np, op)(np_arr, axis=axis)
|
||||
self.assertTrue(np.array_equal(out, ref, equal_nan=True))
|
||||
|
||||
def test_long_column(self):
|
||||
a = (np.random.randn(8192, 64) * 32).astype(np.int32)
|
||||
b = mx.array(a)
|
||||
|
||||
c1 = a.sum(0)
|
||||
c2 = b.sum(0)
|
||||
self.assertTrue(np.all(c1 == c2))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
mlx_tests.MLXTestRunner(failfast=True)
|
||||
|
||||
40
setup.py
40
setup.py
@@ -7,13 +7,21 @@ import re
|
||||
import subprocess
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
from subprocess import run
|
||||
|
||||
from setuptools import Command, Extension, find_namespace_packages, setup
|
||||
from setuptools.command.bdist_wheel import bdist_wheel
|
||||
from setuptools.command.build_ext import build_ext
|
||||
|
||||
|
||||
def cuda_toolkit_major_version():
|
||||
out = subprocess.check_output(["nvcc", "--version"], stderr=subprocess.STDOUT)
|
||||
text = out.decode()
|
||||
m = re.search(r"release (\d+)", text)
|
||||
if m:
|
||||
return int(m.group(1))
|
||||
return None
|
||||
|
||||
|
||||
def get_version():
|
||||
with open("mlx/version.h", "r") as fid:
|
||||
for l in fid:
|
||||
@@ -31,7 +39,7 @@ def get_version():
|
||||
version = f"{version}.dev{today.year}{today.month:02d}{today.day:02d}"
|
||||
if not pypi_release and not dev_release:
|
||||
git_hash = (
|
||||
run(
|
||||
subprocess.run(
|
||||
"git rev-parse --short HEAD".split(),
|
||||
capture_output=True,
|
||||
check=True,
|
||||
@@ -284,7 +292,11 @@ if __name__ == "__main__":
|
||||
install_requires.append(
|
||||
f'mlx-metal=={version}; platform_system == "Darwin"'
|
||||
)
|
||||
extras["cuda"] = [f'mlx-cuda=={version}; platform_system == "Linux"']
|
||||
extras["cuda"] = [f'mlx-cuda-12=={version}; platform_system == "Linux"']
|
||||
for toolkit in [12, 13]:
|
||||
extras[f"cuda{toolkit}"] = [
|
||||
f'mlx-cuda-{toolkit}=={version}; platform_system == "Linux"'
|
||||
]
|
||||
extras["cpu"] = [f'mlx-cpu=={version}; platform_system == "Linux"']
|
||||
|
||||
_setup(
|
||||
@@ -299,13 +311,25 @@ if __name__ == "__main__":
|
||||
if build_macos:
|
||||
name = "mlx-metal"
|
||||
elif build_cuda:
|
||||
name = "mlx-cuda"
|
||||
toolkit = cuda_toolkit_major_version()
|
||||
name = f"mlx-cuda-{toolkit}"
|
||||
if toolkit == 12:
|
||||
install_requires += [
|
||||
"nvidia-cublas-cu12==12.9.*",
|
||||
"nvidia-cuda-nvrtc-cu12==12.9.*",
|
||||
]
|
||||
elif toolkit == 13:
|
||||
install_requires += [
|
||||
"nvidia-cublas-cu13",
|
||||
"nvidia-cuda-nvrtc-cu13",
|
||||
]
|
||||
else:
|
||||
raise ValueError(f"Unknown toolkit {toolkit}")
|
||||
install_requires += [
|
||||
"nvidia-cublas-cu12==12.9.*",
|
||||
"nvidia-cuda-nvrtc-cu12==12.9.*",
|
||||
"nvidia-cudnn-cu12==9.*",
|
||||
"nvidia-nccl-cu12",
|
||||
f"nvidia-cudnn-cu{toolkit}==9.*",
|
||||
f"nvidia-nccl-cu{toolkit}",
|
||||
]
|
||||
|
||||
else:
|
||||
name = "mlx-cpu"
|
||||
_setup(
|
||||
|
||||
Reference in New Issue
Block a user