mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-12-15 01:42:31 +08:00
only download local shard (#1240)
This commit is contained in:
@@ -627,6 +627,7 @@ def load_config(model_path: Path) -> dict:
|
||||
def load_model(
|
||||
model_path: Path,
|
||||
lazy: bool = False,
|
||||
strict: bool = True,
|
||||
model_config: dict = {},
|
||||
get_model_classes: Callable[[dict], Tuple[Type[nn.Module], Type]] = _get_classes,
|
||||
) -> nn.Module:
|
||||
@@ -638,6 +639,8 @@ def load_model(
|
||||
lazy (bool): If False eval the model parameters to make sure they are
|
||||
loaded in memory before returning, otherwise they will be loaded
|
||||
when needed. Default: ``False``
|
||||
strict (bool): Whether or not to raise an exception if weights don't
|
||||
match. Default: ``True``
|
||||
model_config (dict, optional): Optional configuration parameters for the
|
||||
model. Defaults to an empty dictionary.
|
||||
get_model_classes (Callable[[dict], Tuple[Type[nn.Module], Type]], optional):
|
||||
@@ -660,7 +663,7 @@ def load_model(
|
||||
# Try weight for back-compat
|
||||
weight_files = glob.glob(str(model_path / "weight*.safetensors"))
|
||||
|
||||
if not weight_files:
|
||||
if not weight_files and strict:
|
||||
logging.error(f"No safetensors found in {model_path}")
|
||||
raise FileNotFoundError(f"No safetensors found in {model_path}")
|
||||
|
||||
@@ -694,7 +697,7 @@ def load_model(
|
||||
class_predicate=class_predicate,
|
||||
)
|
||||
|
||||
model.load_weights(list(weights.items()))
|
||||
model.load_weights(list(weights.items()), strict=strict)
|
||||
|
||||
if not lazy:
|
||||
mx.eval(model.parameters())
|
||||
|
||||
Reference in New Issue
Block a user