More shape type (#1705)

* more shape type

* fix
This commit is contained in:
Awni Hannun
2024-12-19 08:08:20 -08:00
committed by GitHub
parent f17536af9c
commit e03f0372b1
38 changed files with 260 additions and 258 deletions

View File

@@ -611,8 +611,8 @@ TEST_CASE("test categorical") {
CHECK_THROWS(categorical(logits, -3));
// Invalid requested shapes
CHECK_THROWS(categorical(logits, 1, std::vector<int>{1}));
CHECK_THROWS(categorical(logits, 1, std::vector<int>{11}));
CHECK_THROWS(categorical(logits, 1, Shape{1}));
CHECK_THROWS(categorical(logits, 1, Shape{11}));
CHECK_THROWS(categorical(logits, 1, {10, 1}));
CHECK_EQ(categorical(logits, -1).shape(), Shape{10});