mirror of
https://github.com/ml-explore/mlx.git
synced 2025-11-03 01:48:12 +08:00
78 lines
2.3 KiB
C++
78 lines
2.3 KiB
C++
// Copyright © 2025 Apple Inc.
|
|
|
|
#include "mlx/backend/common/utils.h"
|
|
#include "mlx/backend/cuda/device.h"
|
|
#include "mlx/backend/cuda/gemms/cublas_gemm.h"
|
|
|
|
namespace mlx::core {
|
|
|
|
void CublasGemm::run_batched(
|
|
cu::CommandEncoder& encoder,
|
|
array& out,
|
|
const array& a,
|
|
const array& b,
|
|
const Shape& batch_shape,
|
|
const Strides& a_batch_strides,
|
|
const Strides& b_batch_strides,
|
|
float alpha) {
|
|
encoder.set_input_array(a);
|
|
encoder.set_input_array(b);
|
|
encoder.set_output_array(out);
|
|
auto nbatch = out.size() / (M_ * N_ * batch_shape.back());
|
|
ContiguousIterator a_it(batch_shape, a_batch_strides, batch_shape.size() - 1);
|
|
ContiguousIterator b_it(batch_shape, b_batch_strides, batch_shape.size() - 1);
|
|
auto concurrent = encoder.concurrent_context();
|
|
for (size_t i = 0; i < nbatch; ++i) {
|
|
execute(
|
|
encoder,
|
|
gpu_ptr<int8_t>(out) +
|
|
out.itemsize() * i * batch_shape.back() * M_ * N_,
|
|
gpu_ptr<int8_t>(a) + a.itemsize() * a_it.loc,
|
|
gpu_ptr<int8_t>(b) + b.itemsize() * b_it.loc,
|
|
nullptr,
|
|
alpha);
|
|
a_it.step();
|
|
b_it.step();
|
|
}
|
|
}
|
|
|
|
void CublasGemm::run_batched(
|
|
cu::CommandEncoder& encoder,
|
|
array& out,
|
|
const array& a,
|
|
const array& b,
|
|
const array& c,
|
|
const Shape& batch_shape,
|
|
const Strides& a_batch_strides,
|
|
const Strides& b_batch_strides,
|
|
const Strides& c_batch_strides,
|
|
float alpha,
|
|
float beta) {
|
|
encoder.set_input_array(a);
|
|
encoder.set_input_array(b);
|
|
encoder.set_input_array(c);
|
|
encoder.set_output_array(out);
|
|
|
|
auto nbatch = out.size() / (M_ * N_ * batch_shape.back());
|
|
ContiguousIterator a_it(batch_shape, a_batch_strides, batch_shape.size() - 1);
|
|
ContiguousIterator b_it(batch_shape, b_batch_strides, batch_shape.size() - 1);
|
|
ContiguousIterator c_it(batch_shape, c_batch_strides, batch_shape.size() - 1);
|
|
auto concurrent = encoder.concurrent_context();
|
|
for (size_t i = 0; i < nbatch; ++i) {
|
|
execute(
|
|
encoder,
|
|
gpu_ptr<int8_t>(out) +
|
|
out.itemsize() * i * batch_shape.back() * M_ * N_,
|
|
gpu_ptr<int8_t>(a) + a.itemsize() * a_it.loc,
|
|
gpu_ptr<int8_t>(b) + b.itemsize() * b_it.loc,
|
|
gpu_ptr<int8_t>(c) + c.itemsize() * c_it.loc,
|
|
alpha,
|
|
beta);
|
|
a_it.step();
|
|
b_it.step();
|
|
c_it.step();
|
|
}
|
|
}
|
|
|
|
} // namespace mlx::core
|