From 624a156d7baaf983c0eabb04a71bcc1f16f19307 Mon Sep 17 00:00:00 2001 From: madroid Date: Sat, 12 Oct 2024 20:04:40 +0800 Subject: [PATCH] FLUX: add load_dataset to __ini__ file --- flux/flux/__init__.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/flux/flux/__init__.py b/flux/flux/__init__.py index 8d39d605..33bd815f 100644 --- a/flux/flux/__init__.py +++ b/flux/flux/__init__.py @@ -1,7 +1,5 @@ # Copyright © 2024 Apple Inc. -import math -import time from typing import Tuple import mlx.core as mx @@ -9,6 +7,7 @@ import mlx.nn as nn from mlx.utils import tree_unflatten from tqdm import tqdm +from .datasets import load_dataset from .lora import LoRALinear from .sampler import FluxSampler from .utils import ( @@ -187,7 +186,7 @@ class FluxPipeline: images = [] for i in tqdm(range(len(x_t)), disable=not progress): - images.append(self.decode(x_t[i : i + 1])) + images.append(self.decode(x_t[i: i + 1])) mx.eval(images[-1]) images = mx.concatenate(images, axis=0) mx.eval(images)