mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-29 05:31:18 +08:00
Compare commits
1 Commits
5feed6cb77
...
70db65c6be
Author | SHA1 | Date | |
---|---|---|---|
![]() |
70db65c6be |
@ -16,9 +16,6 @@ parameters:
|
|||||||
linux_release:
|
linux_release:
|
||||||
type: boolean
|
type: boolean
|
||||||
default: false
|
default: false
|
||||||
cuda_release:
|
|
||||||
type: boolean
|
|
||||||
default: false
|
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
build_documentation:
|
build_documentation:
|
||||||
@ -226,6 +223,7 @@ jobs:
|
|||||||
command: |
|
command: |
|
||||||
sudo apt-get update
|
sudo apt-get update
|
||||||
sudo apt-get install libblas-dev liblapack-dev liblapacke-dev
|
sudo apt-get install libblas-dev liblapack-dev liblapacke-dev
|
||||||
|
sudo apt-get install openmpi-bin openmpi-common libopenmpi-dev
|
||||||
python -m venv env
|
python -m venv env
|
||||||
source env/bin/activate
|
source env/bin/activate
|
||||||
CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
|
CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
|
||||||
@ -358,48 +356,6 @@ jobs:
|
|||||||
- store_artifacts:
|
- store_artifacts:
|
||||||
path: wheelhouse/
|
path: wheelhouse/
|
||||||
|
|
||||||
build_cuda_release:
|
|
||||||
parameters:
|
|
||||||
python_version:
|
|
||||||
type: string
|
|
||||||
default: "3.9"
|
|
||||||
extra_env:
|
|
||||||
type: string
|
|
||||||
default: "DEV_RELEASE=1"
|
|
||||||
machine:
|
|
||||||
image: linux-cuda-12:default
|
|
||||||
resource_class: gpu.nvidia.small.gen2
|
|
||||||
steps:
|
|
||||||
- checkout
|
|
||||||
- run:
|
|
||||||
name: Build wheel
|
|
||||||
command: |
|
|
||||||
sudo apt-get update
|
|
||||||
sudo apt-get install libblas-dev liblapack-dev liblapacke-dev
|
|
||||||
python -m venv env
|
|
||||||
source env/bin/activate
|
|
||||||
pip install auditwheel
|
|
||||||
pip install patchelf
|
|
||||||
pip install build
|
|
||||||
pip install twine
|
|
||||||
<< parameters.extra_env >> \
|
|
||||||
CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
|
|
||||||
CMAKE_ARGS="-DMLX_BUILD_CUDA=ON -DCMAKE_CUDA_COMPILER=`which nvcc`" \
|
|
||||||
pip install ".[dev]" -v
|
|
||||||
python setup.py generate_stubs
|
|
||||||
<< parameters.extra_env >> \
|
|
||||||
CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
|
|
||||||
CMAKE_ARGS="-DMLX_BUILD_CUDA=ON -DCMAKE_CUDA_COMPILER=`which nvcc`" \
|
|
||||||
python -m build --wheel
|
|
||||||
bash python/scripts/repair_cuda.sh
|
|
||||||
- run:
|
|
||||||
name: Upload package
|
|
||||||
command: |
|
|
||||||
source env/bin/activate
|
|
||||||
twine upload wheelhouse/*.whl
|
|
||||||
- store_artifacts:
|
|
||||||
path: wheelhouse/
|
|
||||||
|
|
||||||
workflows:
|
workflows:
|
||||||
build_and_test:
|
build_and_test:
|
||||||
when:
|
when:
|
||||||
@ -669,14 +625,3 @@ workflows:
|
|||||||
parameters:
|
parameters:
|
||||||
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
|
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
|
||||||
extra_env: ["PYPI_RELEASE=1"]
|
extra_env: ["PYPI_RELEASE=1"]
|
||||||
cuda_test_release:
|
|
||||||
when:
|
|
||||||
and:
|
|
||||||
- equal: [ main, << pipeline.git.branch >> ]
|
|
||||||
- << pipeline.parameters.cuda_release >>
|
|
||||||
jobs:
|
|
||||||
- build_cuda_release:
|
|
||||||
matrix:
|
|
||||||
parameters:
|
|
||||||
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
|
|
||||||
extra_env: ["PYPI_RELEASE=1"]
|
|
||||||
|
@ -30,16 +30,6 @@ MLX is also available on conda-forge. To install MLX with conda do:
|
|||||||
|
|
||||||
conda install conda-forge::mlx
|
conda install conda-forge::mlx
|
||||||
|
|
||||||
CUDA
|
|
||||||
^^^^
|
|
||||||
|
|
||||||
MLX has a CUDA backend which you can use on any Linux platform with CUDA 12
|
|
||||||
and SM 7.0 (Volta) and up. To install MLX with CUDA support, run:
|
|
||||||
|
|
||||||
.. code-block:: shell
|
|
||||||
|
|
||||||
pip install mlx-cuda
|
|
||||||
|
|
||||||
|
|
||||||
Troubleshooting
|
Troubleshooting
|
||||||
^^^^^^^^^^^^^^^
|
^^^^^^^^^^^^^^^
|
||||||
@ -75,8 +65,6 @@ Build Requirements
|
|||||||
Python API
|
Python API
|
||||||
^^^^^^^^^^
|
^^^^^^^^^^
|
||||||
|
|
||||||
.. _python install:
|
|
||||||
|
|
||||||
To build and install the MLX python library from source, first, clone MLX from
|
To build and install the MLX python library from source, first, clone MLX from
|
||||||
`its GitHub repo <https://github.com/ml-explore/mlx>`_:
|
`its GitHub repo <https://github.com/ml-explore/mlx>`_:
|
||||||
|
|
||||||
@ -119,8 +107,6 @@ IDE:
|
|||||||
C++ API
|
C++ API
|
||||||
^^^^^^^
|
^^^^^^^
|
||||||
|
|
||||||
.. _cpp install:
|
|
||||||
|
|
||||||
Currently, MLX must be built and installed from source.
|
Currently, MLX must be built and installed from source.
|
||||||
|
|
||||||
Similarly to the python library, to build and install the MLX C++ library start
|
Similarly to the python library, to build and install the MLX C++ library start
|
||||||
@ -199,7 +185,6 @@ should point to the path to the built metal library.
|
|||||||
|
|
||||||
xcrun -sdk macosx --show-sdk-version
|
xcrun -sdk macosx --show-sdk-version
|
||||||
|
|
||||||
|
|
||||||
Binary Size Minimization
|
Binary Size Minimization
|
||||||
~~~~~~~~~~~~~~~~~~~~~~~~
|
~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
@ -228,50 +213,6 @@ be anwywhere from a few hundred millisecond to a few seconds depending on the
|
|||||||
application. Once a kernel is compiled, it will be cached by the system. The
|
application. Once a kernel is compiled, it will be cached by the system. The
|
||||||
Metal kernel cache persists across reboots.
|
Metal kernel cache persists across reboots.
|
||||||
|
|
||||||
Linux
|
|
||||||
^^^^^
|
|
||||||
|
|
||||||
To build from source on Linux (CPU only), install the BLAS and LAPACK headers.
|
|
||||||
For example on Ubuntu, run the following:
|
|
||||||
|
|
||||||
.. code-block:: shell
|
|
||||||
|
|
||||||
apt-get update -y
|
|
||||||
apt-get install libblas-dev liblapack-dev liblapacke-dev -y
|
|
||||||
|
|
||||||
From here follow the instructions to install either the :ref:`Python <python
|
|
||||||
install>` or :ref:`C++ <cpp install>` APIs.
|
|
||||||
|
|
||||||
CUDA
|
|
||||||
^^^^
|
|
||||||
|
|
||||||
To build from source on Linux with CUDA, install the BLAS and LAPACK headers
|
|
||||||
and the CUDA toolkit. For example on Ubuntu, run the following:
|
|
||||||
|
|
||||||
.. code-block:: shell
|
|
||||||
|
|
||||||
wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64/cuda-keyring_1.1-1_all.deb
|
|
||||||
dpkg -i cuda-keyring_1.1-1_all.deb
|
|
||||||
apt-get update -y
|
|
||||||
apt-get -y install cuda-toolkit-12-9
|
|
||||||
apt-get install libblas-dev liblapack-dev liblapacke-dev -y
|
|
||||||
|
|
||||||
|
|
||||||
When building either the Python or C++ APIs make sure to pass the cmake flag
|
|
||||||
``MLX_BUILD_CUDA=ON``. For example, to build the Python API run:
|
|
||||||
|
|
||||||
.. code-block:: shell
|
|
||||||
|
|
||||||
CMAKE_BUILD_PARALLEL_LEVEL=8 CMAKE_ARGS="-DMLX_BUILD_CUDA=ON" pip install -e ".[dev]"
|
|
||||||
|
|
||||||
To build the C++ package run:
|
|
||||||
|
|
||||||
.. code-block:: shell
|
|
||||||
|
|
||||||
mkdir -p build && cd build
|
|
||||||
cmake .. -DMLX_BUILD_CUDA=ON && make -j
|
|
||||||
|
|
||||||
|
|
||||||
Troubleshooting
|
Troubleshooting
|
||||||
^^^^^^^^^^^^^^^
|
^^^^^^^^^^^^^^^
|
||||||
|
|
||||||
|
@ -3,7 +3,6 @@
|
|||||||
#include "mlx/backend/cuda/allocator.h"
|
#include "mlx/backend/cuda/allocator.h"
|
||||||
#include "mlx/backend/cuda/utils.h"
|
#include "mlx/backend/cuda/utils.h"
|
||||||
#include "mlx/backend/cuda/worker.h"
|
#include "mlx/backend/cuda/worker.h"
|
||||||
#include "mlx/utils.h"
|
|
||||||
|
|
||||||
#include <cuda_runtime.h>
|
#include <cuda_runtime.h>
|
||||||
#include <fmt/format.h>
|
#include <fmt/format.h>
|
||||||
@ -15,11 +14,9 @@ namespace mlx::core {
|
|||||||
|
|
||||||
namespace cu {
|
namespace cu {
|
||||||
|
|
||||||
constexpr int page_size = 16384;
|
|
||||||
|
|
||||||
CudaAllocator::CudaAllocator()
|
CudaAllocator::CudaAllocator()
|
||||||
: buffer_cache_(
|
: buffer_cache_(
|
||||||
page_size,
|
getpagesize(),
|
||||||
[](CudaBuffer* buf) { return buf->size; },
|
[](CudaBuffer* buf) { return buf->size; },
|
||||||
[this](CudaBuffer* buf) {
|
[this](CudaBuffer* buf) {
|
||||||
cuda_free(buf->data);
|
cuda_free(buf->data);
|
||||||
@ -34,14 +31,7 @@ CudaAllocator::CudaAllocator()
|
|||||||
|
|
||||||
Buffer CudaAllocator::malloc(size_t size) {
|
Buffer CudaAllocator::malloc(size_t size) {
|
||||||
// Find available buffer from cache.
|
// Find available buffer from cache.
|
||||||
auto orig_size = size;
|
|
||||||
std::unique_lock lock(mutex_);
|
std::unique_lock lock(mutex_);
|
||||||
if (size < page_size) {
|
|
||||||
size = next_power_of_2(size);
|
|
||||||
} else {
|
|
||||||
size = page_size * ((size + page_size - 1) / page_size);
|
|
||||||
}
|
|
||||||
|
|
||||||
CudaBuffer* buf = buffer_cache_.reuse_from_cache(size);
|
CudaBuffer* buf = buffer_cache_.reuse_from_cache(size);
|
||||||
if (!buf) {
|
if (!buf) {
|
||||||
// If we have a lot of memory pressure or are over the maximum cache size,
|
// If we have a lot of memory pressure or are over the maximum cache size,
|
||||||
|
@ -24,6 +24,7 @@ void copy_gpu_inplace(
|
|||||||
auto& encoder = cu::get_command_encoder(s);
|
auto& encoder = cu::get_command_encoder(s);
|
||||||
encoder.set_input_array(in);
|
encoder.set_input_array(in);
|
||||||
encoder.set_output_array(out);
|
encoder.set_output_array(out);
|
||||||
|
|
||||||
if (ctype == CopyType::Scalar || ctype == CopyType::Vector) {
|
if (ctype == CopyType::Scalar || ctype == CopyType::Vector) {
|
||||||
copy_contiguous(encoder, ctype, in, out, offset_in, offset_out);
|
copy_contiguous(encoder, ctype, in, out, offset_in, offset_out);
|
||||||
return;
|
return;
|
||||||
|
@ -114,7 +114,7 @@ void CommandEncoder::synchronize() {
|
|||||||
std::future<void> f = p->get_future();
|
std::future<void> f = p->get_future();
|
||||||
add_completed_handler([p = std::move(p)]() { p->set_value(); });
|
add_completed_handler([p = std::move(p)]() { p->set_value(); });
|
||||||
worker_.end_batch();
|
worker_.end_batch();
|
||||||
commit();
|
worker_.commit();
|
||||||
f.wait();
|
f.wait();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -155,8 +155,8 @@ inline __host__ __device__ cuda::std::tuple<IdxT, IdxT> elem_to_loc_nd(
|
|||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i = NDIM - 1; i >= 0; --i) {
|
for (int i = NDIM - 1; i >= 0; --i) {
|
||||||
int dim_idx = elem % shape[i];
|
int dim_idx = elem % shape[i];
|
||||||
a_loc += dim_idx * IdxT(a_strides[i]);
|
a_loc += dim_idx * a_strides[i];
|
||||||
b_loc += dim_idx * IdxT(b_strides[i]);
|
b_loc += dim_idx * b_strides[i];
|
||||||
elem /= shape[i];
|
elem /= shape[i];
|
||||||
}
|
}
|
||||||
return cuda::std::make_tuple(a_loc, b_loc);
|
return cuda::std::make_tuple(a_loc, b_loc);
|
||||||
@ -175,9 +175,9 @@ inline __host__ __device__ cuda::std::tuple<IdxT, IdxT, IdxT> elem_to_loc_nd(
|
|||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i = NDIM - 1; i >= 0; --i) {
|
for (int i = NDIM - 1; i >= 0; --i) {
|
||||||
int dim_idx = elem % shape[i];
|
int dim_idx = elem % shape[i];
|
||||||
a_loc += dim_idx * IdxT(a_strides[i]);
|
a_loc += dim_idx * a_strides[i];
|
||||||
b_loc += dim_idx * IdxT(b_strides[i]);
|
b_loc += dim_idx * b_strides[i];
|
||||||
c_loc += dim_idx * IdxT(c_strides[i]);
|
c_loc += dim_idx * c_strides[i];
|
||||||
elem /= shape[i];
|
elem /= shape[i];
|
||||||
}
|
}
|
||||||
return cuda::std::make_tuple(a_loc, b_loc, c_loc);
|
return cuda::std::make_tuple(a_loc, b_loc, c_loc);
|
||||||
@ -206,8 +206,8 @@ inline __host__ __device__ cuda::std::tuple<IdxT, IdxT> elem_to_loc_4d(
|
|||||||
IdxT b_loc = 0;
|
IdxT b_loc = 0;
|
||||||
for (int i = ndim - 1; i >= 0; --i) {
|
for (int i = ndim - 1; i >= 0; --i) {
|
||||||
int dim_idx = elem % shape[i];
|
int dim_idx = elem % shape[i];
|
||||||
a_loc += dim_idx * IdxT(a_strides[i]);
|
a_loc += dim_idx * a_strides[i];
|
||||||
b_loc += dim_idx * IdxT(b_strides[i]);
|
b_loc += dim_idx * b_strides[i];
|
||||||
elem /= shape[i];
|
elem /= shape[i];
|
||||||
}
|
}
|
||||||
return cuda::std::make_tuple(a_loc, b_loc);
|
return cuda::std::make_tuple(a_loc, b_loc);
|
||||||
@ -226,9 +226,9 @@ inline __host__ __device__ cuda::std::tuple<IdxT, IdxT, IdxT> elem_to_loc_4d(
|
|||||||
IdxT c_loc = 0;
|
IdxT c_loc = 0;
|
||||||
for (int i = ndim - 1; i >= 0; --i) {
|
for (int i = ndim - 1; i >= 0; --i) {
|
||||||
int dim_idx = elem % shape[i];
|
int dim_idx = elem % shape[i];
|
||||||
a_loc += dim_idx * IdxT(a_strides[i]);
|
a_loc += dim_idx * a_strides[i];
|
||||||
b_loc += dim_idx * IdxT(b_strides[i]);
|
b_loc += dim_idx * b_strides[i];
|
||||||
c_loc += dim_idx * IdxT(c_strides[i]);
|
c_loc += dim_idx * c_strides[i];
|
||||||
elem /= shape[i];
|
elem /= shape[i];
|
||||||
}
|
}
|
||||||
return cuda::std::make_tuple(a_loc, b_loc, c_loc);
|
return cuda::std::make_tuple(a_loc, b_loc, c_loc);
|
||||||
|
@ -162,15 +162,11 @@ class MatMul {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void* workspace_ptr = nullptr;
|
|
||||||
if (heuristic_.workspaceSize > 0) {
|
|
||||||
array workspace(
|
array workspace(
|
||||||
allocator::malloc(heuristic_.workspaceSize),
|
allocator::malloc(heuristic_.workspaceSize),
|
||||||
{static_cast<int>(heuristic_.workspaceSize)},
|
{static_cast<int>(heuristic_.workspaceSize)},
|
||||||
int8);
|
int8);
|
||||||
encoder.add_temporary(workspace);
|
encoder.add_temporary(workspace);
|
||||||
workspace_ptr = workspace.data<void>();
|
|
||||||
}
|
|
||||||
|
|
||||||
encoder.launch_kernel([&](cudaStream_t stream) {
|
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||||
CHECK_CUBLAS_ERROR(cublasLtMatmul(
|
CHECK_CUBLAS_ERROR(cublasLtMatmul(
|
||||||
@ -187,8 +183,8 @@ class MatMul {
|
|||||||
out,
|
out,
|
||||||
out_desc_,
|
out_desc_,
|
||||||
&heuristic_.algo,
|
&heuristic_.algo,
|
||||||
workspace_ptr,
|
workspace.data<void>(),
|
||||||
heuristic_.workspaceSize,
|
workspace.nbytes(),
|
||||||
stream));
|
stream));
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
@ -362,18 +358,9 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
a_batch_strides.back(),
|
a_batch_strides.back(),
|
||||||
b_batch_strides.back());
|
b_batch_strides.back());
|
||||||
|
|
||||||
encoder.set_input_array(a);
|
|
||||||
encoder.set_input_array(b);
|
|
||||||
encoder.set_output_array(out);
|
|
||||||
auto nbatch = batch_count / batch_shape.back();
|
|
||||||
if (nbatch == 1) {
|
|
||||||
matmul.run(encoder, out.data<int8_t>(), a.data<int8_t>(), b.data<int8_t>());
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
ContiguousIterator a_it(batch_shape, a_batch_strides, batch_shape.size() - 1);
|
ContiguousIterator a_it(batch_shape, a_batch_strides, batch_shape.size() - 1);
|
||||||
ContiguousIterator b_it(batch_shape, b_batch_strides, batch_shape.size() - 1);
|
ContiguousIterator b_it(batch_shape, b_batch_strides, batch_shape.size() - 1);
|
||||||
for (size_t i = 0; i < nbatch; ++i) {
|
for (size_t i = 0; i < batch_count / batch_shape.back(); ++i) {
|
||||||
matmul.run(
|
matmul.run(
|
||||||
encoder,
|
encoder,
|
||||||
out.data<int8_t>() + out.itemsize() * i * batch_shape.back() * M * N,
|
out.data<int8_t>() + out.itemsize() * i * batch_shape.back() * M * N,
|
||||||
@ -457,28 +444,10 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
b_batch_strides.back(),
|
b_batch_strides.back(),
|
||||||
c_batch_strides.back());
|
c_batch_strides.back());
|
||||||
|
|
||||||
encoder.set_input_array(a);
|
|
||||||
encoder.set_input_array(b);
|
|
||||||
encoder.set_input_array(c);
|
|
||||||
encoder.set_output_array(out);
|
|
||||||
|
|
||||||
auto nbatch = batch_count / batch_shape.back();
|
|
||||||
if (nbatch == 1) {
|
|
||||||
matmul.run(
|
|
||||||
encoder,
|
|
||||||
out.data<int8_t>(),
|
|
||||||
a.data<int8_t>(),
|
|
||||||
b.data<int8_t>(),
|
|
||||||
c.data<int8_t>(),
|
|
||||||
alpha_,
|
|
||||||
beta_);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
ContiguousIterator a_it(batch_shape, a_batch_strides, batch_shape.size() - 1);
|
ContiguousIterator a_it(batch_shape, a_batch_strides, batch_shape.size() - 1);
|
||||||
ContiguousIterator b_it(batch_shape, b_batch_strides, batch_shape.size() - 1);
|
ContiguousIterator b_it(batch_shape, b_batch_strides, batch_shape.size() - 1);
|
||||||
ContiguousIterator c_it(batch_shape, c_batch_strides, batch_shape.size() - 1);
|
ContiguousIterator c_it(batch_shape, c_batch_strides, batch_shape.size() - 1);
|
||||||
for (size_t i = 0; i < nbatch; ++i) {
|
for (size_t i = 0; i < batch_count / batch_shape.back(); ++i) {
|
||||||
matmul.run(
|
matmul.run(
|
||||||
encoder,
|
encoder,
|
||||||
out.data<int8_t>() + out.itemsize() * i * batch_shape.back() * M * N,
|
out.data<int8_t>() + out.itemsize() * i * batch_shape.back() * M * N,
|
||||||
|
@ -79,6 +79,9 @@ void segmented_sort(cu::CommandEncoder& encoder, Args&&... args) {
|
|||||||
void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) {
|
void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) {
|
||||||
array out = out_;
|
array out = out_;
|
||||||
auto& encoder = cu::get_command_encoder(s);
|
auto& encoder = cu::get_command_encoder(s);
|
||||||
|
encoder.set_input_array(in);
|
||||||
|
encoder.set_output_array(out);
|
||||||
|
|
||||||
if (axis < 0) {
|
if (axis < 0) {
|
||||||
axis += in.ndim();
|
axis += in.ndim();
|
||||||
}
|
}
|
||||||
@ -103,8 +106,6 @@ void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) {
|
|||||||
in.flags());
|
in.flags());
|
||||||
}
|
}
|
||||||
|
|
||||||
encoder.set_input_array(in);
|
|
||||||
encoder.set_output_array(out);
|
|
||||||
encoder.launch_kernel([&](cudaStream_t stream) {
|
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||||
MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE, {
|
MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE, {
|
||||||
if constexpr (!std::is_same_v<CTYPE, complex64_t>) {
|
if constexpr (!std::is_same_v<CTYPE, complex64_t>) {
|
||||||
|
@ -413,7 +413,7 @@ class Module(dict):
|
|||||||
f'Module does not have sub-module named "{k}".'
|
f'Module does not have sub-module named "{k}".'
|
||||||
)
|
)
|
||||||
elif isinstance(modules, list):
|
elif isinstance(modules, list):
|
||||||
for i in range(len(modules)):
|
for i in range(len(dst)):
|
||||||
current_value = dst[i]
|
current_value = dst[i]
|
||||||
new_value = modules[i]
|
new_value = modules[i]
|
||||||
if self.is_module(current_value) and self.is_module(new_value):
|
if self.is_module(current_value) and self.is_module(new_value):
|
||||||
|
@ -1,17 +0,0 @@
|
|||||||
#!/bin/bash
|
|
||||||
|
|
||||||
auditwheel repair dist/* \
|
|
||||||
--plat manylinux_2_35_x86_64 \
|
|
||||||
--exclude libcublas* \
|
|
||||||
--exclude libnvrtc*
|
|
||||||
|
|
||||||
cd wheelhouse
|
|
||||||
repaired_wheel=$(find . -name "*.whl" -print -quit)
|
|
||||||
unzip -q "${repaired_wheel}"
|
|
||||||
core_so=$(find mlx -name "core*.so" -print -quit)
|
|
||||||
rpath=$(patchelf --print-rpath "${core_so}")
|
|
||||||
rpath=$rpath:\$ORIGIN/../nvidia/cublas/lib:\$ORIGIN/../nvidia/cuda_nvrtc/lib
|
|
||||||
patchelf --force-rpath --set-rpath "$rpath" "$core_so"
|
|
||||||
|
|
||||||
# Re-zip the repaired wheel
|
|
||||||
zip -r -q "${repaired_wheel}" .
|
|
@ -259,11 +259,6 @@ class TestBase(mlx_tests.MLXTestCase):
|
|||||||
with self.assertRaises(ValueError):
|
with self.assertRaises(ValueError):
|
||||||
m = m.update_modules({"list": ["hi"]})
|
m = m.update_modules({"list": ["hi"]})
|
||||||
|
|
||||||
# Allow updating a strict subset
|
|
||||||
m = nn.Sequential(nn.Linear(3, 3), nn.Linear(3, 3))
|
|
||||||
m.update_modules({"layers": [{}, nn.Linear(3, 4)]})
|
|
||||||
self.assertEqual(m.layers[1].weight.shape, (4, 3))
|
|
||||||
|
|
||||||
|
|
||||||
class TestLayers(mlx_tests.MLXTestCase):
|
class TestLayers(mlx_tests.MLXTestCase):
|
||||||
def test_identity(self):
|
def test_identity(self):
|
||||||
|
8
setup.py
8
setup.py
@ -174,26 +174,20 @@ if __name__ == "__main__":
|
|||||||
)
|
)
|
||||||
package_dir = {"": "python"}
|
package_dir = {"": "python"}
|
||||||
package_data = {"mlx": ["lib/*", "include/*", "share/*"], "mlx.core": ["*.pyi"]}
|
package_data = {"mlx": ["lib/*", "include/*", "share/*"], "mlx.core": ["*.pyi"]}
|
||||||
install_requires = []
|
|
||||||
build_cuda = "MLX_BUILD_CUDA=ON" in os.environ.get("CMAKE_ARGS", "")
|
|
||||||
if build_cuda:
|
|
||||||
install_requires = ["nvidia-cublas-cu12", "nvidia-cuda-nvrtc-cu12"]
|
|
||||||
|
|
||||||
setup(
|
setup(
|
||||||
name="mlx-cuda" if build_cuda else "mlx",
|
name="mlx",
|
||||||
version=get_version(),
|
version=get_version(),
|
||||||
author="MLX Contributors",
|
author="MLX Contributors",
|
||||||
author_email="mlx@group.apple.com",
|
author_email="mlx@group.apple.com",
|
||||||
description="A framework for machine learning on Apple silicon.",
|
description="A framework for machine learning on Apple silicon.",
|
||||||
long_description=long_description,
|
long_description=long_description,
|
||||||
long_description_content_type="text/markdown",
|
long_description_content_type="text/markdown",
|
||||||
license="MIT",
|
|
||||||
url="https://github.com/ml-explore/mlx",
|
url="https://github.com/ml-explore/mlx",
|
||||||
packages=packages,
|
packages=packages,
|
||||||
package_dir=package_dir,
|
package_dir=package_dir,
|
||||||
package_data=package_data,
|
package_data=package_data,
|
||||||
include_package_data=True,
|
include_package_data=True,
|
||||||
install_requires=install_requires,
|
|
||||||
extras_require={
|
extras_require={
|
||||||
"dev": [
|
"dev": [
|
||||||
"nanobind==2.4.0",
|
"nanobind==2.4.0",
|
||||||
|
Loading…
Reference in New Issue
Block a user