mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-05 07:34:42 +08:00
Compare commits
6 Commits
Author | SHA1 | Date | |
---|---|---|---|
![]() |
8f163a367d | ||
![]() |
89a3df9014 | ||
![]() |
c5d2937aa5 | ||
![]() |
b61a65e313 | ||
![]() |
04cbb4191c | ||
![]() |
c5460762e7 |
@@ -107,8 +107,20 @@ same array:
|
||||
>>> a
|
||||
array([1, 2, 0], dtype=int32)
|
||||
|
||||
Note that unlike NumPy, slicing an array creates a copy, not a view. So
|
||||
mutating it does not mutate the original array:
|
||||
|
||||
Note, unlike NumPy, updates to the same location are nondeterministic:
|
||||
.. code-block:: shell
|
||||
|
||||
>>> a = mx.array([1, 2, 3])
|
||||
>>> b = a[:]
|
||||
>>> b[2] = 0
|
||||
>>> b
|
||||
array([1, 2, 0], dtype=int32)
|
||||
>>> a
|
||||
array([1, 2, 3], dtype=int32)
|
||||
|
||||
Also unlike NumPy, updates to the same location are nondeterministic:
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
|
@@ -394,7 +394,7 @@ void ScaledDotProductAttention::eval_gpu(
|
||||
|
||||
// Define some copy functions to ensure the layout of the inputs is as
|
||||
// expected.
|
||||
copies.reserve(3);
|
||||
copies.reserve(inputs.size());
|
||||
auto copy_unless = [&copies, &s](
|
||||
auto predicate, const array& arr) -> const array& {
|
||||
if (!predicate(arr)) {
|
||||
|
@@ -1,20 +1,34 @@
|
||||
mlx.core.__prefix__:
|
||||
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union
|
||||
import sys
|
||||
if sys.version_info >= (3, 10):
|
||||
from typing import TypeAlias
|
||||
else:
|
||||
from typing_extensions import TypeAlias
|
||||
|
||||
mlx.core.__suffix__:
|
||||
from typing import Union
|
||||
scalar: TypeAlias = Union[int, float, bool]
|
||||
list_or_scalar: TypeAlias = Union[scalar, list["list_or_scalar"]]
|
||||
bool_: Dtype = ...
|
||||
|
||||
mlx.core.distributed.__prefix__:
|
||||
from mlx.core import array, Dtype, Device, Stream
|
||||
from mlx.core import array, Dtype, Device, Stream, scalar
|
||||
from mlx.core.distributed import Group
|
||||
from typing import Sequence, Optional, Union
|
||||
|
||||
mlx.core.fast.__prefix__:
|
||||
from mlx.core import array, Dtype, Device, Stream
|
||||
from mlx.core import array, Dtype, Device, Stream, scalar
|
||||
from typing import Sequence, Optional, Union
|
||||
|
||||
mlx.core.linalg.__prefix__:
|
||||
from mlx.core import array, Dtype, Device, Stream
|
||||
from mlx.core import array, Dtype, Device, Stream, scalar
|
||||
from typing import Sequence, Optional, Tuple, Union
|
||||
|
||||
mlx.core.metal.__prefix__:
|
||||
from mlx.core import array, Dtype, Device, Stream
|
||||
from mlx.core import array, Dtype, Device, Stream, scalar
|
||||
from typing import Sequence, Optional, Union
|
||||
|
||||
mlx.core.random.__prefix__:
|
||||
from mlx.core import array, Dtype, Device, Stream
|
||||
from mlx.core import array, Dtype, Device, Stream, scalar, float32, int32
|
||||
from typing import Sequence, Optional, Union
|
||||
|
@@ -556,7 +556,7 @@ class AdamW(Adam):
|
||||
eps (float, optional): The term :math:`\epsilon` added to the
|
||||
denominator to improve numerical stability. Default: ``1e-8``
|
||||
weight_decay (float, optional): The weight decay :math:`\lambda`.
|
||||
Default: ``0``.
|
||||
Default: ``0.01``.
|
||||
bias_correction (bool, optional): If set to ``True``, bias correction
|
||||
is applied. Default: ``False``
|
||||
"""
|
||||
|
@@ -320,6 +320,7 @@ void init_array(nb::module_& m) {
|
||||
.def_prop_ro(
|
||||
"shape",
|
||||
[](const mx::array& a) { return nb::cast(a.shape()); },
|
||||
nb::sig("def shape(self) -> tuple[int, ...]"),
|
||||
R"pbdoc(
|
||||
The shape of the array as a Python tuple.
|
||||
|
||||
@@ -347,6 +348,7 @@ void init_array(nb::module_& m) {
|
||||
.def(
|
||||
"item",
|
||||
&to_scalar,
|
||||
nb::sig("def item(self) -> scalar"),
|
||||
R"pbdoc(
|
||||
Access the value of a scalar array.
|
||||
|
||||
@@ -356,6 +358,7 @@ void init_array(nb::module_& m) {
|
||||
.def(
|
||||
"tolist",
|
||||
&tolist,
|
||||
nb::sig("def tolist(self) -> list_or_scalar"),
|
||||
R"pbdoc(
|
||||
Convert the array to a Python :class:`list`.
|
||||
|
||||
|
@@ -447,6 +447,8 @@ void init_linalg(nb::module_& parent_module) {
|
||||
"a"_a,
|
||||
nb::kw_only(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def eig(a: array, *, stream: Union[None, Stream, Device] = None) -> Tuple[array, array]"),
|
||||
R"pbdoc(
|
||||
Compute the eigenvalues and eigenvectors of a square matrix.
|
||||
|
||||
@@ -523,6 +525,8 @@ void init_linalg(nb::module_& parent_module) {
|
||||
"UPLO"_a = "L",
|
||||
nb::kw_only(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def eigh(a: array, UPLO: str = 'L', *, stream: Union[None, Stream, Device] = None) -> Tuple[array, array]"),
|
||||
R"pbdoc(
|
||||
Compute the eigenvalues and eigenvectors of a complex Hermitian or
|
||||
real symmetric matrix.
|
||||
|
@@ -4271,7 +4271,7 @@ void init_ops(nb::module_& m) {
|
||||
nb::kw_only(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def dequantize(w: array, /, scales: array, biases: Optional[array] = = None, group_size: int = 64, bits: int = 4, mode: str = 'affine', *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
"def dequantize(w: array, /, scales: array, biases: Optional[array] = None, group_size: int = 64, bits: int = 4, mode: str = 'affine', *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
R"pbdoc(
|
||||
Dequantize the matrix ``w`` using quantization parameters.
|
||||
|
||||
|
@@ -171,7 +171,7 @@ void init_random(nb::module_& parent_module) {
|
||||
"key"_a = nb::none(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def normal(shape: Sequence[int] = [], dtype: Optional[Dtype] = float32, loc: Optional[scalar, array] = None, scale: Optional[scalar, array] = None, key: Optional[array] = None, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
"def normal(shape: Sequence[int] = [], dtype: Optional[Dtype] = float32, loc: Union[scalar, array, None] = None, scale: Union[scalar, array, None] = None, key: Optional[array] = None, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
R"pbdoc(
|
||||
Generate normally distributed random numbers.
|
||||
|
||||
|
@@ -619,6 +619,17 @@ class TestSDPA(mlx_tests.MLXTestCase):
|
||||
out = mx.fast.scaled_dot_product_attention(q, k, v, scale=scale, mask=mask)
|
||||
self.assertTrue(mx.allclose(ref, out, atol=1e-4, rtol=1e-4))
|
||||
|
||||
def test_sdpa_noncontiguous_inputs(self):
|
||||
mask = mx.ones(shape=(4, 1, 7, 7), dtype=mx.bool_)
|
||||
mx.random.seed(0)
|
||||
q = mx.random.normal(shape=(4, 7, 32, 64)).swapaxes(1, 2)
|
||||
|
||||
k = mx.random.normal(shape=(4, 7, 8, 64)).swapaxes(1, 2)
|
||||
v = mx.random.normal(shape=(4, 7, 8, 64)).swapaxes(1, 2)
|
||||
out = mx.fast.scaled_dot_product_attention(q, k, v, scale=1.0, mask=mask)
|
||||
ref = mlx_ref_attn(q, k, v, scale=1.0, mask=mask)
|
||||
self.assertTrue(mx.allclose(ref, out, atol=1e-4, rtol=1e-4))
|
||||
|
||||
def test_sdpa_promote_mask(self):
|
||||
mask = mx.array(2.0, mx.bfloat16)
|
||||
D = 64
|
||||
|
4
setup.py
4
setup.py
@@ -176,10 +176,6 @@ class GenerateStubs(Command):
|
||||
# Run again without recursive to specify output file name
|
||||
subprocess.run(["rm", f"{out_path}/mlx.pyi"])
|
||||
subprocess.run(stub_cmd + ["-o", f"{out_path}/__init__.pyi"])
|
||||
# mx.bool_ gets filtered by nanobind because of the trailing
|
||||
# underscore, add it manually:
|
||||
with open(f"{out_path}/__init__.pyi", "a") as fid:
|
||||
fid.write("\nbool_: Dtype = ...")
|
||||
|
||||
|
||||
class MLXBdistWheel(bdist_wheel):
|
||||
|
Reference in New Issue
Block a user