mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-01 04:24:36 +08:00
Compile stride bug (#812)
* fix compile stride bug * revert sdpa fix * fix cpu * fix bug with simplifying outputs
This commit is contained in:
@@ -703,3 +703,18 @@ TEST_CASE("test shapeless compile") {
|
||||
CHECK_NE(out.inputs()[1].id(), out2.inputs()[1].id());
|
||||
}
|
||||
}
|
||||
|
||||
auto compile_broadcast_add(const std::vector<array>& inputs) {
|
||||
auto b = zeros({8, 8});
|
||||
return std::vector<array>{inputs[0] + b};
|
||||
}
|
||||
|
||||
TEST_CASE("test compile strides") {
|
||||
{
|
||||
auto cfun = compile(compile_broadcast_add);
|
||||
auto a = zeros({1, 8, 8});
|
||||
auto out = cfun({a})[0];
|
||||
eval(out);
|
||||
CHECK_EQ(out.strides().size(), 3);
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user