Source code for qics.cones.entropy.quantcondentr

# Copyright (c) 2024, Kerry He, James Saunderson, and Hamza Fawzi

# This Python package QICS is licensed under the MIT license; see LICENSE.md
# file in the root directory or at https://github.com/kerry-he/qics

import numpy as np
import scipy as sp

from qics._utils.gradient import D1_log, D2_log, scnd_frechet
from qics._utils.linalg import (
    cho_fact,
    cho_solve,
    congr_multi,
    dense_dot_x,
    inp,
    x_dot_dense,
)
from qics.cones.base import Cone
from qics.quantum import i_kr, i_kr_multi, p_tr, p_tr_multi
from qics.vectorize import get_full_to_compact_op, vec_to_mat


[docs] class QuantCondEntr(Cone): r"""A class representing a quantum conditional entropy cone defined on :math:`k` subsystems, where the :math:`i`-th subsystem has dimension :math:`n_i`, and the :math:`i`-th subsystem is being traced out, i.e., .. math:: \mathcal{QCE}_{\{n_i\}, j} = \text{cl}\{ (t, X) \in \mathbb{R} \times \mathbb{H}^{\Pi_in_i}_{++} : t \geq -S(X) + S(\text{tr}_i(X)) \}, where .. math:: S(X) = -\text{tr}[X \log(X)], is the quantum (von Neumann) entropy, and :math:`\text{tr}_i` is the partial trace on the :math:`i`-th subsystem. Parameters ---------- dims : :obj:`tuple` of :obj:`int` or :obj:`list` of :obj:`int` List of dimensions :math:`\{ n_i \}_{i=0}^{k-1}` of the :math:`k` subsystems. sys : :obj:`int` or :obj:`tuple` of :obj:`int` or :obj:`list` of :obj:`int` Which systems are being traced out by the partial trace. Can define multiple subsystems to trace out. iscomplex : :obj:`bool` Whether the matrix :math:`X` is defined over :math:`\mathbb{H}^n` (``True``), or restricted to :math:`\mathbb{S}^n` (``False``). The default is ``False``. See also -------- QuantRelEntr : Quantum relative entropy cone Notes ----- The quantum conditional entropy can also be modelled by the quantum relative entropy by noting the identity .. math:: S(X \| \mathbb{I} \otimes \text{tr}_1(X)) = -S(X) + S(\text{tr}_1(X)). However, the cone oracles for the quantum conditional entropy cone are much more efficient than those for the quantum relative entropy cone, so it is recommended to use the quantum conditional entropy cone where possible. """ def __init__(self, dims, sys, iscomplex=False): if isinstance(sys, int): sys = [sys] if isinstance(sys, tuple) or isinstance(sys, set): sys = list(sys) assert all( [sys_k < len(dims) for sys_k in sys] ), "Invalid subsystems specified, exceeds total number of dimensions provided." self.dims = dims self.sys = sys self.iscomplex = iscomplex self.N = np.prod(dims) # Total dim. of multipartite system self.m = np.prod([dims[k] for k in sys]) # Dim. of system traced out self.n = self.N // self.m # Dim. of system after partial trace self.nu = 1 + self.N # Barrier parameter if iscomplex: self.vn = self.n * self.n self.dim = [1, 2 * self.N * self.N] self.type = ["r", "h"] self.dtype = np.complex128 else: self.vn = self.n * (self.n + 1) // 2 self.dim = [1, self.N * self.N] self.type = ["r", "s"] self.dtype = np.float64 # Update flags self.feas_updated = False self.grad_updated = False self.hess_aux_updated = False self.invhess_aux_updated = False self.invhess_aux_aux_updated = False self.congr_aux_updated = False self.dder3_aux_updated = False return def get_iscomplex(self): return self.iscomplex def get_init_point(self, out): a = self.N * np.log(self.m) ** 2 + 1 b = -(2.0 + (1 + self.N) * self.N * np.log(self.m) ** 2) t0 = np.sqrt((-b - np.sqrt(b * b - 4 * a)) / (2 * a)) x0 = np.sqrt((1 + self.N - t0 * t0) / self.N) point = [np.array([[t0]]), np.eye(self.N, dtype=self.dtype) * x0] self.set_point(point, point) out[0][:] = point[0] out[1][:] = point[1] return out def get_feas(self): if self.feas_updated: return self.feas self.feas_updated = True (self.t, self.X) = self.primal self.Y = p_tr(self.X, self.dims, self.sys) # Check that X and Y are positive definite # Note that Y = pTr(X) should be positive definite if X is, # but check just to be sure that we can safely take logarithms of Y. self.Dx, self.Ux = np.linalg.eigh(self.X) self.Dy, self.Uy = np.linalg.eigh(self.Y) if any(self.Dx <= 0) or any(self.Dy <= 0): self.feas = False return self.feas # Check that t > -S(X) + S(pTr(X)) self.log_Dx = np.log(self.Dx) self.log_Dy = np.log(self.Dy) log_X = (self.Ux * self.log_Dx) @ self.Ux.conj().T log_X = (log_X + log_X.conj().T) * 0.5 log_Y = (self.Uy * self.log_Dy) @ self.Uy.conj().T log_Y = (log_Y + log_Y.conj().T) * 0.5 self.log_XY = log_X - i_kr(log_Y, self.dims, self.sys) self.z = self.t[0, 0] - inp(self.X, self.log_XY) self.feas = self.z > 0 return self.feas def get_val(self): return -np.log(self.z) - np.sum(self.log_Dx) def update_grad(self): assert self.feas_updated assert not self.grad_updated # Compute gradients of quantum conditional entropy # DPhi(X) = log(X) - I ⊗ pTr(X) self.DPhi = self.log_XY # Compute X^-1 inv_Dx = np.reciprocal(self.Dx) inv_X_rt2 = self.Ux * np.sqrt(inv_Dx) self.inv_X = inv_X_rt2 @ inv_X_rt2.conj().T # Compute gradient of barrier function self.zi = np.reciprocal(self.z) self.grad = [-self.zi, self.zi * self.DPhi - self.inv_X] self.grad_updated = True def hess_prod_ip(self, out, H): assert self.grad_updated if not self.hess_aux_updated: self.update_hessprod_aux() (Ht, Hx) = H Hy = p_tr(Hx, self.dims, self.sys) UxHxUx = self.Ux.conj().T @ Hx @ self.Ux UyHyUy = self.Uy.conj().T @ Hy @ self.Uy # Hessian products for quantum conditional entropy # D2Phi(X)[Hx] = Ux [log^[1](Dx) .* (Ux' Hx Ux)] Ux' # - I ⊗ [Uy [log^[1](Dy) .* (Uy' pTr(Hx) Uy)] Uy'] D2PhiH = self.Ux @ (self.D1x_log * UxHxUx) @ self.Ux.conj().T work = self.Uy @ (self.D1y_log * UyHyUy) @ self.Uy.conj().T D2PhiH -= i_kr(work, self.dims, self.sys) # ====================================================================== # Hessian products with respect to t # ====================================================================== # D2_t F(t, X)[Ht, Hx] = (Ht - D_X Phi(X)[Hx]) / z^2 out[0][:] = (Ht - inp(self.DPhi, Hx)) * self.zi2 # ====================================================================== # Hessian products with respect to X # ====================================================================== # D2_X F(t, X)[Ht, Hx] = -D2_t F(t, X)[Ht, Hx] * DPhi(X) # + D2Phi(X)[Hx] / z + X^-1 Hx X^-1 out_X = -out[0] * self.DPhi out_X += self.zi * D2PhiH out_X += self.inv_X @ Hx @ self.inv_X out[1][:] = (out_X + out_X.conj().T) * 0.5 return out def hess_congr(self, A): assert self.grad_updated if not self.hess_aux_updated: self.update_hessprod_aux() if not self.congr_aux_updated: self.congr_aux(A) p = A.shape[0] lhs = np.zeros((p, sum(self.dim))) Work0, Work1 = self.Work0, self.Work1 Work2, Work3 = self.Work2, self.Work3 # ====================================================================== # Hessian products with respect to t # ====================================================================== # D2_t F(t, X)[Ht, Hx] = (Ht - D_X Phi(X)[Hx]) / z^2 DPhi_vec = self.DPhi.view(np.float64).reshape((-1, 1)) out_t = self.At - (self.Ax_vec @ DPhi_vec).ravel() out_t *= self.zi2 lhs[:, 0] = out_t # ====================================================================== # Hessian products with respect to X # ====================================================================== # Hessian products for quantum conditional entropy # D2Phi(X)[Hx] = Ux [log^[1](Dx) .* (Ux' Hx Ux)] Ux' # - I ⊗ [Uy [log^[1](Dy) .* (Uy' pTr(Hx) Uy)] Uy'] # Compute first term, i.e., Ux [log^[1](Dx) .* (Ux' Hx Ux)] Ux' congr_multi(Work2, self.Ux.conj().T, self.Ax, Work3) Work2 *= self.D1x_comb congr_multi(Work1, self.Ux, Work2, Work3) # Compute second term, i.e., I ⊗ [Uy [log^[1](Dy) .* (Uy' Hy Uy)] Uy'] p_tr_multi(self.work1, self.Ax, self.dims, self.sys) # Compute pTr(Ax) congr_multi(self.work2, self.Uy.conj().T, self.work1, self.work3) self.work2 *= self.D1y_log * self.zi congr_multi(self.work1, self.Uy, self.work2, self.work3) i_kr_multi(Work0, self.work1, self.dims, self.sys) # Subtract the two terms to obtain D2Phi(X)[Hx] Work1 -= Work0 # Hessian product of barrier function # D2_X F(t, X)[Ht, Hx] = -D2_t F(t, X)[Ht, Hx] * DPhi(X) # + D2Phi(X)[Hx] / z + X^-1 Hx X^-1 np.outer(out_t, self.DPhi, out=Work0.reshape(p, -1)) Work1 -= Work0 out_X = Work1.reshape((p, -1)).view(np.float64) lhs[:, 1:] = out_X # Multiply A (H A') return dense_dot_x(lhs, A.T) def invhess_prod_ip(self, out, H): assert self.grad_updated if not self.hess_aux_updated: self.update_hessprod_aux() if not self.invhess_aux_updated: self.update_invhessprod_aux() # Computes inverse Hessian product of the QCE barrier with a single vector (Ht, Hx) # See invhess_congr() for additional comments (Ht, Hx) = H Wx = Hx + Ht * self.DPhi # ====================================================================== # Inverse Hessian products with respect to X # ====================================================================== # X = [D2S(X)^-1 - D2S(X)^-1 pTr' N^-1 pTr D2S(X)^-1] Wx # Apply D2S(X)^-1 = (Ux ⊗ Ux) log^[1](Dx) (Ux' ⊗ Ux') UxWxUx = self.Ux.conj().T @ Wx @ self.Ux Hinv_x = self.Ux @ (self.D1x_comb_inv * UxWxUx) @ self.Ux.conj().T work = -p_tr(Hinv_x, self.dims, self.sys) # Solve linear system N \ ( ... ) # Convert matrices to compact real vectors temp_vec = work.view(np.float64).reshape((-1, 1)) temp_vec = self.F2C_op @ temp_vec # Solve system temp_vec = cho_solve(self.hess_schur_fact, temp_vec) # Expand compact real vectors back into full matrices temp_vec = self.F2C_op.T @ temp_vec work = temp_vec.T.view(self.dtype).reshape((self.n, self.n)) # Apply pTr' = IKr temp = i_kr(work, self.dims, self.sys) # Apply D2S(X)^-1 = (Ux ⊗ Ux) log^[1](Dx) (Ux' ⊗ Ux') temp = self.Ux.conj().T @ temp @ self.Ux out_X = Hinv_x - self.Ux @ (self.D1x_comb_inv * temp) @ self.Ux.conj().T out[1][:] = (out_X + out_X.conj().T) * 0.5 # ====================================================================== # Inverse Hessian products with respect to t # ====================================================================== # z^2 Ht + <DPhi(X), X> out[0][:] = Ht * self.z2 + inp(out_X, self.DPhi) return out def invhess_congr(self, A): assert self.grad_updated if not self.hess_aux_updated: self.update_hessprod_aux() if not self.invhess_aux_updated: self.update_invhessprod_aux() if not self.congr_aux_updated: self.congr_aux(A) # The inverse Hessian product applied on (Ht, Hx) for the QCE barrier is # X = M \ Wx # t = z^2 Ht + <DPhi(X), X> # where Wx = Hx + Ht DPhi(X) # M = 1/z D2S(X) - 1/z pTr' D2S(pTr(X)) pTr + X^-1 ⊗ X^-1 # = (Ux ⊗ Ux) (1/z log + inv)^[1](Dx) (Ux' ⊗ Ux') # - 1/z pTr' (Uy ⊗ Uy) log^[1](Dy) (Uy' ⊗ Uy') pTr # # Treating [pTr' D2S(pTr(X)) pTr] as a low-rank perturbation of D2S(X), # we can solve linear systems with M by using the matrix inversion lemma # X = [D2S(X)^-1 - D2S(X)^-1 pTr' N^-1 pTr D2S(X)^-1] Wx # where # N = 1/z (Uy ⊗ Uy) [log^[1](Dy)]^-1 (Uy' ⊗ Uy') # - pTr (Ux ⊗ Ux) [(1/z log + inv)^[1](Dx)]^-1 (Ux' ⊗ Ux') pTr' p = A.shape[0] lhs = np.zeros((p, sum(self.dim))) Work0, Work1, work1 = self.Work0, self.Work1, self.work1 Work2, Work3 = self.Work2, self.Work3 # ====================================================================== # Inverse Hessian products with respect to X # ====================================================================== # X = [D2S(X)^-1 - D2S(X)^-1 pTr' N^-1 pTr D2S(X)^-1] Wx # Compute Wx = Hx + Ht DPhi(X) np.outer(self.At, self.DPhi, out=Work2.reshape((p, -1))) np.add(self.Ax, Work2, out=Work0) # Apply D2S(X)^-1 = (Ux ⊗ Ux) log^[1](Dx) (Ux' ⊗ Ux') congr_multi(Work2, self.Ux.conj().T, Work0, Work3) Work2 *= self.D1x_comb_inv congr_multi(Work0, self.Ux, Work2, Work3) # Apply pTr p_tr_multi(work1, Work0, self.dims, self.sys) work1 *= -1 # Solve linear system N \ ( ... ) # Convert matrices to compact real vectors work = work1.view(np.float64).reshape((p, -1)).T work = x_dot_dense(self.F2C_op, work) # Solve system work = cho_solve(self.hess_schur_fact, work) # Expand compact real vectors back into full matrices work = x_dot_dense(self.F2C_op.T, work) work1.view(np.float64).reshape((p, -1))[:] = work.T # Apply pTr' = iKr i_kr_multi(Work1, work1, self.dims, self.sys) # Apply D2S(X)^-1 = (Ux ⊗ Ux) log^[1](Dx) (Ux' ⊗ Ux') congr_multi(Work2, self.Ux.conj().T, Work1, Work3) Work2 *= self.D1x_comb_inv congr_multi(Work1, self.Ux, Work2, Work3) # Subtract previous expression from D2S(X)^-1 Wx to get X Work0 -= Work1 out_X = Work0.reshape((p, -1)).view(np.float64) lhs[:, 1:] = out_X # ====================================================================== # Inverse Hessian products with respect to t # ====================================================================== # z^2 Ht + <DPhi(X), X> DPhi_vec = self.DPhi.view(np.float64).reshape((-1, 1)) out_t = self.z2 * self.At out_t += (out_X @ DPhi_vec).ravel() lhs[:, 0] = out_t # Multiply A (H A') return dense_dot_x(lhs, A.T) def third_dir_deriv_axpy(self, out, H, a=True): assert self.grad_updated if not self.hess_aux_updated: self.update_hessprod_aux() if not self.dder3_aux_updated: self.update_dder3_aux() (Ht, Hx) = H Hy = p_tr(Hx, self.dims, self.sys) UxHxUx = self.Ux.conj().T @ Hx @ self.Ux UyHyUy = self.Uy.conj().T @ Hy @ self.Uy # Quantum conditional entropy Hessians D2PhiH = self.Ux @ (self.D1x_log * UxHxUx) @ self.Ux.conj().T work = self.Uy @ (self.D1y_log * UyHyUy) @ self.Uy.conj().T D2PhiH -= i_kr(work, self.dims, self.sys) # Quantum conditional entropy third order derivatives D3PhiHH = scnd_frechet(self.D2x_log, UxHxUx, UxHxUx, self.Ux) work = scnd_frechet(self.D2y_log, UyHyUy, UyHyUy, self.Uy) D3PhiHH -= i_kr(work, self.dims, self.sys) # Third derivative of barrier DPhiH = inp(self.DPhi, Hx) D2PhiHH = inp(D2PhiH, Hx) chi = Ht - DPhiH dder3_t = -2 * (self.zi**3) * (chi**2) - (self.zi**2) * D2PhiHH dder3_X = -dder3_t * self.DPhi dder3_X -= 2 * (self.zi**2) * chi * D2PhiH dder3_X += self.zi * D3PhiHH dder3_X -= 2 * self.inv_X @ Hx @ self.inv_X @ Hx @ self.inv_X dder3_X = (dder3_X + dder3_X.conj().T) * 0.5 out[0][:] += dder3_t * a out[1][:] += dder3_X * a return out # ========================================================================== # Auxilliary functions # ========================================================================== def congr_aux(self, A): assert not self.congr_aux_updated iscomplex = self.iscomplex # Get slices and views of A matrix to be used in congruence computations if sp.sparse.issparse(A): A = A.tocsr() self.Ax_vec = A[:, 1:] if sp.sparse.issparse(A): A = A.toarray() Ax_dense = np.ascontiguousarray(A[:, 1:]) self.At = A[:, 0] self.Ax = np.array([vec_to_mat(Ax_k, iscomplex) for Ax_k in Ax_dense]) # Preallocate matrices we will need when performing these congruences self.Work0 = np.empty_like(self.Ax) self.Work1 = np.empty_like(self.Ax) self.Work2 = np.empty_like(self.Ax) self.Work3 = np.empty_like(self.Ax) p = A.shape[0] self.work1 = np.empty((p, self.n, self.n), dtype=self.dtype) self.work2 = np.empty((p, self.n, self.n), dtype=self.dtype) self.work3 = np.empty((p, self.n, self.n), dtype=self.dtype) self.congr_aux_updated = True def update_hessprod_aux(self): assert not self.hess_aux_updated assert self.grad_updated D1x_inv = np.reciprocal(np.outer(self.Dx, self.Dx)) self.D1x_log = D1_log(self.Dx, self.log_Dx) self.D1x_comb = self.zi * self.D1x_log + D1x_inv self.D1y_log = D1_log(self.Dy, self.log_Dy) # Preparing other required variables self.zi2 = self.zi * self.zi self.hess_aux_updated = True def update_dder3_aux(self): assert not self.dder3_aux_updated assert self.hess_aux_updated self.D2x_log = D2_log(self.Dx, self.D1x_log) self.D2y_log = D2_log(self.Dy, self.D1y_log) self.dder3_aux_updated = True def update_invhessprod_aux(self): assert not self.invhess_aux_updated assert self.grad_updated assert self.hess_aux_updated if not self.invhess_aux_aux_updated: self.update_invhessprod_aux_aux() # Precompute and factorize the matrix # N = z (Uy ⊗ Uy) [log^[1](Dy)]^-1 (Uy' ⊗ Uy') # - pTr (Ux ⊗ Ux) [(1/z log + inv)^[1](Dx)]^-1 (Ux' ⊗ Ux') pTr' # which we will need to solve linear systems with the Hessian of our # barrier function self.z2 = self.z * self.z self.D1x_comb_inv = np.reciprocal(self.D1x_comb) work6, work7, work8 = self.work6, self.work7, self.work8 Work6, Work7 = self.Work6, self.Work7 Work8, Work9 = self.Work8, self.Work9 # ====================================================================== # Get first term, i.e., [z (Uy ⊗ Uy) [log^[1](Dy)]^-1 (Uy' ⊗ Uy')] # ====================================================================== # Begin with (Uy' ⊗ Uy') congr_multi(work8, self.Uy.conj().T, self.E, work=work7) # Apply z [log^[1](Dy)]^-1 work8 *= self.z * np.reciprocal(self.D1y_log) # Apply (Uy ⊗ Uy) congr_multi(work6, self.Uy, work8, work=work7) # ====================================================================== # Get second term, i.e., [pTr (Ux ⊗ Ux) ...] # ====================================================================== # Begin with [(Ux' ⊗ Ux') pTr'] # Permute columns of Ux' so subsystems we are tracing out are in front temp = self.Ux.T.reshape(self.N, *self.dims) temp = np.transpose(temp, self.reordered_dims) # Obtain groups of columns of Ux' corresponding to # Ux' (I ⊗ Eij) Ux = (Ux' [I ⊗ ei]) ([I ⊗ ej'] Ux) temp = temp.reshape(self.N, self.m, self.n) lhs = np.copy(temp.conj().transpose(2, 0, 1)) rhs = np.copy(temp.transpose(2, 1, 0)) # Use these groups of columns of Ux' to compute [(Ux' ⊗ Ux') pTr'] # Compute the entries corresponding to Eii first np.matmul(lhs, rhs, out=Work9) Work8[self.diag_idxs] = Work9 # Now compute entries corresponding to Eij for i =/= j rhs *= np.sqrt(0.5) t = 0 for j in range(self.n): np.matmul(lhs[j], rhs[:j], out=Work9[:j]) Work9_T = Work9[:j].conj().transpose(0, 2, 1) if self.iscomplex: # Get symmetric and skew-symmetric matrices corresponding to # real and complex basis vectors np.add(Work9[:j], Work9_T, out=Work8[t : t + 2 * j : 2]) np.subtract(Work9[:j], Work9_T, out=Work8[t + 1 : t + 2 * j + 1 : 2]) Work8[t + 1 : t + 2 * j + 1 : 2] *= -1j t += 2 * j + 1 else: # Symmeterize matrix np.add(Work9[:j], Work9_T, out=Work8[t : t + j]) t += j + 1 # Apply [(1/z log + inv)^[1](Dx)]^-1/2 Work8 *= self.D1x_comb_inv # Apply pTr (Ux ⊗ Ux) congr_multi(Work6, self.Ux, Work8, work=Work7) p_tr_multi(work7, Work6, self.dims, self.sys) # Subtract the two terms to obtain Schur complement then Cholesky factor work6 -= work7 work = work6.view(np.float64).reshape((self.vn, -1)) hess_schur = x_dot_dense(self.F2C_op, work.T) self.hess_schur_fact = cho_fact(hess_schur) self.invhess_aux_updated = True def update_invhessprod_aux_aux(self): assert not self.invhess_aux_aux_updated self.precompute_computational_basis() if self.iscomplex: self.diag_idxs = np.cumsum([i for i in range(3, 2 * self.n + 1, 2)]) self.diag_idxs = np.append(0, self.diag_idxs) else: self.diag_idxs = np.cumsum([i for i in range(2, self.n + 1, 1)]) self.diag_idxs = np.append(0, self.diag_idxs) not_sys = list(set(range(len(self.dims))) - set(self.sys)) reordered_dims = self.sys + not_sys self.reordered_dims = [0] + [k + 1 for k in reordered_dims] self.F2C_op = get_full_to_compact_op(self.n, self.iscomplex) self.work6 = np.empty((self.vn, self.n, self.n), dtype=self.dtype) self.work7 = np.empty((self.vn, self.n, self.n), dtype=self.dtype) self.work8 = np.empty((self.vn, self.n, self.n), dtype=self.dtype) self.Work6 = np.empty((self.vn, self.N, self.N), dtype=self.dtype) self.Work7 = np.empty((self.vn, self.N, self.N), dtype=self.dtype) self.Work8 = np.empty((self.vn, self.N, self.N), dtype=self.dtype) self.Work9 = np.empty((self.n, self.N, self.N), dtype=self.dtype) self.invhess_aux_aux_updated = True