Minimum xcode / sdk (#800)

* minimum xcode /sdk

* try multiple xcode versions in CI

* update python

* metal validation for python tests
This commit is contained in:
Awni Hannun
2024-03-07 08:19:43 -08:00
committed by GitHub
parent afd5274049
commit f512b905c7
4 changed files with 29 additions and 16 deletions

View File

@@ -428,12 +428,21 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
throw std::runtime_error(
"[matmul] Does not yet support non-floating point types.");
}
out.set_data(allocator::malloc_or_wait(out.nbytes()));
auto& s = stream();
auto& d = metal::device(s.device);
auto& a_pre = inputs[0];
auto& b_pre = inputs[1];
// Return 0s if either input is empty
if (a_pre.size() == 0 || b_pre.size() == 0) {
array zero = array(0, a_pre.dtype());
copy_gpu(zero, out, CopyType::Scalar, s);
auto command_buffer = d.get_command_buffer(s.index);
command_buffer->addCompletedHandler([zero](MTL::CommandBuffer*) {});
return;
}
out.set_data(allocator::malloc_or_wait(out.nbytes()));
/////////////////////////////////////////////////////////////////////////////
// Init checks and prep
@@ -573,7 +582,6 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
[copies](MTL::CommandBuffer*) mutable { copies.clear(); });
return;
}
/////////////////////////////////////////////////////////////////////////////
// Gemm specialization