Fix a couple bugs (#1161)

* fix jit reduce for RMS norm

* make strides a single buffer

* better eval error message

* fix compiling with inf and bf16

* fix cpu compile with bf16
This commit is contained in:
Awni Hannun
2024-05-28 15:18:18 -07:00
committed by GitHub
parent a87ef5bfc1
commit e7a2a3dcd1
9 changed files with 59 additions and 27 deletions

View File

@@ -259,11 +259,14 @@ MTL::ComputePipelineState* get_reduce_init_kernel(
MTL::ComputePipelineState* get_reduce_kernel(
metal::Device& d,
const std::string& kernel_name,
const std::string& op_name,
const array& in,
const array& out) {
std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1);
auto lib = d.get_library(lib_name);
if (lib == nullptr) {
std::string op_type = op_name;
op_type[0] = std::toupper(op_name[0]);
bool non_atomic = out.dtype() == int64 || out.dtype() == uint64;
std::ostringstream kernel_source;
kernel_source << metal::utils() << metal::reduce_utils() << metal::reduce()
@@ -273,7 +276,7 @@ MTL::ComputePipelineState* get_reduce_kernel(
lib_name,
get_type_string(in.dtype()),
get_type_string(out.dtype()),
op_name(out));
op_type);
lib = d.get_library(lib_name, kernel_source.str());
}
return d.get_kernel(kernel_name, lib);