Remove trailing commas in function arguments for unified formatting in flows example

This commit is contained in:
Siddharth Mishra-Sharma 2023-12-18 10:03:09 -05:00
parent 18f9646d56
commit 5206c2740f

View File

@ -8,13 +8,7 @@ from distributions import Normal
class MLP(nn.Module):
def __init__(
self,
n_layers: int,
d_in: int,
d_hidden: int,
d_out: int,
):
def __init__(self, n_layers: int, d_in: int, d_hidden: int, d_out: int):
super().__init__()
layer_sizes = [d_in] + [d_hidden] * n_layers + [d_out]
self.layers = [nn.Linear(idim, odim) for idim, odim in zip(layer_sizes[:-1], layer_sizes[1:])]
@ -26,13 +20,7 @@ class MLP(nn.Module):
class RealNVP(nn.Module):
def __init__(
self,
n_transforms: int,
d_params: int,
d_hidden: int,
n_layers: int,
):
def __init__(self, n_transforms: int, d_params: int, d_hidden: int, n_layers: int):
super().__init__()
# Alternating masks
@ -53,12 +41,7 @@ class RealNVP(nn.Module):
log_prob += ldj
return log_prob + self.bast_dist.log_prob(x).sum(-1)
def sample(
self,
sample_shape: Union[int, Tuple[int, ...]],
key: Optional[mx.array] = None,
n_transforms: Optional[int] = None,
):
def sample(self, sample_shape: Union[int, Tuple[int, ...]], key: Optional[mx.array] = None, n_transforms: Optional[int] = None):
x = self.bast_dist.sample(sample_shape, key=key)
for mask, conditioner in zip(self.mask_list[:n_transforms], self.conditioner_list[:n_transforms]):
x, _ = MaskedCoupling(mask, conditioner, AffineBijector).forward_and_log_det(x)