minor fixes (#631)

* minor fixes

* var with ddof >= nelements
This commit is contained in:
Awni Hannun
2024-02-05 13:27:49 -08:00
committed by GitHub
parent d75ae52ecd
commit d40a04f8dc
5 changed files with 40 additions and 5 deletions

View File

@@ -46,6 +46,9 @@ inline void matmul_cblas_general(
size_t N = b.shape(-1);
size_t K = a.shape(-1);
if (M == 0 || N == 0) {
return;
}
if (K == 0) {
std::memset(static_cast<void*>(out.data<float>()), 0, out.nbytes());
return;
@@ -94,6 +97,9 @@ inline void matmul_bnns_general(
size_t N = b.shape(-1);
size_t K = a.shape(-1);
if (M == 0 || N == 0) {
return;
}
if (K == 0) {
std::memset(static_cast<void*>(out.data<float>()), 0, out.nbytes());
return;

View File

@@ -132,7 +132,9 @@ inline void matmul_common_general(
size_t M = a.shape(-2);
size_t N = b.shape(-1);
size_t K = a.shape(-1);
if (M == 0 || N == 0) {
return;
}
if (K == 0) {
std::memset(static_cast<void*>(out.data<float>()), 0, out.nbytes());
return;

View File

@@ -1323,8 +1323,8 @@ array mean(
for (int axis : axes) {
if (axis < -ndim || axis >= ndim) {
std::ostringstream msg;
msg << "[mean] axis " << axis + " is out of bounds for array with "
<< ndim + " dimensions.";
msg << "[mean] axis " << axis << " is out of bounds for array with "
<< ndim << " dimensions.";
throw std::invalid_argument(msg.str());
}
}
@@ -1364,7 +1364,7 @@ array var(
if (ddof != 0) {
auto nelements = compute_number_of_elements(a, axes);
float factor = nelements / (nelements - ddof);
auto factor = nelements / static_cast<float>(std::max(nelements - ddof, 0));
v = multiply(v, array(factor, dtype), s);
}