-
Notifications
You must be signed in to change notification settings - Fork 149
Determinant of factorized matrices #1785
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -14,7 +14,7 @@ | |
| node_rewriter, | ||
| ) | ||
| from pytensor.graph.rewriting.unify import OpPattern | ||
| from pytensor.scalar.basic import Abs, Log, Mul, Sign | ||
| from pytensor.scalar.basic import Abs, Exp, Log, Mul, Sign, Sqr | ||
| from pytensor.tensor.basic import ( | ||
| AllocDiag, | ||
| ExtractDiag, | ||
|
|
@@ -23,6 +23,7 @@ | |
| concatenate, | ||
| diag, | ||
| diagonal, | ||
| ones, | ||
| ) | ||
| from pytensor.tensor.blockwise import Blockwise | ||
| from pytensor.tensor.elemwise import DimShuffle, Elemwise | ||
|
|
@@ -46,9 +47,12 @@ | |
| ) | ||
| from pytensor.tensor.rewriting.blockwise import blockwise_of | ||
| from pytensor.tensor.slinalg import ( | ||
| LU, | ||
| QR, | ||
| BlockDiagonal, | ||
| Cholesky, | ||
| CholeskySolve, | ||
| LUFactor, | ||
| Solve, | ||
| SolveBase, | ||
| SolveTriangular, | ||
|
|
@@ -65,6 +69,10 @@ | |
| MATRIX_INVERSE_OPS = (MatrixInverse, MatrixPinv) | ||
|
|
||
|
|
||
| def matrix_diagonal_product(x): | ||
| return pt.prod(diagonal(x, axis1=-2, axis2=-1), axis=-1) | ||
|
|
||
|
|
||
| def is_matrix_transpose(x: TensorVariable) -> bool: | ||
| """Check if a variable corresponds to a transpose of the last two axes""" | ||
| node = x.owner | ||
|
|
@@ -281,41 +289,39 @@ def cholesky_ldotlt(fgraph, node): | |
|
|
||
| @register_stabilize | ||
| @register_specialize | ||
| @node_rewriter([det]) | ||
| def local_det_chol(fgraph, node): | ||
| """ | ||
| If we have det(X) and there is already an L=cholesky(X) | ||
| floating around, then we can use prod(diag(L)) to get the determinant. | ||
| @node_rewriter([log]) | ||
| def local_log_prod_to_sum_log(fgraph, node): | ||
| """Rewrite log(prod(x)) as sum(log(x)), when x is known to be positive.""" | ||
| [p] = node.inputs | ||
| p_node = p.owner | ||
|
|
||
| """ | ||
| (x,) = node.inputs | ||
| for cl, xpos in fgraph.clients[x]: | ||
| if isinstance(cl.op, Blockwise) and isinstance(cl.op.core_op, Cholesky): | ||
| L = cl.outputs[0] | ||
| return [prod(diagonal(L, axis1=-2, axis2=-1) ** 2, axis=-1)] | ||
| if p_node is None: | ||
| return None | ||
|
|
||
| p_op = p_node.op | ||
|
|
||
| @register_canonicalize | ||
| @register_stabilize | ||
| @register_specialize | ||
| @node_rewriter([log]) | ||
| def local_log_prod_sqr(fgraph, node): | ||
| """ | ||
| This utilizes a boolean `positive` tag on matrices. | ||
| """ | ||
| (x,) = node.inputs | ||
| if x.owner and isinstance(x.owner.op, Prod): | ||
| # we cannot always make this substitution because | ||
| # the prod might include negative terms | ||
| p = x.owner.inputs[0] | ||
| if isinstance(p_op, Prod): | ||
| x = p_node.inputs[0] | ||
|
|
||
| # p is the matrix we're reducing with prod | ||
| if getattr(p.tag, "positive", None) is True: | ||
| return [log(p).sum(axis=x.owner.op.axis)] | ||
| # TODO: The product of diagonals of a Cholesky(A) are also strictly positive | ||
| if ( | ||
| x.owner is not None | ||
| and isinstance(x.owner.op, Elemwise) | ||
| and isinstance(x.owner.op.scalar_op, Abs | Sqr | Exp) | ||
| ) or getattr(x.tag, "positive", False): | ||
| return [log(x).sum(axis=p_node.op.axis)] | ||
|
|
||
| # TODO: have a reduction like prod and sum that simply | ||
| # returns the sign of the prod multiplication. | ||
|
|
||
| # Special case for log(abs(prod(x))) -> sum(log(abs(x))) that shows up in slogdet | ||
| elif isinstance(p_op, Elemwise) and isinstance(p_op.scalar_op, Abs): | ||
| [p] = p_node.inputs | ||
| p_node = p.owner | ||
| if p_node is not None and isinstance(p_node.op, Prod): | ||
| [x] = p.owner.inputs | ||
| return [log(abs(x)).sum(axis=p_node.op.axis)] | ||
|
|
||
|
|
||
| @register_specialize | ||
| @node_rewriter([blockwise_of(MatrixInverse | Cholesky | MatrixPinv)]) | ||
|
|
@@ -442,6 +448,127 @@ def _find_diag_from_eye_mul(potential_mul_input): | |
| return eye_input, non_eye_inputs | ||
|
|
||
|
|
||
| @register_stabilize | ||
| @register_specialize | ||
| @node_rewriter([det]) | ||
| def det_of_matrix_factorized_elsewhere(fgraph, node): | ||
| """ | ||
| If we have det(X) or abs(det(X)) and there is already a nice decomposition(X) floating around, | ||
| use it to compute it more cheaply | ||
|
|
||
| """ | ||
| [det] = node.outputs | ||
| [x] = node.inputs | ||
|
|
||
| only_used_by_abs = all( | ||
| isinstance(client.op, Elemwise) and isinstance(client.op.scalar_op, Abs) | ||
| for client, _ in fgraph.clients[det] | ||
| ) | ||
|
|
||
| new_det = None | ||
| for client, _ in fgraph.clients[x]: | ||
| core_op = client.op.core_op if isinstance(client.op, Blockwise) else client.op | ||
| match core_op: | ||
| case Cholesky(): | ||
| L = client.outputs[0] | ||
| new_det = matrix_diagonal_product(L) ** 2 | ||
|
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. TODO: Add the positive tag here. Possibly also rewrite for log(x ** 2) -> log(x) * 2, when we know x is positive |
||
| case LU(): | ||
| U = client.outputs[-1] | ||
| new_det = matrix_diagonal_product(U) | ||
| case LUFactor(): | ||
| LU_packed = client.outputs[0] | ||
| new_det = matrix_diagonal_product(LU_packed) | ||
| case _: | ||
| if not only_used_by_abs: | ||
| continue | ||
| match core_op: | ||
| case SVD(): | ||
| lmbda = ( | ||
| client.outputs[1] | ||
| if core_op.compute_uv | ||
| else client.outputs[0] | ||
| ) | ||
| new_det = prod(lmbda, axis=-1) | ||
| case QR(): | ||
| R = client.outputs[-1] | ||
| # if mode == "economic", R may not be square and this rewrite could hide a shape error | ||
| # That's why it's tagged as `shape_unsafe` | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This rewrite isn't tagged shape_unsafe |
||
| new_det = matrix_diagonal_product(R) | ||
|
|
||
| if new_det is not None: | ||
| # found a match | ||
| break | ||
| else: # no-break (i.e., no-match) | ||
| return None | ||
|
|
||
| [det] = node.outputs | ||
| copy_stack_trace(det, new_det) | ||
| return [new_det] | ||
|
|
||
|
|
||
| @register_stabilize("shape_unsafe") | ||
| @register_specialize("shape_unsafe") | ||
| @node_rewriter(tracks=[det]) | ||
| def det_of_factorized_matrix(fgraph, node): | ||
| """Introduce special forms for det(decomposition(X)). | ||
|
|
||
| Some cases are only known up to a sign change such as det(QR(X)), | ||
| and are only introduced if the determinant is only ever used inside an abs | ||
| """ | ||
| [det] = node.outputs | ||
| [x] = node.inputs | ||
|
|
||
| only_used_by_abs = all( | ||
| isinstance(client.op, Elemwise) and isinstance(client.op.scalar_op, Abs) | ||
| for client, _ in fgraph.clients[det] | ||
| ) | ||
|
|
||
| x_node = x.owner | ||
| if x_node is None: | ||
| return None | ||
|
|
||
| x_op = x_node.op | ||
| core_op = x_op.core_op if isinstance(x_op, Blockwise) else x_op | ||
|
|
||
| new_det = None | ||
| match core_op: | ||
| case Cholesky(): | ||
| new_det = matrix_diagonal_product(x) | ||
| case LU(): | ||
| if x is x_node.outputs[-2]: | ||
| # x is L | ||
| new_det = ones(x.shape[:-2], dtype=det.dtype) | ||
| elif x is x_node.outputs[-1]: | ||
| # x is U | ||
| new_det = matrix_diagonal_product(x) | ||
| case SVD(): | ||
| if not core_op.compute_uv or x is x_node.outputs[1]: | ||
| # x is lambda | ||
| new_det = prod(x, axis=-1) | ||
| elif only_used_by_abs: | ||
| # x is either U or Vt and only ever used inside an abs | ||
| new_det = ones(x.shape[:-2], dtype=det.dtype) | ||
| case QR(): | ||
| # if mode == "economic", Q/R may not be square and this rewrite could hide a shape error | ||
| # That's why it's tagged as `shape_unsafe` | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is it worth handling this case in a separate rewrite so as to not tag the others as shape_unsafe (since they aren't)
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I've thought about that as well. That's almost always the case, only a subset of the matching cases is actually unsafe in a rewrite. OTOH the tag is mostly a debug thing, if you're getting an odd result or shape error you may want to exclude to see if it goes away or the error is more obvious. You never really want to exclude them at the end of the day |
||
| if x is x_node.outputs[-1]: | ||
| # x is R | ||
| new_det = matrix_diagonal_product(x) | ||
| elif ( | ||
| only_used_by_abs | ||
| and core_op.mode in ("economic", "full") | ||
| and x is x_node.outputs[0] | ||
| ): | ||
| # x is Q and it's only ever used inside an abs | ||
| new_det = ones(x.shape[:-2], dtype=det.dtype) | ||
|
|
||
| if new_det is None: | ||
| return None | ||
|
|
||
| copy_stack_trace(det, new_det) | ||
| return [new_det] | ||
|
|
||
|
|
||
| @register_canonicalize("shape_unsafe") | ||
| @register_stabilize("shape_unsafe") | ||
| @node_rewriter([det]) | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Any Op that that maps (-1, 1) to the same value is actually fine, At the very least should include square as well