This commit is contained in:
Awni Hannun 2023-11-30 11:52:47 -08:00
parent 5117e2e65d
commit 1900564f59
3 changed files with 8 additions and 4 deletions

View File

@ -73,8 +73,7 @@ class StableDiffusion:
# Create the latent variables # Create the latent variables
x_T = self.sampler.sample_prior( x_T = self.sampler.sample_prior(
(n_images, *latent_size, self.autoencoder.latent_channels), (n_images, *latent_size, self.autoencoder.latent_channels), dtype=self.dtype
dtype=self.dtype
) )
# Perform the denoising loop # Perform the denoising loop

View File

@ -51,7 +51,9 @@ class SimpleEulerSampler:
def sample_prior(self, shape, dtype=mx.float32, key=None): def sample_prior(self, shape, dtype=mx.float32, key=None):
noise = mx.random.normal(shape, key=key) noise = mx.random.normal(shape, key=key)
return (noise * self._sigmas[-1] * (self._sigmas[-1].square() + 1).rsqrt()).astype(dtype) return (
noise * self._sigmas[-1] * (self._sigmas[-1].square() + 1).rsqrt()
).astype(dtype)
def sigmas(self, t): def sigmas(self, t):
return _interp(self._sigmas, t) return _interp(self._sigmas, t)

View File

@ -5,6 +5,7 @@ import regex
class Tokenizer: class Tokenizer:
"""A simple port of CLIPTokenizer from https://github.com/huggingface/transformers/ .""" """A simple port of CLIPTokenizer from https://github.com/huggingface/transformers/ ."""
def __init__(self, bpe_ranks, vocab): def __init__(self, bpe_ranks, vocab):
self.bpe_ranks = bpe_ranks self.bpe_ranks = bpe_ranks
self.vocab = vocab self.vocab = vocab
@ -46,7 +47,9 @@ class Tokenizer:
# #
# Ported from https://github.com/huggingface/transformers/blob/main/src/transformers/models/clip/tokenization_clip.py # Ported from https://github.com/huggingface/transformers/blob/main/src/transformers/models/clip/tokenization_clip.py
while unique_bigrams: while unique_bigrams:
bigram = min(unique_bigrams, key=lambda pair: self.bpe_ranks.get(pair, float("inf"))) bigram = min(
unique_bigrams, key=lambda pair: self.bpe_ranks.get(pair, float("inf"))
)
if bigram not in self.bpe_ranks: if bigram not in self.bpe_ranks:
break break