mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-11-04 10:38:10 +08:00 
			
		
		
		
	* export and import functions * refactor + works for few primitives * nit * allow primitives with state * nit * nit * simplify serialize / deserialize * fix for constants * python bindings * maybe fix serialize failure case * add example * more primitives, training kind of works * same result for python and c++ * some fixes * fix export * template it up * some simplificatoin * rebase * allow kwargs and multiple functions * exporter * more primitives for exporting * deal with endianness * handle invalid stream * add docstring
		
			
				
	
	
		
			77 lines
		
	
	
		
			2.4 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			77 lines
		
	
	
		
			2.4 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
# Copyright © 2024 Apple Inc.
 | 
						|
 | 
						|
import mlx.core as mx
 | 
						|
import mlx.nn as nn
 | 
						|
import mlx.optimizers as optim
 | 
						|
import mlx.utils
 | 
						|
 | 
						|
 | 
						|
class MLP(nn.Module):
 | 
						|
    """A simple MLP."""
 | 
						|
 | 
						|
    def __init__(
 | 
						|
        self, num_layers: int, input_dim: int, hidden_dim: int, output_dim: int
 | 
						|
    ):
 | 
						|
        super().__init__()
 | 
						|
        layer_sizes = [input_dim] + [hidden_dim] * num_layers + [output_dim]
 | 
						|
        self.layers = [
 | 
						|
            nn.Linear(idim, odim)
 | 
						|
            for idim, odim in zip(layer_sizes[:-1], layer_sizes[1:])
 | 
						|
        ]
 | 
						|
 | 
						|
    def __call__(self, x):
 | 
						|
        for l in self.layers[:-1]:
 | 
						|
            x = nn.relu(l(x))
 | 
						|
        return self.layers[-1](x)
 | 
						|
 | 
						|
 | 
						|
if __name__ == "__main__":
 | 
						|
 | 
						|
    batch_size = 8
 | 
						|
    input_dim = 32
 | 
						|
    output_dim = 10
 | 
						|
 | 
						|
    def init():
 | 
						|
        # Seed for the parameter initialization
 | 
						|
        mx.random.seed(0)
 | 
						|
        model = MLP(
 | 
						|
            num_layers=3, input_dim=input_dim, hidden_dim=64, output_dim=output_dim
 | 
						|
        )
 | 
						|
        optimizer = optim.SGD(learning_rate=1e-1)
 | 
						|
        optimizer.init(model.parameters())
 | 
						|
        state = [model.parameters(), optimizer.state]
 | 
						|
        tree_structure, state = zip(*mlx.utils.tree_flatten(state))
 | 
						|
        return model, optimizer, tree_structure, state
 | 
						|
 | 
						|
    # Export the model parameter initialization
 | 
						|
    model, optimizer, tree_structure, state = init()
 | 
						|
    mx.eval(state)
 | 
						|
    mx.export_function("init_mlp.mlxfn", lambda: init()[-1])
 | 
						|
 | 
						|
    def loss_fn(params, X, y):
 | 
						|
        model.update(params)
 | 
						|
        return nn.losses.cross_entropy(model(X), y, reduction="mean")
 | 
						|
 | 
						|
    def step(*inputs):
 | 
						|
        *state, X, y = inputs
 | 
						|
        params, opt_state = mlx.utils.tree_unflatten(list(zip(tree_structure, state)))
 | 
						|
        optimizer.state = opt_state
 | 
						|
        loss, grads = mx.value_and_grad(loss_fn)(params, X, y)
 | 
						|
        params = optimizer.apply_gradients(grads, params)
 | 
						|
        _, state = zip(*mlx.utils.tree_flatten([params, optimizer.state]))
 | 
						|
        return *state, loss
 | 
						|
 | 
						|
    # Make some random data
 | 
						|
    mx.random.seed(42)
 | 
						|
    example_X = mx.random.normal(shape=(batch_size, input_dim))
 | 
						|
    example_y = mx.random.randint(low=0, high=output_dim, shape=(batch_size,))
 | 
						|
    mx.export_function("train_mlp.mlxfn", step, *state, example_X, example_y)
 | 
						|
 | 
						|
    # Export one step of SGD
 | 
						|
    imported_step = mx.import_function("train_mlp.mlxfn")
 | 
						|
 | 
						|
    for it in range(100):
 | 
						|
        *state, loss = imported_step(*state, example_X, example_y)
 | 
						|
        if it % 10 == 0:
 | 
						|
            print(f"Loss {loss.item():.6}")
 |