name: 'Build CUDA wheel' description: 'Build CUDA wheel' inputs: nvcc-location: description: 'Location of nvcc compiler' required: true runs: using: "composite" steps: - name: Build package shell: bash env: MLX_BUILD_STAGE: 2 CMAKE_ARGS: -DMLX_BUILD_CUDA=ON -DCMAKE_CUDA_COMPILER=${{ inputs.nvcc-location }} run: | pip install auditwheel build patchelf setuptools python setup.py clean --all python -m build -w if [ -f "python/scripts/repair_cuda.sh" ]; then bash python/scripts/repair_cuda.sh fi