mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 09:21:16 +08:00
Add isort pre-commit and run (#68)
This commit is contained in:
parent
e89c571de7
commit
5b9be57ac3
@ -7,3 +7,9 @@ repos:
|
|||||||
rev: 22.10.0
|
rev: 22.10.0
|
||||||
hooks:
|
hooks:
|
||||||
- id: black
|
- id: black
|
||||||
|
- repo: https://github.com/pycqa/isort
|
||||||
|
rev: 5.12.0
|
||||||
|
hooks:
|
||||||
|
- id: isort
|
||||||
|
args:
|
||||||
|
- --profile=black
|
@ -1,7 +1,6 @@
|
|||||||
# Copyright © 2023 Apple Inc.
|
# Copyright © 2023 Apple Inc.
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from time_utils import time_fn
|
from time_utils import time_fn
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,8 +1,8 @@
|
|||||||
# Copyright © 2023 Apple Inc.
|
# Copyright © 2023 Apple Inc.
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import mlx.core as mx
|
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
from time_utils import time_fn
|
from time_utils import time_fn
|
||||||
|
|
||||||
B = 8
|
B = 8
|
||||||
|
@ -1,13 +1,14 @@
|
|||||||
# Copyright © 2023 Apple Inc.
|
# Copyright © 2023 Apple Inc.
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import argparse
|
import argparse
|
||||||
import mlx.core as mx
|
|
||||||
import time
|
|
||||||
import torch
|
|
||||||
import os
|
|
||||||
import math
|
import math
|
||||||
|
import os
|
||||||
import subprocess
|
import subprocess
|
||||||
|
import time
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
|
||||||
device_name = subprocess.check_output(["sysctl", "-n", "machdep.cpu.brand_string"])
|
device_name = subprocess.check_output(["sysctl", "-n", "machdep.cpu.brand_string"])
|
||||||
device_name = device_name.decode("utf-8").strip("\n")
|
device_name = device_name.decode("utf-8").strip("\n")
|
||||||
|
@ -1,14 +1,14 @@
|
|||||||
# Copyright © 2023 Apple Inc.
|
# Copyright © 2023 Apple Inc.
|
||||||
|
|
||||||
import matplotlib.pyplot as plt
|
|
||||||
import numpy as np
|
|
||||||
import argparse
|
import argparse
|
||||||
import mlx.core as mx
|
|
||||||
import time
|
|
||||||
import torch
|
|
||||||
import os
|
import os
|
||||||
import subprocess
|
import subprocess
|
||||||
|
import time
|
||||||
|
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import mlx.core as mx
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
|
||||||
results_dir = "./results"
|
results_dir = "./results"
|
||||||
|
|
||||||
|
@ -4,8 +4,8 @@ import math
|
|||||||
import time
|
import time
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
|
||||||
import torch.mps
|
import torch.mps
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
|
||||||
def sync_if_needed(x):
|
def sync_if_needed(x):
|
||||||
|
@ -1,8 +1,8 @@
|
|||||||
# Copyright © 2023 Apple Inc.
|
# Copyright © 2023 Apple Inc.
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import mlx.core as mx
|
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
from time_utils import time_fn
|
from time_utils import time_fn
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
# Copyright © 2023 Apple Inc.
|
# Copyright © 2023 Apple Inc.
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
|
|
||||||
from .mlx_sample_extensions import *
|
from .mlx_sample_extensions import *
|
||||||
|
@ -1,8 +1,9 @@
|
|||||||
# Copyright © 2023 Apple Inc.
|
# Copyright © 2023 Apple Inc.
|
||||||
|
|
||||||
from mlx import extension
|
|
||||||
from setuptools import setup
|
from setuptools import setup
|
||||||
|
|
||||||
|
from mlx import extension
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
setup(
|
setup(
|
||||||
name="mlx_sample_extensions",
|
name="mlx_sample_extensions",
|
||||||
|
@ -1,8 +1,9 @@
|
|||||||
# Copyright © 2023 Apple Inc.
|
# Copyright © 2023 Apple Inc.
|
||||||
|
|
||||||
import mlx.core as mx
|
|
||||||
import time
|
import time
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
|
||||||
num_features = 100
|
num_features = 100
|
||||||
num_examples = 1_000
|
num_examples = 1_000
|
||||||
num_iters = 10_000
|
num_iters = 10_000
|
||||||
|
@ -1,8 +1,9 @@
|
|||||||
# Copyright © 2023 Apple Inc.
|
# Copyright © 2023 Apple Inc.
|
||||||
|
|
||||||
import mlx.core as mx
|
|
||||||
import time
|
import time
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
|
||||||
num_features = 100
|
num_features = 100
|
||||||
num_examples = 1_000
|
num_examples = 1_000
|
||||||
num_iters = 10_000
|
num_iters = 10_000
|
||||||
|
@ -6,7 +6,7 @@ import subprocess
|
|||||||
import sys
|
import sys
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from setuptools import Extension, setup, find_namespace_packages
|
from setuptools import Extension, find_namespace_packages, setup
|
||||||
from setuptools.command.build_ext import build_ext
|
from setuptools.command.build_ext import build_ext
|
||||||
|
|
||||||
import mlx
|
import mlx
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
# Copyright © 2023 Apple Inc.
|
# Copyright © 2023 Apple Inc.
|
||||||
|
|
||||||
from mlx.nn.layers import *
|
|
||||||
from mlx.nn import losses
|
from mlx.nn import losses
|
||||||
|
from mlx.nn.layers import *
|
||||||
from mlx.nn.utils import value_and_grad
|
from mlx.nn.utils import value_and_grad
|
||||||
|
@ -1,6 +1,5 @@
|
|||||||
# Copyright © 2023 Apple Inc.
|
# Copyright © 2023 Apple Inc.
|
||||||
|
|
||||||
from mlx.nn.layers.base import Module
|
|
||||||
from mlx.nn.layers.activations import (
|
from mlx.nn.layers.activations import (
|
||||||
GELU,
|
GELU,
|
||||||
ReLU,
|
ReLU,
|
||||||
@ -11,6 +10,7 @@ from mlx.nn.layers.activations import (
|
|||||||
relu,
|
relu,
|
||||||
silu,
|
silu,
|
||||||
)
|
)
|
||||||
|
from mlx.nn.layers.base import Module
|
||||||
from mlx.nn.layers.containers import Sequential
|
from mlx.nn.layers.containers import Sequential
|
||||||
from mlx.nn.layers.convolution import Conv1d, Conv2d
|
from mlx.nn.layers.convolution import Conv1d, Conv2d
|
||||||
from mlx.nn.layers.dropout import Dropout
|
from mlx.nn.layers.dropout import Dropout
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
# Copyright © 2023 Apple Inc.
|
# Copyright © 2023 Apple Inc.
|
||||||
|
|
||||||
import textwrap
|
import textwrap
|
||||||
from typing import Any, Callable, List, Union, Optional
|
from typing import Any, Callable, List, Optional, Union
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
from mlx.utils import tree_flatten, tree_unflatten
|
from mlx.utils import tree_flatten, tree_unflatten
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
# Copyright © 2023 Apple Inc.
|
# Copyright © 2023 Apple Inc.
|
||||||
|
|
||||||
import math
|
import math
|
||||||
from typing import Optional, Any
|
from typing import Any, Optional
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
from mlx.nn.layers.base import Module
|
from mlx.nn.layers.base import Module
|
||||||
|
@ -5,9 +5,8 @@ import unittest
|
|||||||
from itertools import permutations
|
from itertools import permutations
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
import mlx_tests
|
import mlx_tests
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
class TestVersion(mlx_tests.MLXTestCase):
|
class TestVersion(mlx_tests.MLXTestCase):
|
||||||
|
@ -3,7 +3,6 @@
|
|||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
|
|
||||||
import mlx_tests
|
import mlx_tests
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,13 +1,12 @@
|
|||||||
# Copyright © 2023 Apple Inc.
|
# Copyright © 2023 Apple Inc.
|
||||||
|
|
||||||
|
import math
|
||||||
import unittest
|
import unittest
|
||||||
from itertools import permutations
|
from itertools import permutations
|
||||||
|
|
||||||
import math
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
import mlx_tests
|
import mlx_tests
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import torch
|
import torch
|
||||||
|
@ -1,13 +1,12 @@
|
|||||||
# Copyright © 2023 Apple Inc.
|
# Copyright © 2023 Apple Inc.
|
||||||
|
|
||||||
|
import math
|
||||||
import unittest
|
import unittest
|
||||||
from itertools import permutations
|
from itertools import permutations
|
||||||
|
|
||||||
import math
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
import mlx_tests
|
import mlx_tests
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
class TestBlas(mlx_tests.MLXTestCase):
|
class TestBlas(mlx_tests.MLXTestCase):
|
||||||
|
@ -1,13 +1,12 @@
|
|||||||
# Copyright © 2023 Apple Inc.
|
# Copyright © 2023 Apple Inc.
|
||||||
|
|
||||||
|
import math
|
||||||
import unittest
|
import unittest
|
||||||
from itertools import permutations
|
from itertools import permutations
|
||||||
|
|
||||||
import math
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
import mlx_tests
|
import mlx_tests
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import torch
|
import torch
|
||||||
|
@ -3,7 +3,6 @@
|
|||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
|
|
||||||
import mlx_tests
|
import mlx_tests
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,11 +1,9 @@
|
|||||||
# Copyright © 2023 Apple Inc.
|
# Copyright © 2023 Apple Inc.
|
||||||
|
|
||||||
|
import unittest
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
|
||||||
import unittest
|
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
|
|
||||||
import mlx_tests
|
import mlx_tests
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,12 +1,11 @@
|
|||||||
# Copyright © 2023 Apple Inc.
|
# Copyright © 2023 Apple Inc.
|
||||||
|
|
||||||
|
import itertools
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
import itertools
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
import mlx_tests
|
import mlx_tests
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
class TestFFT(mlx_tests.MLXTestCase):
|
class TestFFT(mlx_tests.MLXTestCase):
|
||||||
|
@ -1,12 +1,12 @@
|
|||||||
# Copyright © 2023 Apple Inc.
|
# Copyright © 2023 Apple Inc.
|
||||||
|
|
||||||
import unittest
|
|
||||||
import os
|
import os
|
||||||
import mlx.core as mx
|
|
||||||
import numpy as np
|
|
||||||
import tempfile
|
import tempfile
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
import mlx_tests
|
import mlx_tests
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
class TestLoad(mlx_tests.MLXTestCase):
|
class TestLoad(mlx_tests.MLXTestCase):
|
||||||
|
@ -1,15 +1,14 @@
|
|||||||
# Copyright © 2023 Apple Inc.
|
# Copyright © 2023 Apple Inc.
|
||||||
|
|
||||||
|
import os
|
||||||
|
import tempfile
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
from mlx.utils import tree_flatten, tree_map, tree_unflatten
|
|
||||||
import numpy as np
|
|
||||||
import os
|
|
||||||
import tempfile
|
|
||||||
|
|
||||||
import mlx_tests
|
import mlx_tests
|
||||||
|
import numpy as np
|
||||||
|
from mlx.utils import tree_flatten, tree_map, tree_unflatten
|
||||||
|
|
||||||
|
|
||||||
class TestNN(mlx_tests.MLXTestCase):
|
class TestNN(mlx_tests.MLXTestCase):
|
||||||
|
@ -1,13 +1,12 @@
|
|||||||
# Copyright © 2023 Apple Inc.
|
# Copyright © 2023 Apple Inc.
|
||||||
|
|
||||||
|
import math
|
||||||
import unittest
|
import unittest
|
||||||
from itertools import permutations
|
from itertools import permutations
|
||||||
|
|
||||||
import math
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
import mlx_tests
|
import mlx_tests
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
class TestOps(mlx_tests.MLXTestCase):
|
class TestOps(mlx_tests.MLXTestCase):
|
||||||
|
@ -5,7 +5,6 @@ import unittest
|
|||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx.optimizers as opt
|
import mlx.optimizers as opt
|
||||||
import mlx.utils
|
import mlx.utils
|
||||||
|
|
||||||
import mlx_tests
|
import mlx_tests
|
||||||
|
|
||||||
|
|
||||||
|
@ -3,7 +3,6 @@
|
|||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
|
|
||||||
import mlx_tests
|
import mlx_tests
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,12 +1,11 @@
|
|||||||
# Copyright © 2023 Apple Inc.
|
# Copyright © 2023 Apple Inc.
|
||||||
|
|
||||||
import unittest
|
import unittest
|
||||||
from itertools import permutations, combinations
|
from itertools import combinations, permutations
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
import mlx_tests
|
import mlx_tests
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
class TestReduce(mlx_tests.MLXTestCase):
|
class TestReduce(mlx_tests.MLXTestCase):
|
||||||
|
@ -4,7 +4,6 @@ import unittest
|
|||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx.utils
|
import mlx.utils
|
||||||
|
|
||||||
import mlx_tests
|
import mlx_tests
|
||||||
|
|
||||||
|
|
||||||
|
@ -3,7 +3,6 @@
|
|||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
|
|
||||||
import mlx_tests
|
import mlx_tests
|
||||||
|
|
||||||
|
|
||||||
|
2
setup.py
2
setup.py
@ -9,7 +9,7 @@ import sysconfig
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from subprocess import run
|
from subprocess import run
|
||||||
|
|
||||||
from setuptools import Extension, setup, find_namespace_packages
|
from setuptools import Extension, find_namespace_packages, setup
|
||||||
from setuptools.command.build_ext import build_ext
|
from setuptools.command.build_ext import build_ext
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user