From ecb174ca9d11d12c7e25519cfe5244e060869de4 Mon Sep 17 00:00:00 2001 From: Danilo Peixoto Date: Sun, 21 Jan 2024 17:53:12 -0300 Subject: [PATCH] Type annotations for `mlx.core` module (#512) --- .circleci/config.yml | 73 +++++++++++++++++++++++++++++++++++++++----- setup.py | 2 +- 2 files changed, 67 insertions(+), 8 deletions(-) diff --git a/.circleci/config.yml b/.circleci/config.yml index 7157949b7..d408bca11 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -26,16 +26,21 @@ jobs: command: | pip install --upgrade cmake pip install --upgrade pybind11[global] + pip install pybind11-stubgen pip install numpy sudo apt-get update sudo apt-get install libblas-dev - run: - name: Build python package + name: Install Python package command: | CMAKE_ARGS="-DMLX_BUILD_METAL=OFF" CMAKE_BUILD_PARALLEL_LEVEL="" python3 setup.py build_ext --inplace CMAKE_ARGS="-DMLX_BUILD_METAL=OFF" CMAKE_BUILD_PARALLEL_LEVEL="" python3 setup.py develop - run: - name: Run the python tests + name: Generate package stubs + command: | + python3 setup.py generate_stubs + - run: + name: Run Python tests command: | python3 -m unittest discover python/tests # TODO: Reenable when extension api becomes stable @@ -65,19 +70,26 @@ jobs: conda activate runner-env pip install --upgrade cmake pip install --upgrade pybind11[global] + pip install pybind11-stubgen pip install numpy pip install torch pip install tensorflow pip install unittest-xml-reporting - run: - name: Build python package + name: Install Python package command: | eval "$(conda shell.bash hook)" conda activate runner-env CMAKE_BUILD_PARALLEL_LEVEL="" python setup.py build_ext --inplace CMAKE_BUILD_PARALLEL_LEVEL="" python setup.py develop - run: - name: Run the python tests + name: Generate package stubs + command: | + eval "$(conda shell.bash hook)" + conda activate runner-env + python setup.py generate_stubs + - run: + name: Run Python tests command: | eval "$(conda shell.bash hook)" conda activate runner-env @@ -121,10 +133,26 @@ jobs: conda activate runner-env pip install --upgrade cmake pip install --upgrade pybind11[global] + pip install pybind11-stubgen pip install numpy pip install twine - run: - name: Build package + name: Install Python package + command: | + eval "$(conda shell.bash hook)" + conda activate runner-env + DEVELOPER_DIR=$(developer_dir_macos_<< parameters.macos_version >>) \ + PYPI_RELEASE=1 \ + CMAKE_BUILD_PARALLEL_LEVEL="" \ + python setup.py install + - run: + name: Generate package stubs + command: | + eval "$(conda shell.bash hook)" + conda activate runner-env + python setup.py generate_stubs + - run: + name: Publish Python package command: | eval "$(conda shell.bash hook)" conda activate runner-env @@ -157,10 +185,26 @@ jobs: conda activate runner-env pip install --upgrade cmake pip install --upgrade pybind11[global] + pip install pybind11-stubgen pip install numpy pip install twine - run: - name: Build package + name: Install Python package + command: | + eval "$(conda shell.bash hook)" + conda activate runner-env + DEVELOPER_DIR=$(developer_dir_macos_<< parameters.macos_version >>) \ + DEV_RELEASE=1 \ + CMAKE_BUILD_PARALLEL_LEVEL="" \ + python setup.py install + - run: + name: Generate package stubs + command: | + eval "$(conda shell.bash hook)" + conda activate runner-env + python setup.py generate_stubs + - run: + name: Publish Python package command: | eval "$(conda shell.bash hook)" conda activate runner-env @@ -193,10 +237,25 @@ jobs: conda activate runner-env pip install --upgrade cmake pip install --upgrade pybind11[global] + pip install pybind11-stubgen pip install numpy pip install twine - run: - name: Build package + name: Install Python package + command: | + eval "$(conda shell.bash hook)" + conda activate runner-env + DEVELOPER_DIR=$(developer_dir_macos_<< parameters.macos_version >>) \ + CMAKE_BUILD_PARALLEL_LEVEL="" \ + python setup.py install + - run: + name: Generate package stubs + command: | + eval "$(conda shell.bash hook)" + conda activate runner-env + python setup.py generate_stubs + - run: + name: Build package distribution command: | eval "$(conda shell.bash hook)" conda activate runner-env diff --git a/setup.py b/setup.py index 5ff41b57f..8e1fc93ae 100644 --- a/setup.py +++ b/setup.py @@ -148,7 +148,7 @@ if __name__ == "__main__": where="python", exclude=["src", "tests", "tests.*"] ) package_dir = {"": "python"} - package_data = {"mlx": ["lib/*", "include/*", "share/*"]} + package_data = {"mlx": ["lib/*", "include/*", "share/*"], "mlx.core": ["*.pyi"]} setup( name="mlx",