mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-18 15:11:14 +08:00
format
This commit is contained in:
parent
1a0e884036
commit
6bb0b254fd
@ -162,14 +162,14 @@ class MatMul {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void *workspace_ptr = nullptr;
|
void* workspace_ptr = nullptr;
|
||||||
if (heuristic_.workspaceSize > 0) {
|
if (heuristic_.workspaceSize > 0) {
|
||||||
array workspace(
|
array workspace(
|
||||||
allocator::malloc(heuristic_.workspaceSize),
|
allocator::malloc(heuristic_.workspaceSize),
|
||||||
{static_cast<int>(heuristic_.workspaceSize)},
|
{static_cast<int>(heuristic_.workspaceSize)},
|
||||||
int8);
|
int8);
|
||||||
encoder.add_temporary(workspace);
|
encoder.add_temporary(workspace);
|
||||||
workspace_ptr = workspace.data<void>();
|
workspace_ptr = workspace.data<void>();
|
||||||
}
|
}
|
||||||
|
|
||||||
encoder.launch_kernel([&](cudaStream_t stream) {
|
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||||
@ -464,7 +464,14 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
|
|
||||||
auto nbatch = batch_count / batch_shape.back();
|
auto nbatch = batch_count / batch_shape.back();
|
||||||
if (nbatch == 1) {
|
if (nbatch == 1) {
|
||||||
matmul.run(encoder, out.data<int8_t>(), a.data<int8_t>(), b.data<int8_t>(), c.data<int8_t>(), alpha_, beta_);
|
matmul.run(
|
||||||
|
encoder,
|
||||||
|
out.data<int8_t>(),
|
||||||
|
a.data<int8_t>(),
|
||||||
|
b.data<int8_t>(),
|
||||||
|
c.data<int8_t>(),
|
||||||
|
alpha_,
|
||||||
|
beta_);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user