mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-25 01:41:19 +08:00
format
This commit is contained in:
parent
5117e2e65d
commit
1900564f59
@ -73,8 +73,7 @@ class StableDiffusion:
|
|||||||
|
|
||||||
# Create the latent variables
|
# Create the latent variables
|
||||||
x_T = self.sampler.sample_prior(
|
x_T = self.sampler.sample_prior(
|
||||||
(n_images, *latent_size, self.autoencoder.latent_channels),
|
(n_images, *latent_size, self.autoencoder.latent_channels), dtype=self.dtype
|
||||||
dtype=self.dtype
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Perform the denoising loop
|
# Perform the denoising loop
|
||||||
|
@ -51,7 +51,9 @@ class SimpleEulerSampler:
|
|||||||
|
|
||||||
def sample_prior(self, shape, dtype=mx.float32, key=None):
|
def sample_prior(self, shape, dtype=mx.float32, key=None):
|
||||||
noise = mx.random.normal(shape, key=key)
|
noise = mx.random.normal(shape, key=key)
|
||||||
return (noise * self._sigmas[-1] * (self._sigmas[-1].square() + 1).rsqrt()).astype(dtype)
|
return (
|
||||||
|
noise * self._sigmas[-1] * (self._sigmas[-1].square() + 1).rsqrt()
|
||||||
|
).astype(dtype)
|
||||||
|
|
||||||
def sigmas(self, t):
|
def sigmas(self, t):
|
||||||
return _interp(self._sigmas, t)
|
return _interp(self._sigmas, t)
|
||||||
|
@ -5,6 +5,7 @@ import regex
|
|||||||
|
|
||||||
class Tokenizer:
|
class Tokenizer:
|
||||||
"""A simple port of CLIPTokenizer from https://github.com/huggingface/transformers/ ."""
|
"""A simple port of CLIPTokenizer from https://github.com/huggingface/transformers/ ."""
|
||||||
|
|
||||||
def __init__(self, bpe_ranks, vocab):
|
def __init__(self, bpe_ranks, vocab):
|
||||||
self.bpe_ranks = bpe_ranks
|
self.bpe_ranks = bpe_ranks
|
||||||
self.vocab = vocab
|
self.vocab = vocab
|
||||||
@ -46,7 +47,9 @@ class Tokenizer:
|
|||||||
#
|
#
|
||||||
# Ported from https://github.com/huggingface/transformers/blob/main/src/transformers/models/clip/tokenization_clip.py
|
# Ported from https://github.com/huggingface/transformers/blob/main/src/transformers/models/clip/tokenization_clip.py
|
||||||
while unique_bigrams:
|
while unique_bigrams:
|
||||||
bigram = min(unique_bigrams, key=lambda pair: self.bpe_ranks.get(pair, float("inf")))
|
bigram = min(
|
||||||
|
unique_bigrams, key=lambda pair: self.bpe_ranks.get(pair, float("inf"))
|
||||||
|
)
|
||||||
if bigram not in self.bpe_ranks:
|
if bigram not in self.bpe_ranks:
|
||||||
break
|
break
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user