work with tuple shape (#393)

This commit is contained in:
Awni Hannun 2024-02-01 13:03:47 -08:00 committed by GitHub
parent 0340113e02
commit ec14583c2a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 5 additions and 22 deletions

View File

@ -48,8 +48,6 @@ latent_generator = sd.generate_latents("A photo of an astronaut riding a horse o
# Here we are evaluating each diffusion step but we could also evaluate
# once at the end.
for x_t in latent_generator:
mx.simplify(x_t) # remove possible redundant computation eg reuse
# scalars etc
mx.eval(x_t)
# Now x_t is the last latent from the reverse process aka x_0. We can

View File

@ -1,4 +1,4 @@
mlx
mlx>=0.1
safetensors
huggingface-hub
regex

View File

@ -16,21 +16,6 @@ from .model_io import (
from .sampler import SimpleEulerSampler
def _repeat(x, n, axis):
# Make the expanded shape
s = x.shape
s.insert(axis + 1, n)
# Expand
x = mx.broadcast_to(mx.expand_dims(x, axis + 1), s)
# Make the flattened shape
s.pop(axis + 1)
s[axis] *= n
return x.reshape(s)
class StableDiffusion:
def __init__(self, model: str = _DEFAULT_MODEL, float16: bool = False):
self.dtype = mx.float16 if float16 else mx.float32
@ -62,7 +47,7 @@ class StableDiffusion:
# Repeat the conditioning for each of the generated images
if n_images > 1:
conditioning = _repeat(conditioning, n_images, axis=0)
conditioning = mx.repeat(conditioning, n_images, axis=0)
return conditioning

View File

@ -1,4 +1,4 @@
mlx
mlx>=0.1
numba
numpy
torch

View File

@ -303,7 +303,7 @@ class TestWhisper(unittest.TestCase):
def check_segment(seg, expected):
for k, v in expected.items():
if isinstance(v, float):
self.assertAlmostEqual(seg[k], v, places=3)
self.assertAlmostEqual(seg[k], v, places=2)
else:
self.assertEqual(seg[k], v)

View File

@ -50,7 +50,7 @@ def detect_language(
mel = mel[None]
# skip encoder forward pass if already-encoded audio features were given
if mel.shape[-2:] != [model.dims.n_audio_ctx, model.dims.n_audio_state]:
if mel.shape[-2:] != (model.dims.n_audio_ctx, model.dims.n_audio_state):
mel = model.encoder(mel)
# forward pass using a single token, startoftranscript