mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 09:21:16 +08:00
546 lines
15 KiB
C++
546 lines
15 KiB
C++
![]() |
#include <numeric>
|
||
|
|
||
|
#include "doctest/doctest.h"
|
||
|
|
||
|
#include "mlx/mlx.h"
|
||
|
|
||
|
using namespace mlx::core;
|
||
|
|
||
|
TEST_CASE("test random key") {
|
||
|
auto key = random::key(0);
|
||
|
CHECK(array_equal(key, array({0, 0})).item<bool>());
|
||
|
|
||
|
key = random::key(1);
|
||
|
CHECK(array_equal(key, array({0, 1})).item<bool>());
|
||
|
|
||
|
int64_t seed = static_cast<int64_t>(1) << 32;
|
||
|
key = random::key(seed);
|
||
|
CHECK(array_equal(key, array({1, 0})).item<bool>());
|
||
|
|
||
|
key = random::key(seed + 1);
|
||
|
CHECK(array_equal(key, array({1, 1})).item<bool>());
|
||
|
}
|
||
|
|
||
|
TEST_CASE("test global rng") {
|
||
|
random::seed(4);
|
||
|
auto x = random::bits({});
|
||
|
auto y = random::bits({});
|
||
|
|
||
|
random::seed(4);
|
||
|
auto a = random::bits({});
|
||
|
auto b = random::bits({});
|
||
|
|
||
|
CHECK_EQ(x.item<uint32_t>(), a.item<uint32_t>());
|
||
|
CHECK_EQ(y.item<uint32_t>(), b.item<uint32_t>());
|
||
|
}
|
||
|
|
||
|
TEST_CASE("test random split") {
|
||
|
auto [key, subkey] = random::split(random::key(0));
|
||
|
CHECK(array_equal(key, array({4146024105u, 967050713u})).item<bool>());
|
||
|
CHECK(array_equal(subkey, array({2718843009u, 1272950319u})).item<bool>());
|
||
|
|
||
|
auto keys = random::split(random::key(0), 3);
|
||
|
auto expected = array(
|
||
|
{2467461003u,
|
||
|
428148500u,
|
||
|
3186719485u,
|
||
|
3840466878u,
|
||
|
2562233961u,
|
||
|
1946702221u},
|
||
|
{3, 2});
|
||
|
CHECK(array_equal(keys, expected).item<bool>());
|
||
|
}
|
||
|
|
||
|
TEST_CASE("test random bits") {
|
||
|
// Test shapes, types, and sizes
|
||
|
{
|
||
|
auto key = random::key(0);
|
||
|
auto x = random::bits({}, key);
|
||
|
CHECK_EQ(x.size(), 1);
|
||
|
CHECK_EQ(x.dtype(), uint32);
|
||
|
|
||
|
x = random::bits({0}, key);
|
||
|
CHECK(array_equal(x, array({})).item<bool>());
|
||
|
|
||
|
// Check wrong key type or shape
|
||
|
key = array({0, 0});
|
||
|
CHECK_THROWS_AS(random::uniform({}, key), std::invalid_argument);
|
||
|
key = array({0, 0}, {1, 2});
|
||
|
CHECK_THROWS_AS(random::uniform({}, key), std::invalid_argument);
|
||
|
key = array({0u, 0u, 0u}, {3, 1});
|
||
|
CHECK_THROWS_AS(random::uniform({}, key), std::invalid_argument);
|
||
|
key = array({0u, 0u}, {2, 1});
|
||
|
CHECK_THROWS_AS(random::uniform({}, key), std::invalid_argument);
|
||
|
}
|
||
|
|
||
|
// Expected bits in the following tests were generated from
|
||
|
// Jax's Threefry 2x32 implementation using the following in
|
||
|
// python:
|
||
|
//
|
||
|
// ```
|
||
|
// import jax
|
||
|
// import jax.prng
|
||
|
// shape = (SET THIS)
|
||
|
// seed = (SET THIS)
|
||
|
// width = (SET THIS)
|
||
|
// key = jax.random.PRNGKey(seed)
|
||
|
// print(jax.prng.threefry_prng_impl.random_bits(key, width, shape))
|
||
|
|
||
|
{
|
||
|
auto key = random::key(0);
|
||
|
auto x = random::bits({}, key);
|
||
|
auto y = random::bits({}, key);
|
||
|
CHECK_EQ(x.item<uint32_t>(), 1797259609u);
|
||
|
CHECK_EQ(x.item<uint32_t>(), y.item<uint32_t>());
|
||
|
|
||
|
x = random::bits({}, 2, key);
|
||
|
CHECK_EQ(x.item<uint16_t>(), 345);
|
||
|
|
||
|
x = random::bits({}, 1, key);
|
||
|
CHECK_EQ(x.item<uint8_t>(), 89);
|
||
|
}
|
||
|
|
||
|
{
|
||
|
auto key = random::key(1);
|
||
|
auto x = random::bits({}, key);
|
||
|
CHECK_EQ(x.item<uint32_t>(), 507451445u);
|
||
|
|
||
|
x = random::bits({}, 2, key);
|
||
|
CHECK_EQ(x.item<uint16_t>(), 6197);
|
||
|
|
||
|
x = random::bits({}, 1, key);
|
||
|
CHECK_EQ(x.item<uint8_t>(), 53);
|
||
|
|
||
|
CHECK_THROWS(random::bits({}, 0, key));
|
||
|
CHECK_THROWS(random::bits({}, 5, key));
|
||
|
CHECK_THROWS(random::bits({}, -1, key));
|
||
|
}
|
||
|
|
||
|
{
|
||
|
auto key = random::key(0);
|
||
|
auto x = random::bits({3, 1}, key);
|
||
|
auto expected = array({4146024105u, 1351547692u, 2718843009u}, {3, 1});
|
||
|
CHECK(array_equal(x, expected).item<bool>());
|
||
|
|
||
|
x = random::bits({5}, 2, key);
|
||
|
expected = array({20137, 63263, 64300, 20622, 16513}, uint16);
|
||
|
CHECK(array_equal(x, expected).item<bool>());
|
||
|
expected = array({20137, 63263, 64300, 20622, 16513, 41486}, uint16);
|
||
|
x = random::bits({6}, 2, key);
|
||
|
CHECK(array_equal(x, expected).item<bool>());
|
||
|
expected = array({20137, 63263, 1497, 14756, 16513, 41486, 44591}, uint16);
|
||
|
x = random::bits({7}, 2, key);
|
||
|
CHECK(array_equal(x, expected).item<bool>());
|
||
|
x = random::bits({8}, 2, key);
|
||
|
expected =
|
||
|
array({20137, 63263, 1497, 14756, 16513, 41486, 44591, 19423}, uint16);
|
||
|
CHECK(array_equal(x, expected).item<bool>());
|
||
|
}
|
||
|
|
||
|
{
|
||
|
auto key = array({0u, 0u, 1u, 1u}, {2, 2});
|
||
|
auto shape = std::vector<int>{3};
|
||
|
auto fn = [&shape](array k) { return random::bits(shape, k); };
|
||
|
|
||
|
auto expected = array(
|
||
|
{4146024105u,
|
||
|
1351547692u,
|
||
|
2718843009u,
|
||
|
3725146706u,
|
||
|
1802982961u,
|
||
|
1349634643u},
|
||
|
{2, 3});
|
||
|
CHECK(array_equal(vmap(fn)(key), expected).item<bool>());
|
||
|
expected = array(
|
||
|
{2441914641u,
|
||
|
1110694964u,
|
||
|
3819641963u,
|
||
|
2441914641u,
|
||
|
1110694964u,
|
||
|
3819641963u},
|
||
|
{2, 3});
|
||
|
CHECK(array_equal(vmap(fn, 1)(key), expected).item<bool>());
|
||
|
|
||
|
// Vmap twice
|
||
|
key = array(
|
||
|
{0u,
|
||
|
0u,
|
||
|
1u,
|
||
|
1u,
|
||
|
2u,
|
||
|
2u,
|
||
|
|
||
|
3u,
|
||
|
3u,
|
||
|
4u,
|
||
|
4u,
|
||
|
5u,
|
||
|
5u},
|
||
|
{3, 2, 2});
|
||
|
shape = {2};
|
||
|
auto out = vmap(vmap(fn))(key);
|
||
|
expected = array(
|
||
|
{928981903u,
|
||
|
3453687069u,
|
||
|
3606183818u,
|
||
|
460005496u,
|
||
|
|
||
|
2799733733u,
|
||
|
856293553u,
|
||
|
4081856343u,
|
||
|
3445925136u,
|
||
|
|
||
|
2775548010u,
|
||
|
1430281703u,
|
||
|
305173070u,
|
||
|
2615843348u},
|
||
|
{3, 2, 2});
|
||
|
CHECK(array_equal(out, expected).item<bool>());
|
||
|
|
||
|
out = vmap(vmap(fn, 1), 0)(key);
|
||
|
expected = array(
|
||
|
{1948878966u,
|
||
|
4237131848u,
|
||
|
1948878966u,
|
||
|
4237131848u,
|
||
|
|
||
|
2531170506u,
|
||
|
1858648356u,
|
||
|
2531170506u,
|
||
|
1858648356u,
|
||
|
|
||
|
740561898u,
|
||
|
4234094099u,
|
||
|
740561898u,
|
||
|
4234094099u},
|
||
|
{3, 2, 2});
|
||
|
CHECK(array_equal(out, expected).item<bool>());
|
||
|
}
|
||
|
|
||
|
// Vmap smaller type
|
||
|
{
|
||
|
auto key = array({0u, 0u, 1u, 1u}, {2, 2});
|
||
|
auto fn = [](array k) { return random::bits({5}, 2, k); };
|
||
|
|
||
|
auto expected = array(
|
||
|
{4146024105u,
|
||
|
1351547692u,
|
||
|
2718843009u,
|
||
|
3725146706u,
|
||
|
1802982961u,
|
||
|
1349634643u},
|
||
|
{2, 3});
|
||
|
auto out = vmap(fn)(key);
|
||
|
auto x1 = random::bits({5}, 2, take(key, array(0), 0));
|
||
|
auto x2 = random::bits({5}, 2, take(key, array(1), 0));
|
||
|
|
||
|
CHECK(array_equal(take(out, array(0), 0), x1).item<bool>());
|
||
|
CHECK(array_equal(take(out, array(1), 0), x2).item<bool>());
|
||
|
}
|
||
|
}
|
||
|
|
||
|
TEST_CASE("test random uniform") {
|
||
|
// Test shapes, types, and sizes
|
||
|
{
|
||
|
auto x = random::uniform({});
|
||
|
CHECK_EQ(x.size(), 1);
|
||
|
CHECK_EQ(x.dtype(), float32);
|
||
|
|
||
|
if (is_available(float16)) {
|
||
|
x = random::uniform({}, float16);
|
||
|
CHECK_EQ(x.size(), 1);
|
||
|
CHECK_EQ(x.dtype(), float16);
|
||
|
}
|
||
|
|
||
|
x = random::uniform({0});
|
||
|
CHECK(array_equal(x, array({})).item<bool>());
|
||
|
|
||
|
// Non float type throws
|
||
|
CHECK_THROWS_AS(random::uniform({}, int32), std::invalid_argument);
|
||
|
|
||
|
// Check broadcasting
|
||
|
x = random::uniform(zeros({3, 1}), ones({1, 3}), {3, 3});
|
||
|
CHECK_EQ(x.shape(), std::vector<int>{3, 3});
|
||
|
CHECK_THROWS_AS(
|
||
|
random::uniform(zeros({3, 3}), 1.0, {1, 3}), std::invalid_argument);
|
||
|
CHECK_THROWS_AS(
|
||
|
random::uniform(zeros({3, 3}), 1.0, {2, 3}), std::invalid_argument);
|
||
|
CHECK_THROWS_AS(
|
||
|
random::uniform(zeros({3, 1}), ones({1, 3}), {1, 3}),
|
||
|
std::invalid_argument);
|
||
|
|
||
|
// Check wrong key type or shape
|
||
|
auto key = array({0, 0});
|
||
|
CHECK_THROWS_AS(random::uniform({}, key), std::invalid_argument);
|
||
|
key = array({0, 0}, {1, 2});
|
||
|
CHECK_THROWS_AS(random::uniform({}, key), std::invalid_argument);
|
||
|
key = array({0u, 0u, 0u}, {3, 1});
|
||
|
CHECK_THROWS_AS(random::uniform({}, key), std::invalid_argument);
|
||
|
key = array({0u, 0u}, {2, 1});
|
||
|
CHECK_THROWS_AS(random::uniform({}, key), std::invalid_argument);
|
||
|
}
|
||
|
|
||
|
// Expected bits in the following tests were generated from
|
||
|
// Jax's Threefry 2x32 implementation using the following in
|
||
|
// python:
|
||
|
//
|
||
|
// ```
|
||
|
// import jax
|
||
|
// import jax.prng
|
||
|
// shape = (SET THIS)
|
||
|
// seed = (SET THIS)
|
||
|
// key = jax.random.PRNGKey(seed)
|
||
|
// print(jax.prng.threefry_prng_impl.random_bits(key, 32, shape))
|
||
|
|
||
|
constexpr auto to_float = [](uint32_t n) {
|
||
|
return static_cast<float>(n) / UINT32_MAX;
|
||
|
};
|
||
|
|
||
|
{
|
||
|
auto key = random::key(0);
|
||
|
auto x = random::uniform({}, key);
|
||
|
auto y = random::uniform({}, key);
|
||
|
auto expected = to_float(1797259609);
|
||
|
CHECK_EQ(x.item<float>(), expected);
|
||
|
CHECK_EQ(x.item<float>(), y.item<float>());
|
||
|
}
|
||
|
|
||
|
{
|
||
|
auto key = random::key(1);
|
||
|
auto x = random::uniform({}, key);
|
||
|
auto expected = to_float(507451445);
|
||
|
CHECK_EQ(x.item<float>(), expected);
|
||
|
}
|
||
|
|
||
|
{
|
||
|
auto key = random::key(0);
|
||
|
auto x = random::uniform({3, 1}, key);
|
||
|
auto expected = array(
|
||
|
{to_float(4146024105), to_float(1351547692), to_float(2718843009)},
|
||
|
{3, 1});
|
||
|
CHECK(array_equal(x, expected).item<bool>());
|
||
|
}
|
||
|
|
||
|
// Check vmap
|
||
|
{
|
||
|
auto key = random::key(0);
|
||
|
auto fun = [](array k, array low) {
|
||
|
return random::uniform(low, 1, {3}, float32, k);
|
||
|
};
|
||
|
auto out = vmap(fun, -1)(key, zeros({2, 3}));
|
||
|
CHECK_EQ(out.shape(), std::vector<int>{2, 3});
|
||
|
|
||
|
key = zeros({2, 2}, uint32);
|
||
|
out = vmap(fun)(key, zeros({2, 3}));
|
||
|
CHECK_EQ(out.shape(), std::vector<int>{2, 3});
|
||
|
}
|
||
|
|
||
|
// Check bounds are respected
|
||
|
{
|
||
|
auto key = random::key(128291);
|
||
|
auto out = random::uniform(array(-1.0f), array(1.0f), {100}, float32, key);
|
||
|
CHECK(all(less(out, array(1.0f))).item<bool>());
|
||
|
CHECK(all(greater_equal(out, array(-1.0f))).item<bool>());
|
||
|
}
|
||
|
}
|
||
|
|
||
|
TEST_CASE("test random normal") {
|
||
|
// Test shapes, types, and sizes
|
||
|
{
|
||
|
auto x = random::normal({});
|
||
|
CHECK_EQ(x.size(), 1);
|
||
|
CHECK_EQ(x.dtype(), float32);
|
||
|
|
||
|
x = random::uniform({0});
|
||
|
CHECK(array_equal(x, array({})).item<bool>());
|
||
|
|
||
|
// Non float type throws
|
||
|
CHECK_THROWS_AS(random::normal({}, int32), std::invalid_argument);
|
||
|
|
||
|
// Check wrong key type or shape
|
||
|
auto key = array({0, 0});
|
||
|
CHECK_THROWS_AS(random::normal({}, key), std::invalid_argument);
|
||
|
key = array({0, 0}, {1, 2});
|
||
|
CHECK_THROWS_AS(random::normal({}, key), std::invalid_argument);
|
||
|
key = array({0u, 0u, 0u}, {3, 1});
|
||
|
CHECK_THROWS_AS(random::normal({}, key), std::invalid_argument);
|
||
|
key = array({0u, 0u}, {2, 1});
|
||
|
CHECK_THROWS_AS(random::normal({}, key), std::invalid_argument);
|
||
|
}
|
||
|
|
||
|
{
|
||
|
constexpr float inf = std::numeric_limits<float>::infinity();
|
||
|
auto key = random::key(128291);
|
||
|
auto out = random::normal({100}, key);
|
||
|
CHECK(all(less(abs(out), array(inf))).item<bool>());
|
||
|
}
|
||
|
}
|
||
|
|
||
|
TEST_CASE("test random randint") {
|
||
|
CHECK_THROWS_AS(
|
||
|
random::randint(array(3), array(5), {1}, float32), std::invalid_argument);
|
||
|
|
||
|
auto x = random::randint(0, 10, {}, uint32);
|
||
|
CHECK_EQ(x.size(), 1);
|
||
|
CHECK_EQ(x.dtype(), uint32);
|
||
|
|
||
|
x = random::randint(0, 2, {}, bool_);
|
||
|
CHECK_EQ(x.size(), 1);
|
||
|
CHECK_EQ(x.dtype(), bool_);
|
||
|
|
||
|
x = random::randint(0, 2, {}, int32);
|
||
|
CHECK_EQ(x.size(), 1);
|
||
|
CHECK_EQ(x.dtype(), int32);
|
||
|
|
||
|
x = random::randint(0, 2, {}, int64);
|
||
|
CHECK_EQ(x.size(), 1);
|
||
|
CHECK_EQ(x.dtype(), int64);
|
||
|
|
||
|
// Check all in bounds
|
||
|
auto low = -10.0;
|
||
|
auto high = 20.0;
|
||
|
x = random::randint(low, high, {1000, 1000});
|
||
|
CHECK((all(low <= x).item<bool>() && all(x < high).item<bool>()));
|
||
|
|
||
|
// Check high < low => all equals to low
|
||
|
low = 20.0;
|
||
|
high = -10.0;
|
||
|
x = random::randint(low, high, {3, 3});
|
||
|
CHECK(all(equal(x, array(low))).item<bool>());
|
||
|
|
||
|
// Check wrong key type or shape
|
||
|
auto key = array({0, 0}, {1, 2});
|
||
|
CHECK_THROWS_AS(
|
||
|
random::randint(low, high, {}, float32, key), std::invalid_argument);
|
||
|
}
|
||
|
|
||
|
TEST_CASE("test random bernoulli") {
|
||
|
auto x = random::bernoulli();
|
||
|
|
||
|
CHECK_EQ(x.size(), 1);
|
||
|
CHECK_EQ(x.dtype(), bool_);
|
||
|
|
||
|
// Bernoulli parameter can have floating point type
|
||
|
if (is_available(float16)) {
|
||
|
x = random::bernoulli(array(0.5, float16));
|
||
|
CHECK_EQ(x.size(), 1);
|
||
|
CHECK_EQ(x.dtype(), bool_);
|
||
|
}
|
||
|
|
||
|
CHECK_THROWS(random::bernoulli(array(1, int32)));
|
||
|
|
||
|
// Negative numbers allowed in Jax
|
||
|
x = random::bernoulli(array(-1.0));
|
||
|
CHECK_FALSE(x.item<bool>());
|
||
|
|
||
|
x = random::bernoulli(array(5.0));
|
||
|
CHECK(x.item<bool>());
|
||
|
|
||
|
// Return array with correct shape
|
||
|
x = random::bernoulli(0.5, {3, 3});
|
||
|
CHECK_EQ(x.shape(), std::vector<int>({3, 3}));
|
||
|
|
||
|
// Try with p = {}
|
||
|
x = random::bernoulli(array({}));
|
||
|
CHECK_EQ(x.size(), 0);
|
||
|
|
||
|
// Try broadcasting
|
||
|
auto p = array({0.1, 0.2, 0.3});
|
||
|
p = reshape(p, {1, 3});
|
||
|
x = random::bernoulli(p, {4, 3});
|
||
|
CHECK_EQ(x.shape(), std::vector<int>({4, 3}));
|
||
|
|
||
|
CHECK_THROWS_AS(random::bernoulli(array({}), {3, 3}), std::invalid_argument);
|
||
|
|
||
|
p = array({0.1, 0.2, 0.3});
|
||
|
// Ask for the wrong shape => throws
|
||
|
CHECK_THROWS_AS(random::bernoulli(p, {2}), std::invalid_argument);
|
||
|
|
||
|
// Check wrong key type or shape
|
||
|
auto key = array({0, 0}, {1, 2});
|
||
|
CHECK_THROWS_AS(random::bernoulli(array(0.5), key), std::invalid_argument);
|
||
|
}
|
||
|
|
||
|
TEST_CASE("Test truncated normal") {
|
||
|
auto x = random::truncated_normal(array(-2.0), array(2.0));
|
||
|
|
||
|
CHECK_EQ(x.size(), 1);
|
||
|
CHECK_EQ(x.dtype(), float32);
|
||
|
|
||
|
if (is_available(float16)) {
|
||
|
x = random::truncated_normal(array(-2.0), array(2.0), {}, float16);
|
||
|
CHECK_EQ(x.size(), 1);
|
||
|
CHECK_EQ(x.dtype(), float16);
|
||
|
}
|
||
|
|
||
|
// Requested shape
|
||
|
x = random::truncated_normal(array(-2.0), array(2.0), {3, 4});
|
||
|
CHECK_EQ(x.shape(), std::vector<int>({3, 4}));
|
||
|
|
||
|
// Empty array
|
||
|
x = random::truncated_normal(array({}), array({}));
|
||
|
CHECK_EQ(x.size(), 0);
|
||
|
|
||
|
// Broadcast
|
||
|
auto lower = reshape(array({-2.0, -3.0}), {1, 2});
|
||
|
auto higher = reshape(array({0.0, 3.0, 1.5}), {3, 1});
|
||
|
x = random::truncated_normal(lower, higher);
|
||
|
|
||
|
// All in bounds
|
||
|
CHECK_EQ(x.shape(), std::vector<int>({3, 2}));
|
||
|
CHECK((all(x <= higher).item<bool>() && all(lower <= x).item<bool>()));
|
||
|
|
||
|
// high < low => all equal to low
|
||
|
x = random::truncated_normal(array(2.0), array(-2.0));
|
||
|
CHECK(all(x == array(2.0)).item<bool>());
|
||
|
|
||
|
// Non broadcastable => throws
|
||
|
CHECK_THROWS_AS(
|
||
|
random::truncated_normal(lower, higher, {4, 2}), std::invalid_argument);
|
||
|
|
||
|
auto key = array({0, 0}, {1, 2});
|
||
|
CHECK_THROWS_AS(
|
||
|
random::truncated_normal(array(-2.0), array(2.0), {1, 1}, float32, key),
|
||
|
std::invalid_argument);
|
||
|
}
|
||
|
|
||
|
TEST_CASE("test categorical") {
|
||
|
auto logits = zeros({10, 20});
|
||
|
|
||
|
using random::categorical;
|
||
|
|
||
|
// Invalid axes
|
||
|
CHECK_THROWS(categorical(logits, 2));
|
||
|
CHECK_THROWS(categorical(logits, -3));
|
||
|
|
||
|
// Invalid requested shapes
|
||
|
CHECK_THROWS(categorical(logits, 1, std::vector<int>{1}));
|
||
|
CHECK_THROWS(categorical(logits, 1, std::vector<int>{11}));
|
||
|
CHECK_THROWS(categorical(logits, 1, {10, 1}));
|
||
|
|
||
|
CHECK_EQ(categorical(logits, -1).shape(), std::vector<int>{10});
|
||
|
CHECK_EQ(categorical(logits, 0).shape(), std::vector<int>{20});
|
||
|
CHECK_EQ(categorical(logits, 1).shape(), std::vector<int>{10});
|
||
|
|
||
|
auto out = categorical(logits);
|
||
|
CHECK_EQ(out.shape(), std::vector<int>{10});
|
||
|
CHECK_EQ(out.dtype(), uint32);
|
||
|
CHECK(max(out).item<uint32_t>() < 20);
|
||
|
|
||
|
out = categorical(logits, 0, {5, 20});
|
||
|
CHECK_EQ(out.shape(), std::vector<int>{5, 20});
|
||
|
CHECK(max(out).item<uint32_t>() < 10);
|
||
|
|
||
|
float inf = std::numeric_limits<float>::infinity();
|
||
|
logits = array({1.0f, -2.0f, inf, 4.0f, 3.0f});
|
||
|
CHECK_EQ(categorical(logits).item<uint32_t>(), 2);
|
||
|
|
||
|
logits = array({-inf, -2.0f, -inf, -inf});
|
||
|
CHECK_EQ(categorical(logits).item<uint32_t>(), 1);
|
||
|
|
||
|
logits = zeros({5, 4, 3});
|
||
|
CHECK_EQ(categorical(logits, -1, 7).shape(), std::vector<int>{5, 4, 7});
|
||
|
CHECK_EQ(categorical(logits, -2, 7).shape(), std::vector<int>{5, 3, 7});
|
||
|
CHECK_EQ(categorical(logits, -3, 7).shape(), std::vector<int>{4, 3, 7});
|
||
|
}
|