Fixed several type annotations in the MLX stubs which degraded to Unknown/Any (#2560)

* Added scalar to stubs to fix Unkown Type Hint

### Proposed changes

Issue #2478 reports that several type annotations in the MLX stubs degrade to Unknown/Any in editors like VS Code with Pylance, due to missing imports (Union, Optional, Tuple) and an undefined scalar type alias.

This PR updates the stub generation patterns to:
	•	Add missing typing imports in mlx.core.__prefix__ so that Union, Optional, Tuple, etc. are always available.
	•	Define and export scalar: TypeAlias = Union[int, float, bool] in mlx.core.__suffix__ so that functions typed with Union[scalar, array] resolve correctly instead of falling back to Any.
	•	Update submodule stub prefixes (distributed, fast, linalg, metal, random) to import scalar alongside array, Device, and Stream, ensuring type checkers resolve the union consistently across modules.

With these changes, functions like mlx.add now display rich type signatures such as:

```
def add(
    a: scalar | array,
    b: scalar | array,
    stream: Stream | Device | None = None
) -> array
```

instead of degrading to Any.

### Checklist

	•	I have read the CONTRIBUTING document
	•	I have run pre-commit run --all-files to format my code / installed pre-commit prior to committing changes
	•	I have added tests that prove my fix is effective or that my feature works (n/a — stub generation only)
	•	I have updated the necessary documentation (if needed)

* add bool to patterns

---------

Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
Manuel Villanueva
2025-09-03 14:52:08 -05:00
committed by GitHub
parent c5d2937aa5
commit 89a3df9014
2 changed files with 18 additions and 9 deletions

View File

@@ -1,20 +1,33 @@
mlx.core.__prefix__:
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union
import sys
if sys.version_info >= (3, 10):
from typing import TypeAlias
else:
from typing_extensions import TypeAlias
mlx.core.__suffix__:
from typing import Union
scalar: TypeAlias = Union[int, float, bool]
bool_: Dtype = ...
mlx.core.distributed.__prefix__:
from mlx.core import array, Dtype, Device, Stream
from mlx.core import array, Dtype, Device, Stream, scalar
from mlx.core.distributed import Group
from typing import Sequence, Optional, Union
mlx.core.fast.__prefix__:
from mlx.core import array, Dtype, Device, Stream
from mlx.core import array, Dtype, Device, Stream, scalar
from typing import Sequence, Optional, Union
mlx.core.linalg.__prefix__:
from mlx.core import array, Dtype, Device, Stream
from mlx.core import array, Dtype, Device, Stream, scalar
from typing import Sequence, Optional, Tuple, Union
mlx.core.metal.__prefix__:
from mlx.core import array, Dtype, Device, Stream
from mlx.core import array, Dtype, Device, Stream, scalar
from typing import Sequence, Optional, Union
mlx.core.random.__prefix__:
from mlx.core import array, Dtype, Device, Stream
from mlx.core import array, Dtype, Device, Stream, scalar
from typing import Sequence, Optional, Union

View File

@@ -176,10 +176,6 @@ class GenerateStubs(Command):
# Run again without recursive to specify output file name
subprocess.run(["rm", f"{out_path}/mlx.pyi"])
subprocess.run(stub_cmd + ["-o", f"{out_path}/__init__.pyi"])
# mx.bool_ gets filtered by nanobind because of the trailing
# underscore, add it manually:
with open(f"{out_path}/__init__.pyi", "a") as fid:
fid.write("\nbool_: Dtype = ...")
class MLXBdistWheel(bdist_wheel):