mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-09-01 04:14:38 +08:00
Fix minor typos, add some annotations
This commit is contained in:
@@ -13,7 +13,7 @@ class Bijector:
|
||||
|
||||
|
||||
class AffineBijector(Bijector):
|
||||
def __init__(self, shift_and_log_scale):
|
||||
def __init__(self, shift_and_log_scale: mx.array):
|
||||
self.shift_and_log_scale = shift_and_log_scale
|
||||
|
||||
def forward_and_log_det(self, x: mx.array):
|
||||
|
@@ -32,16 +32,18 @@ class RealNVP(nn.Module):
|
||||
# Conditioning MLP
|
||||
self.conditioner_list = [MLP(n_layers, d_params, d_hidden, 2 * d_params) for _ in range(n_transforms)]
|
||||
|
||||
self.bast_dist = Normal(mx.zeros(d_params), mx.ones(d_params))
|
||||
self.base_dist = Normal(mx.zeros(d_params), mx.ones(d_params))
|
||||
|
||||
def log_prob(self, x: mx.array):
|
||||
"""Flow back to the primal Gaussian and compute log-density, adding the transformation log-determinant along the way."""
|
||||
log_prob = mx.zeros(x.shape[0])
|
||||
for mask, conditioner in zip(self.mask_list[::-1], self.conditioner_list[::-1]):
|
||||
x, ldj = MaskedCoupling(mask, conditioner, AffineBijector).inverse_and_log_det(x)
|
||||
log_prob += ldj
|
||||
return log_prob + self.bast_dist.log_prob(x).sum(-1)
|
||||
return log_prob + self.base_dist.log_prob(x).sum(-1)
|
||||
|
||||
def sample(self, sample_shape: Union[int, Tuple[int, ...]], key: Optional[mx.array] = None, n_transforms: Optional[int] = None):
|
||||
"""Sample from the primal Gaussian and flow towards the target distribution."""
|
||||
x = self.bast_dist.sample(sample_shape, key=key)
|
||||
for mask, conditioner in zip(self.mask_list[:n_transforms], self.conditioner_list[:n_transforms]):
|
||||
x, _ = MaskedCoupling(mask, conditioner, AffineBijector).forward_and_log_det(x)
|
||||
|
@@ -1,4 +1,5 @@
|
||||
mlx
|
||||
numpy
|
||||
tqdm
|
||||
scikit-learn
|
||||
scikit-learn
|
||||
matplotlib
|
BIN
flow/samples.png
BIN
flow/samples.png
Binary file not shown.
Before Width: | Height: | Size: 82 KiB After Width: | Height: | Size: 86 KiB |
Reference in New Issue
Block a user