mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-22 02:58:16 +08:00
Close a couple edge case bugs: hadamard and addmm on empty inputs (#2177)
* handle hadamard and addmm on empty inputs * fix
This commit is contained in:
@@ -716,6 +716,23 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
throw std::runtime_error(
|
||||
"[matmul] Does not yet support non-floating point types.");
|
||||
}
|
||||
|
||||
// Return 0s if either input is empty
|
||||
if (out.size() == 0) {
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
return;
|
||||
}
|
||||
|
||||
// Copy c into out and return
|
||||
if (inputs[0].shape(-1) == 0) {
|
||||
copy_gpu(
|
||||
inputs[2],
|
||||
out,
|
||||
inputs[2].flags().row_contiguous ? CopyType::Vector : CopyType::General,
|
||||
stream());
|
||||
return;
|
||||
}
|
||||
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
auto& s = stream();
|
||||
auto& d = metal::device(s.device);
|
||||
|
Reference in New Issue
Block a user