mlx/tests/creations_tests.cpp
Josh Soref 44c1ce5e6a
Spelling (#342)
* spelling: accumulates

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: across

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: additional

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: against

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: among

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: array

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: at least

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: available

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: axes

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: basically

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: bfloat

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: bounds

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: broadcast

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: buffer

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: class

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: coefficients

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: collision

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: combinations

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: committing

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: computation

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: consider

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: constructing

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: conversions

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: correctly

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: corresponding

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: declaration

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: default

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: dependency

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: destination

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: destructor

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: dimensions

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: divided

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: element-wise

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: elements

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: endianness

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: equivalent

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: explicitly

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: github

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: indices

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: irregularly

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: memory

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: metallib

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: negative

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: notable

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: optional

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: otherwise

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: overridden

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: partially

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: partition

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: perform

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: perturbations

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: positively

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: primitive

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: repeat

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: repeats

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: respect

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: respectively

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: result

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: rounding

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: separate

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: skipping

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: structure

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: the

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: transpose

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: unnecessary

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: unneeded

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: unsupported

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

---------

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>
2024-01-01 21:08:17 -08:00

242 lines
5.8 KiB
C++

// Copyright © 2023 Apple Inc.
#include "doctest/doctest.h"
#include "mlx/mlx.h"
using namespace mlx::core;
TEST_CASE("test arange") {
// Check type is inferred correctly
{
auto x = arange(10);
CHECK_EQ(x.dtype(), int32);
x = arange(10.0);
CHECK_EQ(x.dtype(), float32);
x = arange(10, float32);
CHECK_EQ(x.dtype(), float32);
x = arange(10, float16);
CHECK_EQ(x.dtype(), float16);
x = arange(10, bfloat16);
CHECK_EQ(x.dtype(), bfloat16);
x = arange(10.0, int32);
CHECK_EQ(x.dtype(), int32);
x = arange(0, 10);
CHECK_EQ(x.dtype(), int32);
x = arange(0.0, 10.0, int32);
CHECK_EQ(x.dtype(), int32);
x = arange(0.0, 10.0);
CHECK_EQ(x.dtype(), float32);
x = arange(0, 10, float32);
CHECK_EQ(x.dtype(), float32);
x = arange(0, 10, 0.1, float32);
CHECK_EQ(x.dtype(), float32);
x = arange(0.0, 10.0, 0.5, int32);
CHECK_EQ(x.dtype(), int32);
x = arange(10.0, uint32);
CHECK_EQ(x.dtype(), uint32);
x = arange(0.0, 10.0, uint32);
CHECK_EQ(x.dtype(), uint32);
x = arange(0.0, 10.0, 0.5, uint32);
CHECK_EQ(x.dtype(), uint32);
// arange unsupported for bool_
CHECK_THROWS_AS(arange(10, bool_), std::invalid_argument);
}
// Check correct sizes
{
auto x = arange(10);
CHECK_EQ(x.size(), 10);
x = arange(0.0, 10.0, 0.5);
CHECK_EQ(x.size(), 20);
x = arange(0.0, 10.0, 0.45);
CHECK_EQ(x.size(), 23);
x = arange(0, 10, 10);
CHECK_EQ(x.size(), 1);
x = arange(0, 10, 9);
CHECK_EQ(x.size(), 2);
x = arange(0, 10, 100);
CHECK_EQ(x.size(), 1);
x = arange(0, -10, 1);
CHECK_EQ(x.size(), 0);
x = arange(0, -10, -1);
CHECK_EQ(x.size(), 10);
x = arange(0, -10, -10);
CHECK_EQ(x.size(), 1);
}
// Check values
{
auto x = arange(0, 3);
CHECK(array_equal(x, array({0, 1, 2})).item<bool>());
x = arange(0, 3, 2);
CHECK(array_equal(x, array({0, 2})).item<bool>());
x = arange(0, 3, 3);
CHECK(array_equal(x, array({0})).item<bool>());
x = arange(0, -3, 1);
CHECK(array_equal(x, array({})).item<bool>());
x = arange(0, 3, -1);
CHECK(array_equal(x, array({})).item<bool>());
x = arange(0, -3, -1);
CHECK(array_equal(x, array({0, -1, -2})).item<bool>());
x = arange(0.0, 5.0, 0.5, int32);
CHECK(array_equal(x, zeros({10})).item<bool>());
x = arange(0.0, 5.0, 1.5, int32);
CHECK(array_equal(x, array({0, 1, 2, 3})).item<bool>());
x = arange(0.0, 5.0, 1.0, float16);
CHECK(array_equal(x, array({0, 1, 2, 3, 4}, float16)).item<bool>());
x = arange(0.0, 5.0, 1.0, bfloat16);
CHECK(array_equal(x, array({0, 1, 2, 3, 4}, bfloat16)).item<bool>());
x = arange(0.0, 5.0, 1.5, bfloat16);
CHECK(array_equal(x, array({0., 1.5, 3., 4.5}, bfloat16)).item<bool>());
}
}
TEST_CASE("test astype") {
// Check type conversions
{
auto x = array(1);
auto y = astype(x, float32);
CHECK_EQ(y.dtype(), float32);
CHECK_EQ(y.item<float>(), 1.0f);
y = astype(x, int32);
CHECK_EQ(y.dtype(), int32);
CHECK_EQ(y.item<int>(), 1);
x = array(-3.0f);
y = astype(x, int32);
CHECK_EQ(y.dtype(), int32);
CHECK_EQ(y.item<int>(), -3);
y = astype(x, uint32);
CHECK_EQ(y.dtype(), uint32);
// Use std::copy since the result is platform dependent
uint32_t v;
std::copy(x.data<float>(), x.data<float>() + 1, &v);
CHECK_EQ(y.item<uint32_t>(), v);
}
}
TEST_CASE("test full") {
// Check full works for different types
{
auto x = full({}, 0);
CHECK_EQ(x.dtype(), int32);
CHECK_EQ(x.item<int>(), 0);
x = full({}, 0.0);
CHECK_EQ(x.dtype(), float32);
CHECK_EQ(x.item<float>(), 0);
x = full({}, false);
CHECK_EQ(x.item<bool>(), false);
x = full({}, 0, int32);
CHECK_EQ(x.item<int>(), 0);
x = full({}, 0, float32);
CHECK_EQ(x.item<float>(), 0);
x = full({1, 2}, 2, float32);
CHECK(array_equal(x, array({2.0, 2.0}, {1, 2})).item<bool>());
x = full({2, 1}, 2, float32);
CHECK(array_equal(x, array({2.0, 2.0}, {2, 1})).item<bool>());
x = full({2}, false);
CHECK_EQ(x.dtype(), bool_);
CHECK(array_equal(x, array({false, false})).item<bool>());
x = full({2}, 1.0, bool_);
CHECK_EQ(x.dtype(), bool_);
CHECK(array_equal(x, array({true, true})).item<bool>());
x = full({2}, 1.0, uint32);
CHECK_EQ(x.dtype(), uint32);
CHECK(array_equal(x, array({1, 1})).item<bool>());
CHECK_THROWS_AS(full({2}, array({})), std::invalid_argument);
}
// Check broadcasting works
{
auto x = full({2, 2}, array({3, 4}, {2, 1}));
CHECK(array_equal(x, array({3, 3, 4, 4}, {2, 2})).item<bool>());
x = full({2, 2}, array({3, 4}, {1, 2}));
CHECK(array_equal(x, array({3, 4, 3, 4}, {2, 2})).item<bool>());
}
// Check zeros and ones
{
auto x = zeros({2, 2}, float32);
CHECK_EQ(x.shape(), std::vector<int>{2, 2});
CHECK_EQ(x.ndim(), 2);
CHECK_EQ(x.dtype(), float32);
auto y = array({0.0, 0.0, 0.0, 0.0}, {2, 2});
CHECK(array_equal(x, y).item<bool>());
x = ones({2, 2}, float32);
CHECK_EQ(x.shape(), std::vector<int>{2, 2});
CHECK_EQ(x.ndim(), 2);
CHECK_EQ(x.dtype(), float32);
y = array({1.0, 1.0, 1.0, 1.0}, {2, 2});
CHECK(array_equal(x, y).item<bool>());
x = zeros({2, 2}, int32);
y = zeros_like(x);
CHECK_EQ(y.dtype(), int32);
CHECK(array_equal(x, y).item<bool>());
x = ones({2, 2}, int32);
y = ones_like(x);
CHECK_EQ(y.dtype(), int32);
CHECK(array_equal(x, y).item<bool>());
}
// Works for empty shape and empty array
{
array x = ones({}, int32);
CHECK_EQ(x.shape(), std::vector<int>{});
CHECK_EQ(x.item<int>(), 1);
x = full({0}, array({}));
CHECK_EQ(x.shape(), std::vector<int>{0});
CHECK_EQ(x.size(), 0);
CHECK_THROWS_AS(full({}, array({})), std::invalid_argument);
}
}