FLUX: add load_dataset to __ini__ file

This commit is contained in:
madroid 2024-10-12 20:04:40 +08:00
parent b22611d2a9
commit 624a156d7b

View File

@ -1,7 +1,5 @@
# Copyright © 2024 Apple Inc.
import math
import time
from typing import Tuple
import mlx.core as mx
@ -9,6 +7,7 @@ import mlx.nn as nn
from mlx.utils import tree_unflatten
from tqdm import tqdm
from .datasets import load_dataset
from .lora import LoRALinear
from .sampler import FluxSampler
from .utils import (
@ -187,7 +186,7 @@ class FluxPipeline:
images = []
for i in tqdm(range(len(x_t)), disable=not progress):
images.append(self.decode(x_t[i : i + 1]))
images.append(self.decode(x_t[i: i + 1]))
mx.eval(images[-1])
images = mx.concatenate(images, axis=0)
mx.eval(images)