mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-20 17:38:09 +08:00
Fix eye for larger matrices (#463)
* fix eye * fix scatter for <32bit (non native atomic) types * fix int overflow
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
#include <cmath>
|
||||
#include <iostream> // TODO
|
||||
#include <numeric>
|
||||
|
||||
#include "doctest/doctest.h"
|
||||
@@ -509,13 +510,14 @@ TEST_CASE("test is inf") {
|
||||
array x(1.0f);
|
||||
CHECK_FALSE(isinf(x).item<bool>());
|
||||
|
||||
array y(std::numeric_limits<double>::infinity());
|
||||
auto inf = std::numeric_limits<float>::infinity();
|
||||
array y(inf);
|
||||
CHECK(isinf(y).item<bool>());
|
||||
|
||||
array z = identity(7);
|
||||
CHECK_FALSE(any(isinf(z)).item<bool>());
|
||||
|
||||
array w = array({1.0f, std::numeric_limits<double>::infinity(), 2.0f});
|
||||
array w = array({1.0f, inf, 2.0f});
|
||||
CHECK(array_equal({false, true, false}, isinf(w)).item<bool>());
|
||||
|
||||
array a(1.0f, bfloat16);
|
||||
@@ -524,10 +526,10 @@ TEST_CASE("test is inf") {
|
||||
array b(1.0f, float16);
|
||||
CHECK_FALSE(isinf(b).item<bool>());
|
||||
|
||||
array c(std::numeric_limits<double>::infinity(), bfloat16);
|
||||
array c(inf, bfloat16);
|
||||
CHECK(isinf(c).item<bool>());
|
||||
|
||||
array d(std::numeric_limits<double>::infinity(), float16);
|
||||
array d(inf, float16);
|
||||
CHECK(isinf(d).item<bool>());
|
||||
}
|
||||
|
||||
@@ -1878,6 +1880,28 @@ TEST_CASE("test scatter") {
|
||||
CHECK(array_equal(out, array({1, 0, 1, 0}, {2, 2})).item<bool>());
|
||||
}
|
||||
|
||||
TEST_CASE("test scatter types") {
|
||||
for (auto t : {bool_, uint8, uint16, int8, int16}) {
|
||||
auto in = zeros({4, 4}, t);
|
||||
auto inds = {arange(4), arange(4)};
|
||||
auto updates = ones({4, 1, 1}, t);
|
||||
auto out = scatter(in, inds, updates, {0, 1});
|
||||
auto expected =
|
||||
array({1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1}, {4, 4}, t);
|
||||
CHECK(array_equal(out, expected).item<bool>());
|
||||
}
|
||||
|
||||
for (auto t : {float16, bfloat16}) {
|
||||
auto in = zeros({4, 4}, t);
|
||||
auto inds = {arange(4), arange(4)};
|
||||
auto updates = ones({4, 1, 1}, t);
|
||||
auto out = scatter(in, inds, updates, {0, 1});
|
||||
auto expected =
|
||||
array({1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1}, {4, 4}, t);
|
||||
CHECK(allclose(out, expected).item<bool>());
|
||||
}
|
||||
}
|
||||
|
||||
TEST_CASE("test complex ops") {
|
||||
// Creation ops
|
||||
{
|
||||
|
Reference in New Issue
Block a user