diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index b8babecad..bb29c8fcd 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -7,3 +7,9 @@ repos: rev: 22.10.0 hooks: - id: black +- repo: https://github.com/pycqa/isort + rev: 5.12.0 + hooks: + - id: isort + args: + - --profile=black \ No newline at end of file diff --git a/benchmarks/numpy/single_ops.py b/benchmarks/numpy/single_ops.py index 2631c50fc..390748130 100644 --- a/benchmarks/numpy/single_ops.py +++ b/benchmarks/numpy/single_ops.py @@ -1,7 +1,6 @@ # Copyright © 2023 Apple Inc. import numpy as np - from time_utils import time_fn diff --git a/benchmarks/python/batch_matmul_bench.py b/benchmarks/python/batch_matmul_bench.py index bf0f74b82..0d822493b 100644 --- a/benchmarks/python/batch_matmul_bench.py +++ b/benchmarks/python/batch_matmul_bench.py @@ -1,8 +1,8 @@ # Copyright © 2023 Apple Inc. import argparse -import mlx.core as mx +import mlx.core as mx from time_utils import time_fn B = 8 diff --git a/benchmarks/python/blas/bench_gemm.py b/benchmarks/python/blas/bench_gemm.py index 608a7cb95..3681cd777 100644 --- a/benchmarks/python/blas/bench_gemm.py +++ b/benchmarks/python/blas/bench_gemm.py @@ -1,13 +1,14 @@ # Copyright © 2023 Apple Inc. -import numpy as np import argparse -import mlx.core as mx -import time -import torch -import os import math +import os import subprocess +import time + +import mlx.core as mx +import numpy as np +import torch device_name = subprocess.check_output(["sysctl", "-n", "machdep.cpu.brand_string"]) device_name = device_name.decode("utf-8").strip("\n") diff --git a/benchmarks/python/blas/bench_gemv.py b/benchmarks/python/blas/bench_gemv.py index 41ac4ec3a..5f491ffc8 100644 --- a/benchmarks/python/blas/bench_gemv.py +++ b/benchmarks/python/blas/bench_gemv.py @@ -1,14 +1,14 @@ # Copyright © 2023 Apple Inc. -import matplotlib.pyplot as plt -import numpy as np import argparse -import mlx.core as mx -import time -import torch import os import subprocess +import time +import matplotlib.pyplot as plt +import mlx.core as mx +import numpy as np +import torch results_dir = "./results" diff --git a/benchmarks/python/llama_torch_bench.py b/benchmarks/python/llama_torch_bench.py index 72329f373..5ee4ddd9c 100644 --- a/benchmarks/python/llama_torch_bench.py +++ b/benchmarks/python/llama_torch_bench.py @@ -4,8 +4,8 @@ import math import time import torch -import torch.nn as nn import torch.mps +import torch.nn as nn def sync_if_needed(x): diff --git a/benchmarks/python/single_ops.py b/benchmarks/python/single_ops.py index 631eef4e2..5a7720b0e 100644 --- a/benchmarks/python/single_ops.py +++ b/benchmarks/python/single_ops.py @@ -1,8 +1,8 @@ # Copyright © 2023 Apple Inc. import argparse -import mlx.core as mx +import mlx.core as mx from time_utils import time_fn diff --git a/examples/extensions/mlx_sample_extensions/__init__.py b/examples/extensions/mlx_sample_extensions/__init__.py index c194d099c..4f3acce9e 100644 --- a/examples/extensions/mlx_sample_extensions/__init__.py +++ b/examples/extensions/mlx_sample_extensions/__init__.py @@ -1,4 +1,5 @@ # Copyright © 2023 Apple Inc. import mlx.core as mx + from .mlx_sample_extensions import * diff --git a/examples/extensions/setup.py b/examples/extensions/setup.py index 74e59cb9a..d432f67f7 100644 --- a/examples/extensions/setup.py +++ b/examples/extensions/setup.py @@ -1,8 +1,9 @@ # Copyright © 2023 Apple Inc. -from mlx import extension from setuptools import setup +from mlx import extension + if __name__ == "__main__": setup( name="mlx_sample_extensions", diff --git a/examples/python/linear_regression.py b/examples/python/linear_regression.py index 47b0253d9..1de4fc531 100644 --- a/examples/python/linear_regression.py +++ b/examples/python/linear_regression.py @@ -1,8 +1,9 @@ # Copyright © 2023 Apple Inc. -import mlx.core as mx import time +import mlx.core as mx + num_features = 100 num_examples = 1_000 num_iters = 10_000 diff --git a/examples/python/logistic_regression.py b/examples/python/logistic_regression.py index 8df1fd994..ce6c1f53f 100644 --- a/examples/python/logistic_regression.py +++ b/examples/python/logistic_regression.py @@ -1,8 +1,9 @@ # Copyright © 2023 Apple Inc. -import mlx.core as mx import time +import mlx.core as mx + num_features = 100 num_examples = 1_000 num_iters = 10_000 diff --git a/python/mlx/extension.py b/python/mlx/extension.py index 2924ef0c6..842cf667f 100644 --- a/python/mlx/extension.py +++ b/python/mlx/extension.py @@ -6,7 +6,7 @@ import subprocess import sys from pathlib import Path -from setuptools import Extension, setup, find_namespace_packages +from setuptools import Extension, find_namespace_packages, setup from setuptools.command.build_ext import build_ext import mlx diff --git a/python/mlx/nn/__init__.py b/python/mlx/nn/__init__.py index 2af4f3a51..9bb7cc63d 100644 --- a/python/mlx/nn/__init__.py +++ b/python/mlx/nn/__init__.py @@ -1,5 +1,5 @@ # Copyright © 2023 Apple Inc. -from mlx.nn.layers import * from mlx.nn import losses +from mlx.nn.layers import * from mlx.nn.utils import value_and_grad diff --git a/python/mlx/nn/layers/__init__.py b/python/mlx/nn/layers/__init__.py index 30a185044..133a04494 100644 --- a/python/mlx/nn/layers/__init__.py +++ b/python/mlx/nn/layers/__init__.py @@ -1,6 +1,5 @@ # Copyright © 2023 Apple Inc. -from mlx.nn.layers.base import Module from mlx.nn.layers.activations import ( GELU, ReLU, @@ -11,6 +10,7 @@ from mlx.nn.layers.activations import ( relu, silu, ) +from mlx.nn.layers.base import Module from mlx.nn.layers.containers import Sequential from mlx.nn.layers.convolution import Conv1d, Conv2d from mlx.nn.layers.dropout import Dropout diff --git a/python/mlx/nn/layers/base.py b/python/mlx/nn/layers/base.py index 991b13813..2d80553a1 100644 --- a/python/mlx/nn/layers/base.py +++ b/python/mlx/nn/layers/base.py @@ -1,7 +1,7 @@ # Copyright © 2023 Apple Inc. import textwrap -from typing import Any, Callable, List, Union, Optional +from typing import Any, Callable, List, Optional, Union import mlx.core as mx from mlx.utils import tree_flatten, tree_unflatten diff --git a/python/mlx/nn/layers/transformer.py b/python/mlx/nn/layers/transformer.py index 68d9303ac..7fffe4291 100644 --- a/python/mlx/nn/layers/transformer.py +++ b/python/mlx/nn/layers/transformer.py @@ -1,7 +1,7 @@ # Copyright © 2023 Apple Inc. import math -from typing import Optional, Any +from typing import Any, Optional import mlx.core as mx from mlx.nn.layers.base import Module diff --git a/python/tests/test_array.py b/python/tests/test_array.py index d3539ac55..3c38a7ef5 100644 --- a/python/tests/test_array.py +++ b/python/tests/test_array.py @@ -5,9 +5,8 @@ import unittest from itertools import permutations import mlx.core as mx -import numpy as np - import mlx_tests +import numpy as np class TestVersion(mlx_tests.MLXTestCase): diff --git a/python/tests/test_autograd.py b/python/tests/test_autograd.py index 5cead2f33..72cc0dc55 100644 --- a/python/tests/test_autograd.py +++ b/python/tests/test_autograd.py @@ -3,7 +3,6 @@ import unittest import mlx.core as mx - import mlx_tests diff --git a/python/tests/test_bf16.py b/python/tests/test_bf16.py index 27a4e5623..95bf076e8 100644 --- a/python/tests/test_bf16.py +++ b/python/tests/test_bf16.py @@ -1,13 +1,12 @@ # Copyright © 2023 Apple Inc. +import math import unittest from itertools import permutations -import math import mlx.core as mx -import numpy as np - import mlx_tests +import numpy as np try: import torch diff --git a/python/tests/test_blas.py b/python/tests/test_blas.py index 7a7133e3d..b2a762681 100644 --- a/python/tests/test_blas.py +++ b/python/tests/test_blas.py @@ -1,13 +1,12 @@ # Copyright © 2023 Apple Inc. +import math import unittest from itertools import permutations -import math import mlx.core as mx -import numpy as np - import mlx_tests +import numpy as np class TestBlas(mlx_tests.MLXTestCase): diff --git a/python/tests/test_conv.py b/python/tests/test_conv.py index 13b51e711..1501334e9 100644 --- a/python/tests/test_conv.py +++ b/python/tests/test_conv.py @@ -1,13 +1,12 @@ # Copyright © 2023 Apple Inc. +import math import unittest from itertools import permutations -import math import mlx.core as mx -import numpy as np - import mlx_tests +import numpy as np try: import torch diff --git a/python/tests/test_device.py b/python/tests/test_device.py index fff368b9a..8aac105bc 100644 --- a/python/tests/test_device.py +++ b/python/tests/test_device.py @@ -3,7 +3,6 @@ import unittest import mlx.core as mx - import mlx_tests diff --git a/python/tests/test_eval.py b/python/tests/test_eval.py index 88149729c..b5145efd0 100644 --- a/python/tests/test_eval.py +++ b/python/tests/test_eval.py @@ -1,11 +1,9 @@ # Copyright © 2023 Apple Inc. +import unittest from functools import partial -import unittest - import mlx.core as mx - import mlx_tests diff --git a/python/tests/test_fft.py b/python/tests/test_fft.py index ce273bb22..4be12e21f 100644 --- a/python/tests/test_fft.py +++ b/python/tests/test_fft.py @@ -1,12 +1,11 @@ # Copyright © 2023 Apple Inc. +import itertools import unittest -import itertools import mlx.core as mx -import numpy as np - import mlx_tests +import numpy as np class TestFFT(mlx_tests.MLXTestCase): diff --git a/python/tests/test_load.py b/python/tests/test_load.py index 6f345fa99..e63588d03 100644 --- a/python/tests/test_load.py +++ b/python/tests/test_load.py @@ -1,12 +1,12 @@ # Copyright © 2023 Apple Inc. -import unittest import os -import mlx.core as mx -import numpy as np import tempfile +import unittest +import mlx.core as mx import mlx_tests +import numpy as np class TestLoad(mlx_tests.MLXTestCase): diff --git a/python/tests/test_nn.py b/python/tests/test_nn.py index 77eb6bf79..e8bae0588 100644 --- a/python/tests/test_nn.py +++ b/python/tests/test_nn.py @@ -1,15 +1,14 @@ # Copyright © 2023 Apple Inc. +import os +import tempfile import unittest import mlx.core as mx import mlx.nn as nn -from mlx.utils import tree_flatten, tree_map, tree_unflatten -import numpy as np -import os -import tempfile - import mlx_tests +import numpy as np +from mlx.utils import tree_flatten, tree_map, tree_unflatten class TestNN(mlx_tests.MLXTestCase): diff --git a/python/tests/test_ops.py b/python/tests/test_ops.py index 217a52589..4761bee3e 100644 --- a/python/tests/test_ops.py +++ b/python/tests/test_ops.py @@ -1,13 +1,12 @@ # Copyright © 2023 Apple Inc. +import math import unittest from itertools import permutations -import math import mlx.core as mx -import numpy as np - import mlx_tests +import numpy as np class TestOps(mlx_tests.MLXTestCase): diff --git a/python/tests/test_optimizers.py b/python/tests/test_optimizers.py index 3d4f4518f..e136f6b3e 100644 --- a/python/tests/test_optimizers.py +++ b/python/tests/test_optimizers.py @@ -5,7 +5,6 @@ import unittest import mlx.core as mx import mlx.optimizers as opt import mlx.utils - import mlx_tests diff --git a/python/tests/test_random.py b/python/tests/test_random.py index e850e4235..1603371b3 100644 --- a/python/tests/test_random.py +++ b/python/tests/test_random.py @@ -3,7 +3,6 @@ import unittest import mlx.core as mx - import mlx_tests diff --git a/python/tests/test_reduce.py b/python/tests/test_reduce.py index da6207325..c8edc344d 100644 --- a/python/tests/test_reduce.py +++ b/python/tests/test_reduce.py @@ -1,12 +1,11 @@ # Copyright © 2023 Apple Inc. import unittest -from itertools import permutations, combinations +from itertools import combinations, permutations import mlx.core as mx -import numpy as np - import mlx_tests +import numpy as np class TestReduce(mlx_tests.MLXTestCase): diff --git a/python/tests/test_tree.py b/python/tests/test_tree.py index f5003c4a7..cab137b78 100644 --- a/python/tests/test_tree.py +++ b/python/tests/test_tree.py @@ -4,7 +4,6 @@ import unittest import mlx.core as mx import mlx.utils - import mlx_tests diff --git a/python/tests/test_vmap.py b/python/tests/test_vmap.py index 50f7e6ef6..c56eda80d 100644 --- a/python/tests/test_vmap.py +++ b/python/tests/test_vmap.py @@ -3,7 +3,6 @@ import unittest import mlx.core as mx - import mlx_tests diff --git a/setup.py b/setup.py index 980c4635a..d98b47c82 100644 --- a/setup.py +++ b/setup.py @@ -9,7 +9,7 @@ import sysconfig from pathlib import Path from subprocess import run -from setuptools import Extension, setup, find_namespace_packages +from setuptools import Extension, find_namespace_packages, setup from setuptools.command.build_ext import build_ext