mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-06 08:24:39 +08:00
jagrit's commit files
This commit is contained in:
124
python/mlx/nn/layers/convolution.py
Normal file
124
python/mlx/nn/layers/convolution.py
Normal file
@@ -0,0 +1,124 @@
|
||||
import math
|
||||
from typing import Union
|
||||
|
||||
import mlx.core as mx
|
||||
from mlx.nn.layers.base import Module
|
||||
|
||||
|
||||
class Conv1d(Module):
|
||||
"""Applies a 1-dimensional convolution over the multi-channel input sequence.
|
||||
|
||||
The channels are expected to be last i.e. the input shape should be ``NLC`` where:
|
||||
- ``N`` is the batch dimension
|
||||
- ``L`` is the sequence length
|
||||
- ``C`` is the number of input channels
|
||||
|
||||
Args:
|
||||
in_channels (int): The number of input channels
|
||||
out_channels (int): The number of output channels
|
||||
kernel_size (int): The size of the convolution filters
|
||||
stride (int, optional): The stride when applying the filter.
|
||||
Default: 1.
|
||||
padding (int, optional): How many positions to 0-pad the input with.
|
||||
Default: 0.
|
||||
bias (bool, optional): If ``True`` add a learnable bias to the output.
|
||||
Default: ``True``
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
kernel_size: int,
|
||||
stride: int = 1,
|
||||
padding: int = 0,
|
||||
bias: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
scale = math.sqrt(1 / (in_channels * kernel_size))
|
||||
self.weight = mx.random.uniform(
|
||||
low=-scale,
|
||||
high=scale,
|
||||
shape=(out_channels, kernel_size, in_channels),
|
||||
)
|
||||
if bias:
|
||||
self.bias = mx.zeros((out_channels,))
|
||||
|
||||
self.padding = padding
|
||||
self.stride = stride
|
||||
|
||||
def _extra_repr(self):
|
||||
return (
|
||||
f"{self.weight.shape[-1]}, {self.weight.shape[0]}, "
|
||||
f"kernel_size={self.weight.shape[1]}, stride={self.stride}, "
|
||||
f"padding={self.padding}, bias={'bias' in self}"
|
||||
)
|
||||
|
||||
def __call__(self, x):
|
||||
y = mx.conv1d(x, self.weight, self.stride, self.padding)
|
||||
if "bias" in self:
|
||||
y = y + self.bias
|
||||
return y
|
||||
|
||||
|
||||
class Conv2d(Module):
|
||||
"""Applies a 2-dimensional convolution over the multi-channel input image.
|
||||
|
||||
The channels are expected to be last i.e. the input shape should be ``NHWC`` where:
|
||||
- ``N`` is the batch dimension
|
||||
- ``H`` is the input image height
|
||||
- ``W`` is the input image width
|
||||
- ``C`` is the number of input channels
|
||||
|
||||
Args:
|
||||
in_channels (int): The number of input channels.
|
||||
out_channels (int): The number of output channels.
|
||||
kernel_size (int or tuple): The size of the convolution filters.
|
||||
stride (int or tuple, optional): The size of the stride when
|
||||
applying the filter. Default: 0.
|
||||
padding (int or tuple, optional): How many positions to 0-pad
|
||||
the input with. Default: 0.
|
||||
bias (bool, optional): If ``True`` add a learnable bias to the
|
||||
output. Default: ``True``
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
kernel_size: Union[int, tuple],
|
||||
stride: Union[int, tuple] = 1,
|
||||
padding: Union[int, tuple] = 0,
|
||||
bias: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
kernel_size, stride, padding = map(
|
||||
lambda x: (x, x) if isinstance(x, int) else x,
|
||||
(kernel_size, stride, padding),
|
||||
)
|
||||
scale = math.sqrt(1 / (in_channels * kernel_size[0] * kernel_size[1]))
|
||||
self.weight = mx.random.uniform(
|
||||
low=-scale,
|
||||
high=scale,
|
||||
shape=(out_channels, *kernel_size, in_channels),
|
||||
)
|
||||
if bias:
|
||||
self.bias = mx.zeros((out_channels,))
|
||||
|
||||
self.padding = padding
|
||||
self.stride = stride
|
||||
|
||||
def _extra_repr(self):
|
||||
return (
|
||||
f"{self.weight.shape[-1]}, {self.weight.shape[0]}, "
|
||||
f"kernel_size={self.weight.shape[1:2]}, stride={self.stride}, "
|
||||
f"padding={self.padding}, bias={'bias' in self}"
|
||||
)
|
||||
|
||||
def __call__(self, x):
|
||||
y = mx.conv2d(x, self.weight, self.stride, self.padding)
|
||||
if "bias" in self:
|
||||
y = y + self.bias
|
||||
return y
|
28
python/mlx/nn/layers/embedding.py
Normal file
28
python/mlx/nn/layers/embedding.py
Normal file
@@ -0,0 +1,28 @@
|
||||
import math
|
||||
|
||||
import mlx.core as mx
|
||||
from mlx.nn.layers.base import Module
|
||||
|
||||
|
||||
class Embedding(Module):
|
||||
"""Implements a simple lookup table that maps each input integer to a
|
||||
high-dimensional vector.
|
||||
|
||||
Typically used to embed discrete tokens for processing by neural networks.
|
||||
|
||||
Args:
|
||||
num_embeddings (int): How many possible discrete tokens can we embed.
|
||||
Usually called the vocabulary size.
|
||||
dims (int): The dimensionality of the embeddings.
|
||||
"""
|
||||
|
||||
def __init__(self, num_embeddings: int, dims: int):
|
||||
super().__init__()
|
||||
scale = math.sqrt(1 / dims)
|
||||
self.weight = mx.random.normal((num_embeddings, dims)) * scale
|
||||
|
||||
def _extra_repr(self):
|
||||
return f"{self.weight.shape[0]}, {self.weight.shape[1]}"
|
||||
|
||||
def __call__(self, x):
|
||||
return self.weight[x]
|
34
python/mlx/nn/layers/linear.py
Normal file
34
python/mlx/nn/layers/linear.py
Normal file
@@ -0,0 +1,34 @@
|
||||
import math
|
||||
|
||||
import mlx.core as mx
|
||||
from mlx.nn.layers.base import Module
|
||||
|
||||
|
||||
class Linear(Module):
|
||||
"""Applies an affine transformation to the input.
|
||||
|
||||
Args:
|
||||
input_dims (int): The dimensionality of the input features
|
||||
output_dims (int): The dimensionality of the output features
|
||||
bias (bool): If set to False then the layer will not use a bias
|
||||
"""
|
||||
|
||||
def __init__(self, input_dims: int, output_dims: int, bias: bool = True):
|
||||
super().__init__()
|
||||
scale = math.sqrt(1 / input_dims)
|
||||
self.weight = mx.random.uniform(
|
||||
low=-scale,
|
||||
high=scale,
|
||||
shape=(output_dims, input_dims),
|
||||
)
|
||||
if bias:
|
||||
self.bias = mx.zeros((output_dims,))
|
||||
|
||||
def _extra_repr(self):
|
||||
return f"input_dims={self.weight.shape[1]}, output_dims={self.weight.shape[0]}, bias={'bias' in self}"
|
||||
|
||||
def __call__(self, x):
|
||||
x = x @ self.weight.T
|
||||
if "bias" in self:
|
||||
x = x + self.bias
|
||||
return x
|
19
python/src/load.h
Normal file
19
python/src/load.h
Normal file
@@ -0,0 +1,19 @@
|
||||
#pragma once
|
||||
|
||||
#include <pybind11/pybind11.h>
|
||||
#include <unordered_map>
|
||||
#include <variant>
|
||||
#include "mlx/ops.h"
|
||||
|
||||
namespace py = pybind11;
|
||||
using namespace mlx::core;
|
||||
|
||||
using DictOrArray = std::variant<array, std::unordered_map<std::string, array>>;
|
||||
|
||||
DictOrArray mlx_load_helper(py::object file, StreamOrDevice s);
|
||||
void mlx_save_helper(py::object file, array a, bool retain_graph = true);
|
||||
void mlx_savez_helper(
|
||||
py::object file,
|
||||
py::args args,
|
||||
const py::kwargs& kwargs,
|
||||
bool compressed = false);
|
31
python/src/mlx.cpp
Normal file
31
python/src/mlx.cpp
Normal file
@@ -0,0 +1,31 @@
|
||||
#include <pybind11/pybind11.h>
|
||||
|
||||
#define STRINGIFY(x) #x
|
||||
#define TOSTRING(x) STRINGIFY(x)
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
void init_array(py::module_&);
|
||||
void init_device(py::module_&);
|
||||
void init_stream(py::module_&);
|
||||
void init_metal(py::module_&);
|
||||
void init_ops(py::module_&);
|
||||
void init_transforms(py::module_&);
|
||||
void init_random(py::module_&);
|
||||
void init_fft(py::module_&);
|
||||
|
||||
PYBIND11_MODULE(core, m) {
|
||||
m.doc() = "mlx: A framework for machine learning on Apple Silicon.";
|
||||
|
||||
auto reprlib_fix = py::module_::import("mlx._reprlib_fix");
|
||||
|
||||
init_device(m);
|
||||
init_stream(m);
|
||||
init_array(m);
|
||||
init_metal(m);
|
||||
init_ops(m);
|
||||
init_transforms(m);
|
||||
init_random(m);
|
||||
init_fft(m);
|
||||
m.attr("__version__") = TOSTRING(_VERSION_);
|
||||
}
|
723
python/src/transforms.cpp
Normal file
723
python/src/transforms.cpp
Normal file
@@ -0,0 +1,723 @@
|
||||
#include <pybind11/functional.h>
|
||||
#include <pybind11/pybind11.h>
|
||||
#include <pybind11/stl.h>
|
||||
#include <algorithm>
|
||||
#include <fstream>
|
||||
#include <numeric>
|
||||
#include <sstream>
|
||||
|
||||
#include "mlx/array.h"
|
||||
#include "mlx/graph_utils.h"
|
||||
#include "mlx/transforms.h"
|
||||
#include "mlx/transforms_impl.h"
|
||||
|
||||
namespace py = pybind11;
|
||||
using namespace py::literals;
|
||||
using namespace mlx::core;
|
||||
|
||||
using IntOrVec = std::variant<int, std::vector<int>>;
|
||||
using StrOrVec = std::variant<std::string, std::vector<std::string>>;
|
||||
|
||||
template <typename T>
|
||||
std::vector<T> to_vector(const std::variant<T, std::vector<T>>& v) {
|
||||
std::vector<T> vals;
|
||||
if (auto pv = std::get_if<T>(&v); pv) {
|
||||
vals.push_back(*pv);
|
||||
} else {
|
||||
vals = std::get<std::vector<T>>(v);
|
||||
}
|
||||
return vals;
|
||||
}
|
||||
|
||||
void tree_visit(py::object tree, std::function<void(py::handle)> visitor) {
|
||||
std::function<void(py::handle)> recurse;
|
||||
recurse = [&](py::handle subtree) {
|
||||
if (py::isinstance<py::list>(subtree) ||
|
||||
py::isinstance<py::tuple>(subtree)) {
|
||||
for (auto item : subtree) {
|
||||
recurse(item);
|
||||
}
|
||||
} else if (py::isinstance<py::dict>(subtree)) {
|
||||
for (auto item : py::cast<py::dict>(subtree)) {
|
||||
recurse(item.second);
|
||||
}
|
||||
} else {
|
||||
visitor(subtree);
|
||||
}
|
||||
};
|
||||
|
||||
recurse(tree);
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename V>
|
||||
void validate_subtrees(const std::vector<py::object>& subtrees) {
|
||||
int len = py::cast<T>(subtrees[0]).size();
|
||||
for (auto& subtree : subtrees) {
|
||||
if ((py::isinstance<T>(subtree) && py::cast<T>(subtree).size() != len) ||
|
||||
py::isinstance<U>(subtree) || py::isinstance<V>(subtree)) {
|
||||
throw std::invalid_argument(
|
||||
"[tree_map] Additional input tree is not a valid prefix of the first tree.");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
py::object tree_map(
|
||||
const std::vector<py::object>& trees,
|
||||
std::function<py::object(const std::vector<py::object>&)> transform) {
|
||||
std::function<py::object(const std::vector<py::object>&)> recurse;
|
||||
|
||||
recurse = [&](const std::vector<py::object>& subtrees) {
|
||||
if (py::isinstance<py::list>(subtrees[0])) {
|
||||
py::list l;
|
||||
std::vector<py::object> items(subtrees.size());
|
||||
validate_subtrees<py::list, py::tuple, py::dict>(subtrees);
|
||||
for (int i = 0; i < py::cast<py::list>(subtrees[0]).size(); ++i) {
|
||||
for (int j = 0; j < subtrees.size(); ++j) {
|
||||
if (py::isinstance<py::list>(subtrees[j])) {
|
||||
items[j] = py::cast<py::list>(subtrees[j])[i];
|
||||
} else {
|
||||
items[j] = subtrees[j];
|
||||
}
|
||||
}
|
||||
l.append(recurse(items));
|
||||
}
|
||||
return py::cast<py::object>(l);
|
||||
} else if (py::isinstance<py::tuple>(subtrees[0])) {
|
||||
// Check the rest of the subtrees
|
||||
std::vector<py::object> items(subtrees.size());
|
||||
int len = py::cast<py::tuple>(subtrees[0]).size();
|
||||
py::tuple l(len);
|
||||
validate_subtrees<py::tuple, py::list, py::dict>(subtrees);
|
||||
for (int i = 0; i < len; ++i) {
|
||||
for (int j = 0; j < subtrees.size(); ++j) {
|
||||
if (py::isinstance<py::tuple>(subtrees[j])) {
|
||||
items[j] = py::cast<py::tuple>(subtrees[j])[i];
|
||||
} else {
|
||||
items[j] = subtrees[j];
|
||||
}
|
||||
}
|
||||
l[i] = recurse(items);
|
||||
}
|
||||
return py::cast<py::object>(l);
|
||||
} else if (py::isinstance<py::dict>(subtrees[0])) {
|
||||
std::vector<py::object> items(subtrees.size());
|
||||
validate_subtrees<py::dict, py::list, py::tuple>(subtrees);
|
||||
py::dict d;
|
||||
for (auto item : py::cast<py::dict>(subtrees[0])) {
|
||||
for (int j = 0; j < subtrees.size(); ++j) {
|
||||
if (py::isinstance<py::dict>(subtrees[j])) {
|
||||
auto subdict = py::cast<py::dict>(subtrees[j]);
|
||||
if (!subdict.contains(item.first)) {
|
||||
throw std::invalid_argument(
|
||||
"[tree_map] Tree is not a valid prefix tree of the first tree.");
|
||||
}
|
||||
items[j] = subdict[item.first];
|
||||
} else {
|
||||
items[j] = subtrees[j];
|
||||
}
|
||||
}
|
||||
d[item.first] = recurse(items);
|
||||
}
|
||||
return py::cast<py::object>(d);
|
||||
} else {
|
||||
return transform(subtrees);
|
||||
}
|
||||
};
|
||||
return recurse(trees);
|
||||
}
|
||||
|
||||
py::object tree_map(
|
||||
py::object tree,
|
||||
std::function<py::object(py::handle)> transform) {
|
||||
return tree_map({tree}, [&](std::vector<py::object> inputs) {
|
||||
return transform(inputs[0]);
|
||||
});
|
||||
}
|
||||
|
||||
std::vector<array> tree_flatten(py::object tree, bool strict = true) {
|
||||
std::vector<array> flat_tree;
|
||||
|
||||
tree_visit(tree, [&](py::handle obj) {
|
||||
if (py::isinstance<array>(obj)) {
|
||||
flat_tree.push_back(py::cast<array>(obj));
|
||||
} else if (strict) {
|
||||
throw std::invalid_argument("Argument is not an array");
|
||||
}
|
||||
});
|
||||
|
||||
return flat_tree;
|
||||
}
|
||||
|
||||
py::object tree_unflatten(
|
||||
py::object tree,
|
||||
const std::vector<array>& values,
|
||||
int index = 0) {
|
||||
return tree_map(tree, [&](py::handle obj) {
|
||||
if (py::isinstance<array>(obj)) {
|
||||
return py::cast(values[index++]);
|
||||
} else {
|
||||
return py::cast<py::object>(obj);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
auto validate_argnums_argnames(
|
||||
const std::optional<IntOrVec>& argnums,
|
||||
const StrOrVec& argnames) {
|
||||
auto vec_names = to_vector(argnames);
|
||||
|
||||
if (!argnums.has_value()) {
|
||||
// argnums was not provided and argnames was empty
|
||||
if (vec_names.empty()) {
|
||||
return std::make_pair(std::vector<int>{0}, vec_names);
|
||||
} else {
|
||||
return std::make_pair(std::vector<int>{}, vec_names);
|
||||
}
|
||||
}
|
||||
|
||||
return std::make_pair(to_vector(*argnums), vec_names);
|
||||
}
|
||||
|
||||
auto py_value_and_grad(
|
||||
const py::function& fun,
|
||||
std::vector<int> argnums,
|
||||
std::vector<std::string> argnames,
|
||||
const std::string& error_msg_tag,
|
||||
bool scalar_func_only) {
|
||||
// Sanitize argnums
|
||||
if (argnums.size() == 0 && argnames.size() == 0) {
|
||||
throw std::invalid_argument(
|
||||
error_msg_tag + " Gradient wrt no argument requested");
|
||||
}
|
||||
if (argnums.size() > 0) {
|
||||
std::sort(argnums.begin(), argnums.end());
|
||||
if (argnums[0] < 0) {
|
||||
std::ostringstream msg;
|
||||
msg << error_msg_tag
|
||||
<< " Can't compute the gradient of negative argument index "
|
||||
<< argnums[0];
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
}
|
||||
|
||||
return [fun, argnums, argnames, error_msg_tag, scalar_func_only](
|
||||
const py::args& args, const py::kwargs& kwargs) {
|
||||
// Sanitize the input
|
||||
if (argnums.size() > 0 && argnums.back() >= args.size()) {
|
||||
std::ostringstream msg;
|
||||
msg << error_msg_tag << " Can't compute the gradient of argument index "
|
||||
<< argnums.back() << " because the function is called with only "
|
||||
<< args.size() << " arguments.";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
for (auto& key : argnames) {
|
||||
if (!kwargs.contains(key)) {
|
||||
std::ostringstream msg;
|
||||
msg << error_msg_tag
|
||||
<< " Can't compute the gradient of keyword argument '" << key
|
||||
<< "' because the function is called with the "
|
||||
<< "following keyword arguments {";
|
||||
for (auto item : kwargs) {
|
||||
msg << item.first.cast<std::string>() << ",";
|
||||
}
|
||||
msg << "}";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
}
|
||||
|
||||
// Collect the arrays
|
||||
std::vector<array> arrays;
|
||||
std::vector<int> counts(1, 0);
|
||||
for (auto i : argnums) {
|
||||
auto argsi = tree_flatten(args[i]);
|
||||
arrays.insert(arrays.end(), argsi.begin(), argsi.end());
|
||||
counts.push_back(argsi.size());
|
||||
}
|
||||
for (auto& key : argnames) {
|
||||
auto argsk = tree_flatten(kwargs[key.c_str()]);
|
||||
arrays.insert(arrays.end(), argsk.begin(), argsk.end());
|
||||
counts.push_back(argsk.size());
|
||||
}
|
||||
std::partial_sum(counts.cbegin(), counts.cend(), counts.begin());
|
||||
std::vector<int> gradient_indices(arrays.size());
|
||||
std::iota(gradient_indices.begin(), gradient_indices.end(), 0);
|
||||
|
||||
// value_out will hold the output of the python function in order to be
|
||||
// able to reconstruct the python tree of extra return values
|
||||
py::object py_value_out;
|
||||
auto value_and_grads = value_and_grad(
|
||||
[&fun,
|
||||
&args,
|
||||
&kwargs,
|
||||
&argnums,
|
||||
&argnames,
|
||||
&counts,
|
||||
&py_value_out,
|
||||
&error_msg_tag,
|
||||
scalar_func_only](const std::vector<array>& a) {
|
||||
// Copy the arguments
|
||||
py::args args_cpy = py::tuple(args.size());
|
||||
py::kwargs kwargs_cpy = py::kwargs();
|
||||
int j = 0;
|
||||
for (int i = 0; i < args.size(); ++i) {
|
||||
if (j < argnums.size() && i == argnums[j]) {
|
||||
args_cpy[i] = tree_unflatten(args[i], a, counts[j]);
|
||||
j++;
|
||||
} else {
|
||||
args_cpy[i] = args[i];
|
||||
}
|
||||
}
|
||||
for (auto& key : argnames) {
|
||||
kwargs_cpy[key.c_str()] =
|
||||
tree_unflatten(kwargs[key.c_str()], a, counts[j]);
|
||||
j++;
|
||||
}
|
||||
for (auto item : kwargs) {
|
||||
if (kwargs_cpy.contains(item.first)) {
|
||||
continue;
|
||||
}
|
||||
kwargs_cpy[item.first] = item.second;
|
||||
}
|
||||
|
||||
// Call the python function
|
||||
py_value_out = fun(*args_cpy, **kwargs_cpy);
|
||||
|
||||
// Validate the return value of the python function
|
||||
if (!py::isinstance<array>(py_value_out)) {
|
||||
if (scalar_func_only) {
|
||||
std::ostringstream msg;
|
||||
msg << error_msg_tag << " The return value of the function "
|
||||
<< "whose gradient we want to compute should be a "
|
||||
<< "scalar array; but " << py_value_out.get_type()
|
||||
<< " was returned.";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
if (!py::isinstance<py::tuple>(py_value_out)) {
|
||||
std::ostringstream msg;
|
||||
msg << error_msg_tag << " The return value of the function "
|
||||
<< "whose gradient we want to compute should be either a "
|
||||
<< "scalar array or a tuple with the first value being a "
|
||||
<< "scalar array (Union[array, Tuple[array, Any, ...]]); but "
|
||||
<< py_value_out.get_type() << " was returned.";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
py::tuple ret = py::cast<py::tuple>(py_value_out);
|
||||
if (ret.size() == 0) {
|
||||
std::ostringstream msg;
|
||||
msg << error_msg_tag << " The return value of the function "
|
||||
<< "whose gradient we want to compute should be either a "
|
||||
<< "scalar array or a non-empty tuple. The first value should be a "
|
||||
<< "scalar array and the rest can be anything. Instead, "
|
||||
<< "we got an empty tuple.";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
if (!py::isinstance<array>(ret[0])) {
|
||||
std::ostringstream msg;
|
||||
msg << error_msg_tag << " The return value of the function "
|
||||
<< "whose gradient we want to compute should be either a "
|
||||
<< "scalar array or a tuple with the first value being a "
|
||||
<< "scalar array (Union[array, Tuple[array, Any, ...]]); but it "
|
||||
<< "was a tuple with the first value being of type "
|
||||
<< ret[0].get_type() << " .";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
}
|
||||
|
||||
return tree_flatten(py_value_out, false);
|
||||
},
|
||||
gradient_indices)(arrays);
|
||||
|
||||
auto value = value_and_grads.first;
|
||||
auto gradients = value_and_grads.second;
|
||||
|
||||
// Put the gradients back in their container.
|
||||
// We have the following cases:
|
||||
//
|
||||
// 1. Single python positional argument has a gradient (eg argnums=[0])
|
||||
// 2. Many python positional arguments have gradients (eg argnums=[0, 1])
|
||||
// 3. A python keyword argument has gradients
|
||||
//
|
||||
// In case 1 we return the original python variable but with the gradients.
|
||||
// In case 2 we return a tuple of the above.
|
||||
// In case 3 we return a tuple containing a tuple and dict (sth like
|
||||
// (tuple(), dict(x=mx.array(5))) ).
|
||||
py::object positional_grads;
|
||||
py::object keyword_grads;
|
||||
py::object py_grads;
|
||||
|
||||
// Collect the gradients for the positional arguments
|
||||
if (argnums.size() == 1) {
|
||||
positional_grads = tree_unflatten(args[argnums[0]], gradients, counts[0]);
|
||||
} else if (argnums.size() > 1) {
|
||||
py::tuple grads_(argnums.size());
|
||||
for (int i = 0; i < argnums.size(); i++) {
|
||||
grads_[i] = tree_unflatten(args[argnums[i]], gradients, counts[i]);
|
||||
}
|
||||
positional_grads = py::cast<py::object>(grads_);
|
||||
} else {
|
||||
positional_grads = py::none();
|
||||
}
|
||||
|
||||
// No keyword argument gradients so return the tuple of gradients
|
||||
if (argnames.size() == 0) {
|
||||
py_grads = positional_grads;
|
||||
} else {
|
||||
py::dict grads_;
|
||||
for (int i = 0; i < argnames.size(); i++) {
|
||||
auto& k = argnames[i];
|
||||
grads_[k.c_str()] = tree_unflatten(
|
||||
kwargs[k.c_str()], gradients, counts[i + argnums.size()]);
|
||||
}
|
||||
keyword_grads = py::cast<py::object>(grads_);
|
||||
|
||||
py_grads =
|
||||
py::cast<py::object>(py::make_tuple(positional_grads, keyword_grads));
|
||||
}
|
||||
|
||||
// Put the values back in the container
|
||||
py::object return_value = tree_unflatten(py_value_out, value);
|
||||
return std::make_pair(return_value, py_grads);
|
||||
};
|
||||
}
|
||||
|
||||
auto py_vmap(
|
||||
const py::function& fun,
|
||||
const py::object& in_axes,
|
||||
const py::object& out_axes) {
|
||||
return [fun, in_axes, out_axes](const py::args& args) {
|
||||
auto axes_to_flat_tree = [](const py::object& tree,
|
||||
const py::object& axes) {
|
||||
auto tree_axes = tree_map(
|
||||
{tree, axes},
|
||||
[](const std::vector<py::object>& inputs) { return inputs[1]; });
|
||||
std::vector<int> flat_axes;
|
||||
tree_visit(tree_axes, [&flat_axes](py::handle obj) {
|
||||
if (obj.is_none()) {
|
||||
flat_axes.push_back(-1);
|
||||
} else if (py::isinstance<py::int_>(obj)) {
|
||||
flat_axes.push_back(py::cast<int>(py::cast<py::int_>(obj)));
|
||||
} else {
|
||||
throw std::invalid_argument("[vmap] axis must be int or None.");
|
||||
}
|
||||
});
|
||||
return flat_axes;
|
||||
};
|
||||
|
||||
// Inputs must be array or tree of arrays
|
||||
auto inputs = tree_flatten(args, true);
|
||||
auto flat_in_axes = axes_to_flat_tree(args, in_axes);
|
||||
|
||||
// py_value_out will hold the output of the python function in order to be
|
||||
// able to reconstruct the python tree of extra return values
|
||||
py::object py_outputs;
|
||||
|
||||
auto vmap_fn =
|
||||
[&fun, &args, &inputs, &py_outputs](const std::vector<array>& a) {
|
||||
// Call the python function
|
||||
py_outputs = fun(*tree_unflatten(args, a));
|
||||
|
||||
// Flatten the outputs
|
||||
return tree_flatten(py_outputs, true);
|
||||
};
|
||||
|
||||
auto [trace_inputs, trace_outputs] =
|
||||
detail::vmap_trace(vmap_fn, inputs, flat_in_axes);
|
||||
|
||||
auto flat_out_axes = axes_to_flat_tree(py_outputs, out_axes);
|
||||
|
||||
// Perform the vmap
|
||||
auto outputs = detail::vmap_replace(
|
||||
inputs, trace_inputs, trace_outputs, flat_in_axes, flat_out_axes);
|
||||
|
||||
// Put the outputs back in the container
|
||||
return tree_unflatten(py_outputs, outputs);
|
||||
};
|
||||
}
|
||||
|
||||
void init_transforms(py::module_& m) {
|
||||
m.def(
|
||||
"eval",
|
||||
[](const py::args& args, bool retain_graph) {
|
||||
std::vector<array> arrays = tree_flatten(args);
|
||||
eval(arrays, retain_graph);
|
||||
},
|
||||
"retain_graph"_a = false,
|
||||
R"pbdoc(
|
||||
Evaluate an :class:`array` or tree of :class:`array`.
|
||||
|
||||
Args:
|
||||
*args (arrays or trees of arrays): Each argument can be a single array
|
||||
or a tree of arrays. If a tree is given the nodes can be a Python
|
||||
:class:`list`, :class:`tuple` or :class:`dict` but the leafs must all be
|
||||
an :class:`array`.
|
||||
retain_graph (bool): Indicate that the graph structure should be
|
||||
preserved. This option is intended to enable function transforms
|
||||
which contain control flow based on the value of an array.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"jvp",
|
||||
[](const py::function& fun,
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& tangents) {
|
||||
auto vfun = [&fun](const std::vector<array>& primals) {
|
||||
py::args args = py::tuple(primals.size());
|
||||
for (int i = 0; i < primals.size(); ++i) {
|
||||
args[i] = primals[i];
|
||||
}
|
||||
auto out = fun(*args);
|
||||
if (py::isinstance<array>(out)) {
|
||||
return std::vector<array>{py::cast<array>(out)};
|
||||
} else {
|
||||
return py::cast<std::vector<array>>(out);
|
||||
}
|
||||
};
|
||||
return jvp(vfun, primals, tangents);
|
||||
},
|
||||
"fun"_a,
|
||||
"primals"_a,
|
||||
"tangents"_a,
|
||||
R"pbdoc(
|
||||
Compute the Jacobian-vector product.
|
||||
|
||||
This computes the product of the Jacobian of a function ``fun`` evaluated
|
||||
at ``primals`` with the ``tangents``.
|
||||
|
||||
Args:
|
||||
fun (function): A function which takes a variable number of :class:`array`
|
||||
and returns a single :class:`array` or list of :class:`array`.
|
||||
primals (list(array)): A list of :class:`array` at which to
|
||||
evaluate the Jacobian.
|
||||
tangents (list(array)): A list of :class:`array` which are the
|
||||
"vector" in the Jacobian-vector product. The ``tangents`` should be the
|
||||
same in number, shape, and type as the inputs of ``fun`` (i.e. the ``primals``).
|
||||
|
||||
Returns:
|
||||
list(array): A list of the Jacobian-vector products which
|
||||
is the same in number, shape, and type of the inputs to ``fun``.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"vjp",
|
||||
[](const py::function& fun,
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& cotangents) {
|
||||
auto vfun = [&fun](const std::vector<array>& primals) {
|
||||
py::args args = py::tuple(primals.size());
|
||||
for (int i = 0; i < primals.size(); ++i) {
|
||||
args[i] = primals[i];
|
||||
}
|
||||
auto out = fun(*args);
|
||||
if (py::isinstance<array>(out)) {
|
||||
return std::vector<array>{py::cast<array>(out)};
|
||||
} else {
|
||||
return py::cast<std::vector<array>>(out);
|
||||
}
|
||||
};
|
||||
return vjp(vfun, primals, cotangents);
|
||||
},
|
||||
"fun"_a,
|
||||
"primals"_a,
|
||||
"cotangents"_a,
|
||||
R"pbdoc(
|
||||
Compute the vector-Jacobian product.
|
||||
|
||||
Computes the product of the ``cotangents`` with the Jacobian of a
|
||||
function ``fun`` evaluated at ``primals``.
|
||||
|
||||
Args:
|
||||
fun (function): A function which takes a variable number of :class:`array`
|
||||
and returns a single :class:`array` or list of :class:`array`.
|
||||
primals (list(array)): A list of :class:`array` at which to
|
||||
evaluate the Jacobian.
|
||||
cotangents (list(array)): A list of :class:`array` which are the
|
||||
"vector" in the vector-Jacobian product. The ``cotangents`` should be the
|
||||
same in number, shape, and type as the outputs of ``fun``.
|
||||
|
||||
Returns:
|
||||
list(array): A list of the vector-Jacobian products which
|
||||
is the same in number, shape, and type of the outputs of ``fun``.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"value_and_grad",
|
||||
[](const py::function& fun,
|
||||
const std::optional<IntOrVec>& argnums,
|
||||
const StrOrVec& argnames) {
|
||||
auto [argnums_vec, argnames_vec] =
|
||||
validate_argnums_argnames(argnums, argnames);
|
||||
return py::cpp_function(py_value_and_grad(
|
||||
fun, argnums_vec, argnames_vec, "[value_and_grad]", false));
|
||||
},
|
||||
"fun"_a,
|
||||
"argnums"_a = std::nullopt,
|
||||
"argnames"_a = std::vector<std::string>{},
|
||||
R"pbdoc(
|
||||
Returns a function which computes the value and gradient of ``fun``.
|
||||
|
||||
The function passed to :func:`value_and_grad` should return either
|
||||
a scalar loss or a tuple in which the first element is a scalar
|
||||
loss and the remaining elements can be anything.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import mlx.core as mx
|
||||
|
||||
def mse(params, inputs, targets):
|
||||
outputs = forward(params, inputs)
|
||||
lvalue = (outputs - targets).square().mean()
|
||||
return lvalue
|
||||
|
||||
# Returns lvalue, dlvalue/dparams
|
||||
lvalue, grads = mx.value_and_grad(mse)
|
||||
|
||||
def lasso(params, inputs, targets, a=1.0, b=1.0):
|
||||
outputs = forward(params, inputs)
|
||||
mse = (outputs - targets).square().mean()
|
||||
l1 = mx.abs(outputs - targets).mean()
|
||||
|
||||
loss = a*mse + b*l1
|
||||
|
||||
return loss, mse, l1
|
||||
|
||||
(loss, mse, l1), grads = mx.value_and_grad(lasso)
|
||||
|
||||
Args:
|
||||
fun (function): A function which takes a variable number of
|
||||
:class:`array` or trees of :class:`array` and returns
|
||||
a scalar output :class:`array` or a tuple the first element
|
||||
of which should be a scalar :class:`array`.
|
||||
argnums (int or list(int), optional): Specify the index (or indices)
|
||||
of the positional arguments of ``fun`` to compute the gradient
|
||||
with respect to. If neither ``argnums`` nor ``argnames`` are
|
||||
provided ``argnums`` defaults to ``0`` indicating ``fun``'s first
|
||||
argument.
|
||||
argnames (str or list(str), optional): Specify keyword arguments of
|
||||
``fun`` to compute gradients with respect to. It defaults to [] so
|
||||
no gradients for keyword arguments by default.
|
||||
|
||||
Returns:
|
||||
function: A function which returns a tuple where the first element
|
||||
is the output of `fun` and the second element is the gradients w.r.t.
|
||||
the loss.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"grad",
|
||||
[](const py::function& fun,
|
||||
const std::optional<IntOrVec>& argnums,
|
||||
const StrOrVec& argnames) {
|
||||
auto [argnums_vec, argnames_vec] =
|
||||
validate_argnums_argnames(argnums, argnames);
|
||||
auto fn =
|
||||
py_value_and_grad(fun, argnums_vec, argnames_vec, "[grad]", true);
|
||||
return py::cpp_function(
|
||||
[fn](const py::args& args, const py::kwargs& kwargs) {
|
||||
return fn(args, kwargs).second;
|
||||
});
|
||||
},
|
||||
"fun"_a,
|
||||
"argnums"_a = std::nullopt,
|
||||
"argnames"_a = std::vector<std::string>{},
|
||||
R"pbdoc(
|
||||
Returns a function which computes the gradient of ``fun``.
|
||||
|
||||
Args:
|
||||
fun (function): A function which takes a variable number of
|
||||
:class:`array` or trees of :class:`array` and returns
|
||||
a scalar output :class:`array`.
|
||||
argnums (int or list(int), optional): Specify the index (or indices)
|
||||
of the positional arguments of ``fun`` to compute the gradient
|
||||
with respect to. If neither ``argnums`` nor ``argnames`` are
|
||||
provided ``argnums`` defaults to ``0`` indicating ``fun``'s first
|
||||
argument.
|
||||
argnames (str or list(str), optional): Specify keyword arguments of
|
||||
``fun`` to compute gradients with respect to. It defaults to [] so
|
||||
no gradients for keyword arguments by default.
|
||||
|
||||
Returns:
|
||||
function: A function which has the same input arguments as ``fun`` and
|
||||
returns the gradient(s).
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"vmap",
|
||||
[](const py::function& fun,
|
||||
const py::object& in_axes,
|
||||
const py::object& out_axes) {
|
||||
return py::cpp_function(py_vmap(fun, in_axes, out_axes));
|
||||
},
|
||||
"fun"_a,
|
||||
"in_axes"_a = 0,
|
||||
"out_axes"_a = 0,
|
||||
R"pbdoc(
|
||||
Returns a vectorized version of ``fun``.
|
||||
|
||||
Args:
|
||||
fun (function): A function which takes a variable number of
|
||||
:class:`array` or a tree of :class:`array` and returns
|
||||
a variable number of :class:`array` or a tree of :class:`array`.
|
||||
in_axes (int, optional): An integer or a valid prefix tree of the
|
||||
inputs to ``fun`` where each node specifies the vmapped axis. If
|
||||
the value is ``None`` then the corresponding input(s) are not vmapped.
|
||||
Defaults to ``0``.
|
||||
out_axes (int, optional): An integer or a valid prefix tree of the
|
||||
outputs of ``fun`` where each node specifies the vmapped axis. If
|
||||
the value is ``None`` then the corresponding outputs(s) are not vmapped.
|
||||
Defaults to ``0``.
|
||||
|
||||
Returns:
|
||||
function: The vectorized function.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"simplify",
|
||||
[](const py::args& args) {
|
||||
std::vector<array> arrays = tree_flatten(args);
|
||||
simplify(arrays);
|
||||
},
|
||||
R"pbdoc(
|
||||
Simplify the graph that computes the arrays.
|
||||
|
||||
Run a few fast graph simplification operations to reuse computation and
|
||||
reduce memory consumption. This function is meant to be run every time
|
||||
so its overhead should be small, approximately 1ms for a graph with a
|
||||
few thousand nodes.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import mlx.core as mx
|
||||
|
||||
def foo(x):
|
||||
y = x @ x
|
||||
z = x @ x
|
||||
return y + z
|
||||
|
||||
x = mx.ones((10, 10))
|
||||
y = foo(x)
|
||||
z = foo(x)
|
||||
|
||||
# Computes the matmul twice
|
||||
mx.eval(y)
|
||||
|
||||
# Computes the matmul once
|
||||
mx.simplify(z)
|
||||
mx.eval(z)
|
||||
|
||||
Args:
|
||||
args: Any number of arrays and/or trees of arrays to be simplified.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"export_to_dot",
|
||||
[](py::object file, const py::args& args) {
|
||||
std::vector<array> arrays = tree_flatten(args);
|
||||
if (py::isinstance<py::str>(file)) {
|
||||
std::ofstream out(py::cast<std::string>(file));
|
||||
export_to_dot(out, arrays);
|
||||
} else if (py::hasattr(file, "write")) {
|
||||
std::ostringstream out;
|
||||
export_to_dot(out, arrays);
|
||||
auto write = file.attr("write");
|
||||
write(out.str());
|
||||
} else {
|
||||
throw std::invalid_argument(
|
||||
"export_to_dot accepts file-like objects or strings to be used as filenames");
|
||||
}
|
||||
},
|
||||
"file"_a);
|
||||
}
|
16
python/tests/mlx_tests.py
Normal file
16
python/tests/mlx_tests.py
Normal file
@@ -0,0 +1,16 @@
|
||||
import os
|
||||
import unittest
|
||||
|
||||
import mlx.core as mx
|
||||
|
||||
|
||||
class MLXTestCase(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.default = mx.default_device()
|
||||
device = os.getenv("DEVICE", None)
|
||||
if device is not None:
|
||||
device = getattr(mx, device)
|
||||
mx.set_default_device(device)
|
||||
|
||||
def tearDown(self):
|
||||
mx.set_default_device(self.default)
|
445
python/tests/test_blas.py
Normal file
445
python/tests/test_blas.py
Normal file
@@ -0,0 +1,445 @@
|
||||
import unittest
|
||||
from itertools import permutations
|
||||
|
||||
import math
|
||||
import mlx.core as mx
|
||||
import numpy as np
|
||||
|
||||
import mlx_tests
|
||||
|
||||
|
||||
class TestBlas(mlx_tests.MLXTestCase):
|
||||
@property
|
||||
def dtypes(self):
|
||||
return ["float32", "float16"] if mx.metal.is_available() else ["float32"]
|
||||
|
||||
def __gemm_test(
|
||||
self,
|
||||
shape_a,
|
||||
shape_b,
|
||||
np_dtype=np.float32,
|
||||
f_np_a=lambda x: x,
|
||||
f_np_b=lambda x: x,
|
||||
f_mx_a=lambda x: x,
|
||||
f_mx_b=lambda x: x,
|
||||
):
|
||||
with self.subTest(
|
||||
dtype=np.dtype(np_dtype).name, shape_a=shape_a, shape_b=shape_b
|
||||
):
|
||||
np.random.seed(42)
|
||||
scale = max(np.sum(shape_a), 128)
|
||||
a_np = np.random.normal(0.0, 1.0 / scale, shape_a).astype(np_dtype)
|
||||
b_np = np.random.normal(0.0, 1.0 / scale, shape_b).astype(np_dtype)
|
||||
|
||||
a_mx = mx.array(a_np)
|
||||
b_mx = mx.array(b_np)
|
||||
|
||||
a_np = f_np_a(a_np.astype(np.float32))
|
||||
b_np = f_np_b(b_np.astype(np.float32))
|
||||
a_mx = f_mx_a(a_mx)
|
||||
b_mx = f_mx_b(b_mx)
|
||||
|
||||
out_npy = a_np @ b_np
|
||||
out_mlx = a_mx @ b_mx
|
||||
|
||||
self.assertListEqual(list(out_npy.shape), list(out_mlx.shape))
|
||||
self.assertTrue(np.allclose(out_mlx, out_npy.astype(np_dtype), atol=1e-5))
|
||||
|
||||
def test_matmul_unaligned(self):
|
||||
if not mx.metal.is_available():
|
||||
return
|
||||
|
||||
for dtype in self.dtypes:
|
||||
np_dtype = getattr(np, dtype)
|
||||
base_shapes = [4, 8, 16, 32, 64, 128]
|
||||
pertubations = [-2, -1, 0, 1, 2]
|
||||
|
||||
for dim in base_shapes:
|
||||
for p in pertubations:
|
||||
shape_a = (dim + p, dim + p)
|
||||
shape_b = (dim + p, dim + p)
|
||||
self.__gemm_test(shape_a, shape_b, np_dtype)
|
||||
|
||||
def test_matmul_shapes(self):
|
||||
if not mx.metal.is_available():
|
||||
return
|
||||
|
||||
shapes = [
|
||||
(1, 2, 1, 1),
|
||||
(1, 1, 2, 1),
|
||||
(3, 23, 457, 3),
|
||||
]
|
||||
|
||||
if mx.default_device() == mx.gpu:
|
||||
shapes += [
|
||||
(16, 768, 768, 128),
|
||||
]
|
||||
|
||||
for dtype in self.dtypes:
|
||||
np_dtype = getattr(np, dtype)
|
||||
|
||||
for B, M, N, K in shapes:
|
||||
|
||||
with self.subTest(tranpose="nn"):
|
||||
shape_a = (B, M, K)
|
||||
shape_b = (B, K, N)
|
||||
self.__gemm_test(shape_a, shape_b, np_dtype)
|
||||
|
||||
with self.subTest(tranpose="nt"):
|
||||
shape_a = (B, M, K)
|
||||
shape_b = (B, N, K)
|
||||
self.__gemm_test(
|
||||
shape_a,
|
||||
shape_b,
|
||||
np_dtype,
|
||||
f_np_b=lambda x: np.transpose(x, (0, 2, 1)),
|
||||
f_mx_b=lambda x: mx.transpose(x, (0, 2, 1)),
|
||||
)
|
||||
|
||||
with self.subTest(tranpose="tn"):
|
||||
shape_a = (B, K, M)
|
||||
shape_b = (B, K, N)
|
||||
self.__gemm_test(
|
||||
shape_a,
|
||||
shape_b,
|
||||
np_dtype,
|
||||
f_np_a=lambda x: np.transpose(x, (0, 2, 1)),
|
||||
f_mx_a=lambda x: mx.transpose(x, (0, 2, 1)),
|
||||
)
|
||||
|
||||
with self.subTest(tranpose="tt"):
|
||||
shape_a = (B, K, M)
|
||||
shape_b = (B, N, K)
|
||||
self.__gemm_test(
|
||||
shape_a,
|
||||
shape_b,
|
||||
np_dtype,
|
||||
f_np_a=lambda x: np.transpose(x, (0, 2, 1)),
|
||||
f_mx_a=lambda x: mx.transpose(x, (0, 2, 1)),
|
||||
f_np_b=lambda x: np.transpose(x, (0, 2, 1)),
|
||||
f_mx_b=lambda x: mx.transpose(x, (0, 2, 1)),
|
||||
)
|
||||
|
||||
def test_matmul(self):
|
||||
# Note: so far, matmul only works with floating-point types
|
||||
a = mx.array([[1.0, 2.0], [3.0, 4.0]])
|
||||
|
||||
b = mx.array([[0.0, -1.0], [-3.0, 3.0]])
|
||||
|
||||
expected = [[-6.0, 5.0], [-12.0, 9.0]]
|
||||
|
||||
self.assertEqual((a @ b).tolist(), expected)
|
||||
self.assertEqual(mx.matmul(a, b).tolist(), expected)
|
||||
|
||||
# Transposed matmul
|
||||
np.random.seed(0)
|
||||
a_npy = np.random.normal(0.0, 1.0 / 128, (128, 16)).astype(np.float32)
|
||||
b_npy = np.random.normal(0.0, 1.0 / 128, (128, 16)).astype(np.float32)
|
||||
c_npy = a_npy @ np.transpose(b_npy, (1, 0))
|
||||
d_npy = np.transpose(a_npy, (1, 0)) @ b_npy
|
||||
|
||||
a_mlx = mx.array(a_npy)
|
||||
b_mlx = mx.array(b_npy)
|
||||
c_mlx = a_mlx @ mx.transpose(b_mlx, (1, 0))
|
||||
d_mlx = mx.transpose(a_mlx, (1, 0)) @ b_mlx
|
||||
|
||||
self.assertListEqual(list(c_npy.shape), list(c_mlx.shape))
|
||||
self.assertListEqual(list(d_npy.shape), list(d_mlx.shape))
|
||||
|
||||
self.assertTrue(np.allclose(c_mlx, c_npy, atol=1e-6))
|
||||
self.assertTrue(np.allclose(d_mlx, d_npy, atol=1e-6))
|
||||
|
||||
def test_matmul_dtypes(self):
|
||||
|
||||
for dt in self.dtypes:
|
||||
a_npy = np.random.normal(0.0, 1.0 / 256, (16, 16, 16)).astype(
|
||||
getattr(np, dt)
|
||||
)
|
||||
b_npy = np.random.normal(0.0, 1.0 / 256, (16, 16, 16)).astype(
|
||||
getattr(np, dt)
|
||||
)
|
||||
a_mlx = mx.array(a_npy)
|
||||
b_mlx = mx.array(b_npy)
|
||||
|
||||
c_npy = np.matmul(a_npy, b_npy, dtype=getattr(np, dt))
|
||||
c_mlx = a_mlx @ b_mlx
|
||||
|
||||
self.assertTrue(np.allclose(c_mlx, c_npy, atol=1e-6))
|
||||
|
||||
def test_matmul_batched(self):
|
||||
np.random.seed(0)
|
||||
# Batched matmul
|
||||
a_npy = np.random.normal(0.0, 1.0 / 128, (32, 128, 16)).astype(np.float32)
|
||||
b_npy = np.random.normal(0.0, 1.0 / 128, (32, 16, 16)).astype(np.float32)
|
||||
c_npy = a_npy @ b_npy
|
||||
|
||||
a_mlx = mx.array(a_npy)
|
||||
b_mlx = mx.array(b_npy)
|
||||
c_mlx = a_mlx @ b_mlx
|
||||
|
||||
self.assertListEqual(list(c_npy.shape), list(c_mlx.shape))
|
||||
self.assertTrue(np.allclose(c_mlx, c_npy, atol=1e-6))
|
||||
|
||||
# Batched and transposed matmul
|
||||
b_npy = np.random.normal(0.0, 1.0 / 128, (32, 128, 16)).astype(np.float32)
|
||||
c_npy = a_npy @ np.transpose(b_npy, (0, 2, 1))
|
||||
|
||||
b_mlx = mx.array(b_npy)
|
||||
c_mlx = a_mlx @ mx.transpose(b_mlx, (0, 2, 1))
|
||||
|
||||
self.assertListEqual(list(c_npy.shape), list(c_mlx.shape))
|
||||
self.assertTrue(np.allclose(c_mlx, c_npy, atol=1e-6))
|
||||
|
||||
# Batched matmul with simple broadast
|
||||
a_npy = np.random.normal(0.0, 1.0 / 128, (32, 128, 16)).astype(np.float32)
|
||||
b_npy = np.random.normal(0.0, 1.0 / 128, (16, 16)).astype(np.float32)
|
||||
c_npy = a_npy @ b_npy
|
||||
|
||||
a_mlx = mx.array(a_npy)
|
||||
b_mlx = mx.array(b_npy)
|
||||
c_mlx = a_mlx @ b_mlx
|
||||
|
||||
self.assertListEqual(list(c_npy.shape), list(c_mlx.shape))
|
||||
self.assertTrue(np.allclose(c_mlx, c_npy, atol=1e-6))
|
||||
|
||||
# Both operands broadcasted
|
||||
d_npy = np.broadcast_to(b_npy, (5, 16, 16))
|
||||
d_mlx = mx.broadcast_to(b_mlx, (5, 16, 16))
|
||||
|
||||
e_npy = d_npy @ d_npy
|
||||
e_mlx = d_mlx @ d_mlx
|
||||
|
||||
self.assertListEqual(list(e_npy.shape), list(e_mlx.shape))
|
||||
self.assertTrue(np.allclose(e_mlx, e_npy, atol=1e-6))
|
||||
|
||||
# Batched and transposed matmul with simple broadast
|
||||
a_npy = np.random.normal(0.0, 1.0 / 128, (32, 128, 16)).astype(np.float32)
|
||||
b_npy = np.random.normal(0.0, 1.0 / 128, (128, 16)).astype(np.float32)
|
||||
a_mlx = mx.array(a_npy)
|
||||
b_mlx = mx.array(b_npy)
|
||||
|
||||
c_npy = a_npy @ np.transpose(b_npy, (1, 0))
|
||||
c_mlx = a_mlx @ mx.transpose(b_mlx, (1, 0))
|
||||
|
||||
self.assertListEqual(list(c_npy.shape), list(c_mlx.shape))
|
||||
self.assertTrue(np.allclose(c_mlx, c_npy, atol=1e-6))
|
||||
|
||||
# Matmul with vector
|
||||
a_npy = np.random.normal(0.0, 1.0 / 128, (32, 128, 16)).astype(np.float32)
|
||||
b_npy = np.random.normal(0.0, 1.0 / 128, (16,)).astype(np.float32)
|
||||
a_mlx = mx.array(a_npy)
|
||||
b_mlx = mx.array(b_npy)
|
||||
|
||||
c_npy = a_npy @ b_npy
|
||||
c_mlx = a_mlx @ b_mlx
|
||||
|
||||
self.assertListEqual(list(c_npy.shape), list(c_mlx.shape))
|
||||
self.assertTrue(np.allclose(c_mlx, c_npy, atol=1e-6))
|
||||
|
||||
# Test Multiheaded attention style matmul
|
||||
a_npy = np.random.normal(0.0, 1.0 / 128, (64, 16, 4, 32)).astype(np.float32)
|
||||
b_npy = np.random.normal(0.0, 1.0 / 128, (64, 16, 4, 32)).astype(np.float32)
|
||||
a_mlx = mx.array(a_npy)
|
||||
b_mlx = mx.array(b_npy)
|
||||
|
||||
a_npy = np.transpose(a_npy, (0, 2, 1, 3))
|
||||
b_npy = np.transpose(b_npy, (0, 2, 1, 3))
|
||||
a_mlx = mx.transpose(a_mlx, (0, 2, 1, 3))
|
||||
b_mlx = mx.transpose(b_mlx, (0, 2, 1, 3))
|
||||
|
||||
c_npy = a_npy @ np.transpose(b_npy, (0, 1, 3, 2))
|
||||
c_mlx = a_mlx @ mx.transpose(b_mlx, (0, 1, 3, 2))
|
||||
self.assertListEqual(list(c_npy.shape), list(c_mlx.shape))
|
||||
self.assertTrue(np.allclose(c_mlx, c_npy, atol=1e-6))
|
||||
|
||||
def __gemv_test(
|
||||
self,
|
||||
shape_mat,
|
||||
shape_vec,
|
||||
np_dtype=np.float32,
|
||||
mat_first=True,
|
||||
np_mat_f=lambda x: x,
|
||||
np_vec_f=lambda x: x,
|
||||
mlx_mat_f=lambda x: x,
|
||||
mlx_vec_f=lambda x: x,
|
||||
):
|
||||
with self.subTest(shape=shape_mat):
|
||||
np.random.seed(42)
|
||||
scale = max(np.sum(shape_mat), 32)
|
||||
mat_npy = np.random.normal(0.0, 1.0 / scale, shape_mat).astype(np_dtype)
|
||||
vec_npy = np.random.normal(0.0, 1.0 / scale, shape_vec).astype(np_dtype)
|
||||
|
||||
mat_mlx = mx.array(mat_npy)
|
||||
vec_mlx = mx.array(vec_npy)
|
||||
|
||||
mat_npy = np_mat_f(mat_npy)
|
||||
vec_npy = np_vec_f(vec_npy)
|
||||
mat_mlx = mlx_mat_f(mat_mlx)
|
||||
vec_mlx = mlx_vec_f(vec_mlx)
|
||||
|
||||
if mat_first:
|
||||
out_npy = mat_npy @ vec_npy
|
||||
out_mlx = mat_mlx @ vec_mlx
|
||||
else:
|
||||
out_npy = vec_npy @ mat_npy
|
||||
out_mlx = vec_mlx @ mat_mlx
|
||||
|
||||
self.assertListEqual(list(out_npy.shape), list(out_mlx.shape))
|
||||
self.assertTrue(np.allclose(out_mlx, out_npy, atol=1e-5))
|
||||
|
||||
def test_matrix_vector(self):
|
||||
for dtype in self.dtypes:
|
||||
with self.subTest(dtype=dtype):
|
||||
np_dtype = getattr(np, dtype)
|
||||
|
||||
# Basic square matrix test
|
||||
self.__gemv_test(
|
||||
shape_mat=(64, 64), shape_vec=(64, 1), np_dtype=np_dtype
|
||||
)
|
||||
self.__gemv_test(
|
||||
shape_mat=(64, 64),
|
||||
shape_vec=(64, 1),
|
||||
np_dtype=np_dtype,
|
||||
mat_first=False,
|
||||
np_vec_f=lambda x: np.transpose(x, (1, 0)),
|
||||
mlx_vec_f=lambda x: mx.transpose(x, (1, 0)),
|
||||
)
|
||||
|
||||
# Vector matrix product with aligned and unaligned shapes
|
||||
for in_len_base, out_len_base in (
|
||||
(2, 2),
|
||||
(32, 32),
|
||||
(64, 64),
|
||||
(2048, 2048),
|
||||
):
|
||||
for mi in (-1, 0, 1):
|
||||
for mj in (-1, 0, 1):
|
||||
# Vec mat
|
||||
shape_mat = (in_len_base + mi, out_len_base + mj)
|
||||
shape_vec = (1, in_len_base + mi)
|
||||
self.__gemv_test(
|
||||
shape_mat, shape_vec, mat_first=False, np_dtype=np_dtype
|
||||
)
|
||||
|
||||
# Mat vec
|
||||
shape_mat = (out_len_base + mj, in_len_base + mi)
|
||||
shape_vec = (in_len_base + mi, 1)
|
||||
self.__gemv_test(
|
||||
shape_mat, shape_vec, mat_first=True, np_dtype=np_dtype
|
||||
)
|
||||
|
||||
def test_matrix_vector_batched(self):
|
||||
for dtype in self.dtypes:
|
||||
with self.subTest(dtype=dtype):
|
||||
np_dtype = getattr(np, dtype)
|
||||
|
||||
# Batched mat vec
|
||||
for shape_mat, shape_vec in (
|
||||
((32, 128, 64), (32, 64, 1)),
|
||||
((128, 64), (32, 64, 1)),
|
||||
((32, 128, 64), (64, 1)),
|
||||
):
|
||||
self.__gemv_test(
|
||||
shape_mat, shape_vec, mat_first=True, np_dtype=np_dtype
|
||||
)
|
||||
|
||||
# Batched vec mat
|
||||
for shape_vec, shape_mat in (
|
||||
((32, 1, 128), (32, 128, 64)),
|
||||
((32, 1, 128), (128, 64)),
|
||||
((1, 128), (32, 128, 64)),
|
||||
):
|
||||
self.__gemv_test(
|
||||
shape_mat, shape_vec, mat_first=False, np_dtype=np_dtype
|
||||
)
|
||||
|
||||
def test_matrix_vector_broadcast(self):
|
||||
for dtype in self.dtypes:
|
||||
with self.subTest(dtype=dtype):
|
||||
np_dtype = getattr(np, dtype)
|
||||
|
||||
# Different broadcasts mat vec
|
||||
for shape_mat, shape_vec in (
|
||||
((32, 64, 64), (32, 64, 1)),
|
||||
((64, 64), (32, 64, 1)),
|
||||
((32, 64, 64), (64, 1)),
|
||||
):
|
||||
self.__gemv_test(
|
||||
shape_mat=(64, 64),
|
||||
shape_vec=(64, 1),
|
||||
np_dtype=np_dtype,
|
||||
np_mat_f=(lambda mat_npy: np.broadcast_to(mat_npy, shape_mat)),
|
||||
np_vec_f=(lambda vec_npy: np.broadcast_to(vec_npy, shape_vec)),
|
||||
mlx_mat_f=(lambda mat_mlx: mx.broadcast_to(mat_mlx, shape_mat)),
|
||||
mlx_vec_f=(lambda vec_mlx: mx.broadcast_to(vec_mlx, shape_vec)),
|
||||
)
|
||||
|
||||
# Different broadcasts vec mat
|
||||
for shape_vec, shape_mat in (
|
||||
((32, 1, 64), (32, 64, 64)),
|
||||
((32, 1, 64), (64, 64)),
|
||||
((1, 64), (32, 64, 64)),
|
||||
):
|
||||
self.__gemv_test(
|
||||
shape_mat=(64, 64),
|
||||
shape_vec=(1, 64),
|
||||
np_dtype=np_dtype,
|
||||
mat_first=False,
|
||||
np_mat_f=lambda mat_npy: np.broadcast_to(mat_npy, shape_mat),
|
||||
np_vec_f=lambda vec_npy: np.broadcast_to(vec_npy, shape_vec),
|
||||
mlx_mat_f=lambda mat_mlx: mx.broadcast_to(mat_mlx, shape_mat),
|
||||
mlx_vec_f=lambda vec_mlx: mx.broadcast_to(vec_mlx, shape_vec),
|
||||
)
|
||||
|
||||
def test_matrix_vector_edgecases(self):
|
||||
for dtype in self.dtypes:
|
||||
with self.subTest(dtype=dtype):
|
||||
np_dtype = getattr(np, dtype)
|
||||
|
||||
for in_vec_len in np.arange(1, 5):
|
||||
for out_vec_len in np.arange(1, 5):
|
||||
for batch_size in np.arange(1, 5):
|
||||
with self.subTest(
|
||||
problem_shape=(batch_size, in_vec_len, out_vec_len)
|
||||
):
|
||||
# Matrix vector
|
||||
with self.subTest(transpose=False):
|
||||
a_npy = np.ones(
|
||||
(batch_size, out_vec_len, in_vec_len),
|
||||
dtype=np_dtype,
|
||||
)
|
||||
b_npy = np.ones(
|
||||
(batch_size, in_vec_len, 1), dtype=np_dtype
|
||||
)
|
||||
for i in range(batch_size):
|
||||
b_npy[i] *= i + 1.0
|
||||
|
||||
a_mlx, b_mlx = map(mx.array, [a_npy, b_npy])
|
||||
c_npy = a_npy @ b_npy
|
||||
c_mlx = a_mlx @ b_mlx
|
||||
|
||||
self.assertListEqual(
|
||||
list(c_npy.shape), list(c_mlx.shape)
|
||||
)
|
||||
self.assertTrue(np.array_equal(c_mlx, c_npy))
|
||||
|
||||
# Vector matrix
|
||||
with self.subTest(transpose=True):
|
||||
a_npy = np.ones(
|
||||
(batch_size, out_vec_len, in_vec_len),
|
||||
dtype=np_dtype,
|
||||
)
|
||||
b_npy = np.ones(
|
||||
(batch_size, 1, out_vec_len), dtype=np_dtype
|
||||
)
|
||||
for i in range(batch_size):
|
||||
b_npy[i] *= i + 1.0
|
||||
|
||||
a_mlx, b_mlx = map(mx.array, [a_npy, b_npy])
|
||||
c_npy = b_npy @ a_npy
|
||||
c_mlx = b_mlx @ a_mlx
|
||||
|
||||
self.assertListEqual(
|
||||
list(c_npy.shape), list(c_mlx.shape)
|
||||
)
|
||||
self.assertTrue(np.array_equal(c_mlx, c_npy))
|
445
python/tests/test_conv.py
Normal file
445
python/tests/test_conv.py
Normal file
@@ -0,0 +1,445 @@
|
||||
import unittest
|
||||
from itertools import permutations
|
||||
|
||||
import math
|
||||
import mlx.core as mx
|
||||
import numpy as np
|
||||
|
||||
import mlx_tests
|
||||
|
||||
try:
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
has_torch = True
|
||||
except ImportError as e:
|
||||
has_torch = False
|
||||
|
||||
|
||||
class TestConv(mlx_tests.MLXTestCase):
|
||||
def test_numpy_conv(self):
|
||||
for dtype in (
|
||||
"float16",
|
||||
"float32",
|
||||
):
|
||||
np_dtype = getattr(np, dtype)
|
||||
for M, N, mode in (
|
||||
(1, 1, "full"),
|
||||
(25, 5, "full"),
|
||||
(24, 5, "same"),
|
||||
(24, 4, "same"),
|
||||
(24, 4, "valid"),
|
||||
(4, 24, "full"),
|
||||
(5, 25, "same"),
|
||||
(4, 25, "valid"),
|
||||
):
|
||||
with self.subTest(dtype=dtype, M=M, N=N, mode=mode):
|
||||
atol = 1e-6 if dtype == "float32" else 1e-5
|
||||
a_np = np.random.rand(M).astype(np_dtype)
|
||||
v_np = np.random.rand(N).astype(np_dtype)
|
||||
a_mx = mx.array(a_np)
|
||||
v_mx = mx.array(v_np)
|
||||
|
||||
c_np = np.convolve(a_np, v_np, mode=mode)
|
||||
c_mx = mx.convolve(a_mx, v_mx, mode=mode)
|
||||
|
||||
self.assertListEqual(list(c_mx.shape), list(c_np.shape))
|
||||
self.assertTrue(np.allclose(c_mx, c_np, atol=atol))
|
||||
|
||||
@unittest.skipIf(not has_torch, "requires Torch")
|
||||
def test_torch_conv_1D(self):
|
||||
def run_conv1D(
|
||||
N,
|
||||
C,
|
||||
O,
|
||||
iH,
|
||||
kH,
|
||||
stride,
|
||||
padding,
|
||||
dilation=1,
|
||||
groups=1,
|
||||
dtype="float32",
|
||||
atol=1e-5,
|
||||
):
|
||||
with self.subTest(
|
||||
dtype=dtype,
|
||||
N=N,
|
||||
C=C,
|
||||
O=O,
|
||||
iH=iH,
|
||||
kH=kH,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
dilation=dilation,
|
||||
groups=groups,
|
||||
):
|
||||
np_dtype = getattr(np, dtype)
|
||||
np.random.seed(0)
|
||||
in_np = np.random.normal(0, 1.0 / C, (N, iH, C)).astype(np_dtype)
|
||||
wt_np = np.random.normal(0, 1.0 / C, (O, kH, C)).astype(np_dtype)
|
||||
|
||||
in_mx, wt_mx = map(mx.array, (in_np, wt_np))
|
||||
in_pt, wt_pt = map(
|
||||
lambda x: torch.from_numpy(x.transpose(0, 2, 1)), (in_np, wt_np)
|
||||
)
|
||||
|
||||
out_mx = mx.conv1d(
|
||||
in_mx,
|
||||
wt_mx,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
dilation=dilation,
|
||||
groups=groups,
|
||||
)
|
||||
out_pt = torch.conv1d(
|
||||
in_pt,
|
||||
wt_pt,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
dilation=dilation,
|
||||
groups=groups,
|
||||
)
|
||||
out_pt = torch.transpose(out_pt, 2, 1)
|
||||
|
||||
self.assertListEqual(list(out_pt.shape), out_mx.shape)
|
||||
self.assertTrue(np.allclose(out_pt.numpy(), out_mx, atol=atol))
|
||||
|
||||
for dtype in ("float32",):
|
||||
for N, C, O in (
|
||||
(1, 1, 1),
|
||||
(1, 6, 1),
|
||||
(1, 1, 6),
|
||||
(4, 32, 64),
|
||||
):
|
||||
for iH, kH, stride, padding in (
|
||||
(1, 1, 1, 0),
|
||||
(3, 3, 1, 0),
|
||||
(31, 5, 5, 2),
|
||||
):
|
||||
run_conv1D(N, C, O, iH, kH, stride, padding, dtype=dtype)
|
||||
|
||||
# Strided inputs tests
|
||||
for tpose_in, tpose_wt in (
|
||||
((0, 2, 1), (0, 1, 2)),
|
||||
((0, 2, 1), (0, 2, 1)),
|
||||
):
|
||||
with self.subTest(name="strided", tpose_in=tpose_in, tpose_wt=tpose_wt):
|
||||
in_np = np.random.normal(0, 1.0 / 16, (16, 16, 16)).astype(np.float32)
|
||||
wt_np = np.random.normal(0, 1.0 / 16, (16, 16, 16)).astype(np.float32)
|
||||
|
||||
in_mx, wt_mx = map(mx.array, (in_np, wt_np))
|
||||
in_mx_t = mx.transpose(in_mx, tpose_in)
|
||||
wt_mx_t = mx.transpose(wt_mx, tpose_wt)
|
||||
out_mx = mx.conv1d(in_mx_t, wt_mx_t)
|
||||
|
||||
in_pt, wt_pt = map(
|
||||
lambda x: torch.from_numpy(x.transpose(0, 2, 1)),
|
||||
(in_np.transpose(tpose_in), wt_np.transpose(tpose_wt)),
|
||||
)
|
||||
|
||||
out_pt = torch.conv1d(in_pt, wt_pt)
|
||||
out_pt = torch.transpose(out_pt, 2, 1)
|
||||
|
||||
self.assertListEqual(list(out_pt.shape), out_mx.shape)
|
||||
self.assertTrue(np.allclose(out_pt.numpy(), out_mx, atol=1e-5))
|
||||
|
||||
@unittest.skipIf(not has_torch, "requires Torch")
|
||||
def test_torch_conv_1D_grad(self):
|
||||
def run_conv1D_grad(
|
||||
N,
|
||||
C,
|
||||
O,
|
||||
iH,
|
||||
kH,
|
||||
stride,
|
||||
padding,
|
||||
dilation=1,
|
||||
groups=1,
|
||||
dtype="float32",
|
||||
atol=1e-5,
|
||||
):
|
||||
with self.subTest(
|
||||
dtype=dtype,
|
||||
N=N,
|
||||
C=C,
|
||||
O=O,
|
||||
iH=iH,
|
||||
kH=kH,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
dilation=dilation,
|
||||
groups=groups,
|
||||
):
|
||||
np_dtype = getattr(np, dtype)
|
||||
np.random.seed(0)
|
||||
oH = 1 + ((iH + 2 * padding - dilation * (kH - 1) - 1) // stride)
|
||||
|
||||
in_np = np.random.normal(0, 1.0 / C, (N, iH, C)).astype(np_dtype)
|
||||
wt_np = np.random.normal(0, 1.0 / C, (O, kH, C)).astype(np_dtype)
|
||||
ct_np = np.random.normal(0, 1.0 / C, (N, oH, O)).astype(np_dtype)
|
||||
|
||||
in_mx, wt_mx, ct_mx = map(mx.array, (in_np, wt_np, ct_np))
|
||||
in_pt, wt_pt, ct_pt = map(
|
||||
lambda x: torch.from_numpy(x.transpose(0, 2, 1)),
|
||||
(in_np, wt_np, ct_np),
|
||||
)
|
||||
|
||||
def f(a, b):
|
||||
return mx.conv1d(
|
||||
a,
|
||||
b,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
dilation=dilation,
|
||||
groups=groups,
|
||||
)
|
||||
|
||||
_, outs_mx = mx.vjp(
|
||||
f,
|
||||
[
|
||||
in_mx,
|
||||
wt_mx,
|
||||
],
|
||||
[
|
||||
ct_mx,
|
||||
],
|
||||
)
|
||||
pt_grad_in = F.grad.conv1d_input(
|
||||
in_pt.shape,
|
||||
wt_pt,
|
||||
ct_pt,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
dilation=dilation,
|
||||
groups=groups,
|
||||
)
|
||||
pt_grad_wt = F.grad.conv1d_weight(
|
||||
in_pt,
|
||||
wt_pt.shape,
|
||||
ct_pt,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
dilation=dilation,
|
||||
groups=groups,
|
||||
)
|
||||
pt_grad_in = torch.transpose(pt_grad_in, 2, 1).numpy()
|
||||
pt_grad_wt = torch.transpose(pt_grad_wt, 2, 1).numpy()
|
||||
|
||||
mx_grad_in, mx_grad_wt = outs_mx
|
||||
|
||||
self.assertListEqual(list(pt_grad_in.shape), mx_grad_in.shape)
|
||||
self.assertListEqual(list(in_mx.shape), mx_grad_in.shape)
|
||||
self.assertTrue(np.allclose(pt_grad_in, mx_grad_in, atol=atol))
|
||||
|
||||
self.assertListEqual(list(pt_grad_wt.shape), mx_grad_wt.shape)
|
||||
self.assertListEqual(list(wt_mx.shape), mx_grad_wt.shape)
|
||||
self.assertTrue(np.allclose(pt_grad_wt, mx_grad_wt, atol=atol))
|
||||
|
||||
for dtype in ("float32",):
|
||||
for N, C, O in (
|
||||
(1, 1, 1),
|
||||
(1, 6, 1),
|
||||
(1, 1, 6),
|
||||
(4, 32, 64),
|
||||
):
|
||||
for iH, kH, stride, padding in (
|
||||
(1, 1, 1, 0),
|
||||
(3, 3, 1, 0),
|
||||
(31, 5, 5, 2),
|
||||
):
|
||||
run_conv1D_grad(N, C, O, iH, kH, stride, padding, dtype=dtype)
|
||||
|
||||
@unittest.skipIf(not has_torch, "requires Torch")
|
||||
def test_torch_conv_2D(self):
|
||||
def run_conv2D(
|
||||
N,
|
||||
C,
|
||||
O,
|
||||
idim,
|
||||
kdim,
|
||||
stride,
|
||||
padding,
|
||||
dilation=(1, 1),
|
||||
groups=1,
|
||||
dtype="float32",
|
||||
atol=1e-5,
|
||||
):
|
||||
with self.subTest(
|
||||
dtype=dtype,
|
||||
N=N,
|
||||
C=C,
|
||||
O=O,
|
||||
idim=idim,
|
||||
kdim=kdim,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
dilation=dilation,
|
||||
groups=groups,
|
||||
):
|
||||
np_dtype = getattr(np, dtype)
|
||||
np.random.seed(0)
|
||||
iH, iW = idim
|
||||
kH, kW = kdim
|
||||
scale = 1.0 / math.sqrt(kH * kW * C)
|
||||
in_np = np.random.normal(0.0, scale, (N, iH, iW, C)).astype(np_dtype)
|
||||
wt_np = np.random.normal(0.0, 1.0, (O, kH, kW, C)).astype(np_dtype)
|
||||
|
||||
in_mx, wt_mx = map(mx.array, (in_np, wt_np))
|
||||
in_pt, wt_pt = map(
|
||||
lambda x: torch.from_numpy(x.transpose(0, 3, 1, 2)).to("cpu"),
|
||||
(in_np, wt_np),
|
||||
)
|
||||
|
||||
out_mx = mx.conv2d(
|
||||
in_mx,
|
||||
wt_mx,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
dilation=dilation,
|
||||
groups=groups,
|
||||
)
|
||||
out_pt = torch.conv2d(
|
||||
in_pt,
|
||||
wt_pt,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
dilation=dilation,
|
||||
groups=groups,
|
||||
)
|
||||
out_pt = torch.permute(out_pt, (0, 2, 3, 1)).numpy(force=True)
|
||||
|
||||
self.assertListEqual(list(out_pt.shape), list(out_mx.shape))
|
||||
self.assertTrue(np.allclose(out_pt, out_mx, atol=atol))
|
||||
|
||||
for dtype in ("float32",):
|
||||
for N, C, O in (
|
||||
(1, 1, 1),
|
||||
(1, 6, 1),
|
||||
(1, 1, 6),
|
||||
(4, 32, 64),
|
||||
):
|
||||
for idim, kdim, stride, padding in (
|
||||
((1, 1), (1, 1), (1, 1), (0, 0)),
|
||||
((3, 3), (3, 1), (1, 1), (0, 0)),
|
||||
((31, 31), (5, 5), (5, 5), (2, 2)),
|
||||
):
|
||||
run_conv2D(N, C, O, idim, kdim, stride, padding, dtype=dtype)
|
||||
|
||||
@unittest.skipIf(not has_torch, "requires Torch")
|
||||
def test_torch_conv_2D_grad(self):
|
||||
def run_conv2D_grad(
|
||||
N,
|
||||
C,
|
||||
O,
|
||||
idim,
|
||||
kdim,
|
||||
stride,
|
||||
padding,
|
||||
dilation=(1, 1),
|
||||
groups=1,
|
||||
dtype="float32",
|
||||
atol=1e-5,
|
||||
):
|
||||
with self.subTest(
|
||||
dtype=dtype,
|
||||
N=N,
|
||||
C=C,
|
||||
O=O,
|
||||
idim=idim,
|
||||
kdim=kdim,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
dilation=dilation,
|
||||
groups=groups,
|
||||
):
|
||||
np_dtype = getattr(np, dtype)
|
||||
np.random.seed(0)
|
||||
iH, iW = idim
|
||||
kH, kW = kdim
|
||||
scale = 1.0 / math.sqrt(kH * kW * C)
|
||||
|
||||
oH = 1 + (
|
||||
(iH + 2 * padding[0] - dilation[0] * (kH - 1) - 1) // stride[0]
|
||||
)
|
||||
oW = 1 + (
|
||||
(iW + 2 * padding[1] - dilation[1] * (kW - 1) - 1) // stride[1]
|
||||
)
|
||||
|
||||
in_np = np.random.normal(0.0, scale, (N, iH, iW, C)).astype(np_dtype)
|
||||
wt_np = np.random.normal(0.0, scale, (O, kH, kW, C)).astype(np_dtype)
|
||||
ct_np = np.random.normal(0.0, scale, (N, oH, oW, O)).astype(np_dtype)
|
||||
|
||||
in_mx, wt_mx, ct_mx = map(mx.array, (in_np, wt_np, ct_np))
|
||||
in_pt, wt_pt, ct_pt = map(
|
||||
lambda x: torch.from_numpy(x.transpose(0, 3, 1, 2)).to("cpu"),
|
||||
(in_np, wt_np, ct_np),
|
||||
)
|
||||
|
||||
def f(a, b):
|
||||
return mx.conv2d(
|
||||
a,
|
||||
b,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
dilation=dilation,
|
||||
groups=groups,
|
||||
)
|
||||
|
||||
_, outs_mx = mx.vjp(
|
||||
f,
|
||||
[
|
||||
in_mx,
|
||||
wt_mx,
|
||||
],
|
||||
[
|
||||
ct_mx,
|
||||
],
|
||||
)
|
||||
pt_grad_in = F.grad.conv1d_input(
|
||||
in_pt.shape,
|
||||
wt_pt,
|
||||
ct_pt,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
dilation=dilation,
|
||||
groups=groups,
|
||||
)
|
||||
pt_grad_wt = F.grad.conv1d_weight(
|
||||
in_pt,
|
||||
wt_pt.shape,
|
||||
ct_pt,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
dilation=dilation,
|
||||
groups=groups,
|
||||
)
|
||||
pt_grad_in = torch.permute(pt_grad_in, (0, 2, 3, 1)).numpy()
|
||||
pt_grad_wt = torch.permute(pt_grad_wt, (0, 2, 3, 1)).numpy()
|
||||
|
||||
mx_grad_in, mx_grad_wt = outs_mx
|
||||
|
||||
self.assertListEqual(list(pt_grad_in.shape), mx_grad_in.shape)
|
||||
self.assertListEqual(list(in_mx.shape), mx_grad_in.shape)
|
||||
self.assertTrue(np.allclose(pt_grad_in, mx_grad_in, atol=atol))
|
||||
|
||||
self.assertListEqual(list(pt_grad_wt.shape), mx_grad_wt.shape)
|
||||
self.assertListEqual(list(wt_mx.shape), mx_grad_wt.shape)
|
||||
self.assertTrue(np.allclose(pt_grad_wt, mx_grad_wt, atol=atol))
|
||||
|
||||
for dtype in ("float32",):
|
||||
for N, C, O in (
|
||||
(1, 1, 1),
|
||||
(1, 6, 1),
|
||||
(1, 1, 6),
|
||||
(4, 32, 64),
|
||||
):
|
||||
for idim, kdim, stride, padding in (
|
||||
((1, 1), (1, 1), (1, 1), (0, 0)),
|
||||
((3, 3), (3, 1), (1, 1), (0, 0)),
|
||||
((31, 31), (5, 5), (5, 5), (2, 2)),
|
||||
):
|
||||
run_conv2D_grad(N, C, O, idim, kdim, stride, padding, dtype=dtype)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
157
python/tests/test_load.py
Normal file
157
python/tests/test_load.py
Normal file
@@ -0,0 +1,157 @@
|
||||
import unittest
|
||||
import os
|
||||
import mlx.core as mx
|
||||
import numpy as np
|
||||
import tempfile
|
||||
|
||||
import mlx_tests
|
||||
|
||||
|
||||
class TestLoad(mlx_tests.MLXTestCase):
|
||||
dtypes = [
|
||||
"uint8",
|
||||
"uint16",
|
||||
"uint32",
|
||||
"uint64",
|
||||
"int8",
|
||||
"int16",
|
||||
"int32",
|
||||
"int64",
|
||||
"float32",
|
||||
"float16",
|
||||
"complex64",
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.test_dir_fid = tempfile.TemporaryDirectory()
|
||||
cls.test_dir = cls.test_dir_fid.name
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
cls.test_dir_fid.cleanup()
|
||||
|
||||
def test_save_and_load(self):
|
||||
|
||||
if not os.path.isdir(self.test_dir):
|
||||
os.mkdir(self.test_dir)
|
||||
|
||||
for dt in self.dtypes:
|
||||
with self.subTest(dtype=dt):
|
||||
for i, shape in enumerate([(1,), (23,), (1024, 1024), (4, 6, 3, 1, 2)]):
|
||||
with self.subTest(shape=shape):
|
||||
save_file_mlx = os.path.join(self.test_dir, f"mlx_{dt}_{i}.npy")
|
||||
save_file_npy = os.path.join(self.test_dir, f"npy_{dt}_{i}.npy")
|
||||
|
||||
save_arr = np.random.uniform(0.0, 32.0, size=shape)
|
||||
save_arr_npy = save_arr.astype(getattr(np, dt))
|
||||
save_arr_mlx = mx.array(save_arr_npy)
|
||||
|
||||
mx.save(save_file_mlx, save_arr_mlx)
|
||||
np.save(save_file_npy, save_arr_npy)
|
||||
|
||||
# Load array saved by mlx as mlx array
|
||||
load_arr_mlx_mlx = mx.load(save_file_mlx)
|
||||
self.assertTrue(mx.array_equal(load_arr_mlx_mlx, save_arr_mlx))
|
||||
|
||||
# Load array saved by numpy as mlx array
|
||||
load_arr_npy_mlx = mx.load(save_file_npy)
|
||||
self.assertTrue(mx.array_equal(load_arr_npy_mlx, save_arr_mlx))
|
||||
|
||||
# Load array saved by mlx as numpy array
|
||||
load_arr_mlx_npy = np.load(save_file_mlx)
|
||||
self.assertTrue(np.array_equal(load_arr_mlx_npy, save_arr_npy))
|
||||
|
||||
def test_save_and_load_fs(self):
|
||||
|
||||
if not os.path.isdir(self.test_dir):
|
||||
os.mkdir(self.test_dir)
|
||||
|
||||
for dt in self.dtypes:
|
||||
with self.subTest(dtype=dt):
|
||||
for i, shape in enumerate([(1,), (23,), (1024, 1024), (4, 6, 3, 1, 2)]):
|
||||
with self.subTest(shape=shape):
|
||||
save_file_mlx = os.path.join(
|
||||
self.test_dir, f"mlx_{dt}_{i}_fs.npy"
|
||||
)
|
||||
save_file_npy = os.path.join(
|
||||
self.test_dir, f"npy_{dt}_{i}_fs.npy"
|
||||
)
|
||||
|
||||
save_arr = np.random.uniform(0.0, 32.0, size=shape)
|
||||
save_arr_npy = save_arr.astype(getattr(np, dt))
|
||||
save_arr_mlx = mx.array(save_arr_npy)
|
||||
|
||||
with open(save_file_mlx, "wb") as f:
|
||||
mx.save(f, save_arr_mlx)
|
||||
|
||||
np.save(save_file_npy, save_arr_npy)
|
||||
|
||||
# Load array saved by mlx as mlx array
|
||||
with open(save_file_mlx, "rb") as f:
|
||||
load_arr_mlx_mlx = mx.load(f)
|
||||
self.assertTrue(mx.array_equal(load_arr_mlx_mlx, save_arr_mlx))
|
||||
|
||||
# Load array saved by numpy as mlx array
|
||||
with open(save_file_npy, "rb") as f:
|
||||
load_arr_npy_mlx = mx.load(f)
|
||||
self.assertTrue(mx.array_equal(load_arr_npy_mlx, save_arr_mlx))
|
||||
|
||||
# Load array saved by mlx as numpy array
|
||||
load_arr_mlx_npy = np.load(save_file_mlx)
|
||||
self.assertTrue(np.array_equal(load_arr_mlx_npy, save_arr_npy))
|
||||
|
||||
def test_savez_and_loadz(self):
|
||||
if not os.path.isdir(self.test_dir):
|
||||
os.mkdir(self.test_dir)
|
||||
|
||||
for dt in self.dtypes:
|
||||
with self.subTest(dtype=dt):
|
||||
shapes = [(6,), (6, 6), (4, 1, 3, 1, 2)]
|
||||
save_file_mlx_uncomp = os.path.join(
|
||||
self.test_dir, f"mlx_{dt}_uncomp.npz"
|
||||
)
|
||||
save_file_npy_uncomp = os.path.join(
|
||||
self.test_dir, f"npy_{dt}_uncomp.npz"
|
||||
)
|
||||
save_file_mlx_comp = os.path.join(self.test_dir, f"mlx_{dt}_comp.npz")
|
||||
save_file_npy_comp = os.path.join(self.test_dir, f"npy_{dt}_comp.npz")
|
||||
|
||||
# Make dictionary of multiple
|
||||
save_arrs_npy = {
|
||||
f"save_arr_{i}": np.random.uniform(
|
||||
0.0, 32.0, size=shapes[i]
|
||||
).astype(getattr(np, dt))
|
||||
for i in range(len(shapes))
|
||||
}
|
||||
save_arrs_mlx = {k: mx.array(v) for k, v in save_arrs_npy.items()}
|
||||
|
||||
# Save as npz files
|
||||
np.savez(save_file_npy_uncomp, **save_arrs_npy)
|
||||
mx.savez(save_file_mlx_uncomp, **save_arrs_mlx)
|
||||
np.savez_compressed(save_file_npy_comp, **save_arrs_npy)
|
||||
mx.savez_compressed(save_file_mlx_comp, **save_arrs_mlx)
|
||||
|
||||
for save_file_npy, save_file_mlx in (
|
||||
(save_file_npy_uncomp, save_file_mlx_uncomp),
|
||||
(save_file_npy_comp, save_file_mlx_comp),
|
||||
):
|
||||
|
||||
# Load array saved by mlx as mlx array
|
||||
load_arr_mlx_mlx = mx.load(save_file_mlx)
|
||||
for k, v in load_arr_mlx_mlx.items():
|
||||
self.assertTrue(mx.array_equal(save_arrs_mlx[k], v))
|
||||
|
||||
# Load arrays saved by numpy as mlx arrays
|
||||
load_arr_npy_mlx = mx.load(save_file_npy)
|
||||
for k, v in load_arr_npy_mlx.items():
|
||||
self.assertTrue(mx.array_equal(save_arrs_mlx[k], v))
|
||||
|
||||
# Load array saved by mlx as numpy array
|
||||
load_arr_mlx_npy = np.load(save_file_mlx)
|
||||
for k, v in load_arr_mlx_npy.items():
|
||||
self.assertTrue(np.array_equal(save_arrs_npy[k], v))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
231
python/tests/test_nn.py
Normal file
231
python/tests/test_nn.py
Normal file
@@ -0,0 +1,231 @@
|
||||
import unittest
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
from mlx.utils import tree_flatten, tree_map, tree_unflatten
|
||||
import numpy as np
|
||||
import os
|
||||
import tempfile
|
||||
|
||||
import mlx_tests
|
||||
|
||||
|
||||
class TestNN(mlx_tests.MLXTestCase):
|
||||
def test_linear(self):
|
||||
inputs = mx.zeros((10, 4))
|
||||
layer = nn.Linear(input_dims=4, output_dims=8)
|
||||
outputs = layer(inputs)
|
||||
self.assertEqual(tuple(outputs.shape), (10, 8))
|
||||
|
||||
def test_cross_entropy(self):
|
||||
logits = mx.array([[0.0, -float("inf")], [-float("inf"), 0.0]])
|
||||
targets = mx.array([0, 1])
|
||||
losses = nn.losses.cross_entropy(logits, targets)
|
||||
self.assertTrue(mx.array_equal(losses, mx.zeros((2,))))
|
||||
|
||||
def test_gelu(self):
|
||||
inputs = [1.15286231, -0.81037411, 0.35816911, 0.77484438, 0.66276414]
|
||||
|
||||
# From: jax.nn.gelu(np.array(inputs), approximate=False)
|
||||
expected = np.array(
|
||||
[1.0093501, -0.16925684, 0.22918941, 0.60498625, 0.49459383]
|
||||
)
|
||||
|
||||
out = nn.GELU()(mx.array(inputs))
|
||||
self.assertTrue(np.allclose(out, expected))
|
||||
|
||||
# Crudely check the approximations
|
||||
x = mx.arange(-6.0, 6.0, 12 / 100)
|
||||
y = nn.gelu(x)
|
||||
y_hat1 = nn.gelu_approx(x)
|
||||
y_hat2 = nn.gelu_fast_approx(x)
|
||||
self.assertLess(mx.abs(y - y_hat1).max(), 0.0003)
|
||||
self.assertLess(mx.abs(y - y_hat2).max(), 0.02)
|
||||
|
||||
def test_group_norm(self):
|
||||
x = mx.arange(100, dtype=mx.float32)
|
||||
x = x.reshape(1, 10, 10, 1)
|
||||
x = mx.broadcast_to(x, (2, 10, 10, 4))
|
||||
x = mx.concatenate([x, 0.5 * x], axis=-1)
|
||||
|
||||
# Group norm in groups last mode
|
||||
g = nn.GroupNorm(2, 8)
|
||||
y = g(x)
|
||||
means = y.reshape(2, -1, 2).mean(axis=1)
|
||||
var = y.reshape(2, -1, 2).var(axis=1)
|
||||
self.assertTrue(np.allclose(means, np.zeros_like(means), atol=1e-6))
|
||||
self.assertTrue(np.allclose(var, np.ones_like(var), atol=1e-6))
|
||||
g.weight = g.weight * 2
|
||||
g.bias = g.bias + 3
|
||||
y = g(x)
|
||||
means = y.reshape(2, -1, 2).mean(axis=1)
|
||||
var = y.reshape(2, -1, 2).var(axis=1)
|
||||
self.assertTrue(np.allclose(means, 3 * np.ones_like(means), atol=1e-6))
|
||||
self.assertTrue(np.allclose(var, 4 * np.ones_like(var), atol=1e-6))
|
||||
|
||||
# Group norm in groups first mode
|
||||
g = nn.GroupNorm(2, 8, pytorch_compatible=True)
|
||||
y = g(x)
|
||||
means = y.reshape(2, -1, 2, 4).mean(axis=(1, -1))
|
||||
var = y.reshape(2, -1, 2, 4).var(axis=(1, -1))
|
||||
self.assertTrue(np.allclose(means, np.zeros_like(means), atol=1e-6))
|
||||
self.assertTrue(np.allclose(var, np.ones_like(var), atol=1e-6))
|
||||
g.weight = g.weight * 2
|
||||
g.bias = g.bias + 3
|
||||
y = g(x)
|
||||
means = y.reshape(2, -1, 2, 4).mean(axis=(1, -1))
|
||||
var = y.reshape(2, -1, 2, 4).var(axis=(1, -1))
|
||||
self.assertTrue(np.allclose(means, 3 * np.ones_like(means), atol=1e-6))
|
||||
self.assertTrue(np.allclose(var, 4 * np.ones_like(var), atol=1e-6))
|
||||
|
||||
def test_conv1d(self):
|
||||
N = 5
|
||||
L = 12
|
||||
ks = 3
|
||||
C_in = 2
|
||||
C_out = 4
|
||||
x = mx.ones((N, L, C_in))
|
||||
c = nn.Conv1d(in_channels=C_in, out_channels=C_out, kernel_size=ks)
|
||||
c.weight = mx.ones_like(c.weight)
|
||||
y = c(x)
|
||||
self.assertEqual(y.shape, [N, L - ks + 1, C_out])
|
||||
self.assertTrue(mx.allclose(y, mx.full(y.shape, ks * C_in, mx.float32)))
|
||||
|
||||
c = nn.Conv1d(in_channels=C_in, out_channels=C_out, kernel_size=ks, stride=2)
|
||||
y = c(x)
|
||||
self.assertEqual(y.shape, [N, (L - ks + 1) // 2, C_out])
|
||||
self.assertTrue("bias" in c.parameters())
|
||||
|
||||
c = nn.Conv1d(in_channels=C_in, out_channels=C_out, kernel_size=ks, bias=False)
|
||||
self.assertTrue("bias" not in c.parameters())
|
||||
|
||||
def test_conv2d(self):
|
||||
x = mx.ones((4, 8, 8, 3))
|
||||
c = nn.Conv2d(3, 1, 8)
|
||||
y = c(x)
|
||||
self.assertEqual(y.shape, [4, 1, 1, 1])
|
||||
c.weight = mx.ones_like(c.weight) / 8 / 8 / 3
|
||||
y = c(x)
|
||||
self.assertTrue(np.allclose(y[:, 0, 0, 0], x.mean(axis=(1, 2, 3))))
|
||||
|
||||
# 3x3 conv no padding stride 1
|
||||
c = nn.Conv2d(3, 8, 3)
|
||||
y = c(x)
|
||||
self.assertEqual(y.shape, [4, 6, 6, 8])
|
||||
self.assertLess(mx.abs(y - c.weight.sum((1, 2, 3))).max(), 1e-4)
|
||||
|
||||
# 3x3 conv padding 1 stride 1
|
||||
c = nn.Conv2d(3, 8, 3, padding=1)
|
||||
y = c(x)
|
||||
self.assertEqual(y.shape, [4, 8, 8, 8])
|
||||
self.assertLess(mx.abs(y[:, 1:7, 1:7] - c.weight.sum((1, 2, 3))).max(), 1e-4)
|
||||
self.assertLess(
|
||||
mx.abs(y[:, 0, 0] - c.weight[:, 1:, 1:].sum(axis=(1, 2, 3))).max(),
|
||||
1e-4,
|
||||
)
|
||||
self.assertLess(
|
||||
mx.abs(y[:, 7, 7] - c.weight[:, :-1, :-1].sum(axis=(1, 2, 3))).max(),
|
||||
1e-4,
|
||||
)
|
||||
self.assertLess(
|
||||
mx.abs(y[:, 1:7, 7] - c.weight[:, :, :-1].sum(axis=(1, 2, 3))).max(),
|
||||
1e-4,
|
||||
)
|
||||
self.assertLess(
|
||||
mx.abs(y[:, 7, 1:7] - c.weight[:, :-1, :].sum(axis=(1, 2, 3))).max(),
|
||||
1e-4,
|
||||
)
|
||||
|
||||
# 3x3 conv no padding stride 2
|
||||
c = nn.Conv2d(3, 8, 3, padding=0, stride=2)
|
||||
y = c(x)
|
||||
self.assertEqual(y.shape, [4, 3, 3, 8])
|
||||
self.assertLess(mx.abs(y - c.weight.sum((1, 2, 3))).max(), 1e-4)
|
||||
|
||||
def test_sequential(self):
|
||||
x = mx.ones((10, 2))
|
||||
m = nn.Sequential(nn.Linear(2, 10), nn.ReLU(), nn.Linear(10, 1))
|
||||
y = m(x)
|
||||
self.assertEqual(y.shape, [10, 1])
|
||||
params = m.parameters()
|
||||
self.assertTrue("layers" in params)
|
||||
self.assertEqual(len(params["layers"]), 3)
|
||||
self.assertTrue("weight" in params["layers"][0])
|
||||
self.assertEqual(len(params["layers"][1]), 0)
|
||||
self.assertTrue("weight" in params["layers"][2])
|
||||
|
||||
m.layers[1] = nn.relu
|
||||
y2 = m(x)
|
||||
self.assertTrue(mx.array_equal(y, y2))
|
||||
|
||||
def test_module_utilities(self):
|
||||
m = nn.Sequential(
|
||||
nn.Sequential(nn.Linear(2, 10), nn.relu),
|
||||
nn.Sequential(nn.Linear(10, 10), nn.ReLU()),
|
||||
nn.Linear(10, 1),
|
||||
mx.sigmoid,
|
||||
)
|
||||
|
||||
children = m.children()
|
||||
self.assertTrue(isinstance(children, dict))
|
||||
self.assertEqual(len(children), 1)
|
||||
self.assertTrue(isinstance(children["layers"], list))
|
||||
self.assertEqual(len(children["layers"]), 4)
|
||||
self.assertEqual(children["layers"][3], {})
|
||||
flat_children = tree_flatten(children, is_leaf=nn.Module.is_module)
|
||||
self.assertEqual(len(flat_children), 3)
|
||||
|
||||
leaves = tree_flatten(m.leaf_modules(), is_leaf=nn.Module.is_module)
|
||||
self.assertEqual(len(leaves), 4)
|
||||
self.assertEqual(leaves[0][0], "layers.0.layers.0")
|
||||
self.assertEqual(leaves[1][0], "layers.1.layers.0")
|
||||
self.assertEqual(leaves[2][0], "layers.1.layers.1")
|
||||
self.assertEqual(leaves[3][0], "layers.2")
|
||||
self.assertTrue(leaves[0][1] is m.layers[0].layers[0])
|
||||
self.assertTrue(leaves[1][1] is m.layers[1].layers[0])
|
||||
self.assertTrue(leaves[2][1] is m.layers[1].layers[1])
|
||||
self.assertTrue(leaves[3][1] is m.layers[2])
|
||||
|
||||
m.eval()
|
||||
|
||||
def assert_not_training(k, m):
|
||||
self.assertFalse(m.training)
|
||||
|
||||
m.apply_to_modules(assert_not_training)
|
||||
|
||||
m.train()
|
||||
|
||||
def assert_training(k, m):
|
||||
self.assertTrue(m.training)
|
||||
|
||||
m.apply_to_modules(assert_training)
|
||||
|
||||
def test_sin_pe(self):
|
||||
m = nn.SinusoidalPositionalEncoding(16, min_freq=0.01)
|
||||
x = mx.arange(10)
|
||||
y = m(x)
|
||||
|
||||
self.assertEqual(y.shape, [10, 16])
|
||||
similarities = y @ y.T
|
||||
self.assertLess(
|
||||
mx.abs(similarities[mx.arange(10), mx.arange(10)] - 1).max(), 1e-5
|
||||
)
|
||||
|
||||
def test_io(self):
|
||||
def make_model():
|
||||
return nn.Sequential(nn.Linear(2, 2), nn.ReLU(), nn.Linear(2, 2))
|
||||
|
||||
m = make_model()
|
||||
tdir = tempfile.TemporaryDirectory()
|
||||
file = os.path.join(tdir.name, "model.npz")
|
||||
m.save_weights(file)
|
||||
m_load = make_model()
|
||||
m_load.load_weights(file)
|
||||
tdir.cleanup()
|
||||
|
||||
eq_tree = tree_map(mx.array_equal, m.parameters(), m_load.parameters())
|
||||
self.assertTrue(all(tree_flatten(eq_tree)))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
29
python/tests/test_optimizers.py
Normal file
29
python/tests/test_optimizers.py
Normal file
@@ -0,0 +1,29 @@
|
||||
import unittest
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.optimizers as opt
|
||||
import mlx.utils
|
||||
|
||||
import mlx_tests
|
||||
|
||||
|
||||
class TestOptimizers(mlx_tests.MLXTestCase):
|
||||
def test_optimizers(self):
|
||||
params = {
|
||||
"first": [mx.zeros((10,)), mx.zeros((1,))],
|
||||
"second": mx.zeros((1,)),
|
||||
}
|
||||
grads = mlx.utils.tree_map(lambda x: mx.ones_like(x), params)
|
||||
|
||||
for optim in [opt.SGD(0.1), opt.Adam(0.1)]:
|
||||
update = optim.apply_gradients(grads, params)
|
||||
mx.eval(update)
|
||||
equal_shape = mlx.utils.tree_map(
|
||||
lambda x, y: x.shape == y.shape, params, update
|
||||
)
|
||||
all_equal = all(v for _, v in mlx.utils.tree_flatten(equal_shape))
|
||||
self.assertTrue(all_equal)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
192
python/tests/test_random.py
Normal file
192
python/tests/test_random.py
Normal file
@@ -0,0 +1,192 @@
|
||||
import unittest
|
||||
|
||||
import mlx.core as mx
|
||||
|
||||
import mlx_tests
|
||||
|
||||
|
||||
class TestRandom(mlx_tests.MLXTestCase):
|
||||
def test_global_rng(self):
|
||||
mx.random.seed(3)
|
||||
a = mx.random.uniform()
|
||||
b = mx.random.uniform()
|
||||
|
||||
mx.random.seed(3)
|
||||
x = mx.random.uniform()
|
||||
y = mx.random.uniform()
|
||||
|
||||
self.assertEqual(a.item(), x.item())
|
||||
self.assertEqual(y.item(), b.item())
|
||||
|
||||
def test_key(self):
|
||||
k1 = mx.random.key(0)
|
||||
k2 = mx.random.key(0)
|
||||
self.assertTrue(mx.array_equal(k1, k2))
|
||||
|
||||
k2 = mx.random.key(1)
|
||||
self.assertFalse(mx.array_equal(k1, k2))
|
||||
|
||||
def test_key_split(self):
|
||||
key = mx.random.key(0)
|
||||
|
||||
k1, k2 = mx.random.split(key)
|
||||
self.assertFalse(mx.array_equal(k1, k2))
|
||||
|
||||
r1, r2 = mx.random.split(key)
|
||||
self.assertTrue(mx.array_equal(k1, r1))
|
||||
self.assertTrue(mx.array_equal(k2, r2))
|
||||
|
||||
keys = mx.random.split(key, 10)
|
||||
self.assertEqual(keys.shape, [10, 2])
|
||||
|
||||
def test_uniform(self):
|
||||
key = mx.random.key(0)
|
||||
a = mx.random.uniform(key=key)
|
||||
self.assertEqual(a.shape, [])
|
||||
self.assertEqual(a.dtype, mx.float32)
|
||||
|
||||
b = mx.random.uniform(key=key)
|
||||
self.assertEqual(a.item(), b.item())
|
||||
|
||||
a = mx.random.uniform(shape=(2, 3))
|
||||
self.assertEqual(a.shape, [2, 3])
|
||||
|
||||
a = mx.random.uniform(shape=(1000,), low=-1, high=5)
|
||||
self.assertTrue(mx.all((a > -1) < 5).item())
|
||||
|
||||
a = mx.random.uniform(shape=(1000,), low=mx.array(-1), high=5)
|
||||
self.assertTrue(mx.all((a > -1) < 5).item())
|
||||
|
||||
def test_normal(self):
|
||||
key = mx.random.key(0)
|
||||
a = mx.random.normal(key=key)
|
||||
self.assertEqual(a.shape, [])
|
||||
self.assertEqual(a.dtype, mx.float32)
|
||||
|
||||
b = mx.random.normal(key=key)
|
||||
self.assertEqual(a.item(), b.item())
|
||||
|
||||
a = mx.random.normal(shape=(2, 3))
|
||||
self.assertEqual(a.shape, [2, 3])
|
||||
|
||||
## Generate in float16 or bfloat16
|
||||
for t in [mx.float16, mx.bfloat16]:
|
||||
a = mx.random.normal(dtype=t)
|
||||
self.assertEqual(a.dtype, t)
|
||||
|
||||
def test_randint(self):
|
||||
a = mx.random.randint(0, 1, [])
|
||||
self.assertEqual(a.shape, [])
|
||||
self.assertEqual(a.dtype, mx.int32)
|
||||
|
||||
shape = [88]
|
||||
low = mx.array(3)
|
||||
high = mx.array(15)
|
||||
|
||||
key = mx.random.key(0)
|
||||
a = mx.random.randint(low, high, shape, key=key)
|
||||
self.assertEqual(a.shape, shape)
|
||||
self.assertEqual(a.dtype, mx.int32)
|
||||
|
||||
# Check using the same key yields the same value
|
||||
b = mx.random.randint(low, high, shape, key=key)
|
||||
self.assertListEqual(a.tolist(), b.tolist())
|
||||
|
||||
shape = [3, 4]
|
||||
low = mx.reshape(mx.array([0] * 3), [3, 1])
|
||||
high = mx.reshape(mx.array([12, 13, 14, 15]), [1, 4])
|
||||
|
||||
a = mx.random.randint(low, high, shape)
|
||||
self.assertEqual(a.shape, shape)
|
||||
|
||||
a = mx.random.randint(-10, 10, [1000, 1000])
|
||||
self.assertTrue(mx.all(-10 <= a).item() and mx.all(a < 10).item())
|
||||
|
||||
a = mx.random.randint(10, -10, [1000, 1000])
|
||||
self.assertTrue(mx.all(a == 10).item())
|
||||
|
||||
def test_bernoulli(self):
|
||||
a = mx.random.bernoulli()
|
||||
self.assertEqual(a.shape, [])
|
||||
self.assertEqual(a.dtype, mx.bool_)
|
||||
|
||||
a = mx.random.bernoulli(mx.array(0.5), [5])
|
||||
self.assertEqual(a.shape, [5])
|
||||
|
||||
a = mx.random.bernoulli(mx.array([2.0, -2.0]))
|
||||
self.assertEqual(a.tolist(), [True, False])
|
||||
self.assertEqual(a.shape, [2])
|
||||
|
||||
p = mx.array([0.1, 0.2, 0.3])
|
||||
mx.reshape(p, [1, 3])
|
||||
x = mx.random.bernoulli(p, [4, 3])
|
||||
self.assertEqual(x.shape, [4, 3])
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
mx.random.bernoulli(p, [2]) # Bad shape
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
mx.random.bernoulli(0, [2]) # Bad type
|
||||
|
||||
def test_truncated_normal(self):
|
||||
a = mx.random.truncated_normal(-2.0, 2.0)
|
||||
self.assertEqual(a.size, 1)
|
||||
self.assertEqual(a.dtype, mx.float32)
|
||||
|
||||
a = mx.random.truncated_normal(mx.array([]), mx.array([]))
|
||||
self.assertEqual(a.dtype, mx.float32)
|
||||
self.assertEqual(a.size, 0)
|
||||
|
||||
lower = mx.reshape(mx.array([-2.0, 0.0]), [1, 2])
|
||||
upper = mx.reshape(mx.array([0.0, 1.0, 2.0]), [3, 1])
|
||||
a = mx.random.truncated_normal(lower, upper)
|
||||
|
||||
self.assertEqual(a.shape, [3, 2])
|
||||
self.assertTrue(mx.all(lower <= a).item() and mx.all(a <= upper).item())
|
||||
|
||||
a = mx.random.truncated_normal(2.0, -2.0)
|
||||
self.assertTrue(mx.all(a == 2.0).item())
|
||||
|
||||
a = mx.random.truncated_normal(-3.0, 3.0, [542, 399])
|
||||
self.assertEqual(a.shape, [542, 399])
|
||||
|
||||
lower = mx.array([-2.0, -1.0])
|
||||
higher = mx.array([1.0, 2.0, 3.0])
|
||||
with self.assertRaises(ValueError):
|
||||
mx.random.truncated_normal(lower, higher) # Bad shape
|
||||
|
||||
def test_gumbel(self):
|
||||
samples = mx.random.gumbel(shape=(100, 100))
|
||||
self.assertEqual(samples.shape, [100, 100])
|
||||
self.assertEqual(samples.dtype, mx.float32)
|
||||
mean = 0.5772
|
||||
# Std deviation of the sample mean is small (<0.02),
|
||||
# so this test is pretty conservative
|
||||
self.assertTrue(mx.abs(mx.mean(samples) - mean) < 0.2)
|
||||
|
||||
def test_categorical(self):
|
||||
logits = mx.zeros((10, 20))
|
||||
self.assertEqual(mx.random.categorical(logits, -1).shape, [10])
|
||||
self.assertEqual(mx.random.categorical(logits, 0).shape, [20])
|
||||
self.assertEqual(mx.random.categorical(logits, 1).shape, [10])
|
||||
|
||||
out = mx.random.categorical(logits)
|
||||
self.assertEqual(out.shape, [10])
|
||||
self.assertEqual(out.dtype, mx.uint32)
|
||||
self.assertTrue(mx.max(out).item() < 20)
|
||||
|
||||
out = mx.random.categorical(logits, 0, [5, 20])
|
||||
self.assertEqual(out.shape, [5, 20])
|
||||
self.assertTrue(mx.max(out).item() < 10)
|
||||
|
||||
out = mx.random.categorical(logits, 1, num_samples=7)
|
||||
self.assertEqual(out.shape, [10, 7])
|
||||
out = mx.random.categorical(logits, 0, num_samples=7)
|
||||
self.assertEqual(out.shape, [20, 7])
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
mx.random.categorical(logits, shape=[10, 5], num_samples=5)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
26
python/tests/test_tree.py
Normal file
26
python/tests/test_tree.py
Normal file
@@ -0,0 +1,26 @@
|
||||
import unittest
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.utils
|
||||
|
||||
import mlx_tests
|
||||
|
||||
|
||||
class TestTreeUtils(mlx_tests.MLXTestCase):
|
||||
def test_tree_map(self):
|
||||
tree = {"a": 0, "b": 1, "c": 2}
|
||||
tree = mlx.utils.tree_map(lambda x: x + 1, tree)
|
||||
|
||||
expected_tree = {"a": 1, "b": 2, "c": 3}
|
||||
self.assertEqual(tree, expected_tree)
|
||||
|
||||
def test_tree_flatten(self):
|
||||
tree = [{"a": 1, "b": 2}, "c"]
|
||||
vals = (1, 2, "c")
|
||||
flat_tree = mlx.utils.tree_flatten(tree)
|
||||
self.assertEqual(list(zip(*flat_tree))[1], vals)
|
||||
self.assertEqual(mlx.utils.tree_unflatten(flat_tree), tree)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
167
python/tests/test_vmap.py
Normal file
167
python/tests/test_vmap.py
Normal file
@@ -0,0 +1,167 @@
|
||||
import unittest
|
||||
|
||||
import mlx.core as mx
|
||||
|
||||
import mlx_tests
|
||||
|
||||
|
||||
class TestVmap(mlx_tests.MLXTestCase):
|
||||
def test_basics(self):
|
||||
# Can't vmap over scalars
|
||||
with self.assertRaises(ValueError):
|
||||
mx.vmap(mx.exp)(mx.array(1.0))
|
||||
|
||||
# Invalid input
|
||||
with self.assertRaises(ValueError):
|
||||
mx.vmap(mx.exp)("hello")
|
||||
|
||||
# Invalid axes
|
||||
with self.assertRaises(ValueError):
|
||||
mx.vmap(mx.exp, in_axes="hello")(mx.array([0, 1]))
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
mx.vmap(mx.exp, in_axes=2)(mx.array([0, 1]))
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
mx.vmap(mx.exp, out_axes="hello")(mx.array([0, 1]))
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
mx.vmap(mx.exp, out_axes=2)(mx.array([0, 1]))
|
||||
|
||||
def test_unary(self):
|
||||
ops = [
|
||||
"abs",
|
||||
"cos",
|
||||
"erf",
|
||||
"erfinv",
|
||||
"exp",
|
||||
"log",
|
||||
"log1p",
|
||||
"log2",
|
||||
"log10",
|
||||
"logical_not",
|
||||
"negative",
|
||||
"reciprocal",
|
||||
"rsqrt",
|
||||
"sigmoid",
|
||||
"sign",
|
||||
"sin",
|
||||
"sqrt",
|
||||
"square",
|
||||
]
|
||||
ops = ["erfinv"]
|
||||
for opname in ops:
|
||||
with self.subTest(op=opname):
|
||||
op = getattr(mx, opname)
|
||||
x = mx.arange(5)
|
||||
y = mx.vmap(op)(x)
|
||||
self.assertTrue(mx.array_equal(y, op(x), equal_nan=True))
|
||||
|
||||
x = mx.arange(8).reshape(2, 4)
|
||||
y = mx.vmap(op)(x)
|
||||
self.assertTrue(mx.array_equal(y, op(x), equal_nan=True))
|
||||
|
||||
y = mx.vmap(op, in_axes=1, out_axes=1)(x)
|
||||
self.assertTrue(mx.array_equal(y, op(x), equal_nan=True))
|
||||
|
||||
def test_binary(self):
|
||||
ops = [
|
||||
"add",
|
||||
"divide",
|
||||
"equal",
|
||||
"greater",
|
||||
"greater_equal",
|
||||
"less",
|
||||
"less_equal",
|
||||
"logaddexp",
|
||||
"maximum",
|
||||
"minimum",
|
||||
"multiply",
|
||||
"power",
|
||||
"subtract",
|
||||
]
|
||||
for opname in ops:
|
||||
with self.subTest(op=opname):
|
||||
op = getattr(mx, opname)
|
||||
x = mx.random.uniform(shape=(5,))
|
||||
y = mx.random.uniform(shape=(5,))
|
||||
out = mx.vmap(op)(x, y)
|
||||
self.assertTrue(mx.array_equal(out, op(x, y)))
|
||||
|
||||
x = mx.random.uniform(shape=(2, 4))
|
||||
y = mx.random.uniform(shape=(2, 4))
|
||||
out = mx.vmap(op)(x, y)
|
||||
self.assertTrue(mx.array_equal(out, op(x, y)))
|
||||
|
||||
out = mx.vmap(op, in_axes=(0, 0), out_axes=0)(x, y)
|
||||
self.assertTrue(mx.array_equal(out, op(x, y)))
|
||||
|
||||
y = mx.random.uniform(shape=(4, 2))
|
||||
out = mx.vmap(op, in_axes=(0, 1), out_axes=0)(x, y)
|
||||
self.assertTrue(mx.array_equal(out, op(x, y.T)))
|
||||
|
||||
out = mx.vmap(op, in_axes=(0, 1), out_axes=1)(x, y)
|
||||
self.assertTrue(mx.array_equal(out, op(x, y.T).T))
|
||||
|
||||
def test_tree(self):
|
||||
def my_fun(tree):
|
||||
return (tree["a"] + tree["b"][0]) * tree["b"][1]
|
||||
|
||||
tree = {
|
||||
"a": mx.random.uniform(shape=(2, 4)),
|
||||
"b": (
|
||||
mx.random.uniform(shape=(2, 4)),
|
||||
mx.random.uniform(shape=(2, 4)),
|
||||
),
|
||||
}
|
||||
out = mx.vmap(my_fun)(tree)
|
||||
expected = my_fun(tree)
|
||||
self.assertTrue(mx.array_equal(out, my_fun(tree)))
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
mx.vmap(my_fun, in_axes={"a": 0, "b": 0}, out_axes=0)(tree)
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
mx.vmap(my_fun, in_axes={"a": 0, "b": ((0, 0), 0)}, out_axes=0)(tree)
|
||||
|
||||
out = mx.vmap(my_fun, in_axes=({"a": 0, "b": 0},), out_axes=0)(tree)
|
||||
self.assertTrue(mx.array_equal(out, my_fun(tree)))
|
||||
|
||||
out = mx.vmap(my_fun, in_axes=({"a": 0, "b": (0, 0)},), out_axes=0)(tree)
|
||||
self.assertTrue(mx.array_equal(out, my_fun(tree)))
|
||||
|
||||
tree = {
|
||||
"a": mx.random.uniform(shape=(2, 4)),
|
||||
"b": (
|
||||
mx.random.uniform(shape=(4, 2)),
|
||||
mx.random.uniform(shape=(4, 2)),
|
||||
),
|
||||
}
|
||||
out = mx.vmap(my_fun, in_axes=({"a": 0, "b": (1, 1)},), out_axes=0)(tree)
|
||||
expected = (tree["a"] + tree["b"][0].T) * tree["b"][1].T
|
||||
self.assertTrue(mx.array_equal(out, expected))
|
||||
|
||||
def my_fun(x, y):
|
||||
return {"a": x + y, "b": x * y}
|
||||
|
||||
x = mx.random.uniform(shape=(2, 4))
|
||||
y = mx.random.uniform(shape=(2, 4))
|
||||
out = mx.vmap(my_fun, in_axes=0, out_axes=0)(x, y)
|
||||
expected = my_fun(x, y)
|
||||
self.assertTrue(mx.array_equal(out["a"], expected["a"]))
|
||||
self.assertTrue(mx.array_equal(out["b"], expected["b"]))
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
mx.vmap(my_fun, in_axes=0, out_axes=(0, 1))(x, y)
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
mx.vmap(my_fun, in_axes=0, out_axes={"a": 0, "c": 1})(x, y)
|
||||
|
||||
out = mx.vmap(my_fun, in_axes=0, out_axes={"a": 1, "b": 0})(x, y)
|
||||
expected = my_fun(x, y)
|
||||
self.assertTrue(mx.array_equal(out["a"].T, expected["a"]))
|
||||
self.assertTrue(mx.array_equal(out["b"], expected["b"]))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
Reference in New Issue
Block a user