mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-29 18:26:37 +08:00
Remove trailing commas in function arguments for unified formatting in flows example
This commit is contained in:
parent
18f9646d56
commit
5206c2740f
@ -8,13 +8,7 @@ from distributions import Normal
|
|||||||
|
|
||||||
|
|
||||||
class MLP(nn.Module):
|
class MLP(nn.Module):
|
||||||
def __init__(
|
def __init__(self, n_layers: int, d_in: int, d_hidden: int, d_out: int):
|
||||||
self,
|
|
||||||
n_layers: int,
|
|
||||||
d_in: int,
|
|
||||||
d_hidden: int,
|
|
||||||
d_out: int,
|
|
||||||
):
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
layer_sizes = [d_in] + [d_hidden] * n_layers + [d_out]
|
layer_sizes = [d_in] + [d_hidden] * n_layers + [d_out]
|
||||||
self.layers = [nn.Linear(idim, odim) for idim, odim in zip(layer_sizes[:-1], layer_sizes[1:])]
|
self.layers = [nn.Linear(idim, odim) for idim, odim in zip(layer_sizes[:-1], layer_sizes[1:])]
|
||||||
@ -26,13 +20,7 @@ class MLP(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class RealNVP(nn.Module):
|
class RealNVP(nn.Module):
|
||||||
def __init__(
|
def __init__(self, n_transforms: int, d_params: int, d_hidden: int, n_layers: int):
|
||||||
self,
|
|
||||||
n_transforms: int,
|
|
||||||
d_params: int,
|
|
||||||
d_hidden: int,
|
|
||||||
n_layers: int,
|
|
||||||
):
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
# Alternating masks
|
# Alternating masks
|
||||||
@ -53,12 +41,7 @@ class RealNVP(nn.Module):
|
|||||||
log_prob += ldj
|
log_prob += ldj
|
||||||
return log_prob + self.bast_dist.log_prob(x).sum(-1)
|
return log_prob + self.bast_dist.log_prob(x).sum(-1)
|
||||||
|
|
||||||
def sample(
|
def sample(self, sample_shape: Union[int, Tuple[int, ...]], key: Optional[mx.array] = None, n_transforms: Optional[int] = None):
|
||||||
self,
|
|
||||||
sample_shape: Union[int, Tuple[int, ...]],
|
|
||||||
key: Optional[mx.array] = None,
|
|
||||||
n_transforms: Optional[int] = None,
|
|
||||||
):
|
|
||||||
x = self.bast_dist.sample(sample_shape, key=key)
|
x = self.bast_dist.sample(sample_shape, key=key)
|
||||||
for mask, conditioner in zip(self.mask_list[:n_transforms], self.conditioner_list[:n_transforms]):
|
for mask, conditioner in zip(self.mask_list[:n_transforms], self.conditioner_list[:n_transforms]):
|
||||||
x, _ = MaskedCoupling(mask, conditioner, AffineBijector).forward_and_log_det(x)
|
x, _ = MaskedCoupling(mask, conditioner, AffineBijector).forward_and_log_det(x)
|
||||||
|
Loading…
Reference in New Issue
Block a user