2023-12-01 03:12:53 +08:00
|
|
|
// Copyright © 2023 Apple Inc.
|
|
|
|
|
2023-11-30 02:52:08 +08:00
|
|
|
#include <climits>
|
|
|
|
|
|
|
|
#include "doctest/doctest.h"
|
|
|
|
|
|
|
|
#include "mlx/mlx.h"
|
|
|
|
|
|
|
|
using namespace mlx::core;
|
|
|
|
|
|
|
|
TEST_CASE("test array basics") {
|
|
|
|
// Scalar
|
|
|
|
array x(1.0);
|
|
|
|
CHECK_EQ(x.size(), 1);
|
|
|
|
CHECK_EQ(x.ndim(), 0);
|
2024-12-10 03:09:02 +08:00
|
|
|
CHECK_EQ(x.shape(), Shape{});
|
2023-11-30 02:52:08 +08:00
|
|
|
CHECK_THROWS_AS(x.shape(0), std::out_of_range);
|
|
|
|
CHECK_THROWS_AS(x.shape(-1), std::out_of_range);
|
2024-12-10 03:09:02 +08:00
|
|
|
CHECK_EQ(x.strides(), Strides{});
|
2023-11-30 02:52:08 +08:00
|
|
|
CHECK_EQ(x.itemsize(), sizeof(float));
|
|
|
|
CHECK_EQ(x.nbytes(), sizeof(float));
|
|
|
|
CHECK_EQ(x.dtype(), float32);
|
|
|
|
CHECK_EQ(x.item<float>(), 1.0);
|
|
|
|
|
|
|
|
// Scalar with specified type
|
|
|
|
x = array(1, float32);
|
|
|
|
CHECK_EQ(x.dtype(), float32);
|
|
|
|
CHECK_EQ(x.item<float>(), 1.0);
|
|
|
|
|
|
|
|
// Scalar with specified type
|
|
|
|
x = array(1, bool_);
|
|
|
|
CHECK_EQ(x.dtype(), bool_);
|
|
|
|
CHECK_EQ(x.itemsize(), sizeof(bool));
|
|
|
|
CHECK_EQ(x.nbytes(), sizeof(bool));
|
|
|
|
CHECK_EQ(x.item<bool>(), true);
|
|
|
|
|
|
|
|
// Check shaped arrays
|
|
|
|
x = array({1.0});
|
|
|
|
CHECK_EQ(x.dtype(), float32);
|
|
|
|
CHECK_EQ(x.size(), 1);
|
|
|
|
CHECK_EQ(x.ndim(), 1);
|
2024-12-10 03:09:02 +08:00
|
|
|
CHECK_EQ(x.shape(), Shape{1});
|
2023-11-30 02:52:08 +08:00
|
|
|
CHECK_EQ(x.shape(0), 1);
|
|
|
|
CHECK_EQ(x.shape(-1), 1);
|
|
|
|
CHECK_THROWS_AS(x.shape(1), std::out_of_range);
|
|
|
|
CHECK_THROWS_AS(x.shape(-2), std::out_of_range);
|
2024-12-10 03:09:02 +08:00
|
|
|
CHECK_EQ(x.strides(), Strides{1});
|
2023-11-30 02:52:08 +08:00
|
|
|
CHECK_EQ(x.item<float>(), 1.0);
|
|
|
|
|
|
|
|
// Check empty array
|
|
|
|
x = array({});
|
|
|
|
CHECK_EQ(x.size(), 0);
|
|
|
|
CHECK_EQ(x.dtype(), float32);
|
|
|
|
CHECK_EQ(x.itemsize(), sizeof(float));
|
|
|
|
CHECK_EQ(x.nbytes(), 0);
|
|
|
|
CHECK_THROWS_AS(x.item<float>(), std::invalid_argument);
|
|
|
|
|
|
|
|
x = array({1.0, 1.0});
|
|
|
|
CHECK_EQ(x.size(), 2);
|
2024-12-10 03:09:02 +08:00
|
|
|
CHECK_EQ(x.shape(), Shape{2});
|
2023-11-30 02:52:08 +08:00
|
|
|
CHECK_EQ(x.itemsize(), sizeof(float));
|
|
|
|
CHECK_EQ(x.nbytes(), x.itemsize() * x.size());
|
|
|
|
|
|
|
|
// Accessing item in non-scalar array throws
|
|
|
|
CHECK_THROWS_AS(x.item<float>(), std::invalid_argument);
|
|
|
|
|
|
|
|
x = array({1.0, 1.0, 1.0}, {1, 3});
|
2024-12-10 03:09:02 +08:00
|
|
|
CHECK_EQ(x.size(), 3);
|
|
|
|
CHECK_EQ(x.shape(), Shape{1, 3});
|
|
|
|
CHECK_EQ(x.strides(), Strides{3, 1});
|
2023-11-30 02:52:08 +08:00
|
|
|
|
|
|
|
// Test wrong size/shapes throw:
|
|
|
|
CHECK_THROWS_AS(array({1.0, 1.0, 1.0}, {4}), std::invalid_argument);
|
|
|
|
CHECK_THROWS_AS(array({1.0, 1.0, 1.0}, {1, 4}), std::invalid_argument);
|
|
|
|
CHECK_THROWS_AS(array({1.0, 1.0, 1.0}, {1, 2}), std::invalid_argument);
|
|
|
|
|
|
|
|
// Test array ids work as expected
|
|
|
|
x = array(1.0);
|
|
|
|
auto y = x;
|
|
|
|
CHECK_EQ(y.id(), x.id());
|
|
|
|
array z(2.0);
|
|
|
|
CHECK_NE(z.id(), x.id());
|
|
|
|
z = x;
|
|
|
|
CHECK_EQ(z.id(), x.id());
|
|
|
|
|
|
|
|
// Array creation from pointer
|
|
|
|
float data[] = {0.0, 1.0, 2.0, 3.0};
|
|
|
|
x = array(data, {4});
|
|
|
|
CHECK_EQ(x.dtype(), float32);
|
|
|
|
CHECK(array_equal(x, array({0.0, 1.0, 2.0, 3.0})).item<bool>());
|
|
|
|
|
|
|
|
// Array creation from vectors
|
|
|
|
{
|
|
|
|
std::vector<int> data = {0, 1, 2, 3};
|
|
|
|
x = array(data.begin(), {4});
|
|
|
|
CHECK_EQ(x.dtype(), int32);
|
|
|
|
CHECK(array_equal(x, array({0, 1, 2, 3})).item<bool>());
|
|
|
|
}
|
|
|
|
|
|
|
|
{
|
|
|
|
std::vector<bool> data = {false, true, false, true};
|
|
|
|
x = array(data.begin(), {4});
|
|
|
|
CHECK_EQ(x.dtype(), bool_);
|
|
|
|
CHECK(array_equal(x, array({false, true, false, true})).item<bool>());
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
TEST_CASE("test array types") {
|
|
|
|
#define basic_dtype_test(T, mlx_type) \
|
|
|
|
T val = 42; \
|
|
|
|
array x(val); \
|
|
|
|
CHECK_EQ(x.dtype(), mlx_type); \
|
|
|
|
CHECK_EQ(x.item<T>(), val); \
|
|
|
|
x = array({val, val}); \
|
|
|
|
CHECK_EQ(x.dtype(), mlx_type);
|
|
|
|
|
|
|
|
// bool_
|
|
|
|
{
|
|
|
|
array x(true);
|
|
|
|
CHECK_EQ(x.dtype(), bool_);
|
|
|
|
CHECK_EQ(x.item<bool>(), true);
|
|
|
|
|
|
|
|
x = array({true, false});
|
|
|
|
CHECK_EQ(x.dtype(), bool_);
|
|
|
|
|
|
|
|
x = array({true, false}, float32);
|
|
|
|
CHECK_EQ(x.dtype(), float32);
|
|
|
|
CHECK(array_equal(x, array({1.0f, 0.0f})).item<bool>());
|
|
|
|
}
|
|
|
|
|
|
|
|
// uint8
|
2024-12-06 23:54:29 +08:00
|
|
|
{
|
|
|
|
basic_dtype_test(uint8_t, uint8);
|
|
|
|
}
|
2023-11-30 02:52:08 +08:00
|
|
|
|
|
|
|
// uint16
|
2024-12-06 23:54:29 +08:00
|
|
|
{
|
|
|
|
basic_dtype_test(uint16_t, uint16);
|
|
|
|
}
|
2023-11-30 02:52:08 +08:00
|
|
|
|
|
|
|
// uint32
|
2024-12-06 23:54:29 +08:00
|
|
|
{
|
|
|
|
basic_dtype_test(uint32_t, uint32);
|
|
|
|
}
|
2023-11-30 02:52:08 +08:00
|
|
|
|
|
|
|
// uint64
|
2024-12-06 23:54:29 +08:00
|
|
|
{
|
|
|
|
basic_dtype_test(uint64_t, uint64);
|
|
|
|
}
|
2023-11-30 02:52:08 +08:00
|
|
|
|
|
|
|
// int8
|
2024-12-06 23:54:29 +08:00
|
|
|
{
|
|
|
|
basic_dtype_test(int8_t, int8);
|
|
|
|
}
|
2023-11-30 02:52:08 +08:00
|
|
|
|
|
|
|
// int16
|
2024-12-06 23:54:29 +08:00
|
|
|
{
|
|
|
|
basic_dtype_test(int16_t, int16);
|
|
|
|
}
|
2023-11-30 02:52:08 +08:00
|
|
|
|
|
|
|
// int32
|
2024-12-06 23:54:29 +08:00
|
|
|
{
|
|
|
|
basic_dtype_test(int32_t, int32);
|
|
|
|
}
|
2023-11-30 02:52:08 +08:00
|
|
|
|
|
|
|
// int64
|
2024-12-06 23:54:29 +08:00
|
|
|
{
|
|
|
|
basic_dtype_test(int64_t, int64);
|
|
|
|
}
|
2023-11-30 02:52:08 +08:00
|
|
|
|
|
|
|
// float16
|
2024-12-06 23:54:29 +08:00
|
|
|
{
|
|
|
|
basic_dtype_test(float16_t, float16);
|
|
|
|
}
|
2023-11-30 02:52:08 +08:00
|
|
|
|
|
|
|
// float32
|
2024-12-06 23:54:29 +08:00
|
|
|
{
|
|
|
|
basic_dtype_test(float, float32);
|
|
|
|
}
|
2023-11-30 02:52:08 +08:00
|
|
|
|
|
|
|
// bfloat16
|
2024-12-06 23:54:29 +08:00
|
|
|
{
|
|
|
|
basic_dtype_test(bfloat16_t, bfloat16);
|
|
|
|
}
|
2023-11-30 02:52:08 +08:00
|
|
|
|
2024-08-08 04:38:07 +08:00
|
|
|
#undef basic_dtype_test
|
|
|
|
|
2023-11-30 02:52:08 +08:00
|
|
|
// uint32
|
|
|
|
{
|
|
|
|
uint32_t val = UINT_MAX;
|
|
|
|
array x(val);
|
|
|
|
CHECK_EQ(x.dtype(), uint32);
|
|
|
|
CHECK_EQ(x.item<uint32_t>(), val);
|
|
|
|
|
|
|
|
x = array({1u, 2u});
|
|
|
|
CHECK_EQ(x.dtype(), uint32);
|
|
|
|
}
|
|
|
|
|
|
|
|
// int32
|
|
|
|
{
|
|
|
|
array x(-1);
|
|
|
|
CHECK_EQ(x.dtype(), int32);
|
|
|
|
CHECK_EQ(x.item<int>(), -1);
|
|
|
|
|
|
|
|
x = array({-1, 2});
|
|
|
|
CHECK_EQ(x.dtype(), int32);
|
|
|
|
|
|
|
|
std::vector<int> data{0, 1, 2};
|
|
|
|
x = array(data.data(), {static_cast<int>(data.size())}, bool_);
|
|
|
|
CHECK_EQ(x.dtype(), bool_);
|
|
|
|
CHECK(array_equal(x, array({false, true, true})).item<bool>());
|
|
|
|
}
|
|
|
|
|
|
|
|
// int64
|
|
|
|
{
|
|
|
|
int64_t val = static_cast<int64_t>(INT_MIN) - 1;
|
|
|
|
array x(val);
|
|
|
|
CHECK_EQ(x.dtype(), int64);
|
|
|
|
CHECK_EQ(x.item<int64_t>(), val);
|
|
|
|
|
|
|
|
x = array({val, val});
|
|
|
|
CHECK_EQ(x.dtype(), int64);
|
|
|
|
}
|
|
|
|
|
|
|
|
// float32
|
|
|
|
{
|
|
|
|
array x(3.14f);
|
|
|
|
CHECK_EQ(x.dtype(), float32);
|
|
|
|
CHECK_EQ(x.item<float>(), 3.14f);
|
|
|
|
|
|
|
|
x = array(1.25);
|
|
|
|
CHECK_EQ(x.dtype(), float32);
|
|
|
|
CHECK_EQ(x.item<float>(), 1.25f);
|
|
|
|
|
|
|
|
x = array({1.0f, 2.0f});
|
|
|
|
CHECK_EQ(x.dtype(), float32);
|
|
|
|
|
|
|
|
x = array({1.0, 2.0});
|
|
|
|
CHECK_EQ(x.dtype(), float32);
|
|
|
|
|
|
|
|
std::vector<double> data{1.0, 2.0, 4.0};
|
|
|
|
x = array(data.data(), {static_cast<int>(data.size())});
|
|
|
|
CHECK_EQ(x.dtype(), float32);
|
|
|
|
CHECK(array_equal(x, array({1.0f, 2.0f, 4.0f})).item<bool>());
|
|
|
|
}
|
|
|
|
|
|
|
|
// complex64
|
|
|
|
{
|
2024-01-06 07:58:33 +08:00
|
|
|
CHECK_EQ(sizeof(complex64_t), sizeof(std::complex<float>));
|
|
|
|
|
2023-11-30 02:52:08 +08:00
|
|
|
complex64_t v = {1.0f, 1.0f};
|
|
|
|
array x(v);
|
|
|
|
CHECK_EQ(x.dtype(), complex64);
|
|
|
|
CHECK_EQ(x.item<complex64_t>(), v);
|
|
|
|
|
|
|
|
array y(std::complex<float>{1.0f, 1.0f});
|
|
|
|
CHECK_EQ(x.dtype(), complex64);
|
|
|
|
CHECK_EQ(x.item<complex64_t>(), v);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
TEST_CASE("test array metadata") {
|
|
|
|
array x(1.0f);
|
|
|
|
CHECK_EQ(x.data_size(), 1);
|
|
|
|
CHECK_EQ(x.flags().contiguous, true);
|
|
|
|
CHECK_EQ(x.flags().row_contiguous, true);
|
|
|
|
CHECK_EQ(x.flags().col_contiguous, true);
|
|
|
|
|
|
|
|
x = array({1.0f}, {1, 1, 1});
|
|
|
|
CHECK_EQ(x.data_size(), 1);
|
|
|
|
CHECK_EQ(x.flags().contiguous, true);
|
|
|
|
CHECK_EQ(x.flags().row_contiguous, true);
|
|
|
|
CHECK_EQ(x.flags().col_contiguous, true);
|
|
|
|
|
|
|
|
x = array({1.0f, 1.0f}, {1, 2});
|
|
|
|
CHECK_EQ(x.data_size(), 2);
|
|
|
|
CHECK_EQ(x.flags().contiguous, true);
|
|
|
|
CHECK_EQ(x.flags().row_contiguous, true);
|
|
|
|
CHECK_EQ(x.flags().col_contiguous, true);
|
|
|
|
|
|
|
|
x = zeros({1, 1, 4});
|
|
|
|
eval(x);
|
|
|
|
CHECK_EQ(x.data_size(), 4);
|
|
|
|
CHECK_EQ(x.flags().contiguous, true);
|
|
|
|
CHECK_EQ(x.flags().row_contiguous, true);
|
|
|
|
CHECK_EQ(x.flags().col_contiguous, true);
|
|
|
|
|
|
|
|
x = zeros({2, 4});
|
|
|
|
eval(x);
|
|
|
|
CHECK_EQ(x.data_size(), 8);
|
|
|
|
CHECK_EQ(x.flags().contiguous, true);
|
|
|
|
CHECK_EQ(x.flags().row_contiguous, true);
|
|
|
|
CHECK_EQ(x.flags().col_contiguous, false);
|
|
|
|
|
|
|
|
x = array(1.0f);
|
|
|
|
auto y = broadcast_to(x, {1, 1, 1});
|
|
|
|
eval(y);
|
|
|
|
CHECK_EQ(y.data_size(), 1);
|
|
|
|
CHECK_EQ(y.flags().contiguous, true);
|
|
|
|
CHECK_EQ(y.flags().row_contiguous, true);
|
|
|
|
CHECK_EQ(y.flags().col_contiguous, true);
|
|
|
|
|
|
|
|
y = broadcast_to(x, {2, 8, 10});
|
|
|
|
eval(y);
|
|
|
|
CHECK_EQ(y.data_size(), 1);
|
|
|
|
CHECK_EQ(y.flags().contiguous, true);
|
|
|
|
CHECK_EQ(y.flags().row_contiguous, false);
|
|
|
|
CHECK_EQ(y.flags().col_contiguous, false);
|
|
|
|
|
|
|
|
y = broadcast_to(x, {1, 0});
|
|
|
|
eval(y);
|
|
|
|
CHECK_EQ(y.data_size(), 0);
|
|
|
|
CHECK_EQ(y.flags().contiguous, true);
|
|
|
|
CHECK_EQ(y.flags().row_contiguous, true);
|
|
|
|
CHECK_EQ(y.flags().col_contiguous, true);
|
|
|
|
|
|
|
|
y = broadcast_to(zeros({4, 2, 1}), {4, 2, 0});
|
|
|
|
eval(y);
|
|
|
|
CHECK_EQ(y.data_size(), 0);
|
|
|
|
CHECK_EQ(y.flags().contiguous, true);
|
|
|
|
CHECK_EQ(y.flags().row_contiguous, true);
|
|
|
|
CHECK_EQ(y.flags().col_contiguous, true);
|
|
|
|
|
|
|
|
x = array(1.0f);
|
|
|
|
y = transpose(x);
|
|
|
|
eval(y);
|
|
|
|
CHECK_EQ(y.data_size(), 1);
|
|
|
|
CHECK_EQ(y.flags().contiguous, true);
|
|
|
|
CHECK_EQ(y.flags().row_contiguous, true);
|
|
|
|
CHECK_EQ(y.flags().col_contiguous, true);
|
|
|
|
|
|
|
|
x = ones({1, 1, 1});
|
|
|
|
y = transpose(x);
|
|
|
|
eval(y);
|
|
|
|
CHECK_EQ(y.data_size(), 1);
|
|
|
|
CHECK_EQ(y.flags().contiguous, true);
|
|
|
|
CHECK_EQ(y.flags().row_contiguous, true);
|
|
|
|
CHECK_EQ(y.flags().col_contiguous, true);
|
|
|
|
|
|
|
|
x = ones({1, 1, 1});
|
|
|
|
y = transpose(x, {0, 1, 2});
|
|
|
|
eval(y);
|
|
|
|
CHECK_EQ(y.data_size(), 1);
|
|
|
|
CHECK_EQ(y.flags().contiguous, true);
|
|
|
|
CHECK_EQ(y.flags().row_contiguous, true);
|
|
|
|
CHECK_EQ(y.flags().col_contiguous, true);
|
|
|
|
|
|
|
|
x = ones({1, 1, 1});
|
|
|
|
y = transpose(x, {1, 2, 0});
|
|
|
|
eval(y);
|
|
|
|
CHECK_EQ(y.data_size(), 1);
|
|
|
|
CHECK_EQ(y.flags().contiguous, true);
|
|
|
|
CHECK_EQ(y.flags().row_contiguous, true);
|
|
|
|
CHECK_EQ(y.flags().col_contiguous, true);
|
|
|
|
|
|
|
|
x = ones({4, 1});
|
|
|
|
y = transpose(x);
|
|
|
|
eval(y);
|
|
|
|
CHECK_EQ(y.data_size(), 4);
|
|
|
|
CHECK_EQ(y.flags().contiguous, true);
|
|
|
|
CHECK_EQ(y.flags().row_contiguous, true);
|
|
|
|
CHECK_EQ(y.flags().col_contiguous, true);
|
|
|
|
|
|
|
|
x = ones({2, 3, 4});
|
|
|
|
y = transpose(x);
|
|
|
|
eval(y);
|
|
|
|
CHECK_EQ(y.data_size(), 24);
|
|
|
|
CHECK_EQ(y.flags().contiguous, true);
|
|
|
|
CHECK_EQ(y.flags().row_contiguous, false);
|
|
|
|
CHECK_EQ(y.flags().col_contiguous, true);
|
|
|
|
|
|
|
|
y = transpose(x, {0, 2, 1});
|
|
|
|
eval(y);
|
|
|
|
CHECK_EQ(y.data_size(), 24);
|
|
|
|
CHECK_EQ(y.flags().contiguous, true);
|
|
|
|
CHECK_EQ(y.flags().row_contiguous, false);
|
|
|
|
CHECK_EQ(y.flags().col_contiguous, false);
|
|
|
|
|
|
|
|
y = transpose(transpose(x, {0, 2, 1}), {0, 2, 1});
|
|
|
|
eval(y);
|
|
|
|
CHECK_EQ(y.data_size(), 24);
|
|
|
|
CHECK_EQ(y.flags().contiguous, true);
|
|
|
|
CHECK_EQ(y.flags().row_contiguous, true);
|
|
|
|
CHECK_EQ(y.flags().col_contiguous, false);
|
|
|
|
|
|
|
|
x = array(1.0f);
|
|
|
|
y = reshape(x, {1, 1, 1});
|
|
|
|
eval(y);
|
|
|
|
CHECK_EQ(y.data_size(), 1);
|
|
|
|
CHECK_EQ(y.flags().contiguous, true);
|
|
|
|
CHECK_EQ(y.flags().row_contiguous, true);
|
|
|
|
CHECK_EQ(y.flags().col_contiguous, true);
|
|
|
|
|
|
|
|
x = ones({2, 4});
|
|
|
|
y = reshape(x, {8});
|
|
|
|
eval(y);
|
|
|
|
CHECK_EQ(y.data_size(), 8);
|
|
|
|
CHECK_EQ(y.flags().contiguous, true);
|
|
|
|
CHECK_EQ(y.flags().row_contiguous, true);
|
|
|
|
CHECK_EQ(y.flags().col_contiguous, true);
|
|
|
|
|
|
|
|
y = reshape(x, {8, 1, 1});
|
|
|
|
eval(y);
|
|
|
|
CHECK_EQ(y.data_size(), 8);
|
|
|
|
CHECK_EQ(y.flags().contiguous, true);
|
|
|
|
CHECK_EQ(y.flags().row_contiguous, true);
|
|
|
|
CHECK_EQ(y.flags().col_contiguous, true);
|
|
|
|
|
|
|
|
y = reshape(x, {1, 8, 1});
|
|
|
|
eval(y);
|
|
|
|
CHECK_EQ(y.data_size(), 8);
|
|
|
|
CHECK_EQ(y.flags().contiguous, true);
|
|
|
|
CHECK_EQ(y.flags().row_contiguous, true);
|
|
|
|
CHECK_EQ(y.flags().col_contiguous, true);
|
|
|
|
|
|
|
|
x = ones({12});
|
|
|
|
y = reshape(x, {2, 3, 2});
|
|
|
|
eval(y);
|
|
|
|
CHECK_EQ(y.data_size(), 12);
|
|
|
|
CHECK_EQ(y.flags().contiguous, true);
|
|
|
|
CHECK_EQ(y.flags().row_contiguous, true);
|
|
|
|
CHECK_EQ(y.flags().col_contiguous, false);
|
|
|
|
|
|
|
|
x = array(1.0f);
|
|
|
|
y = slice(x, {}, {});
|
|
|
|
eval(y);
|
|
|
|
CHECK_EQ(y.data_size(), 1);
|
|
|
|
CHECK_EQ(y.flags().contiguous, true);
|
|
|
|
CHECK_EQ(y.flags().row_contiguous, true);
|
|
|
|
CHECK_EQ(y.flags().col_contiguous, true);
|
|
|
|
|
|
|
|
x = array({1.0f});
|
|
|
|
y = slice(x, {-10}, {10}, {10});
|
|
|
|
eval(y);
|
|
|
|
CHECK_EQ(y.data_size(), 1);
|
|
|
|
CHECK_EQ(y.flags().contiguous, true);
|
|
|
|
CHECK_EQ(y.flags().row_contiguous, true);
|
|
|
|
CHECK_EQ(y.flags().col_contiguous, true);
|
|
|
|
|
|
|
|
x = array({1.0f, 2.0f, 3.0f}, {1, 3});
|
|
|
|
y = slice(x, {0, 0}, {1, 3}, {1, 1});
|
|
|
|
eval(y);
|
|
|
|
CHECK_EQ(y.data_size(), 3);
|
|
|
|
CHECK_EQ(y.flags().contiguous, true);
|
|
|
|
CHECK_EQ(y.flags().row_contiguous, true);
|
|
|
|
CHECK_EQ(y.flags().col_contiguous, true);
|
|
|
|
|
|
|
|
x = array({1.0f, 2.0f, 3.0f}, {1, 3});
|
|
|
|
y = slice(x, {0, 0}, {1, 3}, {1, 1});
|
|
|
|
eval(y);
|
|
|
|
CHECK_EQ(y.data_size(), 3);
|
|
|
|
CHECK_EQ(y.flags().contiguous, true);
|
|
|
|
CHECK_EQ(y.flags().row_contiguous, true);
|
|
|
|
CHECK_EQ(y.flags().col_contiguous, true);
|
|
|
|
|
|
|
|
x = array({1.0f, 2.0f, 3.0f}, {1, 3});
|
|
|
|
y = slice(x, {0, 0}, {0, 3}, {1, 1});
|
|
|
|
eval(y);
|
|
|
|
CHECK_EQ(y.data_size(), 0);
|
|
|
|
CHECK_EQ(y.flags().contiguous, true);
|
|
|
|
CHECK_EQ(y.flags().row_contiguous, true);
|
|
|
|
CHECK_EQ(y.flags().col_contiguous, true);
|
|
|
|
|
|
|
|
x = array({1.0f, 2.0f, 3.0f}, {1, 3});
|
|
|
|
y = slice(x, {0, 0}, {1, 2}, {1, 1});
|
|
|
|
eval(y);
|
|
|
|
CHECK_EQ(y.data_size(), 2);
|
|
|
|
CHECK_EQ(y.flags().contiguous, true);
|
|
|
|
CHECK_EQ(y.flags().row_contiguous, true);
|
|
|
|
CHECK_EQ(y.flags().col_contiguous, true);
|
|
|
|
|
|
|
|
x = array({1.0f, 2.0f, 3.0f}, {1, 3});
|
|
|
|
y = slice(x, {0, 0}, {1, 2}, {2, 3});
|
|
|
|
eval(y);
|
2024-12-10 03:09:02 +08:00
|
|
|
CHECK_EQ(y.shape(), Shape{1, 1});
|
2023-11-30 02:52:08 +08:00
|
|
|
CHECK_EQ(y.data_size(), 1);
|
|
|
|
CHECK_EQ(y.flags().contiguous, true);
|
|
|
|
CHECK_EQ(y.flags().row_contiguous, true);
|
|
|
|
CHECK_EQ(y.flags().col_contiguous, true);
|
|
|
|
|
|
|
|
x = array({0.0f, 1.0f, 2.0f, 3.0f}, {1, 4});
|
|
|
|
y = slice(x, {0, 0}, {1, 4}, {1, 2});
|
|
|
|
eval(y);
|
2024-12-10 03:09:02 +08:00
|
|
|
CHECK_EQ(y.shape(), Shape{1, 2});
|
2023-11-30 02:52:08 +08:00
|
|
|
CHECK_EQ(y.flags().contiguous, false);
|
|
|
|
CHECK_EQ(y.flags().row_contiguous, false);
|
|
|
|
CHECK_EQ(y.flags().col_contiguous, false);
|
|
|
|
|
|
|
|
x = broadcast_to(array(1.0f), {4, 10});
|
|
|
|
y = slice(x, {0, 0}, {4, 10}, {2, 2});
|
|
|
|
eval(y);
|
2024-12-10 03:09:02 +08:00
|
|
|
CHECK_EQ(y.shape(), Shape{2, 5});
|
2023-11-30 02:52:08 +08:00
|
|
|
CHECK_EQ(y.data_size(), 1);
|
|
|
|
CHECK_EQ(y.flags().contiguous, true);
|
|
|
|
CHECK_EQ(y.flags().row_contiguous, false);
|
|
|
|
CHECK_EQ(y.flags().col_contiguous, false);
|
|
|
|
|
|
|
|
x = broadcast_to(array({1.0f, 2.0f}), {4, 2});
|
|
|
|
y = slice(x, {0, 0}, {1, 2}, {1, 1});
|
|
|
|
eval(y);
|
|
|
|
CHECK_EQ(y.data_size(), 2);
|
|
|
|
CHECK_EQ(y.flags().contiguous, true);
|
|
|
|
CHECK_EQ(y.flags().row_contiguous, true);
|
|
|
|
CHECK_EQ(y.flags().col_contiguous, true);
|
|
|
|
|
|
|
|
y = slice(x, {1, 0}, {2, 2}, {1, 1});
|
|
|
|
eval(y);
|
|
|
|
CHECK_EQ(y.data_size(), 2);
|
|
|
|
CHECK_EQ(y.flags().contiguous, true);
|
|
|
|
CHECK_EQ(y.flags().row_contiguous, true);
|
|
|
|
CHECK_EQ(y.flags().col_contiguous, true);
|
|
|
|
|
|
|
|
x = array({0.0f, 1.0f, 2.0f, 3.0f}, {2, 2});
|
|
|
|
y = slice(x, {0, 0}, {2, 2}, {1, 1});
|
|
|
|
eval(y);
|
|
|
|
CHECK_EQ(y.data_size(), 4);
|
|
|
|
CHECK_EQ(y.flags().contiguous, true);
|
|
|
|
CHECK_EQ(y.flags().row_contiguous, true);
|
|
|
|
CHECK_EQ(y.flags().col_contiguous, false);
|
|
|
|
|
|
|
|
y = slice(transpose(x), {0, 0}, {2, 2}, {1, 1});
|
|
|
|
eval(y);
|
|
|
|
CHECK_EQ(y.data_size(), 4);
|
|
|
|
CHECK_EQ(y.flags().contiguous, true);
|
|
|
|
CHECK_EQ(y.flags().row_contiguous, false);
|
|
|
|
CHECK_EQ(y.flags().col_contiguous, true);
|
|
|
|
|
|
|
|
x = ones({2, 4});
|
|
|
|
auto out = split(x, 2);
|
|
|
|
eval(out);
|
|
|
|
for (auto y : out) {
|
|
|
|
CHECK_EQ(y.data_size(), 4);
|
|
|
|
CHECK_EQ(y.flags().contiguous, true);
|
|
|
|
CHECK_EQ(y.flags().row_contiguous, true);
|
|
|
|
CHECK_EQ(y.flags().col_contiguous, true);
|
|
|
|
}
|
|
|
|
out = split(x, 4, 1);
|
|
|
|
eval(out);
|
|
|
|
for (auto y : out) {
|
|
|
|
CHECK_EQ(y.flags().contiguous, false);
|
|
|
|
CHECK_EQ(y.flags().row_contiguous, false);
|
|
|
|
CHECK_EQ(y.flags().col_contiguous, false);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
TEST_CASE("test array iteration") {
|
|
|
|
// Dim 0 arrays
|
|
|
|
auto arr = array(1);
|
|
|
|
CHECK_THROWS(arr.begin());
|
|
|
|
|
|
|
|
// Iterated arrays are read only
|
|
|
|
CHECK(std::is_const_v<decltype(*arr.begin())>);
|
|
|
|
|
|
|
|
arr = array({1, 2, 3, 4, 5});
|
|
|
|
int i = 0;
|
|
|
|
for (auto a : arr) {
|
|
|
|
i++;
|
|
|
|
CHECK_EQ(a.item<int>(), i);
|
|
|
|
}
|
|
|
|
CHECK_EQ(i, 5);
|
|
|
|
|
|
|
|
arr = array({1, 2, 3, 4}, {2, 2});
|
|
|
|
CHECK(array_equal(*arr.begin(), array({1, 2})).item<bool>());
|
|
|
|
CHECK(array_equal(*(arr.begin() + 1), array({3, 4})).item<bool>());
|
|
|
|
CHECK_EQ(arr.begin() + 2, arr.end());
|
|
|
|
}
|
|
|
|
|
|
|
|
TEST_CASE("test array shared buffer") {
|
2024-12-10 03:09:02 +08:00
|
|
|
Shape shape = {2, 2};
|
|
|
|
auto n_elem = shape[0] * shape[1];
|
2023-11-30 02:52:08 +08:00
|
|
|
|
|
|
|
allocator::Buffer buf_b = allocator::malloc(n_elem * sizeof(float));
|
|
|
|
void* buf_b_ptr = buf_b.raw_ptr();
|
|
|
|
float* float_buf_b = (float*)buf_b_ptr;
|
|
|
|
|
|
|
|
for (int i = 0; i < n_elem; i++) {
|
|
|
|
float_buf_b[i] = 2.;
|
|
|
|
}
|
|
|
|
|
|
|
|
CHECK_EQ(float_buf_b[0], ((float*)buf_b_ptr)[0]);
|
|
|
|
|
|
|
|
auto deleter = [float_buf_b](allocator::Buffer buf) {
|
|
|
|
CHECK_EQ(float_buf_b, (float*)buf.raw_ptr());
|
|
|
|
CHECK_EQ(float_buf_b[0], ((float*)buf.raw_ptr())[0]);
|
|
|
|
allocator::free(buf);
|
|
|
|
};
|
|
|
|
|
|
|
|
array a = ones(shape, float32);
|
|
|
|
array b = array(buf_b, shape, float32, deleter);
|
|
|
|
|
|
|
|
eval(a + b);
|
|
|
|
}
|
2024-02-14 15:34:17 +08:00
|
|
|
|
|
|
|
TEST_CASE("test make empty array") {
|
|
|
|
auto a = array({});
|
|
|
|
CHECK_EQ(a.size(), 0);
|
|
|
|
CHECK_EQ(a.dtype(), float32);
|
|
|
|
|
|
|
|
a = array({}, int32);
|
|
|
|
CHECK_EQ(a.size(), 0);
|
|
|
|
CHECK_EQ(a.dtype(), int32);
|
|
|
|
|
|
|
|
a = array({}, float32);
|
|
|
|
CHECK_EQ(a.size(), 0);
|
|
|
|
CHECK_EQ(a.dtype(), float32);
|
|
|
|
|
|
|
|
a = array({}, bool_);
|
|
|
|
CHECK_EQ(a.size(), 0);
|
|
|
|
CHECK_EQ(a.dtype(), bool_);
|
|
|
|
}
|