awni's commit files

This commit is contained in:
Awni Hannun
2023-11-29 10:30:41 -08:00
parent e411fcae68
commit 8ca7f9e8e9
130 changed files with 30159 additions and 0 deletions

38
tests/CMakeLists.txt Normal file
View File

@@ -0,0 +1,38 @@
FetchContent_Declare(
doctest
GIT_REPOSITORY "https://github.com/onqtam/doctest"
GIT_TAG "b7c21ec5ceeadb4951b00396fc1e4642dd347e5f"
)
FetchContent_MakeAvailable(doctest)
add_executable(tests ${PROJECT_SOURCE_DIR}/tests/tests.cpp)
if (MLX_BUILD_METAL)
set(
METAL_TEST_SOURCES
metal_tests.cpp
)
endif()
target_sources(tests PRIVATE
allocator_tests.cpp
array_tests.cpp
arg_reduce_tests.cpp
autograd_tests.cpp
blas_tests.cpp
creations_tests.cpp
device_tests.cpp
eval_tests.cpp
fft_tests.cpp
graph_optimize_tests.cpp
load_tests.cpp
ops_tests.cpp
random_tests.cpp
scheduler_tests.cpp
utils_tests.cpp
vmap_tests.cpp
${METAL_TEST_SOURCES}
)
target_link_libraries(tests PRIVATE mlx doctest)
add_test(NAME tests COMMAND tests)

205
tests/arg_reduce_tests.cpp Normal file
View File

@@ -0,0 +1,205 @@
#include <iostream>
#include "doctest/doctest.h"
#include "mlx/mlx.h"
#include "mlx/primitives.h"
using namespace mlx::core;
void test_arg_reduce_small(
Device d,
const array& x,
ArgReduce::ReduceType r,
std::vector<int> out_shape,
int axis,
std::vector<int> expected_output) {
auto s = default_stream(d);
auto y =
array(out_shape, uint32, std::make_unique<ArgReduce>(s, r, axis), {x});
y.eval();
const uint32_t* ydata = y.data<uint32_t>();
for (int i = 0; i < y.size(); i++) {
CHECK_EQ(expected_output[i], ydata[i]);
}
}
void test_arg_reduce_against_cpu(
const array& x,
ArgReduce::ReduceType r,
std::vector<int> out_shape,
int axis) {
auto y1 = array(
out_shape,
uint32,
std::make_unique<ArgReduce>(default_stream(Device::cpu), r, axis),
{x});
auto y2 = array(
out_shape,
uint32,
std::make_unique<ArgReduce>(default_stream(Device::gpu), r, axis),
{x});
y1.eval();
y2.eval();
CHECK(array_equal(y1, y2).item<bool>());
}
TEST_CASE("test arg reduce small") {
auto x = array(
{0, 2, 1, 7, 5, -5, 0, 2, 1, 7, 5, -5,
0, 2, 1, 7, 5, -5, 0, 2, 1, 7, 5, -5},
{2, 3, 4});
x.eval();
test_arg_reduce_small(
Device::cpu, x, ArgReduce::ArgMin, {2, 3}, 2, {0, 1, 3, 0, 1, 3});
test_arg_reduce_small(
Device::cpu, x, ArgReduce::ArgMin, {2, 4}, 1, {0, 1, 1, 2, 0, 1, 1, 2});
test_arg_reduce_small(
Device::cpu,
x,
ArgReduce::ArgMin,
{3, 4},
0,
{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0});
test_arg_reduce_small(
Device::cpu, x, ArgReduce::ArgMax, {2, 3}, 2, {3, 0, 1, 3, 0, 1});
test_arg_reduce_small(
Device::cpu, x, ArgReduce::ArgMax, {2, 4}, 1, {1, 2, 2, 0, 1, 2, 2, 0});
test_arg_reduce_small(
Device::cpu,
x,
ArgReduce::ArgMax,
{3, 4},
0,
{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0});
if (!metal::is_available()) {
INFO("Skiping arg reduction gpu tests");
return;
}
test_arg_reduce_small(
Device::gpu, x, ArgReduce::ArgMin, {2, 3}, 2, {0, 1, 3, 0, 1, 3});
test_arg_reduce_small(
Device::gpu, x, ArgReduce::ArgMin, {2, 4}, 1, {0, 1, 1, 2, 0, 1, 1, 2});
test_arg_reduce_small(
Device::gpu,
x,
ArgReduce::ArgMin,
{3, 4},
0,
{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0});
test_arg_reduce_small(
Device::gpu, x, ArgReduce::ArgMax, {2, 3}, 2, {3, 0, 1, 3, 0, 1});
test_arg_reduce_small(
Device::gpu, x, ArgReduce::ArgMax, {2, 4}, 1, {1, 2, 2, 0, 1, 2, 2, 0});
test_arg_reduce_small(
Device::gpu,
x,
ArgReduce::ArgMax,
{3, 4},
0,
{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0});
}
TEST_CASE("test arg reduce against cpu") {
if (!metal::is_available()) {
INFO("Skiping arg reduction gpu tests");
return;
}
auto x = random::uniform(array(0.0), array(1.0), {127, 92, 55});
x.eval();
test_arg_reduce_against_cpu(x, ArgReduce::ArgMin, {127, 92}, 2);
test_arg_reduce_against_cpu(x, ArgReduce::ArgMin, {127, 55}, 1);
test_arg_reduce_against_cpu(x, ArgReduce::ArgMin, {92, 55}, 0);
test_arg_reduce_against_cpu(x, ArgReduce::ArgMax, {127, 92}, 2);
test_arg_reduce_against_cpu(x, ArgReduce::ArgMax, {127, 55}, 1);
test_arg_reduce_against_cpu(x, ArgReduce::ArgMax, {92, 55}, 0);
auto y = random::uniform(array(0.0), array(1.0), {1234});
y.eval();
test_arg_reduce_against_cpu(y, ArgReduce::ArgMin, {}, 0);
test_arg_reduce_against_cpu(y, ArgReduce::ArgMax, {}, 0);
}
void test_arg_reduce_small_bool(
Device d,
ArgReduce::ReduceType r,
std::vector<int> out_shape,
int axis,
std::vector<int> expected_output) {
auto s = default_stream(d);
auto x = array(
{0, 2, 1, 7, 5, -5, 0, 2, 1, 7, 5, -5,
0, 2, 1, 7, 5, -5, 0, 2, 1, 7, 5, -5},
{2, 3, 4});
x.eval();
auto y =
array(out_shape, uint32, std::make_unique<ArgReduce>(s, r, axis), {x});
y.eval();
const uint32_t* ydata = y.data<uint32_t>();
for (int i = 0; i < y.size(); i++) {
CHECK_EQ(expected_output[i], ydata[i]);
}
}
TEST_CASE("test arg reduce bool") {
if (!metal::is_available()) {
INFO("Skiping arg reduction gpu tests");
return;
}
auto x = array(
{false, true, true, false, false, false, false, true,
true, false, true, true, false, true, true, false,
false, false, false, true, true, false, true, true},
{2, 3, 4});
x.eval();
test_arg_reduce_small(
Device::gpu, x, ArgReduce::ArgMin, {2, 3}, 2, {0, 0, 1, 0, 0, 1});
test_arg_reduce_small(
Device::gpu, x, ArgReduce::ArgMin, {2, 4}, 1, {0, 1, 1, 0, 0, 1, 1, 0});
test_arg_reduce_small(
Device::gpu,
x,
ArgReduce::ArgMin,
{3, 4},
0,
{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0});
test_arg_reduce_small(
Device::gpu, x, ArgReduce::ArgMax, {2, 3}, 2, {1, 3, 0, 1, 3, 0});
test_arg_reduce_small(
Device::gpu, x, ArgReduce::ArgMax, {2, 4}, 1, {2, 0, 0, 1, 2, 0, 0, 1});
test_arg_reduce_small(
Device::gpu,
x,
ArgReduce::ArgMax,
{3, 4},
0,
{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0});
}
TEST_CASE("test arg reduce edge cases") {
auto a = argmin(array(1.0));
CHECK_EQ(a.item<uint32_t>(), 0);
auto b = argmax(array(1.0));
CHECK_EQ(b.item<uint32_t>(), 0);
CHECK_THROWS(argmin(array({})));
CHECK_THROWS(argmax(array({})));
}
TEST_CASE("test arg reduce irregular strides") {
auto x = array(
{0, 2, 1, 7, 5, -5, 0, 2, 1, 7, 5, -5,
0, 2, 1, 7, 5, -5, 0, 2, 1, 7, 5, -5},
{2, 3, 4});
x = transpose(x, {2, 0, 1});
x.eval();
test_arg_reduce_small(
Device::cpu, x, ArgReduce::ArgMin, {4, 2}, 2, {0, 0, 1, 1, 1, 1, 2, 2});
if (!metal::is_available()) {
INFO("Skiping arg reduction gpu tests");
return;
}
}

108
tests/blas_tests.cpp Normal file
View File

@@ -0,0 +1,108 @@
#include <numeric>
#include "doctest/doctest.h"
#include "mlx/mlx.h"
using namespace mlx::core;
TEST_CASE("test matmul") {
auto a = array(1);
auto b = array({1.0});
CHECK_THROWS_AS(matmul(a, b), std::invalid_argument);
a = array({1.0});
b = array({1.0});
auto out = matmul(a, b);
CHECK_EQ(out.shape(), std::vector<int>{});
CHECK_EQ(out.size(), 1);
CHECK_EQ(out.dtype(), float32);
CHECK_EQ(out.item<float>(), 1.0f);
a = ones({2, 4});
b = ones({2});
CHECK_THROWS_AS(matmul(a, b), std::invalid_argument);
a = ones({2, 4});
b = ones({3, 2});
CHECK_THROWS_AS(matmul(a, b), std::invalid_argument);
a = ones({2, 4});
b = ones({4, 3, 2});
CHECK_THROWS_AS(matmul(a, b), std::invalid_argument);
a = ones({2});
b = ones({4, 2});
CHECK_THROWS_AS(matmul(a, b), std::invalid_argument);
a = ones({2, 3});
b = ones({4, 2});
CHECK_THROWS_AS(matmul(a, b), std::invalid_argument);
a = ones({2, 4, 3});
b = ones({4, 2});
CHECK_THROWS_AS(matmul(a, b), std::invalid_argument);
a = ones({2, 4});
b = ones({4, 2});
out = matmul(a, b);
CHECK(array_equal(out, full({2, 2}, 4.0f)).item<bool>());
a = ones({2, 4}, int32);
b = ones({4, 2}, float32);
out = matmul(a, b);
CHECK(array_equal(out, full({2, 2}, 4.0f)).item<bool>());
// Check single dimensions
a = ones({4});
b = ones({4, 2});
out = matmul(a, b);
CHECK(array_equal(out, full({2}, 4.0f)).item<bool>());
a = ones({2, 4});
b = ones({4});
out = matmul(a, b);
CHECK(array_equal(out, full({2}, 4.0f)).item<bool>());
a = ones({4});
b = ones({4});
out = matmul(a, b);
CHECK(array_equal(out, full({}, 4.0f)).item<bool>());
// Test transposed arrays
a = array({1.0f, 1.0f, 1.0f, 1.0f}, {1, 4});
b = array({1.0f, 1.0f, 1.0f, 1.0f}, {4, 1});
out = matmul(transpose(a), transpose(b));
CHECK(array_equal(out, ones({4, 4})).item<bool>());
a = array({1.0f, 2.0f, 3.0f, 4.0f}, {2, 2});
b = array({1.0f, 2.0f, 1.0f, 2.0f}, {2, 2});
out = matmul(transpose(a), b);
CHECK(
array_equal(out, array({4.0f, 8.0f, 6.0f, 12.0f}, {2, 2})).item<bool>());
out = matmul(a, transpose(b));
CHECK(
array_equal(out, array({5.0f, 5.0f, 11.0f, 11.0f}, {2, 2})).item<bool>());
out = matmul(transpose(a), transpose(b));
CHECK(
array_equal(out, array({7.0f, 7.0f, 10.0f, 10.0f}, {2, 2})).item<bool>());
// Test broadcasting for both arrays
a = ones({5, 4, 2});
b = ones({2, 3});
out = matmul(a, b);
CHECK(array_equal(out, full({5, 4, 3}, 2.0f)).item<bool>());
a = ones({5, 1, 4, 2});
b = ones({1, 7, 2, 3});
out = matmul(a, b);
CHECK(array_equal(out, full({5, 7, 4, 3}, 2.0f)).item<bool>());
// Test batched matmul with transpose
a = ones({2, 2, 4});
b = ones({2, 4, 2});
out = matmul(transpose(a, {0, 2, 1}), transpose(b, {0, 2, 1}));
CHECK(array_equal(out, full({2, 4, 4}, 2.0f)).item<bool>());
}

438
tests/metal_tests.cpp Normal file
View File

@@ -0,0 +1,438 @@
#include <array>
#include "doctest/doctest.h"
#include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/metal.h"
#include "mlx/mlx.h"
using namespace mlx::core;
static const std::array<Dtype, 5> types =
{bool_, uint32, int32, int64, float32};
TEST_CASE("test metal device") {
// Make sure the device and library can load
CHECK(metal::is_available());
auto& device = metal::device(Device::gpu);
}
TEST_CASE("test metal arange") {
for (auto t : types) {
if (t == bool_) {
continue;
}
auto out_cpu = arange(1, 100, 2, t, Device::cpu);
auto out_gpu = arange(1, 100, 2, t, Device::gpu);
CHECK(array_equal(out_gpu, out_cpu, Device::cpu).item<bool>());
out_cpu = arange(1, 5, 0.25, t, Device::cpu);
out_gpu = arange(1, 5, 0.25, t, Device::gpu);
CHECK(array_equal(out_gpu, out_cpu, Device::cpu).item<bool>());
}
}
TEST_CASE("test metal full") {
for (auto t : types) {
auto out_cpu = full({4, 4}, 2, t, Device::cpu);
auto out_gpu = full({4, 4}, 2, t, Device::gpu);
CHECK(array_equal(out_gpu, out_cpu, Device::cpu).item<bool>());
}
// Check broadcasting works
{
auto x = full({2, 2}, array({3, 4}, {2, 1}), Device::gpu);
CHECK(
array_equal(x, array({3, 3, 4, 4}, {2, 2}), Device::cpu).item<bool>());
x = full({2, 2}, array({3, 4}, {1, 2}), Device::gpu);
CHECK(
array_equal(x, array({3, 4, 3, 4}, {2, 2}), Device::cpu).item<bool>());
}
// Check zeros and ones
{
auto x = zeros({2, 2}, float32, Device::gpu);
auto y = array({0.0, 0.0, 0.0, 0.0}, {2, 2});
CHECK(array_equal(x, y, Device::cpu).item<bool>());
x = ones({2, 2}, float32, Device::gpu);
y = array({1.0, 1.0, 1.0, 1.0}, {2, 2});
CHECK(array_equal(x, y, Device::cpu).item<bool>());
}
}
TEST_CASE("test metal astype") {
array x = array({-4, -3, -2, -1, 0, 1, 2, 3});
// Check all types work
for (auto t : types) {
auto out_cpu = astype(x, t, Device::cpu);
auto out_gpu = astype(x, t, Device::gpu);
CHECK(array_equal(out_gpu, out_cpu, Device::cpu).item<bool>());
}
x = transpose(reshape(x, {2, 2, 2}), {1, 2, 0});
for (auto t : types) {
auto out_cpu = astype(x, t, Device::cpu);
auto out_gpu = astype(x, t, Device::gpu);
CHECK(array_equal(out_gpu, out_cpu, Device::cpu).item<bool>());
}
}
TEST_CASE("test metal reshape") {
array x = array({0, 1, 2, 3, 4, 5, 6, 7});
auto out_cpu = reshape(x, {2, 2, 2});
auto out_gpu = reshape(x, {2, 2, 2}, Device::gpu);
CHECK(array_equal(out_gpu, out_cpu, Device::cpu).item<bool>());
x = transpose(reshape(x, {2, 2, 2}), {1, 2, 0});
out_cpu = reshape(x, {4, 2});
out_gpu = reshape(x, {4, 2}, Device::gpu);
CHECK(array_equal(out_gpu, out_cpu, Device::cpu).item<bool>());
out_cpu = reshape(x, {8});
out_gpu = reshape(x, {8}, Device::gpu);
CHECK(array_equal(out_gpu, out_cpu, Device::cpu).item<bool>());
}
TEST_CASE("test metal reduce") {
{
array a(true);
CHECK_EQ(all(a, Device::gpu).item<bool>(), true);
CHECK_EQ(any(a, Device::gpu).item<bool>(), true);
a = array(std::initializer_list<bool>{});
CHECK_EQ(all(a, Device::gpu).item<bool>(), true);
CHECK_EQ(any(a, Device::gpu).item<bool>(), false);
}
{
std::vector<int> vals(33, 1);
array a(vals.data(), {33});
CHECK_EQ(all(a, Device::gpu).item<bool>(), true);
vals[32] = 0;
a = array(vals.data(), {33});
CHECK_EQ(all(a, Device::gpu).item<bool>(), false);
}
{
std::vector<int> vals(33, 0);
array a(vals.data(), {33});
CHECK_EQ(any(a, Device::gpu).item<bool>(), false);
vals[32] = 1;
a = array(vals.data(), {33});
CHECK_EQ(any(a, Device::gpu).item<bool>(), true);
}
{
std::vector<int> vals(1 << 14, 0);
array a(vals.data(), {1 << 14});
CHECK_EQ(all(a, Device::gpu).item<bool>(), false);
CHECK_EQ(any(a, Device::gpu).item<bool>(), false);
vals[4] = 1;
vals[999] = 1;
vals[2000] = 1;
a = array(vals.data(), {1 << 14});
CHECK_EQ(all(a, Device::gpu).item<bool>(), false);
CHECK_EQ(any(a, Device::gpu).item<bool>(), true);
}
// sum and prod
{
array a = array({true, false, true});
CHECK_EQ(sum(a, Device::gpu).item<uint32_t>(), 2);
CHECK_EQ(prod(a, Device::gpu).item<bool>(), false);
a = array({true, true, true});
CHECK_EQ(sum(a, Device::gpu).item<uint32_t>(), 3);
CHECK_EQ(prod(a, Device::gpu).item<bool>(), true);
a = full({2, 2, 2}, 2.0f);
CHECK_EQ(sum(a, Device::gpu).item<float>(), 16.0f);
CHECK_EQ(prod(a, Device::gpu).item<float>(), 256.0f);
a = full({500, 2, 2}, 1u);
CHECK_EQ(sum(a, Device::gpu).item<uint32_t>(), 2000);
CHECK_EQ(prod(a, Device::gpu).item<uint32_t>(), 1u);
a = full({500, 2, 2}, 1);
CHECK_EQ(sum(a, Device::gpu).item<int32_t>(), 2000);
CHECK_EQ(prod(a, Device::gpu).item<int32_t>(), 1);
}
// reducing only some axes and irregular layouts
{
array a(1.0f);
a = broadcast_to(a, {2, 2, 2});
CHECK_EQ(sum(a, Device::gpu).item<float>(), 8.0f);
a = ones({2, 4, 8, 16});
for (auto ax : {0, 1, 2, 3}) {
auto out_gpu = sum(a, ax, false, Device::gpu);
auto out_cpu = sum(a, ax, false, Device::cpu);
CHECK(array_equal(out_gpu, out_cpu, Device::cpu).item<bool>());
}
for (auto ax : {1, 2, 3}) {
auto out_gpu = sum(a, {0, ax}, false, Device::gpu);
auto out_cpu = sum(a, {0, ax}, false, Device::cpu);
CHECK(array_equal(out_gpu, out_cpu, Device::cpu).item<bool>());
}
for (auto ax : {2, 3}) {
auto out_gpu = sum(a, {0, 1, ax}, false, Device::gpu);
auto out_cpu = sum(a, {0, 1, ax}, false, Device::cpu);
CHECK(array_equal(out_gpu, out_cpu, Device::cpu).item<bool>());
}
}
}
TEST_CASE("test metal binary ops") {
// scalar-scalar
{
array a(2.0f);
array b(4.0f);
auto out = add(a, b, Device::gpu);
CHECK_EQ(out.item<float>(), 6.0f);
}
// scalar-vector and vector-scalar
{
array a(2.0f);
array b({2.0f, 4.0f, 6.0f});
auto out = add(a, b, Device::gpu);
auto expected = array({4.0f, 6.0f, 8.0f});
CHECK(array_equal(out, expected, Device::cpu).item<bool>());
out = add(b, a, Device::gpu);
CHECK(array_equal(out, expected, Device::cpu).item<bool>());
}
// vector-vector
{
array a({0.0f, 1.0f, 2.0f});
array b({3.0f, 4.0f, 5.0f});
auto out = add(a, b, Device::gpu);
auto expected = array({3.0f, 5.0f, 7.0f});
CHECK(array_equal(out, expected, Device::cpu).item<bool>());
}
// general
{
array a({0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f}, {2, 2, 2});
array b({0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f}, {2, 2, 2});
a = transpose(a, {0, 2, 1});
b = transpose(b, {1, 0, 2});
auto out_gpu = add(a, b, Device::gpu);
auto out_cpu = add(a, b, Device::cpu);
auto expected =
array({0.0f, 3.0f, 5.0f, 8.0f, 6.0f, 9.0f, 11.0f, 14.0f}, {2, 2, 2});
CHECK(array_equal(out_gpu, out_cpu, Device::cpu).item<bool>());
CHECK(array_equal(out_gpu, expected, Device::cpu).item<bool>());
}
// Check all types work
for (auto t : types) {
auto a = astype(array({0, 1, 2}), t);
auto b = astype(array({3, 4, 5}), t);
auto out_cpu = add(a, b, Device::cpu);
auto out_gpu = add(a, b, Device::gpu);
CHECK(array_equal(out_gpu, out_cpu, Device::cpu).item<bool>());
}
// Check subtraction
{
auto a = array({3, 2, 1});
auto b = array({1, 1, 1});
auto out = subtract(a, b, Device::gpu);
CHECK(array_equal(out, array({2, 1, 0}), Device::cpu).item<bool>());
}
// Check multiplication
{
auto a = array({1, 2, 3});
auto b = array({2, 2, 2});
auto out = multiply(a, b, Device::gpu);
CHECK(array_equal(out, array({2, 4, 6}), Device::cpu).item<bool>());
}
// Check division
{
auto x = array(1.0f);
auto y = array(1.0f);
CHECK_EQ(divide(x, y, Device::gpu).item<float>(), 1.0f);
x = array(1.0f);
y = array(0.5);
CHECK_EQ(divide(x, y, Device::gpu).item<float>(), 2.0f);
x = array(1.0f);
y = array(0.0f);
CHECK(std::isinf(divide(x, y, Device::gpu).item<float>()));
x = array(0.0f);
y = array(0.0f);
CHECK(std::isnan(divide(x, y, Device::gpu).item<float>()));
}
// Check maximum and minimum
{
auto x = array(1.0f);
auto y = array(0.0f);
CHECK_EQ(maximum(x, y, Device::gpu).item<float>(), 1.0f);
CHECK_EQ(minimum(x, y, Device::gpu).item<float>(), 0.0f);
y = array(2.0f);
CHECK_EQ(maximum(x, y, Device::gpu).item<float>(), 2.0f);
CHECK_EQ(minimum(x, y, Device::gpu).item<float>(), 1.0f);
}
// Check equal
{
array x(1.0f);
array y(1.0f);
CHECK(equal(x, y, Device::gpu).item<bool>());
x = array(0.0f);
CHECK(!equal(x, y, Device::gpu).item<bool>());
}
// Greater and less
{
array x(1.0f);
array y(0.0f);
CHECK(greater(x, y, Device::gpu).item<bool>());
CHECK(greater_equal(x, y, Device::gpu).item<bool>());
CHECK(!greater(y, x, Device::gpu).item<bool>());
CHECK(!greater_equal(y, x, Device::gpu).item<bool>());
y = array(1.0f);
CHECK(!greater(x, y, Device::gpu).item<bool>());
CHECK(greater_equal(x, y, Device::gpu).item<bool>());
x = array(0.0f);
y = array(1.0f);
CHECK(less(x, y, Device::gpu).item<bool>());
CHECK(less_equal(x, y, Device::gpu).item<bool>());
CHECK(!less(y, x, Device::gpu).item<bool>());
CHECK(!less_equal(y, x, Device::gpu).item<bool>());
y = array(0.0f);
CHECK(!less(x, y, Device::gpu).item<bool>());
CHECK(less_equal(x, y, Device::gpu).item<bool>());
}
// Check logaddexp
{
constexpr float inf = std::numeric_limits<float>::infinity();
array x(inf);
array y(2.0f);
auto out = logaddexp(x, y, Device::gpu);
CHECK_EQ(out.item<float>(), inf);
x = array(-inf);
out = logaddexp(x, y, Device::gpu);
CHECK_EQ(out.item<float>(), 2.0f);
y = array(-inf);
out = logaddexp(x, y, Device::gpu);
CHECK_EQ(out.item<float>(), -inf);
}
}
TEST_CASE("test metal unary ops") {
// contiguous
{
array x({-1.0f, 0.0f, 1.0f});
auto expected = array({1.0f, 0.0f, 1.0f});
CHECK(array_equal(abs(x, Device::gpu), expected, Device::cpu).item<bool>());
}
// general
{
array x({-1.0f, 0.0f, 1.0f, 1.0f, -1.0f, 1.0f, 3.0f, -3.0f});
auto y = slice(x, {0}, {8}, {2});
auto expected = array({1.0f, 1.0f, 1.0f, 3.0f});
CHECK(array_equal(abs(y, Device::gpu), expected, Device::cpu).item<bool>());
y = slice(x, {4}, {8});
expected = array({1.0f, 1.0f, 3.0f, 3.0f});
CHECK(array_equal(abs(y, Device::gpu), expected, Device::cpu).item<bool>());
}
// Test negative
{
array x(1.0f);
CHECK_EQ(negative(x, Device::gpu).item<float>(), -1.0f);
}
// Check all types work
for (auto t : types) {
if (t == bool_) {
continue;
}
auto in = astype(array({1}), t);
auto out_cpu = negative(in, Device::cpu);
auto out_gpu = negative(in, Device::gpu);
CHECK(array_equal(out_gpu, out_cpu, Device::cpu).item<bool>());
}
// Test log1p
{
constexpr float inf = std::numeric_limits<float>::infinity();
array x(-1.0f);
CHECK_EQ(log1p(x, Device::gpu).item<float>(), -inf);
x = array(0.0f);
CHECK_EQ(log1p(x, Device::gpu).item<float>(), 0.0f);
x = array(1e-9f);
CHECK_EQ(log1p(x, Device::gpu).item<float>(), 1e-9f);
x = array(-2.0f);
CHECK(std::isnan(log1p(x, Device::gpu).item<float>()));
}
}
TEST_CASE("test metal random") {
{
auto key = random::key(0);
auto x = random::bits({}, 4, key, Device::gpu);
auto y = random::bits({}, 4, key, Device::gpu);
CHECK_EQ(x.item<uint32_t>(), 1797259609u);
CHECK_EQ(x.item<uint32_t>(), y.item<uint32_t>());
}
{
auto key = random::key(1);
auto x = random::bits({}, 4, key, Device::gpu);
CHECK_EQ(x.item<uint32_t>(), 507451445u);
}
{
auto key = random::key(0);
auto x = random::bits({3, 1}, 4, key, Device::gpu);
auto expected = array({4146024105u, 1351547692u, 2718843009u}, {3, 1});
CHECK(array_equal(x, expected, Device::cpu).item<bool>());
}
}
TEST_CASE("test metal matmul") {
{
auto a = ones({2, 2});
auto b = ones({2, 2});
auto out = matmul(a, b, Device::gpu);
CHECK(array_equal(out, full({2, 2}, 2.0f), Device::cpu).item<bool>());
}
// Batched matmul
{
auto a = ones({3, 2, 2});
auto b = ones({3, 2, 2});
auto out = matmul(a, b, Device::gpu);
CHECK(array_equal(out, full({3, 2, 2}, 2.0f), Device::cpu).item<bool>());
}
// Broadcast batched matmul
{
auto a = ones({1, 3, 2, 2});
auto b = ones({3, 1, 2, 2});
auto out = matmul(a, b, Device::gpu);
CHECK(array_equal(out, full({3, 3, 2, 2}, 2.0f), Device::cpu).item<bool>());
}
}

1926
tests/ops_tests.cpp Normal file

File diff suppressed because it is too large Load Diff

545
tests/random_tests.cpp Normal file
View File

@@ -0,0 +1,545 @@
#include <numeric>
#include "doctest/doctest.h"
#include "mlx/mlx.h"
using namespace mlx::core;
TEST_CASE("test random key") {
auto key = random::key(0);
CHECK(array_equal(key, array({0, 0})).item<bool>());
key = random::key(1);
CHECK(array_equal(key, array({0, 1})).item<bool>());
int64_t seed = static_cast<int64_t>(1) << 32;
key = random::key(seed);
CHECK(array_equal(key, array({1, 0})).item<bool>());
key = random::key(seed + 1);
CHECK(array_equal(key, array({1, 1})).item<bool>());
}
TEST_CASE("test global rng") {
random::seed(4);
auto x = random::bits({});
auto y = random::bits({});
random::seed(4);
auto a = random::bits({});
auto b = random::bits({});
CHECK_EQ(x.item<uint32_t>(), a.item<uint32_t>());
CHECK_EQ(y.item<uint32_t>(), b.item<uint32_t>());
}
TEST_CASE("test random split") {
auto [key, subkey] = random::split(random::key(0));
CHECK(array_equal(key, array({4146024105u, 967050713u})).item<bool>());
CHECK(array_equal(subkey, array({2718843009u, 1272950319u})).item<bool>());
auto keys = random::split(random::key(0), 3);
auto expected = array(
{2467461003u,
428148500u,
3186719485u,
3840466878u,
2562233961u,
1946702221u},
{3, 2});
CHECK(array_equal(keys, expected).item<bool>());
}
TEST_CASE("test random bits") {
// Test shapes, types, and sizes
{
auto key = random::key(0);
auto x = random::bits({}, key);
CHECK_EQ(x.size(), 1);
CHECK_EQ(x.dtype(), uint32);
x = random::bits({0}, key);
CHECK(array_equal(x, array({})).item<bool>());
// Check wrong key type or shape
key = array({0, 0});
CHECK_THROWS_AS(random::uniform({}, key), std::invalid_argument);
key = array({0, 0}, {1, 2});
CHECK_THROWS_AS(random::uniform({}, key), std::invalid_argument);
key = array({0u, 0u, 0u}, {3, 1});
CHECK_THROWS_AS(random::uniform({}, key), std::invalid_argument);
key = array({0u, 0u}, {2, 1});
CHECK_THROWS_AS(random::uniform({}, key), std::invalid_argument);
}
// Expected bits in the following tests were generated from
// Jax's Threefry 2x32 implementation using the following in
// python:
//
// ```
// import jax
// import jax.prng
// shape = (SET THIS)
// seed = (SET THIS)
// width = (SET THIS)
// key = jax.random.PRNGKey(seed)
// print(jax.prng.threefry_prng_impl.random_bits(key, width, shape))
{
auto key = random::key(0);
auto x = random::bits({}, key);
auto y = random::bits({}, key);
CHECK_EQ(x.item<uint32_t>(), 1797259609u);
CHECK_EQ(x.item<uint32_t>(), y.item<uint32_t>());
x = random::bits({}, 2, key);
CHECK_EQ(x.item<uint16_t>(), 345);
x = random::bits({}, 1, key);
CHECK_EQ(x.item<uint8_t>(), 89);
}
{
auto key = random::key(1);
auto x = random::bits({}, key);
CHECK_EQ(x.item<uint32_t>(), 507451445u);
x = random::bits({}, 2, key);
CHECK_EQ(x.item<uint16_t>(), 6197);
x = random::bits({}, 1, key);
CHECK_EQ(x.item<uint8_t>(), 53);
CHECK_THROWS(random::bits({}, 0, key));
CHECK_THROWS(random::bits({}, 5, key));
CHECK_THROWS(random::bits({}, -1, key));
}
{
auto key = random::key(0);
auto x = random::bits({3, 1}, key);
auto expected = array({4146024105u, 1351547692u, 2718843009u}, {3, 1});
CHECK(array_equal(x, expected).item<bool>());
x = random::bits({5}, 2, key);
expected = array({20137, 63263, 64300, 20622, 16513}, uint16);
CHECK(array_equal(x, expected).item<bool>());
expected = array({20137, 63263, 64300, 20622, 16513, 41486}, uint16);
x = random::bits({6}, 2, key);
CHECK(array_equal(x, expected).item<bool>());
expected = array({20137, 63263, 1497, 14756, 16513, 41486, 44591}, uint16);
x = random::bits({7}, 2, key);
CHECK(array_equal(x, expected).item<bool>());
x = random::bits({8}, 2, key);
expected =
array({20137, 63263, 1497, 14756, 16513, 41486, 44591, 19423}, uint16);
CHECK(array_equal(x, expected).item<bool>());
}
{
auto key = array({0u, 0u, 1u, 1u}, {2, 2});
auto shape = std::vector<int>{3};
auto fn = [&shape](array k) { return random::bits(shape, k); };
auto expected = array(
{4146024105u,
1351547692u,
2718843009u,
3725146706u,
1802982961u,
1349634643u},
{2, 3});
CHECK(array_equal(vmap(fn)(key), expected).item<bool>());
expected = array(
{2441914641u,
1110694964u,
3819641963u,
2441914641u,
1110694964u,
3819641963u},
{2, 3});
CHECK(array_equal(vmap(fn, 1)(key), expected).item<bool>());
// Vmap twice
key = array(
{0u,
0u,
1u,
1u,
2u,
2u,
3u,
3u,
4u,
4u,
5u,
5u},
{3, 2, 2});
shape = {2};
auto out = vmap(vmap(fn))(key);
expected = array(
{928981903u,
3453687069u,
3606183818u,
460005496u,
2799733733u,
856293553u,
4081856343u,
3445925136u,
2775548010u,
1430281703u,
305173070u,
2615843348u},
{3, 2, 2});
CHECK(array_equal(out, expected).item<bool>());
out = vmap(vmap(fn, 1), 0)(key);
expected = array(
{1948878966u,
4237131848u,
1948878966u,
4237131848u,
2531170506u,
1858648356u,
2531170506u,
1858648356u,
740561898u,
4234094099u,
740561898u,
4234094099u},
{3, 2, 2});
CHECK(array_equal(out, expected).item<bool>());
}
// Vmap smaller type
{
auto key = array({0u, 0u, 1u, 1u}, {2, 2});
auto fn = [](array k) { return random::bits({5}, 2, k); };
auto expected = array(
{4146024105u,
1351547692u,
2718843009u,
3725146706u,
1802982961u,
1349634643u},
{2, 3});
auto out = vmap(fn)(key);
auto x1 = random::bits({5}, 2, take(key, array(0), 0));
auto x2 = random::bits({5}, 2, take(key, array(1), 0));
CHECK(array_equal(take(out, array(0), 0), x1).item<bool>());
CHECK(array_equal(take(out, array(1), 0), x2).item<bool>());
}
}
TEST_CASE("test random uniform") {
// Test shapes, types, and sizes
{
auto x = random::uniform({});
CHECK_EQ(x.size(), 1);
CHECK_EQ(x.dtype(), float32);
if (is_available(float16)) {
x = random::uniform({}, float16);
CHECK_EQ(x.size(), 1);
CHECK_EQ(x.dtype(), float16);
}
x = random::uniform({0});
CHECK(array_equal(x, array({})).item<bool>());
// Non float type throws
CHECK_THROWS_AS(random::uniform({}, int32), std::invalid_argument);
// Check broadcasting
x = random::uniform(zeros({3, 1}), ones({1, 3}), {3, 3});
CHECK_EQ(x.shape(), std::vector<int>{3, 3});
CHECK_THROWS_AS(
random::uniform(zeros({3, 3}), 1.0, {1, 3}), std::invalid_argument);
CHECK_THROWS_AS(
random::uniform(zeros({3, 3}), 1.0, {2, 3}), std::invalid_argument);
CHECK_THROWS_AS(
random::uniform(zeros({3, 1}), ones({1, 3}), {1, 3}),
std::invalid_argument);
// Check wrong key type or shape
auto key = array({0, 0});
CHECK_THROWS_AS(random::uniform({}, key), std::invalid_argument);
key = array({0, 0}, {1, 2});
CHECK_THROWS_AS(random::uniform({}, key), std::invalid_argument);
key = array({0u, 0u, 0u}, {3, 1});
CHECK_THROWS_AS(random::uniform({}, key), std::invalid_argument);
key = array({0u, 0u}, {2, 1});
CHECK_THROWS_AS(random::uniform({}, key), std::invalid_argument);
}
// Expected bits in the following tests were generated from
// Jax's Threefry 2x32 implementation using the following in
// python:
//
// ```
// import jax
// import jax.prng
// shape = (SET THIS)
// seed = (SET THIS)
// key = jax.random.PRNGKey(seed)
// print(jax.prng.threefry_prng_impl.random_bits(key, 32, shape))
constexpr auto to_float = [](uint32_t n) {
return static_cast<float>(n) / UINT32_MAX;
};
{
auto key = random::key(0);
auto x = random::uniform({}, key);
auto y = random::uniform({}, key);
auto expected = to_float(1797259609);
CHECK_EQ(x.item<float>(), expected);
CHECK_EQ(x.item<float>(), y.item<float>());
}
{
auto key = random::key(1);
auto x = random::uniform({}, key);
auto expected = to_float(507451445);
CHECK_EQ(x.item<float>(), expected);
}
{
auto key = random::key(0);
auto x = random::uniform({3, 1}, key);
auto expected = array(
{to_float(4146024105), to_float(1351547692), to_float(2718843009)},
{3, 1});
CHECK(array_equal(x, expected).item<bool>());
}
// Check vmap
{
auto key = random::key(0);
auto fun = [](array k, array low) {
return random::uniform(low, 1, {3}, float32, k);
};
auto out = vmap(fun, -1)(key, zeros({2, 3}));
CHECK_EQ(out.shape(), std::vector<int>{2, 3});
key = zeros({2, 2}, uint32);
out = vmap(fun)(key, zeros({2, 3}));
CHECK_EQ(out.shape(), std::vector<int>{2, 3});
}
// Check bounds are respected
{
auto key = random::key(128291);
auto out = random::uniform(array(-1.0f), array(1.0f), {100}, float32, key);
CHECK(all(less(out, array(1.0f))).item<bool>());
CHECK(all(greater_equal(out, array(-1.0f))).item<bool>());
}
}
TEST_CASE("test random normal") {
// Test shapes, types, and sizes
{
auto x = random::normal({});
CHECK_EQ(x.size(), 1);
CHECK_EQ(x.dtype(), float32);
x = random::uniform({0});
CHECK(array_equal(x, array({})).item<bool>());
// Non float type throws
CHECK_THROWS_AS(random::normal({}, int32), std::invalid_argument);
// Check wrong key type or shape
auto key = array({0, 0});
CHECK_THROWS_AS(random::normal({}, key), std::invalid_argument);
key = array({0, 0}, {1, 2});
CHECK_THROWS_AS(random::normal({}, key), std::invalid_argument);
key = array({0u, 0u, 0u}, {3, 1});
CHECK_THROWS_AS(random::normal({}, key), std::invalid_argument);
key = array({0u, 0u}, {2, 1});
CHECK_THROWS_AS(random::normal({}, key), std::invalid_argument);
}
{
constexpr float inf = std::numeric_limits<float>::infinity();
auto key = random::key(128291);
auto out = random::normal({100}, key);
CHECK(all(less(abs(out), array(inf))).item<bool>());
}
}
TEST_CASE("test random randint") {
CHECK_THROWS_AS(
random::randint(array(3), array(5), {1}, float32), std::invalid_argument);
auto x = random::randint(0, 10, {}, uint32);
CHECK_EQ(x.size(), 1);
CHECK_EQ(x.dtype(), uint32);
x = random::randint(0, 2, {}, bool_);
CHECK_EQ(x.size(), 1);
CHECK_EQ(x.dtype(), bool_);
x = random::randint(0, 2, {}, int32);
CHECK_EQ(x.size(), 1);
CHECK_EQ(x.dtype(), int32);
x = random::randint(0, 2, {}, int64);
CHECK_EQ(x.size(), 1);
CHECK_EQ(x.dtype(), int64);
// Check all in bounds
auto low = -10.0;
auto high = 20.0;
x = random::randint(low, high, {1000, 1000});
CHECK((all(low <= x).item<bool>() && all(x < high).item<bool>()));
// Check high < low => all equals to low
low = 20.0;
high = -10.0;
x = random::randint(low, high, {3, 3});
CHECK(all(equal(x, array(low))).item<bool>());
// Check wrong key type or shape
auto key = array({0, 0}, {1, 2});
CHECK_THROWS_AS(
random::randint(low, high, {}, float32, key), std::invalid_argument);
}
TEST_CASE("test random bernoulli") {
auto x = random::bernoulli();
CHECK_EQ(x.size(), 1);
CHECK_EQ(x.dtype(), bool_);
// Bernoulli parameter can have floating point type
if (is_available(float16)) {
x = random::bernoulli(array(0.5, float16));
CHECK_EQ(x.size(), 1);
CHECK_EQ(x.dtype(), bool_);
}
CHECK_THROWS(random::bernoulli(array(1, int32)));
// Negative numbers allowed in Jax
x = random::bernoulli(array(-1.0));
CHECK_FALSE(x.item<bool>());
x = random::bernoulli(array(5.0));
CHECK(x.item<bool>());
// Return array with correct shape
x = random::bernoulli(0.5, {3, 3});
CHECK_EQ(x.shape(), std::vector<int>({3, 3}));
// Try with p = {}
x = random::bernoulli(array({}));
CHECK_EQ(x.size(), 0);
// Try broadcasting
auto p = array({0.1, 0.2, 0.3});
p = reshape(p, {1, 3});
x = random::bernoulli(p, {4, 3});
CHECK_EQ(x.shape(), std::vector<int>({4, 3}));
CHECK_THROWS_AS(random::bernoulli(array({}), {3, 3}), std::invalid_argument);
p = array({0.1, 0.2, 0.3});
// Ask for the wrong shape => throws
CHECK_THROWS_AS(random::bernoulli(p, {2}), std::invalid_argument);
// Check wrong key type or shape
auto key = array({0, 0}, {1, 2});
CHECK_THROWS_AS(random::bernoulli(array(0.5), key), std::invalid_argument);
}
TEST_CASE("Test truncated normal") {
auto x = random::truncated_normal(array(-2.0), array(2.0));
CHECK_EQ(x.size(), 1);
CHECK_EQ(x.dtype(), float32);
if (is_available(float16)) {
x = random::truncated_normal(array(-2.0), array(2.0), {}, float16);
CHECK_EQ(x.size(), 1);
CHECK_EQ(x.dtype(), float16);
}
// Requested shape
x = random::truncated_normal(array(-2.0), array(2.0), {3, 4});
CHECK_EQ(x.shape(), std::vector<int>({3, 4}));
// Empty array
x = random::truncated_normal(array({}), array({}));
CHECK_EQ(x.size(), 0);
// Broadcast
auto lower = reshape(array({-2.0, -3.0}), {1, 2});
auto higher = reshape(array({0.0, 3.0, 1.5}), {3, 1});
x = random::truncated_normal(lower, higher);
// All in bounds
CHECK_EQ(x.shape(), std::vector<int>({3, 2}));
CHECK((all(x <= higher).item<bool>() && all(lower <= x).item<bool>()));
// high < low => all equal to low
x = random::truncated_normal(array(2.0), array(-2.0));
CHECK(all(x == array(2.0)).item<bool>());
// Non broadcastable => throws
CHECK_THROWS_AS(
random::truncated_normal(lower, higher, {4, 2}), std::invalid_argument);
auto key = array({0, 0}, {1, 2});
CHECK_THROWS_AS(
random::truncated_normal(array(-2.0), array(2.0), {1, 1}, float32, key),
std::invalid_argument);
}
TEST_CASE("test categorical") {
auto logits = zeros({10, 20});
using random::categorical;
// Invalid axes
CHECK_THROWS(categorical(logits, 2));
CHECK_THROWS(categorical(logits, -3));
// Invalid requested shapes
CHECK_THROWS(categorical(logits, 1, std::vector<int>{1}));
CHECK_THROWS(categorical(logits, 1, std::vector<int>{11}));
CHECK_THROWS(categorical(logits, 1, {10, 1}));
CHECK_EQ(categorical(logits, -1).shape(), std::vector<int>{10});
CHECK_EQ(categorical(logits, 0).shape(), std::vector<int>{20});
CHECK_EQ(categorical(logits, 1).shape(), std::vector<int>{10});
auto out = categorical(logits);
CHECK_EQ(out.shape(), std::vector<int>{10});
CHECK_EQ(out.dtype(), uint32);
CHECK(max(out).item<uint32_t>() < 20);
out = categorical(logits, 0, {5, 20});
CHECK_EQ(out.shape(), std::vector<int>{5, 20});
CHECK(max(out).item<uint32_t>() < 10);
float inf = std::numeric_limits<float>::infinity();
logits = array({1.0f, -2.0f, inf, 4.0f, 3.0f});
CHECK_EQ(categorical(logits).item<uint32_t>(), 2);
logits = array({-inf, -2.0f, -inf, -inf});
CHECK_EQ(categorical(logits).item<uint32_t>(), 1);
logits = zeros({5, 4, 3});
CHECK_EQ(categorical(logits, -1, 7).shape(), std::vector<int>{5, 4, 7});
CHECK_EQ(categorical(logits, -2, 7).shape(), std::vector<int>{5, 3, 7});
CHECK_EQ(categorical(logits, -3, 7).shape(), std::vector<int>{4, 3, 7});
}

22
tests/tests.cpp Normal file
View File

@@ -0,0 +1,22 @@
#define DOCTEST_CONFIG_IMPLEMENT
#include "doctest/doctest.h"
#include <cstdlib>
#include "mlx/mlx.h"
using namespace mlx::core;
int main(int argc, char** argv) {
doctest::Context context;
const char* device = std::getenv("DEVICE");
if (device != nullptr && std::string(device) == "cpu") {
set_default_device(Device::cpu);
} else if (metal::is_available()) {
set_default_device(Device::gpu);
}
context.applyCommandLine(argc, argv);
return context.run();
}

248
tests/vmap_tests.cpp Normal file
View File

@@ -0,0 +1,248 @@
#include "doctest/doctest.h"
#include "mlx/mlx.h"
using namespace mlx::core;
TEST_CASE("test simple vmap") {
// vmap reshape
{
auto vfun = vmap([](array input) { return reshape(input, {2, 2}); });
auto x = zeros({3, 4});
CHECK(array_equal(vfun(x), zeros({3, 2, 2})).item<bool>());
x = array({0, 1, 2, 3, 4, 5, 6, 7}, {2, 2, 2});
vfun = vmap([](array input) { return reshape(input, {4}); });
auto expected = array({0, 1, 2, 3, 4, 5, 6, 7}, {2, 4});
CHECK(array_equal(vfun(x), expected).item<bool>());
vfun = vmap([](array input) { return reshape(input, {4}); }, 1);
expected = array({0, 1, 4, 5, 2, 3, 6, 7}, {2, 4});
CHECK(array_equal(vfun(x), expected).item<bool>());
vfun = vmap([](array input) { return reshape(input, {4}); }, 1, 1);
expected = array({0, 2, 1, 3, 4, 6, 5, 7}, {4, 2});
CHECK(array_equal(vfun(x), expected).item<bool>());
}
// vmap broadcast
{
auto fun = [](array input) { return broadcast_to(input, {4, 2}); };
CHECK_THROWS_AS(vmap(fun, 0, -1), std::invalid_argument);
CHECK_THROWS_AS(vmap(fun, -1, 0), std::invalid_argument);
auto vfun = vmap(fun, -1, -1);
auto x = zeros({2});
CHECK(array_equal(vfun(x), zeros({4, 2})).item<bool>());
vfun = vmap(fun);
x = zeros({3, 2});
CHECK(array_equal(vfun(x), zeros({3, 4, 2})).item<bool>());
vfun = vmap(fun, 0, 1);
CHECK(array_equal(vfun(x), zeros({4, 3, 2})).item<bool>());
vfun = vmap(fun, 0, 2);
CHECK(array_equal(vfun(x), zeros({4, 2, 3})).item<bool>());
vfun = vmap(fun, 0, 2);
x = zeros({2, 3});
CHECK_THROWS_AS(vfun(x), std::invalid_argument);
x = zeros({2, 3});
vfun = vmap(fun, 1);
CHECK(array_equal(vfun(x), zeros({3, 4, 2})).item<bool>());
vfun = vmap(fun, 1, 1);
CHECK(array_equal(vfun(x), zeros({4, 3, 2})).item<bool>());
vfun = vmap(fun, 1, 2);
CHECK(array_equal(vfun(x), zeros({4, 2, 3})).item<bool>());
}
// vmap transpose
{
auto fun = [](array input) { return transpose(input); };
auto vfun = vmap(fun);
auto x = array({0, 1, 2, 3, 4, 5}, {3, 2});
CHECK(array_equal(vfun(x), x).item<bool>());
vfun = vmap(fun, 0, 1);
CHECK(array_equal(vfun(x), transpose(x)).item<bool>());
x = array({0, 1, 2, 3, 4, 5, 6, 7}, {2, 2, 2});
vfun = vmap(fun);
CHECK(array_equal(vfun(x), transpose(x, {0, 2, 1})).item<bool>());
vfun = vmap(fun, 1, 1);
CHECK(array_equal(vfun(x), transpose(x, {2, 1, 0})).item<bool>());
vfun = vmap(fun, 2, 2);
CHECK(array_equal(vfun(x), transpose(x, {1, 0, 2})).item<bool>());
// vmap twice
x = array(
{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, {2, 2, 2, 2});
vfun = vmap(vmap(fun));
CHECK(array_equal(vfun(x), transpose(x, {0, 1, 3, 2})).item<bool>());
}
// vmap add
{
auto fun = [](std::vector<array> inputs) {
auto out = add(inputs[0], inputs[1]);
return std::vector<array>{out};
};
auto vfun = vmap(fun);
array x({1.0, 2.0}, {2, 1});
array y({2.0, 3.0}, {2, 1});
auto out = vfun({x, y})[0];
CHECK(array_equal(out, array({3.0, 5.0}, {2, 1})).item<bool>());
x = ones({2, 1, 3});
y = ones({3, 2});
vfun = vmap(fun, {2, 0});
out = vfun({x, y})[0];
CHECK(array_equal(out, full({3, 2, 2}, 2.0)).item<bool>());
x = array(1.);
y = ones({3, 2});
vfun = vmap(fun, {-1, 0});
out = vfun({x, y})[0];
CHECK(array_equal(out, full({3, 2}, 2.0)).item<bool>());
x = ones({3, 2});
y = array(1.);
vfun = vmap(fun, {0, -1});
out = vfun({x, y})[0];
CHECK(array_equal(out, full({3, 2}, 2.0)).item<bool>());
CHECK_THROWS_AS(vmap(fun, {-1, -1}, {0}), std::invalid_argument);
CHECK_THROWS_AS(vmap(fun, {-1, 0}, {-1}), std::invalid_argument);
CHECK_THROWS_AS(vmap(fun, {0, -1}, {-1}), std::invalid_argument);
x = array(1.);
y = array(1.);
vfun = vmap(fun, {-1, -1}, {-1});
out = vfun({x, y})[0];
CHECK(array_equal(out, array(2.)).item<bool>());
x = ones({3, 2, 1});
y = ones({3, 2, 1});
vfun = vmap(vmap(fun));
out = vfun({x, y})[0];
CHECK(array_equal(out, x + y).item<bool>());
}
// vmap with capturing closure
{
auto x = add(add(ones({2}), zeros({2})), zeros({2}));
auto fun = [x](const array& input) { return add(input, x); };
auto vfun = vmap(fun);
auto y = ones({3, 2});
CHECK(array_equal(vfun(y), full({3, 2}, 2.0f)).item<bool>());
}
{
auto x = ones({4});
auto z = x + x;
auto vfun = vmap(
[z](std::vector<array> inputs) {
return std::vector<array>{add(z, inputs[1])};
},
{-1, 0});
auto y = ones({3, 4});
CHECK(array_equal(vfun({x, y})[0], full({3, 4}, 3.0)).item<bool>());
}
}
TEST_CASE("test vmap with eval") {
auto fun = [](std::vector<array> inputs) {
auto x = inputs[0] + 1;
auto y = inputs[1] + 2;
eval(x);
auto out = add(x, y);
return std::vector<array>{out};
};
auto vfun = vmap(fun);
array x({1.0, 2.0}, {2, 1});
array y({2.0, 3.0}, {2, 1});
CHECK_THROWS(vfun({x, y}));
// Ok to eval functions of non-vmapped input
x = array(1.0);
vfun = vmap(fun, {-1, 0});
CHECK(array_equal(vfun({x, y})[0], array({6.0f, 7.0f}, {2, 1})).item<bool>());
// Not ok to eval function of vmapped input even with retain graph
auto fun2 = [](std::vector<array> inputs) {
auto x = inputs[0] + 1;
auto y = inputs[1] + 2;
eval({x}, true);
auto out = add(x, y);
return std::vector<array>{out};
};
x = array({1.0, 2.0}, {2, 1});
CHECK_THROWS(vmap(fun2)({x, y}));
}
TEST_CASE("test vmap comparison ops") {
// vmap equal
{
auto fun = [](std::vector<array> inputs) {
return std::vector<array>{equal(inputs[0], inputs[1])};
};
auto vfun = vmap(fun);
auto x = zeros({2, 3}, float32);
auto y = zeros({2, 3}, float32);
auto out = vfun({x, y})[0];
CHECK(all(out).item<bool>());
vfun = vmap(fun, {0, -1});
x = zeros({2, 3}, float32);
y = zeros({3}, float32);
out = vfun({x, y})[0];
CHECK(all(out).item<bool>());
vfun = vmap(fun, {0, -1});
x = array({0, 0, 0, 1, 1, 1}, {2, 3});
y = zeros({3}, float32);
out = vfun({x, y})[0];
auto expected = array({true, true, true, false, false, false}, {2, 3});
CHECK(array_equal(out, expected).item<bool>());
}
}
TEST_CASE("test vmap creation ops") {
// vmap astype
{
auto fun = [](array in) { return astype(in, int32); };
auto x = zeros({2, 3}, float32);
auto out = vmap(fun)(x);
CHECK_EQ(out.dtype(), int32);
CHECK(array_equal(out, zeros({2, 3}, int32)).item<bool>());
}
// vmap full
{
auto fun = [](array in) { return full({2}, in); };
auto x = array({1, 2, 3});
auto out = vmap(fun)(x);
auto expected = array({1, 1, 2, 2, 3, 3}, {3, 2});
CHECK(array_equal(out, expected).item<bool>());
x = array({1, 2, 3}, {3, 1});
out = vmap(fun)(x);
expected = array({1, 1, 2, 2, 3, 3}, {3, 2});
CHECK(array_equal(out, expected).item<bool>());
x = array({1, 2, 3}, {1, 3});
CHECK_THROWS_AS(vmap(fun)(x), std::invalid_argument);
out = vmap(fun, 1, 1)(x);
expected = array({1, 2, 3, 1, 2, 3}, {2, 3});
CHECK(array_equal(out, expected).item<bool>());
}
}