mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Compare commits
7 Commits
v0.29.0
...
c1e3340b23
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c1e3340b23 | ||
|
|
8f163a367d | ||
|
|
89a3df9014 | ||
|
|
c5d2937aa5 | ||
|
|
b61a65e313 | ||
|
|
04cbb4191c | ||
|
|
c5460762e7 |
@@ -230,6 +230,9 @@ jobs:
|
|||||||
sudo mv ccache-4.11.3-linux-x86_64/ccache /usr/bin/ccache
|
sudo mv ccache-4.11.3-linux-x86_64/ccache /usr/bin/ccache
|
||||||
rm -rf ccache-4.11.3-linux-x86_64
|
rm -rf ccache-4.11.3-linux-x86_64
|
||||||
curl -LsSf https://astral.sh/uv/install.sh | sh
|
curl -LsSf https://astral.sh/uv/install.sh | sh
|
||||||
|
- run:
|
||||||
|
name: Set CCache size
|
||||||
|
command: ccache --max-size 1G
|
||||||
- run:
|
- run:
|
||||||
name: Install Python package
|
name: Install Python package
|
||||||
command: |
|
command: |
|
||||||
@@ -260,7 +263,6 @@ jobs:
|
|||||||
command: |
|
command: |
|
||||||
ccache --show-stats
|
ccache --show-stats
|
||||||
ccache --zero-stats
|
ccache --zero-stats
|
||||||
ccache --max-size 400MB
|
|
||||||
ccache --cleanup
|
ccache --cleanup
|
||||||
- save_cache:
|
- save_cache:
|
||||||
key: cuda-<< parameters.image_date >>-{{ arch }}-{{ epoch }}
|
key: cuda-<< parameters.image_date >>-{{ arch }}-{{ epoch }}
|
||||||
|
|||||||
@@ -107,8 +107,20 @@ same array:
|
|||||||
>>> a
|
>>> a
|
||||||
array([1, 2, 0], dtype=int32)
|
array([1, 2, 0], dtype=int32)
|
||||||
|
|
||||||
|
Note that unlike NumPy, slicing an array creates a copy, not a view. So
|
||||||
|
mutating it does not mutate the original array:
|
||||||
|
|
||||||
Note, unlike NumPy, updates to the same location are nondeterministic:
|
.. code-block:: shell
|
||||||
|
|
||||||
|
>>> a = mx.array([1, 2, 3])
|
||||||
|
>>> b = a[:]
|
||||||
|
>>> b[2] = 0
|
||||||
|
>>> b
|
||||||
|
array([1, 2, 0], dtype=int32)
|
||||||
|
>>> a
|
||||||
|
array([1, 2, 3], dtype=int32)
|
||||||
|
|
||||||
|
Also unlike NumPy, updates to the same location are nondeterministic:
|
||||||
|
|
||||||
.. code-block:: shell
|
.. code-block:: shell
|
||||||
|
|
||||||
|
|||||||
@@ -394,7 +394,7 @@ void ScaledDotProductAttention::eval_gpu(
|
|||||||
|
|
||||||
// Define some copy functions to ensure the layout of the inputs is as
|
// Define some copy functions to ensure the layout of the inputs is as
|
||||||
// expected.
|
// expected.
|
||||||
copies.reserve(3);
|
copies.reserve(inputs.size());
|
||||||
auto copy_unless = [&copies, &s](
|
auto copy_unless = [&copies, &s](
|
||||||
auto predicate, const array& arr) -> const array& {
|
auto predicate, const array& arr) -> const array& {
|
||||||
if (!predicate(arr)) {
|
if (!predicate(arr)) {
|
||||||
|
|||||||
@@ -1,20 +1,34 @@
|
|||||||
|
mlx.core.__prefix__:
|
||||||
|
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union
|
||||||
|
import sys
|
||||||
|
if sys.version_info >= (3, 10):
|
||||||
|
from typing import TypeAlias
|
||||||
|
else:
|
||||||
|
from typing_extensions import TypeAlias
|
||||||
|
|
||||||
|
mlx.core.__suffix__:
|
||||||
|
from typing import Union
|
||||||
|
scalar: TypeAlias = Union[int, float, bool]
|
||||||
|
list_or_scalar: TypeAlias = Union[scalar, list["list_or_scalar"]]
|
||||||
|
bool_: Dtype = ...
|
||||||
|
|
||||||
mlx.core.distributed.__prefix__:
|
mlx.core.distributed.__prefix__:
|
||||||
from mlx.core import array, Dtype, Device, Stream
|
from mlx.core import array, Dtype, Device, Stream, scalar
|
||||||
from mlx.core.distributed import Group
|
from mlx.core.distributed import Group
|
||||||
from typing import Sequence, Optional, Union
|
from typing import Sequence, Optional, Union
|
||||||
|
|
||||||
mlx.core.fast.__prefix__:
|
mlx.core.fast.__prefix__:
|
||||||
from mlx.core import array, Dtype, Device, Stream
|
from mlx.core import array, Dtype, Device, Stream, scalar
|
||||||
from typing import Sequence, Optional, Union
|
from typing import Sequence, Optional, Union
|
||||||
|
|
||||||
mlx.core.linalg.__prefix__:
|
mlx.core.linalg.__prefix__:
|
||||||
from mlx.core import array, Dtype, Device, Stream
|
from mlx.core import array, Dtype, Device, Stream, scalar
|
||||||
from typing import Sequence, Optional, Tuple, Union
|
from typing import Sequence, Optional, Tuple, Union
|
||||||
|
|
||||||
mlx.core.metal.__prefix__:
|
mlx.core.metal.__prefix__:
|
||||||
from mlx.core import array, Dtype, Device, Stream
|
from mlx.core import array, Dtype, Device, Stream, scalar
|
||||||
from typing import Sequence, Optional, Union
|
from typing import Sequence, Optional, Union
|
||||||
|
|
||||||
mlx.core.random.__prefix__:
|
mlx.core.random.__prefix__:
|
||||||
from mlx.core import array, Dtype, Device, Stream
|
from mlx.core import array, Dtype, Device, Stream, scalar, float32, int32
|
||||||
from typing import Sequence, Optional, Union
|
from typing import Sequence, Optional, Union
|
||||||
|
|||||||
@@ -556,7 +556,7 @@ class AdamW(Adam):
|
|||||||
eps (float, optional): The term :math:`\epsilon` added to the
|
eps (float, optional): The term :math:`\epsilon` added to the
|
||||||
denominator to improve numerical stability. Default: ``1e-8``
|
denominator to improve numerical stability. Default: ``1e-8``
|
||||||
weight_decay (float, optional): The weight decay :math:`\lambda`.
|
weight_decay (float, optional): The weight decay :math:`\lambda`.
|
||||||
Default: ``0``.
|
Default: ``0.01``.
|
||||||
bias_correction (bool, optional): If set to ``True``, bias correction
|
bias_correction (bool, optional): If set to ``True``, bias correction
|
||||||
is applied. Default: ``False``
|
is applied. Default: ``False``
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -320,6 +320,7 @@ void init_array(nb::module_& m) {
|
|||||||
.def_prop_ro(
|
.def_prop_ro(
|
||||||
"shape",
|
"shape",
|
||||||
[](const mx::array& a) { return nb::cast(a.shape()); },
|
[](const mx::array& a) { return nb::cast(a.shape()); },
|
||||||
|
nb::sig("def shape(self) -> tuple[int, ...]"),
|
||||||
R"pbdoc(
|
R"pbdoc(
|
||||||
The shape of the array as a Python tuple.
|
The shape of the array as a Python tuple.
|
||||||
|
|
||||||
@@ -347,6 +348,7 @@ void init_array(nb::module_& m) {
|
|||||||
.def(
|
.def(
|
||||||
"item",
|
"item",
|
||||||
&to_scalar,
|
&to_scalar,
|
||||||
|
nb::sig("def item(self) -> scalar"),
|
||||||
R"pbdoc(
|
R"pbdoc(
|
||||||
Access the value of a scalar array.
|
Access the value of a scalar array.
|
||||||
|
|
||||||
@@ -356,6 +358,7 @@ void init_array(nb::module_& m) {
|
|||||||
.def(
|
.def(
|
||||||
"tolist",
|
"tolist",
|
||||||
&tolist,
|
&tolist,
|
||||||
|
nb::sig("def tolist(self) -> list_or_scalar"),
|
||||||
R"pbdoc(
|
R"pbdoc(
|
||||||
Convert the array to a Python :class:`list`.
|
Convert the array to a Python :class:`list`.
|
||||||
|
|
||||||
|
|||||||
@@ -447,6 +447,8 @@ void init_linalg(nb::module_& parent_module) {
|
|||||||
"a"_a,
|
"a"_a,
|
||||||
nb::kw_only(),
|
nb::kw_only(),
|
||||||
"stream"_a = nb::none(),
|
"stream"_a = nb::none(),
|
||||||
|
nb::sig(
|
||||||
|
"def eig(a: array, *, stream: Union[None, Stream, Device] = None) -> Tuple[array, array]"),
|
||||||
R"pbdoc(
|
R"pbdoc(
|
||||||
Compute the eigenvalues and eigenvectors of a square matrix.
|
Compute the eigenvalues and eigenvectors of a square matrix.
|
||||||
|
|
||||||
@@ -523,6 +525,8 @@ void init_linalg(nb::module_& parent_module) {
|
|||||||
"UPLO"_a = "L",
|
"UPLO"_a = "L",
|
||||||
nb::kw_only(),
|
nb::kw_only(),
|
||||||
"stream"_a = nb::none(),
|
"stream"_a = nb::none(),
|
||||||
|
nb::sig(
|
||||||
|
"def eigh(a: array, UPLO: str = 'L', *, stream: Union[None, Stream, Device] = None) -> Tuple[array, array]"),
|
||||||
R"pbdoc(
|
R"pbdoc(
|
||||||
Compute the eigenvalues and eigenvectors of a complex Hermitian or
|
Compute the eigenvalues and eigenvectors of a complex Hermitian or
|
||||||
real symmetric matrix.
|
real symmetric matrix.
|
||||||
|
|||||||
@@ -4271,7 +4271,7 @@ void init_ops(nb::module_& m) {
|
|||||||
nb::kw_only(),
|
nb::kw_only(),
|
||||||
"stream"_a = nb::none(),
|
"stream"_a = nb::none(),
|
||||||
nb::sig(
|
nb::sig(
|
||||||
"def dequantize(w: array, /, scales: array, biases: Optional[array] = = None, group_size: int = 64, bits: int = 4, mode: str = 'affine', *, stream: Union[None, Stream, Device] = None) -> array"),
|
"def dequantize(w: array, /, scales: array, biases: Optional[array] = None, group_size: int = 64, bits: int = 4, mode: str = 'affine', *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||||
R"pbdoc(
|
R"pbdoc(
|
||||||
Dequantize the matrix ``w`` using quantization parameters.
|
Dequantize the matrix ``w`` using quantization parameters.
|
||||||
|
|
||||||
|
|||||||
@@ -171,7 +171,7 @@ void init_random(nb::module_& parent_module) {
|
|||||||
"key"_a = nb::none(),
|
"key"_a = nb::none(),
|
||||||
"stream"_a = nb::none(),
|
"stream"_a = nb::none(),
|
||||||
nb::sig(
|
nb::sig(
|
||||||
"def normal(shape: Sequence[int] = [], dtype: Optional[Dtype] = float32, loc: Optional[scalar, array] = None, scale: Optional[scalar, array] = None, key: Optional[array] = None, stream: Union[None, Stream, Device] = None) -> array"),
|
"def normal(shape: Sequence[int] = [], dtype: Optional[Dtype] = float32, loc: Union[scalar, array, None] = None, scale: Union[scalar, array, None] = None, key: Optional[array] = None, stream: Union[None, Stream, Device] = None) -> array"),
|
||||||
R"pbdoc(
|
R"pbdoc(
|
||||||
Generate normally distributed random numbers.
|
Generate normally distributed random numbers.
|
||||||
|
|
||||||
|
|||||||
@@ -619,6 +619,17 @@ class TestSDPA(mlx_tests.MLXTestCase):
|
|||||||
out = mx.fast.scaled_dot_product_attention(q, k, v, scale=scale, mask=mask)
|
out = mx.fast.scaled_dot_product_attention(q, k, v, scale=scale, mask=mask)
|
||||||
self.assertTrue(mx.allclose(ref, out, atol=1e-4, rtol=1e-4))
|
self.assertTrue(mx.allclose(ref, out, atol=1e-4, rtol=1e-4))
|
||||||
|
|
||||||
|
def test_sdpa_noncontiguous_inputs(self):
|
||||||
|
mask = mx.ones(shape=(4, 1, 7, 7), dtype=mx.bool_)
|
||||||
|
mx.random.seed(0)
|
||||||
|
q = mx.random.normal(shape=(4, 7, 32, 64)).swapaxes(1, 2)
|
||||||
|
|
||||||
|
k = mx.random.normal(shape=(4, 7, 8, 64)).swapaxes(1, 2)
|
||||||
|
v = mx.random.normal(shape=(4, 7, 8, 64)).swapaxes(1, 2)
|
||||||
|
out = mx.fast.scaled_dot_product_attention(q, k, v, scale=1.0, mask=mask)
|
||||||
|
ref = mlx_ref_attn(q, k, v, scale=1.0, mask=mask)
|
||||||
|
self.assertTrue(mx.allclose(ref, out, atol=1e-4, rtol=1e-4))
|
||||||
|
|
||||||
def test_sdpa_promote_mask(self):
|
def test_sdpa_promote_mask(self):
|
||||||
mask = mx.array(2.0, mx.bfloat16)
|
mask = mx.array(2.0, mx.bfloat16)
|
||||||
D = 64
|
D = 64
|
||||||
|
|||||||
4
setup.py
4
setup.py
@@ -176,10 +176,6 @@ class GenerateStubs(Command):
|
|||||||
# Run again without recursive to specify output file name
|
# Run again without recursive to specify output file name
|
||||||
subprocess.run(["rm", f"{out_path}/mlx.pyi"])
|
subprocess.run(["rm", f"{out_path}/mlx.pyi"])
|
||||||
subprocess.run(stub_cmd + ["-o", f"{out_path}/__init__.pyi"])
|
subprocess.run(stub_cmd + ["-o", f"{out_path}/__init__.pyi"])
|
||||||
# mx.bool_ gets filtered by nanobind because of the trailing
|
|
||||||
# underscore, add it manually:
|
|
||||||
with open(f"{out_path}/__init__.pyi", "a") as fid:
|
|
||||||
fid.write("\nbool_: Dtype = ...")
|
|
||||||
|
|
||||||
|
|
||||||
class MLXBdistWheel(bdist_wheel):
|
class MLXBdistWheel(bdist_wheel):
|
||||||
|
|||||||
Reference in New Issue
Block a user