mirror of
				https://github.com/ml-explore/mlx-examples.git
				synced 2025-11-01 11:38:08 +08:00 
			
		
		
		
	override dtype with quant (#1062)
This commit is contained in:
		| @@ -31,7 +31,7 @@ def configure_parser() -> argparse.ArgumentParser: | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--dtype", | ||||
|         help="Type to save the parameters, ignored if -q is given.", | ||||
|         help="Type to save the non-quantized parameters.", | ||||
|         type=str, | ||||
|         choices=["float16", "bfloat16", "float32"], | ||||
|         default="float16", | ||||
|   | ||||
| @@ -111,7 +111,7 @@ class MLP(nn.Module): | ||||
|         self.up_proj = nn.Linear(dim, hidden_dim, bias=False) | ||||
|  | ||||
|     def __call__(self, x) -> mx.array: | ||||
|         return self.down_proj(nn.gelu(self.gate_proj(x)) * self.up_proj(x)) | ||||
|         return self.down_proj(nn.gelu_approx(self.gate_proj(x)) * self.up_proj(x)) | ||||
|  | ||||
|  | ||||
| class TransformerBlock(nn.Module): | ||||
|   | ||||
| @@ -720,7 +720,7 @@ def convert( | ||||
|     model, config, tokenizer = fetch_from_hub(model_path, lazy=True) | ||||
|  | ||||
|     weights = dict(tree_flatten(model.parameters())) | ||||
|     dtype = mx.float16 if quantize else getattr(mx, dtype) | ||||
|     dtype = getattr(mx, dtype) | ||||
|     weights = {k: v.astype(dtype) for k, v in weights.items()} | ||||
|  | ||||
|     if quantize and dequantize: | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Awni Hannun
					Awni Hannun