mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-29 18:26:37 +08:00
Remove trailing commas in function arguments for unified formatting in flows example
This commit is contained in:
parent
18f9646d56
commit
5206c2740f
@ -8,13 +8,7 @@ from distributions import Normal
|
||||
|
||||
|
||||
class MLP(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
n_layers: int,
|
||||
d_in: int,
|
||||
d_hidden: int,
|
||||
d_out: int,
|
||||
):
|
||||
def __init__(self, n_layers: int, d_in: int, d_hidden: int, d_out: int):
|
||||
super().__init__()
|
||||
layer_sizes = [d_in] + [d_hidden] * n_layers + [d_out]
|
||||
self.layers = [nn.Linear(idim, odim) for idim, odim in zip(layer_sizes[:-1], layer_sizes[1:])]
|
||||
@ -26,13 +20,7 @@ class MLP(nn.Module):
|
||||
|
||||
|
||||
class RealNVP(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
n_transforms: int,
|
||||
d_params: int,
|
||||
d_hidden: int,
|
||||
n_layers: int,
|
||||
):
|
||||
def __init__(self, n_transforms: int, d_params: int, d_hidden: int, n_layers: int):
|
||||
super().__init__()
|
||||
|
||||
# Alternating masks
|
||||
@ -53,12 +41,7 @@ class RealNVP(nn.Module):
|
||||
log_prob += ldj
|
||||
return log_prob + self.bast_dist.log_prob(x).sum(-1)
|
||||
|
||||
def sample(
|
||||
self,
|
||||
sample_shape: Union[int, Tuple[int, ...]],
|
||||
key: Optional[mx.array] = None,
|
||||
n_transforms: Optional[int] = None,
|
||||
):
|
||||
def sample(self, sample_shape: Union[int, Tuple[int, ...]], key: Optional[mx.array] = None, n_transforms: Optional[int] = None):
|
||||
x = self.bast_dist.sample(sample_shape, key=key)
|
||||
for mask, conditioner in zip(self.mask_list[:n_transforms], self.conditioner_list[:n_transforms]):
|
||||
x, _ = MaskedCoupling(mask, conditioner, AffineBijector).forward_and_log_det(x)
|
||||
|
Loading…
Reference in New Issue
Block a user