diff --git a/normalizing_flow/distributions.py b/normalizing_flow/distributions.py index 2199a4a8..5b9fae48 100644 --- a/normalizing_flow/distributions.py +++ b/normalizing_flow/distributions.py @@ -1,7 +1,7 @@ # Copyright © 2023-2024 Apple Inc. -from typing import Tuple, Optional, Union import math +from typing import Optional, Tuple, Union import mlx.core as mx diff --git a/normalizing_flow/flows.py b/normalizing_flow/flows.py index 0fbd550b..d2ab1704 100644 --- a/normalizing_flow/flows.py +++ b/normalizing_flow/flows.py @@ -1,11 +1,10 @@ # Copyright © 2023-2024 Apple Inc. -from typing import Tuple, Optional, Union +from typing import Optional, Tuple, Union import mlx.core as mx import mlx.nn as nn - -from bijectors import MaskedCoupling, AffineBijector +from bijectors import AffineBijector, MaskedCoupling from distributions import Normal diff --git a/normalizing_flow/main.py b/normalizing_flow/main.py index 0d3cfe56..2956a098 100644 --- a/normalizing_flow/main.py +++ b/normalizing_flow/main.py @@ -1,15 +1,13 @@ # Copyright © 2023-2024 Apple Inc. -from tqdm import trange -import numpy as np -from sklearn import datasets, preprocessing import matplotlib.pyplot as plt - import mlx.core as mx import mlx.nn as nn import mlx.optimizers as optim - +import numpy as np from flows import RealNVP +from sklearn import datasets, preprocessing +from tqdm import trange def get_moons_dataset(n_samples=100_000, noise=0.06):