From 1fa8dc579762bf926ecbacea20a9bbd29905b093 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Thu, 4 Dec 2025 15:28:29 -0800 Subject: [PATCH] Do a PyPi release for cuda on arm (#2866) --- .github/actions/build-cuda-release/action.yml | 11 ++++++++++- .github/actions/setup-linux/action.yml | 1 + .github/workflows/release.yml | 7 ++++++- python/scripts/repair_cuda.sh | 2 +- 4 files changed, 18 insertions(+), 3 deletions(-) diff --git a/.github/actions/build-cuda-release/action.yml b/.github/actions/build-cuda-release/action.yml index d3fe4c301..1f5ab515c 100644 --- a/.github/actions/build-cuda-release/action.yml +++ b/.github/actions/build-cuda-release/action.yml @@ -1,6 +1,15 @@ name: 'Build CUDA wheel' description: 'Build CUDA wheel' +inputs: + arch: + description: 'Platform architecture tag' + required: true + type: choice + options: + - x86_64 + - aarch64 + runs: using: "composite" steps: @@ -12,4 +21,4 @@ runs: pip install auditwheel build patchelf setuptools python setup.py clean --all MLX_BUILD_STAGE=2 python -m build -w - bash python/scripts/repair_cuda.sh + bash python/scripts/repair_cuda.sh ${{ inputs.arch }} diff --git a/.github/actions/setup-linux/action.yml b/.github/actions/setup-linux/action.yml index 721a097a3..969ac99e9 100644 --- a/.github/actions/setup-linux/action.yml +++ b/.github/actions/setup-linux/action.yml @@ -15,6 +15,7 @@ runs: using: "composite" steps: - name: Use ccache + if: ${{ runner.arch == 'x86_64' }} uses: hendrikmuhs/ccache-action@v1.2 with: key: ccache-${{ runner.os }}-${{ runner.arch }}-${{ inputs.toolkit }}-py${{ inputs.python-version }} diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 5cc99dac2..6087eed0b 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -128,7 +128,10 @@ jobs: build_cuda_release: if: github.repository == 'ml-explore/mlx' - runs-on: ubuntu-22-large + strategy: + matrix: + arch: ['x86_64', 'aarch64'] + runs-on: ${{ matrix.arch == 'x86_64' && 'ubuntu-22-large' || 'ubuntu-22-large-arm' }} env: PYPI_RELEASE: 1 DEV_RELEASE: ${{ github.event.inputs.dev_release == 'true' && 1 || 0 }} @@ -139,6 +142,8 @@ jobs: toolkit: 'cuda-12.9' - name: Build Python package uses: ./.github/actions/build-cuda-release + with: + arch: ${{ matrix.arch }} - name: Upload artifacts uses: actions/upload-artifact@v5 with: diff --git a/python/scripts/repair_cuda.sh b/python/scripts/repair_cuda.sh index 9f8cd9b0d..187b0b15f 100644 --- a/python/scripts/repair_cuda.sh +++ b/python/scripts/repair_cuda.sh @@ -1,7 +1,7 @@ #!/bin/bash auditwheel repair dist/* \ - --plat manylinux_2_35_x86_64 \ + --plat manylinux_2_35_${1} \ --exclude libcublas* \ --exclude libnvrtc* \ --exclude libcuda* \