mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-09-01 04:14:38 +08:00
Normalizing flow example (#133)
* Implement normalizing flow Real NVP example * Add requirements and basic usage to normalizing flow example * Minor changes to README in normalizing flow example * Remove trailing commas in function arguments for unified formatting in flows example * Fix minor typos, add some annotations * format + nits in README * readme fix * mov, minor changes in main, copywright * remove debug * fix * Simplified class structure in distributions; better code re-use in bijectors * Remove rogue space * change name again * nits --------- Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:

committed by
GitHub

parent
cd3cff0858
commit
19b6167d81
31
normalizing_flow/distributions.py
Normal file
31
normalizing_flow/distributions.py
Normal file
@@ -0,0 +1,31 @@
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
from typing import Tuple, Optional, Union
|
||||
import math
|
||||
|
||||
import mlx.core as mx
|
||||
|
||||
|
||||
class Normal:
|
||||
def __init__(self, mu: mx.array, sigma: mx.array):
|
||||
super().__init__()
|
||||
self.mu = mu
|
||||
self.sigma = sigma
|
||||
|
||||
def sample(
|
||||
self, sample_shape: Union[int, Tuple[int, ...]], key: Optional[mx.array] = None
|
||||
):
|
||||
return mx.random.normal(sample_shape, key=key) * self.sigma + self.mu
|
||||
|
||||
def log_prob(self, x: mx.array):
|
||||
return (
|
||||
-0.5 * math.log(2 * math.pi)
|
||||
- mx.log(self.sigma)
|
||||
- 0.5 * ((x - self.mu) / self.sigma) ** 2
|
||||
)
|
||||
|
||||
def sample_and_log_prob(
|
||||
self, sample_shape: Union[int, Tuple[int, ...]], key: Optional[mx.array] = None
|
||||
):
|
||||
x = self.sample(sample_shape, key=key)
|
||||
return x, self.log_prob(x)
|
Reference in New Issue
Block a user