mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 17:31:18 +08:00
work with tuple shape (#393)
This commit is contained in:
parent
0340113e02
commit
ec14583c2a
@ -48,8 +48,6 @@ latent_generator = sd.generate_latents("A photo of an astronaut riding a horse o
|
|||||||
# Here we are evaluating each diffusion step but we could also evaluate
|
# Here we are evaluating each diffusion step but we could also evaluate
|
||||||
# once at the end.
|
# once at the end.
|
||||||
for x_t in latent_generator:
|
for x_t in latent_generator:
|
||||||
mx.simplify(x_t) # remove possible redundant computation eg reuse
|
|
||||||
# scalars etc
|
|
||||||
mx.eval(x_t)
|
mx.eval(x_t)
|
||||||
|
|
||||||
# Now x_t is the last latent from the reverse process aka x_0. We can
|
# Now x_t is the last latent from the reverse process aka x_0. We can
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
mlx
|
mlx>=0.1
|
||||||
safetensors
|
safetensors
|
||||||
huggingface-hub
|
huggingface-hub
|
||||||
regex
|
regex
|
||||||
|
@ -16,21 +16,6 @@ from .model_io import (
|
|||||||
from .sampler import SimpleEulerSampler
|
from .sampler import SimpleEulerSampler
|
||||||
|
|
||||||
|
|
||||||
def _repeat(x, n, axis):
|
|
||||||
# Make the expanded shape
|
|
||||||
s = x.shape
|
|
||||||
s.insert(axis + 1, n)
|
|
||||||
|
|
||||||
# Expand
|
|
||||||
x = mx.broadcast_to(mx.expand_dims(x, axis + 1), s)
|
|
||||||
|
|
||||||
# Make the flattened shape
|
|
||||||
s.pop(axis + 1)
|
|
||||||
s[axis] *= n
|
|
||||||
|
|
||||||
return x.reshape(s)
|
|
||||||
|
|
||||||
|
|
||||||
class StableDiffusion:
|
class StableDiffusion:
|
||||||
def __init__(self, model: str = _DEFAULT_MODEL, float16: bool = False):
|
def __init__(self, model: str = _DEFAULT_MODEL, float16: bool = False):
|
||||||
self.dtype = mx.float16 if float16 else mx.float32
|
self.dtype = mx.float16 if float16 else mx.float32
|
||||||
@ -62,7 +47,7 @@ class StableDiffusion:
|
|||||||
|
|
||||||
# Repeat the conditioning for each of the generated images
|
# Repeat the conditioning for each of the generated images
|
||||||
if n_images > 1:
|
if n_images > 1:
|
||||||
conditioning = _repeat(conditioning, n_images, axis=0)
|
conditioning = mx.repeat(conditioning, n_images, axis=0)
|
||||||
|
|
||||||
return conditioning
|
return conditioning
|
||||||
|
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
mlx
|
mlx>=0.1
|
||||||
numba
|
numba
|
||||||
numpy
|
numpy
|
||||||
torch
|
torch
|
||||||
|
@ -303,7 +303,7 @@ class TestWhisper(unittest.TestCase):
|
|||||||
def check_segment(seg, expected):
|
def check_segment(seg, expected):
|
||||||
for k, v in expected.items():
|
for k, v in expected.items():
|
||||||
if isinstance(v, float):
|
if isinstance(v, float):
|
||||||
self.assertAlmostEqual(seg[k], v, places=3)
|
self.assertAlmostEqual(seg[k], v, places=2)
|
||||||
else:
|
else:
|
||||||
self.assertEqual(seg[k], v)
|
self.assertEqual(seg[k], v)
|
||||||
|
|
||||||
|
@ -50,7 +50,7 @@ def detect_language(
|
|||||||
mel = mel[None]
|
mel = mel[None]
|
||||||
|
|
||||||
# skip encoder forward pass if already-encoded audio features were given
|
# skip encoder forward pass if already-encoded audio features were given
|
||||||
if mel.shape[-2:] != [model.dims.n_audio_ctx, model.dims.n_audio_state]:
|
if mel.shape[-2:] != (model.dims.n_audio_ctx, model.dims.n_audio_state):
|
||||||
mel = model.encoder(mel)
|
mel = model.encoder(mel)
|
||||||
|
|
||||||
# forward pass using a single token, startoftranscript
|
# forward pass using a single token, startoftranscript
|
||||||
|
Loading…
Reference in New Issue
Block a user