Fix the softmax fix (#661)

This commit is contained in:
Awni Hannun 2024-02-09 17:02:13 -08:00 committed by GitHub
parent b96be943dc
commit 7f3f8d8f8d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 3 additions and 3 deletions

View File

@ -276,7 +276,7 @@ void Softmax::eval_cpu(const std::vector<array>& inputs, array& out) {
auto check_input = [](array x) { auto check_input = [](array x) {
bool no_copy = x.strides()[x.ndim() - 1] == 1; bool no_copy = x.strides()[x.ndim() - 1] == 1;
if (x.ndim() > 1) { if (x.ndim() > 1) {
auto s = x.strides()[x.ndim() - 1]; auto s = x.strides()[x.ndim() - 2];
no_copy &= (s == 0 || s == x.shape().back()); no_copy &= (s == 0 || s == x.shape().back());
} }
if (no_copy) { if (no_copy) {

View File

@ -55,7 +55,7 @@ void Softmax::eval(const std::vector<array>& inputs, array& out) {
auto check_input = [](array x) { auto check_input = [](array x) {
bool no_copy = x.strides()[x.ndim() - 1] == 1; bool no_copy = x.strides()[x.ndim() - 1] == 1;
if (x.ndim() > 1) { if (x.ndim() > 1) {
auto s = x.strides()[x.ndim() - 1]; auto s = x.strides()[x.ndim() - 2];
no_copy &= (s == 0 || s == x.shape().back()); no_copy &= (s == 0 || s == x.shape().back());
} }
if (no_copy) { if (no_copy) {

View File

@ -24,7 +24,7 @@ void Softmax::eval_gpu(const std::vector<array>& inputs, array& out) {
auto check_input = [&copies, &s](const array& x) { auto check_input = [&copies, &s](const array& x) {
bool no_copy = x.strides()[x.ndim() - 1] == 1; bool no_copy = x.strides()[x.ndim() - 1] == 1;
if (x.ndim() > 1) { if (x.ndim() > 1) {
auto s = x.strides()[x.ndim() - 1]; auto s = x.strides()[x.ndim() - 2];
no_copy &= (s == 0 || s == x.shape().back()); no_copy &= (s == 0 || s == x.shape().back());
} }
if (no_copy) { if (no_copy) {