mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 09:21:18 +08:00
fix(mlx-m): lazy load hf_olmo (#424)
This commit is contained in:
parent
9b387007ab
commit
da7adae5ec
@ -7,12 +7,6 @@ import mlx.nn as nn
|
|||||||
|
|
||||||
from .base import BaseModelArgs
|
from .base import BaseModelArgs
|
||||||
|
|
||||||
try:
|
|
||||||
import hf_olmo
|
|
||||||
except ImportError:
|
|
||||||
print("To run olmo install ai2-olmo: pip install ai2-olmo")
|
|
||||||
exit(1)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ModelArgs(BaseModelArgs):
|
class ModelArgs(BaseModelArgs):
|
||||||
@ -168,6 +162,11 @@ class OlmoModel(nn.Module):
|
|||||||
class Model(nn.Module):
|
class Model(nn.Module):
|
||||||
def __init__(self, args: ModelArgs):
|
def __init__(self, args: ModelArgs):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
try:
|
||||||
|
import hf_olmo
|
||||||
|
except ImportError:
|
||||||
|
print("To run olmo install ai2-olmo: pip install ai2-olmo")
|
||||||
|
exit(1)
|
||||||
self.model = OlmoModel(args)
|
self.model = OlmoModel(args)
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
|
Loading…
Reference in New Issue
Block a user