diff --git a/.github/actions/build-linux-release/action.yml b/.github/actions/build-linux-release/action.yml index 78e18a11c..3ea0e6fc0 100644 --- a/.github/actions/build-linux-release/action.yml +++ b/.github/actions/build-linux-release/action.yml @@ -7,6 +7,13 @@ inputs: type: boolean required: false default: false + arch: + description: 'Platform architecture tag' + required: true + type: choice + options: + - x86_64 + - aarch64 runs: using: "composite" @@ -23,11 +30,11 @@ runs: pip install auditwheel patchelf build python setup.py clean --all MLX_BUILD_STAGE=1 python -m build -w - bash python/scripts/repair_linux.sh + bash python/scripts/repair_linux.sh ${{ inputs.arch }} - name: Build backend package if: ${{ inputs.build-backend }} shell: bash run: | python setup.py clean --all MLX_BUILD_STAGE=2 python -m build -w - auditwheel repair dist/mlx_cpu*.whl --plat manylinux_2_35_x86_64 + auditwheel repair dist/mlx_cpu*.whl --plat manylinux_2_35_${{ inputs.arch }} diff --git a/.github/workflows/nightly.yml b/.github/workflows/nightly.yml index b7ff91a73..fc0147546 100644 --- a/.github/workflows/nightly.yml +++ b/.github/workflows/nightly.yml @@ -21,6 +21,7 @@ jobs: - uses: ./.github/actions/build-linux-release with: build-backend: ${{ matrix.python-version == '3.10' }} + arch: "x86_64" - name: Upload mlx artifacts uses: actions/upload-artifact@v5 with: @@ -40,7 +41,10 @@ jobs: fail-fast: false matrix: python_version: ["3.10", "3.11", "3.12", "3.13", "3.14"] - runs-on: ubuntu-22.04 + runner: + - ubuntu-22.04 + - ubuntu-22.04-arm + runs-on: ${{ matrix.runner }} steps: - uses: actions/checkout@v5 - uses: ./.github/actions/setup-linux diff --git a/.github/workflows/pull_request.yml b/.github/workflows/pull_request.yml index aae9edb88..ae4c5b0e4 100644 --- a/.github/workflows/pull_request.yml +++ b/.github/workflows/pull_request.yml @@ -14,7 +14,13 @@ jobs: - uses: pre-commit/action@v3.0.1 linux_build_and_test: - runs-on: ubuntu-22.04 + strategy: + matrix: + runner: + - ubuntu-22.04 + - ubuntu-22.04-arm + fail-fast: false + runs-on: ${{ matrix.runner }} steps: - uses: actions/checkout@v5 - uses: ./.github/actions/setup-linux diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 8917ec17b..a2858d071 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -45,7 +45,12 @@ jobs: strategy: matrix: python_version: ["3.10", "3.11", "3.12", "3.13", "3.14"] - runs-on: ubuntu-22.04 + include: + - runner: ubuntu-24.04 + arch: x64 + - runner: ubuntu-24.04-arm64 + arch: arm64 + runs-on: ${{ matrix.runner }} env: PYPI_RELEASE: 1 steps: @@ -56,6 +61,7 @@ jobs: - uses: ./.github/actions/build-linux-release with: build-backend: ${{ matrix.python-version == '3.10' }} + arch: ${{ matrix.arch }} - name: Upload MLX artifacts uses: actions/upload-artifact@v5 with: diff --git a/python/scripts/repair_linux.sh b/python/scripts/repair_linux.sh index 880c43f4b..bbf4dc02a 100644 --- a/python/scripts/repair_linux.sh +++ b/python/scripts/repair_linux.sh @@ -1,7 +1,7 @@ #!/bin/bash auditwheel repair dist/* \ - --plat manylinux_2_35_x86_64 \ + --plat manylinux_2_35_${1} \ --only-plat \ --exclude libmlx* \ -w wheel_tmp diff --git a/tests/linalg_tests.cpp b/tests/linalg_tests.cpp index 0660a69fe..041aedaa5 100644 --- a/tests/linalg_tests.cpp +++ b/tests/linalg_tests.cpp @@ -425,9 +425,11 @@ TEST_CASE("test matrix pseudo-inverse") { const auto A = array({1.0, 2.0, 3.0, 4.0}, {2, 2}); const auto A_pinv = linalg::pinv(A, Device::cpu); const auto A_again = matmul(matmul(A, A_pinv), A); - CHECK(allclose(A_again, A).item()); + CHECK(allclose(A_again, A, /* rtol = */ 1e-5, /* atol = */ 1e-5) + .item()); const auto A_pinv_again = matmul(matmul(A_pinv, A), A_pinv); - CHECK(allclose(A_pinv_again, A_pinv).item()); + CHECK(allclose(A_pinv_again, A_pinv, /* rtol = */ 1e-5, /* atol = */ 1e-5) + .item()); } { // Rectangular matrix m < n const auto prng_key = random::key(42); @@ -437,9 +439,11 @@ TEST_CASE("test matrix pseudo-inverse") { CHECK_FALSE(allclose(zeros, A_pinv, /* rtol = */ 0, /* atol = */ 1e-6) .item()); const auto A_again = matmul(matmul(A, A_pinv), A); - CHECK(allclose(A_again, A).item()); + CHECK(allclose(A_again, A, /* rtol = */ 1e-5, /* atol = */ 1e-5) + .item()); const auto A_pinv_again = matmul(matmul(A_pinv, A), A_pinv); - CHECK(allclose(A_pinv_again, A_pinv).item()); + CHECK(allclose(A_pinv_again, A_pinv, /* rtol = */ 1e-5, /* atol = */ 1e-5) + .item()); } { // Rectangular matrix m > n const auto prng_key = random::key(10); @@ -449,9 +453,11 @@ TEST_CASE("test matrix pseudo-inverse") { CHECK_FALSE(allclose(zeros2, A_pinv, /* rtol = */ 0, /* atol = */ 1e-6) .item()); const auto A_again = matmul(matmul(A, A_pinv), A); - CHECK(allclose(A_again, A).item()); + CHECK(allclose(A_again, A, /* rtol = */ 1e-5, /* atol = */ 1e-5) + .item()); const auto A_pinv_again = matmul(matmul(A_pinv, A), A_pinv); - CHECK(allclose(A_pinv_again, A_pinv).item()); + CHECK(allclose(A_pinv_again, A_pinv, /* rtol = */ 1e-5, /* atol = */ 1e-5) + .item()); } }