mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Faster cpu ops (#1434)
* faster binary and cleaner copy * use recursive template for other ops * more cleanup * fix from cleanup * more clean * fix binary * use contiguous iterator * add 3d * nits * fix * fix? * fix * fix rebase
This commit is contained in:
@@ -26,11 +26,7 @@ void ternary_op_gpu_inplace(
|
||||
// Try to collapse contiguous dims
|
||||
auto maybe_collapse = [topt, &a, &b, &c, &out]() {
|
||||
if (topt == TernaryOpType::General) {
|
||||
// The size cap here should ideally be `UINT32_MAX` but we are
|
||||
// limitied by the shape being an int.
|
||||
auto [shape, strides] = collapse_contiguous_dims(
|
||||
{a, b, c, out},
|
||||
/* size_cap = */ INT32_MAX);
|
||||
auto [shape, strides] = collapse_contiguous_dims(a, b, c, out);
|
||||
return std::make_tuple(
|
||||
shape, strides[0], strides[1], strides[2], strides[3]);
|
||||
} else {
|
||||
|
||||
Reference in New Issue
Block a user