mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-22 02:58:16 +08:00
Faster cpu ops (#1434)
* faster binary and cleaner copy * use recursive template for other ops * more cleanup * fix from cleanup * more clean * fix binary * use contiguous iterator * add 3d * nits * fix * fix? * fix * fix rebase
This commit is contained in:
@@ -73,11 +73,7 @@ void binary_op_gpu_inplace(
|
||||
// Try to collapse contiguous dims
|
||||
auto maybe_collapse = [bopt, &a, &b, &out]() {
|
||||
if (bopt == BinaryOpType::General) {
|
||||
// The size cap here should ideally be `UINT32_MAX` but we are
|
||||
// limitied by the shape being an int.
|
||||
auto [shape, strides] = collapse_contiguous_dims(
|
||||
{a, b, out},
|
||||
/* size_cap = */ INT32_MAX);
|
||||
auto [shape, strides] = collapse_contiguous_dims(a, b, out);
|
||||
return std::make_tuple(shape, strides[0], strides[1], strides[2]);
|
||||
} else {
|
||||
std::vector<size_t> e;
|
||||
|
@@ -26,11 +26,7 @@ void ternary_op_gpu_inplace(
|
||||
// Try to collapse contiguous dims
|
||||
auto maybe_collapse = [topt, &a, &b, &c, &out]() {
|
||||
if (topt == TernaryOpType::General) {
|
||||
// The size cap here should ideally be `UINT32_MAX` but we are
|
||||
// limitied by the shape being an int.
|
||||
auto [shape, strides] = collapse_contiguous_dims(
|
||||
{a, b, c, out},
|
||||
/* size_cap = */ INT32_MAX);
|
||||
auto [shape, strides] = collapse_contiguous_dims(a, b, c, out);
|
||||
return std::make_tuple(
|
||||
shape, strides[0], strides[1], strides[2], strides[3]);
|
||||
} else {
|
||||
|
@@ -28,10 +28,7 @@ void unary_op_gpu_inplace(
|
||||
|
||||
auto maybe_collapse = [contig, &in, &out]() {
|
||||
if (!contig) {
|
||||
auto [shape, strides] = collapse_contiguous_dims(
|
||||
{in, out},
|
||||
/* size_cap = */ INT32_MAX);
|
||||
return std::make_pair(shape, strides[0]);
|
||||
return collapse_contiguous_dims(in);
|
||||
} else {
|
||||
return std::make_pair(std::vector<int>{}, std::vector<size_t>{});
|
||||
}
|
||||
|
Reference in New Issue
Block a user