Double for lapack (#1904)

* double for lapack ops

* add double support for lapack ops
This commit is contained in:
Awni Hannun
2025-02-25 11:39:36 -08:00
committed by GitHub
parent 28b8079e30
commit 7d042f17fe
11 changed files with 338 additions and 225 deletions

View File

@@ -18,6 +18,14 @@ void check_cpu_stream(const StreamOrDevice& s, const std::string& prefix) {
"Explicitly pass a CPU stream to run it.");
}
}
void check_float(Dtype dtype, const std::string& prefix) {
if (dtype != float32 && dtype != float64) {
std::ostringstream msg;
msg << prefix << " Arrays must have type float32 or float64. "
<< "Received array with type " << dtype << ".";
throw std::invalid_argument(msg.str());
}
}
Dtype at_least_float(const Dtype& d) {
return issubdtype(d, inexact) ? d : promote_types(d, float32);
@@ -184,12 +192,8 @@ array norm(
std::pair<array, array> qr(const array& a, StreamOrDevice s /* = {} */) {
check_cpu_stream(s, "[linalg::qr]");
if (a.dtype() != float32) {
std::ostringstream msg;
msg << "[linalg::qr] Arrays must type float32. Received array "
<< "with type " << a.dtype() << ".";
throw std::invalid_argument(msg.str());
}
check_float(a.dtype(), "[linalg::qr]");
if (a.ndim() < 2) {
std::ostringstream msg;
msg << "[linalg::qr] Arrays must have >= 2 dimensions. Received array "
@@ -212,12 +216,8 @@ std::pair<array, array> qr(const array& a, StreamOrDevice s /* = {} */) {
std::vector<array> svd(const array& a, StreamOrDevice s /* = {} */) {
check_cpu_stream(s, "[linalg::svd]");
if (a.dtype() != float32) {
std::ostringstream msg;
msg << "[linalg::svd] Input array must have type float32. Received array "
<< "with type " << a.dtype() << ".";
throw std::invalid_argument(msg.str());
}
check_float(a.dtype(), "[linalg::svd]");
if (a.ndim() < 2) {
std::ostringstream msg;
msg << "[linalg::svd] Input array must have >= 2 dimensions. Received array "
@@ -251,12 +251,8 @@ std::vector<array> svd(const array& a, StreamOrDevice s /* = {} */) {
array inv_impl(const array& a, bool tri, bool upper, StreamOrDevice s) {
check_cpu_stream(s, "[linalg::inv]");
if (a.dtype() != float32) {
std::ostringstream msg;
msg << "[linalg::inv] Arrays must type float32. Received array "
<< "with type " << a.dtype() << ".";
throw std::invalid_argument(msg.str());
}
check_float(a.dtype(), "[linalg::inv]");
if (a.ndim() < 2) {
std::ostringstream msg;
msg << "[linalg::inv] Arrays must have >= 2 dimensions. Received array "
@@ -292,13 +288,7 @@ array cholesky(
bool upper /* = false */,
StreamOrDevice s /* = {} */) {
check_cpu_stream(s, "[linalg::cholesky]");
if (a.dtype() != float32) {
std::ostringstream msg;
msg << "[linalg::cholesky] Arrays must type float32. Received array "
<< "with type " << a.dtype() << ".";
throw std::invalid_argument(msg.str());
}
check_float(a.dtype(), "[linalg::cholesky]");
if (a.ndim() < 2) {
std::ostringstream msg;
msg << "[linalg::cholesky] Arrays must have >= 2 dimensions. Received array "
@@ -321,12 +311,8 @@ array cholesky(
array pinv(const array& a, StreamOrDevice s /* = {} */) {
check_cpu_stream(s, "[linalg::pinv]");
if (a.dtype() != float32) {
std::ostringstream msg;
msg << "[linalg::pinv] Arrays must type float32. Received array "
<< "with type " << a.dtype() << ".";
throw std::invalid_argument(msg.str());
}
check_float(a.dtype(), "[linalg::pinv]");
if (a.ndim() < 2) {
std::ostringstream msg;
msg << "[linalg::pinv] Arrays must have >= 2 dimensions. Received array "
@@ -368,12 +354,7 @@ array cholesky_inv(
bool upper /* = false */,
StreamOrDevice s /* = {} */) {
check_cpu_stream(s, "[linalg::cholesky_inv]");
if (L.dtype() != float32) {
std::ostringstream msg;
msg << "[linalg::cholesky_inv] Arrays must type float32. Received array "
<< "with type " << L.dtype() << ".";
throw std::invalid_argument(msg.str());
}
check_float(L.dtype(), "[linalg::cholesky_inv]");
if (L.ndim() < 2) {
std::ostringstream msg;
@@ -474,12 +455,7 @@ void validate_eigh(
const StreamOrDevice& stream,
const std::string fname) {
check_cpu_stream(stream, fname);
if (a.dtype() != float32) {
std::ostringstream msg;
msg << fname << " Arrays must have type float32. Received array "
<< "with type " << a.dtype() << ".";
throw std::invalid_argument(msg.str());
}
check_float(a.dtype(), fname);
if (a.ndim() < 2) {
std::ostringstream msg;
@@ -524,12 +500,7 @@ void validate_lu(
const StreamOrDevice& stream,
const std::string& fname) {
check_cpu_stream(stream, fname);
if (a.dtype() != float32) {
std::ostringstream msg;
msg << fname << " Arrays must type float32. Received array "
<< "with type " << a.dtype() << ".";
throw std::invalid_argument(msg.str());
}
check_float(a.dtype(), fname);
if (a.ndim() < 2) {
std::ostringstream msg;
@@ -627,10 +598,12 @@ void validate_solve(
}
auto out_type = promote_types(a.dtype(), b.dtype());
if (out_type != float32) {
if (out_type != float32 && out_type != float64) {
std::ostringstream msg;
msg << fname << " Input arrays must promote to float32. Received arrays "
<< "with type " << a.dtype() << " and " << b.dtype() << ".";
msg << fname
<< " Input arrays must promote to float32 or float64. "
" Received arrays with type "
<< a.dtype() << " and " << b.dtype() << ".";
throw std::invalid_argument(msg.str());
}
}