Source code for qics.cones.entropy.classentr

# Copyright (c) 2024, Kerry He, James Saunderson, and Hamza Fawzi

# This Python package QICS is licensed under the MIT license; see LICENSE.md
# file in the root directory or at https://github.com/kerry-he/qics

import numpy as np
import scipy as sp

from qics._utils.linalg import dense_dot_x
from qics.cones.base import Cone, get_central_ray_entr


[docs] class ClassEntr(Cone): r"""A class representing a (homogenized) classical entropy cone .. math:: \mathcal{CE}_{n} = \text{cl}\{ (t, u, x) \in \mathbb{R} \times \mathbb{R}_{++} \times \mathbb{R}^n_{++} : t \geq -u H(u^{-1}x) \}, where .. math:: H(x) = -\sum_{i=1}^n x_i \log(x_i), is the classical (Shannon) entropy function. Parameters ---------- n : :obj:`int` Dimension of the vector :math:`x`, i.e., how many terms are in the classical entropy function. See also -------- ClassRelEntr : Classical relative entropy cone QuantEntr : (Homogenized) quantum entropy cone Notes ----- The epigraph of the classical entropy can be obtained by enforcing the linear constraint :math:`u=1`. Additionally, the exponential cone .. math:: \mathcal{E}=\{ (x,y,z)\in\mathbb{R}_+\times\mathbb{R}_+ \times\mathbb{R} : y \geq x \exp(z/x) \}, can be modelled by realizing that if :math:`(x,y,z)\in\mathcal{E}`, then :math:`(-z, y, x)\in\mathcal{CE}_1`. """ def __init__(self, n): self.n = n self.nu = 2 + self.n # Barrier parameter self.dim = [1, 1, n] self.type = ["r", "r", "r"] # Update flags self.feas_updated = False self.grad_updated = False self.hess_aux_updated = False self.invhess_aux_updated = False self.dder3_aux_updated = False self.congr_aux_updated = False return def get_init_point(self, out): (t0, u0, x0) = get_central_ray_entr(self.n) point = [ np.array([[t0]]), np.array([[u0]]), np.ones((self.n, 1)) * x0, ] self.set_point(point, point) out[0][:] = point[0] out[1][:] = point[1] out[2][:] = point[2] return out def get_feas(self): if self.feas_updated: return self.feas self.feas_updated = True (self.t, self.u, self.x) = self.primal # Check that u and x are strictly positive if (self.u <= 0) or any(self.x <= 0): self.feas = False return self.feas # Check that t > -u H(x/u) = Σ_i xi log(xi) - (Σ_i xi) log(u) self.sum_x = np.sum(self.x) self.log_x = np.log(self.x) self.log_u = np.log(self.u[0, 0]) entr_x = self.x.T @ self.log_x entr_xu = self.sum_x * self.log_u self.z = (self.t - (entr_x - entr_xu))[0, 0] self.feas = self.z > 0 return self.feas def get_val(self): return -np.log(self.z) - np.sum(self.log_u) - np.sum(self.log_x) def update_grad(self): assert self.feas_updated assert not self.grad_updated # Compute gradients of classical entropy # D_u H(u, x) = -Σ_i xi / u self.ui = np.reciprocal(self.u) self.DPhiu = -self.sum_x * self.ui # D_x H(u, x) = log(x) + (1 - log(u)) self.DPhiX = self.log_x + (1.0 - self.log_u) # Compute 1 / x self.xi = np.reciprocal(self.x) # Compute gradient of barrier function self.zi = np.reciprocal(self.z) self.grad = [ -self.zi, self.zi * self.DPhiu - self.ui, self.zi * self.DPhiX - self.xi, ] self.grad_updated = True def hess_prod_ip(self, out, H): assert self.grad_updated if not self.hess_aux_updated: self.update_hessprod_aux() (Ht, Hu, Hx) = H # Hessian product of classical entropy D2PhiuH = Hu * self.sum_x * self.ui2 - np.sum(Hx) * self.ui D2PhixH = -Hu * self.ui + Hx * self.xi # Hessian product of barrier function out[0][:] = (Ht - Hu * self.DPhiu - Hx.T @ self.DPhiX) * self.zi2 out[1][:] = -out[0] * self.DPhiu + self.zi * D2PhiuH + Hu * self.ui2 out[2][:] = -out[0] * self.DPhiX + self.zi * D2PhixH + Hx * self.xi2 return out def hess_congr(self, A): assert self.grad_updated if not self.hess_aux_updated: self.update_hessprod_aux() if not self.congr_aux_updated: self.congr_aux(A) p = A.shape[0] lhs = np.zeros((p, sum(self.dim))) work1, work2 = self.work1, self.work2 # ====================================================================== # Hessian products with respect to t # ====================================================================== # D2_t F(t, u, x)[Ht, Hu, Hx] # = (Ht - D_u H(u, x)[Hu] - D_x H(u, x)[Hx]) / z^2 out_t = self.At - self.Au * self.DPhiu[0, 0] out_t -= (self.Ax @ self.DPhiX).ravel() out_t *= self.zi2 lhs[:, 0] = out_t # ====================================================================== # Hessian products with respect to u # ====================================================================== # Hessian products for classical entropy # D2_uu Phi(u, x) [Hu] = sum(x) Hu / u^2 D2PhiuH = self.Huu * self.Au # D2_ux Phi(u, x) [Hx] = -sum(Hx) / u D2PhiuH += self.Hux * np.sum(self.Ax, axis=1) # Hessian product of barrier function # D2_u F(t, u, x)[Ht, Hu, Hx] # = -D2_t F(t, u, x)[Ht, Hu, Hx] * D_u H(u, x) # + (D2_uu H(u, x)[Hu] + D2_ux H(u, x)[Hx]) / z # + Hu / u^2 out_u = -out_t * self.DPhiu out_u += D2PhiuH lhs[:, 1] = out_u # ====================================================================== # Hessian products with respect to x # ====================================================================== # Hessian products for classical entropy # D2_xx Phi(u, x) [Hx] = Hx / x np.multiply(self.Hxx.T, self.Ax, out=work1) # D2_xu Phi(u, x) [Hu] = -Hu / u work1 += self.Hux * self.Au.reshape(-1, 1) # Hessian product of barrier function # D2_x F(t, u, x)[Ht, Hu, Hx] # = -D2_t F(t, u, x)[Ht, Hu, Hx] * D_x H(u, x) # + (D2_xu H(u, x)[Hu] + D2_xx H(u, x)[Hx]) / z # + Hx / x^2 np.outer(out_t, self.DPhiX, out=work2) work1 -= work2 lhs[:, 2:] = work1 # Multiply A (H A') return dense_dot_x(lhs, A.T) def invhess_prod_ip(self, out, H): assert self.grad_updated if not self.hess_aux_updated: self.update_hessprod_aux() if not self.invhess_aux_updated: self.update_invhessprod_aux() (Ht, Hu, Hx) = H Wu = Hu + Ht * self.DPhiu Wx = Hx + Ht * self.DPhiX # Inverse Hessian product of classical entropy out_u = self.rho * (Wu - self.Hxx_inv_Hux.T @ Wx) out_x = self.Hxx_inv * Wx - out_u * self.Hxx_inv_Hux # Inverse Hessian product of barrier function out[0][:] = Ht * self.z2 + out_u * self.DPhiu + out_x.T @ self.DPhiX out[1][:] = out_u out[2][:] = out_x return out def invhess_congr(self, A): assert self.grad_updated if not self.hess_aux_updated: self.update_hessprod_aux() if not self.invhess_aux_updated: self.update_invhessprod_aux() if not self.congr_aux_updated: self.congr_aux(A) # The inverse Hessian product applied on (Ht, Hu, Hx) for the CE barrier # is # (u, x) = M \ (Wu, Wx) # t = z^2 Ht + <DPhi(u, x), (u, x)> # where (Wu, Wx) = [(Hu, Hx) + Ht DPhi(u, x)] # M = [ (1 + sum(x)/z) / u^2 -1' / zu ] = [ a b'] # [ -1 / zu diag(1/zx + 1/x^2) ] [ b D ] # # To solve linear systems with M, we simplify it by doing block # elimination, in which case we get # u = (Wu - b' D^-1 Wx) / (a - b' D^-1 b) # x = D^-1 (Wx - Wu b) p = A.shape[0] lhs = np.zeros((p, sum(self.dim))) # Compute Wu Wu = self.Au + self.At * self.DPhiu[0, 0] # Compute Wx np.outer(self.At, self.DPhiX, out=self.work1) self.work1 += self.Ax # ====================================================================== # Inverse Hessian products with respect to u # ====================================================================== # u = (Wu - b' D^-1 Wx) / (a - b' D^-1 b) out_u = self.rho * (Wu - (self.work1 @ self.Hxx_inv_Hux).ravel()) lhs[:, 1] = out_u # ====================================================================== # Inverse Hessian products with respect to x # ====================================================================== # x = D^-1 (Wx - Wu b) self.work1 *= self.Hxx_inv.T np.outer(out_u, self.Hxx_inv_Hux, out=self.work2) self.work1 -= self.work2 lhs[:, 2:] = self.work1 # ====================================================================== # Inverse Hessian products with respect to t # ====================================================================== # t = z^2 Ht + <DH(u, x), (u, x)> out_t = self.z2 * self.At out_t += out_u * self.DPhiu[0, 0] out_t += (self.work1 @ self.DPhiX).ravel() lhs[:, 0] = out_t # Multiply A (H A') return dense_dot_x(lhs, A.T) def third_dir_deriv_axpy(self, out, H, a=True): assert self.grad_updated if not self.hess_aux_updated: self.update_hessprod_aux() if not self.dder3_aux_updated: self.update_dder3_aux() (Ht, Hu, Hx) = H Hu2 = Hu * Hu Hx2 = Hx * Hx sum_Hx = np.sum(Hx) chi = (Ht - self.DPhiu * Hu - self.DPhiX.T @ Hx)[0, 0] chi2 = chi * chi # Classical entropy Hessians D2PhiuH = Hu * self.sum_x * self.ui2 - np.sum(Hx) * self.ui D2PhixH = -Hu * self.ui + Hx * self.xi D2PhiuHH = Hu * D2PhiuH D2PhixHH = Hx.T @ D2PhixH # Classical entropy third order derivatives D3PhiuHH = -2 * Hu2 * self.sum_x * self.ui3 D3PhiuHH += 2 * Hu * sum_Hx * self.ui2 D3PhixHH = -Hx2 * self.xi2 D3PhixHH += Hu2 * self.ui2 # Third derivatives of barrier dder3_t = -2 * self.zi3 * chi2 - self.zi2 * (D2PhixHH + D2PhiuHH) dder3_u = -dder3_t * self.DPhiu dder3_u -= 2 * self.zi2 * chi * D2PhiuH dder3_u += self.zi * D3PhiuHH dder3_u -= 2 * Hu2 * self.ui3 dder3_x = -dder3_t * self.DPhiX dder3_x -= 2 * self.zi2 * chi * D2PhixH dder3_x += self.zi * D3PhixHH dder3_x -= 2 * Hx2 * self.xi3 out[0][:] += dder3_t * a out[1][:] += dder3_u * a out[2][:] += dder3_x * a return out # ========================================================================== # Auxilliary functions # ========================================================================== def congr_aux(self, A): assert not self.congr_aux_updated if sp.sparse.issparse(A): A = A.toarray() self.At = A[:, 0] self.Au = A[:, 1] self.Ax = A[:, 2:] self.work1 = np.empty_like(self.Ax) self.work2 = np.empty_like(self.Ax) self.congr_aux_updated = True def update_hessprod_aux(self): assert not self.hess_aux_updated assert self.grad_updated self.zi2 = self.zi * self.zi self.ui2 = self.ui * self.ui self.xi2 = self.xi * self.xi self.Huu = (self.zi * self.sum_x + 1.0) * self.ui2 self.Hux = -self.zi * self.ui[0, 0] self.Hxx = self.zi * self.xi + self.xi2 self.hess_aux_updated = True def update_invhessprod_aux(self): assert not self.invhess_aux_updated assert self.grad_updated self.Hxx_inv = np.reciprocal(self.Hxx) self.Hxx_inv_Hux = self.Hxx_inv * self.Hux self.rho = 1.0 / (self.Huu - np.sum(self.Hxx_inv_Hux) * self.Hux)[0, 0] self.z2 = self.z * self.z self.invhess_aux_updated = True def update_dder3_aux(self): assert not self.dder3_aux_updated assert self.hess_aux_updated self.zi3 = self.zi * self.zi2 self.ui3 = self.ui * self.ui2 self.xi3 = self.xi * self.xi2 self.dder3_aux_updated = True