diff --git a/stable_diffusion/stable_diffusion/__init__.py b/stable_diffusion/stable_diffusion/__init__.py index 57fb8b53..778eff39 100644 --- a/stable_diffusion/stable_diffusion/__init__.py +++ b/stable_diffusion/stable_diffusion/__init__.py @@ -73,8 +73,7 @@ class StableDiffusion: # Create the latent variables x_T = self.sampler.sample_prior( - (n_images, *latent_size, self.autoencoder.latent_channels), - dtype=self.dtype + (n_images, *latent_size, self.autoencoder.latent_channels), dtype=self.dtype ) # Perform the denoising loop diff --git a/stable_diffusion/stable_diffusion/sampler.py b/stable_diffusion/stable_diffusion/sampler.py index e9cc2c1b..a1edf931 100644 --- a/stable_diffusion/stable_diffusion/sampler.py +++ b/stable_diffusion/stable_diffusion/sampler.py @@ -51,7 +51,9 @@ class SimpleEulerSampler: def sample_prior(self, shape, dtype=mx.float32, key=None): noise = mx.random.normal(shape, key=key) - return (noise * self._sigmas[-1] * (self._sigmas[-1].square() + 1).rsqrt()).astype(dtype) + return ( + noise * self._sigmas[-1] * (self._sigmas[-1].square() + 1).rsqrt() + ).astype(dtype) def sigmas(self, t): return _interp(self._sigmas, t) diff --git a/stable_diffusion/stable_diffusion/tokenizer.py b/stable_diffusion/stable_diffusion/tokenizer.py index 41d719ae..07375fc7 100644 --- a/stable_diffusion/stable_diffusion/tokenizer.py +++ b/stable_diffusion/stable_diffusion/tokenizer.py @@ -5,6 +5,7 @@ import regex class Tokenizer: """A simple port of CLIPTokenizer from https://github.com/huggingface/transformers/ .""" + def __init__(self, bpe_ranks, vocab): self.bpe_ranks = bpe_ranks self.vocab = vocab @@ -46,7 +47,9 @@ class Tokenizer: # # Ported from https://github.com/huggingface/transformers/blob/main/src/transformers/models/clip/tokenization_clip.py while unique_bigrams: - bigram = min(unique_bigrams, key=lambda pair: self.bpe_ranks.get(pair, float("inf"))) + bigram = min( + unique_bigrams, key=lambda pair: self.bpe_ranks.get(pair, float("inf")) + ) if bigram not in self.bpe_ranks: break