Normalizing flow example (#133)

* Implement normalizing flow Real NVP example

* Add requirements and basic usage to normalizing flow example

* Minor changes to README in normalizing flow example

* Remove trailing commas in function arguments for unified formatting in flows example

* Fix minor typos, add some annotations

* format + nits in README

* readme fix

* mov, minor changes in main, copywright

* remove debug

* fix

* Simplified class structure in distributions; better code re-use in bijectors

* Remove rogue space

* change name again

* nits

---------

Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
Siddharth Mishra-Sharma
2024-01-13 19:58:48 -05:00
committed by GitHub
parent cd3cff0858
commit 19b6167d81
7 changed files with 339 additions and 0 deletions

View File

@@ -0,0 +1,31 @@
# Copyright © 2023-2024 Apple Inc.
from typing import Tuple, Optional, Union
import math
import mlx.core as mx
class Normal:
def __init__(self, mu: mx.array, sigma: mx.array):
super().__init__()
self.mu = mu
self.sigma = sigma
def sample(
self, sample_shape: Union[int, Tuple[int, ...]], key: Optional[mx.array] = None
):
return mx.random.normal(sample_shape, key=key) * self.sigma + self.mu
def log_prob(self, x: mx.array):
return (
-0.5 * math.log(2 * math.pi)
- mx.log(self.sigma)
- 0.5 * ((x - self.mu) / self.sigma) ** 2
)
def sample_and_log_prob(
self, sample_shape: Union[int, Tuple[int, ...]], key: Optional[mx.array] = None
):
x = self.sample(sample_shape, key=key)
return x, self.log_prob(x)