mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-18 23:21:16 +08:00
format
This commit is contained in:
parent
1a0e884036
commit
6bb0b254fd
@ -162,7 +162,7 @@ class MatMul {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void *workspace_ptr = nullptr;
|
void* workspace_ptr = nullptr;
|
||||||
if (heuristic_.workspaceSize > 0) {
|
if (heuristic_.workspaceSize > 0) {
|
||||||
array workspace(
|
array workspace(
|
||||||
allocator::malloc(heuristic_.workspaceSize),
|
allocator::malloc(heuristic_.workspaceSize),
|
||||||
@ -464,7 +464,14 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
|
|
||||||
auto nbatch = batch_count / batch_shape.back();
|
auto nbatch = batch_count / batch_shape.back();
|
||||||
if (nbatch == 1) {
|
if (nbatch == 1) {
|
||||||
matmul.run(encoder, out.data<int8_t>(), a.data<int8_t>(), b.data<int8_t>(), c.data<int8_t>(), alpha_, beta_);
|
matmul.run(
|
||||||
|
encoder,
|
||||||
|
out.data<int8_t>(),
|
||||||
|
a.data<int8_t>(),
|
||||||
|
b.data<int8_t>(),
|
||||||
|
c.data<int8_t>(),
|
||||||
|
alpha_,
|
||||||
|
beta_);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user