mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 01:17:28 +08:00
Add the Llama and Stable Diffusion examples
This commit is contained in:
parent
7f1328f333
commit
b364cc56cd
37
llama/README.md
Normal file
37
llama/README.md
Normal file
@ -0,0 +1,37 @@
|
||||
# LLaMA
|
||||
|
||||
An example of generating text with LLaMA using MLX.
|
||||
|
||||
LLaMA is a set of open source language models from Meta AI Research[^1] ranging from 7B to 65B parameters.
|
||||
|
||||
### Setup
|
||||
|
||||
Install the dependencies:
|
||||
|
||||
```
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
Next, download and convert the model. If you do not have access to the model
|
||||
weights you will need to [request
|
||||
access](https://docs.google.com/forms/d/e/1FAIpQLSfqNECQnMkycAp2jP4Z9TFX0cGR4uf7b_fBxjY_OjhJILlKGA/viewform)
|
||||
from Meta.
|
||||
|
||||
Convert the weights with:
|
||||
|
||||
```
|
||||
python convert.py <path_to_torch_weights> mlx_llama_weights.npz
|
||||
```
|
||||
|
||||
### Run
|
||||
|
||||
Once you've converted the weights to MLX format, you can interact with the
|
||||
LLaMA model:
|
||||
|
||||
```
|
||||
python llama.py mlx_llama.npz tokenizer.model "hello"
|
||||
```
|
||||
|
||||
Run `python llama.py --help` for more details.
|
||||
|
||||
[^1]: Refer to the [arXiv paper](https://arxiv.org/abs/2302.13971) and [blog post](https://ai.meta.com/blog/large-language-model-llama-meta-ai/) for more details.
|
46
llama/convert.py
Normal file
46
llama/convert.py
Normal file
@ -0,0 +1,46 @@
|
||||
import argparse
|
||||
from itertools import starmap
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
|
||||
def map_torch_to_mlx(key, value):
|
||||
if "tok_embedding" in key:
|
||||
key = "embedding.weight"
|
||||
|
||||
elif "norm" in key:
|
||||
key = key.replace("attention_norm", "norm1").replace("ffn_norm", "norm2")
|
||||
|
||||
elif "wq" in key or "wk" in key or "wv" in key or "wo" in key:
|
||||
key = key.replace("wq", "query_proj")
|
||||
key = key.replace("wk", "key_proj")
|
||||
key = key.replace("wv", "value_proj")
|
||||
key = key.replace("wo", "out_proj")
|
||||
|
||||
elif "w1" in key or "w2" in key or "w3" in key:
|
||||
# The FFN is a separate submodule in PyTorch
|
||||
key = key.replace("feed_forward.w1", "linear1")
|
||||
key = key.replace("feed_forward.w3", "linear2")
|
||||
key = key.replace("feed_forward.w2", "linear3")
|
||||
|
||||
elif "output" in key:
|
||||
key = key.replace("output", "out_proj")
|
||||
|
||||
elif "rope" in key:
|
||||
return None, None
|
||||
|
||||
return key, value.numpy()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Convert Llama weights to MLX")
|
||||
parser.add_argument("torch_weights")
|
||||
parser.add_argument("output_file")
|
||||
args = parser.parse_args()
|
||||
|
||||
state = torch.load(args.torch_weights)
|
||||
np.savez(
|
||||
args.output_file,
|
||||
**{k: v for k, v in starmap(map_torch_to_mlx, state.items()) if k is not None}
|
||||
)
|
304
llama/llama.py
Normal file
304
llama/llama.py
Normal file
@ -0,0 +1,304 @@
|
||||
import argparse
|
||||
import math
|
||||
import numpy as np
|
||||
from sentencepiece import SentencePieceProcessor
|
||||
import time
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
from mlx.utils import tree_unflatten
|
||||
|
||||
|
||||
class LlamaAttention(nn.Module):
|
||||
def __init__(self, dims: int, num_heads: int):
|
||||
super().__init__()
|
||||
|
||||
self.num_heads = num_heads
|
||||
|
||||
self.rope = nn.RoPE(dims // num_heads, traditional=True)
|
||||
self.query_proj = nn.Linear(dims, dims, bias=False)
|
||||
self.key_proj = nn.Linear(dims, dims, bias=False)
|
||||
self.value_proj = nn.Linear(dims, dims, bias=False)
|
||||
self.out_proj = nn.Linear(dims, dims, bias=False)
|
||||
|
||||
def __call__(self, queries, keys, values, mask=None, cache=None):
|
||||
queries = self.query_proj(queries)
|
||||
keys = self.key_proj(keys)
|
||||
values = self.value_proj(values)
|
||||
|
||||
# Extract some shapes
|
||||
num_heads = self.num_heads
|
||||
B, L, D = queries.shape
|
||||
|
||||
# Prepare the queries, keys and values for the attention computation
|
||||
queries = queries.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3)
|
||||
keys = keys.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3)
|
||||
values = values.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3)
|
||||
|
||||
# Add RoPE to the queries and keys and combine them with the cache
|
||||
if cache is not None:
|
||||
key_cache, value_cache = cache
|
||||
queries = self.rope(queries, offset=key_cache.shape[2])
|
||||
keys = self.rope(keys, offset=key_cache.shape[2])
|
||||
keys = mx.concatenate([key_cache, keys], axis=2)
|
||||
values = mx.concatenate([value_cache, values], axis=2)
|
||||
else:
|
||||
queries = self.rope(queries)
|
||||
keys = self.rope(keys)
|
||||
|
||||
# Finally perform the attention computation
|
||||
scale = math.sqrt(1 / queries.shape[-1])
|
||||
scores = (queries * scale) @ keys.transpose(0, 1, 3, 2)
|
||||
if mask is not None:
|
||||
scores = scores + mask
|
||||
scores = mx.softmax(scores, axis=-1)
|
||||
values_hat = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||
|
||||
# Note that we return the keys and values to possibly be used as a cache
|
||||
return self.out_proj(values_hat), (keys, values)
|
||||
|
||||
|
||||
class LlamaEncoderLayer(nn.Module):
|
||||
def __init__(self, dims: int, mlp_dims: int, num_heads: int):
|
||||
super().__init__()
|
||||
|
||||
self.attention = LlamaAttention(dims, num_heads)
|
||||
|
||||
self.norm1 = nn.RMSNorm(dims)
|
||||
self.norm2 = nn.RMSNorm(dims)
|
||||
|
||||
self.linear1 = nn.Linear(dims, mlp_dims, bias=False)
|
||||
self.linear2 = nn.Linear(dims, mlp_dims, bias=False)
|
||||
self.linear3 = nn.Linear(mlp_dims, dims, bias=False)
|
||||
|
||||
def __call__(self, x, mask=None, cache=None):
|
||||
y = self.norm1(x)
|
||||
y, cache = self.attention(y, y, y, mask, cache)
|
||||
x = x + y
|
||||
|
||||
y = self.norm2(x)
|
||||
a = self.linear1(y)
|
||||
b = self.linear2(y)
|
||||
y = a * mx.sigmoid(a) * b
|
||||
y = self.linear3(y)
|
||||
x = x + y
|
||||
|
||||
return x, cache
|
||||
|
||||
|
||||
class Llama(nn.Module):
|
||||
def __init__(
|
||||
self, num_layers: int, vocab_size: int, dims: int, mlp_dims: int, num_heads: int
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.embedding = nn.Embedding(vocab_size, dims)
|
||||
self.layers = [
|
||||
LlamaEncoderLayer(dims, mlp_dims, num_heads) for _ in range(num_layers)
|
||||
]
|
||||
self.norm = nn.RMSNorm(dims)
|
||||
self.out_proj = nn.Linear(dims, vocab_size, bias=False)
|
||||
|
||||
def __call__(self, x):
|
||||
mask = nn.MultiHeadAttention.create_additive_causal_mask(x.shape[1])
|
||||
mask = mask.astype(self.embedding.weight.dtype)
|
||||
|
||||
x = self.embedding(x)
|
||||
for l in self.layers:
|
||||
x, _ = l(x, mask)
|
||||
x = self.norm(x)
|
||||
return self.out_proj(x)
|
||||
|
||||
def generate(self, x, temp=1.0):
|
||||
cache = []
|
||||
|
||||
# Make an additive causal mask. We will need that to process the prompt.
|
||||
mask = nn.MultiHeadAttention.create_additive_causal_mask(x.shape[1])
|
||||
mask = mask.astype(self.embedding.weight.dtype)
|
||||
|
||||
# First we process the prompt x the same was as in __call__ but
|
||||
# save the caches in cache
|
||||
x = self.embedding(x)
|
||||
for l in self.layers:
|
||||
x, c = l(x, mask=mask)
|
||||
# We store the per layer cache in a simple python list
|
||||
cache.append(c)
|
||||
x = self.norm(x)
|
||||
# We only care about the last logits that generate the next token
|
||||
y = self.out_proj(x[:, -1])
|
||||
y = mx.random.categorical(y * (1 / temp))
|
||||
|
||||
# y now has size [1]
|
||||
# Since MLX is lazily evaluated nothing is computed yet.
|
||||
# Calling y.item() would force the computation to happen at
|
||||
# this point but we can also choose not to do that and let the
|
||||
# user choose when to start the computation.
|
||||
yield y
|
||||
|
||||
# Now we parsed the prompt and generated the first token we
|
||||
# need to feed it back into the model and loop to generate the
|
||||
# rest.
|
||||
while True:
|
||||
# Unsqueezing the last dimension to add a sequence length
|
||||
# dimension of 1
|
||||
x = y[:, None]
|
||||
|
||||
x = self.embedding(x)
|
||||
for i in range(len(cache)):
|
||||
# We are overwriting the arrays in the cache list. When
|
||||
# the computation will happen, MLX will be discarding the
|
||||
# old cache the moment it is not needed anymore.
|
||||
x, cache[i] = self.layers[i](x, mask=None, cache=cache[i])
|
||||
x = self.norm(x)
|
||||
y = self.out_proj(x[:, -1])
|
||||
y = mx.random.categorical(y * (1 / temp))
|
||||
|
||||
yield y
|
||||
|
||||
|
||||
def tic():
|
||||
return time.time()
|
||||
|
||||
|
||||
def toc(msg, start):
|
||||
end = time.time()
|
||||
return f"[INFO] {msg}: {end - start:.3f} s"
|
||||
|
||||
|
||||
def generate(args):
|
||||
|
||||
input("Press enter to start generation")
|
||||
print("------")
|
||||
|
||||
x = mx.array([[tokenizer.bos_id()] + tokenizer.encode(args.prompt)])
|
||||
skip = 0
|
||||
prompt_processing = None
|
||||
tokens = []
|
||||
start = tic()
|
||||
for token in model.generate(x, args.temp):
|
||||
tokens.append(token)
|
||||
|
||||
if len(tokens) == 1:
|
||||
# Actually perform the computation to measure the prompt processing time
|
||||
mx.eval(token)
|
||||
prompt_processing = toc("Prompt processing", start)
|
||||
|
||||
if len(tokens) >= args.num_tokens:
|
||||
break
|
||||
|
||||
elif (len(tokens) % args.write_every) == 0:
|
||||
# It is perfectly ok to eval things we have already eval-ed.
|
||||
mx.eval(tokens)
|
||||
s = tokenizer.decode([t.item() for t in tokens])
|
||||
print(s[skip:], end="", flush=True)
|
||||
skip = len(s)
|
||||
|
||||
mx.eval(tokens)
|
||||
full_gen = toc("Full generation", start)
|
||||
s = tokenizer.decode([t.item() for t in tokens])
|
||||
print(s[skip:], end="", flush=True)
|
||||
print()
|
||||
print("------")
|
||||
print(prompt_processing)
|
||||
print(full_gen)
|
||||
|
||||
|
||||
def few_shot_generate(args):
|
||||
def possible_end(s):
|
||||
word = "[Instruction]"
|
||||
for i in range(len(word) - 1, 0, -1):
|
||||
if s[-i:] == word[:i]:
|
||||
return 0
|
||||
if s[-len(word) :] == word:
|
||||
return 1
|
||||
return -1
|
||||
|
||||
def generate(question):
|
||||
x = mx.array([[tokenizer.bos_id()] + tokenizer.encode(question)])
|
||||
skip = 0
|
||||
prompt_processing = None
|
||||
tokens = []
|
||||
start = tic()
|
||||
for token in model.generate(x, args.temp):
|
||||
tokens.append(token)
|
||||
|
||||
if len(tokens) == 1:
|
||||
# Actually perform the computation to measure the prompt processing time
|
||||
mx.eval(token)
|
||||
prompt_processing = toc("Prompt processing", start)
|
||||
|
||||
if len(tokens) >= args.num_tokens:
|
||||
break
|
||||
|
||||
mx.eval(tokens)
|
||||
token_list = [t.item() for t in tokens]
|
||||
s = tokenizer.decode(token_list)
|
||||
|
||||
end = possible_end(s)
|
||||
if end == 0:
|
||||
continue
|
||||
if end == 1:
|
||||
skip = len(s)
|
||||
break
|
||||
|
||||
print(s[skip:], end="", flush=True)
|
||||
skip = len(s)
|
||||
if token_list[-1] == tokenizer.eos_id():
|
||||
break
|
||||
|
||||
mx.eval(tokens)
|
||||
full_gen = toc("Full generation", start)
|
||||
s = tokenizer.decode([t.item() for t in tokens])
|
||||
print(s[skip:], end="", flush=True)
|
||||
|
||||
prompt = open(args.prompt).read().strip()
|
||||
while True:
|
||||
question = input("Ask a question: ")
|
||||
generate(prompt.replace("{}", question))
|
||||
print()
|
||||
|
||||
|
||||
def load_model(model_path):
|
||||
weights = mx.load(model_path)
|
||||
mlp_dims, dims = weights["layers.0.linear1.weight"].shape
|
||||
num_heads = dims // 128
|
||||
num_layers = max(int(l.split(".")[1]) for l in weights.keys() if "layers" in l) + 1
|
||||
vocab_size = weights["out_proj.weight"].shape[-1]
|
||||
model = Llama(num_layers, vocab_size, dims, mlp_dims, num_heads)
|
||||
model.update(tree_unflatten(list(weights.items())))
|
||||
mx.eval(model.parameters())
|
||||
return model
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Llama inference script")
|
||||
parser.add_argument("model", help="The model file containing MLX weights")
|
||||
parser.add_argument("tokenizer", help="The sentencepiece tokenizer")
|
||||
parser.add_argument("prompt", help="The message to be processed by the model")
|
||||
parser.add_argument(
|
||||
"--few-shot",
|
||||
action="store_true",
|
||||
help="Read a few shot prompt from a file (as in `sample_prompt.txt`).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-tokens", "-n", type=int, default=100, help="How many tokens to generate"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--write-every", type=int, default=1, help="After how many tokens to detokenize"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--temp", type=float, default=0.8, help="The sampling temperature"
|
||||
)
|
||||
parser.add_argument("--seed", type=int, default=0, help="The PRNG seed")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
mx.random.seed(args.seed)
|
||||
|
||||
tokenizer = SentencePieceProcessor(model_file=args.tokenizer)
|
||||
print("[INFO] Loading model from disk.")
|
||||
model = load_model(args.model)
|
||||
if args.few_shot:
|
||||
few_shot_generate(args)
|
||||
else:
|
||||
generate(args)
|
2
llama/requirements.txt
Normal file
2
llama/requirements.txt
Normal file
@ -0,0 +1,2 @@
|
||||
sentencepiece
|
||||
torch
|
23
llama/sample_prompt.txt
Normal file
23
llama/sample_prompt.txt
Normal file
@ -0,0 +1,23 @@
|
||||
[Instruction] Give the list of U.S. states bordering Canada
|
||||
[Answer] OK, here is the list of U.S. states located on the border with Canada:
|
||||
- Alaska
|
||||
- Michigan
|
||||
- Maine
|
||||
- Minnesota
|
||||
- Montana
|
||||
- New York
|
||||
- Washington
|
||||
- North Dakota
|
||||
- Ohio
|
||||
- Vermont
|
||||
- New Hampshire
|
||||
- Idaho
|
||||
- Pennsylvania
|
||||
[Instruction] Write a paragraph about "functional analysis"
|
||||
[Answer] OK, here is a paragraph on the topic of functional analysis:
|
||||
Functional analysis is a branch of mathematical analysis, the core of which is formed by the study of vector spaces endowed with some kind of limit-related structure (for example, inner product, norm, or topology) and the linear functions defined on these spaces and suitably respecting these structures. The historical roots of functional analysis lie in the study of spaces of functions and the formulation of properties of transformations of functions such as the Fourier transform as transformations defining, for example, continuous or unitary operators between function spaces. This point of view turned out to be particularly useful for the study of differential and integral equations.
|
||||
[Instruction] I am starting a new dog walking business. Can you help me find 2 possible names for the business?
|
||||
[Answer] OK, here are two possible names for a new dog walking business:
|
||||
The first option is "Paws on Patrol", and the second option is "The Dog Whisperer".
|
||||
[Instruction] {}
|
||||
[Answer]
|
95
stable_diffusion/README.md
Normal file
95
stable_diffusion/README.md
Normal file
@ -0,0 +1,95 @@
|
||||
Stable Diffusion
|
||||
================
|
||||
|
||||
Stable Diffusion in MLX. The implementation was ported from Hugginface's
|
||||
[diffusers](https://huggingface.co/docs/diffusers/index) and we are fetching
|
||||
and using the weights available on the Huggingface Hub by Stability AI at
|
||||
[stabilitiai/stable-diffusion-2-1](https://huggingface.co/stabilityai/stable-diffusion-2-1).
|
||||
|
||||

|
||||
*Image generated using Stable Diffusion in MLX and the prompt 'A big red sign saying MLX in capital letters.'*
|
||||
|
||||
Installation
|
||||
------------
|
||||
|
||||
The dependencies are minimal, namely:
|
||||
|
||||
- `safetensors` and `huggingface-hub` to load the checkpoints.
|
||||
- `regex` for the tokenization
|
||||
- `numpy` because safetensors needs to return some form of array
|
||||
- `tqdm` and `PIL` for the `txt2image.py` script
|
||||
|
||||
You can install all of the above with the `requirements.txt` as follows:
|
||||
|
||||
pip install -r requirements.txt
|
||||
|
||||
Usage
|
||||
------
|
||||
|
||||
Although each component in this repository can be used by itsself, the fastest
|
||||
way to get started is by using the `StableDiffusion` class from the `diffusion`
|
||||
module.
|
||||
|
||||
```python
|
||||
from stable_diffusion import StableDiffusion
|
||||
|
||||
# This will download all the weights from HF hub and load the models in
|
||||
# memory
|
||||
sd = StableDiffusion()
|
||||
|
||||
# This creates a python generator that returns the latent produced by the
|
||||
# reverse diffusion process.
|
||||
#
|
||||
# Because MLX is lazily evaluated iterating over this generator doesn't
|
||||
# actually perform the computation until mx.eval() is called.
|
||||
latent_generator = sd.generate_latents("A photo of an astronaut riding a horse on Mars.")
|
||||
|
||||
# Here we are evaluating each diffusion step but we could also evaluate
|
||||
# once at the end.
|
||||
for x_t in latent_generator:
|
||||
mx.simplify(x_t) # remove possible redundant computation eg reuse
|
||||
# scalars etc
|
||||
mx.eval(x_t)
|
||||
|
||||
# Now x_t is the last latent from the reverse process aka x_0. We can
|
||||
# decode it into an image using the stable diffusion VAE.
|
||||
im = sd.decode(x_t)
|
||||
```
|
||||
|
||||
The above is almost line for line the implementation of the `txt2image.py`
|
||||
script in the root of the repository. You can use the script as follows:
|
||||
|
||||
python txt2image.py "A photo of an astronaut riding a horse on Mars." --n_images 4 --n_rows 2
|
||||
|
||||
Performance
|
||||
-----------
|
||||
|
||||
The following table compares the performance of the UNet in stable diffusion.
|
||||
We report throughput in images per second for the provided `txt2image.py`
|
||||
script and the `diffusers` library using the MPS PyTorch backend.
|
||||
|
||||
At the time of writing this comparison convolutions are still some of the least
|
||||
optimized operations in MLX. Despite that, MLX still achieves **~40% higher
|
||||
throughput** than PyTorch with a batch size of 16 and ~15% higher when
|
||||
comparing the optimal batch sizes.
|
||||
|
||||
Notably, PyTorch achieves almost ~50% higher throughput for the batch size of 1
|
||||
which is unfortunate as that means that a single image can be computed faster.
|
||||
However, when starting with the models not loaded in memory and PyTorch's MPS
|
||||
graph kernels not cached, the compilation time more than accounts for this
|
||||
speed difference.
|
||||
|
||||
| Batch size | PyTorch | MLX |
|
||||
| ---------- | ----------- | ----------- |
|
||||
| 1 | 6.25 im/s | 4.17 im/s |
|
||||
| 2 | 7.14 im/s | 5.88 im/s |
|
||||
| 4 |**7.69 im/s**| 7.14 im/s |
|
||||
| 6 | 7.22 im/s | 8.00 im/s |
|
||||
| 8 | 6.89 im/s | 8.42 im/s |
|
||||
| 12 | 6.62 im/s | 8.51 im/s |
|
||||
| 16 | 6.32 im/s |**8.79 im/s**|
|
||||
|
||||
The above experiments were made on an M2 Ultra with PyTorch version 2.1,
|
||||
diffusers version 0.21.4 and transformers version 4.33.3. For the generation we
|
||||
used classifier free guidance which means that the above batch sizes result
|
||||
double the images processed by the UNet.
|
BIN
stable_diffusion/generated-mlx.png
Normal file
BIN
stable_diffusion/generated-mlx.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 185 KiB |
6
stable_diffusion/requirements.txt
Normal file
6
stable_diffusion/requirements.txt
Normal file
@ -0,0 +1,6 @@
|
||||
safetensors
|
||||
huggingface-hub
|
||||
regex
|
||||
numpy
|
||||
tqdm
|
||||
Pillow
|
96
stable_diffusion/stable_diffusion/__init__.py
Normal file
96
stable_diffusion/stable_diffusion/__init__.py
Normal file
@ -0,0 +1,96 @@
|
||||
import time
|
||||
from typing import Tuple
|
||||
|
||||
import mlx.core as mx
|
||||
|
||||
from .model_io import (
|
||||
load_unet,
|
||||
load_text_encoder,
|
||||
load_autoencoder,
|
||||
load_diffusion_config,
|
||||
load_tokenizer,
|
||||
_DEFAULT_MODEL,
|
||||
)
|
||||
from .sampler import SimpleEulerSampler
|
||||
|
||||
|
||||
def _repeat(x, n, axis):
|
||||
# Make the expanded shape
|
||||
s = x.shape
|
||||
s.insert(axis + 1, n)
|
||||
|
||||
# Expand
|
||||
x = mx.broadcast_to(mx.expand_dims(x, axis + 1), s)
|
||||
|
||||
# Make the flattened shape
|
||||
s.pop(axis + 1)
|
||||
s[axis] *= n
|
||||
|
||||
return x.reshape(s)
|
||||
|
||||
|
||||
class StableDiffusion:
|
||||
def __init__(self, model: str = _DEFAULT_MODEL, float16: bool = False):
|
||||
self.dtype = mx.float16 if float16 else mx.float32
|
||||
self.diffusion_config = load_diffusion_config(model)
|
||||
self.unet = load_unet(model, float16)
|
||||
self.text_encoder = load_text_encoder(model, float16)
|
||||
self.autoencoder = load_autoencoder(model, float16)
|
||||
self.sampler = SimpleEulerSampler(self.diffusion_config)
|
||||
self.tokenizer = load_tokenizer(model)
|
||||
|
||||
def generate_latents(
|
||||
self,
|
||||
text: str,
|
||||
n_images: int = 1,
|
||||
num_steps: int = 50,
|
||||
cfg_weight: float = 7.5,
|
||||
negative_text: str = "",
|
||||
latent_size: Tuple[int] = (64, 64),
|
||||
seed=None,
|
||||
):
|
||||
# Set the PRNG state
|
||||
seed = seed or int(time.time())
|
||||
mx.random.seed(seed)
|
||||
|
||||
# Tokenize the text
|
||||
tokens = [self.tokenizer.tokenize(text)]
|
||||
if cfg_weight > 1:
|
||||
tokens += [self.tokenizer.tokenize(negative_text)]
|
||||
lengths = [len(t) for t in tokens]
|
||||
N = max(lengths)
|
||||
tokens = [t + [0] * (N - len(t)) for t in tokens]
|
||||
tokens = mx.array(tokens)
|
||||
|
||||
# Compute the features
|
||||
conditioning = self.text_encoder(tokens)
|
||||
|
||||
# Repeat the conditioning for each of the generated images
|
||||
if n_images > 1:
|
||||
conditioning = _repeat(conditioning, n_images, axis=0)
|
||||
|
||||
# Create the latent variables
|
||||
x_T = self.sampler.sample_prior(
|
||||
(n_images, *latent_size, self.autoencoder.latent_channels),
|
||||
dtype=self.dtype
|
||||
)
|
||||
|
||||
# Perform the denoising loop
|
||||
x_t = x_T
|
||||
for t, t_prev in self.sampler.timesteps(num_steps, dtype=self.dtype):
|
||||
x_t_unet = mx.concatenate([x_t] * 2, axis=0) if cfg_weight > 1 else x_t
|
||||
t_unet = mx.broadcast_to(t, [len(x_t_unet)])
|
||||
eps_pred = self.unet(x_t_unet, t_unet, encoder_x=conditioning)
|
||||
|
||||
if cfg_weight > 1:
|
||||
eps_text, eps_neg = eps_pred.split(2)
|
||||
eps_pred = eps_neg + cfg_weight * (eps_text - eps_neg)
|
||||
|
||||
x_t_prev = self.sampler.step(eps_pred, x_t, t, t_prev)
|
||||
x_t = x_t_prev
|
||||
yield x_t
|
||||
|
||||
def decode(self, x_t):
|
||||
x = self.autoencoder.decode(x_t / self.autoencoder.scaling_factor)
|
||||
x = mx.minimum(1, mx.maximum(0, x / 2 + 0.5))
|
||||
return x
|
68
stable_diffusion/stable_diffusion/clip.py
Normal file
68
stable_diffusion/stable_diffusion/clip.py
Normal file
@ -0,0 +1,68 @@
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from .config import CLIPTextModelConfig
|
||||
|
||||
|
||||
class CLIPEncoderLayer(nn.Module):
|
||||
"""The transformer encoder layer from CLIP."""
|
||||
|
||||
def __init__(self, model_dims: int, num_heads: int):
|
||||
super().__init__()
|
||||
|
||||
self.layer_norm1 = nn.LayerNorm(model_dims)
|
||||
self.layer_norm2 = nn.LayerNorm(model_dims)
|
||||
|
||||
self.attention = nn.MultiHeadAttention(model_dims, num_heads)
|
||||
# Add biases to the attention projections to match CLIP
|
||||
self.attention.query_proj.bias = mx.zeros(model_dims)
|
||||
self.attention.key_proj.bias = mx.zeros(model_dims)
|
||||
self.attention.value_proj.bias = mx.zeros(model_dims)
|
||||
self.attention.out_proj.bias = mx.zeros(model_dims)
|
||||
|
||||
self.linear1 = nn.Linear(model_dims, 4 * model_dims)
|
||||
self.linear2 = nn.Linear(4 * model_dims, model_dims)
|
||||
|
||||
def __call__(self, x, attn_mask=None):
|
||||
y = self.layer_norm1(x)
|
||||
y = self.attention(y, y, y, attn_mask)
|
||||
x = y + x
|
||||
|
||||
y = self.layer_norm2(x)
|
||||
y = self.linear1(y)
|
||||
y = nn.gelu_approx(y)
|
||||
y = self.linear2(y)
|
||||
x = y + x
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class CLIPTextModel(nn.Module):
|
||||
"""Implements the text encoder transformer from CLIP."""
|
||||
|
||||
def __init__(self, config: CLIPTextModelConfig):
|
||||
super().__init__()
|
||||
|
||||
self.token_embedding = nn.Embedding(config.vocab_size, config.model_dims)
|
||||
self.position_embedding = nn.Embedding(config.max_length, config.model_dims)
|
||||
self.layers = [
|
||||
CLIPEncoderLayer(config.model_dims, config.num_heads)
|
||||
for i in range(config.num_layers)
|
||||
]
|
||||
self.final_layer_norm = nn.LayerNorm(config.model_dims)
|
||||
|
||||
def __call__(self, x):
|
||||
# Extract some shapes
|
||||
B, N = x.shape
|
||||
|
||||
# Compute the embeddings
|
||||
x = self.token_embedding(x)
|
||||
x = x + self.position_embedding.weight[:N]
|
||||
|
||||
# Compute the features from the transformer
|
||||
mask = nn.MultiHeadAttention.create_additive_causal_mask(N, x.dtype)
|
||||
for l in self.layers:
|
||||
x = l(x, mask)
|
||||
|
||||
# Apply the final layernorm and return
|
||||
return self.final_layer_norm(x)
|
46
stable_diffusion/stable_diffusion/config.py
Normal file
46
stable_diffusion/stable_diffusion/config.py
Normal file
@ -0,0 +1,46 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Tuple
|
||||
|
||||
|
||||
@dataclass
|
||||
class AutoencoderConfig:
|
||||
in_channels: int = 3
|
||||
out_channels: int = 3
|
||||
latent_channels_out: int = 8
|
||||
latent_channels_in: int = 4
|
||||
block_out_channels: Tuple[int] = (128, 256, 512, 512)
|
||||
layers_per_block: int = 2
|
||||
norm_num_groups: int = 32
|
||||
scaling_factor: float = 0.18215
|
||||
|
||||
|
||||
@dataclass
|
||||
class CLIPTextModelConfig:
|
||||
num_layers: int = 23
|
||||
model_dims: int = 1024
|
||||
num_heads: int = 16
|
||||
max_length: int = 77
|
||||
vocab_size: int = 49408
|
||||
|
||||
|
||||
@dataclass
|
||||
class UNetConfig:
|
||||
in_channels: int = 4
|
||||
out_channels: int = 4
|
||||
conv_in_kernel: int = 3
|
||||
conv_out_kernel: int = 3
|
||||
block_out_channels: Tuple[int] = (320, 640, 1280, 1280)
|
||||
layers_per_block: Tuple[int] = (2, 2, 2, 2)
|
||||
mid_block_layers: int = 2
|
||||
transformer_layers_per_block: Tuple[int] = (1, 1, 1, 1)
|
||||
num_attention_heads: Tuple[int] = (5, 10, 20, 20)
|
||||
cross_attention_dim: Tuple[int] = (1024,) * 4
|
||||
norm_num_groups: int = 32
|
||||
|
||||
|
||||
@dataclass
|
||||
class DiffusionConfig:
|
||||
beta_schedule: str = "scaled_linear"
|
||||
beta_start: float = 0.00085
|
||||
beta_end: float = 0.012
|
||||
num_train_steps: int = 1000
|
284
stable_diffusion/stable_diffusion/model_io.py
Normal file
284
stable_diffusion/stable_diffusion/model_io.py
Normal file
@ -0,0 +1,284 @@
|
||||
import json
|
||||
from functools import partial
|
||||
|
||||
import numpy as np
|
||||
from huggingface_hub import hf_hub_download
|
||||
from safetensors import safe_open as safetensor_open
|
||||
|
||||
import mlx.core as mx
|
||||
from mlx.utils import tree_unflatten
|
||||
|
||||
from .clip import CLIPTextModel
|
||||
from .config import UNetConfig, CLIPTextModelConfig, AutoencoderConfig, DiffusionConfig
|
||||
from .tokenizer import Tokenizer
|
||||
from .unet import UNetModel
|
||||
from .vae import Autoencoder
|
||||
|
||||
|
||||
_DEFAULT_MODEL = "stabilityai/stable-diffusion-2-1-base"
|
||||
_MODELS = {
|
||||
# See https://huggingface.co/stabilityai/stable-diffusion-2-1-base for the model details and license
|
||||
"stabilityai/stable-diffusion-2-1-base": {
|
||||
"unet_config": "unet/config.json",
|
||||
"unet": "unet/diffusion_pytorch_model.safetensors",
|
||||
"text_encoder_config": "text_encoder/config.json",
|
||||
"text_encoder": "text_encoder/model.safetensors",
|
||||
"vae_config": "vae/config.json",
|
||||
"vae": "vae/diffusion_pytorch_model.safetensors",
|
||||
"diffusion_config": "scheduler/scheduler_config.json",
|
||||
"tokenizer_vocab": "tokenizer/vocab.json",
|
||||
"tokenizer_merges": "tokenizer/merges.txt",
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
def _from_numpy(x):
|
||||
return mx.array(np.ascontiguousarray(x))
|
||||
|
||||
|
||||
def map_unet_weights(key, value):
|
||||
# Map up/downsampling
|
||||
if "downsamplers" in key:
|
||||
key = key.replace("downsamplers.0.conv", "downsample")
|
||||
if "upsamplers" in key:
|
||||
key = key.replace("upsamplers.0.conv", "upsample")
|
||||
|
||||
# Map the mid block
|
||||
if "mid_block.resnets.0" in key:
|
||||
key = key.replace("mid_block.resnets.0", "mid_blocks.0")
|
||||
if "mid_block.attentions.0" in key:
|
||||
key = key.replace("mid_block.attentions.0", "mid_blocks.1")
|
||||
if "mid_block.resnets.1" in key:
|
||||
key = key.replace("mid_block.resnets.1", "mid_blocks.2")
|
||||
|
||||
# Map attention layers
|
||||
if "to_k" in key:
|
||||
key = key.replace("to_k", "key_proj")
|
||||
if "to_out.0" in key:
|
||||
key = key.replace("to_out.0", "out_proj")
|
||||
if "to_q" in key:
|
||||
key = key.replace("to_q", "query_proj")
|
||||
if "to_v" in key:
|
||||
key = key.replace("to_v", "value_proj")
|
||||
|
||||
# Map transformer ffn
|
||||
if "ff.net.2" in key:
|
||||
key = key.replace("ff.net.2", "linear3")
|
||||
if "ff.net.0" in key:
|
||||
k1 = key.replace("ff.net.0.proj", "linear1")
|
||||
k2 = key.replace("ff.net.0.proj", "linear2")
|
||||
v1, v2 = np.split(value, 2)
|
||||
|
||||
return [(k1, _from_numpy(v1)), (k2, _from_numpy(v2))]
|
||||
|
||||
if "conv_shortcut.weight" in key:
|
||||
value = value.squeeze()
|
||||
|
||||
if len(value.shape) == 4:
|
||||
value = value.transpose(0, 2, 3, 1)
|
||||
|
||||
return [(key, _from_numpy(value))]
|
||||
|
||||
|
||||
def map_clip_text_encoder_weights(key, value):
|
||||
# Remove prefixes
|
||||
if key.startswith("text_model."):
|
||||
key = key[11:]
|
||||
if key.startswith("embeddings."):
|
||||
key = key[11:]
|
||||
if key.startswith("encoder."):
|
||||
key = key[8:]
|
||||
|
||||
# Map attention layers
|
||||
if "self_attn." in key:
|
||||
key = key.replace("self_attn.", "attention.")
|
||||
if "q_proj." in key:
|
||||
key = key.replace("q_proj.", "query_proj.")
|
||||
if "k_proj." in key:
|
||||
key = key.replace("k_proj.", "key_proj.")
|
||||
if "v_proj." in key:
|
||||
key = key.replace("v_proj.", "value_proj.")
|
||||
|
||||
# Map ffn layers
|
||||
if "mlp.fc1" in key:
|
||||
key = key.replace("mlp.fc1", "linear1")
|
||||
if "mlp.fc2" in key:
|
||||
key = key.replace("mlp.fc2", "linear2")
|
||||
|
||||
return [(key, _from_numpy(value))]
|
||||
|
||||
|
||||
def map_vae_weights(key, value):
|
||||
# Map up/downsampling
|
||||
if "downsamplers" in key:
|
||||
key = key.replace("downsamplers.0.conv", "downsample")
|
||||
if "upsamplers" in key:
|
||||
key = key.replace("upsamplers.0.conv", "upsample")
|
||||
|
||||
# Map attention layers
|
||||
if "to_k" in key:
|
||||
key = key.replace("to_k", "key_proj")
|
||||
if "to_out.0" in key:
|
||||
key = key.replace("to_out.0", "out_proj")
|
||||
if "to_q" in key:
|
||||
key = key.replace("to_q", "query_proj")
|
||||
if "to_v" in key:
|
||||
key = key.replace("to_v", "value_proj")
|
||||
|
||||
# Map the mid block
|
||||
if "mid_block.resnets.0" in key:
|
||||
key = key.replace("mid_block.resnets.0", "mid_blocks.0")
|
||||
if "mid_block.attentions.0" in key:
|
||||
key = key.replace("mid_block.attentions.0", "mid_blocks.1")
|
||||
if "mid_block.resnets.1" in key:
|
||||
key = key.replace("mid_block.resnets.1", "mid_blocks.2")
|
||||
|
||||
# Map the quant/post_quant layers
|
||||
if "quant_conv" in key:
|
||||
key = key.replace("quant_conv", "quant_proj")
|
||||
value = value.squeeze()
|
||||
|
||||
# Map the conv_shortcut to linear
|
||||
if "conv_shortcut.weight" in key:
|
||||
value = value.squeeze()
|
||||
|
||||
if len(value.shape) == 4:
|
||||
value = value.transpose(0, 2, 3, 1)
|
||||
|
||||
return [(key, _from_numpy(value))]
|
||||
|
||||
|
||||
def _flatten(params):
|
||||
return [(k, v) for p in params for (k, v) in p]
|
||||
|
||||
|
||||
def _load_safetensor_weights(mapper, model, weight_file, float16: bool = False):
|
||||
dtype = np.float16 if float16 else np.float32
|
||||
with safetensor_open(weight_file, framework="numpy") as f:
|
||||
weights = _flatten([mapper(k, f.get_tensor(k).astype(dtype)) for k in f.keys()])
|
||||
model.update(tree_unflatten(weights))
|
||||
|
||||
|
||||
def _check_key(key: str, part: str):
|
||||
if key not in _MODELS:
|
||||
raise ValueError(
|
||||
f"[{part}] '{key}' model not found, choose one of {{{','.join(_MODELS.keys())}}}"
|
||||
)
|
||||
|
||||
|
||||
def load_unet(key: str = _DEFAULT_MODEL, float16: bool = False):
|
||||
"""Load the stable diffusion UNet from Huggingface Hub."""
|
||||
_check_key(key, "load_unet")
|
||||
|
||||
# Download the config and create the model
|
||||
unet_config = _MODELS[key]["unet_config"]
|
||||
with open(hf_hub_download(key, unet_config)) as f:
|
||||
config = json.load(f)
|
||||
|
||||
n_blocks = len(config["block_out_channels"])
|
||||
model = UNetModel(
|
||||
UNetConfig(
|
||||
in_channels=config["in_channels"],
|
||||
out_channels=config["out_channels"],
|
||||
block_out_channels=config["block_out_channels"],
|
||||
layers_per_block=[config["layers_per_block"]] * n_blocks,
|
||||
num_attention_heads=config["attention_head_dim"],
|
||||
cross_attention_dim=[config["cross_attention_dim"]] * n_blocks,
|
||||
norm_num_groups=config["norm_num_groups"],
|
||||
)
|
||||
)
|
||||
|
||||
# Download the weights and map them into the model
|
||||
unet_weights = _MODELS[key]["unet"]
|
||||
weight_file = hf_hub_download(key, unet_weights)
|
||||
_load_safetensor_weights(map_unet_weights, model, weight_file, float16)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def load_text_encoder(key: str = _DEFAULT_MODEL, float16: bool = False):
|
||||
"""Load the stable diffusion text encoder from Huggingface Hub."""
|
||||
_check_key(key, "load_text_encoder")
|
||||
|
||||
# Download the config and create the model
|
||||
text_encoder_config = _MODELS[key]["text_encoder_config"]
|
||||
with open(hf_hub_download(key, text_encoder_config)) as f:
|
||||
config = json.load(f)
|
||||
|
||||
model = CLIPTextModel(
|
||||
CLIPTextModelConfig(
|
||||
num_layers=config["num_hidden_layers"],
|
||||
model_dims=config["hidden_size"],
|
||||
num_heads=config["num_attention_heads"],
|
||||
max_length=config["max_position_embeddings"],
|
||||
vocab_size=config["vocab_size"],
|
||||
)
|
||||
)
|
||||
|
||||
# Download the weights and map them into the model
|
||||
text_encoder_weights = _MODELS[key]["text_encoder"]
|
||||
weight_file = hf_hub_download(key, text_encoder_weights)
|
||||
_load_safetensor_weights(map_clip_text_encoder_weights, model, weight_file, float16)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def load_autoencoder(key: str = _DEFAULT_MODEL, float16: bool = False):
|
||||
"""Load the stable diffusion autoencoder from Huggingface Hub."""
|
||||
_check_key(key, "load_autoencoder")
|
||||
|
||||
# Download the config and create the model
|
||||
vae_config = _MODELS[key]["vae_config"]
|
||||
with open(hf_hub_download(key, vae_config)) as f:
|
||||
config = json.load(f)
|
||||
|
||||
model = Autoencoder(
|
||||
AutoencoderConfig(
|
||||
in_channels=config["in_channels"],
|
||||
out_channels=config["out_channels"],
|
||||
latent_channels_out=2 * config["latent_channels"],
|
||||
latent_channels_in=config["latent_channels"],
|
||||
block_out_channels=config["block_out_channels"],
|
||||
layers_per_block=config["layers_per_block"],
|
||||
norm_num_groups=config["norm_num_groups"],
|
||||
)
|
||||
)
|
||||
|
||||
# Download the weights and map them into the model
|
||||
vae_weights = _MODELS[key]["vae"]
|
||||
weight_file = hf_hub_download(key, vae_weights)
|
||||
_load_safetensor_weights(map_vae_weights, model, weight_file, float16)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def load_diffusion_config(key: str = _DEFAULT_MODEL):
|
||||
"""Load the stable diffusion config from Huggingface Hub."""
|
||||
_check_key(key, "load_diffusion_config")
|
||||
|
||||
diffusion_config = _MODELS[key]["diffusion_config"]
|
||||
with open(hf_hub_download(key, diffusion_config)) as f:
|
||||
config = json.load(f)
|
||||
|
||||
return DiffusionConfig(
|
||||
beta_start=config["beta_start"],
|
||||
beta_end=config["beta_end"],
|
||||
beta_schedule=config["beta_schedule"],
|
||||
num_train_steps=config["num_train_timesteps"],
|
||||
)
|
||||
|
||||
|
||||
def load_tokenizer(key: str = _DEFAULT_MODEL):
|
||||
_check_key(key, "load_tokenizer")
|
||||
|
||||
vocab_file = hf_hub_download(key, _MODELS[key]["tokenizer_vocab"])
|
||||
with open(vocab_file, encoding="utf-8") as f:
|
||||
vocab = json.load(f)
|
||||
|
||||
merges_file = hf_hub_download(key, _MODELS[key]["tokenizer_merges"])
|
||||
with open(merges_file, encoding="utf-8") as f:
|
||||
bpe_merges = f.read().strip().split("\n")[1 : 49152 - 256 - 2 + 1]
|
||||
bpe_merges = [tuple(m.split()) for m in bpe_merges]
|
||||
bpe_ranks = dict(map(reversed, enumerate(bpe_merges)))
|
||||
|
||||
return Tokenizer(bpe_ranks, vocab)
|
70
stable_diffusion/stable_diffusion/sampler.py
Normal file
70
stable_diffusion/stable_diffusion/sampler.py
Normal file
@ -0,0 +1,70 @@
|
||||
from .config import DiffusionConfig
|
||||
|
||||
import mlx.core as mx
|
||||
|
||||
|
||||
def _linspace(a, b, num):
|
||||
x = mx.arange(0, num) / (num - 1)
|
||||
return (b - a) * x + a
|
||||
|
||||
|
||||
def _interp(y, x_new):
|
||||
"""Interpolate the function defined by (arange(0, len(y)), y) at positions x_new."""
|
||||
x_low = x_new.astype(mx.int32)
|
||||
x_high = mx.minimum(x_low + 1, len(y) - 1)
|
||||
|
||||
y_low = y[x_low]
|
||||
y_high = y[x_high]
|
||||
delta_x = x_new - x_low
|
||||
y_new = y_low * (1 - delta_x) + delta_x * y_high
|
||||
|
||||
return y_new
|
||||
|
||||
|
||||
class SimpleEulerSampler:
|
||||
"""A simple Euler integrator that can be used to sample from our diffusion models.
|
||||
|
||||
The method ``step()`` performs one Euler step from x_t to x_t_prev.
|
||||
"""
|
||||
|
||||
def __init__(self, config: DiffusionConfig):
|
||||
# Compute the noise schedule
|
||||
if config.beta_schedule == "linear":
|
||||
betas = _linspace(
|
||||
config.beta_start, config.beta_end, config.num_train_steps
|
||||
)
|
||||
elif config.beta_schedule == "scaled_linear":
|
||||
betas = _linspace(
|
||||
config.beta_start**0.5, config.beta_end**0.5, config.num_train_steps
|
||||
).square()
|
||||
else:
|
||||
raise NotImplementedError(f"{config.beta_schedule} is not implemented.")
|
||||
|
||||
alphas = 1 - betas
|
||||
alphas_cumprod = mx.cumprod(alphas)
|
||||
|
||||
self._sigmas = mx.concatenate(
|
||||
[mx.zeros(1), ((1 - alphas_cumprod) / alphas_cumprod).sqrt()]
|
||||
)
|
||||
|
||||
def sample_prior(self, shape, dtype=mx.float32, key=None):
|
||||
noise = mx.random.normal(shape, key=key)
|
||||
return (noise * self._sigmas[-1] * (self._sigmas[-1].square() + 1).rsqrt()).astype(dtype)
|
||||
|
||||
def sigmas(self, t):
|
||||
return _interp(self._sigmas, t)
|
||||
|
||||
def timesteps(self, num_steps: int, dtype=mx.float32):
|
||||
steps = _linspace(len(self._sigmas) - 1, 0, num_steps + 1).astype(dtype)
|
||||
return list(zip(steps, steps[1:]))
|
||||
|
||||
def step(self, eps_pred, x_t, t, t_prev):
|
||||
sigma = self.sigmas(t).astype(eps_pred.dtype)
|
||||
sigma_prev = self.sigmas(t_prev).astype(eps_pred.dtype)
|
||||
|
||||
dt = sigma_prev - sigma
|
||||
x_t_prev = (sigma.square() + 1).sqrt() * x_t + eps_pred * dt
|
||||
|
||||
x_t_prev = x_t_prev * (sigma_prev.square() + 1).rsqrt()
|
||||
|
||||
return x_t_prev
|
95
stable_diffusion/stable_diffusion/tokenizer.py
Normal file
95
stable_diffusion/stable_diffusion/tokenizer.py
Normal file
@ -0,0 +1,95 @@
|
||||
import regex
|
||||
|
||||
|
||||
class Tokenizer:
|
||||
"""A simple port of CLIPTokenizer from https://github.com/huggingface/transformers/ ."""
|
||||
def __init__(self, bpe_ranks, vocab):
|
||||
self.bpe_ranks = bpe_ranks
|
||||
self.vocab = vocab
|
||||
self.pat = regex.compile(
|
||||
r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""",
|
||||
regex.IGNORECASE,
|
||||
)
|
||||
|
||||
self._cache = {self.bos: self.bos, self.eos: self.eos}
|
||||
|
||||
@property
|
||||
def bos(self):
|
||||
return "<|startoftext|>"
|
||||
|
||||
@property
|
||||
def bos_token(self):
|
||||
return self.vocab[self.bos]
|
||||
|
||||
@property
|
||||
def eos(self):
|
||||
return "<|endoftext|>"
|
||||
|
||||
@property
|
||||
def eos_token(self):
|
||||
return self.vocab[self.eos]
|
||||
|
||||
def bpe(self, text):
|
||||
if text in self._cache:
|
||||
return self._cache[text]
|
||||
|
||||
unigrams = list(text[:-1]) + [text[-1] + "</w>"]
|
||||
unique_bigrams = set(zip(unigrams, unigrams[1:]))
|
||||
|
||||
if not unique_bigrams:
|
||||
return unigrams
|
||||
|
||||
# In every iteration try to merge the two most likely bigrams. If none
|
||||
# was merged we are done.
|
||||
#
|
||||
# Ported from https://github.com/huggingface/transformers/blob/main/src/transformers/models/clip/tokenization_clip.py
|
||||
while unique_bigrams:
|
||||
bigram = min(unique_bigrams, key=lambda pair: self.bpe_ranks.get(pair, float("inf")))
|
||||
if bigram not in self.bpe_ranks:
|
||||
break
|
||||
|
||||
new_unigrams = []
|
||||
skip = False
|
||||
for a, b in zip(unigrams, unigrams[1:]):
|
||||
if skip:
|
||||
skip = False
|
||||
continue
|
||||
|
||||
if (a, b) == bigram:
|
||||
new_unigrams.append(a + b)
|
||||
skip = True
|
||||
|
||||
else:
|
||||
new_unigrams.append(a)
|
||||
|
||||
if not skip:
|
||||
new_unigrams.append(b)
|
||||
|
||||
unigrams = new_unigrams
|
||||
unique_bigrams = set(zip(unigrams, unigrams[1:]))
|
||||
|
||||
self._cache[text] = unigrams
|
||||
|
||||
return unigrams
|
||||
|
||||
def tokenize(self, text, prepend_bos=True, append_eos=True):
|
||||
if isinstance(text, list):
|
||||
return [self.tokenize(t, prepend_bos, append_eos) for t in text]
|
||||
|
||||
# Lower case cleanup and split according to self.pat. Huggingface does
|
||||
# a much more thorough job here but this should suffice for 95% of
|
||||
# cases.
|
||||
clean_text = regex.sub(r"\s+", " ", text.lower())
|
||||
tokens = regex.findall(self.pat, clean_text)
|
||||
|
||||
# Split the tokens according to the byte-pair merge file
|
||||
bpe_tokens = [ti for t in tokens for ti in self.bpe(t)]
|
||||
|
||||
# Map to token ids and return
|
||||
tokens = [self.vocab[t] for t in bpe_tokens]
|
||||
if prepend_bos:
|
||||
tokens = [self.bos_token] + tokens
|
||||
if append_eos:
|
||||
tokens.append(self.eos_token)
|
||||
|
||||
return tokens
|
423
stable_diffusion/stable_diffusion/unet.py
Normal file
423
stable_diffusion/stable_diffusion/unet.py
Normal file
@ -0,0 +1,423 @@
|
||||
import math
|
||||
from typing import Optional
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from .config import UNetConfig
|
||||
|
||||
|
||||
def upsample_nearest(x, scale: int = 2):
|
||||
B, H, W, C = x.shape
|
||||
x = mx.broadcast_to(x[:, :, None, :, None, :], (B, H, scale, W, scale, C))
|
||||
x = x.reshape(B, H * scale, W * scale, C)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class TimestepEmbedding(nn.Module):
|
||||
def __init__(self, in_channels: int, time_embed_dim: int):
|
||||
super().__init__()
|
||||
|
||||
self.linear_1 = nn.Linear(in_channels, time_embed_dim)
|
||||
self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim)
|
||||
|
||||
def __call__(self, x):
|
||||
x = self.linear_1(x)
|
||||
x = nn.silu(x)
|
||||
x = self.linear_2(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class TransformerBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
model_dims: int,
|
||||
num_heads: int,
|
||||
hidden_dims: Optional[int] = None,
|
||||
memory_dims: Optional[int] = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.norm1 = nn.LayerNorm(model_dims)
|
||||
self.attn1 = nn.MultiHeadAttention(model_dims, num_heads)
|
||||
self.attn1.out_proj.bias = mx.zeros(model_dims)
|
||||
|
||||
memory_dims = memory_dims or model_dims
|
||||
self.norm2 = nn.LayerNorm(model_dims)
|
||||
self.attn2 = nn.MultiHeadAttention(
|
||||
model_dims, num_heads, key_input_dims=memory_dims
|
||||
)
|
||||
self.attn2.out_proj.bias = mx.zeros(model_dims)
|
||||
|
||||
hidden_dims = hidden_dims or 4 * model_dims
|
||||
self.norm3 = nn.LayerNorm(model_dims)
|
||||
self.linear1 = nn.Linear(model_dims, hidden_dims)
|
||||
self.linear2 = nn.Linear(model_dims, hidden_dims)
|
||||
self.linear3 = nn.Linear(hidden_dims, model_dims)
|
||||
|
||||
def __call__(self, x, memory, attn_mask, memory_mask):
|
||||
# Self attention
|
||||
y = self.norm1(x)
|
||||
y = self.attn1(y, y, y, attn_mask)
|
||||
x = x + y
|
||||
|
||||
# Cross attention
|
||||
y = self.norm2(x)
|
||||
y = self.attn2(y, memory, memory, memory_mask)
|
||||
x = x + y
|
||||
|
||||
# FFN
|
||||
y = self.norm3(x)
|
||||
y_a = self.linear1(y)
|
||||
y_b = self.linear2(y)
|
||||
y = y_a * nn.gelu_approx(y_b) # approximate gelu?
|
||||
y = self.linear3(y)
|
||||
x = x + y
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class Transformer2D(nn.Module):
|
||||
"""A transformer model for inputs with 2 spatial dimensions."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
model_dims: int,
|
||||
encoder_dims: int,
|
||||
num_heads: int,
|
||||
num_layers: int = 1,
|
||||
norm_num_groups: int = 32,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.norm = nn.GroupNorm(norm_num_groups, in_channels, pytorch_compatible=True)
|
||||
self.proj_in = nn.Linear(in_channels, model_dims)
|
||||
self.transformer_blocks = [
|
||||
TransformerBlock(model_dims, num_heads, memory_dims=encoder_dims)
|
||||
for i in range(num_layers)
|
||||
]
|
||||
self.proj_out = nn.Linear(model_dims, in_channels)
|
||||
|
||||
def __call__(self, x, encoder_x, attn_mask, encoder_attn_mask):
|
||||
# Save the input to add to the output
|
||||
input_x = x
|
||||
|
||||
# Perform the input norm and projection
|
||||
B, H, W, C = x.shape
|
||||
x = self.norm(x).reshape(B, -1, C)
|
||||
x = self.proj_in(x)
|
||||
|
||||
# Apply the transformer
|
||||
for block in self.transformer_blocks:
|
||||
x = block(x, encoder_x, attn_mask, encoder_attn_mask)
|
||||
|
||||
# Apply the output projection and reshape
|
||||
x = self.proj_out(x)
|
||||
x = x.reshape(B, H, W, C)
|
||||
|
||||
return x + input_x
|
||||
|
||||
|
||||
class ResnetBlock2D(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: Optional[int] = None,
|
||||
groups: int = 32,
|
||||
temb_channels: Optional[int] = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
out_channels = out_channels or in_channels
|
||||
|
||||
self.norm1 = nn.GroupNorm(groups, in_channels, pytorch_compatible=True)
|
||||
self.conv1 = nn.Conv2d(
|
||||
in_channels, out_channels, kernel_size=3, stride=1, padding=1
|
||||
)
|
||||
if temb_channels is not None:
|
||||
self.time_emb_proj = nn.Linear(temb_channels, out_channels)
|
||||
self.norm2 = nn.GroupNorm(groups, out_channels, pytorch_compatible=True)
|
||||
self.conv2 = nn.Conv2d(
|
||||
out_channels, out_channels, kernel_size=3, stride=1, padding=1
|
||||
)
|
||||
|
||||
if in_channels != out_channels:
|
||||
self.conv_shortcut = nn.Linear(in_channels, out_channels)
|
||||
|
||||
def __call__(self, x, temb=None):
|
||||
if temb is not None:
|
||||
temb = self.time_emb_proj(nn.silu(temb))
|
||||
|
||||
y = self.norm1(x)
|
||||
y = nn.silu(y)
|
||||
y = self.conv1(y)
|
||||
if temb is not None:
|
||||
y = y + temb[:, None, None, :]
|
||||
y = self.norm2(y)
|
||||
y = nn.silu(y)
|
||||
y = self.conv2(y)
|
||||
|
||||
x = y + (x if "conv_shortcut" not in self else self.conv_shortcut(x))
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class UNetBlock2D(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
temb_channels: int,
|
||||
prev_out_channels: Optional[int] = None,
|
||||
num_layers: int = 1,
|
||||
transformer_layers_per_block: int = 1,
|
||||
num_attention_heads: int = 8,
|
||||
cross_attention_dim=1280,
|
||||
resnet_groups: int = 32,
|
||||
add_downsample=True,
|
||||
add_upsample=True,
|
||||
add_cross_attention=True,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# Prepare the in channels list for the resnets
|
||||
if prev_out_channels is None:
|
||||
in_channels_list = [in_channels] + [out_channels] * (num_layers - 1)
|
||||
else:
|
||||
in_channels_list = [prev_out_channels] + [out_channels] * (num_layers - 1)
|
||||
res_channels_list = [out_channels] * (num_layers - 1) + [in_channels]
|
||||
in_channels_list = [
|
||||
a + b for a, b in zip(in_channels_list, res_channels_list)
|
||||
]
|
||||
|
||||
# Add resnet blocks that also process the time embedding
|
||||
self.resnets = [
|
||||
ResnetBlock2D(
|
||||
in_channels=ic,
|
||||
out_channels=out_channels,
|
||||
temb_channels=temb_channels,
|
||||
groups=resnet_groups,
|
||||
)
|
||||
for ic in in_channels_list
|
||||
]
|
||||
|
||||
# Add optional cross attention layers
|
||||
if add_cross_attention:
|
||||
self.attentions = [
|
||||
Transformer2D(
|
||||
in_channels=out_channels,
|
||||
model_dims=out_channels,
|
||||
num_heads=num_attention_heads,
|
||||
num_layers=transformer_layers_per_block,
|
||||
encoder_dims=cross_attention_dim,
|
||||
)
|
||||
for i in range(num_layers)
|
||||
]
|
||||
|
||||
# Add an optional downsampling layer
|
||||
if add_downsample:
|
||||
self.downsample = nn.Conv2d(
|
||||
out_channels, out_channels, kernel_size=3, stride=2, padding=1
|
||||
)
|
||||
|
||||
# or upsampling layer
|
||||
if add_upsample:
|
||||
self.upsample = nn.Conv2d(
|
||||
out_channels, out_channels, kernel_size=3, stride=1, padding=1
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x,
|
||||
encoder_x=None,
|
||||
temb=None,
|
||||
attn_mask=None,
|
||||
encoder_attn_mask=None,
|
||||
residual_hidden_states=None,
|
||||
):
|
||||
output_states = []
|
||||
|
||||
for i in range(len(self.resnets)):
|
||||
if residual_hidden_states is not None:
|
||||
x = mx.concatenate([x, residual_hidden_states.pop()], axis=-1)
|
||||
|
||||
x = self.resnets[i](x, temb)
|
||||
|
||||
if "attentions" in self:
|
||||
x = self.attentions[i](x, encoder_x, attn_mask, encoder_attn_mask)
|
||||
|
||||
output_states.append(x)
|
||||
|
||||
if "downsample" in self:
|
||||
x = self.downsample(x)
|
||||
output_states.append(x)
|
||||
|
||||
if "upsample" in self:
|
||||
x = self.upsample(upsample_nearest(x))
|
||||
output_states.append(x)
|
||||
|
||||
return x, output_states
|
||||
|
||||
|
||||
class UNetModel(nn.Module):
|
||||
"""The conditional 2D UNet model that actually performs the denoising."""
|
||||
|
||||
def __init__(self, config: UNetConfig):
|
||||
super().__init__()
|
||||
|
||||
self.conv_in = nn.Conv2d(
|
||||
config.in_channels,
|
||||
config.block_out_channels[0],
|
||||
config.conv_in_kernel,
|
||||
padding=(config.conv_in_kernel - 1) // 2,
|
||||
)
|
||||
|
||||
self.timesteps = nn.SinusoidalPositionalEncoding(
|
||||
config.block_out_channels[0],
|
||||
max_freq=1,
|
||||
min_freq=math.exp(
|
||||
-math.log(10000) + 2 * math.log(10000) / config.block_out_channels[0]
|
||||
),
|
||||
scale=1.0,
|
||||
cos_first=True,
|
||||
full_turns=False,
|
||||
)
|
||||
self.time_embedding = TimestepEmbedding(
|
||||
config.block_out_channels[0],
|
||||
config.block_out_channels[0] * 4,
|
||||
)
|
||||
|
||||
# Make the downsampling blocks
|
||||
block_channels = [config.block_out_channels[0]] + list(
|
||||
config.block_out_channels
|
||||
)
|
||||
self.down_blocks = [
|
||||
UNetBlock2D(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
temb_channels=config.block_out_channels[0] * 4,
|
||||
num_layers=config.layers_per_block[i],
|
||||
transformer_layers_per_block=config.transformer_layers_per_block[i],
|
||||
num_attention_heads=config.num_attention_heads[i],
|
||||
cross_attention_dim=config.cross_attention_dim[i],
|
||||
resnet_groups=config.norm_num_groups,
|
||||
add_downsample=(i < len(config.block_out_channels) - 1),
|
||||
add_upsample=False,
|
||||
add_cross_attention=(i < len(config.block_out_channels) - 1),
|
||||
)
|
||||
for i, (in_channels, out_channels) in enumerate(
|
||||
zip(block_channels, block_channels[1:])
|
||||
)
|
||||
]
|
||||
|
||||
# Make the middle block
|
||||
self.mid_blocks = [
|
||||
ResnetBlock2D(
|
||||
in_channels=config.block_out_channels[-1],
|
||||
out_channels=config.block_out_channels[-1],
|
||||
temb_channels=config.block_out_channels[0] * 4,
|
||||
groups=config.norm_num_groups,
|
||||
),
|
||||
Transformer2D(
|
||||
in_channels=config.block_out_channels[-1],
|
||||
model_dims=config.block_out_channels[-1],
|
||||
num_heads=config.num_attention_heads[-1],
|
||||
num_layers=config.transformer_layers_per_block[-1],
|
||||
encoder_dims=config.cross_attention_dim[-1],
|
||||
),
|
||||
ResnetBlock2D(
|
||||
in_channels=config.block_out_channels[-1],
|
||||
out_channels=config.block_out_channels[-1],
|
||||
temb_channels=config.block_out_channels[0] * 4,
|
||||
groups=config.norm_num_groups,
|
||||
),
|
||||
]
|
||||
|
||||
# Make the upsampling blocks
|
||||
block_channels = (
|
||||
[config.block_out_channels[0]]
|
||||
+ list(config.block_out_channels)
|
||||
+ [config.block_out_channels[-1]]
|
||||
)
|
||||
self.up_blocks = [
|
||||
UNetBlock2D(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
temb_channels=config.block_out_channels[0] * 4,
|
||||
prev_out_channels=prev_out_channels,
|
||||
num_layers=config.layers_per_block[i] + 1,
|
||||
transformer_layers_per_block=config.transformer_layers_per_block[i],
|
||||
num_attention_heads=config.num_attention_heads[i],
|
||||
cross_attention_dim=config.cross_attention_dim[i],
|
||||
resnet_groups=config.norm_num_groups,
|
||||
add_downsample=False,
|
||||
add_upsample=(i > 0),
|
||||
add_cross_attention=(i < len(config.block_out_channels) - 1),
|
||||
)
|
||||
for i, (in_channels, out_channels, prev_out_channels) in reversed(
|
||||
list(
|
||||
enumerate(
|
||||
zip(block_channels, block_channels[1:], block_channels[2:])
|
||||
)
|
||||
)
|
||||
)
|
||||
]
|
||||
|
||||
self.conv_norm_out = nn.GroupNorm(
|
||||
config.norm_num_groups,
|
||||
config.block_out_channels[0],
|
||||
pytorch_compatible=True,
|
||||
)
|
||||
self.conv_out = nn.Conv2d(
|
||||
config.block_out_channels[0],
|
||||
config.out_channels,
|
||||
config.conv_out_kernel,
|
||||
padding=(config.conv_out_kernel - 1) // 2,
|
||||
)
|
||||
|
||||
def __call__(self, x, timestep, encoder_x, attn_mask=None, encoder_attn_mask=None):
|
||||
|
||||
# Compute the time embeddings
|
||||
temb = self.timesteps(timestep).astype(x.dtype)
|
||||
temb = self.time_embedding(temb)
|
||||
|
||||
# Preprocess the input
|
||||
x = self.conv_in(x)
|
||||
|
||||
# Run the downsampling part of the unet
|
||||
residuals = [x]
|
||||
for block in self.down_blocks:
|
||||
x, res = block(
|
||||
x,
|
||||
encoder_x=encoder_x,
|
||||
temb=temb,
|
||||
attn_mask=attn_mask,
|
||||
encoder_attn_mask=encoder_attn_mask,
|
||||
)
|
||||
residuals.extend(res)
|
||||
|
||||
# Run the middle part of the unet
|
||||
x = self.mid_blocks[0](x, temb)
|
||||
x = self.mid_blocks[1](x, encoder_x, attn_mask, encoder_attn_mask)
|
||||
x = self.mid_blocks[2](x, temb)
|
||||
|
||||
# Run the upsampling part of the unet
|
||||
for block in self.up_blocks:
|
||||
x, _ = block(
|
||||
x,
|
||||
encoder_x=encoder_x,
|
||||
temb=temb,
|
||||
attn_mask=attn_mask,
|
||||
encoder_attn_mask=encoder_attn_mask,
|
||||
residual_hidden_states=residuals,
|
||||
)
|
||||
|
||||
# Postprocess the output
|
||||
x = self.conv_norm_out(x)
|
||||
x = nn.silu(x)
|
||||
x = self.conv_out(x)
|
||||
|
||||
return x
|
266
stable_diffusion/stable_diffusion/vae.py
Normal file
266
stable_diffusion/stable_diffusion/vae.py
Normal file
@ -0,0 +1,266 @@
|
||||
import math
|
||||
from typing import List
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from .config import AutoencoderConfig
|
||||
from .unet import ResnetBlock2D, upsample_nearest
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
"""A single head unmasked attention for use with the VAE."""
|
||||
|
||||
def __init__(self, dims: int, norm_groups: int = 32):
|
||||
super().__init__()
|
||||
|
||||
self.group_norm = nn.GroupNorm(norm_groups, dims, pytorch_compatible=True)
|
||||
self.query_proj = nn.Linear(dims, dims)
|
||||
self.key_proj = nn.Linear(dims, dims)
|
||||
self.value_proj = nn.Linear(dims, dims)
|
||||
self.out_proj = nn.Linear(dims, dims)
|
||||
|
||||
def __call__(self, x):
|
||||
B, H, W, C = x.shape
|
||||
|
||||
y = self.group_norm(x)
|
||||
|
||||
queries = self.query_proj(y).reshape(B, H * W, C)
|
||||
keys = self.key_proj(y).reshape(B, H * W, C)
|
||||
values = self.value_proj(y).reshape(B, H * W, C)
|
||||
|
||||
scale = 1 / math.sqrt(queries.shape[-1])
|
||||
scores = (queries * scale) @ keys.transpose(0, 2, 1)
|
||||
attn = mx.softmax(scores, axis=-1)
|
||||
y = (attn @ values).reshape(B, H, W, C)
|
||||
|
||||
y = self.out_proj(y)
|
||||
x = x + y
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class EncoderDecoderBlock2D(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
num_layers: int = 1,
|
||||
resnet_groups: int = 32,
|
||||
add_downsample=True,
|
||||
add_upsample=True,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# Add the resnet blocks
|
||||
self.resnets = [
|
||||
ResnetBlock2D(
|
||||
in_channels=in_channels if i == 0 else out_channels,
|
||||
out_channels=out_channels,
|
||||
groups=resnet_groups,
|
||||
)
|
||||
for i in range(num_layers)
|
||||
]
|
||||
|
||||
# Add an optional downsampling layer
|
||||
if add_downsample:
|
||||
self.downsample = nn.Conv2d(
|
||||
out_channels, out_channels, kernel_size=3, stride=2, padding=1
|
||||
)
|
||||
|
||||
# or upsampling layer
|
||||
if add_upsample:
|
||||
self.upsample = nn.Conv2d(
|
||||
out_channels, out_channels, kernel_size=3, stride=1, padding=1
|
||||
)
|
||||
|
||||
def __call__(self, x):
|
||||
for resnet in self.resnets:
|
||||
x = resnet(x)
|
||||
|
||||
if "downsample" in self:
|
||||
x = self.downsample(x)
|
||||
|
||||
if "upsample" in self:
|
||||
x = self.upsample(upsample_nearest(x))
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class Encoder(nn.Module):
|
||||
"""Implements the encoder side of the Autoencoder."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
block_out_channels: List[int] = [64],
|
||||
layers_per_block: int = 2,
|
||||
resnet_groups: int = 32,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.conv_in = nn.Conv2d(
|
||||
in_channels, block_out_channels[0], kernel_size=3, stride=1, padding=1
|
||||
)
|
||||
|
||||
channels = [block_out_channels[0]] + list(block_out_channels)
|
||||
self.down_blocks = [
|
||||
EncoderDecoderBlock2D(
|
||||
in_channels,
|
||||
out_channels,
|
||||
num_layers=layers_per_block,
|
||||
resnet_groups=resnet_groups,
|
||||
add_downsample=i < len(block_out_channels) - 1,
|
||||
add_upsample=False,
|
||||
)
|
||||
for i, (in_channels, out_channels) in enumerate(zip(channels, channels[1:]))
|
||||
]
|
||||
|
||||
self.mid_blocks = [
|
||||
ResnetBlock2D(
|
||||
in_channels=block_out_channels[-1],
|
||||
out_channels=block_out_channels[-1],
|
||||
groups=resnet_groups,
|
||||
),
|
||||
Attention(block_out_channels[-1], resnet_groups),
|
||||
ResnetBlock2D(
|
||||
in_channels=block_out_channels[-1],
|
||||
out_channels=block_out_channels[-1],
|
||||
groups=resnet_groups,
|
||||
),
|
||||
]
|
||||
|
||||
self.conv_norm_out = nn.GroupNorm(
|
||||
resnet_groups, block_out_channels[-1], pytorch_compatible=True
|
||||
)
|
||||
self.conv_out = nn.Conv2d(block_out_channels[-1], out_channels, 3, padding=1)
|
||||
|
||||
def __call__(self, x):
|
||||
x = self.conv_in(x)
|
||||
|
||||
for l in self.down_blocks:
|
||||
x = l(x)
|
||||
|
||||
x = self.mid_blocks[0](x)
|
||||
x = self.mid_blocks[1](x)
|
||||
x = self.mid_blocks[2](x)
|
||||
|
||||
x = self.conv_norm_out(x)
|
||||
x = nn.silu(x)
|
||||
x = self.conv_out(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class Decoder(nn.Module):
|
||||
"""Implements the decoder side of the Autoencoder."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
block_out_channels: List[int] = [64],
|
||||
layers_per_block: int = 2,
|
||||
resnet_groups: int = 32,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.conv_in = nn.Conv2d(
|
||||
in_channels, block_out_channels[-1], kernel_size=3, stride=1, padding=1
|
||||
)
|
||||
|
||||
self.mid_blocks = [
|
||||
ResnetBlock2D(
|
||||
in_channels=block_out_channels[-1],
|
||||
out_channels=block_out_channels[-1],
|
||||
groups=resnet_groups,
|
||||
),
|
||||
Attention(block_out_channels[-1], resnet_groups),
|
||||
ResnetBlock2D(
|
||||
in_channels=block_out_channels[-1],
|
||||
out_channels=block_out_channels[-1],
|
||||
groups=resnet_groups,
|
||||
),
|
||||
]
|
||||
|
||||
channels = list(reversed(block_out_channels))
|
||||
channels = [channels[0]] + channels
|
||||
self.up_blocks = [
|
||||
EncoderDecoderBlock2D(
|
||||
in_channels,
|
||||
out_channels,
|
||||
num_layers=layers_per_block,
|
||||
resnet_groups=resnet_groups,
|
||||
add_downsample=False,
|
||||
add_upsample=i < len(block_out_channels) - 1,
|
||||
)
|
||||
for i, (in_channels, out_channels) in enumerate(zip(channels, channels[1:]))
|
||||
]
|
||||
|
||||
self.conv_norm_out = nn.GroupNorm(
|
||||
resnet_groups, block_out_channels[0], pytorch_compatible=True
|
||||
)
|
||||
self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, 3, padding=1)
|
||||
|
||||
def __call__(self, x):
|
||||
x = self.conv_in(x)
|
||||
|
||||
x = self.mid_blocks[0](x)
|
||||
x = self.mid_blocks[1](x)
|
||||
x = self.mid_blocks[2](x)
|
||||
|
||||
for l in self.up_blocks:
|
||||
x = l(x)
|
||||
|
||||
x = self.conv_norm_out(x)
|
||||
x = nn.silu(x)
|
||||
x = self.conv_out(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class Autoencoder(nn.Module):
|
||||
"""The autoencoder that allows us to perform diffusion in the latent space."""
|
||||
|
||||
def __init__(self, config: AutoencoderConfig):
|
||||
super().__init__()
|
||||
|
||||
self.latent_channels = config.latent_channels_in
|
||||
self.scaling_factor = config.scaling_factor
|
||||
self.encoder = Encoder(
|
||||
config.in_channels,
|
||||
config.latent_channels_out,
|
||||
config.block_out_channels,
|
||||
config.layers_per_block,
|
||||
resnet_groups=config.norm_num_groups,
|
||||
)
|
||||
self.decoder = Decoder(
|
||||
config.latent_channels_in,
|
||||
config.out_channels,
|
||||
config.block_out_channels,
|
||||
config.layers_per_block + 1,
|
||||
resnet_groups=config.norm_num_groups,
|
||||
)
|
||||
|
||||
self.quant_proj = nn.Linear(
|
||||
config.latent_channels_out, config.latent_channels_out
|
||||
)
|
||||
self.post_quant_proj = nn.Linear(
|
||||
config.latent_channels_in, config.latent_channels_in
|
||||
)
|
||||
|
||||
def decode(self, z):
|
||||
return self.decoder(self.post_quant_proj(z))
|
||||
|
||||
def __call__(self, x, key=None):
|
||||
x = self.encoder(x)
|
||||
x = self.query_proj(x)
|
||||
|
||||
mean, logvar = x.split(2, axis=-1)
|
||||
std = mx.exp(0.5 * logvar)
|
||||
z = mx.random.normal(mean.shape, key=key) * std + mean
|
||||
|
||||
x_hat = self.decode(z)
|
||||
|
||||
return dict(x_hat=x_hat, z=z, mean=mean, logvar=logvar)
|
55
stable_diffusion/txt2image.py
Normal file
55
stable_diffusion/txt2image.py
Normal file
@ -0,0 +1,55 @@
|
||||
import argparse
|
||||
from PIL import Image
|
||||
from tqdm import tqdm
|
||||
|
||||
import mlx.core as mx
|
||||
|
||||
from stable_diffusion import StableDiffusion
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Generate images from a textual prompt using stable diffusion"
|
||||
)
|
||||
parser.add_argument("prompt")
|
||||
parser.add_argument("--n_images", type=int, default=4)
|
||||
parser.add_argument("--steps", type=int, default=50)
|
||||
parser.add_argument("--cfg", type=float, default=7.5)
|
||||
parser.add_argument("--negative_prompt", default="")
|
||||
parser.add_argument("--n_rows", type=int, default=1)
|
||||
parser.add_argument("--decoding_batch_size", type=int, default=1)
|
||||
parser.add_argument("--output", default="out.png")
|
||||
args = parser.parse_args()
|
||||
|
||||
sd = StableDiffusion()
|
||||
|
||||
# Generate the latent vectors using diffusion
|
||||
latents = sd.generate_latents(
|
||||
args.prompt,
|
||||
n_images=args.n_images,
|
||||
cfg_weight=args.cfg,
|
||||
num_steps=args.steps,
|
||||
negative_text=args.negative_prompt,
|
||||
)
|
||||
for x_t in tqdm(latents, total=args.steps):
|
||||
mx.simplify(x_t)
|
||||
mx.simplify(x_t)
|
||||
mx.eval(x_t)
|
||||
|
||||
# Decode them into images
|
||||
decoded = []
|
||||
for i in tqdm(range(0, args.n_images, args.decoding_batch_size)):
|
||||
decoded.append(sd.decode(x_t[i : i + args.decoding_batch_size]))
|
||||
mx.eval(decoded[-1])
|
||||
|
||||
# Arrange them on a grid
|
||||
x = mx.concatenate(decoded, axis=0)
|
||||
x = mx.pad(x, [(0, 0), (8, 8), (8, 8), (0, 0)])
|
||||
B, H, W, C = x.shape
|
||||
x = x.reshape(args.n_rows, B // args.n_rows, H, W, C).transpose(0, 2, 1, 3, 4)
|
||||
x = x.reshape(args.n_rows * H, B // args.n_rows * W, C)
|
||||
x = (x * 255).astype(mx.uint8)
|
||||
|
||||
# Save them to disc
|
||||
im = Image.fromarray(x.__array__())
|
||||
im.save(args.output)
|
Loading…
Reference in New Issue
Block a user