mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 09:21:16 +08:00
fix adding inputs arrays in matmul / srot
This commit is contained in:
parent
72e21b7d51
commit
1a0e884036
@ -1,6 +1,5 @@
|
|||||||
// Copyright © 2025 Apple Inc.
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
#include "mlx/utils.h"
|
|
||||||
#include "mlx/backend/cuda/allocator.h"
|
#include "mlx/backend/cuda/allocator.h"
|
||||||
#include "mlx/backend/cuda/utils.h"
|
#include "mlx/backend/cuda/utils.h"
|
||||||
#include "mlx/backend/cuda/worker.h"
|
#include "mlx/backend/cuda/worker.h"
|
||||||
|
@ -362,9 +362,18 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
a_batch_strides.back(),
|
a_batch_strides.back(),
|
||||||
b_batch_strides.back());
|
b_batch_strides.back());
|
||||||
|
|
||||||
|
encoder.set_input_array(a);
|
||||||
|
encoder.set_input_array(b);
|
||||||
|
encoder.set_output_array(out);
|
||||||
|
auto nbatch = batch_count / batch_shape.back();
|
||||||
|
if (nbatch == 1) {
|
||||||
|
matmul.run(encoder, out.data<int8_t>(), a.data<int8_t>(), b.data<int8_t>());
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
ContiguousIterator a_it(batch_shape, a_batch_strides, batch_shape.size() - 1);
|
ContiguousIterator a_it(batch_shape, a_batch_strides, batch_shape.size() - 1);
|
||||||
ContiguousIterator b_it(batch_shape, b_batch_strides, batch_shape.size() - 1);
|
ContiguousIterator b_it(batch_shape, b_batch_strides, batch_shape.size() - 1);
|
||||||
for (size_t i = 0; i < batch_count / batch_shape.back(); ++i) {
|
for (size_t i = 0; i < nbatch; ++i) {
|
||||||
matmul.run(
|
matmul.run(
|
||||||
encoder,
|
encoder,
|
||||||
out.data<int8_t>() + out.itemsize() * i * batch_shape.back() * M * N,
|
out.data<int8_t>() + out.itemsize() * i * batch_shape.back() * M * N,
|
||||||
@ -448,10 +457,21 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
b_batch_strides.back(),
|
b_batch_strides.back(),
|
||||||
c_batch_strides.back());
|
c_batch_strides.back());
|
||||||
|
|
||||||
|
encoder.set_input_array(a);
|
||||||
|
encoder.set_input_array(b);
|
||||||
|
encoder.set_input_array(c);
|
||||||
|
encoder.set_output_array(out);
|
||||||
|
|
||||||
|
auto nbatch = batch_count / batch_shape.back();
|
||||||
|
if (nbatch == 1) {
|
||||||
|
matmul.run(encoder, out.data<int8_t>(), a.data<int8_t>(), b.data<int8_t>(), c.data<int8_t>(), alpha_, beta_);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
ContiguousIterator a_it(batch_shape, a_batch_strides, batch_shape.size() - 1);
|
ContiguousIterator a_it(batch_shape, a_batch_strides, batch_shape.size() - 1);
|
||||||
ContiguousIterator b_it(batch_shape, b_batch_strides, batch_shape.size() - 1);
|
ContiguousIterator b_it(batch_shape, b_batch_strides, batch_shape.size() - 1);
|
||||||
ContiguousIterator c_it(batch_shape, c_batch_strides, batch_shape.size() - 1);
|
ContiguousIterator c_it(batch_shape, c_batch_strides, batch_shape.size() - 1);
|
||||||
for (size_t i = 0; i < batch_count / batch_shape.back(); ++i) {
|
for (size_t i = 0; i < nbatch; ++i) {
|
||||||
matmul.run(
|
matmul.run(
|
||||||
encoder,
|
encoder,
|
||||||
out.data<int8_t>() + out.itemsize() * i * batch_shape.back() * M * N,
|
out.data<int8_t>() + out.itemsize() * i * batch_shape.back() * M * N,
|
||||||
|
@ -79,9 +79,6 @@ void segmented_sort(cu::CommandEncoder& encoder, Args&&... args) {
|
|||||||
void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) {
|
void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) {
|
||||||
array out = out_;
|
array out = out_;
|
||||||
auto& encoder = cu::get_command_encoder(s);
|
auto& encoder = cu::get_command_encoder(s);
|
||||||
encoder.set_input_array(in);
|
|
||||||
encoder.set_output_array(out);
|
|
||||||
|
|
||||||
if (axis < 0) {
|
if (axis < 0) {
|
||||||
axis += in.ndim();
|
axis += in.ndim();
|
||||||
}
|
}
|
||||||
@ -106,6 +103,8 @@ void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) {
|
|||||||
in.flags());
|
in.flags());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
encoder.set_input_array(in);
|
||||||
|
encoder.set_output_array(out);
|
||||||
encoder.launch_kernel([&](cudaStream_t stream) {
|
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||||
MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE, {
|
MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE, {
|
||||||
if constexpr (!std::is_same_v<CTYPE, complex64_t>) {
|
if constexpr (!std::is_same_v<CTYPE, complex64_t>) {
|
||||||
|
Loading…
Reference in New Issue
Block a user