mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-20 03:48:15 +08:00
Fixed several type annotations in the MLX stubs which degraded to Unknown/Any (#2560)
* Added scalar to stubs to fix Unkown Type Hint ### Proposed changes Issue #2478 reports that several type annotations in the MLX stubs degrade to Unknown/Any in editors like VS Code with Pylance, due to missing imports (Union, Optional, Tuple) and an undefined scalar type alias. This PR updates the stub generation patterns to: • Add missing typing imports in mlx.core.__prefix__ so that Union, Optional, Tuple, etc. are always available. • Define and export scalar: TypeAlias = Union[int, float, bool] in mlx.core.__suffix__ so that functions typed with Union[scalar, array] resolve correctly instead of falling back to Any. • Update submodule stub prefixes (distributed, fast, linalg, metal, random) to import scalar alongside array, Device, and Stream, ensuring type checkers resolve the union consistently across modules. With these changes, functions like mlx.add now display rich type signatures such as: ``` def add( a: scalar | array, b: scalar | array, stream: Stream | Device | None = None ) -> array ``` instead of degrading to Any. ### Checklist • I have read the CONTRIBUTING document • I have run pre-commit run --all-files to format my code / installed pre-commit prior to committing changes • I have added tests that prove my fix is effective or that my feature works (n/a — stub generation only) • I have updated the necessary documentation (if needed) * add bool to patterns --------- Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:

committed by
GitHub

parent
c5d2937aa5
commit
89a3df9014
@@ -1,20 +1,33 @@
|
||||
mlx.core.__prefix__:
|
||||
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union
|
||||
import sys
|
||||
if sys.version_info >= (3, 10):
|
||||
from typing import TypeAlias
|
||||
else:
|
||||
from typing_extensions import TypeAlias
|
||||
|
||||
mlx.core.__suffix__:
|
||||
from typing import Union
|
||||
scalar: TypeAlias = Union[int, float, bool]
|
||||
bool_: Dtype = ...
|
||||
|
||||
mlx.core.distributed.__prefix__:
|
||||
from mlx.core import array, Dtype, Device, Stream
|
||||
from mlx.core import array, Dtype, Device, Stream, scalar
|
||||
from mlx.core.distributed import Group
|
||||
from typing import Sequence, Optional, Union
|
||||
|
||||
mlx.core.fast.__prefix__:
|
||||
from mlx.core import array, Dtype, Device, Stream
|
||||
from mlx.core import array, Dtype, Device, Stream, scalar
|
||||
from typing import Sequence, Optional, Union
|
||||
|
||||
mlx.core.linalg.__prefix__:
|
||||
from mlx.core import array, Dtype, Device, Stream
|
||||
from mlx.core import array, Dtype, Device, Stream, scalar
|
||||
from typing import Sequence, Optional, Tuple, Union
|
||||
|
||||
mlx.core.metal.__prefix__:
|
||||
from mlx.core import array, Dtype, Device, Stream
|
||||
from mlx.core import array, Dtype, Device, Stream, scalar
|
||||
from typing import Sequence, Optional, Union
|
||||
|
||||
mlx.core.random.__prefix__:
|
||||
from mlx.core import array, Dtype, Device, Stream
|
||||
from mlx.core import array, Dtype, Device, Stream, scalar
|
||||
from typing import Sequence, Optional, Union
|
||||
|
Reference in New Issue
Block a user