diff --git a/llms/mlx_lm/models/olmo.py b/llms/mlx_lm/models/olmo.py index 629ebe99..2525b181 100644 --- a/llms/mlx_lm/models/olmo.py +++ b/llms/mlx_lm/models/olmo.py @@ -7,12 +7,6 @@ import mlx.nn as nn from .base import BaseModelArgs -try: - import hf_olmo -except ImportError: - print("To run olmo install ai2-olmo: pip install ai2-olmo") - exit(1) - @dataclass class ModelArgs(BaseModelArgs): @@ -168,6 +162,11 @@ class OlmoModel(nn.Module): class Model(nn.Module): def __init__(self, args: ModelArgs): super().__init__() + try: + import hf_olmo + except ImportError: + print("To run olmo install ai2-olmo: pip install ai2-olmo") + exit(1) self.model = OlmoModel(args) def __call__(