mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Reduce a little overhead (#871)
* some small overhead improvements * use result_type in rms_norm * remove release force * fix + use non-vector version * revert compile change * fix ops * a little more overhead * a little more cleanup and overhead
This commit is contained in:
@@ -195,7 +195,7 @@ std::pair<array, array> qr(const array& a, StreamOrDevice s /* = {} */) {
|
||||
auto out = array::make_arrays(
|
||||
{a.shape(), a.shape()},
|
||||
{a.dtype(), a.dtype()},
|
||||
std::make_unique<QRF>(to_stream(s)),
|
||||
std::make_shared<QRF>(to_stream(s)),
|
||||
{astype(a, a.dtype(), s)});
|
||||
return std::make_pair(out[0], out[1]);
|
||||
}
|
||||
@@ -234,7 +234,7 @@ std::vector<array> svd(const array& a, StreamOrDevice s /* = {} */) {
|
||||
return array::make_arrays(
|
||||
{u_shape, s_shape, vt_shape},
|
||||
{a.dtype(), a.dtype(), a.dtype()},
|
||||
std::make_unique<SVD>(to_stream(s)),
|
||||
std::make_shared<SVD>(to_stream(s)),
|
||||
{a});
|
||||
}
|
||||
|
||||
@@ -258,7 +258,7 @@ array inv(const array& a, StreamOrDevice s /* = {} */) {
|
||||
}
|
||||
|
||||
return array(
|
||||
a.shape(), a.dtype(), std::make_unique<Inverse>(to_stream(s)), {a});
|
||||
a.shape(), a.dtype(), std::make_shared<Inverse>(to_stream(s)), {a});
|
||||
}
|
||||
|
||||
} // namespace mlx::core::linalg
|
||||
|
||||
Reference in New Issue
Block a user