From 5206c2740f02c5e2094d8b58b6f56fe0beff0b00 Mon Sep 17 00:00:00 2001 From: Siddharth Mishra-Sharma Date: Mon, 18 Dec 2023 10:03:09 -0500 Subject: [PATCH] Remove trailing commas in function arguments for unified formatting in flows example --- flow/flows.py | 23 +++-------------------- 1 file changed, 3 insertions(+), 20 deletions(-) diff --git a/flow/flows.py b/flow/flows.py index 59aed19e..54b676b9 100644 --- a/flow/flows.py +++ b/flow/flows.py @@ -8,13 +8,7 @@ from distributions import Normal class MLP(nn.Module): - def __init__( - self, - n_layers: int, - d_in: int, - d_hidden: int, - d_out: int, - ): + def __init__(self, n_layers: int, d_in: int, d_hidden: int, d_out: int): super().__init__() layer_sizes = [d_in] + [d_hidden] * n_layers + [d_out] self.layers = [nn.Linear(idim, odim) for idim, odim in zip(layer_sizes[:-1], layer_sizes[1:])] @@ -26,13 +20,7 @@ class MLP(nn.Module): class RealNVP(nn.Module): - def __init__( - self, - n_transforms: int, - d_params: int, - d_hidden: int, - n_layers: int, - ): + def __init__(self, n_transforms: int, d_params: int, d_hidden: int, n_layers: int): super().__init__() # Alternating masks @@ -53,12 +41,7 @@ class RealNVP(nn.Module): log_prob += ldj return log_prob + self.bast_dist.log_prob(x).sum(-1) - def sample( - self, - sample_shape: Union[int, Tuple[int, ...]], - key: Optional[mx.array] = None, - n_transforms: Optional[int] = None, - ): + def sample(self, sample_shape: Union[int, Tuple[int, ...]], key: Optional[mx.array] = None, n_transforms: Optional[int] = None): x = self.bast_dist.sample(sample_shape, key=key) for mask, conditioner in zip(self.mask_list[:n_transforms], self.conditioner_list[:n_transforms]): x, _ = MaskedCoupling(mask, conditioner, AffineBijector).forward_and_log_det(x)