Merge branch 'ml-explore:main' into adding-support-for-mamba2

This commit is contained in:
Gökdeniz Gülmez 2024-11-21 21:06:45 +01:00 committed by GitHub
commit e4eae973e8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 4 additions and 2 deletions

View File

@ -61,8 +61,8 @@ def wired_limit(model: nn.Module, streams: Optional[List[mx.Stream]] = None):
model_mb = model_bytes // 2**20
max_rec_mb = max_rec_size // 2**20
print(
"[WARNING] Generating with a model that requires {model_mb} MB "
"which is close to the maximum recommended size of {max_rec_mb} "
f"[WARNING] Generating with a model that requires {model_mb} MB "
f"which is close to the maximum recommended size of {max_rec_mb} "
"MB. This can be slow. See the documentation for possible work-arounds: "
"https://github.com/ml-explore/mlx-examples/tree/main/llms#large-models"
)

View File

@ -30,6 +30,7 @@ if __name__ == "__main__":
parser.add_argument("--preload-models", action="store_true")
parser.add_argument("--output", default="out.png")
parser.add_argument("--verbose", "-v", action="store_true")
parser.add_argument("--seed", type=int)
args = parser.parse_args()
# Load the models
@ -94,6 +95,7 @@ if __name__ == "__main__":
cfg_weight=args.cfg,
num_steps=args.steps,
negative_text=args.negative_prompt,
seed=args.seed,
)
for x_t in tqdm(latents, total=int(args.steps * args.strength)):
mx.eval(x_t)