mlx-examples/flow/flows.py

52 lines
2.0 KiB
Python

from typing import Tuple, Optional, Union
import mlx.core as mx
import mlx.nn as nn
from bijectors import MaskedCoupling, AffineBijector
from distributions import Normal
class MLP(nn.Module):
def __init__(self, n_layers: int, d_in: int, d_hidden: int, d_out: int):
super().__init__()
layer_sizes = [d_in] + [d_hidden] * n_layers + [d_out]
self.layers = [nn.Linear(idim, odim) for idim, odim in zip(layer_sizes[:-1], layer_sizes[1:])]
def __call__(self, x):
for l in self.layers[:-1]:
x = nn.gelu(l(x))
return self.layers[-1](x)
class RealNVP(nn.Module):
def __init__(self, n_transforms: int, d_params: int, d_hidden: int, n_layers: int):
super().__init__()
# Alternating masks
self.mask_list = [mx.arange(d_params) % 2 == i % 2 for i in range(n_transforms)]
self.mask_list = [mask.astype(mx.bool_) for mask in self.mask_list]
self.freeze(keys=["mask_list"])
# Conditioning MLP
self.conditioner_list = [MLP(n_layers, d_params, d_hidden, 2 * d_params) for _ in range(n_transforms)]
self.bast_dist = Normal(mx.zeros(d_params), mx.ones(d_params))
def log_prob(self, x: mx.array):
log_prob = mx.zeros(x.shape[0])
for mask, conditioner in zip(self.mask_list[::-1], self.conditioner_list[::-1]):
x, ldj = MaskedCoupling(mask, conditioner, AffineBijector).inverse_and_log_det(x)
log_prob += ldj
return log_prob + self.bast_dist.log_prob(x).sum(-1)
def sample(self, sample_shape: Union[int, Tuple[int, ...]], key: Optional[mx.array] = None, n_transforms: Optional[int] = None):
x = self.bast_dist.sample(sample_shape, key=key)
for mask, conditioner in zip(self.mask_list[:n_transforms], self.conditioner_list[:n_transforms]):
x, _ = MaskedCoupling(mask, conditioner, AffineBijector).forward_and_log_det(x)
return x
def __call__(self, x: mx.array):
return self.log_prob(x)