This commit is contained in:
Awni Hannun 2023-11-30 11:52:47 -08:00
parent 5117e2e65d
commit 1900564f59
3 changed files with 8 additions and 4 deletions

View File

@ -73,8 +73,7 @@ class StableDiffusion:
# Create the latent variables
x_T = self.sampler.sample_prior(
(n_images, *latent_size, self.autoencoder.latent_channels),
dtype=self.dtype
(n_images, *latent_size, self.autoencoder.latent_channels), dtype=self.dtype
)
# Perform the denoising loop

View File

@ -51,7 +51,9 @@ class SimpleEulerSampler:
def sample_prior(self, shape, dtype=mx.float32, key=None):
noise = mx.random.normal(shape, key=key)
return (noise * self._sigmas[-1] * (self._sigmas[-1].square() + 1).rsqrt()).astype(dtype)
return (
noise * self._sigmas[-1] * (self._sigmas[-1].square() + 1).rsqrt()
).astype(dtype)
def sigmas(self, t):
return _interp(self._sigmas, t)

View File

@ -5,6 +5,7 @@ import regex
class Tokenizer:
"""A simple port of CLIPTokenizer from https://github.com/huggingface/transformers/ ."""
def __init__(self, bpe_ranks, vocab):
self.bpe_ranks = bpe_ranks
self.vocab = vocab
@ -46,7 +47,9 @@ class Tokenizer:
#
# Ported from https://github.com/huggingface/transformers/blob/main/src/transformers/models/clip/tokenization_clip.py
while unique_bigrams:
bigram = min(unique_bigrams, key=lambda pair: self.bpe_ranks.get(pair, float("inf")))
bigram = min(
unique_bigrams, key=lambda pair: self.bpe_ranks.get(pair, float("inf"))
)
if bigram not in self.bpe_ranks:
break