mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Some checks failed
Nightly Build / build_linux_release (3.10) (push) Has been cancelled
Nightly Build / build_linux_release (3.14) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.10) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14) (push) Has been cancelled
Nightly Build / build_mac_release (3.10) (push) Has been cancelled
Nightly Build / build_mac_release (3.13) (push) Has been cancelled
Nightly Build / build_cuda_with_tests (push) Has been cancelled
Nightly Build / build_cuda_release (push) Has been cancelled
Nightly Build / Linux Fedora CPP Build (aarch64) (push) Has been cancelled
Nightly Build / Linux Fedora CPP Build (x86_64) (push) Has been cancelled
* Use async cuda malloc managed with cuda 13 * add pool threshold * refactor for regular cuda malloc * load eval gpu for cuda * remove use of cuda pool, use cuda free async * fix * fix * fix * fix * fix + comment
78 lines
2.3 KiB
C++
78 lines
2.3 KiB
C++
// Copyright © 2025 Apple Inc.
|
|
|
|
#include "mlx/backend/common/utils.h"
|
|
#include "mlx/backend/cuda/device.h"
|
|
#include "mlx/backend/cuda/gemms/cublas_gemm.h"
|
|
|
|
namespace mlx::core {
|
|
|
|
void CublasGemm::run_batched(
|
|
cu::CommandEncoder& encoder,
|
|
array& out,
|
|
const array& a,
|
|
const array& b,
|
|
const Shape& batch_shape,
|
|
const Strides& a_batch_strides,
|
|
const Strides& b_batch_strides,
|
|
float alpha) {
|
|
encoder.set_input_array(a);
|
|
encoder.set_input_array(b);
|
|
encoder.set_output_array(out);
|
|
auto nbatch = out.size() / (M_ * N_ * batch_shape.back());
|
|
ContiguousIterator a_it(batch_shape, a_batch_strides, batch_shape.size() - 1);
|
|
ContiguousIterator b_it(batch_shape, b_batch_strides, batch_shape.size() - 1);
|
|
auto concurrent = encoder.concurrent_context();
|
|
for (size_t i = 0; i < nbatch; ++i) {
|
|
execute(
|
|
encoder,
|
|
gpu_ptr<int8_t>(out) +
|
|
out.itemsize() * i * batch_shape.back() * M_ * N_,
|
|
gpu_ptr<int8_t>(a) + a.itemsize() * a_it.loc,
|
|
gpu_ptr<int8_t>(b) + b.itemsize() * b_it.loc,
|
|
nullptr,
|
|
alpha);
|
|
a_it.step();
|
|
b_it.step();
|
|
}
|
|
}
|
|
|
|
void CublasGemm::run_batched(
|
|
cu::CommandEncoder& encoder,
|
|
array& out,
|
|
const array& a,
|
|
const array& b,
|
|
const array& c,
|
|
const Shape& batch_shape,
|
|
const Strides& a_batch_strides,
|
|
const Strides& b_batch_strides,
|
|
const Strides& c_batch_strides,
|
|
float alpha,
|
|
float beta) {
|
|
encoder.set_input_array(a);
|
|
encoder.set_input_array(b);
|
|
encoder.set_input_array(c);
|
|
encoder.set_output_array(out);
|
|
|
|
auto nbatch = out.size() / (M_ * N_ * batch_shape.back());
|
|
ContiguousIterator a_it(batch_shape, a_batch_strides, batch_shape.size() - 1);
|
|
ContiguousIterator b_it(batch_shape, b_batch_strides, batch_shape.size() - 1);
|
|
ContiguousIterator c_it(batch_shape, c_batch_strides, batch_shape.size() - 1);
|
|
auto concurrent = encoder.concurrent_context();
|
|
for (size_t i = 0; i < nbatch; ++i) {
|
|
execute(
|
|
encoder,
|
|
gpu_ptr<int8_t>(out) +
|
|
out.itemsize() * i * batch_shape.back() * M_ * N_,
|
|
gpu_ptr<int8_t>(a) + a.itemsize() * a_it.loc,
|
|
gpu_ptr<int8_t>(b) + b.itemsize() * b_it.loc,
|
|
gpu_ptr<int8_t>(c) + c.itemsize() * c_it.loc,
|
|
alpha,
|
|
beta);
|
|
a_it.step();
|
|
b_it.step();
|
|
c_it.step();
|
|
}
|
|
}
|
|
|
|
} // namespace mlx::core
|