Compare commits

...

40 Commits

Author SHA1 Message Date
Huss
ed67d295f2
Merge 5e8ab40d7d into 4b2a0df237 2025-06-20 16:28:13 +12:00
Huss
5e8ab40d7d
Merge branch 'ml-explore:main' into main 2025-06-19 15:45:31 +03:00
Shubbair
a5752be9d9 Code Arrangement 2024-08-01 15:41:21 +03:00
Shubbair
f84b231cf2 Code Arrangement 2024-08-01 15:29:43 +03:00
Shubbair
7e0bdacef3 Code Arrangement 2024-08-01 15:22:19 +03:00
Shubbair
37bbf3ec54 Updating GAN Code... 2024-08-01 01:04:14 +03:00
Shubbair
4d17f80efb Updating GAN Code... 2024-07-31 20:23:57 +03:00
Shubbair
1ef3ad2c6c Updating GAN Code... 2024-07-31 19:59:36 +03:00
Shubbair
a8ffa9cb18 Updating GAN Code... 2024-07-31 11:50:32 +03:00
Shubbair
f70cef9567 Updating GAN Code... 2024-07-31 11:25:39 +03:00
Shubbair
6f7a6609b9 Updating MLX Notebook 2024-07-30 20:01:14 +03:00
Shubbair
0644cc101b Updating MLX Notebook 2024-07-30 19:50:02 +03:00
Shubbair
ad2b6643c0 Updating GAN Code... 2024-07-30 16:59:35 +03:00
Shubbair
3bea855bd2 Updating GAN Code... 2024-07-30 13:45:09 +03:00
Shubbair
c2d731d8a3 Updating GAN Code... 2024-07-30 13:24:53 +03:00
Shubbair
ba52447385 Updating GAN Code... 2024-07-30 13:21:38 +03:00
Shubbair
1e386b5c20 Updating GAN Code... 2024-07-30 02:56:13 +03:00
Shubbair
7438b54ecd Updating GAN Code... 2024-07-30 02:44:41 +03:00
Shubbair
7fea34d65e Updating GAN Code... 2024-07-30 02:37:09 +03:00
Shubbair
f505fe6e55 Updating GAN Code... 2024-07-30 02:17:12 +03:00
Shubbair
4e80759b39 Updating GAN Code... 2024-07-30 02:06:52 +03:00
Shubbair
306e53c402 Updating GAN Code... 2024-07-29 19:44:16 +03:00
Shubbair
bacaa9ec0e Updating GAN Code... 2024-07-29 01:30:08 +03:00
Shubbair
8d27be1442 Updating GAN Code... 2024-07-29 01:24:50 +03:00
Shubbair
4de0583b49 Updating GAN Code... 2024-07-28 19:18:35 +03:00
Shubbair
a07ef6d03b Updating GAN Code... 2024-07-28 18:11:39 +03:00
Shubbair
c0c8293842 Updating GAN Code... 2024-07-28 17:56:26 +03:00
Shubbair
d17d293df9 Updating GAN Code... 2024-07-28 17:35:36 +03:00
Shubbair
3e63cd93fe Updating GAN Code... 2024-07-28 17:26:24 +03:00
Shubbair
3716501e8d Updating GAN Code... 2024-07-28 17:22:40 +03:00
Shubbair
88a20b7276 Updating GAN Code... 2024-07-28 01:10:19 +03:00
Shubbair
8b1713737a Updating GAN Code... 2024-07-27 01:20:00 +03:00
Shubbair
f8b7094fb8 Updating GAN Code... 2024-07-27 01:19:50 +03:00
Shubbair
147cb3d2bc Updating GAN Code... 2024-07-27 01:09:51 +03:00
Shubbair
a05608c34d Updating GAN Code... 2024-07-27 00:22:29 +03:00
Shubbair
f176cce74d Updating GAN Code... 2024-07-27 00:19:08 +03:00
Shubbair
959c623908 Updating GAN Code... 2024-07-26 16:38:55 +03:00
Shubbair
591074bea8 Updating GAN Code... 2024-07-26 16:36:29 +03:00
Shubbair
d426586b03 Updating GAN Code... 2024-07-26 16:07:40 +03:00
Shubbair
5e7ce1048c Add GAN model 25/7 2024-07-25 21:00:41 +03:00
8 changed files with 914 additions and 0 deletions

BIN
gan/gen_images/img_0.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 208 KiB

BIN
gan/gen_images/img_100.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 142 KiB

BIN
gan/gen_images/img_200.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 112 KiB

BIN
gan/gen_images/img_300.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 101 KiB

BIN
gan/gen_images/img_400.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 79 KiB

195
gan/main.py Normal file
View File

@ -0,0 +1,195 @@
import mnist
import argparse
import mlx.core as mx
import mlx.nn as nn
import mlx.optimizers as optim
from tqdm import tqdm
import numpy as np
import matplotlib.pyplot as plt
# Generator Block
def GenBlock(in_dim:int,out_dim:int):
return nn.Sequential(
nn.Linear(in_dim,out_dim),
nn.BatchNorm(out_dim, 0.8),
nn.LeakyReLU(0.2)
)
# Generator Model
class Generator(nn.Module):
def __init__(self, z_dim:int = 32, im_dim:int = 784, hidden_dim: int = 256):
super(Generator, self).__init__()
self.gen = nn.Sequential(
GenBlock(z_dim, hidden_dim),
GenBlock(hidden_dim, hidden_dim * 2),
GenBlock(hidden_dim * 2, hidden_dim * 4),
nn.Linear(hidden_dim * 4,im_dim),
)
def __call__(self, noise):
x = self.gen(noise)
return mx.tanh(x)
# make 2D noise with shape n_samples x z_dim
def get_noise(n_samples:list[int], z_dim:int)->list[int]:
return mx.random.normal(shape=(n_samples, z_dim))
#---------------------------------------------#
# Discriminator Block
def DisBlock(in_dim:int,out_dim:int):
return nn.Sequential(
nn.Linear(in_dim,out_dim),
nn.LeakyReLU(negative_slope=0.2),
nn.Dropout(0.3),
)
# Discriminator Model
class Discriminator(nn.Module):
def __init__(self,im_dim:int = 784, hidden_dim:int = 256):
super(Discriminator, self).__init__()
self.disc = nn.Sequential(
DisBlock(im_dim, hidden_dim * 4),
DisBlock(hidden_dim * 4, hidden_dim * 2),
DisBlock(hidden_dim * 2, hidden_dim),
nn.Linear(hidden_dim,1),
nn.Sigmoid()
)
def __call__(self, noise):
return self.disc(noise)
# Discriminator Loss
def disc_loss(gen, disc, real, num_images, z_dim):
noise = mx.array(get_noise(num_images, z_dim))
fake_images = gen(noise)
fake_disc = disc(fake_images)
fake_labels = mx.zeros((fake_images.shape[0],1))
fake_loss = mx.mean(nn.losses.binary_cross_entropy(fake_disc,fake_labels,with_logits=True))
real_disc = mx.array(disc(real))
real_labels = mx.ones((real.shape[0],1))
real_loss = mx.mean(nn.losses.binary_cross_entropy(real_disc,real_labels,with_logits=True))
disc_loss = (fake_loss + real_loss) / 2.0
return disc_loss
# Genearator Loss
def gen_loss(gen, disc, num_images, z_dim):
noise = mx.array(get_noise(num_images, z_dim))
fake_images = gen(noise)
fake_disc = mx.array(disc(fake_images))
fake_labels = mx.ones((fake_images.shape[0],1))
gen_loss = nn.losses.binary_cross_entropy(fake_disc,fake_labels,with_logits=True)
return mx.mean(gen_loss)
# make batch of images
def batch_iterate(batch_size: int, ipt: list[int])-> list[int]:
perm = np.random.permutation(len(ipt))
for s in range(0, len(ipt), batch_size):
ids = perm[s : s + batch_size]
yield ipt[ids]
# plot batch of images at epoch steps
def show_images(epoch_num:int,imgs:list[int],num_imgs:int = 25):
if (imgs.shape[0] > 0):
fig,axes = plt.subplots(5, 5, figsize=(5, 5))
for i, ax in enumerate(axes.flat):
img = mx.array(imgs[i]).reshape(28,28)
ax.imshow(img,cmap='gray')
ax.axis('off')
plt.tight_layout()
plt.savefig('gen_images/img_{}.png'.format(epoch_num))
plt.show()
def main(args:dict):
seed = 42
n_epochs = 500
z_dim = 128
batch_size = 128
lr = 2e-5
mx.random.seed(seed)
# Load the data
train_images,*_ = map(np.array, getattr(mnist,'mnist')())
# Normalization images => [-1,1]
train_images = train_images * 2.0 - 1.0
gen = Generator(z_dim)
mx.eval(gen.parameters())
gen_opt = optim.Adam(learning_rate=lr, betas=[0.5, 0.999])
disc = Discriminator()
mx.eval(disc.parameters())
disc_opt = optim.Adam(learning_rate=lr, betas=[0.5, 0.999])
# TODO training...
D_loss_grad = nn.value_and_grad(disc, disc_loss)
G_loss_grad = nn.value_and_grad(gen, gen_loss)
for epoch in tqdm(range(n_epochs)):
for idx,real in enumerate(batch_iterate(batch_size, train_images)):
# TODO Train Discriminator
D_loss,D_grads = D_loss_grad(gen, disc,mx.array(real), batch_size, z_dim)
# Update optimizer
disc_opt.update(disc, D_grads)
# Update gradients
mx.eval(disc.parameters(), disc_opt.state)
# TODO Train Generator
G_loss,G_grads = G_loss_grad(gen, disc, batch_size, z_dim)
# Update optimizer
gen_opt.update(gen, G_grads)
# Update gradients
mx.eval(gen.parameters(), gen_opt.state)
if epoch%100==0:
print("Epoch: {}, iteration: {}, Discriminator Loss:{}, Generator Loss: {}".format(epoch,idx,D_loss,G_loss))
fake_noise = mx.array(get_noise(batch_size, z_dim))
fake = gen(fake_noise)
show_images(epoch,fake)
if __name__ == "__main__":
parser = argparse.ArgumentParser("Train a simple GAN on MNIST with MLX.")
parser.add_argument("--gpu", action="store_true", help="Use the Metal back-end.")
parser.add_argument(
"--dataset",
type=str,
default="mnist",
choices=["mnist", "fashion_mnist"],
help="The dataset to use.",
)
args = parser.parse_args()
if not args.gpu:
mx.set_default_device(mx.cpu)
main(args)

83
gan/mnist.py Normal file
View File

@ -0,0 +1,83 @@
# Copyright © 2023 Apple Inc.
import gzip
import os
import pickle
from urllib import request
import numpy as np
def mnist(
save_dir="/tmp",
base_url="https://raw.githubusercontent.com/fgnt/mnist/master/",
filename="mnist.pkl",
):
"""
Load the MNIST dataset in 4 tensors: train images, train labels,
test images, and test labels.
Checks `save_dir` for already downloaded data otherwise downloads.
Download code modified from:
https://github.com/hsjeong5/MNIST-for-Numpy
"""
def download_and_save(save_file):
filename = [
["training_images", "train-images-idx3-ubyte.gz"],
["test_images", "t10k-images-idx3-ubyte.gz"],
["training_labels", "train-labels-idx1-ubyte.gz"],
["test_labels", "t10k-labels-idx1-ubyte.gz"],
]
mnist = {}
for name in filename:
out_file = os.path.join("/tmp", name[1])
request.urlretrieve(base_url + name[1], out_file)
for name in filename[:2]:
out_file = os.path.join("/tmp", name[1])
with gzip.open(out_file, "rb") as f:
mnist[name[0]] = np.frombuffer(f.read(), np.uint8, offset=16).reshape(
-1, 28 * 28
)
for name in filename[-2:]:
out_file = os.path.join("/tmp", name[1])
with gzip.open(out_file, "rb") as f:
mnist[name[0]] = np.frombuffer(f.read(), np.uint8, offset=8)
with open(save_file, "wb") as f:
pickle.dump(mnist, f)
save_file = os.path.join(save_dir, filename)
if not os.path.exists(save_file):
download_and_save(save_file)
with open(save_file, "rb") as f:
mnist = pickle.load(f)
def preproc(x):
return x.astype(np.float32) / 255.0
mnist["training_images"] = preproc(mnist["training_images"])
mnist["test_images"] = preproc(mnist["test_images"])
return (
mnist["training_images"],
mnist["training_labels"].astype(np.uint32),
mnist["test_images"],
mnist["test_labels"].astype(np.uint32),
)
def fashion_mnist(save_dir="/tmp"):
return mnist(
save_dir,
base_url="http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/",
filename="fashion_mnist.pkl",
)
if __name__ == "__main__":
train_x, train_y, test_x, test_y = mnist()
assert train_x.shape == (60000, 28 * 28), "Wrong training set size"
assert train_y.shape == (60000,), "Wrong training set size"
assert test_x.shape == (10000, 28 * 28), "Wrong test set size"
assert test_y.shape == (10000,), "Wrong test set size"

636
gan/playground.ipynb Normal file

File diff suppressed because one or more lines are too long