mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-09-01 04:14:38 +08:00
Add the tokenizers
This commit is contained in:
136
flux/flux/tokenizers.py
Normal file
136
flux/flux/tokenizers.py
Normal file
@@ -0,0 +1,136 @@
|
||||
import regex
|
||||
from sentencepiece import SentencePieceProcessor
|
||||
|
||||
|
||||
class CLIPTokenizer:
|
||||
"""A simple port of CLIPTokenizer from https://github.com/huggingface/transformers/ ."""
|
||||
|
||||
def __init__(self, bpe_ranks, vocab):
|
||||
self.bpe_ranks = bpe_ranks
|
||||
self.vocab = vocab
|
||||
self.pat = regex.compile(
|
||||
r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""",
|
||||
regex.IGNORECASE,
|
||||
)
|
||||
|
||||
self._cache = {self.bos: self.bos, self.eos: self.eos}
|
||||
|
||||
@property
|
||||
def bos(self):
|
||||
return "<|startoftext|>"
|
||||
|
||||
@property
|
||||
def bos_token(self):
|
||||
return self.vocab[self.bos]
|
||||
|
||||
@property
|
||||
def eos(self):
|
||||
return "<|endoftext|>"
|
||||
|
||||
@property
|
||||
def eos_token(self):
|
||||
return self.vocab[self.eos]
|
||||
|
||||
def bpe(self, text):
|
||||
if text in self._cache:
|
||||
return self._cache[text]
|
||||
|
||||
unigrams = list(text[:-1]) + [text[-1] + "</w>"]
|
||||
unique_bigrams = set(zip(unigrams, unigrams[1:]))
|
||||
|
||||
if not unique_bigrams:
|
||||
return unigrams
|
||||
|
||||
# In every iteration try to merge the two most likely bigrams. If none
|
||||
# was merged we are done.
|
||||
#
|
||||
# Ported from https://github.com/huggingface/transformers/blob/main/src/transformers/models/clip/tokenization_clip.py
|
||||
while unique_bigrams:
|
||||
bigram = min(
|
||||
unique_bigrams, key=lambda pair: self.bpe_ranks.get(pair, float("inf"))
|
||||
)
|
||||
if bigram not in self.bpe_ranks:
|
||||
break
|
||||
|
||||
new_unigrams = []
|
||||
skip = False
|
||||
for a, b in zip(unigrams, unigrams[1:]):
|
||||
if skip:
|
||||
skip = False
|
||||
continue
|
||||
|
||||
if (a, b) == bigram:
|
||||
new_unigrams.append(a + b)
|
||||
skip = True
|
||||
|
||||
else:
|
||||
new_unigrams.append(a)
|
||||
|
||||
if not skip:
|
||||
new_unigrams.append(b)
|
||||
|
||||
unigrams = new_unigrams
|
||||
unique_bigrams = set(zip(unigrams, unigrams[1:]))
|
||||
|
||||
self._cache[text] = unigrams
|
||||
|
||||
return unigrams
|
||||
|
||||
def tokenize(self, text, prepend_bos=True, append_eos=True):
|
||||
if isinstance(text, list):
|
||||
return [self.tokenize(t, prepend_bos, append_eos) for t in text]
|
||||
|
||||
# Lower case cleanup and split according to self.pat. Hugging Face does
|
||||
# a much more thorough job here but this should suffice for 95% of
|
||||
# cases.
|
||||
clean_text = regex.sub(r"\s+", " ", text.lower())
|
||||
tokens = regex.findall(self.pat, clean_text)
|
||||
|
||||
# Split the tokens according to the byte-pair merge file
|
||||
bpe_tokens = [ti for t in tokens for ti in self.bpe(t)]
|
||||
|
||||
# Map to token ids and return
|
||||
tokens = [self.vocab[t] for t in bpe_tokens]
|
||||
if prepend_bos:
|
||||
tokens = [self.bos_token] + tokens
|
||||
if append_eos:
|
||||
tokens.append(self.eos_token)
|
||||
|
||||
return tokens
|
||||
|
||||
|
||||
class T5Tokenizer:
|
||||
def __init__(self, model_file):
|
||||
self._tokenizer = SentencePieceProcessor(model_file)
|
||||
|
||||
@property
|
||||
def bos(self):
|
||||
try:
|
||||
return self._tokenizer.id_to_piece(self.bos_token)
|
||||
except IndexError:
|
||||
return None
|
||||
|
||||
@property
|
||||
def bos_token(self):
|
||||
return self._tokenizer.bos_id()
|
||||
|
||||
@property
|
||||
def eos(self):
|
||||
try:
|
||||
return self._tokenizer.id_to_piece(self.eos_token)
|
||||
except IndexError:
|
||||
return None
|
||||
|
||||
@property
|
||||
def eos_token(self):
|
||||
return self._tokenizer.eos_id()
|
||||
|
||||
def tokenize(self, text, prepend_bos=True, append_eos=True):
|
||||
tokens = self._tokenizer.encode(text)
|
||||
|
||||
if prepend_bos and self.bos_token > 0:
|
||||
tokens = [self.bos_token] + tokens
|
||||
if append_eos and self.eos_token > 0:
|
||||
tokens.append(self.eos_token)
|
||||
|
||||
return tokens
|
@@ -10,6 +10,7 @@ from .autoencoder import AutoEncoder, AutoEncoderParams
|
||||
from .clip import CLIPTextModel, CLIPTextModelConfig
|
||||
from .model import Flux, FluxParams
|
||||
from .t5 import T5Config, T5Encoder
|
||||
from .tokenizers import CLIPTokenizer, T5Tokenizer
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -141,6 +142,24 @@ def load_ae(name: str, hf_download: bool = True):
|
||||
return ae
|
||||
|
||||
|
||||
def load_clip(name: str):
|
||||
# Load the config
|
||||
config_path = hf_hub_download(configs[name].repo_id, "text_encoder/config.json")
|
||||
with open(config_path) as f:
|
||||
config = CLIPTextModelConfig.from_dict(json.load(f))
|
||||
|
||||
# Make the clip text encoder
|
||||
clip = CLIPTextModel(config)
|
||||
|
||||
# Load the weights
|
||||
ckpt_path = hf_hub_download(configs[name].repo_id, "text_encoder/model.safetensors")
|
||||
weights = mx.load(ckpt_path)
|
||||
weights = clip.sanitize(weights)
|
||||
clip.load_weights(list(weights.items()))
|
||||
|
||||
return clip
|
||||
|
||||
|
||||
def load_t5(name: str):
|
||||
# Load the config
|
||||
config_path = hf_hub_download(configs[name].repo_id, "text_encoder_2/config.json")
|
||||
@@ -169,19 +188,20 @@ def load_t5(name: str):
|
||||
return t5
|
||||
|
||||
|
||||
def load_clip(name: str):
|
||||
# Load the config
|
||||
config_path = hf_hub_download(configs[name].repo_id, "text_encoder/config.json")
|
||||
with open(config_path) as f:
|
||||
config = CLIPTextModelConfig.from_dict(json.load(f))
|
||||
def load_clip_tokenizer(name: str):
|
||||
vocab_file = hf_hub_download(configs[name].repo_id, "tokenizer/vocab.json")
|
||||
with open(vocab_file, encoding="utf-8") as f:
|
||||
vocab = json.load(f)
|
||||
|
||||
# Make the clip text encoder
|
||||
clip = CLIPTextModel(config)
|
||||
merges_file = hf_hub_download(configs[name].repo_id, "tokenizer/merges.txt")
|
||||
with open(merges_file, encoding="utf-8") as f:
|
||||
bpe_merges = f.read().strip().split("\n")[1 : 49152 - 256 - 2 + 1]
|
||||
bpe_merges = [tuple(m.split()) for m in bpe_merges]
|
||||
bpe_ranks = dict(map(reversed, enumerate(bpe_merges)))
|
||||
|
||||
# Load the weights
|
||||
ckpt_path = hf_hub_download(configs[name].repo_id, "text_encoder/model.safetensors")
|
||||
weights = mx.load(ckpt_path)
|
||||
weights = clip.sanitize(weights)
|
||||
clip.load_weights(list(weights.items()))
|
||||
return CLIPTokenizer(bpe_ranks, vocab)
|
||||
|
||||
return clip
|
||||
|
||||
def load_t5_tokenizer(name: str):
|
||||
model_file = hf_hub_download(configs[name].repo_id, "tokenizer_2/spiece.model")
|
||||
return T5Tokenizer(model_file)
|
||||
|
Reference in New Issue
Block a user