mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Compare commits
2 Commits
eff0e31f00
...
193cdcd81a
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
193cdcd81a | ||
|
|
d8ceae7b77 |
@@ -332,9 +332,17 @@ bool is_graph_updatable(cudaGraph_t graph, int& cluster_dim_x) {
|
||||
for (const auto& node : nodes) {
|
||||
cudaGraphNodeType type;
|
||||
CHECK_CUDA_ERROR(cudaGraphNodeGetType(node, &type));
|
||||
if (type != cudaGraphNodeTypeKernel) {
|
||||
if (type == cudaGraphNodeTypeGraph) {
|
||||
// Try to be updatable for a structure like graph -> graph -> kernel
|
||||
if (num_nodes > 1) {
|
||||
return false;
|
||||
}
|
||||
cudaGraph_t child;
|
||||
CHECK_CUDA_ERROR(cudaGraphChildGraphNodeGetGraph(node, &child));
|
||||
return is_graph_updatable(child, cluster_dim_x);
|
||||
} else if (type != cudaGraphNodeTypeKernel) {
|
||||
return false;
|
||||
} else {
|
||||
cudaLaunchAttributeValue cluster_dim;
|
||||
CHECK_CUDA_ERROR(cudaGraphKernelNodeGetAttribute(
|
||||
node, cudaLaunchAttributeClusterDimension, &cluster_dim));
|
||||
@@ -348,6 +356,7 @@ bool is_graph_updatable(cudaGraph_t graph, int& cluster_dim_x) {
|
||||
}
|
||||
cluster_dim_x = cluster_dim.clusterDim.x;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
@@ -362,7 +371,7 @@ void CommandEncoder::add_graph_node(cudaGraph_t child) {
|
||||
}
|
||||
cudaGraphNode_t node;
|
||||
int cluster_dim_x = 0;
|
||||
is_graph_updatable_ = is_graph_updatable(child, cluster_dim_x);
|
||||
is_graph_updatable_ &= is_graph_updatable(child, cluster_dim_x);
|
||||
CHECK_CUDA_ERROR(cudaGraphAddChildGraphNode(&node, graph_, NULL, 0, child));
|
||||
insert_graph_dependencies(
|
||||
GraphNode{node, "G" + std::to_string(cluster_dim_x)});
|
||||
|
||||
@@ -3846,6 +3846,62 @@ std::vector<array> Reduce::vjp(
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<array> Reduce::jvp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& tangents,
|
||||
const std::vector<int>& argnums) {
|
||||
auto in = primals[0];
|
||||
auto s = stream();
|
||||
|
||||
auto grad_op = [&s, reduce_type = reduce_type_](
|
||||
const array& x, const array& tan, int axis) {
|
||||
if (reduce_type == Reduce::Min) {
|
||||
auto idx = argmin(x, axis, true, s);
|
||||
return take_along_axis(tan, idx, axis, s);
|
||||
} else if (reduce_type == Reduce::Max) {
|
||||
auto idx = argmax(x, axis, true, s);
|
||||
return take_along_axis(tan, idx, axis, s);
|
||||
} else {
|
||||
auto p1 = cumprod(x, axis, /*reverse=*/false, /*inclusive=*/false, s);
|
||||
auto p2 = cumprod(x, axis, /*reverse=*/true, /*inclusive=*/false, s);
|
||||
auto out = multiply(multiply(p1, p2, s), tan, s);
|
||||
return sum(out, axis, true, s);
|
||||
}
|
||||
};
|
||||
|
||||
auto tan = tangents[0];
|
||||
if (reduce_type_ == Reduce::Sum) {
|
||||
return {sum(tan, axes_, true, s)};
|
||||
} else {
|
||||
if (axes_.size() > 1) {
|
||||
std::vector<int> transpose_to;
|
||||
{
|
||||
// Find the transpose needed to move axes_ to the back.
|
||||
int j = 0;
|
||||
for (int i = 0; i < in.ndim(); i++) {
|
||||
if (j < axes_.size() && axes_[j] == i) {
|
||||
j++;
|
||||
} else {
|
||||
transpose_to.push_back(i);
|
||||
}
|
||||
}
|
||||
for (auto ax : axes_) {
|
||||
transpose_to.push_back(ax);
|
||||
}
|
||||
}
|
||||
|
||||
int start_ax = in.ndim() - axes_.size();
|
||||
in = flatten(transpose(in, transpose_to, s), start_ax, -1, s);
|
||||
tan = flatten(transpose(tan, transpose_to, s), start_ax, -1, s);
|
||||
|
||||
auto grad = squeeze(grad_op(in, tan, -1), -1, s);
|
||||
return {expand_dims(grad, axes_, s)};
|
||||
} else {
|
||||
return {grad_op(in, tan, axes_[0])};
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::pair<std::vector<array>, std::vector<int>> Reduce::vmap(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<int>& axes) {
|
||||
|
||||
@@ -1751,12 +1751,7 @@ class Reduce : public UnaryPrimitive {
|
||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||
|
||||
DEFINE_VMAP()
|
||||
|
||||
std::vector<array> vjp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& cotangents,
|
||||
const std::vector<int>& argnums,
|
||||
const std::vector<array>& outputs) override;
|
||||
DEFINE_GRADS();
|
||||
|
||||
std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
|
||||
|
||||
|
||||
@@ -798,6 +798,22 @@ class TestAutograd(mlx_tests.MLXTestCase):
|
||||
grad_fn(model)
|
||||
self.assertEqual(model[1].item(), 2.0)
|
||||
|
||||
def test_reduce_jvp(self):
|
||||
a = mx.arange(4)
|
||||
b = mx.array([3, 2, 1, 0])
|
||||
|
||||
out, jout = mx.jvp(mx.sum, primals=(a,), tangents=(b,))
|
||||
self.assertEqual(jout[0].item(), 6)
|
||||
|
||||
out, jout = mx.jvp(mx.prod, primals=(a,), tangents=(b,))
|
||||
self.assertEqual(jout[0].item(), 18)
|
||||
|
||||
out, jout = mx.jvp(mx.min, primals=(a,), tangents=(b,))
|
||||
self.assertEqual(jout[0].item(), 3)
|
||||
|
||||
out, jout = mx.jvp(mx.max, primals=(a,), tangents=(b,))
|
||||
self.assertEqual(jout[0].item(), 0)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
mlx_tests.MLXTestRunner()
|
||||
|
||||
Reference in New Issue
Block a user