mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 17:31:18 +08:00
format
This commit is contained in:
parent
5117e2e65d
commit
1900564f59
@ -73,8 +73,7 @@ class StableDiffusion:
|
||||
|
||||
# Create the latent variables
|
||||
x_T = self.sampler.sample_prior(
|
||||
(n_images, *latent_size, self.autoencoder.latent_channels),
|
||||
dtype=self.dtype
|
||||
(n_images, *latent_size, self.autoencoder.latent_channels), dtype=self.dtype
|
||||
)
|
||||
|
||||
# Perform the denoising loop
|
||||
|
@ -51,7 +51,9 @@ class SimpleEulerSampler:
|
||||
|
||||
def sample_prior(self, shape, dtype=mx.float32, key=None):
|
||||
noise = mx.random.normal(shape, key=key)
|
||||
return (noise * self._sigmas[-1] * (self._sigmas[-1].square() + 1).rsqrt()).astype(dtype)
|
||||
return (
|
||||
noise * self._sigmas[-1] * (self._sigmas[-1].square() + 1).rsqrt()
|
||||
).astype(dtype)
|
||||
|
||||
def sigmas(self, t):
|
||||
return _interp(self._sigmas, t)
|
||||
|
@ -5,6 +5,7 @@ import regex
|
||||
|
||||
class Tokenizer:
|
||||
"""A simple port of CLIPTokenizer from https://github.com/huggingface/transformers/ ."""
|
||||
|
||||
def __init__(self, bpe_ranks, vocab):
|
||||
self.bpe_ranks = bpe_ranks
|
||||
self.vocab = vocab
|
||||
@ -46,7 +47,9 @@ class Tokenizer:
|
||||
#
|
||||
# Ported from https://github.com/huggingface/transformers/blob/main/src/transformers/models/clip/tokenization_clip.py
|
||||
while unique_bigrams:
|
||||
bigram = min(unique_bigrams, key=lambda pair: self.bpe_ranks.get(pair, float("inf")))
|
||||
bigram = min(
|
||||
unique_bigrams, key=lambda pair: self.bpe_ranks.get(pair, float("inf"))
|
||||
)
|
||||
if bigram not in self.bpe_ranks:
|
||||
break
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user