Fix minor typos, add some annotations

This commit is contained in:
Siddharth Mishra-Sharma
2023-12-20 13:10:05 -05:00
parent 5206c2740f
commit 3b98fb3a0f
4 changed files with 7 additions and 4 deletions

View File

@@ -13,7 +13,7 @@ class Bijector:
class AffineBijector(Bijector):
def __init__(self, shift_and_log_scale):
def __init__(self, shift_and_log_scale: mx.array):
self.shift_and_log_scale = shift_and_log_scale
def forward_and_log_det(self, x: mx.array):

View File

@@ -32,16 +32,18 @@ class RealNVP(nn.Module):
# Conditioning MLP
self.conditioner_list = [MLP(n_layers, d_params, d_hidden, 2 * d_params) for _ in range(n_transforms)]
self.bast_dist = Normal(mx.zeros(d_params), mx.ones(d_params))
self.base_dist = Normal(mx.zeros(d_params), mx.ones(d_params))
def log_prob(self, x: mx.array):
"""Flow back to the primal Gaussian and compute log-density, adding the transformation log-determinant along the way."""
log_prob = mx.zeros(x.shape[0])
for mask, conditioner in zip(self.mask_list[::-1], self.conditioner_list[::-1]):
x, ldj = MaskedCoupling(mask, conditioner, AffineBijector).inverse_and_log_det(x)
log_prob += ldj
return log_prob + self.bast_dist.log_prob(x).sum(-1)
return log_prob + self.base_dist.log_prob(x).sum(-1)
def sample(self, sample_shape: Union[int, Tuple[int, ...]], key: Optional[mx.array] = None, n_transforms: Optional[int] = None):
"""Sample from the primal Gaussian and flow towards the target distribution."""
x = self.bast_dist.sample(sample_shape, key=key)
for mask, conditioner in zip(self.mask_list[:n_transforms], self.conditioner_list[:n_transforms]):
x, _ = MaskedCoupling(mask, conditioner, AffineBijector).forward_and_log_det(x)

View File

@@ -1,4 +1,5 @@
mlx
numpy
tqdm
scikit-learn
scikit-learn
matplotlib

Binary file not shown.

Before

Width:  |  Height:  |  Size: 82 KiB

After

Width:  |  Height:  |  Size: 86 KiB