mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 17:31:18 +08:00
reformatted
This commit is contained in:
parent
d7cab9d5f5
commit
298178d669
@ -8,6 +8,7 @@ import json
|
|||||||
import shutil
|
import shutil
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict
|
from typing import Dict
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
import torch
|
import torch
|
||||||
@ -149,7 +150,7 @@ def quantize(weights, config, args):
|
|||||||
def make_shards(weights: dict, max_file_size_gibibyte: int = 15):
|
def make_shards(weights: dict, max_file_size_gibibyte: int = 15):
|
||||||
max_file_size_bytes = max_file_size_gibibyte << 30
|
max_file_size_bytes = max_file_size_gibibyte << 30
|
||||||
shards = []
|
shards = []
|
||||||
shard : Dict[str, mx.array] = {}
|
shard: Dict[str, mx.array] = {}
|
||||||
shard_size = 0
|
shard_size = 0
|
||||||
for k, v in weights.items():
|
for k, v in weights.items():
|
||||||
if shard_size + v.nbytes > max_file_size_bytes:
|
if shard_size + v.nbytes > max_file_size_bytes:
|
||||||
|
Loading…
Reference in New Issue
Block a user