Compare commits

...

6 Commits

Author SHA1 Message Date
XXXXRT666
8f163a367d typing: add type hints to mlx.core.array, linalg, distributed, and random (#2565)
* Add type annotations to mlx methods

* Missing list_or_scalar
2025-09-04 09:08:11 -07:00
Manuel Villanueva
89a3df9014 Fixed several type annotations in the MLX stubs which degraded to Unknown/Any (#2560)
* Added scalar to stubs to fix Unkown Type Hint

### Proposed changes

Issue #2478 reports that several type annotations in the MLX stubs degrade to Unknown/Any in editors like VS Code with Pylance, due to missing imports (Union, Optional, Tuple) and an undefined scalar type alias.

This PR updates the stub generation patterns to:
	•	Add missing typing imports in mlx.core.__prefix__ so that Union, Optional, Tuple, etc. are always available.
	•	Define and export scalar: TypeAlias = Union[int, float, bool] in mlx.core.__suffix__ so that functions typed with Union[scalar, array] resolve correctly instead of falling back to Any.
	•	Update submodule stub prefixes (distributed, fast, linalg, metal, random) to import scalar alongside array, Device, and Stream, ensuring type checkers resolve the union consistently across modules.

With these changes, functions like mlx.add now display rich type signatures such as:

```
def add(
    a: scalar | array,
    b: scalar | array,
    stream: Stream | Device | None = None
) -> array
```

instead of degrading to Any.

### Checklist

	•	I have read the CONTRIBUTING document
	•	I have run pre-commit run --all-files to format my code / installed pre-commit prior to committing changes
	•	I have added tests that prove my fix is effective or that my feature works (n/a — stub generation only)
	•	I have updated the necessary documentation (if needed)

* add bool to patterns

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2025-09-03 12:52:08 -07:00
Krishi Saripalli
c5d2937aa5 chore: Update Docs With Slice Copy Example (#2559)
* chore: updated docs with slice copy example

* nits

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2025-09-02 22:07:02 -07:00
Awni Hannun
b61a65e313 fix copies in sdpa (#2563) 2025-09-02 11:00:36 -07:00
wrmsr
04cbb4191c Fix dequantize python sig (#2562) 2025-09-01 11:50:20 -07:00
Artur Antonov
c5460762e7 Fix AdamW weight_decay default value in docstring (#2557) 2025-08-31 21:29:30 -07:00
10 changed files with 54 additions and 14 deletions

View File

@@ -107,8 +107,20 @@ same array:
>>> a
array([1, 2, 0], dtype=int32)
Note that unlike NumPy, slicing an array creates a copy, not a view. So
mutating it does not mutate the original array:
Note, unlike NumPy, updates to the same location are nondeterministic:
.. code-block:: shell
>>> a = mx.array([1, 2, 3])
>>> b = a[:]
>>> b[2] = 0
>>> b
array([1, 2, 0], dtype=int32)
>>> a
array([1, 2, 3], dtype=int32)
Also unlike NumPy, updates to the same location are nondeterministic:
.. code-block:: shell

View File

@@ -394,7 +394,7 @@ void ScaledDotProductAttention::eval_gpu(
// Define some copy functions to ensure the layout of the inputs is as
// expected.
copies.reserve(3);
copies.reserve(inputs.size());
auto copy_unless = [&copies, &s](
auto predicate, const array& arr) -> const array& {
if (!predicate(arr)) {

View File

@@ -1,20 +1,34 @@
mlx.core.__prefix__:
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union
import sys
if sys.version_info >= (3, 10):
from typing import TypeAlias
else:
from typing_extensions import TypeAlias
mlx.core.__suffix__:
from typing import Union
scalar: TypeAlias = Union[int, float, bool]
list_or_scalar: TypeAlias = Union[scalar, list["list_or_scalar"]]
bool_: Dtype = ...
mlx.core.distributed.__prefix__:
from mlx.core import array, Dtype, Device, Stream
from mlx.core import array, Dtype, Device, Stream, scalar
from mlx.core.distributed import Group
from typing import Sequence, Optional, Union
mlx.core.fast.__prefix__:
from mlx.core import array, Dtype, Device, Stream
from mlx.core import array, Dtype, Device, Stream, scalar
from typing import Sequence, Optional, Union
mlx.core.linalg.__prefix__:
from mlx.core import array, Dtype, Device, Stream
from mlx.core import array, Dtype, Device, Stream, scalar
from typing import Sequence, Optional, Tuple, Union
mlx.core.metal.__prefix__:
from mlx.core import array, Dtype, Device, Stream
from mlx.core import array, Dtype, Device, Stream, scalar
from typing import Sequence, Optional, Union
mlx.core.random.__prefix__:
from mlx.core import array, Dtype, Device, Stream
from mlx.core import array, Dtype, Device, Stream, scalar, float32, int32
from typing import Sequence, Optional, Union

View File

@@ -556,7 +556,7 @@ class AdamW(Adam):
eps (float, optional): The term :math:`\epsilon` added to the
denominator to improve numerical stability. Default: ``1e-8``
weight_decay (float, optional): The weight decay :math:`\lambda`.
Default: ``0``.
Default: ``0.01``.
bias_correction (bool, optional): If set to ``True``, bias correction
is applied. Default: ``False``
"""

View File

@@ -320,6 +320,7 @@ void init_array(nb::module_& m) {
.def_prop_ro(
"shape",
[](const mx::array& a) { return nb::cast(a.shape()); },
nb::sig("def shape(self) -> tuple[int, ...]"),
R"pbdoc(
The shape of the array as a Python tuple.
@@ -347,6 +348,7 @@ void init_array(nb::module_& m) {
.def(
"item",
&to_scalar,
nb::sig("def item(self) -> scalar"),
R"pbdoc(
Access the value of a scalar array.
@@ -356,6 +358,7 @@ void init_array(nb::module_& m) {
.def(
"tolist",
&tolist,
nb::sig("def tolist(self) -> list_or_scalar"),
R"pbdoc(
Convert the array to a Python :class:`list`.

View File

@@ -447,6 +447,8 @@ void init_linalg(nb::module_& parent_module) {
"a"_a,
nb::kw_only(),
"stream"_a = nb::none(),
nb::sig(
"def eig(a: array, *, stream: Union[None, Stream, Device] = None) -> Tuple[array, array]"),
R"pbdoc(
Compute the eigenvalues and eigenvectors of a square matrix.
@@ -523,6 +525,8 @@ void init_linalg(nb::module_& parent_module) {
"UPLO"_a = "L",
nb::kw_only(),
"stream"_a = nb::none(),
nb::sig(
"def eigh(a: array, UPLO: str = 'L', *, stream: Union[None, Stream, Device] = None) -> Tuple[array, array]"),
R"pbdoc(
Compute the eigenvalues and eigenvectors of a complex Hermitian or
real symmetric matrix.

View File

@@ -4271,7 +4271,7 @@ void init_ops(nb::module_& m) {
nb::kw_only(),
"stream"_a = nb::none(),
nb::sig(
"def dequantize(w: array, /, scales: array, biases: Optional[array] = = None, group_size: int = 64, bits: int = 4, mode: str = 'affine', *, stream: Union[None, Stream, Device] = None) -> array"),
"def dequantize(w: array, /, scales: array, biases: Optional[array] = None, group_size: int = 64, bits: int = 4, mode: str = 'affine', *, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc(
Dequantize the matrix ``w`` using quantization parameters.

View File

@@ -171,7 +171,7 @@ void init_random(nb::module_& parent_module) {
"key"_a = nb::none(),
"stream"_a = nb::none(),
nb::sig(
"def normal(shape: Sequence[int] = [], dtype: Optional[Dtype] = float32, loc: Optional[scalar, array] = None, scale: Optional[scalar, array] = None, key: Optional[array] = None, stream: Union[None, Stream, Device] = None) -> array"),
"def normal(shape: Sequence[int] = [], dtype: Optional[Dtype] = float32, loc: Union[scalar, array, None] = None, scale: Union[scalar, array, None] = None, key: Optional[array] = None, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc(
Generate normally distributed random numbers.

View File

@@ -619,6 +619,17 @@ class TestSDPA(mlx_tests.MLXTestCase):
out = mx.fast.scaled_dot_product_attention(q, k, v, scale=scale, mask=mask)
self.assertTrue(mx.allclose(ref, out, atol=1e-4, rtol=1e-4))
def test_sdpa_noncontiguous_inputs(self):
mask = mx.ones(shape=(4, 1, 7, 7), dtype=mx.bool_)
mx.random.seed(0)
q = mx.random.normal(shape=(4, 7, 32, 64)).swapaxes(1, 2)
k = mx.random.normal(shape=(4, 7, 8, 64)).swapaxes(1, 2)
v = mx.random.normal(shape=(4, 7, 8, 64)).swapaxes(1, 2)
out = mx.fast.scaled_dot_product_attention(q, k, v, scale=1.0, mask=mask)
ref = mlx_ref_attn(q, k, v, scale=1.0, mask=mask)
self.assertTrue(mx.allclose(ref, out, atol=1e-4, rtol=1e-4))
def test_sdpa_promote_mask(self):
mask = mx.array(2.0, mx.bfloat16)
D = 64

View File

@@ -176,10 +176,6 @@ class GenerateStubs(Command):
# Run again without recursive to specify output file name
subprocess.run(["rm", f"{out_path}/mlx.pyi"])
subprocess.run(stub_cmd + ["-o", f"{out_path}/__init__.pyi"])
# mx.bool_ gets filtered by nanobind because of the trailing
# underscore, add it manually:
with open(f"{out_path}/__init__.pyi", "a") as fid:
fid.write("\nbool_: Dtype = ...")
class MLXBdistWheel(bdist_wheel):