Fix the softmax fix (#661)

This commit is contained in:
Awni Hannun
2024-02-09 17:02:13 -08:00
committed by GitHub
parent b96be943dc
commit 7f3f8d8f8d
3 changed files with 3 additions and 3 deletions

View File

@@ -276,7 +276,7 @@ void Softmax::eval_cpu(const std::vector<array>& inputs, array& out) {
auto check_input = [](array x) {
bool no_copy = x.strides()[x.ndim() - 1] == 1;
if (x.ndim() > 1) {
auto s = x.strides()[x.ndim() - 1];
auto s = x.strides()[x.ndim() - 2];
no_copy &= (s == 0 || s == x.shape().back());
}
if (no_copy) {

View File

@@ -55,7 +55,7 @@ void Softmax::eval(const std::vector<array>& inputs, array& out) {
auto check_input = [](array x) {
bool no_copy = x.strides()[x.ndim() - 1] == 1;
if (x.ndim() > 1) {
auto s = x.strides()[x.ndim() - 1];
auto s = x.strides()[x.ndim() - 2];
no_copy &= (s == 0 || s == x.shape().back());
}
if (no_copy) {

View File

@@ -24,7 +24,7 @@ void Softmax::eval_gpu(const std::vector<array>& inputs, array& out) {
auto check_input = [&copies, &s](const array& x) {
bool no_copy = x.strides()[x.ndim() - 1] == 1;
if (x.ndim() > 1) {
auto s = x.strides()[x.ndim() - 1];
auto s = x.strides()[x.ndim() - 2];
no_copy &= (s == 0 || s == x.shape().back());
}
if (no_copy) {