mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 01:17:26 +08:00
Compare commits
3 Commits
970a5d0a5e
...
755fb4f970
Author | SHA1 | Date | |
---|---|---|---|
![]() |
755fb4f970 | ||
![]() |
76831ed83d | ||
![]() |
7c99acb799 |
@ -16,6 +16,9 @@ parameters:
|
||||
linux_release:
|
||||
type: boolean
|
||||
default: false
|
||||
cuda_release:
|
||||
type: boolean
|
||||
default: false
|
||||
|
||||
jobs:
|
||||
build_documentation:
|
||||
@ -104,7 +107,7 @@ jobs:
|
||||
command: |
|
||||
echo "stubs"
|
||||
pip install typing_extensions
|
||||
python setup.py generate_stubs
|
||||
python setup.py generate_stubs
|
||||
- run:
|
||||
name: Run Python tests
|
||||
command: |
|
||||
@ -162,7 +165,7 @@ jobs:
|
||||
command: |
|
||||
source env/bin/activate
|
||||
pip install typing_extensions
|
||||
python setup.py generate_stubs
|
||||
python setup.py generate_stubs
|
||||
- run:
|
||||
name: Run Python tests
|
||||
command: |
|
||||
@ -223,7 +226,6 @@ jobs:
|
||||
command: |
|
||||
sudo apt-get update
|
||||
sudo apt-get install libblas-dev liblapack-dev liblapacke-dev
|
||||
sudo apt-get install openmpi-bin openmpi-common libopenmpi-dev
|
||||
python -m venv env
|
||||
source env/bin/activate
|
||||
CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
|
||||
@ -283,7 +285,7 @@ jobs:
|
||||
command: |
|
||||
source env/bin/activate
|
||||
pip install typing_extensions
|
||||
python setup.py generate_stubs
|
||||
python setup.py generate_stubs
|
||||
- run:
|
||||
name: Build Python package
|
||||
command: |
|
||||
@ -342,7 +344,7 @@ jobs:
|
||||
CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
|
||||
pip install . -v
|
||||
pip install typing_extensions
|
||||
python setup.py generate_stubs
|
||||
python setup.py generate_stubs
|
||||
<< parameters.extra_env >> \
|
||||
CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
|
||||
python -m build --wheel
|
||||
@ -356,6 +358,48 @@ jobs:
|
||||
- store_artifacts:
|
||||
path: wheelhouse/
|
||||
|
||||
build_cuda_release:
|
||||
parameters:
|
||||
python_version:
|
||||
type: string
|
||||
default: "3.9"
|
||||
extra_env:
|
||||
type: string
|
||||
default: "DEV_RELEASE=1"
|
||||
machine:
|
||||
image: linux-cuda-12:default
|
||||
resource_class: gpu.nvidia.small.gen2
|
||||
steps:
|
||||
- checkout
|
||||
- run:
|
||||
name: Build wheel
|
||||
command: |
|
||||
sudo apt-get update
|
||||
sudo apt-get install libblas-dev liblapack-dev liblapacke-dev
|
||||
python -m venv env
|
||||
source env/bin/activate
|
||||
pip install auditwheel
|
||||
pip install patchelf
|
||||
pip install build
|
||||
pip install twine
|
||||
<< parameters.extra_env >> \
|
||||
CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
|
||||
CMAKE_ARGS="-DMLX_BUILD_CUDA=ON -DCMAKE_CUDA_COMPILER=`which nvcc`" \
|
||||
pip install ".[dev]" -v
|
||||
python setup.py generate_stubs
|
||||
<< parameters.extra_env >> \
|
||||
CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
|
||||
CMAKE_ARGS="-DMLX_BUILD_CUDA=ON -DCMAKE_CUDA_COMPILER=`which nvcc`" \
|
||||
python -m build --wheel
|
||||
bash python/scripts/repair_cuda.sh
|
||||
- run:
|
||||
name: Upload package
|
||||
command: |
|
||||
source env/bin/activate
|
||||
twine upload wheelhouse/*.whl
|
||||
- store_artifacts:
|
||||
path: wheelhouse/
|
||||
|
||||
workflows:
|
||||
build_and_test:
|
||||
when:
|
||||
@ -625,3 +669,14 @@ workflows:
|
||||
parameters:
|
||||
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
|
||||
extra_env: ["PYPI_RELEASE=1"]
|
||||
cuda_test_release:
|
||||
when:
|
||||
and:
|
||||
- equal: [ main, << pipeline.git.branch >> ]
|
||||
- << pipeline.parameters.cuda_release >>
|
||||
jobs:
|
||||
- build_cuda_release:
|
||||
matrix:
|
||||
parameters:
|
||||
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
|
||||
extra_env: ["PYPI_RELEASE=1"]
|
||||
|
@ -30,6 +30,16 @@ MLX is also available on conda-forge. To install MLX with conda do:
|
||||
|
||||
conda install conda-forge::mlx
|
||||
|
||||
CUDA
|
||||
^^^^
|
||||
|
||||
MLX has a CUDA backend which you can use on any Linux platform with CUDA 12
|
||||
and SM 7.0 (Volta) and up. To install MLX with CUDA support, run:
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
pip install mlx-cuda
|
||||
|
||||
|
||||
Troubleshooting
|
||||
^^^^^^^^^^^^^^^
|
||||
@ -65,6 +75,8 @@ Build Requirements
|
||||
Python API
|
||||
^^^^^^^^^^
|
||||
|
||||
.. _python install:
|
||||
|
||||
To build and install the MLX python library from source, first, clone MLX from
|
||||
`its GitHub repo <https://github.com/ml-explore/mlx>`_:
|
||||
|
||||
@ -107,6 +119,8 @@ IDE:
|
||||
C++ API
|
||||
^^^^^^^
|
||||
|
||||
.. _cpp install:
|
||||
|
||||
Currently, MLX must be built and installed from source.
|
||||
|
||||
Similarly to the python library, to build and install the MLX C++ library start
|
||||
@ -185,6 +199,7 @@ should point to the path to the built metal library.
|
||||
|
||||
xcrun -sdk macosx --show-sdk-version
|
||||
|
||||
|
||||
Binary Size Minimization
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
@ -213,6 +228,50 @@ be anwywhere from a few hundred millisecond to a few seconds depending on the
|
||||
application. Once a kernel is compiled, it will be cached by the system. The
|
||||
Metal kernel cache persists across reboots.
|
||||
|
||||
Linux
|
||||
^^^^^
|
||||
|
||||
To build from source on Linux (CPU only), install the BLAS and LAPACK headers.
|
||||
For example on Ubuntu, run the following:
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
apt-get update -y
|
||||
apt-get install libblas-dev liblapack-dev liblapacke-dev -y
|
||||
|
||||
From here follow the instructions to install either the :ref:`Python <python
|
||||
install>` or :ref:`C++ <cpp install>` APIs.
|
||||
|
||||
CUDA
|
||||
^^^^
|
||||
|
||||
To build from source on Linux with CUDA, install the BLAS and LAPACK headers
|
||||
and the CUDA toolkit. For example on Ubuntu, run the following:
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64/cuda-keyring_1.1-1_all.deb
|
||||
dpkg -i cuda-keyring_1.1-1_all.deb
|
||||
apt-get update -y
|
||||
apt-get -y install cuda-toolkit-12-9
|
||||
apt-get install libblas-dev liblapack-dev liblapacke-dev -y
|
||||
|
||||
|
||||
When building either the Python or C++ APIs make sure to pass the cmake flag
|
||||
``MLX_BUILD_CUDA=ON``. For example, to build the Python API run:
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
CMAKE_BUILD_PARALLEL_LEVEL=8 CMAKE_ARGS="-DMLX_BUILD_CUDA=ON" pip install -e ".[dev]"
|
||||
|
||||
To build the C++ package run:
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
mkdir -p build && cd build
|
||||
cmake .. -DMLX_BUILD_CUDA=ON && make -j
|
||||
|
||||
|
||||
Troubleshooting
|
||||
^^^^^^^^^^^^^^^
|
||||
|
||||
|
@ -114,7 +114,7 @@ void CommandEncoder::synchronize() {
|
||||
std::future<void> f = p->get_future();
|
||||
add_completed_handler([p = std::move(p)]() { p->set_value(); });
|
||||
worker_.end_batch();
|
||||
worker_.commit();
|
||||
commit();
|
||||
f.wait();
|
||||
}
|
||||
|
||||
|
@ -5,28 +5,33 @@ template <typename T, typename AccT = float, int N_READS = 4>
|
||||
const device T* in,
|
||||
device T* out,
|
||||
constant int& axis_size,
|
||||
uint gid [[threadgroup_position_in_grid]],
|
||||
uint _lid [[thread_position_in_threadgroup]],
|
||||
uint2 gid [[threadgroup_position_in_grid]],
|
||||
uint2 tid [[thread_position_in_grid]],
|
||||
uint2 grid_dim [[threads_per_grid]],
|
||||
uint2 _lid [[thread_position_in_threadgroup]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
|
||||
int lid = _lid;
|
||||
int lid = _lid.x;
|
||||
|
||||
constexpr int SIMD_SIZE = 32;
|
||||
constexpr int elem_per_group = SIMD_SIZE * 32 * N_READS;
|
||||
|
||||
threadgroup AccT local_max[SIMD_SIZE];
|
||||
threadgroup AccT local_normalizer[SIMD_SIZE];
|
||||
|
||||
AccT ld[N_READS];
|
||||
|
||||
in += gid * size_t(axis_size) + lid * N_READS;
|
||||
if (lid * N_READS + N_READS <= axis_size) {
|
||||
const int axis_offset = tid.y * elem_per_group;
|
||||
in += gid.x * size_t(axis_size) + lid * N_READS + axis_offset;
|
||||
if (axis_offset + lid * N_READS + N_READS <= axis_size) {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
ld[i] = AccT(in[i]);
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
ld[i] =
|
||||
((lid * N_READS + i) < axis_size) ? AccT(in[i]) : Limits<AccT>::min;
|
||||
ld[i] = ((axis_offset + lid * N_READS + i) < axis_size)
|
||||
? AccT(in[i])
|
||||
: Limits<AccT>::min;
|
||||
}
|
||||
}
|
||||
if (simd_group_id == 0) {
|
||||
@ -55,6 +60,7 @@ template <typename T, typename AccT = float, int N_READS = 4>
|
||||
maxval = local_max[0];
|
||||
|
||||
// Compute exp(x_i - maxval) and store the partial sums in local_normalizer
|
||||
out += gid.x * grid_dim.y + tid.y;
|
||||
AccT normalizer = 0;
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
normalizer += fast::exp(ld[i] - maxval);
|
||||
@ -67,7 +73,7 @@ template <typename T, typename AccT = float, int N_READS = 4>
|
||||
if (simd_group_id == 0) {
|
||||
normalizer = simd_sum(local_normalizer[simd_lane_id]);
|
||||
if (simd_lane_id == 0) {
|
||||
out[gid] = isinf(maxval) ? T(maxval) : T(log(normalizer) + maxval);
|
||||
out[0] = isinf(maxval) ? T(maxval) : T(log(normalizer) + maxval);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -62,15 +62,37 @@ void LogSumExp::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
const int n_reads = 4;
|
||||
const int looped_limit = LOGSUMEXP_LOOPED_LIMIT;
|
||||
|
||||
std::string kernel_name = (axis_size > looped_limit) ? "looped_" : "block_";
|
||||
bool split = n_rows < 4 && axis_size > 4 * looped_limit;
|
||||
bool looped = !split && axis_size > looped_limit;
|
||||
std::string kernel_name = looped ? "looped_" : "block_";
|
||||
kernel_name += "logsumexp_";
|
||||
kernel_name += type_to_name(out);
|
||||
|
||||
auto kernel = get_logsumexp_kernel(d, kernel_name, out);
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
if (split) {
|
||||
auto tmp_size = ceildiv(axis_size, looped_limit);
|
||||
auto tmp_shape = Shape{n_rows, static_cast<int>(tmp_size)};
|
||||
array tmp(tmp_shape, in.dtype(), nullptr, {});
|
||||
tmp.set_data(allocator::malloc(tmp.nbytes()));
|
||||
size_t threadgroup_size = 1024;
|
||||
assert(threadgroup_size <= kernel->maxTotalThreadsPerThreadgroup());
|
||||
size_t n_threads = n_rows * threadgroup_size;
|
||||
auto grid_dims = MTL::Size(n_threads, tmp_size, 1);
|
||||
auto group_dims = MTL::Size(threadgroup_size, 1, 1);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
compute_encoder.set_input_array(in, 0);
|
||||
compute_encoder.set_output_array(tmp, 1);
|
||||
compute_encoder.set_bytes(axis_size, 2);
|
||||
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
||||
d.add_temporary(tmp, s.index);
|
||||
in = tmp;
|
||||
axis_size = tmp_size;
|
||||
}
|
||||
|
||||
{
|
||||
MTL::Size grid_dims, group_dims;
|
||||
if (axis_size <= looped_limit) {
|
||||
if (!looped) {
|
||||
size_t threadgroup_needed = (axis_size + n_reads - 1) / n_reads;
|
||||
size_t simds_needed = (threadgroup_needed + simd_size - 1) / simd_size;
|
||||
size_t threadgroup_size = simd_size * simds_needed;
|
||||
|
17
python/scripts/repair_cuda.sh
Normal file
17
python/scripts/repair_cuda.sh
Normal file
@ -0,0 +1,17 @@
|
||||
#!/bin/bash
|
||||
|
||||
auditwheel repair dist/* \
|
||||
--plat manylinux_2_35_x86_64 \
|
||||
--exclude libcublas* \
|
||||
--exclude libnvrtc*
|
||||
|
||||
cd wheelhouse
|
||||
repaired_wheel=$(find . -name "*.whl" -print -quit)
|
||||
unzip -q "${repaired_wheel}"
|
||||
core_so=$(find mlx -name "core*.so" -print -quit)
|
||||
rpath=$(patchelf --print-rpath "${core_so}")
|
||||
rpath=$rpath:\$ORIGIN/../nvidia/cublas/lib:\$ORIGIN/../nvidia/cuda_nvrtc/lib
|
||||
patchelf --force-rpath --set-rpath "$rpath" "$core_so"
|
||||
|
||||
# Re-zip the repaired wheel
|
||||
zip -r -q "${repaired_wheel}" .
|
@ -760,6 +760,10 @@ class TestOps(mlx_tests.MLXTestCase):
|
||||
x = mx.broadcast_to(mx.random.uniform(shape=(2, 1, 8)), (2, 2, 8))
|
||||
self.assertTrue(mx.allclose(mx.logsumexp(x), logsumexp(x)))
|
||||
|
||||
# Even larger
|
||||
x = mx.random.uniform(shape=(4 * 4096 + 3,))
|
||||
self.assertTrue(mx.allclose(mx.logsumexp(x), logsumexp(x)))
|
||||
|
||||
def test_mean(self):
|
||||
x = mx.array(
|
||||
[
|
||||
|
8
setup.py
8
setup.py
@ -174,20 +174,26 @@ if __name__ == "__main__":
|
||||
)
|
||||
package_dir = {"": "python"}
|
||||
package_data = {"mlx": ["lib/*", "include/*", "share/*"], "mlx.core": ["*.pyi"]}
|
||||
install_requires = []
|
||||
build_cuda = "MLX_BUILD_CUDA=ON" in os.environ.get("CMAKE_ARGS", "")
|
||||
if build_cuda:
|
||||
install_requires = ["nvidia-cublas-cu12", "nvidia-cuda-nvrtc-cu12"]
|
||||
|
||||
setup(
|
||||
name="mlx",
|
||||
name="mlx-cuda" if build_cuda else "mlx",
|
||||
version=get_version(),
|
||||
author="MLX Contributors",
|
||||
author_email="mlx@group.apple.com",
|
||||
description="A framework for machine learning on Apple silicon.",
|
||||
long_description=long_description,
|
||||
long_description_content_type="text/markdown",
|
||||
license="MIT",
|
||||
url="https://github.com/ml-explore/mlx",
|
||||
packages=packages,
|
||||
package_dir=package_dir,
|
||||
package_data=package_data,
|
||||
include_package_data=True,
|
||||
install_requires=install_requires,
|
||||
extras_require={
|
||||
"dev": [
|
||||
"nanobind==2.4.0",
|
||||
|
Loading…
Reference in New Issue
Block a user