mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 09:21:16 +08:00
Fix a couple bugs (#1161)
* fix jit reduce for RMS norm * make strides a single buffer * better eval error message * fix compiling with inf and bf16 * fix cpu compile with bf16
This commit is contained in:
parent
a87ef5bfc1
commit
e7a2a3dcd1
@ -28,6 +28,7 @@ const char* get_kernel_preamble() {
|
||||
return R"preamble(
|
||||
$INCLUDES
|
||||
$CONTENT
|
||||
using namespace mlx::core;
|
||||
using namespace mlx::core::detail;
|
||||
)preamble";
|
||||
}
|
||||
|
@ -56,12 +56,15 @@ inline void build_kernel(
|
||||
} else {
|
||||
add_indices = true;
|
||||
os << " device const " << get_type_string(x.dtype()) << "* " << xname
|
||||
<< " [[buffer(" << cnt++ << ")]]," << std::endl
|
||||
<< " constant const size_t* " << xname << "_strides [[buffer("
|
||||
<< cnt++ << ")]]," << std::endl;
|
||||
<< " [[buffer(" << cnt++ << ")]]," << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
if (add_indices) {
|
||||
os << " constant const size_t* in_strides [[buffer(" << cnt++
|
||||
<< ")]],\n";
|
||||
}
|
||||
|
||||
// Add the output arguments
|
||||
for (auto& x : outputs) {
|
||||
os << " device " << get_type_string(x.dtype()) << "* "
|
||||
@ -110,13 +113,17 @@ inline void build_kernel(
|
||||
}
|
||||
|
||||
// Read the inputs in tmps
|
||||
for (auto& x : inputs) {
|
||||
int nc_in_count = 0;
|
||||
for (int i = 0; i < inputs.size(); ++i) {
|
||||
auto& x = inputs[i];
|
||||
auto& xname = namer.get_name(x);
|
||||
|
||||
if (is_constant(x)) {
|
||||
os << " " << get_type_string(x.dtype()) << " tmp_" << xname << " = ";
|
||||
auto type_str = get_type_string(x.dtype());
|
||||
os << " auto tmp_" << xname << " = static_cast<"
|
||||
<< get_type_string(x.dtype()) << ">(";
|
||||
print_constant(os, x);
|
||||
os << ";" << std::endl;
|
||||
os << ");" << std::endl;
|
||||
} else if (is_scalar(x)) {
|
||||
os << " " << get_type_string(x.dtype()) << " tmp_" << xname << " = "
|
||||
<< xname << "[0];" << std::endl;
|
||||
@ -124,17 +131,20 @@ inline void build_kernel(
|
||||
os << " " << get_type_string(x.dtype()) << " tmp_" << xname << " = "
|
||||
<< xname << "[index];" << std::endl;
|
||||
} else if (!dynamic_dims) {
|
||||
int offset = nc_in_count * ndim;
|
||||
os << " " << get_type_string(x.dtype()) << " tmp_" << xname << " = "
|
||||
<< xname << "[";
|
||||
os << "index_0 * " << xname << "_strides[0]";
|
||||
os << "index_0 * " << "in_strides[" << offset << "]";
|
||||
for (int i = 1; i < ndim; i++) {
|
||||
os << " + index_" << i << " * " << xname << "_strides[" << i << "]";
|
||||
os << " + index_" << i << " * " << "in_strides[" << offset + i << "]";
|
||||
}
|
||||
os << "];" << std::endl;
|
||||
nc_in_count++;
|
||||
} else {
|
||||
os << " " << get_type_string(x.dtype()) << " tmp_" << xname << " = "
|
||||
<< xname << "[elem_to_loc(index, output_shape, " << xname
|
||||
<< "_strides, ndim)];" << std::endl;
|
||||
<< xname << "[elem_to_loc(index, output_shape, in_strides + "
|
||||
<< nc_in_count * ndim << ", ndim)];" << std::endl;
|
||||
nc_in_count++;
|
||||
}
|
||||
}
|
||||
|
||||
@ -296,6 +306,7 @@ void Compiled::eval_gpu(
|
||||
// Put the inputs in
|
||||
int cnt = 0;
|
||||
int stride_idx = 1; // idx 0 is the output strides
|
||||
std::vector<size_t> in_strides;
|
||||
for (int i = 0; i < inputs.size(); i++) {
|
||||
if (constant_ids_.find(inputs_[i].id()) != constant_ids_.end()) {
|
||||
continue;
|
||||
@ -303,13 +314,17 @@ void Compiled::eval_gpu(
|
||||
auto& x = inputs[i];
|
||||
compute_encoder.set_input_array(x, cnt++);
|
||||
if (!contiguous && !is_scalar(x)) {
|
||||
compute_encoder->setBytes(
|
||||
strides[stride_idx].data(),
|
||||
strides[stride_idx].size() * sizeof(size_t),
|
||||
cnt++);
|
||||
in_strides.insert(
|
||||
in_strides.end(),
|
||||
strides[stride_idx].begin(),
|
||||
strides[stride_idx].end());
|
||||
stride_idx++;
|
||||
}
|
||||
}
|
||||
if (!in_strides.empty()) {
|
||||
compute_encoder->setBytes(
|
||||
in_strides.data(), in_strides.size() * sizeof(size_t), cnt++);
|
||||
}
|
||||
|
||||
compiled_allocate_outputs(
|
||||
inputs, outputs, inputs_, constant_ids_, contiguous, true);
|
||||
|
@ -259,11 +259,14 @@ MTL::ComputePipelineState* get_reduce_init_kernel(
|
||||
MTL::ComputePipelineState* get_reduce_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
const std::string& op_name,
|
||||
const array& in,
|
||||
const array& out) {
|
||||
std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1);
|
||||
auto lib = d.get_library(lib_name);
|
||||
if (lib == nullptr) {
|
||||
std::string op_type = op_name;
|
||||
op_type[0] = std::toupper(op_name[0]);
|
||||
bool non_atomic = out.dtype() == int64 || out.dtype() == uint64;
|
||||
std::ostringstream kernel_source;
|
||||
kernel_source << metal::utils() << metal::reduce_utils() << metal::reduce()
|
||||
@ -273,7 +276,7 @@ MTL::ComputePipelineState* get_reduce_kernel(
|
||||
lib_name,
|
||||
get_type_string(in.dtype()),
|
||||
get_type_string(out.dtype()),
|
||||
op_name(out));
|
||||
op_type);
|
||||
lib = d.get_library(lib_name, kernel_source.str());
|
||||
}
|
||||
return d.get_kernel(kernel_name, lib);
|
||||
|
@ -76,6 +76,7 @@ MTL::ComputePipelineState* get_reduce_init_kernel(
|
||||
MTL::ComputePipelineState* get_reduce_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
const std::string& op_name,
|
||||
const array& in,
|
||||
const array& out);
|
||||
|
||||
|
@ -38,8 +38,8 @@
|
||||
#define instantiate_reduce_ops(inst_f, type_f) \
|
||||
type_f(inst_f, sum, Sum) \
|
||||
type_f(inst_f, prod, Prod) \
|
||||
type_f(inst_f, min_, Min) \
|
||||
type_f(inst_f, max_, Max)
|
||||
type_f(inst_f, min, Min) \
|
||||
type_f(inst_f, max, Max)
|
||||
|
||||
// Special case for bool reductions
|
||||
#define instantiate_reduce_from_types_helper( \
|
||||
|
@ -98,6 +98,7 @@ MTL::ComputePipelineState* get_reduce_init_kernel(
|
||||
MTL::ComputePipelineState* get_reduce_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
const std::string&,
|
||||
const array&,
|
||||
const array&) {
|
||||
return d.get_kernel(kernel_name);
|
||||
|
@ -46,7 +46,7 @@ void all_reduce_dispatch(
|
||||
kernel_name += "NoAtomics";
|
||||
}
|
||||
kernel_name += "_reduce_" + op_name + type_to_name(in);
|
||||
auto kernel = get_reduce_kernel(d, kernel_name, in, out);
|
||||
auto kernel = get_reduce_kernel(d, kernel_name, op_name, in, out);
|
||||
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
@ -175,7 +175,7 @@ void row_reduce_general_dispatch(
|
||||
kname << "rowGeneral" << small_desc << "_reduce_" << op_name
|
||||
<< type_to_name(in);
|
||||
|
||||
auto kernel = get_reduce_kernel(d, kname.str(), in, out);
|
||||
auto kernel = get_reduce_kernel(d, kname.str(), op_name, in, out);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
// Get dispatch grid dims
|
||||
@ -342,7 +342,7 @@ void strided_reduce_general_dispatch(
|
||||
if (reduction_size * non_col_reductions < 16) {
|
||||
// Select kernel
|
||||
auto kernel = get_reduce_kernel(
|
||||
d, "colSmall_reduce_" + op_name + type_to_name(in), in, out);
|
||||
d, "colSmall_reduce_" + op_name + type_to_name(in), op_name, in, out);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
// Select block dims
|
||||
@ -384,7 +384,7 @@ void strided_reduce_general_dispatch(
|
||||
kernel_name += "NoAtomics";
|
||||
}
|
||||
kernel_name += "_reduce_" + op_name + type_to_name(in);
|
||||
auto kernel = get_reduce_kernel(d, kernel_name, in, out);
|
||||
auto kernel = get_reduce_kernel(d, kernel_name, op_name, in, out);
|
||||
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
@ -501,7 +501,7 @@ void strided_reduce_general_dispatch(
|
||||
std::string kernel_name =
|
||||
"rowGeneralNoAtomics_reduce_" + op_name + type_to_name(intermediate);
|
||||
auto row_reduce_kernel =
|
||||
get_reduce_kernel(d, kernel_name, intermediate, out);
|
||||
get_reduce_kernel(d, kernel_name, op_name, intermediate, out);
|
||||
|
||||
compute_encoder->setComputePipelineState(row_reduce_kernel);
|
||||
compute_encoder.set_input_array(intermediate, 0);
|
||||
@ -573,10 +573,10 @@ void Reduce::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
op_name = out.dtype() == bool_ ? "and" : "prod";
|
||||
break;
|
||||
case Reduce::Min:
|
||||
op_name = out.dtype() == bool_ ? "and" : "min_";
|
||||
op_name = out.dtype() == bool_ ? "and" : "min";
|
||||
break;
|
||||
case Reduce::Max:
|
||||
op_name = out.dtype() == bool_ ? "or" : "max_";
|
||||
op_name = out.dtype() == bool_ ? "or" : "max";
|
||||
break;
|
||||
}
|
||||
|
||||
|
@ -88,9 +88,13 @@ array eval_impl(std::vector<array> outputs, bool async) {
|
||||
" transformations like compile or vmap is not allowed.");
|
||||
}
|
||||
throw std::runtime_error(
|
||||
"[eval] Attempting to eval an array without a primitive. "
|
||||
"This may be a bug, please file an issue here: "
|
||||
" https://github.com/ml-explore/mlx/issues.");
|
||||
"[eval] Attempting to eval an array without a primitive.\n"
|
||||
"If you are compiling a function, make sure all the inputs "
|
||||
"and outputs are captured:\n"
|
||||
"https://ml-explore.github.io/mlx/build/html/usage/compile.html#pure-functions.\n"
|
||||
"If you are not using compile, this may be a bug. "
|
||||
"Please file an issue here:\n"
|
||||
"https://github.com/ml-explore/mlx/issues.");
|
||||
}
|
||||
if (a.primitive().stream() != in.primitive().stream()) {
|
||||
needs_signal.insert(in.id());
|
||||
|
@ -704,6 +704,13 @@ class TestCompile(mlx_tests.MLXTestCase):
|
||||
self.assertEqual(y1.item(), y2.item())
|
||||
self.assertEqual(y1.item(), 6)
|
||||
|
||||
def test_inf_constant(self):
|
||||
def fn(x):
|
||||
return mx.where(mx.isinf(x), 0, 1)
|
||||
|
||||
x = mx.array([0, float("inf"), 1], dtype=mx.bfloat16)
|
||||
self.assertTrue(mx.array_equal(mx.compile(fn)(x), fn(x)))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
Loading…
Reference in New Issue
Block a user