From 422ff344cd6908d0a29ea9eb1b0de5843eca2701 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Wed, 3 Dec 2025 13:47:50 -0800 Subject: [PATCH] do a PyPi release for cuda on arm --- .github/actions/build-cuda-release/action.yml | 11 ++++++++++- .github/workflows/release.yml | 7 ++++++- python/scripts/repair_cuda.sh | 2 +- 3 files changed, 17 insertions(+), 3 deletions(-) diff --git a/.github/actions/build-cuda-release/action.yml b/.github/actions/build-cuda-release/action.yml index d3fe4c301..1f5ab515c 100644 --- a/.github/actions/build-cuda-release/action.yml +++ b/.github/actions/build-cuda-release/action.yml @@ -1,6 +1,15 @@ name: 'Build CUDA wheel' description: 'Build CUDA wheel' +inputs: + arch: + description: 'Platform architecture tag' + required: true + type: choice + options: + - x86_64 + - aarch64 + runs: using: "composite" steps: @@ -12,4 +21,4 @@ runs: pip install auditwheel build patchelf setuptools python setup.py clean --all MLX_BUILD_STAGE=2 python -m build -w - bash python/scripts/repair_cuda.sh + bash python/scripts/repair_cuda.sh ${{ inputs.arch }} diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 5cc99dac2..ff5d5d89c 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -128,7 +128,10 @@ jobs: build_cuda_release: if: github.repository == 'ml-explore/mlx' - runs-on: ubuntu-22-large + strategy: + matrix: + arch: ['x86_64', 'aarch64'] + runs-on: ${{ matrix.arch == 'x86_64' && 'ubuntu-22.04' || 'ubuntu-22.04-arm' }} env: PYPI_RELEASE: 1 DEV_RELEASE: ${{ github.event.inputs.dev_release == 'true' && 1 || 0 }} @@ -139,6 +142,8 @@ jobs: toolkit: 'cuda-12.9' - name: Build Python package uses: ./.github/actions/build-cuda-release + with: + arch: ${{ matrix.arch }} - name: Upload artifacts uses: actions/upload-artifact@v5 with: diff --git a/python/scripts/repair_cuda.sh b/python/scripts/repair_cuda.sh index 9f8cd9b0d..187b0b15f 100644 --- a/python/scripts/repair_cuda.sh +++ b/python/scripts/repair_cuda.sh @@ -1,7 +1,7 @@ #!/bin/bash auditwheel repair dist/* \ - --plat manylinux_2_35_x86_64 \ + --plat manylinux_2_35_${1} \ --exclude libcublas* \ --exclude libnvrtc* \ --exclude libcuda* \