mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 09:21:18 +08:00
two minor fixes (#335)
This commit is contained in:
parent
d8680a89f9
commit
bcc9fc3581
@ -200,3 +200,10 @@ class Model(nn.Module):
|
||||
):
|
||||
out, cache = self.model(inputs, cache)
|
||||
return self.lm_head(out), cache
|
||||
|
||||
@staticmethod
|
||||
def sanitize(weights):
|
||||
# Remove unused precomputed rotary freqs
|
||||
return {
|
||||
k: v for k, v in weights.items() if "self_attn.rotary_emb.inv_freq" not in k
|
||||
}
|
||||
|
@ -122,7 +122,7 @@ def generate(
|
||||
|
||||
tokens = []
|
||||
skip = 0
|
||||
REPLACEMENT_CHAR = '\ufffd'
|
||||
REPLACEMENT_CHAR = "\ufffd"
|
||||
|
||||
for token, _ in zip(generate_step(prompt, model, temp), range(max_tokens)):
|
||||
if token == tokenizer.eos_token_id:
|
||||
@ -136,7 +136,7 @@ def generate(
|
||||
print(s[skip:], end="", flush=True)
|
||||
skip = len(s)
|
||||
|
||||
tokens = tokenizer.decode(tokens).replace(REPLACEMENT_CHAR, '')
|
||||
tokens = tokenizer.decode(tokens).replace(REPLACEMENT_CHAR, "")
|
||||
if verbose:
|
||||
print(tokens[skip:], flush=True)
|
||||
return tokens
|
||||
@ -174,6 +174,8 @@ def load(path_or_hf_repo: str) -> Tuple[nn.Module, PreTrainedTokenizer]:
|
||||
weights.update(mx.load(wf))
|
||||
|
||||
model_class, model_args_class = _get_classes(config=config)
|
||||
if hasattr(model_class, "sanitize"):
|
||||
weights = model_class.sanitize(weights)
|
||||
|
||||
model_args = model_args_class.from_dict(config)
|
||||
model = model_class(model_args)
|
||||
|
@ -69,5 +69,5 @@ if __name__ == "__main__":
|
||||
x = (x * 255).astype(mx.uint8)
|
||||
|
||||
# Save them to disc
|
||||
im = Image.fromarray(x.__array__())
|
||||
im = Image.fromarray(np.array(x))
|
||||
im.save(args.output)
|
||||
|
Loading…
Reference in New Issue
Block a user