fix(mlx-m): lazy load hf_olmo (#424)

This commit is contained in:
Anchen
2024-02-09 04:02:43 +11:00
committed by GitHub
parent 9b387007ab
commit da7adae5ec

View File

@@ -7,12 +7,6 @@ import mlx.nn as nn
from .base import BaseModelArgs
try:
import hf_olmo
except ImportError:
print("To run olmo install ai2-olmo: pip install ai2-olmo")
exit(1)
@dataclass
class ModelArgs(BaseModelArgs):
@@ -168,6 +162,11 @@ class OlmoModel(nn.Module):
class Model(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
try:
import hf_olmo
except ImportError:
print("To run olmo install ai2-olmo: pip install ai2-olmo")
exit(1)
self.model = OlmoModel(args)
def __call__(