Reduce a little overhead (#871)

* some small overhead improvements

* use result_type in rms_norm

* remove release force

* fix + use non-vector version

* revert compile change

* fix ops

* a little more overhead

* a little more cleanup and overhead
This commit is contained in:
Awni Hannun
2024-03-22 17:29:36 -07:00
committed by GitHub
parent 6ee1112f30
commit be98f4ab6b
13 changed files with 239 additions and 240 deletions

View File

@@ -195,7 +195,7 @@ std::pair<array, array> qr(const array& a, StreamOrDevice s /* = {} */) {
auto out = array::make_arrays(
{a.shape(), a.shape()},
{a.dtype(), a.dtype()},
std::make_unique<QRF>(to_stream(s)),
std::make_shared<QRF>(to_stream(s)),
{astype(a, a.dtype(), s)});
return std::make_pair(out[0], out[1]);
}
@@ -234,7 +234,7 @@ std::vector<array> svd(const array& a, StreamOrDevice s /* = {} */) {
return array::make_arrays(
{u_shape, s_shape, vt_shape},
{a.dtype(), a.dtype(), a.dtype()},
std::make_unique<SVD>(to_stream(s)),
std::make_shared<SVD>(to_stream(s)),
{a});
}
@@ -258,7 +258,7 @@ array inv(const array& a, StreamOrDevice s /* = {} */) {
}
return array(
a.shape(), a.dtype(), std::make_unique<Inverse>(to_stream(s)), {a});
a.shape(), a.dtype(), std::make_shared<Inverse>(to_stream(s)), {a});
}
} // namespace mlx::core::linalg