mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Compare commits
4 Commits
0e0d9ac522
...
a4fcc893cd
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a4fcc893cd | ||
|
|
9d10239af7 | ||
|
|
19facd4b20 | ||
|
|
f5299f72cd |
@@ -41,7 +41,7 @@ jobs:
|
|||||||
pip install --upgrade pip
|
pip install --upgrade pip
|
||||||
pip install --upgrade cmake
|
pip install --upgrade cmake
|
||||||
pip install -r docs/requirements.txt
|
pip install -r docs/requirements.txt
|
||||||
CMAKE_BUILD_PARALLEL_LEVEL=`sysctl -n hw.ncpu` pip install . -v
|
pip install . -v
|
||||||
- when:
|
- when:
|
||||||
condition:
|
condition:
|
||||||
not: << parameters.upload-docs >>
|
not: << parameters.upload-docs >>
|
||||||
@@ -97,10 +97,8 @@ jobs:
|
|||||||
name: Install Python package
|
name: Install Python package
|
||||||
command: |
|
command: |
|
||||||
CMAKE_ARGS="-DMLX_BUILD_METAL=OFF" \
|
CMAKE_ARGS="-DMLX_BUILD_METAL=OFF" \
|
||||||
CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
|
|
||||||
python3 setup.py build_ext --inplace
|
python3 setup.py build_ext --inplace
|
||||||
CMAKE_ARGS="-DMLX_BUILD_METAL=OFF" \
|
CMAKE_ARGS="-DMLX_BUILD_METAL=OFF" \
|
||||||
CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
|
|
||||||
python3 setup.py develop
|
python3 setup.py develop
|
||||||
- run:
|
- run:
|
||||||
name: Generate package stubs
|
name: Generate package stubs
|
||||||
@@ -157,8 +155,7 @@ jobs:
|
|||||||
name: Install Python package
|
name: Install Python package
|
||||||
command: |
|
command: |
|
||||||
source env/bin/activate
|
source env/bin/activate
|
||||||
DEBUG=1 CMAKE_BUILD_PARALLEL_LEVEL=`sysctl -n hw.ncpu` \
|
DEBUG=1 CMAKE_ARGS="-DCMAKE_COMPILE_WARNING_AS_ERROR=ON" \
|
||||||
CMAKE_ARGS="-DCMAKE_COMPILE_WARNING_AS_ERROR=ON" \
|
|
||||||
pip install -e . -v
|
pip install -e . -v
|
||||||
- run:
|
- run:
|
||||||
name: Generate package stubs
|
name: Generate package stubs
|
||||||
@@ -208,8 +205,7 @@ jobs:
|
|||||||
name: Run Python tests with JIT
|
name: Run Python tests with JIT
|
||||||
command: |
|
command: |
|
||||||
source env/bin/activate
|
source env/bin/activate
|
||||||
CMAKE_BUILD_PARALLEL_LEVEL=`sysctl -n hw.ncpu` \
|
CMAKE_ARGS="-DMLX_METAL_JIT=ON" \
|
||||||
CMAKE_ARGS="-DMLX_METAL_JIT=ON" \
|
|
||||||
pip install -e . -v
|
pip install -e . -v
|
||||||
LOW_MEMORY=1 DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 \
|
LOW_MEMORY=1 DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 \
|
||||||
METAL_DEBUG_ERROR_MODE=0 \
|
METAL_DEBUG_ERROR_MODE=0 \
|
||||||
@@ -228,8 +224,7 @@ jobs:
|
|||||||
sudo apt-get install libblas-dev liblapack-dev liblapacke-dev
|
sudo apt-get install libblas-dev liblapack-dev liblapacke-dev
|
||||||
python -m venv env
|
python -m venv env
|
||||||
source env/bin/activate
|
source env/bin/activate
|
||||||
CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
|
CMAKE_ARGS="-DMLX_BUILD_CUDA=ON -DCMAKE_CUDA_COMPILER=`which nvcc`" \
|
||||||
CMAKE_ARGS="-DMLX_BUILD_CUDA=ON -DCMAKE_CUDA_COMPILER=`which nvcc`" \
|
|
||||||
pip install -e ".[dev]"
|
pip install -e ".[dev]"
|
||||||
- run:
|
- run:
|
||||||
name: Run Python tests
|
name: Run Python tests
|
||||||
@@ -278,7 +273,6 @@ jobs:
|
|||||||
command: |
|
command: |
|
||||||
source env/bin/activate
|
source env/bin/activate
|
||||||
env -u MACOSX_DEPLOYMENT_TARGET DEV_RELEASE=1 \
|
env -u MACOSX_DEPLOYMENT_TARGET DEV_RELEASE=1 \
|
||||||
CMAKE_BUILD_PARALLEL_LEVEL=`sysctl -n hw.ncpu` \
|
|
||||||
pip install . -v
|
pip install . -v
|
||||||
- run:
|
- run:
|
||||||
name: Generate package stubs
|
name: Generate package stubs
|
||||||
@@ -290,9 +284,7 @@ jobs:
|
|||||||
name: Build Python package
|
name: Build Python package
|
||||||
command: |
|
command: |
|
||||||
source env/bin/activate
|
source env/bin/activate
|
||||||
<< parameters.build_env >> \
|
<< parameters.build_env >> python -m build -w
|
||||||
CMAKE_BUILD_PARALLEL_LEVEL=`sysctl -n hw.ncpu` \
|
|
||||||
python -m build -w
|
|
||||||
- when:
|
- when:
|
||||||
condition: << parameters.build_env >>
|
condition: << parameters.build_env >>
|
||||||
steps:
|
steps:
|
||||||
@@ -340,14 +332,10 @@ jobs:
|
|||||||
pip install patchelf
|
pip install patchelf
|
||||||
pip install build
|
pip install build
|
||||||
pip install twine
|
pip install twine
|
||||||
<< parameters.extra_env >> \
|
<< parameters.extra_env >> pip install . -v
|
||||||
CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
|
|
||||||
pip install . -v
|
|
||||||
pip install typing_extensions
|
pip install typing_extensions
|
||||||
python setup.py generate_stubs
|
python setup.py generate_stubs
|
||||||
<< parameters.extra_env >> \
|
<< parameters.extra_env >> python -m build --wheel
|
||||||
CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
|
|
||||||
python -m build --wheel
|
|
||||||
auditwheel show dist/*
|
auditwheel show dist/*
|
||||||
auditwheel repair dist/* --plat manylinux_2_31_x86_64
|
auditwheel repair dist/* --plat manylinux_2_31_x86_64
|
||||||
- run:
|
- run:
|
||||||
@@ -383,12 +371,10 @@ jobs:
|
|||||||
pip install build
|
pip install build
|
||||||
pip install twine
|
pip install twine
|
||||||
<< parameters.extra_env >> \
|
<< parameters.extra_env >> \
|
||||||
CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
|
|
||||||
CMAKE_ARGS="-DMLX_BUILD_CUDA=ON -DCMAKE_CUDA_COMPILER=`which nvcc`" \
|
CMAKE_ARGS="-DMLX_BUILD_CUDA=ON -DCMAKE_CUDA_COMPILER=`which nvcc`" \
|
||||||
pip install ".[dev]" -v
|
pip install ".[dev]" -v
|
||||||
python setup.py generate_stubs
|
python setup.py generate_stubs
|
||||||
<< parameters.extra_env >> \
|
<< parameters.extra_env >> \
|
||||||
CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
|
|
||||||
CMAKE_ARGS="-DMLX_BUILD_CUDA=ON -DCMAKE_CUDA_COMPILER=`which nvcc`" \
|
CMAKE_ARGS="-DMLX_BUILD_CUDA=ON -DCMAKE_CUDA_COMPILER=`which nvcc`" \
|
||||||
python -m build --wheel
|
python -m build --wheel
|
||||||
bash python/scripts/repair_cuda.sh
|
bash python/scripts/repair_cuda.sh
|
||||||
@@ -506,6 +492,16 @@ workflows:
|
|||||||
branches:
|
branches:
|
||||||
ignore: /.*/
|
ignore: /.*/
|
||||||
upload-docs: true
|
upload-docs: true
|
||||||
|
- build_linux_release:
|
||||||
|
filters:
|
||||||
|
tags:
|
||||||
|
only: /^v.*/
|
||||||
|
branches:
|
||||||
|
ignore: /.*/
|
||||||
|
matrix:
|
||||||
|
parameters:
|
||||||
|
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
|
||||||
|
extra_env: ["PYPI_RELEASE=1"]
|
||||||
|
|
||||||
prb:
|
prb:
|
||||||
when:
|
when:
|
||||||
|
|||||||
@@ -88,20 +88,20 @@ Then simply build and install MLX using pip:
|
|||||||
|
|
||||||
.. code-block:: shell
|
.. code-block:: shell
|
||||||
|
|
||||||
CMAKE_BUILD_PARALLEL_LEVEL=8 pip install .
|
pip install .
|
||||||
|
|
||||||
For developing, install the package with development dependencies, and use an
|
For developing, install the package with development dependencies, and use an
|
||||||
editable install:
|
editable install:
|
||||||
|
|
||||||
.. code-block:: shell
|
.. code-block:: shell
|
||||||
|
|
||||||
CMAKE_BUILD_PARALLEL_LEVEL=8 pip install -e ".[dev]"
|
pip install -e ".[dev]"
|
||||||
|
|
||||||
Once the development dependencies are installed, you can build faster with:
|
Once the development dependencies are installed, you can build faster with:
|
||||||
|
|
||||||
.. code-block:: shell
|
.. code-block:: shell
|
||||||
|
|
||||||
CMAKE_BUILD_PARALLEL_LEVEL=8 python setup.py build_ext --inplace
|
python setup.py build_ext --inplace
|
||||||
|
|
||||||
Run the tests with:
|
Run the tests with:
|
||||||
|
|
||||||
@@ -262,7 +262,7 @@ When building either the Python or C++ APIs make sure to pass the cmake flag
|
|||||||
|
|
||||||
.. code-block:: shell
|
.. code-block:: shell
|
||||||
|
|
||||||
CMAKE_BUILD_PARALLEL_LEVEL=8 CMAKE_ARGS="-DMLX_BUILD_CUDA=ON" pip install -e ".[dev]"
|
CMAKE_ARGS="-DMLX_BUILD_CUDA=ON" pip install -e ".[dev]"
|
||||||
|
|
||||||
To build the C++ package run:
|
To build the C++ package run:
|
||||||
|
|
||||||
|
|||||||
@@ -17,35 +17,106 @@ namespace cu {
|
|||||||
|
|
||||||
namespace cg = cooperative_groups;
|
namespace cg = cooperative_groups;
|
||||||
|
|
||||||
template <typename Op, typename In, typename Out, typename IdxT>
|
template <typename Op, typename In, typename Out, typename IdxT, int N_READS>
|
||||||
__global__ void binary_ss(const In* a, const In* b, Out* out, IdxT size) {
|
__global__ void binary_ss(const In* a, const In* b, Out* out, IdxT size) {
|
||||||
IdxT index = cg::this_grid().thread_rank();
|
IdxT index = cg::this_grid().thread_rank();
|
||||||
if (index < size) {
|
int remaining = size - index * N_READS;
|
||||||
out[index] = Op{}(a[0], b[0]);
|
if (remaining <= 0) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (remaining < N_READS) {
|
||||||
|
for (int i = 0; i < remaining; ++i) {
|
||||||
|
IdxT offset = index * N_READS + i;
|
||||||
|
out[offset] = Op{}(a[0], b[0]);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
AlignedVector<Out, N_READS> out_vec;
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < N_READS; ++i) {
|
||||||
|
out_vec.val[i] = Op{}(a[0], b[0]);
|
||||||
|
}
|
||||||
|
|
||||||
|
store_vector<N_READS>(out, index, out_vec);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename Op, typename In, typename Out, typename IdxT>
|
template <typename Op, typename In, typename Out, typename IdxT, int N_READS>
|
||||||
__global__ void binary_sv(const In* a, const In* b, Out* out, IdxT size) {
|
__global__ void binary_sv(const In* a, const In* b, Out* out, IdxT size) {
|
||||||
IdxT index = cg::this_grid().thread_rank();
|
IdxT index = cg::this_grid().thread_rank();
|
||||||
if (index < size) {
|
int remaining = size - index * N_READS;
|
||||||
out[index] = Op{}(a[0], b[index]);
|
if (remaining <= 0) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (remaining < N_READS) {
|
||||||
|
for (int i = 0; i < remaining; ++i) {
|
||||||
|
IdxT offset = index * N_READS + i;
|
||||||
|
out[offset] = Op{}(a[0], b[offset]);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
auto b_vec = load_vector<N_READS>(b, index);
|
||||||
|
|
||||||
|
AlignedVector<Out, N_READS> out_vec;
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < N_READS; ++i) {
|
||||||
|
out_vec.val[i] = Op{}(a[0], b_vec.val[i]);
|
||||||
|
}
|
||||||
|
|
||||||
|
store_vector<N_READS>(out, index, out_vec);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename Op, typename In, typename Out, typename IdxT>
|
template <typename Op, typename In, typename Out, typename IdxT, int N_READS>
|
||||||
__global__ void binary_vs(const In* a, const In* b, Out* out, IdxT size) {
|
__global__ void binary_vs(const In* a, const In* b, Out* out, IdxT size) {
|
||||||
IdxT index = cg::this_grid().thread_rank();
|
IdxT index = cg::this_grid().thread_rank();
|
||||||
if (index < size) {
|
int remaining = size - index * N_READS;
|
||||||
out[index] = Op{}(a[index], b[0]);
|
if (remaining <= 0) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (remaining < N_READS) {
|
||||||
|
for (int i = 0; i < remaining; ++i) {
|
||||||
|
IdxT offset = index * N_READS + i;
|
||||||
|
out[offset] = Op{}(a[offset], b[0]);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
auto a_vec = load_vector<N_READS>(a, index);
|
||||||
|
|
||||||
|
AlignedVector<Out, N_READS> out_vec;
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < N_READS; ++i) {
|
||||||
|
out_vec.val[i] = Op{}(a_vec.val[i], b[0]);
|
||||||
|
}
|
||||||
|
|
||||||
|
store_vector<N_READS>(out, index, out_vec);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename Op, typename In, typename Out, typename IdxT>
|
template <typename Op, typename In, typename Out, typename IdxT, int N_READS>
|
||||||
__global__ void binary_vv(const In* a, const In* b, Out* out, IdxT size) {
|
__global__ void binary_vv(const In* a, const In* b, Out* out, IdxT size) {
|
||||||
IdxT index = cg::this_grid().thread_rank();
|
IdxT index = cg::this_grid().thread_rank();
|
||||||
if (index < size) {
|
int remaining = size - index * N_READS;
|
||||||
out[index] = Op{}(a[index], b[index]);
|
if (remaining <= 0) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (remaining < N_READS) {
|
||||||
|
for (int i = 0; i < remaining; ++i) {
|
||||||
|
IdxT offset = index * N_READS + i;
|
||||||
|
out[offset] = Op{}(a[offset], b[offset]);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
auto a_vec = load_vector<N_READS>(a, index);
|
||||||
|
auto b_vec = load_vector<N_READS>(b, index);
|
||||||
|
|
||||||
|
AlignedVector<Out, N_READS> out_vec;
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < N_READS; ++i) {
|
||||||
|
out_vec.val[i] = Op{}(a_vec.val[i], b_vec.val[i]);
|
||||||
|
}
|
||||||
|
|
||||||
|
store_vector<N_READS>(out, index, out_vec);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -198,16 +269,23 @@ void binary_op_gpu_inplace(
|
|||||||
} else {
|
} else {
|
||||||
dispatch_bool(out.data_size() > INT32_MAX, [&](auto large) {
|
dispatch_bool(out.data_size() > INT32_MAX, [&](auto large) {
|
||||||
using IdxT = std::conditional_t<large(), int64_t, uint32_t>;
|
using IdxT = std::conditional_t<large(), int64_t, uint32_t>;
|
||||||
auto kernel = cu::binary_ss<Op, InType, OutType, IdxT>;
|
// TODO: Choose optimized value based on type size.
|
||||||
|
constexpr int N_READS = 4;
|
||||||
|
auto kernel = cu::binary_ss<Op, InType, OutType, IdxT, N_READS>;
|
||||||
if (bopt == BinaryOpType::ScalarVector) {
|
if (bopt == BinaryOpType::ScalarVector) {
|
||||||
kernel = cu::binary_sv<Op, InType, OutType, IdxT>;
|
kernel = cu::binary_sv<Op, InType, OutType, IdxT, N_READS>;
|
||||||
} else if (bopt == BinaryOpType::VectorScalar) {
|
} else if (bopt == BinaryOpType::VectorScalar) {
|
||||||
kernel = cu::binary_vs<Op, InType, OutType, IdxT>;
|
kernel = cu::binary_vs<Op, InType, OutType, IdxT, N_READS>;
|
||||||
} else if (bopt == BinaryOpType::VectorVector) {
|
} else if (bopt == BinaryOpType::VectorVector) {
|
||||||
kernel = cu::binary_vv<Op, InType, OutType, IdxT>;
|
kernel = cu::binary_vv<Op, InType, OutType, IdxT, N_READS>;
|
||||||
}
|
}
|
||||||
auto [num_blocks, block_dims] = get_launch_args(
|
auto [num_blocks, block_dims] = get_launch_args(
|
||||||
kernel, out.data_size(), out.shape(), out.strides(), large());
|
kernel,
|
||||||
|
out.data_size(),
|
||||||
|
out.shape(),
|
||||||
|
out.strides(),
|
||||||
|
large(),
|
||||||
|
N_READS);
|
||||||
encoder.add_kernel_node(
|
encoder.add_kernel_node(
|
||||||
kernel,
|
kernel,
|
||||||
num_blocks,
|
num_blocks,
|
||||||
|
|||||||
@@ -28,6 +28,27 @@ namespace mlx::core::cu {
|
|||||||
using Shape = cuda::std::array<int32_t, MAX_NDIM>;
|
using Shape = cuda::std::array<int32_t, MAX_NDIM>;
|
||||||
using Strides = cuda::std::array<int64_t, MAX_NDIM>;
|
using Strides = cuda::std::array<int64_t, MAX_NDIM>;
|
||||||
|
|
||||||
|
// Vectorized load/store.
|
||||||
|
template <typename T, int N>
|
||||||
|
struct alignas(sizeof(T) * N) AlignedVector {
|
||||||
|
T val[N];
|
||||||
|
};
|
||||||
|
|
||||||
|
template <int N, typename T>
|
||||||
|
inline __device__ AlignedVector<T, N> load_vector(
|
||||||
|
const T* ptr,
|
||||||
|
uint32_t offset) {
|
||||||
|
auto* from = reinterpret_cast<const AlignedVector<T, N>*>(ptr);
|
||||||
|
return from[offset];
|
||||||
|
}
|
||||||
|
|
||||||
|
template <int N, typename T>
|
||||||
|
inline __device__ void
|
||||||
|
store_vector(T* ptr, uint32_t offset, const AlignedVector<T, N>& vec) {
|
||||||
|
auto* to = reinterpret_cast<AlignedVector<T, N>*>(ptr);
|
||||||
|
to[offset] = vec;
|
||||||
|
}
|
||||||
|
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
// Type limits utils
|
// Type limits utils
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
|||||||
@@ -31,6 +31,7 @@ inline void threadgroup_sum(
|
|||||||
for (int i = 0; i < N; i++) {
|
for (int i = 0; i < N; i++) {
|
||||||
x[i] = simd_sum(x[i]);
|
x[i] = simd_sum(x[i]);
|
||||||
}
|
}
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
if (simd_lane_id == 0) {
|
if (simd_lane_id == 0) {
|
||||||
for (int i = 0; i < N; i++) {
|
for (int i = 0; i < N; i++) {
|
||||||
xs[N * simd_group_id + i] = x[i];
|
xs[N * simd_group_id + i] = x[i];
|
||||||
|
|||||||
@@ -53,11 +53,7 @@ class CMakeBuild(build_ext):
|
|||||||
# Set CMAKE_BUILD_PARALLEL_LEVEL to control the parallel build level
|
# Set CMAKE_BUILD_PARALLEL_LEVEL to control the parallel build level
|
||||||
# across all generators.
|
# across all generators.
|
||||||
if "CMAKE_BUILD_PARALLEL_LEVEL" not in os.environ:
|
if "CMAKE_BUILD_PARALLEL_LEVEL" not in os.environ:
|
||||||
# self.parallel is a Python 3 only way to set parallel jobs by hand
|
build_args += [f"-j{os.cpu_count()}"]
|
||||||
# using -j in the build_ext call, not supported by pip or PyPA-build.
|
|
||||||
if hasattr(self, "parallel") and self.parallel:
|
|
||||||
# CMake 3.12+ only.
|
|
||||||
build_args += [f"-j{self.parallel}"]
|
|
||||||
|
|
||||||
build_temp = Path(self.build_temp) / ext.name
|
build_temp = Path(self.build_temp) / ext.name
|
||||||
if not build_temp.exists():
|
if not build_temp.exists():
|
||||||
|
|||||||
@@ -175,11 +175,12 @@ void init_fast(nb::module_& parent_module) {
|
|||||||
* `Grouped Query Attention <https://arxiv.org/abs/2305.13245>`_
|
* `Grouped Query Attention <https://arxiv.org/abs/2305.13245>`_
|
||||||
* `Multi-Query Attention <https://arxiv.org/abs/1911.02150>`_
|
* `Multi-Query Attention <https://arxiv.org/abs/1911.02150>`_
|
||||||
|
|
||||||
Note: The softmax operation is performed in ``float32`` regardless of
|
.. note::
|
||||||
the input precision.
|
|
||||||
|
|
||||||
Note: For Grouped Query Attention and Multi-Query Attention, the ``k``
|
* The softmax operation is performed in ``float32`` regardless of
|
||||||
and ``v`` inputs should not be pre-tiled to match ``q``.
|
the input precision.
|
||||||
|
* For Grouped Query Attention and Multi-Query Attention, the ``k``
|
||||||
|
and ``v`` inputs should not be pre-tiled to match ``q``.
|
||||||
|
|
||||||
In the following the dimensions are given by:
|
In the following the dimensions are given by:
|
||||||
|
|
||||||
@@ -195,13 +196,30 @@ void init_fast(nb::module_& parent_module) {
|
|||||||
k (array): Keys with shape ``[B, N_kv, T_kv, D]``.
|
k (array): Keys with shape ``[B, N_kv, T_kv, D]``.
|
||||||
v (array): Values with shape ``[B, N_kv, T_kv, D]``.
|
v (array): Values with shape ``[B, N_kv, T_kv, D]``.
|
||||||
scale (float): Scale for queries (typically ``1.0 / sqrt(q.shape(-1)``)
|
scale (float): Scale for queries (typically ``1.0 / sqrt(q.shape(-1)``)
|
||||||
mask (Union[None, str, array], optional): A causal, boolean or additive
|
mask (Union[None, str, array], optional): The mask to apply to the
|
||||||
mask to apply to the query-key scores. The mask can have at most 4
|
query-key scores. The mask can be an array or a string indicating
|
||||||
dimensions and must be broadcast-compatible with the shape
|
the mask type. The only supported string type is ``"causal"``. If
|
||||||
``[B, N, T_q, T_kv]``. If an additive mask is given its type must
|
the mask is an array it can be a boolean or additive mask. The mask
|
||||||
promote to the promoted type of ``q``, ``k``, and ``v``.
|
can have at most 4 dimensions and must be broadcast-compatible with
|
||||||
|
the shape ``[B, N, T_q, T_kv]``. If an additive mask is given its
|
||||||
|
type must promote to the promoted type of ``q``, ``k``, and ``v``.
|
||||||
Returns:
|
Returns:
|
||||||
array: The output array.
|
array: The output array.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
B = 2
|
||||||
|
N_q = N_kv = 32
|
||||||
|
T_q = T_kv = 1000
|
||||||
|
D = 128
|
||||||
|
|
||||||
|
q = mx.random.normal(shape=(B, N_q, T_q, D))
|
||||||
|
k = mx.random.normal(shape=(B, N_kv, T_kv, D))
|
||||||
|
v = mx.random.normal(shape=(B, N_kv, T_kv, D))
|
||||||
|
scale = D ** -0.5
|
||||||
|
out = mx.fast.scaled_dot_product_attention(q, k, v, scale=scale, mask="causal")
|
||||||
)pbdoc");
|
)pbdoc");
|
||||||
|
|
||||||
m.def(
|
m.def(
|
||||||
|
|||||||
6
setup.py
6
setup.py
@@ -97,11 +97,7 @@ class CMakeBuild(build_ext):
|
|||||||
# Set CMAKE_BUILD_PARALLEL_LEVEL to control the parallel build level
|
# Set CMAKE_BUILD_PARALLEL_LEVEL to control the parallel build level
|
||||||
# across all generators.
|
# across all generators.
|
||||||
if "CMAKE_BUILD_PARALLEL_LEVEL" not in os.environ:
|
if "CMAKE_BUILD_PARALLEL_LEVEL" not in os.environ:
|
||||||
# self.parallel is a Python 3 only way to set parallel jobs by hand
|
build_args += [f"-j{os.cpu_count()}"]
|
||||||
# using -j in the build_ext call, not supported by pip or PyPA-build.
|
|
||||||
if hasattr(self, "parallel") and self.parallel:
|
|
||||||
# CMake 3.12+ only.
|
|
||||||
build_args += [f"-j{self.parallel}"]
|
|
||||||
|
|
||||||
build_temp = Path(self.build_temp) / ext.name
|
build_temp = Path(self.build_temp) / ext.name
|
||||||
if not build_temp.exists():
|
if not build_temp.exists():
|
||||||
|
|||||||
Reference in New Issue
Block a user