mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-29 23:15:09 +08:00
awni's commit files
This commit is contained in:
38
tests/CMakeLists.txt
Normal file
38
tests/CMakeLists.txt
Normal file
@@ -0,0 +1,38 @@
|
||||
FetchContent_Declare(
|
||||
doctest
|
||||
GIT_REPOSITORY "https://github.com/onqtam/doctest"
|
||||
GIT_TAG "b7c21ec5ceeadb4951b00396fc1e4642dd347e5f"
|
||||
)
|
||||
FetchContent_MakeAvailable(doctest)
|
||||
|
||||
add_executable(tests ${PROJECT_SOURCE_DIR}/tests/tests.cpp)
|
||||
|
||||
if (MLX_BUILD_METAL)
|
||||
set(
|
||||
METAL_TEST_SOURCES
|
||||
metal_tests.cpp
|
||||
)
|
||||
endif()
|
||||
|
||||
target_sources(tests PRIVATE
|
||||
allocator_tests.cpp
|
||||
array_tests.cpp
|
||||
arg_reduce_tests.cpp
|
||||
autograd_tests.cpp
|
||||
blas_tests.cpp
|
||||
creations_tests.cpp
|
||||
device_tests.cpp
|
||||
eval_tests.cpp
|
||||
fft_tests.cpp
|
||||
graph_optimize_tests.cpp
|
||||
load_tests.cpp
|
||||
ops_tests.cpp
|
||||
random_tests.cpp
|
||||
scheduler_tests.cpp
|
||||
utils_tests.cpp
|
||||
vmap_tests.cpp
|
||||
${METAL_TEST_SOURCES}
|
||||
)
|
||||
|
||||
target_link_libraries(tests PRIVATE mlx doctest)
|
||||
add_test(NAME tests COMMAND tests)
|
||||
205
tests/arg_reduce_tests.cpp
Normal file
205
tests/arg_reduce_tests.cpp
Normal file
@@ -0,0 +1,205 @@
|
||||
#include <iostream>
|
||||
|
||||
#include "doctest/doctest.h"
|
||||
|
||||
#include "mlx/mlx.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
using namespace mlx::core;
|
||||
|
||||
void test_arg_reduce_small(
|
||||
Device d,
|
||||
const array& x,
|
||||
ArgReduce::ReduceType r,
|
||||
std::vector<int> out_shape,
|
||||
int axis,
|
||||
std::vector<int> expected_output) {
|
||||
auto s = default_stream(d);
|
||||
auto y =
|
||||
array(out_shape, uint32, std::make_unique<ArgReduce>(s, r, axis), {x});
|
||||
y.eval();
|
||||
const uint32_t* ydata = y.data<uint32_t>();
|
||||
for (int i = 0; i < y.size(); i++) {
|
||||
CHECK_EQ(expected_output[i], ydata[i]);
|
||||
}
|
||||
}
|
||||
|
||||
void test_arg_reduce_against_cpu(
|
||||
const array& x,
|
||||
ArgReduce::ReduceType r,
|
||||
std::vector<int> out_shape,
|
||||
int axis) {
|
||||
auto y1 = array(
|
||||
out_shape,
|
||||
uint32,
|
||||
std::make_unique<ArgReduce>(default_stream(Device::cpu), r, axis),
|
||||
{x});
|
||||
auto y2 = array(
|
||||
out_shape,
|
||||
uint32,
|
||||
std::make_unique<ArgReduce>(default_stream(Device::gpu), r, axis),
|
||||
{x});
|
||||
y1.eval();
|
||||
y2.eval();
|
||||
CHECK(array_equal(y1, y2).item<bool>());
|
||||
}
|
||||
|
||||
TEST_CASE("test arg reduce small") {
|
||||
auto x = array(
|
||||
{0, 2, 1, 7, 5, -5, 0, 2, 1, 7, 5, -5,
|
||||
0, 2, 1, 7, 5, -5, 0, 2, 1, 7, 5, -5},
|
||||
{2, 3, 4});
|
||||
x.eval();
|
||||
test_arg_reduce_small(
|
||||
Device::cpu, x, ArgReduce::ArgMin, {2, 3}, 2, {0, 1, 3, 0, 1, 3});
|
||||
test_arg_reduce_small(
|
||||
Device::cpu, x, ArgReduce::ArgMin, {2, 4}, 1, {0, 1, 1, 2, 0, 1, 1, 2});
|
||||
test_arg_reduce_small(
|
||||
Device::cpu,
|
||||
x,
|
||||
ArgReduce::ArgMin,
|
||||
{3, 4},
|
||||
0,
|
||||
{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0});
|
||||
test_arg_reduce_small(
|
||||
Device::cpu, x, ArgReduce::ArgMax, {2, 3}, 2, {3, 0, 1, 3, 0, 1});
|
||||
test_arg_reduce_small(
|
||||
Device::cpu, x, ArgReduce::ArgMax, {2, 4}, 1, {1, 2, 2, 0, 1, 2, 2, 0});
|
||||
test_arg_reduce_small(
|
||||
Device::cpu,
|
||||
x,
|
||||
ArgReduce::ArgMax,
|
||||
{3, 4},
|
||||
0,
|
||||
{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0});
|
||||
|
||||
if (!metal::is_available()) {
|
||||
INFO("Skiping arg reduction gpu tests");
|
||||
return;
|
||||
}
|
||||
|
||||
test_arg_reduce_small(
|
||||
Device::gpu, x, ArgReduce::ArgMin, {2, 3}, 2, {0, 1, 3, 0, 1, 3});
|
||||
test_arg_reduce_small(
|
||||
Device::gpu, x, ArgReduce::ArgMin, {2, 4}, 1, {0, 1, 1, 2, 0, 1, 1, 2});
|
||||
test_arg_reduce_small(
|
||||
Device::gpu,
|
||||
x,
|
||||
ArgReduce::ArgMin,
|
||||
{3, 4},
|
||||
0,
|
||||
{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0});
|
||||
test_arg_reduce_small(
|
||||
Device::gpu, x, ArgReduce::ArgMax, {2, 3}, 2, {3, 0, 1, 3, 0, 1});
|
||||
test_arg_reduce_small(
|
||||
Device::gpu, x, ArgReduce::ArgMax, {2, 4}, 1, {1, 2, 2, 0, 1, 2, 2, 0});
|
||||
test_arg_reduce_small(
|
||||
Device::gpu,
|
||||
x,
|
||||
ArgReduce::ArgMax,
|
||||
{3, 4},
|
||||
0,
|
||||
{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0});
|
||||
}
|
||||
|
||||
TEST_CASE("test arg reduce against cpu") {
|
||||
if (!metal::is_available()) {
|
||||
INFO("Skiping arg reduction gpu tests");
|
||||
return;
|
||||
}
|
||||
|
||||
auto x = random::uniform(array(0.0), array(1.0), {127, 92, 55});
|
||||
x.eval();
|
||||
test_arg_reduce_against_cpu(x, ArgReduce::ArgMin, {127, 92}, 2);
|
||||
test_arg_reduce_against_cpu(x, ArgReduce::ArgMin, {127, 55}, 1);
|
||||
test_arg_reduce_against_cpu(x, ArgReduce::ArgMin, {92, 55}, 0);
|
||||
test_arg_reduce_against_cpu(x, ArgReduce::ArgMax, {127, 92}, 2);
|
||||
test_arg_reduce_against_cpu(x, ArgReduce::ArgMax, {127, 55}, 1);
|
||||
test_arg_reduce_against_cpu(x, ArgReduce::ArgMax, {92, 55}, 0);
|
||||
|
||||
auto y = random::uniform(array(0.0), array(1.0), {1234});
|
||||
y.eval();
|
||||
test_arg_reduce_against_cpu(y, ArgReduce::ArgMin, {}, 0);
|
||||
test_arg_reduce_against_cpu(y, ArgReduce::ArgMax, {}, 0);
|
||||
}
|
||||
|
||||
void test_arg_reduce_small_bool(
|
||||
Device d,
|
||||
ArgReduce::ReduceType r,
|
||||
std::vector<int> out_shape,
|
||||
int axis,
|
||||
std::vector<int> expected_output) {
|
||||
auto s = default_stream(d);
|
||||
auto x = array(
|
||||
{0, 2, 1, 7, 5, -5, 0, 2, 1, 7, 5, -5,
|
||||
0, 2, 1, 7, 5, -5, 0, 2, 1, 7, 5, -5},
|
||||
{2, 3, 4});
|
||||
x.eval();
|
||||
auto y =
|
||||
array(out_shape, uint32, std::make_unique<ArgReduce>(s, r, axis), {x});
|
||||
y.eval();
|
||||
const uint32_t* ydata = y.data<uint32_t>();
|
||||
for (int i = 0; i < y.size(); i++) {
|
||||
CHECK_EQ(expected_output[i], ydata[i]);
|
||||
}
|
||||
}
|
||||
|
||||
TEST_CASE("test arg reduce bool") {
|
||||
if (!metal::is_available()) {
|
||||
INFO("Skiping arg reduction gpu tests");
|
||||
return;
|
||||
}
|
||||
auto x = array(
|
||||
{false, true, true, false, false, false, false, true,
|
||||
true, false, true, true, false, true, true, false,
|
||||
false, false, false, true, true, false, true, true},
|
||||
{2, 3, 4});
|
||||
x.eval();
|
||||
test_arg_reduce_small(
|
||||
Device::gpu, x, ArgReduce::ArgMin, {2, 3}, 2, {0, 0, 1, 0, 0, 1});
|
||||
test_arg_reduce_small(
|
||||
Device::gpu, x, ArgReduce::ArgMin, {2, 4}, 1, {0, 1, 1, 0, 0, 1, 1, 0});
|
||||
test_arg_reduce_small(
|
||||
Device::gpu,
|
||||
x,
|
||||
ArgReduce::ArgMin,
|
||||
{3, 4},
|
||||
0,
|
||||
{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0});
|
||||
test_arg_reduce_small(
|
||||
Device::gpu, x, ArgReduce::ArgMax, {2, 3}, 2, {1, 3, 0, 1, 3, 0});
|
||||
test_arg_reduce_small(
|
||||
Device::gpu, x, ArgReduce::ArgMax, {2, 4}, 1, {2, 0, 0, 1, 2, 0, 0, 1});
|
||||
test_arg_reduce_small(
|
||||
Device::gpu,
|
||||
x,
|
||||
ArgReduce::ArgMax,
|
||||
{3, 4},
|
||||
0,
|
||||
{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0});
|
||||
}
|
||||
|
||||
TEST_CASE("test arg reduce edge cases") {
|
||||
auto a = argmin(array(1.0));
|
||||
CHECK_EQ(a.item<uint32_t>(), 0);
|
||||
auto b = argmax(array(1.0));
|
||||
CHECK_EQ(b.item<uint32_t>(), 0);
|
||||
CHECK_THROWS(argmin(array({})));
|
||||
CHECK_THROWS(argmax(array({})));
|
||||
}
|
||||
|
||||
TEST_CASE("test arg reduce irregular strides") {
|
||||
auto x = array(
|
||||
{0, 2, 1, 7, 5, -5, 0, 2, 1, 7, 5, -5,
|
||||
0, 2, 1, 7, 5, -5, 0, 2, 1, 7, 5, -5},
|
||||
{2, 3, 4});
|
||||
x = transpose(x, {2, 0, 1});
|
||||
x.eval();
|
||||
test_arg_reduce_small(
|
||||
Device::cpu, x, ArgReduce::ArgMin, {4, 2}, 2, {0, 0, 1, 1, 1, 1, 2, 2});
|
||||
|
||||
if (!metal::is_available()) {
|
||||
INFO("Skiping arg reduction gpu tests");
|
||||
return;
|
||||
}
|
||||
}
|
||||
108
tests/blas_tests.cpp
Normal file
108
tests/blas_tests.cpp
Normal file
@@ -0,0 +1,108 @@
|
||||
#include <numeric>
|
||||
|
||||
#include "doctest/doctest.h"
|
||||
|
||||
#include "mlx/mlx.h"
|
||||
|
||||
using namespace mlx::core;
|
||||
|
||||
TEST_CASE("test matmul") {
|
||||
auto a = array(1);
|
||||
auto b = array({1.0});
|
||||
CHECK_THROWS_AS(matmul(a, b), std::invalid_argument);
|
||||
|
||||
a = array({1.0});
|
||||
b = array({1.0});
|
||||
auto out = matmul(a, b);
|
||||
CHECK_EQ(out.shape(), std::vector<int>{});
|
||||
CHECK_EQ(out.size(), 1);
|
||||
CHECK_EQ(out.dtype(), float32);
|
||||
CHECK_EQ(out.item<float>(), 1.0f);
|
||||
|
||||
a = ones({2, 4});
|
||||
b = ones({2});
|
||||
CHECK_THROWS_AS(matmul(a, b), std::invalid_argument);
|
||||
|
||||
a = ones({2, 4});
|
||||
b = ones({3, 2});
|
||||
CHECK_THROWS_AS(matmul(a, b), std::invalid_argument);
|
||||
|
||||
a = ones({2, 4});
|
||||
b = ones({4, 3, 2});
|
||||
CHECK_THROWS_AS(matmul(a, b), std::invalid_argument);
|
||||
|
||||
a = ones({2});
|
||||
b = ones({4, 2});
|
||||
CHECK_THROWS_AS(matmul(a, b), std::invalid_argument);
|
||||
|
||||
a = ones({2, 3});
|
||||
b = ones({4, 2});
|
||||
CHECK_THROWS_AS(matmul(a, b), std::invalid_argument);
|
||||
|
||||
a = ones({2, 4, 3});
|
||||
b = ones({4, 2});
|
||||
CHECK_THROWS_AS(matmul(a, b), std::invalid_argument);
|
||||
|
||||
a = ones({2, 4});
|
||||
b = ones({4, 2});
|
||||
out = matmul(a, b);
|
||||
CHECK(array_equal(out, full({2, 2}, 4.0f)).item<bool>());
|
||||
|
||||
a = ones({2, 4}, int32);
|
||||
b = ones({4, 2}, float32);
|
||||
out = matmul(a, b);
|
||||
CHECK(array_equal(out, full({2, 2}, 4.0f)).item<bool>());
|
||||
|
||||
// Check single dimensions
|
||||
a = ones({4});
|
||||
b = ones({4, 2});
|
||||
out = matmul(a, b);
|
||||
CHECK(array_equal(out, full({2}, 4.0f)).item<bool>());
|
||||
|
||||
a = ones({2, 4});
|
||||
b = ones({4});
|
||||
out = matmul(a, b);
|
||||
CHECK(array_equal(out, full({2}, 4.0f)).item<bool>());
|
||||
|
||||
a = ones({4});
|
||||
b = ones({4});
|
||||
out = matmul(a, b);
|
||||
CHECK(array_equal(out, full({}, 4.0f)).item<bool>());
|
||||
|
||||
// Test transposed arrays
|
||||
a = array({1.0f, 1.0f, 1.0f, 1.0f}, {1, 4});
|
||||
b = array({1.0f, 1.0f, 1.0f, 1.0f}, {4, 1});
|
||||
out = matmul(transpose(a), transpose(b));
|
||||
CHECK(array_equal(out, ones({4, 4})).item<bool>());
|
||||
|
||||
a = array({1.0f, 2.0f, 3.0f, 4.0f}, {2, 2});
|
||||
b = array({1.0f, 2.0f, 1.0f, 2.0f}, {2, 2});
|
||||
out = matmul(transpose(a), b);
|
||||
CHECK(
|
||||
array_equal(out, array({4.0f, 8.0f, 6.0f, 12.0f}, {2, 2})).item<bool>());
|
||||
|
||||
out = matmul(a, transpose(b));
|
||||
CHECK(
|
||||
array_equal(out, array({5.0f, 5.0f, 11.0f, 11.0f}, {2, 2})).item<bool>());
|
||||
|
||||
out = matmul(transpose(a), transpose(b));
|
||||
CHECK(
|
||||
array_equal(out, array({7.0f, 7.0f, 10.0f, 10.0f}, {2, 2})).item<bool>());
|
||||
|
||||
// Test broadcasting for both arrays
|
||||
a = ones({5, 4, 2});
|
||||
b = ones({2, 3});
|
||||
out = matmul(a, b);
|
||||
CHECK(array_equal(out, full({5, 4, 3}, 2.0f)).item<bool>());
|
||||
|
||||
a = ones({5, 1, 4, 2});
|
||||
b = ones({1, 7, 2, 3});
|
||||
out = matmul(a, b);
|
||||
CHECK(array_equal(out, full({5, 7, 4, 3}, 2.0f)).item<bool>());
|
||||
|
||||
// Test batched matmul with transpose
|
||||
a = ones({2, 2, 4});
|
||||
b = ones({2, 4, 2});
|
||||
out = matmul(transpose(a, {0, 2, 1}), transpose(b, {0, 2, 1}));
|
||||
CHECK(array_equal(out, full({2, 4, 4}, 2.0f)).item<bool>());
|
||||
}
|
||||
438
tests/metal_tests.cpp
Normal file
438
tests/metal_tests.cpp
Normal file
@@ -0,0 +1,438 @@
|
||||
#include <array>
|
||||
#include "doctest/doctest.h"
|
||||
|
||||
#include "mlx/backend/metal/device.h"
|
||||
#include "mlx/backend/metal/metal.h"
|
||||
#include "mlx/mlx.h"
|
||||
|
||||
using namespace mlx::core;
|
||||
|
||||
static const std::array<Dtype, 5> types =
|
||||
{bool_, uint32, int32, int64, float32};
|
||||
|
||||
TEST_CASE("test metal device") {
|
||||
// Make sure the device and library can load
|
||||
CHECK(metal::is_available());
|
||||
auto& device = metal::device(Device::gpu);
|
||||
}
|
||||
|
||||
TEST_CASE("test metal arange") {
|
||||
for (auto t : types) {
|
||||
if (t == bool_) {
|
||||
continue;
|
||||
}
|
||||
auto out_cpu = arange(1, 100, 2, t, Device::cpu);
|
||||
auto out_gpu = arange(1, 100, 2, t, Device::gpu);
|
||||
CHECK(array_equal(out_gpu, out_cpu, Device::cpu).item<bool>());
|
||||
|
||||
out_cpu = arange(1, 5, 0.25, t, Device::cpu);
|
||||
out_gpu = arange(1, 5, 0.25, t, Device::gpu);
|
||||
CHECK(array_equal(out_gpu, out_cpu, Device::cpu).item<bool>());
|
||||
}
|
||||
}
|
||||
|
||||
TEST_CASE("test metal full") {
|
||||
for (auto t : types) {
|
||||
auto out_cpu = full({4, 4}, 2, t, Device::cpu);
|
||||
auto out_gpu = full({4, 4}, 2, t, Device::gpu);
|
||||
CHECK(array_equal(out_gpu, out_cpu, Device::cpu).item<bool>());
|
||||
}
|
||||
|
||||
// Check broadcasting works
|
||||
{
|
||||
auto x = full({2, 2}, array({3, 4}, {2, 1}), Device::gpu);
|
||||
CHECK(
|
||||
array_equal(x, array({3, 3, 4, 4}, {2, 2}), Device::cpu).item<bool>());
|
||||
x = full({2, 2}, array({3, 4}, {1, 2}), Device::gpu);
|
||||
CHECK(
|
||||
array_equal(x, array({3, 4, 3, 4}, {2, 2}), Device::cpu).item<bool>());
|
||||
}
|
||||
|
||||
// Check zeros and ones
|
||||
{
|
||||
auto x = zeros({2, 2}, float32, Device::gpu);
|
||||
auto y = array({0.0, 0.0, 0.0, 0.0}, {2, 2});
|
||||
CHECK(array_equal(x, y, Device::cpu).item<bool>());
|
||||
|
||||
x = ones({2, 2}, float32, Device::gpu);
|
||||
y = array({1.0, 1.0, 1.0, 1.0}, {2, 2});
|
||||
CHECK(array_equal(x, y, Device::cpu).item<bool>());
|
||||
}
|
||||
}
|
||||
|
||||
TEST_CASE("test metal astype") {
|
||||
array x = array({-4, -3, -2, -1, 0, 1, 2, 3});
|
||||
// Check all types work
|
||||
for (auto t : types) {
|
||||
auto out_cpu = astype(x, t, Device::cpu);
|
||||
auto out_gpu = astype(x, t, Device::gpu);
|
||||
CHECK(array_equal(out_gpu, out_cpu, Device::cpu).item<bool>());
|
||||
}
|
||||
|
||||
x = transpose(reshape(x, {2, 2, 2}), {1, 2, 0});
|
||||
for (auto t : types) {
|
||||
auto out_cpu = astype(x, t, Device::cpu);
|
||||
auto out_gpu = astype(x, t, Device::gpu);
|
||||
CHECK(array_equal(out_gpu, out_cpu, Device::cpu).item<bool>());
|
||||
}
|
||||
}
|
||||
|
||||
TEST_CASE("test metal reshape") {
|
||||
array x = array({0, 1, 2, 3, 4, 5, 6, 7});
|
||||
auto out_cpu = reshape(x, {2, 2, 2});
|
||||
auto out_gpu = reshape(x, {2, 2, 2}, Device::gpu);
|
||||
CHECK(array_equal(out_gpu, out_cpu, Device::cpu).item<bool>());
|
||||
|
||||
x = transpose(reshape(x, {2, 2, 2}), {1, 2, 0});
|
||||
out_cpu = reshape(x, {4, 2});
|
||||
out_gpu = reshape(x, {4, 2}, Device::gpu);
|
||||
CHECK(array_equal(out_gpu, out_cpu, Device::cpu).item<bool>());
|
||||
|
||||
out_cpu = reshape(x, {8});
|
||||
out_gpu = reshape(x, {8}, Device::gpu);
|
||||
CHECK(array_equal(out_gpu, out_cpu, Device::cpu).item<bool>());
|
||||
}
|
||||
|
||||
TEST_CASE("test metal reduce") {
|
||||
{
|
||||
array a(true);
|
||||
CHECK_EQ(all(a, Device::gpu).item<bool>(), true);
|
||||
CHECK_EQ(any(a, Device::gpu).item<bool>(), true);
|
||||
|
||||
a = array(std::initializer_list<bool>{});
|
||||
CHECK_EQ(all(a, Device::gpu).item<bool>(), true);
|
||||
CHECK_EQ(any(a, Device::gpu).item<bool>(), false);
|
||||
}
|
||||
|
||||
{
|
||||
std::vector<int> vals(33, 1);
|
||||
array a(vals.data(), {33});
|
||||
CHECK_EQ(all(a, Device::gpu).item<bool>(), true);
|
||||
|
||||
vals[32] = 0;
|
||||
a = array(vals.data(), {33});
|
||||
CHECK_EQ(all(a, Device::gpu).item<bool>(), false);
|
||||
}
|
||||
|
||||
{
|
||||
std::vector<int> vals(33, 0);
|
||||
array a(vals.data(), {33});
|
||||
CHECK_EQ(any(a, Device::gpu).item<bool>(), false);
|
||||
|
||||
vals[32] = 1;
|
||||
a = array(vals.data(), {33});
|
||||
CHECK_EQ(any(a, Device::gpu).item<bool>(), true);
|
||||
}
|
||||
|
||||
{
|
||||
std::vector<int> vals(1 << 14, 0);
|
||||
array a(vals.data(), {1 << 14});
|
||||
CHECK_EQ(all(a, Device::gpu).item<bool>(), false);
|
||||
CHECK_EQ(any(a, Device::gpu).item<bool>(), false);
|
||||
|
||||
vals[4] = 1;
|
||||
vals[999] = 1;
|
||||
vals[2000] = 1;
|
||||
a = array(vals.data(), {1 << 14});
|
||||
CHECK_EQ(all(a, Device::gpu).item<bool>(), false);
|
||||
CHECK_EQ(any(a, Device::gpu).item<bool>(), true);
|
||||
}
|
||||
|
||||
// sum and prod
|
||||
{
|
||||
array a = array({true, false, true});
|
||||
CHECK_EQ(sum(a, Device::gpu).item<uint32_t>(), 2);
|
||||
CHECK_EQ(prod(a, Device::gpu).item<bool>(), false);
|
||||
|
||||
a = array({true, true, true});
|
||||
CHECK_EQ(sum(a, Device::gpu).item<uint32_t>(), 3);
|
||||
CHECK_EQ(prod(a, Device::gpu).item<bool>(), true);
|
||||
|
||||
a = full({2, 2, 2}, 2.0f);
|
||||
CHECK_EQ(sum(a, Device::gpu).item<float>(), 16.0f);
|
||||
CHECK_EQ(prod(a, Device::gpu).item<float>(), 256.0f);
|
||||
|
||||
a = full({500, 2, 2}, 1u);
|
||||
CHECK_EQ(sum(a, Device::gpu).item<uint32_t>(), 2000);
|
||||
CHECK_EQ(prod(a, Device::gpu).item<uint32_t>(), 1u);
|
||||
|
||||
a = full({500, 2, 2}, 1);
|
||||
CHECK_EQ(sum(a, Device::gpu).item<int32_t>(), 2000);
|
||||
CHECK_EQ(prod(a, Device::gpu).item<int32_t>(), 1);
|
||||
}
|
||||
|
||||
// reducing only some axes and irregular layouts
|
||||
{
|
||||
array a(1.0f);
|
||||
a = broadcast_to(a, {2, 2, 2});
|
||||
CHECK_EQ(sum(a, Device::gpu).item<float>(), 8.0f);
|
||||
|
||||
a = ones({2, 4, 8, 16});
|
||||
for (auto ax : {0, 1, 2, 3}) {
|
||||
auto out_gpu = sum(a, ax, false, Device::gpu);
|
||||
auto out_cpu = sum(a, ax, false, Device::cpu);
|
||||
CHECK(array_equal(out_gpu, out_cpu, Device::cpu).item<bool>());
|
||||
}
|
||||
|
||||
for (auto ax : {1, 2, 3}) {
|
||||
auto out_gpu = sum(a, {0, ax}, false, Device::gpu);
|
||||
auto out_cpu = sum(a, {0, ax}, false, Device::cpu);
|
||||
CHECK(array_equal(out_gpu, out_cpu, Device::cpu).item<bool>());
|
||||
}
|
||||
for (auto ax : {2, 3}) {
|
||||
auto out_gpu = sum(a, {0, 1, ax}, false, Device::gpu);
|
||||
auto out_cpu = sum(a, {0, 1, ax}, false, Device::cpu);
|
||||
CHECK(array_equal(out_gpu, out_cpu, Device::cpu).item<bool>());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST_CASE("test metal binary ops") {
|
||||
// scalar-scalar
|
||||
{
|
||||
array a(2.0f);
|
||||
array b(4.0f);
|
||||
auto out = add(a, b, Device::gpu);
|
||||
CHECK_EQ(out.item<float>(), 6.0f);
|
||||
}
|
||||
|
||||
// scalar-vector and vector-scalar
|
||||
{
|
||||
array a(2.0f);
|
||||
array b({2.0f, 4.0f, 6.0f});
|
||||
auto out = add(a, b, Device::gpu);
|
||||
auto expected = array({4.0f, 6.0f, 8.0f});
|
||||
CHECK(array_equal(out, expected, Device::cpu).item<bool>());
|
||||
out = add(b, a, Device::gpu);
|
||||
CHECK(array_equal(out, expected, Device::cpu).item<bool>());
|
||||
}
|
||||
|
||||
// vector-vector
|
||||
{
|
||||
array a({0.0f, 1.0f, 2.0f});
|
||||
array b({3.0f, 4.0f, 5.0f});
|
||||
auto out = add(a, b, Device::gpu);
|
||||
auto expected = array({3.0f, 5.0f, 7.0f});
|
||||
CHECK(array_equal(out, expected, Device::cpu).item<bool>());
|
||||
}
|
||||
|
||||
// general
|
||||
{
|
||||
array a({0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f}, {2, 2, 2});
|
||||
array b({0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f}, {2, 2, 2});
|
||||
a = transpose(a, {0, 2, 1});
|
||||
b = transpose(b, {1, 0, 2});
|
||||
auto out_gpu = add(a, b, Device::gpu);
|
||||
auto out_cpu = add(a, b, Device::cpu);
|
||||
auto expected =
|
||||
array({0.0f, 3.0f, 5.0f, 8.0f, 6.0f, 9.0f, 11.0f, 14.0f}, {2, 2, 2});
|
||||
CHECK(array_equal(out_gpu, out_cpu, Device::cpu).item<bool>());
|
||||
CHECK(array_equal(out_gpu, expected, Device::cpu).item<bool>());
|
||||
}
|
||||
|
||||
// Check all types work
|
||||
for (auto t : types) {
|
||||
auto a = astype(array({0, 1, 2}), t);
|
||||
auto b = astype(array({3, 4, 5}), t);
|
||||
auto out_cpu = add(a, b, Device::cpu);
|
||||
auto out_gpu = add(a, b, Device::gpu);
|
||||
CHECK(array_equal(out_gpu, out_cpu, Device::cpu).item<bool>());
|
||||
}
|
||||
|
||||
// Check subtraction
|
||||
{
|
||||
auto a = array({3, 2, 1});
|
||||
auto b = array({1, 1, 1});
|
||||
auto out = subtract(a, b, Device::gpu);
|
||||
CHECK(array_equal(out, array({2, 1, 0}), Device::cpu).item<bool>());
|
||||
}
|
||||
|
||||
// Check multiplication
|
||||
{
|
||||
auto a = array({1, 2, 3});
|
||||
auto b = array({2, 2, 2});
|
||||
auto out = multiply(a, b, Device::gpu);
|
||||
CHECK(array_equal(out, array({2, 4, 6}), Device::cpu).item<bool>());
|
||||
}
|
||||
|
||||
// Check division
|
||||
{
|
||||
auto x = array(1.0f);
|
||||
auto y = array(1.0f);
|
||||
CHECK_EQ(divide(x, y, Device::gpu).item<float>(), 1.0f);
|
||||
|
||||
x = array(1.0f);
|
||||
y = array(0.5);
|
||||
CHECK_EQ(divide(x, y, Device::gpu).item<float>(), 2.0f);
|
||||
|
||||
x = array(1.0f);
|
||||
y = array(0.0f);
|
||||
CHECK(std::isinf(divide(x, y, Device::gpu).item<float>()));
|
||||
|
||||
x = array(0.0f);
|
||||
y = array(0.0f);
|
||||
CHECK(std::isnan(divide(x, y, Device::gpu).item<float>()));
|
||||
}
|
||||
|
||||
// Check maximum and minimum
|
||||
{
|
||||
auto x = array(1.0f);
|
||||
auto y = array(0.0f);
|
||||
CHECK_EQ(maximum(x, y, Device::gpu).item<float>(), 1.0f);
|
||||
CHECK_EQ(minimum(x, y, Device::gpu).item<float>(), 0.0f);
|
||||
y = array(2.0f);
|
||||
CHECK_EQ(maximum(x, y, Device::gpu).item<float>(), 2.0f);
|
||||
CHECK_EQ(minimum(x, y, Device::gpu).item<float>(), 1.0f);
|
||||
}
|
||||
|
||||
// Check equal
|
||||
{
|
||||
array x(1.0f);
|
||||
array y(1.0f);
|
||||
CHECK(equal(x, y, Device::gpu).item<bool>());
|
||||
x = array(0.0f);
|
||||
CHECK(!equal(x, y, Device::gpu).item<bool>());
|
||||
}
|
||||
|
||||
// Greater and less
|
||||
{
|
||||
array x(1.0f);
|
||||
array y(0.0f);
|
||||
CHECK(greater(x, y, Device::gpu).item<bool>());
|
||||
CHECK(greater_equal(x, y, Device::gpu).item<bool>());
|
||||
CHECK(!greater(y, x, Device::gpu).item<bool>());
|
||||
CHECK(!greater_equal(y, x, Device::gpu).item<bool>());
|
||||
y = array(1.0f);
|
||||
CHECK(!greater(x, y, Device::gpu).item<bool>());
|
||||
CHECK(greater_equal(x, y, Device::gpu).item<bool>());
|
||||
|
||||
x = array(0.0f);
|
||||
y = array(1.0f);
|
||||
CHECK(less(x, y, Device::gpu).item<bool>());
|
||||
CHECK(less_equal(x, y, Device::gpu).item<bool>());
|
||||
CHECK(!less(y, x, Device::gpu).item<bool>());
|
||||
CHECK(!less_equal(y, x, Device::gpu).item<bool>());
|
||||
y = array(0.0f);
|
||||
CHECK(!less(x, y, Device::gpu).item<bool>());
|
||||
CHECK(less_equal(x, y, Device::gpu).item<bool>());
|
||||
}
|
||||
|
||||
// Check logaddexp
|
||||
{
|
||||
constexpr float inf = std::numeric_limits<float>::infinity();
|
||||
array x(inf);
|
||||
array y(2.0f);
|
||||
auto out = logaddexp(x, y, Device::gpu);
|
||||
CHECK_EQ(out.item<float>(), inf);
|
||||
|
||||
x = array(-inf);
|
||||
out = logaddexp(x, y, Device::gpu);
|
||||
CHECK_EQ(out.item<float>(), 2.0f);
|
||||
|
||||
y = array(-inf);
|
||||
out = logaddexp(x, y, Device::gpu);
|
||||
CHECK_EQ(out.item<float>(), -inf);
|
||||
}
|
||||
}
|
||||
|
||||
TEST_CASE("test metal unary ops") {
|
||||
// contiguous
|
||||
{
|
||||
array x({-1.0f, 0.0f, 1.0f});
|
||||
auto expected = array({1.0f, 0.0f, 1.0f});
|
||||
CHECK(array_equal(abs(x, Device::gpu), expected, Device::cpu).item<bool>());
|
||||
}
|
||||
|
||||
// general
|
||||
{
|
||||
array x({-1.0f, 0.0f, 1.0f, 1.0f, -1.0f, 1.0f, 3.0f, -3.0f});
|
||||
auto y = slice(x, {0}, {8}, {2});
|
||||
auto expected = array({1.0f, 1.0f, 1.0f, 3.0f});
|
||||
CHECK(array_equal(abs(y, Device::gpu), expected, Device::cpu).item<bool>());
|
||||
|
||||
y = slice(x, {4}, {8});
|
||||
expected = array({1.0f, 1.0f, 3.0f, 3.0f});
|
||||
CHECK(array_equal(abs(y, Device::gpu), expected, Device::cpu).item<bool>());
|
||||
}
|
||||
|
||||
// Test negative
|
||||
{
|
||||
array x(1.0f);
|
||||
CHECK_EQ(negative(x, Device::gpu).item<float>(), -1.0f);
|
||||
}
|
||||
|
||||
// Check all types work
|
||||
for (auto t : types) {
|
||||
if (t == bool_) {
|
||||
continue;
|
||||
}
|
||||
auto in = astype(array({1}), t);
|
||||
auto out_cpu = negative(in, Device::cpu);
|
||||
auto out_gpu = negative(in, Device::gpu);
|
||||
CHECK(array_equal(out_gpu, out_cpu, Device::cpu).item<bool>());
|
||||
}
|
||||
|
||||
// Test log1p
|
||||
{
|
||||
constexpr float inf = std::numeric_limits<float>::infinity();
|
||||
array x(-1.0f);
|
||||
CHECK_EQ(log1p(x, Device::gpu).item<float>(), -inf);
|
||||
|
||||
x = array(0.0f);
|
||||
CHECK_EQ(log1p(x, Device::gpu).item<float>(), 0.0f);
|
||||
|
||||
x = array(1e-9f);
|
||||
CHECK_EQ(log1p(x, Device::gpu).item<float>(), 1e-9f);
|
||||
|
||||
x = array(-2.0f);
|
||||
CHECK(std::isnan(log1p(x, Device::gpu).item<float>()));
|
||||
}
|
||||
}
|
||||
|
||||
TEST_CASE("test metal random") {
|
||||
{
|
||||
auto key = random::key(0);
|
||||
auto x = random::bits({}, 4, key, Device::gpu);
|
||||
auto y = random::bits({}, 4, key, Device::gpu);
|
||||
CHECK_EQ(x.item<uint32_t>(), 1797259609u);
|
||||
CHECK_EQ(x.item<uint32_t>(), y.item<uint32_t>());
|
||||
}
|
||||
|
||||
{
|
||||
auto key = random::key(1);
|
||||
auto x = random::bits({}, 4, key, Device::gpu);
|
||||
CHECK_EQ(x.item<uint32_t>(), 507451445u);
|
||||
}
|
||||
|
||||
{
|
||||
auto key = random::key(0);
|
||||
auto x = random::bits({3, 1}, 4, key, Device::gpu);
|
||||
auto expected = array({4146024105u, 1351547692u, 2718843009u}, {3, 1});
|
||||
CHECK(array_equal(x, expected, Device::cpu).item<bool>());
|
||||
}
|
||||
}
|
||||
|
||||
TEST_CASE("test metal matmul") {
|
||||
{
|
||||
auto a = ones({2, 2});
|
||||
auto b = ones({2, 2});
|
||||
auto out = matmul(a, b, Device::gpu);
|
||||
CHECK(array_equal(out, full({2, 2}, 2.0f), Device::cpu).item<bool>());
|
||||
}
|
||||
|
||||
// Batched matmul
|
||||
{
|
||||
auto a = ones({3, 2, 2});
|
||||
auto b = ones({3, 2, 2});
|
||||
auto out = matmul(a, b, Device::gpu);
|
||||
CHECK(array_equal(out, full({3, 2, 2}, 2.0f), Device::cpu).item<bool>());
|
||||
}
|
||||
|
||||
// Broadcast batched matmul
|
||||
{
|
||||
auto a = ones({1, 3, 2, 2});
|
||||
auto b = ones({3, 1, 2, 2});
|
||||
auto out = matmul(a, b, Device::gpu);
|
||||
CHECK(array_equal(out, full({3, 3, 2, 2}, 2.0f), Device::cpu).item<bool>());
|
||||
}
|
||||
}
|
||||
1926
tests/ops_tests.cpp
Normal file
1926
tests/ops_tests.cpp
Normal file
File diff suppressed because it is too large
Load Diff
545
tests/random_tests.cpp
Normal file
545
tests/random_tests.cpp
Normal file
@@ -0,0 +1,545 @@
|
||||
#include <numeric>
|
||||
|
||||
#include "doctest/doctest.h"
|
||||
|
||||
#include "mlx/mlx.h"
|
||||
|
||||
using namespace mlx::core;
|
||||
|
||||
TEST_CASE("test random key") {
|
||||
auto key = random::key(0);
|
||||
CHECK(array_equal(key, array({0, 0})).item<bool>());
|
||||
|
||||
key = random::key(1);
|
||||
CHECK(array_equal(key, array({0, 1})).item<bool>());
|
||||
|
||||
int64_t seed = static_cast<int64_t>(1) << 32;
|
||||
key = random::key(seed);
|
||||
CHECK(array_equal(key, array({1, 0})).item<bool>());
|
||||
|
||||
key = random::key(seed + 1);
|
||||
CHECK(array_equal(key, array({1, 1})).item<bool>());
|
||||
}
|
||||
|
||||
TEST_CASE("test global rng") {
|
||||
random::seed(4);
|
||||
auto x = random::bits({});
|
||||
auto y = random::bits({});
|
||||
|
||||
random::seed(4);
|
||||
auto a = random::bits({});
|
||||
auto b = random::bits({});
|
||||
|
||||
CHECK_EQ(x.item<uint32_t>(), a.item<uint32_t>());
|
||||
CHECK_EQ(y.item<uint32_t>(), b.item<uint32_t>());
|
||||
}
|
||||
|
||||
TEST_CASE("test random split") {
|
||||
auto [key, subkey] = random::split(random::key(0));
|
||||
CHECK(array_equal(key, array({4146024105u, 967050713u})).item<bool>());
|
||||
CHECK(array_equal(subkey, array({2718843009u, 1272950319u})).item<bool>());
|
||||
|
||||
auto keys = random::split(random::key(0), 3);
|
||||
auto expected = array(
|
||||
{2467461003u,
|
||||
428148500u,
|
||||
3186719485u,
|
||||
3840466878u,
|
||||
2562233961u,
|
||||
1946702221u},
|
||||
{3, 2});
|
||||
CHECK(array_equal(keys, expected).item<bool>());
|
||||
}
|
||||
|
||||
TEST_CASE("test random bits") {
|
||||
// Test shapes, types, and sizes
|
||||
{
|
||||
auto key = random::key(0);
|
||||
auto x = random::bits({}, key);
|
||||
CHECK_EQ(x.size(), 1);
|
||||
CHECK_EQ(x.dtype(), uint32);
|
||||
|
||||
x = random::bits({0}, key);
|
||||
CHECK(array_equal(x, array({})).item<bool>());
|
||||
|
||||
// Check wrong key type or shape
|
||||
key = array({0, 0});
|
||||
CHECK_THROWS_AS(random::uniform({}, key), std::invalid_argument);
|
||||
key = array({0, 0}, {1, 2});
|
||||
CHECK_THROWS_AS(random::uniform({}, key), std::invalid_argument);
|
||||
key = array({0u, 0u, 0u}, {3, 1});
|
||||
CHECK_THROWS_AS(random::uniform({}, key), std::invalid_argument);
|
||||
key = array({0u, 0u}, {2, 1});
|
||||
CHECK_THROWS_AS(random::uniform({}, key), std::invalid_argument);
|
||||
}
|
||||
|
||||
// Expected bits in the following tests were generated from
|
||||
// Jax's Threefry 2x32 implementation using the following in
|
||||
// python:
|
||||
//
|
||||
// ```
|
||||
// import jax
|
||||
// import jax.prng
|
||||
// shape = (SET THIS)
|
||||
// seed = (SET THIS)
|
||||
// width = (SET THIS)
|
||||
// key = jax.random.PRNGKey(seed)
|
||||
// print(jax.prng.threefry_prng_impl.random_bits(key, width, shape))
|
||||
|
||||
{
|
||||
auto key = random::key(0);
|
||||
auto x = random::bits({}, key);
|
||||
auto y = random::bits({}, key);
|
||||
CHECK_EQ(x.item<uint32_t>(), 1797259609u);
|
||||
CHECK_EQ(x.item<uint32_t>(), y.item<uint32_t>());
|
||||
|
||||
x = random::bits({}, 2, key);
|
||||
CHECK_EQ(x.item<uint16_t>(), 345);
|
||||
|
||||
x = random::bits({}, 1, key);
|
||||
CHECK_EQ(x.item<uint8_t>(), 89);
|
||||
}
|
||||
|
||||
{
|
||||
auto key = random::key(1);
|
||||
auto x = random::bits({}, key);
|
||||
CHECK_EQ(x.item<uint32_t>(), 507451445u);
|
||||
|
||||
x = random::bits({}, 2, key);
|
||||
CHECK_EQ(x.item<uint16_t>(), 6197);
|
||||
|
||||
x = random::bits({}, 1, key);
|
||||
CHECK_EQ(x.item<uint8_t>(), 53);
|
||||
|
||||
CHECK_THROWS(random::bits({}, 0, key));
|
||||
CHECK_THROWS(random::bits({}, 5, key));
|
||||
CHECK_THROWS(random::bits({}, -1, key));
|
||||
}
|
||||
|
||||
{
|
||||
auto key = random::key(0);
|
||||
auto x = random::bits({3, 1}, key);
|
||||
auto expected = array({4146024105u, 1351547692u, 2718843009u}, {3, 1});
|
||||
CHECK(array_equal(x, expected).item<bool>());
|
||||
|
||||
x = random::bits({5}, 2, key);
|
||||
expected = array({20137, 63263, 64300, 20622, 16513}, uint16);
|
||||
CHECK(array_equal(x, expected).item<bool>());
|
||||
expected = array({20137, 63263, 64300, 20622, 16513, 41486}, uint16);
|
||||
x = random::bits({6}, 2, key);
|
||||
CHECK(array_equal(x, expected).item<bool>());
|
||||
expected = array({20137, 63263, 1497, 14756, 16513, 41486, 44591}, uint16);
|
||||
x = random::bits({7}, 2, key);
|
||||
CHECK(array_equal(x, expected).item<bool>());
|
||||
x = random::bits({8}, 2, key);
|
||||
expected =
|
||||
array({20137, 63263, 1497, 14756, 16513, 41486, 44591, 19423}, uint16);
|
||||
CHECK(array_equal(x, expected).item<bool>());
|
||||
}
|
||||
|
||||
{
|
||||
auto key = array({0u, 0u, 1u, 1u}, {2, 2});
|
||||
auto shape = std::vector<int>{3};
|
||||
auto fn = [&shape](array k) { return random::bits(shape, k); };
|
||||
|
||||
auto expected = array(
|
||||
{4146024105u,
|
||||
1351547692u,
|
||||
2718843009u,
|
||||
3725146706u,
|
||||
1802982961u,
|
||||
1349634643u},
|
||||
{2, 3});
|
||||
CHECK(array_equal(vmap(fn)(key), expected).item<bool>());
|
||||
expected = array(
|
||||
{2441914641u,
|
||||
1110694964u,
|
||||
3819641963u,
|
||||
2441914641u,
|
||||
1110694964u,
|
||||
3819641963u},
|
||||
{2, 3});
|
||||
CHECK(array_equal(vmap(fn, 1)(key), expected).item<bool>());
|
||||
|
||||
// Vmap twice
|
||||
key = array(
|
||||
{0u,
|
||||
0u,
|
||||
1u,
|
||||
1u,
|
||||
2u,
|
||||
2u,
|
||||
|
||||
3u,
|
||||
3u,
|
||||
4u,
|
||||
4u,
|
||||
5u,
|
||||
5u},
|
||||
{3, 2, 2});
|
||||
shape = {2};
|
||||
auto out = vmap(vmap(fn))(key);
|
||||
expected = array(
|
||||
{928981903u,
|
||||
3453687069u,
|
||||
3606183818u,
|
||||
460005496u,
|
||||
|
||||
2799733733u,
|
||||
856293553u,
|
||||
4081856343u,
|
||||
3445925136u,
|
||||
|
||||
2775548010u,
|
||||
1430281703u,
|
||||
305173070u,
|
||||
2615843348u},
|
||||
{3, 2, 2});
|
||||
CHECK(array_equal(out, expected).item<bool>());
|
||||
|
||||
out = vmap(vmap(fn, 1), 0)(key);
|
||||
expected = array(
|
||||
{1948878966u,
|
||||
4237131848u,
|
||||
1948878966u,
|
||||
4237131848u,
|
||||
|
||||
2531170506u,
|
||||
1858648356u,
|
||||
2531170506u,
|
||||
1858648356u,
|
||||
|
||||
740561898u,
|
||||
4234094099u,
|
||||
740561898u,
|
||||
4234094099u},
|
||||
{3, 2, 2});
|
||||
CHECK(array_equal(out, expected).item<bool>());
|
||||
}
|
||||
|
||||
// Vmap smaller type
|
||||
{
|
||||
auto key = array({0u, 0u, 1u, 1u}, {2, 2});
|
||||
auto fn = [](array k) { return random::bits({5}, 2, k); };
|
||||
|
||||
auto expected = array(
|
||||
{4146024105u,
|
||||
1351547692u,
|
||||
2718843009u,
|
||||
3725146706u,
|
||||
1802982961u,
|
||||
1349634643u},
|
||||
{2, 3});
|
||||
auto out = vmap(fn)(key);
|
||||
auto x1 = random::bits({5}, 2, take(key, array(0), 0));
|
||||
auto x2 = random::bits({5}, 2, take(key, array(1), 0));
|
||||
|
||||
CHECK(array_equal(take(out, array(0), 0), x1).item<bool>());
|
||||
CHECK(array_equal(take(out, array(1), 0), x2).item<bool>());
|
||||
}
|
||||
}
|
||||
|
||||
TEST_CASE("test random uniform") {
|
||||
// Test shapes, types, and sizes
|
||||
{
|
||||
auto x = random::uniform({});
|
||||
CHECK_EQ(x.size(), 1);
|
||||
CHECK_EQ(x.dtype(), float32);
|
||||
|
||||
if (is_available(float16)) {
|
||||
x = random::uniform({}, float16);
|
||||
CHECK_EQ(x.size(), 1);
|
||||
CHECK_EQ(x.dtype(), float16);
|
||||
}
|
||||
|
||||
x = random::uniform({0});
|
||||
CHECK(array_equal(x, array({})).item<bool>());
|
||||
|
||||
// Non float type throws
|
||||
CHECK_THROWS_AS(random::uniform({}, int32), std::invalid_argument);
|
||||
|
||||
// Check broadcasting
|
||||
x = random::uniform(zeros({3, 1}), ones({1, 3}), {3, 3});
|
||||
CHECK_EQ(x.shape(), std::vector<int>{3, 3});
|
||||
CHECK_THROWS_AS(
|
||||
random::uniform(zeros({3, 3}), 1.0, {1, 3}), std::invalid_argument);
|
||||
CHECK_THROWS_AS(
|
||||
random::uniform(zeros({3, 3}), 1.0, {2, 3}), std::invalid_argument);
|
||||
CHECK_THROWS_AS(
|
||||
random::uniform(zeros({3, 1}), ones({1, 3}), {1, 3}),
|
||||
std::invalid_argument);
|
||||
|
||||
// Check wrong key type or shape
|
||||
auto key = array({0, 0});
|
||||
CHECK_THROWS_AS(random::uniform({}, key), std::invalid_argument);
|
||||
key = array({0, 0}, {1, 2});
|
||||
CHECK_THROWS_AS(random::uniform({}, key), std::invalid_argument);
|
||||
key = array({0u, 0u, 0u}, {3, 1});
|
||||
CHECK_THROWS_AS(random::uniform({}, key), std::invalid_argument);
|
||||
key = array({0u, 0u}, {2, 1});
|
||||
CHECK_THROWS_AS(random::uniform({}, key), std::invalid_argument);
|
||||
}
|
||||
|
||||
// Expected bits in the following tests were generated from
|
||||
// Jax's Threefry 2x32 implementation using the following in
|
||||
// python:
|
||||
//
|
||||
// ```
|
||||
// import jax
|
||||
// import jax.prng
|
||||
// shape = (SET THIS)
|
||||
// seed = (SET THIS)
|
||||
// key = jax.random.PRNGKey(seed)
|
||||
// print(jax.prng.threefry_prng_impl.random_bits(key, 32, shape))
|
||||
|
||||
constexpr auto to_float = [](uint32_t n) {
|
||||
return static_cast<float>(n) / UINT32_MAX;
|
||||
};
|
||||
|
||||
{
|
||||
auto key = random::key(0);
|
||||
auto x = random::uniform({}, key);
|
||||
auto y = random::uniform({}, key);
|
||||
auto expected = to_float(1797259609);
|
||||
CHECK_EQ(x.item<float>(), expected);
|
||||
CHECK_EQ(x.item<float>(), y.item<float>());
|
||||
}
|
||||
|
||||
{
|
||||
auto key = random::key(1);
|
||||
auto x = random::uniform({}, key);
|
||||
auto expected = to_float(507451445);
|
||||
CHECK_EQ(x.item<float>(), expected);
|
||||
}
|
||||
|
||||
{
|
||||
auto key = random::key(0);
|
||||
auto x = random::uniform({3, 1}, key);
|
||||
auto expected = array(
|
||||
{to_float(4146024105), to_float(1351547692), to_float(2718843009)},
|
||||
{3, 1});
|
||||
CHECK(array_equal(x, expected).item<bool>());
|
||||
}
|
||||
|
||||
// Check vmap
|
||||
{
|
||||
auto key = random::key(0);
|
||||
auto fun = [](array k, array low) {
|
||||
return random::uniform(low, 1, {3}, float32, k);
|
||||
};
|
||||
auto out = vmap(fun, -1)(key, zeros({2, 3}));
|
||||
CHECK_EQ(out.shape(), std::vector<int>{2, 3});
|
||||
|
||||
key = zeros({2, 2}, uint32);
|
||||
out = vmap(fun)(key, zeros({2, 3}));
|
||||
CHECK_EQ(out.shape(), std::vector<int>{2, 3});
|
||||
}
|
||||
|
||||
// Check bounds are respected
|
||||
{
|
||||
auto key = random::key(128291);
|
||||
auto out = random::uniform(array(-1.0f), array(1.0f), {100}, float32, key);
|
||||
CHECK(all(less(out, array(1.0f))).item<bool>());
|
||||
CHECK(all(greater_equal(out, array(-1.0f))).item<bool>());
|
||||
}
|
||||
}
|
||||
|
||||
TEST_CASE("test random normal") {
|
||||
// Test shapes, types, and sizes
|
||||
{
|
||||
auto x = random::normal({});
|
||||
CHECK_EQ(x.size(), 1);
|
||||
CHECK_EQ(x.dtype(), float32);
|
||||
|
||||
x = random::uniform({0});
|
||||
CHECK(array_equal(x, array({})).item<bool>());
|
||||
|
||||
// Non float type throws
|
||||
CHECK_THROWS_AS(random::normal({}, int32), std::invalid_argument);
|
||||
|
||||
// Check wrong key type or shape
|
||||
auto key = array({0, 0});
|
||||
CHECK_THROWS_AS(random::normal({}, key), std::invalid_argument);
|
||||
key = array({0, 0}, {1, 2});
|
||||
CHECK_THROWS_AS(random::normal({}, key), std::invalid_argument);
|
||||
key = array({0u, 0u, 0u}, {3, 1});
|
||||
CHECK_THROWS_AS(random::normal({}, key), std::invalid_argument);
|
||||
key = array({0u, 0u}, {2, 1});
|
||||
CHECK_THROWS_AS(random::normal({}, key), std::invalid_argument);
|
||||
}
|
||||
|
||||
{
|
||||
constexpr float inf = std::numeric_limits<float>::infinity();
|
||||
auto key = random::key(128291);
|
||||
auto out = random::normal({100}, key);
|
||||
CHECK(all(less(abs(out), array(inf))).item<bool>());
|
||||
}
|
||||
}
|
||||
|
||||
TEST_CASE("test random randint") {
|
||||
CHECK_THROWS_AS(
|
||||
random::randint(array(3), array(5), {1}, float32), std::invalid_argument);
|
||||
|
||||
auto x = random::randint(0, 10, {}, uint32);
|
||||
CHECK_EQ(x.size(), 1);
|
||||
CHECK_EQ(x.dtype(), uint32);
|
||||
|
||||
x = random::randint(0, 2, {}, bool_);
|
||||
CHECK_EQ(x.size(), 1);
|
||||
CHECK_EQ(x.dtype(), bool_);
|
||||
|
||||
x = random::randint(0, 2, {}, int32);
|
||||
CHECK_EQ(x.size(), 1);
|
||||
CHECK_EQ(x.dtype(), int32);
|
||||
|
||||
x = random::randint(0, 2, {}, int64);
|
||||
CHECK_EQ(x.size(), 1);
|
||||
CHECK_EQ(x.dtype(), int64);
|
||||
|
||||
// Check all in bounds
|
||||
auto low = -10.0;
|
||||
auto high = 20.0;
|
||||
x = random::randint(low, high, {1000, 1000});
|
||||
CHECK((all(low <= x).item<bool>() && all(x < high).item<bool>()));
|
||||
|
||||
// Check high < low => all equals to low
|
||||
low = 20.0;
|
||||
high = -10.0;
|
||||
x = random::randint(low, high, {3, 3});
|
||||
CHECK(all(equal(x, array(low))).item<bool>());
|
||||
|
||||
// Check wrong key type or shape
|
||||
auto key = array({0, 0}, {1, 2});
|
||||
CHECK_THROWS_AS(
|
||||
random::randint(low, high, {}, float32, key), std::invalid_argument);
|
||||
}
|
||||
|
||||
TEST_CASE("test random bernoulli") {
|
||||
auto x = random::bernoulli();
|
||||
|
||||
CHECK_EQ(x.size(), 1);
|
||||
CHECK_EQ(x.dtype(), bool_);
|
||||
|
||||
// Bernoulli parameter can have floating point type
|
||||
if (is_available(float16)) {
|
||||
x = random::bernoulli(array(0.5, float16));
|
||||
CHECK_EQ(x.size(), 1);
|
||||
CHECK_EQ(x.dtype(), bool_);
|
||||
}
|
||||
|
||||
CHECK_THROWS(random::bernoulli(array(1, int32)));
|
||||
|
||||
// Negative numbers allowed in Jax
|
||||
x = random::bernoulli(array(-1.0));
|
||||
CHECK_FALSE(x.item<bool>());
|
||||
|
||||
x = random::bernoulli(array(5.0));
|
||||
CHECK(x.item<bool>());
|
||||
|
||||
// Return array with correct shape
|
||||
x = random::bernoulli(0.5, {3, 3});
|
||||
CHECK_EQ(x.shape(), std::vector<int>({3, 3}));
|
||||
|
||||
// Try with p = {}
|
||||
x = random::bernoulli(array({}));
|
||||
CHECK_EQ(x.size(), 0);
|
||||
|
||||
// Try broadcasting
|
||||
auto p = array({0.1, 0.2, 0.3});
|
||||
p = reshape(p, {1, 3});
|
||||
x = random::bernoulli(p, {4, 3});
|
||||
CHECK_EQ(x.shape(), std::vector<int>({4, 3}));
|
||||
|
||||
CHECK_THROWS_AS(random::bernoulli(array({}), {3, 3}), std::invalid_argument);
|
||||
|
||||
p = array({0.1, 0.2, 0.3});
|
||||
// Ask for the wrong shape => throws
|
||||
CHECK_THROWS_AS(random::bernoulli(p, {2}), std::invalid_argument);
|
||||
|
||||
// Check wrong key type or shape
|
||||
auto key = array({0, 0}, {1, 2});
|
||||
CHECK_THROWS_AS(random::bernoulli(array(0.5), key), std::invalid_argument);
|
||||
}
|
||||
|
||||
TEST_CASE("Test truncated normal") {
|
||||
auto x = random::truncated_normal(array(-2.0), array(2.0));
|
||||
|
||||
CHECK_EQ(x.size(), 1);
|
||||
CHECK_EQ(x.dtype(), float32);
|
||||
|
||||
if (is_available(float16)) {
|
||||
x = random::truncated_normal(array(-2.0), array(2.0), {}, float16);
|
||||
CHECK_EQ(x.size(), 1);
|
||||
CHECK_EQ(x.dtype(), float16);
|
||||
}
|
||||
|
||||
// Requested shape
|
||||
x = random::truncated_normal(array(-2.0), array(2.0), {3, 4});
|
||||
CHECK_EQ(x.shape(), std::vector<int>({3, 4}));
|
||||
|
||||
// Empty array
|
||||
x = random::truncated_normal(array({}), array({}));
|
||||
CHECK_EQ(x.size(), 0);
|
||||
|
||||
// Broadcast
|
||||
auto lower = reshape(array({-2.0, -3.0}), {1, 2});
|
||||
auto higher = reshape(array({0.0, 3.0, 1.5}), {3, 1});
|
||||
x = random::truncated_normal(lower, higher);
|
||||
|
||||
// All in bounds
|
||||
CHECK_EQ(x.shape(), std::vector<int>({3, 2}));
|
||||
CHECK((all(x <= higher).item<bool>() && all(lower <= x).item<bool>()));
|
||||
|
||||
// high < low => all equal to low
|
||||
x = random::truncated_normal(array(2.0), array(-2.0));
|
||||
CHECK(all(x == array(2.0)).item<bool>());
|
||||
|
||||
// Non broadcastable => throws
|
||||
CHECK_THROWS_AS(
|
||||
random::truncated_normal(lower, higher, {4, 2}), std::invalid_argument);
|
||||
|
||||
auto key = array({0, 0}, {1, 2});
|
||||
CHECK_THROWS_AS(
|
||||
random::truncated_normal(array(-2.0), array(2.0), {1, 1}, float32, key),
|
||||
std::invalid_argument);
|
||||
}
|
||||
|
||||
TEST_CASE("test categorical") {
|
||||
auto logits = zeros({10, 20});
|
||||
|
||||
using random::categorical;
|
||||
|
||||
// Invalid axes
|
||||
CHECK_THROWS(categorical(logits, 2));
|
||||
CHECK_THROWS(categorical(logits, -3));
|
||||
|
||||
// Invalid requested shapes
|
||||
CHECK_THROWS(categorical(logits, 1, std::vector<int>{1}));
|
||||
CHECK_THROWS(categorical(logits, 1, std::vector<int>{11}));
|
||||
CHECK_THROWS(categorical(logits, 1, {10, 1}));
|
||||
|
||||
CHECK_EQ(categorical(logits, -1).shape(), std::vector<int>{10});
|
||||
CHECK_EQ(categorical(logits, 0).shape(), std::vector<int>{20});
|
||||
CHECK_EQ(categorical(logits, 1).shape(), std::vector<int>{10});
|
||||
|
||||
auto out = categorical(logits);
|
||||
CHECK_EQ(out.shape(), std::vector<int>{10});
|
||||
CHECK_EQ(out.dtype(), uint32);
|
||||
CHECK(max(out).item<uint32_t>() < 20);
|
||||
|
||||
out = categorical(logits, 0, {5, 20});
|
||||
CHECK_EQ(out.shape(), std::vector<int>{5, 20});
|
||||
CHECK(max(out).item<uint32_t>() < 10);
|
||||
|
||||
float inf = std::numeric_limits<float>::infinity();
|
||||
logits = array({1.0f, -2.0f, inf, 4.0f, 3.0f});
|
||||
CHECK_EQ(categorical(logits).item<uint32_t>(), 2);
|
||||
|
||||
logits = array({-inf, -2.0f, -inf, -inf});
|
||||
CHECK_EQ(categorical(logits).item<uint32_t>(), 1);
|
||||
|
||||
logits = zeros({5, 4, 3});
|
||||
CHECK_EQ(categorical(logits, -1, 7).shape(), std::vector<int>{5, 4, 7});
|
||||
CHECK_EQ(categorical(logits, -2, 7).shape(), std::vector<int>{5, 3, 7});
|
||||
CHECK_EQ(categorical(logits, -3, 7).shape(), std::vector<int>{4, 3, 7});
|
||||
}
|
||||
22
tests/tests.cpp
Normal file
22
tests/tests.cpp
Normal file
@@ -0,0 +1,22 @@
|
||||
#define DOCTEST_CONFIG_IMPLEMENT
|
||||
#include "doctest/doctest.h"
|
||||
|
||||
#include <cstdlib>
|
||||
|
||||
#include "mlx/mlx.h"
|
||||
|
||||
using namespace mlx::core;
|
||||
|
||||
int main(int argc, char** argv) {
|
||||
doctest::Context context;
|
||||
|
||||
const char* device = std::getenv("DEVICE");
|
||||
if (device != nullptr && std::string(device) == "cpu") {
|
||||
set_default_device(Device::cpu);
|
||||
} else if (metal::is_available()) {
|
||||
set_default_device(Device::gpu);
|
||||
}
|
||||
|
||||
context.applyCommandLine(argc, argv);
|
||||
return context.run();
|
||||
}
|
||||
248
tests/vmap_tests.cpp
Normal file
248
tests/vmap_tests.cpp
Normal file
@@ -0,0 +1,248 @@
|
||||
#include "doctest/doctest.h"
|
||||
|
||||
#include "mlx/mlx.h"
|
||||
|
||||
using namespace mlx::core;
|
||||
|
||||
TEST_CASE("test simple vmap") {
|
||||
// vmap reshape
|
||||
{
|
||||
auto vfun = vmap([](array input) { return reshape(input, {2, 2}); });
|
||||
auto x = zeros({3, 4});
|
||||
CHECK(array_equal(vfun(x), zeros({3, 2, 2})).item<bool>());
|
||||
|
||||
x = array({0, 1, 2, 3, 4, 5, 6, 7}, {2, 2, 2});
|
||||
vfun = vmap([](array input) { return reshape(input, {4}); });
|
||||
auto expected = array({0, 1, 2, 3, 4, 5, 6, 7}, {2, 4});
|
||||
CHECK(array_equal(vfun(x), expected).item<bool>());
|
||||
|
||||
vfun = vmap([](array input) { return reshape(input, {4}); }, 1);
|
||||
expected = array({0, 1, 4, 5, 2, 3, 6, 7}, {2, 4});
|
||||
CHECK(array_equal(vfun(x), expected).item<bool>());
|
||||
|
||||
vfun = vmap([](array input) { return reshape(input, {4}); }, 1, 1);
|
||||
expected = array({0, 2, 1, 3, 4, 6, 5, 7}, {4, 2});
|
||||
CHECK(array_equal(vfun(x), expected).item<bool>());
|
||||
}
|
||||
|
||||
// vmap broadcast
|
||||
{
|
||||
auto fun = [](array input) { return broadcast_to(input, {4, 2}); };
|
||||
|
||||
CHECK_THROWS_AS(vmap(fun, 0, -1), std::invalid_argument);
|
||||
CHECK_THROWS_AS(vmap(fun, -1, 0), std::invalid_argument);
|
||||
|
||||
auto vfun = vmap(fun, -1, -1);
|
||||
auto x = zeros({2});
|
||||
CHECK(array_equal(vfun(x), zeros({4, 2})).item<bool>());
|
||||
|
||||
vfun = vmap(fun);
|
||||
x = zeros({3, 2});
|
||||
CHECK(array_equal(vfun(x), zeros({3, 4, 2})).item<bool>());
|
||||
|
||||
vfun = vmap(fun, 0, 1);
|
||||
CHECK(array_equal(vfun(x), zeros({4, 3, 2})).item<bool>());
|
||||
|
||||
vfun = vmap(fun, 0, 2);
|
||||
CHECK(array_equal(vfun(x), zeros({4, 2, 3})).item<bool>());
|
||||
|
||||
vfun = vmap(fun, 0, 2);
|
||||
x = zeros({2, 3});
|
||||
CHECK_THROWS_AS(vfun(x), std::invalid_argument);
|
||||
|
||||
x = zeros({2, 3});
|
||||
vfun = vmap(fun, 1);
|
||||
CHECK(array_equal(vfun(x), zeros({3, 4, 2})).item<bool>());
|
||||
|
||||
vfun = vmap(fun, 1, 1);
|
||||
CHECK(array_equal(vfun(x), zeros({4, 3, 2})).item<bool>());
|
||||
|
||||
vfun = vmap(fun, 1, 2);
|
||||
CHECK(array_equal(vfun(x), zeros({4, 2, 3})).item<bool>());
|
||||
}
|
||||
|
||||
// vmap transpose
|
||||
{
|
||||
auto fun = [](array input) { return transpose(input); };
|
||||
auto vfun = vmap(fun);
|
||||
auto x = array({0, 1, 2, 3, 4, 5}, {3, 2});
|
||||
CHECK(array_equal(vfun(x), x).item<bool>());
|
||||
|
||||
vfun = vmap(fun, 0, 1);
|
||||
CHECK(array_equal(vfun(x), transpose(x)).item<bool>());
|
||||
|
||||
x = array({0, 1, 2, 3, 4, 5, 6, 7}, {2, 2, 2});
|
||||
vfun = vmap(fun);
|
||||
CHECK(array_equal(vfun(x), transpose(x, {0, 2, 1})).item<bool>());
|
||||
|
||||
vfun = vmap(fun, 1, 1);
|
||||
CHECK(array_equal(vfun(x), transpose(x, {2, 1, 0})).item<bool>());
|
||||
|
||||
vfun = vmap(fun, 2, 2);
|
||||
CHECK(array_equal(vfun(x), transpose(x, {1, 0, 2})).item<bool>());
|
||||
|
||||
// vmap twice
|
||||
x = array(
|
||||
{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, {2, 2, 2, 2});
|
||||
vfun = vmap(vmap(fun));
|
||||
CHECK(array_equal(vfun(x), transpose(x, {0, 1, 3, 2})).item<bool>());
|
||||
}
|
||||
|
||||
// vmap add
|
||||
{
|
||||
auto fun = [](std::vector<array> inputs) {
|
||||
auto out = add(inputs[0], inputs[1]);
|
||||
return std::vector<array>{out};
|
||||
};
|
||||
|
||||
auto vfun = vmap(fun);
|
||||
array x({1.0, 2.0}, {2, 1});
|
||||
array y({2.0, 3.0}, {2, 1});
|
||||
auto out = vfun({x, y})[0];
|
||||
CHECK(array_equal(out, array({3.0, 5.0}, {2, 1})).item<bool>());
|
||||
|
||||
x = ones({2, 1, 3});
|
||||
y = ones({3, 2});
|
||||
vfun = vmap(fun, {2, 0});
|
||||
out = vfun({x, y})[0];
|
||||
CHECK(array_equal(out, full({3, 2, 2}, 2.0)).item<bool>());
|
||||
|
||||
x = array(1.);
|
||||
y = ones({3, 2});
|
||||
vfun = vmap(fun, {-1, 0});
|
||||
out = vfun({x, y})[0];
|
||||
CHECK(array_equal(out, full({3, 2}, 2.0)).item<bool>());
|
||||
|
||||
x = ones({3, 2});
|
||||
y = array(1.);
|
||||
vfun = vmap(fun, {0, -1});
|
||||
out = vfun({x, y})[0];
|
||||
CHECK(array_equal(out, full({3, 2}, 2.0)).item<bool>());
|
||||
|
||||
CHECK_THROWS_AS(vmap(fun, {-1, -1}, {0}), std::invalid_argument);
|
||||
CHECK_THROWS_AS(vmap(fun, {-1, 0}, {-1}), std::invalid_argument);
|
||||
CHECK_THROWS_AS(vmap(fun, {0, -1}, {-1}), std::invalid_argument);
|
||||
|
||||
x = array(1.);
|
||||
y = array(1.);
|
||||
vfun = vmap(fun, {-1, -1}, {-1});
|
||||
out = vfun({x, y})[0];
|
||||
CHECK(array_equal(out, array(2.)).item<bool>());
|
||||
|
||||
x = ones({3, 2, 1});
|
||||
y = ones({3, 2, 1});
|
||||
vfun = vmap(vmap(fun));
|
||||
out = vfun({x, y})[0];
|
||||
CHECK(array_equal(out, x + y).item<bool>());
|
||||
}
|
||||
|
||||
// vmap with capturing closure
|
||||
{
|
||||
auto x = add(add(ones({2}), zeros({2})), zeros({2}));
|
||||
auto fun = [x](const array& input) { return add(input, x); };
|
||||
|
||||
auto vfun = vmap(fun);
|
||||
auto y = ones({3, 2});
|
||||
CHECK(array_equal(vfun(y), full({3, 2}, 2.0f)).item<bool>());
|
||||
}
|
||||
{
|
||||
auto x = ones({4});
|
||||
auto z = x + x;
|
||||
auto vfun = vmap(
|
||||
[z](std::vector<array> inputs) {
|
||||
return std::vector<array>{add(z, inputs[1])};
|
||||
},
|
||||
{-1, 0});
|
||||
auto y = ones({3, 4});
|
||||
CHECK(array_equal(vfun({x, y})[0], full({3, 4}, 3.0)).item<bool>());
|
||||
}
|
||||
}
|
||||
|
||||
TEST_CASE("test vmap with eval") {
|
||||
auto fun = [](std::vector<array> inputs) {
|
||||
auto x = inputs[0] + 1;
|
||||
auto y = inputs[1] + 2;
|
||||
eval(x);
|
||||
auto out = add(x, y);
|
||||
return std::vector<array>{out};
|
||||
};
|
||||
|
||||
auto vfun = vmap(fun);
|
||||
array x({1.0, 2.0}, {2, 1});
|
||||
array y({2.0, 3.0}, {2, 1});
|
||||
CHECK_THROWS(vfun({x, y}));
|
||||
|
||||
// Ok to eval functions of non-vmapped input
|
||||
x = array(1.0);
|
||||
vfun = vmap(fun, {-1, 0});
|
||||
CHECK(array_equal(vfun({x, y})[0], array({6.0f, 7.0f}, {2, 1})).item<bool>());
|
||||
|
||||
// Not ok to eval function of vmapped input even with retain graph
|
||||
auto fun2 = [](std::vector<array> inputs) {
|
||||
auto x = inputs[0] + 1;
|
||||
auto y = inputs[1] + 2;
|
||||
eval({x}, true);
|
||||
auto out = add(x, y);
|
||||
return std::vector<array>{out};
|
||||
};
|
||||
x = array({1.0, 2.0}, {2, 1});
|
||||
CHECK_THROWS(vmap(fun2)({x, y}));
|
||||
}
|
||||
|
||||
TEST_CASE("test vmap comparison ops") {
|
||||
// vmap equal
|
||||
{
|
||||
auto fun = [](std::vector<array> inputs) {
|
||||
return std::vector<array>{equal(inputs[0], inputs[1])};
|
||||
};
|
||||
auto vfun = vmap(fun);
|
||||
auto x = zeros({2, 3}, float32);
|
||||
auto y = zeros({2, 3}, float32);
|
||||
auto out = vfun({x, y})[0];
|
||||
CHECK(all(out).item<bool>());
|
||||
|
||||
vfun = vmap(fun, {0, -1});
|
||||
x = zeros({2, 3}, float32);
|
||||
y = zeros({3}, float32);
|
||||
out = vfun({x, y})[0];
|
||||
CHECK(all(out).item<bool>());
|
||||
|
||||
vfun = vmap(fun, {0, -1});
|
||||
x = array({0, 0, 0, 1, 1, 1}, {2, 3});
|
||||
y = zeros({3}, float32);
|
||||
out = vfun({x, y})[0];
|
||||
auto expected = array({true, true, true, false, false, false}, {2, 3});
|
||||
CHECK(array_equal(out, expected).item<bool>());
|
||||
}
|
||||
}
|
||||
|
||||
TEST_CASE("test vmap creation ops") {
|
||||
// vmap astype
|
||||
{
|
||||
auto fun = [](array in) { return astype(in, int32); };
|
||||
auto x = zeros({2, 3}, float32);
|
||||
auto out = vmap(fun)(x);
|
||||
CHECK_EQ(out.dtype(), int32);
|
||||
CHECK(array_equal(out, zeros({2, 3}, int32)).item<bool>());
|
||||
}
|
||||
|
||||
// vmap full
|
||||
{
|
||||
auto fun = [](array in) { return full({2}, in); };
|
||||
auto x = array({1, 2, 3});
|
||||
auto out = vmap(fun)(x);
|
||||
auto expected = array({1, 1, 2, 2, 3, 3}, {3, 2});
|
||||
CHECK(array_equal(out, expected).item<bool>());
|
||||
|
||||
x = array({1, 2, 3}, {3, 1});
|
||||
out = vmap(fun)(x);
|
||||
expected = array({1, 1, 2, 2, 3, 3}, {3, 2});
|
||||
CHECK(array_equal(out, expected).item<bool>());
|
||||
|
||||
x = array({1, 2, 3}, {1, 3});
|
||||
CHECK_THROWS_AS(vmap(fun)(x), std::invalid_argument);
|
||||
out = vmap(fun, 1, 1)(x);
|
||||
expected = array({1, 2, 3, 1, 2, 3}, {2, 3});
|
||||
CHECK(array_equal(out, expected).item<bool>());
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user