mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 01:17:28 +08:00
76 lines
2.4 KiB
Python
76 lines
2.4 KiB
Python
# Copyright © 2023-2024 Apple Inc.
|
|
|
|
from typing import Optional, Tuple, Union
|
|
|
|
import mlx.core as mx
|
|
import mlx.nn as nn
|
|
from bijectors import AffineBijector, MaskedCoupling
|
|
from distributions import Normal
|
|
|
|
|
|
class MLP(nn.Module):
|
|
def __init__(self, n_layers: int, d_in: int, d_hidden: int, d_out: int):
|
|
super().__init__()
|
|
layer_sizes = [d_in] + [d_hidden] * n_layers + [d_out]
|
|
self.layers = [
|
|
nn.Linear(idim, odim)
|
|
for idim, odim in zip(layer_sizes[:-1], layer_sizes[1:])
|
|
]
|
|
|
|
def __call__(self, x):
|
|
for l in self.layers[:-1]:
|
|
x = nn.gelu(l(x))
|
|
return self.layers[-1](x)
|
|
|
|
|
|
class RealNVP(nn.Module):
|
|
def __init__(self, n_transforms: int, d_params: int, d_hidden: int, n_layers: int):
|
|
super().__init__()
|
|
|
|
# Alternating masks
|
|
self.mask_list = [mx.arange(d_params) % 2 == i % 2 for i in range(n_transforms)]
|
|
self.mask_list = [mask.astype(mx.bool_) for mask in self.mask_list]
|
|
|
|
self.freeze(keys=["mask_list"])
|
|
|
|
# Conditioning MLP
|
|
self.conditioner_list = [
|
|
MLP(n_layers, d_params, d_hidden, 2 * d_params) for _ in range(n_transforms)
|
|
]
|
|
|
|
self.base_dist = Normal(mx.zeros(d_params), mx.ones(d_params))
|
|
|
|
def log_prob(self, x: mx.array):
|
|
"""
|
|
Flow back to the primal Gaussian and compute log-density,
|
|
adding the transformation log-determinant along the way.
|
|
"""
|
|
log_prob = mx.zeros(x.shape[0])
|
|
for mask, conditioner in zip(self.mask_list[::-1], self.conditioner_list[::-1]):
|
|
x, ldj = MaskedCoupling(
|
|
mask, conditioner, AffineBijector
|
|
).inverse_and_log_det(x)
|
|
log_prob += ldj
|
|
return log_prob + self.base_dist.log_prob(x).sum(-1)
|
|
|
|
def sample(
|
|
self,
|
|
sample_shape: Union[int, Tuple[int, ...]],
|
|
key: Optional[mx.array] = None,
|
|
n_transforms: Optional[int] = None,
|
|
):
|
|
"""
|
|
Sample from the primal Gaussian and flow towards the target distribution.
|
|
"""
|
|
x = self.base_dist.sample(sample_shape, key=key)
|
|
for mask, conditioner in zip(
|
|
self.mask_list[:n_transforms], self.conditioner_list[:n_transforms]
|
|
):
|
|
x, _ = MaskedCoupling(
|
|
mask, conditioner, AffineBijector
|
|
).forward_and_log_det(x)
|
|
return x
|
|
|
|
def __call__(self, x: mx.array):
|
|
return self.log_prob(x)
|