Compile stride bug (#812)

* fix compile stride bug

* revert sdpa fix

* fix cpu

* fix bug with simplifying outputs
This commit is contained in:
Awni Hannun
2024-03-11 06:31:31 -07:00
committed by GitHub
parent a4d290adb9
commit 7c441600fe
9 changed files with 58 additions and 12 deletions

View File

@@ -703,3 +703,18 @@ TEST_CASE("test shapeless compile") {
CHECK_NE(out.inputs()[1].id(), out2.inputs()[1].id());
}
}
auto compile_broadcast_add(const std::vector<array>& inputs) {
auto b = zeros({8, 8});
return std::vector<array>{inputs[0] + b};
}
TEST_CASE("test compile strides") {
{
auto cfun = compile(compile_broadcast_add);
auto a = zeros({1, 8, 8});
auto out = cfun({a})[0];
eval(out);
CHECK_EQ(out.strides().size(), 3);
}
}