mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Compare commits
5 Commits
aad49f932f
...
32b18d8b66
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
32b18d8b66 | ||
|
|
472c43a0c8 | ||
|
|
b7214ff01e | ||
|
|
76414c8971 | ||
|
|
49e4566df3 |
@@ -2,8 +2,8 @@ name: 'Build CUDA wheel'
|
||||
description: 'Build CUDA wheel'
|
||||
|
||||
inputs:
|
||||
nvcc-location:
|
||||
description: 'Location of nvcc compiler'
|
||||
toolkit:
|
||||
description: 'The CUDA toolkit'
|
||||
required: true
|
||||
|
||||
runs:
|
||||
@@ -12,7 +12,7 @@ runs:
|
||||
- name: Build package
|
||||
shell: bash
|
||||
env:
|
||||
CMAKE_ARGS: -DMLX_BUILD_CUDA=ON -DCMAKE_CUDA_COMPILER=${{ inputs.nvcc-location }}
|
||||
CMAKE_ARGS: -DMLX_BUILD_CUDA=ON -DCMAKE_CUDA_COMPILER=/usr/local/${{ inputs.toolkit }}/bin/nvcc
|
||||
run: |
|
||||
pip install auditwheel build patchelf setuptools
|
||||
python setup.py clean --all
|
||||
|
||||
9
.github/actions/build-cuda/action.yml
vendored
9
.github/actions/build-cuda/action.yml
vendored
@@ -2,10 +2,9 @@ name: 'Build and Test with CUDA'
|
||||
description: 'Build and test MLX with CUDA'
|
||||
|
||||
inputs:
|
||||
nvcc-location:
|
||||
description: 'Location of nvcc compiler'
|
||||
toolkit:
|
||||
description: 'The CUDA toolkit'
|
||||
required: true
|
||||
default: '/usr/local/cuda-12.9/bin/nvcc'
|
||||
|
||||
runs:
|
||||
using: "composite"
|
||||
@@ -14,7 +13,7 @@ runs:
|
||||
shell: bash
|
||||
env:
|
||||
DEBUG: 1
|
||||
CMAKE_ARGS: -DMLX_BUILD_CUDA=ON -DCMAKE_COMPILE_WARNING_AS_ERROR=ON -DCMAKE_CUDA_COMPILER=${{ inputs.nvcc-location }}
|
||||
CMAKE_ARGS: -DMLX_BUILD_CUDA=ON -DCMAKE_COMPILE_WARNING_AS_ERROR=ON -DCMAKE_CUDA_COMPILER=/usr/local/${{ inputs.toolkit }}/bin/nvcc
|
||||
run: pip install --no-build-isolation -e ".[dev]" -v
|
||||
|
||||
- name: Build CPP only
|
||||
@@ -22,6 +21,6 @@ runs:
|
||||
run: |
|
||||
cmake . -B build \
|
||||
-DMLX_BUILD_CUDA=ON \
|
||||
-DCMAKE_CUDA_COMPILER=${{ inputs.nvcc-location }} \
|
||||
-DCMAKE_CUDA_COMPILER=/usr/local/${{ inputs.toolkit }}/bin/nvcc \
|
||||
-DCMAKE_BUILD_TYPE=DEBUG
|
||||
cmake --build build -j $(nproc)
|
||||
|
||||
10
.github/actions/build-macos-release/action.yml
vendored
10
.github/actions/build-macos-release/action.yml
vendored
@@ -16,21 +16,15 @@ runs:
|
||||
using: "composite"
|
||||
steps:
|
||||
- name: Build Python package
|
||||
shell: bash
|
||||
env:
|
||||
MACOSX_DEPLOYMENT_TARGET: ${{ inputs.macos-target }}
|
||||
shell: bash -l {0}
|
||||
run: |
|
||||
conda activate env
|
||||
pip install build
|
||||
python setup.py clean --all
|
||||
MLX_BUILD_STAGE=1 python -m build -w
|
||||
|
||||
- name: Build backend package
|
||||
if: ${{ inputs.build-backend }}
|
||||
shell: bash
|
||||
env:
|
||||
MACOSX_DEPLOYMENT_TARGET: ${{ inputs.macos-target }}
|
||||
shell: bash -l {0}
|
||||
run: |
|
||||
conda activate env
|
||||
python setup.py clean --all
|
||||
MLX_BUILD_STAGE=2 python -m build -w
|
||||
|
||||
69
.github/actions/build-macos/action.yml
vendored
69
.github/actions/build-macos/action.yml
vendored
@@ -1,73 +1,51 @@
|
||||
name: 'Build and Test on macOS'
|
||||
description: 'Build and test MLX on macOS'
|
||||
|
||||
inputs:
|
||||
python-version:
|
||||
description: 'Python version to use'
|
||||
required: false
|
||||
default: '3.10'
|
||||
macos-target:
|
||||
description: 'macOS target to build and test for'
|
||||
required: false
|
||||
default: '14.0'
|
||||
|
||||
runs:
|
||||
using: "composite"
|
||||
steps:
|
||||
- name: Setup uv
|
||||
uses: astral-sh/setup-uv@v6
|
||||
with:
|
||||
python-version: ${{ inputs.python-version }}
|
||||
activate-environment: true
|
||||
|
||||
- name: Install dependencies
|
||||
shell: sh
|
||||
env:
|
||||
DEBUG: 1
|
||||
CMAKE_ARGS: "-DCMAKE_COMPILE_WARNING_AS_ERROR=ON"
|
||||
MACOSX_DEPLOYMENT_TARGET: ${{ inputs.macos-target }}
|
||||
shell: bash -l {0}
|
||||
run: |
|
||||
uv pip install --upgrade pip
|
||||
uv pip install cmake setuptools nanobind==2.4.0
|
||||
uv pip install -e . -v
|
||||
pip install --upgrade pip
|
||||
pip install cmake setuptools nanobind==2.4.0
|
||||
pip install -e . -v
|
||||
|
||||
- name: Generate package stubs
|
||||
shell: bash
|
||||
shell: bash -l {0}
|
||||
run: |
|
||||
uv pip install typing_extensions
|
||||
uv run --no-project setup.py generate_stubs
|
||||
pip install typing_extensions
|
||||
python setup.py generate_stubs
|
||||
|
||||
- name: Install tests dependencies
|
||||
shell: sh
|
||||
shell: bash -l {0}
|
||||
run: |
|
||||
uv pip install numpy torch tensorflow unittest-xml-reporting
|
||||
pip install numpy torch tensorflow unittest-xml-reporting
|
||||
|
||||
- name: Run Python tests
|
||||
shell: bash
|
||||
shell: bash -l {0}
|
||||
env:
|
||||
LOW_MEMORY: 1
|
||||
MACOSX_DEPLOYMENT_TARGET: ${{ inputs.macos-target }}
|
||||
run: |
|
||||
DEVICE=cpu uv run -m xmlrunner discover -v python/tests -o test-results/cpu
|
||||
DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 METAL_DEBUG_ERROR_MODE=0 uv run -m xmlrunner discover -v python/tests -o test-results/gpu
|
||||
DEVICE=cpu python -m xmlrunner discover -v python/tests -o test-results/cpu
|
||||
DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 METAL_DEBUG_ERROR_MODE=0 python -m xmlrunner discover -v python/tests -o test-results/gpu
|
||||
mpirun --bind-to none -host localhost:8 -np 8 -x DYLD_LIBRARY_PATH=/opt/homebrew/lib/ python python/tests/mpi_test_distributed.py
|
||||
mlx.launch --verbose -n 8 python/tests/ring_test_distributed.py -v 2> >(tee -a stderr.log >&2)
|
||||
if $(grep "\[WARN\]" stderr.log); then echo "Distributed ring test failed"; exit 1; fi
|
||||
|
||||
- name: Build example extension
|
||||
shell: bash
|
||||
env:
|
||||
MACOSX_DEPLOYMENT_TARGET: ${{ inputs.macos-target }}
|
||||
shell: bash -l {0}
|
||||
run: |
|
||||
cd examples/extensions
|
||||
uv pip install -r requirements.txt
|
||||
uv run --no-project setup.py build_ext --inplace
|
||||
uv run --no-project test.py
|
||||
pip install -r requirements.txt
|
||||
python setup.py build_ext --inplace
|
||||
python test.py
|
||||
|
||||
- name: Build CPP only
|
||||
shell: bash
|
||||
env:
|
||||
MACOSX_DEPLOYMENT_TARGET: ${{ inputs.macos-target }}
|
||||
shell: bash -l {0}
|
||||
run: |
|
||||
mkdir -p build
|
||||
cd build
|
||||
@@ -75,7 +53,7 @@ runs:
|
||||
make -j $(sysctl -n hw.ncpu)
|
||||
|
||||
- name: Run CPP tests
|
||||
shell: bash
|
||||
shell: bash -l {0}
|
||||
env:
|
||||
DEVICE: gpu
|
||||
METAL_DEVICE_WRAPPER_TYPE: 1
|
||||
@@ -83,9 +61,7 @@ runs:
|
||||
run: ./build/tests/tests
|
||||
|
||||
- name: Build small binary with JIT
|
||||
shell: bash
|
||||
env:
|
||||
MACOSX_DEPLOYMENT_TARGET: ${{ inputs.macos-target }}
|
||||
shell: bash -l {0}
|
||||
run: |
|
||||
mkdir -p build
|
||||
cd build
|
||||
@@ -98,16 +74,15 @@ runs:
|
||||
make -j $(sysctl -n hw.ncpu)
|
||||
|
||||
- name: Run Python tests with JIT
|
||||
shell: bash
|
||||
shell: bash -l {0}
|
||||
env:
|
||||
LOW_MEMORY: 1
|
||||
DEVICE: gpu
|
||||
METAL_DEVICE_WRAPPER_TYPE: 1
|
||||
METAL_DEBUG_ERROR_MODE: 0
|
||||
MACOSX_DEPLOYMENT_TARGET: ${{ inputs.macos-target }}
|
||||
run: |
|
||||
CMAKE_ARGS="-DMLX_METAL_JIT=ON" \
|
||||
uv pip install -e . -v
|
||||
uv run -m xmlrunner discover \
|
||||
pip install -e . -v
|
||||
python -m xmlrunner discover \
|
||||
-v python/tests \
|
||||
-o test-results/gpu_jit
|
||||
|
||||
47
.github/actions/setup-linux/action.yml
vendored
47
.github/actions/setup-linux/action.yml
vendored
@@ -2,14 +2,10 @@ name: 'Setup Linux Environment'
|
||||
description: 'Install dependencies for Linux builds'
|
||||
|
||||
inputs:
|
||||
runner-type:
|
||||
description: 'Whether to set this up as a linux or CUDA runner'
|
||||
toolkit:
|
||||
description: 'Which toolkit to install'
|
||||
required: false
|
||||
default: 'linux'
|
||||
type: choice
|
||||
options:
|
||||
- linux
|
||||
- cuda
|
||||
default: 'cpu'
|
||||
python-version:
|
||||
description: 'Version of python to set up'
|
||||
required: false
|
||||
@@ -21,7 +17,7 @@ runs:
|
||||
- name: Use ccache
|
||||
uses: hendrikmuhs/ccache-action@v1.2
|
||||
with:
|
||||
key: ccache-${{ inputs.runner-type }}-${{ runner.arch }}-py${{ inputs.python-version }}
|
||||
key: ccache-${{ runner.os }}-${{ runner.arch }}-${{ inputs.toolkit }}-py${{ inputs.python-version }}
|
||||
max-size: 1GB
|
||||
|
||||
- name: Install common dependencies
|
||||
@@ -33,7 +29,6 @@ runs:
|
||||
- uses: actions/setup-python@v6
|
||||
with:
|
||||
python-version: ${{ inputs.python-version }}
|
||||
cache: 'pip'
|
||||
|
||||
- name: Setup Python venv
|
||||
shell: bash
|
||||
@@ -49,21 +44,33 @@ runs:
|
||||
shell: bash
|
||||
run: sudo apt-get install -y openmpi-bin openmpi-common libopenmpi-dev
|
||||
|
||||
- name: Network CUDA installation from packages
|
||||
if: inputs.runner-type == 'cuda'
|
||||
shell: bash ## Specific to Ubuntu 22.04 & Architecture x86_64
|
||||
- name: Install CUDA toolkit
|
||||
if: ${{ startsWith(inputs.toolkit, 'cuda') }}
|
||||
shell: bash
|
||||
env:
|
||||
# Note: the CI machine does not meet CUDA 13's driver requirement.
|
||||
# Compatibility matrix:
|
||||
# https://docs.nvidia.com/deeplearning/cudnn/backend/latest/reference/support-matrix.html
|
||||
# The `nvcc` is installed into `/usr/local/cuda-VERSION/bin/nvcc` - but
|
||||
# it's *not* on the default toolkit path.
|
||||
PACKAGES: |
|
||||
{
|
||||
"cuda-12.6": "libcudnn9-dev-cuda-12 cuda-toolkit-12-6",
|
||||
"cuda-12.8": "libcudnn9-dev-cuda-12 cuda-toolkit-12-8",
|
||||
"cuda-12.9": "libcudnn9-dev-cuda-12 cuda-toolkit-12-9",
|
||||
"cuda-13.0": "libcudnn9-dev-cuda-13 cuda-toolkit-13-0"
|
||||
}
|
||||
run: |
|
||||
wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64/cuda-keyring_1.1-1_all.deb
|
||||
export ARCH=${{ runner.arch == 'arm64' && 'arm64' || 'x86_64' }}
|
||||
wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/$ARCH/cuda-keyring_1.1-1_all.deb
|
||||
sudo dpkg -i cuda-keyring_1.1-1_all.deb
|
||||
sudo apt-get update
|
||||
sudo apt-get install -y libcudnn9-dev-cuda-12 libnccl2 libnccl-dev cuda-toolkit-12-9
|
||||
# Note: This installs CUDA 12.9, which is the latest supported by cuDNN 9.x and works with the NVidia 570 drivers
|
||||
# cuda-toolkit by itself installs version 13 (+) and requires updated drives (580+), which require a reboot to function properly.
|
||||
# Compatibility matrix: https://docs.nvidia.com/deeplearning/cudnn/backend/latest/reference/support-matrix.html
|
||||
# This also drops `nvcc` into `/usr/local/cuda-12.9/bin/nvcc` - but it's *not* on the default PATH
|
||||
sudo apt-get install -y \
|
||||
libnccl2 libnccl-dev \
|
||||
${{ fromJson(env.PACKAGES)[inputs.toolkit] }}
|
||||
|
||||
- name: Package and Driver Report
|
||||
if: inputs.runner-type == 'cuda'
|
||||
- name: CUDA packages and driver report
|
||||
if: ${{ startsWith(inputs.toolkit, 'cuda') }}
|
||||
shell: bash
|
||||
run: |
|
||||
sudo apt-get install -y ubuntu-drivers-common dkms
|
||||
|
||||
11
.github/actions/setup-macos/action.yml
vendored
11
.github/actions/setup-macos/action.yml
vendored
@@ -1,6 +1,12 @@
|
||||
name: 'Setup macOS Environment'
|
||||
description: 'Install dependencies for macOS builds'
|
||||
|
||||
inputs:
|
||||
python-version:
|
||||
description: 'Python version to use'
|
||||
required: false
|
||||
default: '3.10'
|
||||
|
||||
runs:
|
||||
using: "composite"
|
||||
steps:
|
||||
@@ -11,3 +17,8 @@ runs:
|
||||
- name: Verify MetalToolchain installed
|
||||
shell: bash
|
||||
run: xcodebuild -showComponent MetalToolchain
|
||||
|
||||
- uses: conda-incubator/setup-miniconda@v3
|
||||
with:
|
||||
miniconda-version: "latest"
|
||||
python-version: ${{ inputs.python-version }}
|
||||
|
||||
13
.github/workflows/nightly.yml
vendored
13
.github/workflows/nightly.yml
vendored
@@ -67,7 +67,6 @@ jobs:
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
- uses: ./.github/actions/build-macos
|
||||
|
||||
- name: Build macOS 15 package
|
||||
uses: ./.github/actions/build-macos-release
|
||||
with:
|
||||
@@ -81,13 +80,19 @@ jobs:
|
||||
|
||||
build_cuda_with_tests:
|
||||
if: github.repository == 'ml-explore/mlx'
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
toolkit: ['cuda-12.8', 'cuda-12.9']
|
||||
runs-on: gpu-t4-4-core
|
||||
steps:
|
||||
- uses: actions/checkout@v5
|
||||
- uses: ./.github/actions/setup-linux
|
||||
with:
|
||||
runner-type: 'cuda'
|
||||
toolkit: ${{ matrix.toolkit }}
|
||||
- uses: ./.github/actions/build-cuda
|
||||
with:
|
||||
toolkit: ${{ matrix.toolkit }}
|
||||
- uses: ./.github/actions/test-linux
|
||||
|
||||
build_cuda_release:
|
||||
@@ -97,11 +102,11 @@ jobs:
|
||||
- uses: actions/checkout@v5
|
||||
- uses: ./.github/actions/setup-linux
|
||||
with:
|
||||
runner-type: 'cuda'
|
||||
toolkit: 'cuda-12.9'
|
||||
- name: Build Python package
|
||||
uses: ./.github/actions/build-cuda-release
|
||||
with:
|
||||
nvcc-location: '/usr/local/cuda-12.9/bin/nvcc'
|
||||
toolkit: 'cuda-12.9'
|
||||
- name: Upload artifacts
|
||||
uses: actions/upload-artifact@v5
|
||||
with:
|
||||
|
||||
18
.github/workflows/pull_request.yml
vendored
18
.github/workflows/pull_request.yml
vendored
@@ -1,10 +1,14 @@
|
||||
name: Build and Test
|
||||
|
||||
on: pull_request
|
||||
on: [pull_request, push]
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.ref }}
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
check_lint:
|
||||
runs-on: ubuntu-22.04
|
||||
@@ -35,24 +39,30 @@ jobs:
|
||||
matrix:
|
||||
macos-target: ["14.0", "15.0"]
|
||||
runs-on: [self-hosted, macos]
|
||||
env:
|
||||
MACOSX_DEPLOYMENT_TARGET: ${{ matrix.macos-target }}
|
||||
needs: check_lint
|
||||
steps:
|
||||
- uses: actions/checkout@v5
|
||||
- uses: ./.github/actions/setup-macos
|
||||
- uses: ./.github/actions/build-macos
|
||||
with:
|
||||
macos-target: ${{ matrix.macos-target }}
|
||||
|
||||
cuda_build_and_test:
|
||||
if: github.repository == 'ml-explore/mlx'
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
toolkit: ['cuda-12.8', 'cuda-12.9']
|
||||
runs-on: gpu-t4-4-core
|
||||
needs: check_lint
|
||||
steps:
|
||||
- uses: actions/checkout@v5
|
||||
- uses: ./.github/actions/setup-linux
|
||||
with:
|
||||
runner-type: 'cuda'
|
||||
toolkit: ${{ matrix.toolkit }}
|
||||
- uses: ./.github/actions/build-cuda
|
||||
with:
|
||||
toolkit: ${{ matrix.toolkit }}
|
||||
- uses: ./.github/actions/test-linux
|
||||
|
||||
build_documentation:
|
||||
|
||||
23
.github/workflows/release.yml
vendored
23
.github/workflows/release.yml
vendored
@@ -47,12 +47,8 @@ jobs:
|
||||
strategy:
|
||||
matrix:
|
||||
python_version: ["3.10", "3.11", "3.12", "3.13", "3.14"]
|
||||
include:
|
||||
- runner: ubuntu-24.04
|
||||
arch: x64
|
||||
- runner: ubuntu-24.04-arm64
|
||||
arch: arm64
|
||||
runs-on: ${{ matrix.runner }}
|
||||
arch: ['x86_64', 'aarch64']
|
||||
runs-on: ${{ matrix.arch == 'x86_64' && 'ubuntu-22.04' || 'ubuntu-22.04-arm' }}
|
||||
env:
|
||||
PYPI_RELEASE: 1
|
||||
DEV_RELEASE: ${{ github.event.inputs.dev_release == 'true' && 1 || 0 }}
|
||||
@@ -68,12 +64,14 @@ jobs:
|
||||
- name: Upload MLX artifacts
|
||||
uses: actions/upload-artifact@v5
|
||||
with:
|
||||
overwrite: true
|
||||
name: linux-wheels-${{ matrix.python_version }}
|
||||
path: wheelhouse/mlx-*.whl
|
||||
- name: Upload CPU artifacts
|
||||
if: matrix.python_version == '3.10'
|
||||
uses: actions/upload-artifact@v5
|
||||
with:
|
||||
overwrite: true
|
||||
name: mlx-cpu
|
||||
path: wheelhouse/mlx_cpu-*.whl
|
||||
|
||||
@@ -90,19 +88,17 @@ jobs:
|
||||
steps:
|
||||
- uses: actions/checkout@v5
|
||||
- uses: ./.github/actions/setup-macos
|
||||
- uses: conda-incubator/setup-miniconda@v3
|
||||
with:
|
||||
miniconda-version: "latest"
|
||||
python-version: ${{ matrix.python-version }}
|
||||
|
||||
- name: Install dependencies
|
||||
shell: sh
|
||||
shell: bash -l {0}
|
||||
run: |
|
||||
pip install --upgrade pip
|
||||
pip install cmake setuptools nanobind==2.4.0
|
||||
pip install -e . -v
|
||||
- name: Generate package stubs
|
||||
shell: bash
|
||||
shell: bash -l {0}
|
||||
run: |
|
||||
pip install typing_extensions
|
||||
python setup.py generate_stubs
|
||||
@@ -119,12 +115,14 @@ jobs:
|
||||
- name: Upload MLX artifacts
|
||||
uses: actions/upload-artifact@v5
|
||||
with:
|
||||
overwrite: true
|
||||
name: mac-wheels-${{ matrix.python-version }}
|
||||
path: dist/mlx-*.whl
|
||||
- name: Upload Metal artifacts
|
||||
if: matrix.python-version == '3.10'
|
||||
uses: actions/upload-artifact@v5
|
||||
with:
|
||||
overwrite: true
|
||||
name: mlx-metal
|
||||
path: dist/mlx_metal-*.whl
|
||||
|
||||
@@ -138,14 +136,15 @@ jobs:
|
||||
- uses: actions/checkout@v5
|
||||
- uses: ./.github/actions/setup-linux
|
||||
with:
|
||||
runner-type: 'cuda'
|
||||
toolkit: 'cuda-12.9'
|
||||
- name: Build Python package
|
||||
uses: ./.github/actions/build-cuda-release
|
||||
with:
|
||||
nvcc-location: '/usr/local/cuda-12.9/bin/nvcc'
|
||||
toolkit: 'cuda-12.9'
|
||||
- name: Upload artifacts
|
||||
uses: actions/upload-artifact@v5
|
||||
with:
|
||||
overwrite: true
|
||||
name: mlx-cuda
|
||||
path: wheelhouse/mlx_cuda-*.whl
|
||||
|
||||
|
||||
35
mlx/fast.cpp
35
mlx/fast.cpp
@@ -578,7 +578,7 @@ array scaled_dot_product_attention(
|
||||
const array& values,
|
||||
const float scale,
|
||||
const std::string& mask_mode /* = "" */,
|
||||
const std::vector<array>& mask_arrs /* = {} */,
|
||||
std::optional<array> mask_arr /* = {} */,
|
||||
const std::optional<array>& sinks /* = {} */,
|
||||
StreamOrDevice s /* = {}*/) {
|
||||
for (const auto& tensor : {queries, keys, values}) {
|
||||
@@ -606,32 +606,22 @@ array scaled_dot_product_attention(
|
||||
has_mask = true;
|
||||
do_causal = true;
|
||||
|
||||
if (!mask_arrs.empty()) {
|
||||
if (mask_arr) {
|
||||
std::ostringstream msg;
|
||||
msg << "[scaled_dot_product_attention] Invalid mask_arrs for mask_mode "
|
||||
<< "'casusal'. No array masks supported.";
|
||||
msg << "[scaled_dot_product_attention] Invalid mask_arr for mask_mode "
|
||||
<< "'casusal'. No array mask should be passed.";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
}
|
||||
|
||||
if (mask_mode == "array" || (mask_mode == "" && !mask_arrs.empty())) {
|
||||
if (mask_arrs.size() != 1) {
|
||||
std::ostringstream msg;
|
||||
msg << "[scaled_dot_product_attention] Invalid mask_arrs for mask_mode "
|
||||
<< "'" << mask_mode << "'. Only 1 mask array is supported, got "
|
||||
<< mask_arrs.size() << "arrays.";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
} else if (mask_arr) {
|
||||
has_mask = true;
|
||||
has_arr_mask = true;
|
||||
has_bool_mask = mask_arrs[0].dtype() == bool_;
|
||||
has_bool_mask = mask_arr->dtype() == bool_;
|
||||
}
|
||||
|
||||
if (has_arr_mask && (mask_arrs[0]).ndim() > 4) {
|
||||
if (has_arr_mask && mask_arr->ndim() > 4) {
|
||||
std::ostringstream msg;
|
||||
msg << "[scaled_dot_product_attention] the mask with shape "
|
||||
<< mask_arrs[0].shape() << " expected to have at most rank 4.";
|
||||
<< mask_arr->shape() << " expected to have at most rank 4.";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
@@ -764,20 +754,19 @@ array scaled_dot_product_attention(
|
||||
std::vector<array> inputs = {q, k, v};
|
||||
if (has_arr_mask) {
|
||||
// Check type
|
||||
auto mask_arr = mask_arrs[0];
|
||||
has_bool_mask = mask_arr.dtype() == bool_;
|
||||
if (promote_types(mask_arr.dtype(), final_type) != final_type) {
|
||||
has_bool_mask = mask_arr->dtype() == bool_;
|
||||
if (promote_types(mask_arr->dtype(), final_type) != final_type) {
|
||||
std::ostringstream msg;
|
||||
msg << "[scaled_dot_product_attention] Mask type must promote to output type "
|
||||
<< final_type << ".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
} else if (!has_bool_mask) {
|
||||
mask_arr = astype(mask_arr, final_type, stream);
|
||||
mask_arr = astype(*mask_arr, final_type, stream);
|
||||
}
|
||||
// Broadcast mask
|
||||
auto mask_shape = queries.shape();
|
||||
mask_shape.back() = keys.shape(-2);
|
||||
inputs.push_back(broadcast_to(mask_arr, mask_shape, stream));
|
||||
inputs.push_back(broadcast_to(*mask_arr, mask_shape, stream));
|
||||
}
|
||||
if (has_sinks) {
|
||||
if (promote_types(sinks->dtype(), final_type) != final_type) {
|
||||
|
||||
@@ -49,7 +49,7 @@ array scaled_dot_product_attention(
|
||||
const array& values,
|
||||
const float scale,
|
||||
const std::string& mask_mode = "",
|
||||
const std::vector<array>& mask_arrs = {},
|
||||
std::optional<array> mask_arr = {},
|
||||
const std::optional<array>& sinks = {},
|
||||
StreamOrDevice s = {});
|
||||
|
||||
|
||||
@@ -213,11 +213,11 @@ void init_fast(nb::module_& parent_module) {
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
return mx::fast::scaled_dot_product_attention(
|
||||
queries, keys, values, scale, mask_str, {}, sinks, s);
|
||||
queries, keys, values, scale, mask_str, std::nullopt, sinks, s);
|
||||
} else {
|
||||
auto mask_arr = std::get<mx::array>(mask);
|
||||
return mx::fast::scaled_dot_product_attention(
|
||||
queries, keys, values, scale, "", {mask_arr}, sinks, s);
|
||||
queries, keys, values, scale, "", mask_arr, sinks, s);
|
||||
}
|
||||
|
||||
} else {
|
||||
|
||||
Reference in New Issue
Block a user