only download local shard (#1240)

This commit is contained in:
Awni Hannun
2025-02-02 13:58:44 -08:00
committed by GitHub
parent e8afb59de4
commit 9c2ef38d4d
4 changed files with 159 additions and 65 deletions

View File

@@ -627,6 +627,7 @@ def load_config(model_path: Path) -> dict:
def load_model(
model_path: Path,
lazy: bool = False,
strict: bool = True,
model_config: dict = {},
get_model_classes: Callable[[dict], Tuple[Type[nn.Module], Type]] = _get_classes,
) -> nn.Module:
@@ -638,6 +639,8 @@ def load_model(
lazy (bool): If False eval the model parameters to make sure they are
loaded in memory before returning, otherwise they will be loaded
when needed. Default: ``False``
strict (bool): Whether or not to raise an exception if weights don't
match. Default: ``True``
model_config (dict, optional): Optional configuration parameters for the
model. Defaults to an empty dictionary.
get_model_classes (Callable[[dict], Tuple[Type[nn.Module], Type]], optional):
@@ -660,7 +663,7 @@ def load_model(
# Try weight for back-compat
weight_files = glob.glob(str(model_path / "weight*.safetensors"))
if not weight_files:
if not weight_files and strict:
logging.error(f"No safetensors found in {model_path}")
raise FileNotFoundError(f"No safetensors found in {model_path}")
@@ -694,7 +697,7 @@ def load_model(
class_predicate=class_predicate,
)
model.load_weights(list(weights.items()))
model.load_weights(list(weights.items()), strict=strict)
if not lazy:
mx.eval(model.parameters())