mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-31 15:21:19 +08:00
Fix logsumexp edge case (#740)
* fix logsumexp * fix inf constant * also fix power grad * fix ternary dispatch
This commit is contained in:
parent
ac02cf33bd
commit
e6418781ab
@ -7,6 +7,10 @@
|
||||
|
||||
namespace mlx::core::detail {
|
||||
|
||||
namespace {
|
||||
constexpr float inf = std::numeric_limits<float>::infinity();
|
||||
} // namespace
|
||||
|
||||
typedef union {
|
||||
int i;
|
||||
float f;
|
||||
|
@ -3,11 +3,13 @@ set(
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/atomic.h
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/bf16.h
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/bf16_math.h
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/binary.h
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/complex.h
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/defines.h
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/erf.h
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/indexing.h
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/reduce.h
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/unary.h
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/utils.h
|
||||
)
|
||||
|
||||
|
@ -2,16 +2,6 @@
|
||||
|
||||
#include "mlx/backend/metal/kernels/binary.h"
|
||||
|
||||
template <typename T, typename U, typename Op>
|
||||
[[kernel]] void binary_op_s2s(
|
||||
device const T* a,
|
||||
device const T* b,
|
||||
device U* c,
|
||||
uint index [[thread_position_in_grid]]) {
|
||||
c[index] = Op()(a[0], b[0]);
|
||||
}
|
||||
|
||||
|
||||
template <typename T, typename U, typename Op>
|
||||
[[kernel]] void binary_op_ss(
|
||||
device const T* a,
|
||||
|
@ -7,6 +7,16 @@
|
||||
#include "mlx/backend/metal/kernels/bf16.h"
|
||||
#include "mlx/backend/metal/kernels/ternary.h"
|
||||
|
||||
template <typename T, typename Op>
|
||||
[[kernel]] void ternary_op_v(
|
||||
device const bool* a,
|
||||
device const T* b,
|
||||
device const T* c,
|
||||
device T* d,
|
||||
uint index [[thread_position_in_grid]]) {
|
||||
d[index] = Op()(a[index], b[index], c[index]);
|
||||
}
|
||||
|
||||
template <typename T, typename Op>
|
||||
[[kernel]] void ternary_op_g_nd1(
|
||||
device const bool* a,
|
||||
@ -94,6 +104,15 @@ template <typename T, typename Op>
|
||||
d[out_idx] = Op()(a[idx.x], b[idx.y], c[idx.z]);
|
||||
}
|
||||
|
||||
#define instantiate_ternary_v(name, type, op) \
|
||||
template [[host_name(name)]] \
|
||||
[[kernel]] void ternary_op_v<type, op>( \
|
||||
device const bool* a, \
|
||||
device const type* b, \
|
||||
device const type* c, \
|
||||
device type* d, \
|
||||
uint index [[thread_position_in_grid]]); \
|
||||
|
||||
#define instantiate_ternary_g(name, type, op) \
|
||||
template [[host_name(name)]] \
|
||||
[[kernel]] void ternary_op_g<type, op>( \
|
||||
@ -160,14 +179,10 @@ template <typename T, typename Op>
|
||||
instantiate_ternary_g_dim(name, type, op, 5) \
|
||||
|
||||
#define instantiate_ternary_all(name, tname, type, op) \
|
||||
instantiate_ternary_v("v" #name #tname, type, op) \
|
||||
instantiate_ternary_g("g" #name #tname, type, op) \
|
||||
instantiate_ternary_g_nd("g" #name #tname, type, op) \
|
||||
|
||||
#define instantiate_ternary_float(name, op) \
|
||||
instantiate_ternary_all(name, float16, half, op) \
|
||||
instantiate_ternary_all(name, float32, float, op) \
|
||||
instantiate_ternary_all(name, bfloat16, bfloat16_t, op)
|
||||
|
||||
#define instantiate_ternary_types(name, op) \
|
||||
instantiate_ternary_all(name, bool_, bool, op) \
|
||||
instantiate_ternary_all(name, uint8, uint8_t, op) \
|
||||
@ -178,7 +193,9 @@ template <typename T, typename Op>
|
||||
instantiate_ternary_all(name, int16, int16_t, op) \
|
||||
instantiate_ternary_all(name, int32, int32_t, op) \
|
||||
instantiate_ternary_all(name, int64, int64_t, op) \
|
||||
instantiate_ternary_all(name, float16, half, op) \
|
||||
instantiate_ternary_all(name, float32, float, op) \
|
||||
instantiate_ternary_all(name, bfloat16, bfloat16_t, op) \
|
||||
instantiate_ternary_all(name, complex64, complex64_t, op) \
|
||||
instantiate_ternary_float(name, op)
|
||||
|
||||
instantiate_ternary_types(select, Select)
|
||||
|
@ -9,6 +9,10 @@
|
||||
#include "mlx/backend/metal/kernels/erf.h"
|
||||
#include "mlx/backend/metal/kernels/utils.h"
|
||||
|
||||
namespace {
|
||||
constant float inf = metal::numeric_limits<float>::infinity();
|
||||
}
|
||||
|
||||
struct Abs {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
|
@ -1,5 +1,4 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include <algorithm>
|
||||
#include <cassert>
|
||||
#include <numeric>
|
||||
@ -240,11 +239,15 @@ void ternary_op(
|
||||
auto& strides_out = strides[3];
|
||||
|
||||
std::ostringstream kname;
|
||||
kname << "g";
|
||||
kname << op << type_to_name(b);
|
||||
if (topt == TernaryOpType::General &&
|
||||
shape.size() <= MAX_BINARY_SPECIALIZED_DIMS) {
|
||||
kname << "_" << shape.size();
|
||||
if (topt == TernaryOpType::General) {
|
||||
kname << "g";
|
||||
kname << op << type_to_name(b);
|
||||
if (shape.size() <= MAX_BINARY_SPECIALIZED_DIMS) {
|
||||
kname << "_" << shape.size();
|
||||
}
|
||||
} else {
|
||||
kname << "v";
|
||||
kname << op << type_to_name(b);
|
||||
}
|
||||
|
||||
auto& s = out.primitive().stream();
|
||||
@ -257,44 +260,46 @@ void ternary_op(
|
||||
set_array_buffer(compute_encoder, c, 2);
|
||||
set_array_buffer(compute_encoder, out, 3);
|
||||
|
||||
auto ndim = shape.size();
|
||||
if (ndim > 3) {
|
||||
compute_encoder->setBytes(shape.data(), ndim * sizeof(int), 4);
|
||||
compute_encoder->setBytes(strides_a.data(), ndim * sizeof(size_t), 5);
|
||||
compute_encoder->setBytes(strides_b.data(), ndim * sizeof(size_t), 6);
|
||||
compute_encoder->setBytes(strides_c.data(), ndim * sizeof(size_t), 7);
|
||||
if (topt == TernaryOpType::General) {
|
||||
auto ndim = shape.size();
|
||||
if (ndim > 3) {
|
||||
compute_encoder->setBytes(shape.data(), ndim * sizeof(int), 4);
|
||||
compute_encoder->setBytes(strides_a.data(), ndim * sizeof(size_t), 5);
|
||||
compute_encoder->setBytes(strides_b.data(), ndim * sizeof(size_t), 6);
|
||||
compute_encoder->setBytes(strides_c.data(), ndim * sizeof(size_t), 7);
|
||||
|
||||
if (ndim > MAX_BINARY_SPECIALIZED_DIMS) {
|
||||
compute_encoder->setBytes(&ndim, sizeof(int), 8);
|
||||
if (ndim > MAX_BINARY_SPECIALIZED_DIMS) {
|
||||
compute_encoder->setBytes(&ndim, sizeof(int), 8);
|
||||
}
|
||||
} else {
|
||||
// The shape is implicit in the grid for <= 3D
|
||||
compute_encoder->setBytes(strides_a.data(), ndim * sizeof(size_t), 4);
|
||||
compute_encoder->setBytes(strides_b.data(), ndim * sizeof(size_t), 5);
|
||||
compute_encoder->setBytes(strides_c.data(), ndim * sizeof(size_t), 6);
|
||||
}
|
||||
} else if (ndim > 0) {
|
||||
// The shape is implicit in the grid for <= 3D
|
||||
compute_encoder->setBytes(strides_a.data(), ndim * sizeof(size_t), 4);
|
||||
compute_encoder->setBytes(strides_b.data(), ndim * sizeof(size_t), 5);
|
||||
compute_encoder->setBytes(strides_c.data(), ndim * sizeof(size_t), 6);
|
||||
} else {
|
||||
// For 0-dim we still need to bind something to these buffers since the
|
||||
// current ternary kernels always access the strides.
|
||||
size_t dummy_stride = 0;
|
||||
int dummy_shape = 0;
|
||||
compute_encoder->setBytes(&dummy_shape, sizeof(int), 4);
|
||||
compute_encoder->setBytes(&dummy_stride, sizeof(size_t), 5);
|
||||
compute_encoder->setBytes(&dummy_stride, sizeof(size_t), 6);
|
||||
compute_encoder->setBytes(&dummy_stride, sizeof(size_t), 7);
|
||||
compute_encoder->setBytes(&ndim, sizeof(int), 8);
|
||||
}
|
||||
|
||||
// Launch up to 3D grid of threads
|
||||
size_t dim0 = ndim > 0 ? shape[ndim - 1] : 1;
|
||||
size_t dim1 = ndim > 1 ? shape[ndim - 2] : 1;
|
||||
size_t rest = out.size() / (dim0 * dim1);
|
||||
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
||||
if (thread_group_size != 1024) {
|
||||
throw std::runtime_error("[Metal::binary] Must use 1024 sized block");
|
||||
// Launch up to 3D grid of threads
|
||||
size_t dim0 = ndim > 0 ? shape[ndim - 1] : 1;
|
||||
size_t dim1 = ndim > 1 ? shape[ndim - 2] : 1;
|
||||
size_t rest = out.size() / (dim0 * dim1);
|
||||
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
||||
if (thread_group_size != 1024) {
|
||||
throw std::runtime_error("[Metal::binary] Must use 1024 sized block");
|
||||
}
|
||||
MTL::Size group_dims = get_block_dims(dim0, dim1, rest);
|
||||
MTL::Size grid_dims = MTL::Size(dim0, dim1, rest);
|
||||
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||
} else {
|
||||
// Launch a 1D grid of threads
|
||||
size_t nthreads = out.data_size();
|
||||
MTL::Size grid_dims = MTL::Size(nthreads, 1, 1);
|
||||
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
||||
if (thread_group_size > nthreads) {
|
||||
thread_group_size = nthreads;
|
||||
}
|
||||
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
|
||||
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||
}
|
||||
MTL::Size group_dims = get_block_dims(dim0, dim1, rest);
|
||||
MTL::Size grid_dims = MTL::Size(dim0, dim1, rest);
|
||||
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
void unary_op(
|
||||
|
@ -1755,7 +1755,11 @@ array logsumexp(
|
||||
StreamOrDevice s /* = {}*/) {
|
||||
auto maxval = stop_gradient(max(a, axes, true, s));
|
||||
auto out = log(sum(exp(subtract(a, maxval, s), s), axes, keepdims, s), s);
|
||||
return add(out, reshape(maxval, out.shape(), s), s);
|
||||
out = add(out, reshape(maxval, out.shape(), s), s);
|
||||
if (!keepdims) {
|
||||
maxval = squeeze(maxval, axes, s);
|
||||
}
|
||||
return where(isinf(maxval, s), maxval, out, s);
|
||||
}
|
||||
|
||||
array logsumexp(
|
||||
|
@ -2043,7 +2043,10 @@ std::vector<array> Power::vjp(
|
||||
primals[1],
|
||||
stream()));
|
||||
} else {
|
||||
vjps.push_back(multiply(log(primals[0], stream()), outputs[0], stream()));
|
||||
auto& exp = outputs[0];
|
||||
auto exp_vjp = multiply(log(primals[0], stream()), outputs[0], stream());
|
||||
// 0 * log 0 -> 0
|
||||
vjps.push_back(where(exp, exp_vjp, array(0.0f, exp.dtype()), stream()));
|
||||
}
|
||||
vjps.back() = multiply(cotangents[0], vjps.back(), stream());
|
||||
}
|
||||
|
@ -415,6 +415,14 @@ class TestAutograd(mlx_tests.MLXTestCase):
|
||||
_, vjps = mx.vjp(func, (arr,), (cotan,))
|
||||
self.assertEqual(vjps[0].item(), 8.0)
|
||||
|
||||
def test_power_grad(self):
|
||||
def fun(x, y):
|
||||
res = x - y
|
||||
return res**x
|
||||
|
||||
grad = mx.grad(fun)(mx.array(1.0), mx.array(1.0))
|
||||
self.assertEqual(grad.item(), 1.0)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
@ -539,6 +539,15 @@ class TestCompile(mlx_tests.MLXTestCase):
|
||||
z = fun(mx.array(1), "two")
|
||||
self.assertEqual(z.item(), 3)
|
||||
|
||||
def test_compile_inf(self):
|
||||
|
||||
@mx.compile
|
||||
def fun(x):
|
||||
return mx.isinf(x + 2)
|
||||
|
||||
out = fun(mx.array([0.0]))
|
||||
self.assertEqual(out.item(), False)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
@ -66,13 +66,15 @@ class TestLoad(mlx_tests.MLXTestCase):
|
||||
def test_save_and_load_safetensors(self):
|
||||
if not os.path.isdir(self.test_dir):
|
||||
os.mkdir(self.test_dir)
|
||||
|
||||
test_file = os.path.join(self.test_dir, "test.safetensors")
|
||||
with self.assertRaises(Exception):
|
||||
mx.save_safetensors("test", {"a": mx.ones((4, 4))}, {"testing": 0})
|
||||
mx.save_safetensors(test_file, {"a": mx.ones((4, 4))}, {"testing": 0})
|
||||
|
||||
mx.save_safetensors(
|
||||
"test", {"test": mx.ones((2, 2))}, {"testing": "test", "format": "mlx"}
|
||||
test_file, {"test": mx.ones((2, 2))}, {"testing": "test", "format": "mlx"}
|
||||
)
|
||||
res = mx.load("test.safetensors", return_metadata=True)
|
||||
res = mx.load(test_file, return_metadata=True)
|
||||
self.assertEqual(len(res), 2)
|
||||
self.assertEqual(res[1], {"testing": "test", "format": "mlx"})
|
||||
|
||||
|
@ -791,13 +791,13 @@ TEST_CASE("test reduction ops") {
|
||||
constexpr float inf = std::numeric_limits<float>::infinity();
|
||||
|
||||
x = array({-inf, -inf});
|
||||
WARN_EQ(logsumexp(x).item<float>(), -inf);
|
||||
CHECK_EQ(logsumexp(x).item<float>(), -inf);
|
||||
|
||||
x = array({0.0f, -inf});
|
||||
CHECK_EQ(logsumexp(x).item<float>(), 0.0f);
|
||||
|
||||
x = array({0.0f, inf});
|
||||
WARN_EQ(logsumexp(x).item<float>(), inf);
|
||||
CHECK_EQ(logsumexp(x).item<float>(), inf);
|
||||
|
||||
x = reshape(arange(6, float32), {2, 3});
|
||||
|
||||
@ -2819,4 +2819,4 @@ TEST_CASE("test atleast_3d") {
|
||||
out = atleast_3d(x);
|
||||
CHECK_EQ(out.ndim(), 3);
|
||||
CHECK_EQ(out.shape(), std::vector<int>{3, 1, 1});
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user