fix(mlx-m): lazy load hf_olmo (#424)

This commit is contained in:
Anchen 2024-02-09 04:02:43 +11:00 committed by GitHub
parent 9b387007ab
commit da7adae5ec
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -7,12 +7,6 @@ import mlx.nn as nn
from .base import BaseModelArgs
try:
import hf_olmo
except ImportError:
print("To run olmo install ai2-olmo: pip install ai2-olmo")
exit(1)
@dataclass
class ModelArgs(BaseModelArgs):
@ -168,6 +162,11 @@ class OlmoModel(nn.Module):
class Model(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
try:
import hf_olmo
except ImportError:
print("To run olmo install ai2-olmo: pip install ai2-olmo")
exit(1)
self.model = OlmoModel(args)
def __call__(