mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-10-24 22:58:08 +08:00
work with tuple shape (#393)
This commit is contained in:
@@ -16,21 +16,6 @@ from .model_io import (
|
||||
from .sampler import SimpleEulerSampler
|
||||
|
||||
|
||||
def _repeat(x, n, axis):
|
||||
# Make the expanded shape
|
||||
s = x.shape
|
||||
s.insert(axis + 1, n)
|
||||
|
||||
# Expand
|
||||
x = mx.broadcast_to(mx.expand_dims(x, axis + 1), s)
|
||||
|
||||
# Make the flattened shape
|
||||
s.pop(axis + 1)
|
||||
s[axis] *= n
|
||||
|
||||
return x.reshape(s)
|
||||
|
||||
|
||||
class StableDiffusion:
|
||||
def __init__(self, model: str = _DEFAULT_MODEL, float16: bool = False):
|
||||
self.dtype = mx.float16 if float16 else mx.float32
|
||||
@@ -62,7 +47,7 @@ class StableDiffusion:
|
||||
|
||||
# Repeat the conditioning for each of the generated images
|
||||
if n_images > 1:
|
||||
conditioning = _repeat(conditioning, n_images, axis=0)
|
||||
conditioning = mx.repeat(conditioning, n_images, axis=0)
|
||||
|
||||
return conditioning
|
||||
|
||||
|
||||
Reference in New Issue
Block a user