mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 01:17:26 +08:00
Metal validation (#432)
* tests clear metal validation * add cpp test with metal validation to circleci * nit
This commit is contained in:
parent
975e265f74
commit
c9934fe8a4
@ -80,6 +80,13 @@ jobs:
|
||||
DEVICE=gpu python -m xmlrunner discover -v python/tests -o test-results/gpu
|
||||
- store_test_results:
|
||||
path: test-results
|
||||
- run:
|
||||
name: Build CPP only
|
||||
command: |
|
||||
mkdir -p build && cd build && cmake .. && make -j
|
||||
- run:
|
||||
name: Run CPP tests
|
||||
command: METAL_DEVICE_WRAPPER_TYPE=1 METAL_DEBUG_ERROR_MODE=0 ./build/tests/tests
|
||||
|
||||
build_release:
|
||||
machine: true
|
||||
|
@ -153,6 +153,11 @@ MetalAllocator::MetalAllocator()
|
||||
gc_limit_(0.95 * device_->recommendedMaxWorkingSetSize()) {}
|
||||
|
||||
Buffer MetalAllocator::malloc(size_t size, bool allow_swap /* = false */) {
|
||||
// Metal doesn't like empty buffers
|
||||
if (size == 0) {
|
||||
return Buffer{nullptr};
|
||||
}
|
||||
|
||||
// Align up memory
|
||||
if (size > vm_page_size) {
|
||||
size = vm_page_size * ((size + vm_page_size - 1) / vm_page_size);
|
||||
|
@ -20,6 +20,9 @@ void copy_gpu(const array& in, array& out, CopyType ctype, const Stream& s) {
|
||||
} else {
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
}
|
||||
if (out.size() == 0) {
|
||||
return;
|
||||
}
|
||||
if (ctype == CopyType::GeneralGeneral) {
|
||||
ctype = CopyType::General;
|
||||
}
|
||||
|
@ -1,5 +1,4 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
|
||||
#include <algorithm>
|
||||
#include <cassert>
|
||||
#include <numeric>
|
||||
@ -33,6 +32,9 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
}
|
||||
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
if (out.size() == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
auto& s = stream();
|
||||
auto& d = metal::device(s.device);
|
||||
@ -110,14 +112,18 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
for (int i = 0; i < nidx; ++i) {
|
||||
set_array_buffer(compute_encoder, arg_enc, inputs[i + 1], i);
|
||||
}
|
||||
if (idx_ndim > 0) {
|
||||
arg_enc->setBuffer(
|
||||
static_cast<MTL::Buffer*>(idx_shapes_buf.ptr()), 0, nidx + 1);
|
||||
compute_encoder->useResource(
|
||||
static_cast<MTL::Buffer*>(idx_shapes_buf.ptr()), MTL::ResourceUsageRead);
|
||||
static_cast<MTL::Buffer*>(idx_shapes_buf.ptr()),
|
||||
MTL::ResourceUsageRead);
|
||||
arg_enc->setBuffer(
|
||||
static_cast<MTL::Buffer*>(idx_strides_buf.ptr()), 0, nidx + 2);
|
||||
compute_encoder->useResource(
|
||||
static_cast<MTL::Buffer*>(idx_strides_buf.ptr()), MTL::ResourceUsageRead);
|
||||
static_cast<MTL::Buffer*>(idx_strides_buf.ptr()),
|
||||
MTL::ResourceUsageRead);
|
||||
}
|
||||
*static_cast<int*>(arg_enc->constantData(nidx + 3)) = idx_ndim;
|
||||
|
||||
// Set all the buffers
|
||||
@ -163,6 +169,11 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
inputs[0].data_size() == 1 ? CopyType::Scalar : CopyType::General;
|
||||
copy_gpu(inputs[0], out, copy_type);
|
||||
|
||||
// Empty update
|
||||
if (inputs.back().size() == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Get stream
|
||||
auto& s = stream();
|
||||
auto& d = metal::device(s.device);
|
||||
@ -254,14 +265,18 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
for (int i = 0; i < nidx; ++i) {
|
||||
set_array_buffer(compute_encoder, arg_enc, inputs[i + 1], i);
|
||||
}
|
||||
if (idx_ndim > 0) {
|
||||
arg_enc->setBuffer(
|
||||
static_cast<MTL::Buffer*>(idx_shapes_buf.ptr()), 0, nidx + 1);
|
||||
compute_encoder->useResource(
|
||||
static_cast<MTL::Buffer*>(idx_shapes_buf.ptr()), MTL::ResourceUsageRead);
|
||||
static_cast<MTL::Buffer*>(idx_shapes_buf.ptr()),
|
||||
MTL::ResourceUsageRead);
|
||||
arg_enc->setBuffer(
|
||||
static_cast<MTL::Buffer*>(idx_strides_buf.ptr()), 0, nidx + 2);
|
||||
compute_encoder->useResource(
|
||||
static_cast<MTL::Buffer*>(idx_strides_buf.ptr()), MTL::ResourceUsageRead);
|
||||
static_cast<MTL::Buffer*>(idx_strides_buf.ptr()),
|
||||
MTL::ResourceUsageRead);
|
||||
}
|
||||
*static_cast<int*>(arg_enc->constantData(nidx + 3)) = idx_ndim;
|
||||
|
||||
compute_encoder->setBuffer(static_cast<MTL::Buffer*>(arg_buf.ptr()), 0, 0);
|
||||
@ -272,14 +287,32 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
}
|
||||
set_array_buffer(compute_encoder, upd, 1);
|
||||
set_array_buffer(compute_encoder, out, 2);
|
||||
if (upd_ndim == 0) {
|
||||
// Need placeholders so Metal doesn't compalain
|
||||
int shape_ = 0;
|
||||
size_t stride_ = 0;
|
||||
compute_encoder->setBytes(&shape_, sizeof(int), 3);
|
||||
compute_encoder->setBytes(&stride_, sizeof(size_t), 4);
|
||||
} else {
|
||||
compute_encoder->setBytes(upd.shape().data(), upd_ndim * sizeof(int), 3);
|
||||
compute_encoder->setBytes(upd.strides().data(), upd_ndim * sizeof(size_t), 4);
|
||||
compute_encoder->setBytes(
|
||||
upd.strides().data(), upd_ndim * sizeof(size_t), 4);
|
||||
}
|
||||
compute_encoder->setBytes(&upd_ndim, sizeof(size_t), 5);
|
||||
compute_encoder->setBytes(&upd_size, sizeof(size_t), 6);
|
||||
|
||||
size_t out_ndim = out.ndim();
|
||||
if (out_ndim == 0) {
|
||||
// Need placeholders so Metal doesn't compalain
|
||||
int shape_ = 0;
|
||||
size_t stride_ = 0;
|
||||
compute_encoder->setBytes(&shape_, sizeof(int), 7);
|
||||
compute_encoder->setBytes(&stride_, sizeof(size_t), 8);
|
||||
} else {
|
||||
compute_encoder->setBytes(out.shape().data(), out_ndim * sizeof(int), 7);
|
||||
compute_encoder->setBytes(out.strides().data(), out_ndim * sizeof(size_t), 8);
|
||||
compute_encoder->setBytes(
|
||||
out.strides().data(), out_ndim * sizeof(size_t), 8);
|
||||
}
|
||||
compute_encoder->setBytes(&out_ndim, sizeof(size_t), 9);
|
||||
compute_encoder->setBytes(axes_.data(), axes_.size() * sizeof(int), 10);
|
||||
|
||||
|
@ -31,6 +31,9 @@ void binary_op(
|
||||
set_binary_op_output_data(a, b, outputs[1], bopt);
|
||||
|
||||
auto& out = outputs[0];
|
||||
if (out.size() == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Try to collapse contiguous dims
|
||||
auto [shape, strides] = collapse_contiguous_dims(a, b, out);
|
||||
@ -120,6 +123,9 @@ void binary_op(
|
||||
auto& b = inputs[1];
|
||||
auto bopt = get_binary_op_type(a, b);
|
||||
set_binary_op_output_data(a, b, out, bopt);
|
||||
if (out.size() == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Try to collapse contiguous dims
|
||||
auto [shape, strides] = collapse_contiguous_dims(a, b, out);
|
||||
@ -214,6 +220,9 @@ void unary_op(
|
||||
} else {
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
}
|
||||
if (in.size() == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
auto& s = out.primitive().stream();
|
||||
auto& d = metal::device(s.device);
|
||||
@ -263,6 +272,9 @@ void arange_set_scalars(T start, T next, MTL::ComputeCommandEncoder* enc) {
|
||||
void Arange::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 0);
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
if (out.size() == 0) {
|
||||
return;
|
||||
}
|
||||
auto& s = stream();
|
||||
auto& d = metal::device(s.device);
|
||||
auto kernel = d.get_kernel("arange" + type_to_name(out));
|
||||
@ -390,9 +402,18 @@ void ArgReduce::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
set_array_buffer(compute_encoder, in, 0);
|
||||
set_array_buffer(compute_encoder, out, 1);
|
||||
if (ndim == 0) {
|
||||
// Pass place holders so metal doesn't complain
|
||||
int shape_ = 0;
|
||||
size_t stride_ = 0;
|
||||
compute_encoder->setBytes(&shape_, sizeof(int), 2);
|
||||
compute_encoder->setBytes(&stride_, sizeof(size_t), 3);
|
||||
compute_encoder->setBytes(&stride_, sizeof(size_t), 4);
|
||||
} else {
|
||||
compute_encoder->setBytes(shape.data(), ndim * sizeof(int), 2);
|
||||
compute_encoder->setBytes(in_strides.data(), ndim * sizeof(size_t), 3);
|
||||
compute_encoder->setBytes(out_strides.data(), ndim * sizeof(size_t), 4);
|
||||
}
|
||||
compute_encoder->setBytes(&ndim, sizeof(size_t), 5);
|
||||
compute_encoder->setBytes(&axis_stride, sizeof(size_t), 6);
|
||||
compute_encoder->setBytes(&axis_size, sizeof(size_t), 7);
|
||||
@ -629,6 +650,9 @@ void RandomBits::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
size_t elems_per_key = out.size() / num_keys;
|
||||
size_t bytes_per_key = out.itemsize() * elems_per_key;
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
if (out.size() == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
size_t out_per_key = (bytes_per_key + 4 - 1) / 4;
|
||||
size_t half_size = out_per_key / 2;
|
||||
|
@ -2,7 +2,6 @@
|
||||
|
||||
#include <algorithm>
|
||||
#include <cassert>
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
|
||||
#include "mlx/backend/common/reduce.h"
|
||||
@ -21,10 +20,14 @@ namespace mlx::core {
|
||||
|
||||
namespace {
|
||||
|
||||
inline auto safe_divup(size_t n, size_t m) {
|
||||
inline auto safe_div(size_t n, size_t m) {
|
||||
return m == 0 ? 0 : (n + m - 1) / m;
|
||||
}
|
||||
|
||||
inline auto safe_divup(size_t n, size_t m) {
|
||||
return safe_div(n, m) * m;
|
||||
}
|
||||
|
||||
// All Reduce
|
||||
void all_reduce_dispatch(
|
||||
const array& in,
|
||||
@ -56,7 +59,7 @@ void all_reduce_dispatch(
|
||||
mod_in_size > thread_group_size ? thread_group_size : mod_in_size;
|
||||
|
||||
// If the number of thread groups needed exceeds 1024, we reuse threads groups
|
||||
uint n_thread_groups = safe_divup(mod_in_size, thread_group_size);
|
||||
uint n_thread_groups = safe_div(mod_in_size, thread_group_size);
|
||||
n_thread_groups = std::min(n_thread_groups, 1024u);
|
||||
uint nthreads = n_thread_groups * thread_group_size;
|
||||
|
||||
@ -204,7 +207,8 @@ void strided_reduce_general_dispatch(
|
||||
// if we ever come to doubles. In that case, we should also cut
|
||||
// down the number of threads we launch in a threadgroup
|
||||
compute_encoder->setThreadgroupMemoryLength(
|
||||
threadgroup_dim_x * threadgroup_dim_y * out.itemsize(), 0);
|
||||
safe_divup(threadgroup_dim_x * threadgroup_dim_y * out.itemsize(), 16),
|
||||
0);
|
||||
|
||||
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
|
||||
}
|
||||
@ -231,7 +235,10 @@ void Reduce::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(!axes_.empty());
|
||||
|
||||
// Continue with reduction operation
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
// Minimum of 4 bytes since we use size 4 structs for all reduce
|
||||
// and metal will complain o/w
|
||||
size_t min_bytes = std::max(out.nbytes(), 4ul);
|
||||
out.set_data(allocator::malloc_or_wait(min_bytes));
|
||||
std::string op_name;
|
||||
switch (reduce_type_) {
|
||||
case Reduce::And:
|
||||
@ -273,7 +280,7 @@ void Reduce::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
}
|
||||
|
||||
// Reduce
|
||||
{
|
||||
if (in.size() > 0) {
|
||||
std::vector<array> copies;
|
||||
ReductionPlan plan = get_reduction_plan(in, axes_);
|
||||
|
||||
|
@ -1,5 +1,4 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
|
||||
#include <cmath>
|
||||
#include <numeric>
|
||||
#include <set>
|
||||
|
@ -38,6 +38,4 @@ TEST_CASE("test large allocations") {
|
||||
auto buffer = allocator::malloc(size);
|
||||
allocator::free(buffer);
|
||||
}
|
||||
// Shouldn't be able to allocate an exabyte anytime soon.
|
||||
CHECK_THROWS_AS(allocator::malloc(1ull << 60), std::runtime_error);
|
||||
}
|
||||
|
@ -1,7 +1,5 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include "doctest/doctest.h"
|
||||
|
||||
#include "mlx/mlx.h"
|
||||
|
@ -438,3 +438,36 @@ TEST_CASE("test metal matmul") {
|
||||
CHECK(array_equal(out, full({3, 3, 2, 2}, 2.0f), Device::cpu).item<bool>());
|
||||
}
|
||||
}
|
||||
|
||||
TEST_CASE("test metal validation") {
|
||||
// Run this test with Metal validation enabled
|
||||
// METAL_DEVICE_WRAPPER_TYPE=1 METAL_DEBUG_ERROR_MODE=0 ./tests/tests \
|
||||
// -tc="test metal validation" \
|
||||
|
||||
auto x = array({});
|
||||
eval(exp(x));
|
||||
|
||||
auto y = array({});
|
||||
eval(add(x, y));
|
||||
|
||||
eval(sum(x));
|
||||
|
||||
x = array({1, 2, 3});
|
||||
y = array(0);
|
||||
eval(gather(x, y, 0, {0}));
|
||||
eval(gather(x, y, 0, {2}));
|
||||
|
||||
eval(gather(x, y, 0, {0}));
|
||||
eval(gather(x, y, 0, {2}));
|
||||
|
||||
eval(scatter(x, y, array({2}), 0));
|
||||
|
||||
x = arange(0, -3, 1);
|
||||
eval(x);
|
||||
array_equal(x, array({})).item<bool>();
|
||||
|
||||
x = array({1.0, 0.0});
|
||||
eval(argmax(x));
|
||||
|
||||
eval(scatter_max(array(1), {}, array(2), std::vector<int>{}));
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user