two minor fixes (#335)

This commit is contained in:
Awni Hannun 2024-01-18 14:18:13 -08:00 committed by GitHub
parent d8680a89f9
commit bcc9fc3581
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 12 additions and 3 deletions

View File

@ -200,3 +200,10 @@ class Model(nn.Module):
): ):
out, cache = self.model(inputs, cache) out, cache = self.model(inputs, cache)
return self.lm_head(out), cache return self.lm_head(out), cache
@staticmethod
def sanitize(weights):
# Remove unused precomputed rotary freqs
return {
k: v for k, v in weights.items() if "self_attn.rotary_emb.inv_freq" not in k
}

View File

@ -122,7 +122,7 @@ def generate(
tokens = [] tokens = []
skip = 0 skip = 0
REPLACEMENT_CHAR = '\ufffd' REPLACEMENT_CHAR = "\ufffd"
for token, _ in zip(generate_step(prompt, model, temp), range(max_tokens)): for token, _ in zip(generate_step(prompt, model, temp), range(max_tokens)):
if token == tokenizer.eos_token_id: if token == tokenizer.eos_token_id:
@ -136,7 +136,7 @@ def generate(
print(s[skip:], end="", flush=True) print(s[skip:], end="", flush=True)
skip = len(s) skip = len(s)
tokens = tokenizer.decode(tokens).replace(REPLACEMENT_CHAR, '') tokens = tokenizer.decode(tokens).replace(REPLACEMENT_CHAR, "")
if verbose: if verbose:
print(tokens[skip:], flush=True) print(tokens[skip:], flush=True)
return tokens return tokens
@ -174,6 +174,8 @@ def load(path_or_hf_repo: str) -> Tuple[nn.Module, PreTrainedTokenizer]:
weights.update(mx.load(wf)) weights.update(mx.load(wf))
model_class, model_args_class = _get_classes(config=config) model_class, model_args_class = _get_classes(config=config)
if hasattr(model_class, "sanitize"):
weights = model_class.sanitize(weights)
model_args = model_args_class.from_dict(config) model_args = model_args_class.from_dict(config)
model = model_class(model_args) model = model_class(model_args)

View File

@ -69,5 +69,5 @@ if __name__ == "__main__":
x = (x * 255).astype(mx.uint8) x = (x * 255).astype(mx.uint8)
# Save them to disc # Save them to disc
im = Image.fromarray(x.__array__()) im = Image.fromarray(np.array(x))
im.save(args.output) im.save(args.output)