mirror of
				https://github.com/ml-explore/mlx-examples.git
				synced 2025-11-04 05:28:11 +08:00 
			
		
		
		
	override dtype with quant (#1062)
This commit is contained in:
		@@ -31,7 +31,7 @@ def configure_parser() -> argparse.ArgumentParser:
 | 
			
		||||
    )
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--dtype",
 | 
			
		||||
        help="Type to save the parameters, ignored if -q is given.",
 | 
			
		||||
        help="Type to save the non-quantized parameters.",
 | 
			
		||||
        type=str,
 | 
			
		||||
        choices=["float16", "bfloat16", "float32"],
 | 
			
		||||
        default="float16",
 | 
			
		||||
 
 | 
			
		||||
@@ -111,7 +111,7 @@ class MLP(nn.Module):
 | 
			
		||||
        self.up_proj = nn.Linear(dim, hidden_dim, bias=False)
 | 
			
		||||
 | 
			
		||||
    def __call__(self, x) -> mx.array:
 | 
			
		||||
        return self.down_proj(nn.gelu(self.gate_proj(x)) * self.up_proj(x))
 | 
			
		||||
        return self.down_proj(nn.gelu_approx(self.gate_proj(x)) * self.up_proj(x))
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TransformerBlock(nn.Module):
 | 
			
		||||
 
 | 
			
		||||
@@ -720,7 +720,7 @@ def convert(
 | 
			
		||||
    model, config, tokenizer = fetch_from_hub(model_path, lazy=True)
 | 
			
		||||
 | 
			
		||||
    weights = dict(tree_flatten(model.parameters()))
 | 
			
		||||
    dtype = mx.float16 if quantize else getattr(mx, dtype)
 | 
			
		||||
    dtype = getattr(mx, dtype)
 | 
			
		||||
    weights = {k: v.astype(dtype) for k, v in weights.items()}
 | 
			
		||||
 | 
			
		||||
    if quantize and dequantize:
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user