mlx-examples/flow/flows.py

69 lines
2.1 KiB
Python
Raw Normal View History

from typing import Tuple, Optional, Union
import mlx.core as mx
import mlx.nn as nn
from bijectors import MaskedCoupling, AffineBijector
from distributions import Normal
class MLP(nn.Module):
def __init__(
self,
n_layers: int,
d_in: int,
d_hidden: int,
d_out: int,
):
super().__init__()
layer_sizes = [d_in] + [d_hidden] * n_layers + [d_out]
self.layers = [nn.Linear(idim, odim) for idim, odim in zip(layer_sizes[:-1], layer_sizes[1:])]
def __call__(self, x):
for l in self.layers[:-1]:
x = nn.gelu(l(x))
return self.layers[-1](x)
class RealNVP(nn.Module):
def __init__(
self,
n_transforms: int,
d_params: int,
d_hidden: int,
n_layers: int,
):
super().__init__()
# Alternating masks
self.mask_list = [mx.arange(d_params) % 2 == i % 2 for i in range(n_transforms)]
self.mask_list = [mask.astype(mx.bool_) for mask in self.mask_list]
self.freeze(keys=["mask_list"])
# Conditioning MLP
self.conditioner_list = [MLP(n_layers, d_params, d_hidden, 2 * d_params) for _ in range(n_transforms)]
self.bast_dist = Normal(mx.zeros(d_params), mx.ones(d_params))
def log_prob(self, x: mx.array):
log_prob = mx.zeros(x.shape[0])
for mask, conditioner in zip(self.mask_list[::-1], self.conditioner_list[::-1]):
x, ldj = MaskedCoupling(mask, conditioner, AffineBijector).inverse_and_log_det(x)
log_prob += ldj
return log_prob + self.bast_dist.log_prob(x).sum(-1)
def sample(
self,
sample_shape: Union[int, Tuple[int, ...]],
key: Optional[mx.array] = None,
n_transforms: Optional[int] = None,
):
x = self.bast_dist.sample(sample_shape, key=key)
for mask, conditioner in zip(self.mask_list[:n_transforms], self.conditioner_list[:n_transforms]):
x, _ = MaskedCoupling(mask, conditioner, AffineBijector).forward_and_log_det(x)
return x
def __call__(self, x: mx.array):
return self.log_prob(x)