mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-24 02:41:19 +08:00
add zero vjps for bitwise ops and gather w.r.t. index (#1256)
This commit is contained in:
parent
20bb301195
commit
bdb36c9a63
@ -523,7 +523,7 @@ std::vector<array> AsType::vjp(
|
||||
const std::vector<array>&) {
|
||||
if (cotangents[0].dtype() != dtype_) {
|
||||
throw std::invalid_argument(
|
||||
"[astype] Type of cotangentsgent does not much primal output type.");
|
||||
"[astype] Type of cotangents does not match primal output type.");
|
||||
}
|
||||
return {astype(cotangents[0], primals[0].dtype(), stream())};
|
||||
}
|
||||
@ -629,6 +629,26 @@ std::pair<std::vector<array>, std::vector<int>> BitwiseBinary::vmap(
|
||||
{to_ax}};
|
||||
}
|
||||
|
||||
std::vector<array> BitwiseBinary::jvp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& tangents,
|
||||
const std::vector<int>& argnums) {
|
||||
assert(primals.size() == 2);
|
||||
std::vector<array> vjps = {zeros_like(tangents[0], stream())};
|
||||
if (argnums.size() > 1) {
|
||||
vjps.push_back(vjps.back());
|
||||
}
|
||||
return vjps;
|
||||
}
|
||||
|
||||
std::vector<array> BitwiseBinary::vjp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& cotangents,
|
||||
const std::vector<int>& argnums,
|
||||
const std::vector<array>&) {
|
||||
return jvp(primals, cotangents, argnums);
|
||||
}
|
||||
|
||||
std::vector<array> Broadcast::vjp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& cotangents,
|
||||
@ -1598,13 +1618,19 @@ std::vector<array> Gather::vjp(
|
||||
const std::vector<array>& cotangents,
|
||||
const std::vector<int>& argnums,
|
||||
const std::vector<array>&) {
|
||||
if (argnums.size() > 1 || argnums[0] != 0) {
|
||||
throw std::invalid_argument(
|
||||
"[gather] Cannot calculate VJP with respect to indices.");
|
||||
std::vector<array> vjps;
|
||||
for (int argnum : argnums) {
|
||||
if (argnum > 0) {
|
||||
// Grads w.r.t. indices are zero
|
||||
vjps.push_back(
|
||||
zeros(primals[argnum].shape(), primals[argnum].dtype(), stream()));
|
||||
} else {
|
||||
auto src = zeros_like(primals[0], stream());
|
||||
std::vector<array> inds(primals.begin() + 1, primals.end());
|
||||
vjps.push_back(scatter_add(src, inds, cotangents[0], axes_, stream()));
|
||||
}
|
||||
}
|
||||
auto src = zeros_like(primals[0], stream());
|
||||
std::vector<array> inds(primals.begin() + 1, primals.end());
|
||||
return {scatter_add(src, inds, cotangents[0], axes_, stream())};
|
||||
return vjps;
|
||||
}
|
||||
|
||||
std::vector<array> Gather::jvp(
|
||||
|
@ -471,6 +471,7 @@ class BitwiseBinary : public UnaryPrimitive {
|
||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||
|
||||
DEFINE_VMAP()
|
||||
DEFINE_GRADS()
|
||||
bool is_equivalent(const Primitive& other) const override;
|
||||
void print(std::ostream& os) override;
|
||||
DEFINE_INPUT_OUTPUT_SHAPE()
|
||||
|
@ -1043,6 +1043,19 @@ class TestArray(mlx_tests.MLXTestCase):
|
||||
a_mlx = mx.array(a_np)
|
||||
self.assertTrue(np.array_equal(a_np[2:-1, 0], np.array(a_mlx[2:-1, 0])))
|
||||
|
||||
def test_indexing_grad(self):
|
||||
x = mx.array([[1, 2], [3, 4]]).astype(mx.float32)
|
||||
ind = mx.array([0, 1, 0]).astype(mx.float32)
|
||||
|
||||
def index_fn(x, ind):
|
||||
return x[ind.astype(mx.int32)].sum()
|
||||
|
||||
grad_x, grad_ind = mx.grad(index_fn, argnums=(0, 1))(x, ind)
|
||||
expected = mx.array([[2, 2], [1, 1]])
|
||||
|
||||
self.assertTrue(mx.array_equal(grad_x, expected))
|
||||
self.assertTrue(mx.array_equal(grad_ind, mx.zeros(ind.shape)))
|
||||
|
||||
def test_setitem(self):
|
||||
a = mx.array(0)
|
||||
a[None] = 1
|
||||
|
@ -2373,6 +2373,21 @@ class TestOps(mlx_tests.MLXTestCase):
|
||||
self.assertEqual(c.shape, (3, 2, 5))
|
||||
self.assertTrue(mx.array_equal(c, mx.ones((3, 2, 5), dtype=mx.bool_)))
|
||||
|
||||
def test_bitwise_grad(self):
|
||||
a = np.random.randint(0, 10, size=(4, 3))
|
||||
b = np.random.randint(0, 10, size=(4, 3))
|
||||
cotangent = np.random.randint(0, 10, size=(4, 3))
|
||||
a = mx.array(a)
|
||||
b = mx.array(b)
|
||||
cotangent = mx.array(cotangent)
|
||||
|
||||
def bitwise(a, b):
|
||||
return a.astype(mx.int32) & b.astype(mx.int32)
|
||||
|
||||
_, vjps = mx.vjp(bitwise, [a, b], [cotangent])
|
||||
for vjp in vjps:
|
||||
self.assertFalse(np.any(np.array(vjp)))
|
||||
|
||||
def test_conjugate(self):
|
||||
shape = (3, 5, 7)
|
||||
a = np.random.normal(size=shape) + 1j * np.random.normal(size=shape)
|
||||
|
Loading…
Reference in New Issue
Block a user