mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-09 18:56:39 +08:00
Fix the softmax fix (#661)
This commit is contained in:
parent
b96be943dc
commit
7f3f8d8f8d
@ -276,7 +276,7 @@ void Softmax::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
auto check_input = [](array x) {
|
auto check_input = [](array x) {
|
||||||
bool no_copy = x.strides()[x.ndim() - 1] == 1;
|
bool no_copy = x.strides()[x.ndim() - 1] == 1;
|
||||||
if (x.ndim() > 1) {
|
if (x.ndim() > 1) {
|
||||||
auto s = x.strides()[x.ndim() - 1];
|
auto s = x.strides()[x.ndim() - 2];
|
||||||
no_copy &= (s == 0 || s == x.shape().back());
|
no_copy &= (s == 0 || s == x.shape().back());
|
||||||
}
|
}
|
||||||
if (no_copy) {
|
if (no_copy) {
|
||||||
|
@ -55,7 +55,7 @@ void Softmax::eval(const std::vector<array>& inputs, array& out) {
|
|||||||
auto check_input = [](array x) {
|
auto check_input = [](array x) {
|
||||||
bool no_copy = x.strides()[x.ndim() - 1] == 1;
|
bool no_copy = x.strides()[x.ndim() - 1] == 1;
|
||||||
if (x.ndim() > 1) {
|
if (x.ndim() > 1) {
|
||||||
auto s = x.strides()[x.ndim() - 1];
|
auto s = x.strides()[x.ndim() - 2];
|
||||||
no_copy &= (s == 0 || s == x.shape().back());
|
no_copy &= (s == 0 || s == x.shape().back());
|
||||||
}
|
}
|
||||||
if (no_copy) {
|
if (no_copy) {
|
||||||
|
@ -24,7 +24,7 @@ void Softmax::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
auto check_input = [&copies, &s](const array& x) {
|
auto check_input = [&copies, &s](const array& x) {
|
||||||
bool no_copy = x.strides()[x.ndim() - 1] == 1;
|
bool no_copy = x.strides()[x.ndim() - 1] == 1;
|
||||||
if (x.ndim() > 1) {
|
if (x.ndim() > 1) {
|
||||||
auto s = x.strides()[x.ndim() - 1];
|
auto s = x.strides()[x.ndim() - 2];
|
||||||
no_copy &= (s == 0 || s == x.shape().back());
|
no_copy &= (s == 0 || s == x.shape().back());
|
||||||
}
|
}
|
||||||
if (no_copy) {
|
if (no_copy) {
|
||||||
|
Loading…
Reference in New Issue
Block a user