diff --git a/stable_diffusion/README.md b/stable_diffusion/README.md index 711ec10c..cf2b7113 100644 --- a/stable_diffusion/README.md +++ b/stable_diffusion/README.md @@ -48,8 +48,6 @@ latent_generator = sd.generate_latents("A photo of an astronaut riding a horse o # Here we are evaluating each diffusion step but we could also evaluate # once at the end. for x_t in latent_generator: - mx.simplify(x_t) # remove possible redundant computation eg reuse - # scalars etc mx.eval(x_t) # Now x_t is the last latent from the reverse process aka x_0. We can diff --git a/stable_diffusion/requirements.txt b/stable_diffusion/requirements.txt index c2fa7225..aa76c437 100644 --- a/stable_diffusion/requirements.txt +++ b/stable_diffusion/requirements.txt @@ -1,4 +1,4 @@ -mlx +mlx>=0.1 safetensors huggingface-hub regex diff --git a/stable_diffusion/stable_diffusion/__init__.py b/stable_diffusion/stable_diffusion/__init__.py index f9325ae6..9a8052c5 100644 --- a/stable_diffusion/stable_diffusion/__init__.py +++ b/stable_diffusion/stable_diffusion/__init__.py @@ -16,21 +16,6 @@ from .model_io import ( from .sampler import SimpleEulerSampler -def _repeat(x, n, axis): - # Make the expanded shape - s = x.shape - s.insert(axis + 1, n) - - # Expand - x = mx.broadcast_to(mx.expand_dims(x, axis + 1), s) - - # Make the flattened shape - s.pop(axis + 1) - s[axis] *= n - - return x.reshape(s) - - class StableDiffusion: def __init__(self, model: str = _DEFAULT_MODEL, float16: bool = False): self.dtype = mx.float16 if float16 else mx.float32 @@ -62,7 +47,7 @@ class StableDiffusion: # Repeat the conditioning for each of the generated images if n_images > 1: - conditioning = _repeat(conditioning, n_images, axis=0) + conditioning = mx.repeat(conditioning, n_images, axis=0) return conditioning diff --git a/whisper/requirements.txt b/whisper/requirements.txt index e4dbf8d1..23d43200 100644 --- a/whisper/requirements.txt +++ b/whisper/requirements.txt @@ -1,4 +1,4 @@ -mlx +mlx>=0.1 numba numpy torch diff --git a/whisper/test.py b/whisper/test.py index 835fc179..9fc3f0d5 100644 --- a/whisper/test.py +++ b/whisper/test.py @@ -303,7 +303,7 @@ class TestWhisper(unittest.TestCase): def check_segment(seg, expected): for k, v in expected.items(): if isinstance(v, float): - self.assertAlmostEqual(seg[k], v, places=3) + self.assertAlmostEqual(seg[k], v, places=2) else: self.assertEqual(seg[k], v) diff --git a/whisper/whisper/decoding.py b/whisper/whisper/decoding.py index c2105972..d0598496 100644 --- a/whisper/whisper/decoding.py +++ b/whisper/whisper/decoding.py @@ -50,7 +50,7 @@ def detect_language( mel = mel[None] # skip encoder forward pass if already-encoded audio features were given - if mel.shape[-2:] != [model.dims.n_audio_ctx, model.dims.n_audio_state]: + if mel.shape[-2:] != (model.dims.n_audio_ctx, model.dims.n_audio_state): mel = model.encoder(mel) # forward pass using a single token, startoftranscript