mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-09-11 06:04:36 +08:00
Fix minor typos, add some annotations
This commit is contained in:
@@ -13,7 +13,7 @@ class Bijector:
|
|||||||
|
|
||||||
|
|
||||||
class AffineBijector(Bijector):
|
class AffineBijector(Bijector):
|
||||||
def __init__(self, shift_and_log_scale):
|
def __init__(self, shift_and_log_scale: mx.array):
|
||||||
self.shift_and_log_scale = shift_and_log_scale
|
self.shift_and_log_scale = shift_and_log_scale
|
||||||
|
|
||||||
def forward_and_log_det(self, x: mx.array):
|
def forward_and_log_det(self, x: mx.array):
|
||||||
|
@@ -32,16 +32,18 @@ class RealNVP(nn.Module):
|
|||||||
# Conditioning MLP
|
# Conditioning MLP
|
||||||
self.conditioner_list = [MLP(n_layers, d_params, d_hidden, 2 * d_params) for _ in range(n_transforms)]
|
self.conditioner_list = [MLP(n_layers, d_params, d_hidden, 2 * d_params) for _ in range(n_transforms)]
|
||||||
|
|
||||||
self.bast_dist = Normal(mx.zeros(d_params), mx.ones(d_params))
|
self.base_dist = Normal(mx.zeros(d_params), mx.ones(d_params))
|
||||||
|
|
||||||
def log_prob(self, x: mx.array):
|
def log_prob(self, x: mx.array):
|
||||||
|
"""Flow back to the primal Gaussian and compute log-density, adding the transformation log-determinant along the way."""
|
||||||
log_prob = mx.zeros(x.shape[0])
|
log_prob = mx.zeros(x.shape[0])
|
||||||
for mask, conditioner in zip(self.mask_list[::-1], self.conditioner_list[::-1]):
|
for mask, conditioner in zip(self.mask_list[::-1], self.conditioner_list[::-1]):
|
||||||
x, ldj = MaskedCoupling(mask, conditioner, AffineBijector).inverse_and_log_det(x)
|
x, ldj = MaskedCoupling(mask, conditioner, AffineBijector).inverse_and_log_det(x)
|
||||||
log_prob += ldj
|
log_prob += ldj
|
||||||
return log_prob + self.bast_dist.log_prob(x).sum(-1)
|
return log_prob + self.base_dist.log_prob(x).sum(-1)
|
||||||
|
|
||||||
def sample(self, sample_shape: Union[int, Tuple[int, ...]], key: Optional[mx.array] = None, n_transforms: Optional[int] = None):
|
def sample(self, sample_shape: Union[int, Tuple[int, ...]], key: Optional[mx.array] = None, n_transforms: Optional[int] = None):
|
||||||
|
"""Sample from the primal Gaussian and flow towards the target distribution."""
|
||||||
x = self.bast_dist.sample(sample_shape, key=key)
|
x = self.bast_dist.sample(sample_shape, key=key)
|
||||||
for mask, conditioner in zip(self.mask_list[:n_transforms], self.conditioner_list[:n_transforms]):
|
for mask, conditioner in zip(self.mask_list[:n_transforms], self.conditioner_list[:n_transforms]):
|
||||||
x, _ = MaskedCoupling(mask, conditioner, AffineBijector).forward_and_log_det(x)
|
x, _ = MaskedCoupling(mask, conditioner, AffineBijector).forward_and_log_det(x)
|
||||||
|
@@ -1,4 +1,5 @@
|
|||||||
mlx
|
mlx
|
||||||
numpy
|
numpy
|
||||||
tqdm
|
tqdm
|
||||||
scikit-learn
|
scikit-learn
|
||||||
|
matplotlib
|
BIN
flow/samples.png
BIN
flow/samples.png
Binary file not shown.
Before Width: | Height: | Size: 82 KiB After Width: | Height: | Size: 86 KiB |
Reference in New Issue
Block a user