override dtype with quant (#1062)

This commit is contained in:
Awni Hannun 2024-10-22 09:56:45 -07:00 committed by GitHub
parent 743763bc2e
commit 66e7bcb886
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 3 additions and 3 deletions

View File

@ -31,7 +31,7 @@ def configure_parser() -> argparse.ArgumentParser:
)
parser.add_argument(
"--dtype",
help="Type to save the parameters, ignored if -q is given.",
help="Type to save the non-quantized parameters.",
type=str,
choices=["float16", "bfloat16", "float32"],
default="float16",

View File

@ -111,7 +111,7 @@ class MLP(nn.Module):
self.up_proj = nn.Linear(dim, hidden_dim, bias=False)
def __call__(self, x) -> mx.array:
return self.down_proj(nn.gelu(self.gate_proj(x)) * self.up_proj(x))
return self.down_proj(nn.gelu_approx(self.gate_proj(x)) * self.up_proj(x))
class TransformerBlock(nn.Module):

View File

@ -720,7 +720,7 @@ def convert(
model, config, tokenizer = fetch_from_hub(model_path, lazy=True)
weights = dict(tree_flatten(model.parameters()))
dtype = mx.float16 if quantize else getattr(mx, dtype)
dtype = getattr(mx, dtype)
weights = {k: v.astype(dtype) for k, v in weights.items()}
if quantize and dequantize: