mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Custom logsumexp (#2028)
* initial custom logsumexp * more tests * comments + fix
This commit is contained in:
@@ -23,12 +23,7 @@ void Softmax::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
|
||||
// Make sure that the last dimension is contiguous
|
||||
auto set_output = [&s, &out](const array& x) {
|
||||
bool no_copy = x.flags().contiguous && x.strides()[x.ndim() - 1] == 1;
|
||||
if (no_copy && x.ndim() > 1) {
|
||||
auto s = x.strides()[x.ndim() - 2];
|
||||
no_copy &= (s == 0 || s == x.shape().back());
|
||||
}
|
||||
if (no_copy) {
|
||||
if (x.flags().contiguous && x.strides()[x.ndim() - 1] == 1) {
|
||||
if (x.is_donatable()) {
|
||||
out.copy_shared_buffer(x);
|
||||
} else {
|
||||
|
||||
Reference in New Issue
Block a user