mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-07-21 02:51:13 +08:00
186 lines
5.3 KiB
Python
186 lines
5.3 KiB
Python
# Copyright © 2024 Apple Inc.
|
|
|
|
import mlx.core as mx
|
|
import regex
|
|
from sentencepiece import SentencePieceProcessor
|
|
|
|
|
|
class CLIPTokenizer:
|
|
"""A simple port of CLIPTokenizer from https://github.com/huggingface/transformers/ ."""
|
|
|
|
def __init__(self, bpe_ranks, vocab, max_length=77):
|
|
self.max_length = max_length
|
|
self.bpe_ranks = bpe_ranks
|
|
self.vocab = vocab
|
|
self.pat = regex.compile(
|
|
r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""",
|
|
regex.IGNORECASE,
|
|
)
|
|
|
|
self._cache = {self.bos: self.bos, self.eos: self.eos}
|
|
|
|
@property
|
|
def bos(self):
|
|
return "<|startoftext|>"
|
|
|
|
@property
|
|
def bos_token(self):
|
|
return self.vocab[self.bos]
|
|
|
|
@property
|
|
def eos(self):
|
|
return "<|endoftext|>"
|
|
|
|
@property
|
|
def eos_token(self):
|
|
return self.vocab[self.eos]
|
|
|
|
def bpe(self, text):
|
|
if text in self._cache:
|
|
return self._cache[text]
|
|
|
|
unigrams = list(text[:-1]) + [text[-1] + "</w>"]
|
|
unique_bigrams = set(zip(unigrams, unigrams[1:]))
|
|
|
|
if not unique_bigrams:
|
|
return unigrams
|
|
|
|
# In every iteration try to merge the two most likely bigrams. If none
|
|
# was merged we are done.
|
|
#
|
|
# Ported from https://github.com/huggingface/transformers/blob/main/src/transformers/models/clip/tokenization_clip.py
|
|
while unique_bigrams:
|
|
bigram = min(
|
|
unique_bigrams, key=lambda pair: self.bpe_ranks.get(pair, float("inf"))
|
|
)
|
|
if bigram not in self.bpe_ranks:
|
|
break
|
|
|
|
new_unigrams = []
|
|
skip = False
|
|
for a, b in zip(unigrams, unigrams[1:]):
|
|
if skip:
|
|
skip = False
|
|
continue
|
|
|
|
if (a, b) == bigram:
|
|
new_unigrams.append(a + b)
|
|
skip = True
|
|
|
|
else:
|
|
new_unigrams.append(a)
|
|
|
|
if not skip:
|
|
new_unigrams.append(b)
|
|
|
|
unigrams = new_unigrams
|
|
unique_bigrams = set(zip(unigrams, unigrams[1:]))
|
|
|
|
self._cache[text] = unigrams
|
|
|
|
return unigrams
|
|
|
|
def tokenize(self, text, prepend_bos=True, append_eos=True):
|
|
if isinstance(text, list):
|
|
return [self.tokenize(t, prepend_bos, append_eos) for t in text]
|
|
|
|
# Lower case cleanup and split according to self.pat. Hugging Face does
|
|
# a much more thorough job here but this should suffice for 95% of
|
|
# cases.
|
|
clean_text = regex.sub(r"\s+", " ", text.lower())
|
|
tokens = regex.findall(self.pat, clean_text)
|
|
|
|
# Split the tokens according to the byte-pair merge file
|
|
bpe_tokens = [ti for t in tokens for ti in self.bpe(t)]
|
|
|
|
# Map to token ids and return
|
|
tokens = [self.vocab[t] for t in bpe_tokens]
|
|
if prepend_bos:
|
|
tokens = [self.bos_token] + tokens
|
|
if append_eos:
|
|
tokens.append(self.eos_token)
|
|
|
|
if len(tokens) > self.max_length:
|
|
tokens = tokens[: self.max_length]
|
|
if append_eos:
|
|
tokens[-1] = self.eos_token
|
|
|
|
return tokens
|
|
|
|
def encode(self, text):
|
|
if not isinstance(text, list):
|
|
return self.encode([text])
|
|
|
|
tokens = self.tokenize(text)
|
|
length = max(len(t) for t in tokens)
|
|
for t in tokens:
|
|
t.extend([self.eos_token] * (length - len(t)))
|
|
|
|
return mx.array(tokens)
|
|
|
|
|
|
class T5Tokenizer:
|
|
def __init__(self, model_file, max_length=512):
|
|
self._tokenizer = SentencePieceProcessor(model_file)
|
|
self.max_length = max_length
|
|
|
|
@property
|
|
def pad(self):
|
|
try:
|
|
return self._tokenizer.id_to_piece(self.pad_token)
|
|
except IndexError:
|
|
return None
|
|
|
|
@property
|
|
def pad_token(self):
|
|
return self._tokenizer.pad_id()
|
|
|
|
@property
|
|
def bos(self):
|
|
try:
|
|
return self._tokenizer.id_to_piece(self.bos_token)
|
|
except IndexError:
|
|
return None
|
|
|
|
@property
|
|
def bos_token(self):
|
|
return self._tokenizer.bos_id()
|
|
|
|
@property
|
|
def eos(self):
|
|
try:
|
|
return self._tokenizer.id_to_piece(self.eos_token)
|
|
except IndexError:
|
|
return None
|
|
|
|
@property
|
|
def eos_token(self):
|
|
return self._tokenizer.eos_id()
|
|
|
|
def tokenize(self, text, prepend_bos=True, append_eos=True, pad=True):
|
|
if isinstance(text, list):
|
|
return [self.tokenize(t, prepend_bos, append_eos, pad) for t in text]
|
|
|
|
tokens = self._tokenizer.encode(text)
|
|
|
|
if prepend_bos and self.bos_token >= 0:
|
|
tokens = [self.bos_token] + tokens
|
|
if append_eos and self.eos_token >= 0:
|
|
tokens.append(self.eos_token)
|
|
if pad and len(tokens) < self.max_length and self.pad_token >= 0:
|
|
tokens += [self.pad_token] * (self.max_length - len(tokens))
|
|
|
|
return tokens
|
|
|
|
def encode(self, text, pad=True):
|
|
if not isinstance(text, list):
|
|
return self.encode([text], pad=pad)
|
|
|
|
pad_token = self.pad_token if self.pad_token >= 0 else 0
|
|
tokens = self.tokenize(text, pad=pad)
|
|
length = max(len(t) for t in tokens)
|
|
for t in tokens:
|
|
t.extend([pad_token] * (length - len(t)))
|
|
|
|
return mx.array(tokens)
|