# Copyright (c) 2024, Kerry He, James Saunderson, and Hamza Fawzi
# This Python package QICS is licensed under the MIT license; see LICENSE.md
# file in the root directory or at https://github.com/kerry-he/qics
import numpy as np
import scipy as sp
from qics._utils.linalg import dense_dot_x
from qics.cones.base import Cone, get_central_ray_entr
[docs]
class ClassEntr(Cone):
    r"""A class representing a (homogenized) classical entropy cone
    .. math::
        \mathcal{CE}_{n} = \text{cl}\{ (t, u, x) \in \mathbb{R} \times
        \mathbb{R}_{++} \times \mathbb{R}^n_{++} : t \geq -u H(u^{-1}x) \},
    where
    .. math::
        H(x) = -\sum_{i=1}^n x_i \log(x_i),
    is the classical (Shannon) entropy function.
    Parameters
    ----------
    n : :obj:`int`
        Dimension of the vector :math:`x`, i.e., how many terms are in the
        classical entropy function.
    See also
    --------
    ClassRelEntr : Classical relative entropy cone
    QuantEntr : (Homogenized) quantum entropy cone
    Notes
    -----
    The epigraph of the classical entropy can be obtained by enforcing the
    linear constraint :math:`u=1`.
    Additionally, the exponential cone
    .. math::
        \mathcal{E}=\{ (x,y,z)\in\mathbb{R}_+\times\mathbb{R}_+
        \times\mathbb{R} : y \geq x \exp(z/x) \},
    can be modelled by realizing that if :math:`(x,y,z)\in\mathcal{E}`,
    then :math:`(-z, y, x)\in\mathcal{CE}_1`.
    """
    def __init__(self, n):
        self.n = n
        self.nu = 2 + self.n  # Barrier parameter
        self.dim = [1, 1, n]
        self.type = ["r", "r", "r"]
        # Update flags
        self.feas_updated = False
        self.grad_updated = False
        self.hess_aux_updated = False
        self.invhess_aux_updated = False
        self.dder3_aux_updated = False
        self.congr_aux_updated = False
        return
    def get_init_point(self, out):
        (t0, u0, x0) = get_central_ray_entr(self.n)
        point = [
            np.array([[t0]]),
            np.array([[u0]]),
            np.ones((self.n, 1)) * x0,
        ]
        self.set_point(point, point)
        out[0][:] = point[0]
        out[1][:] = point[1]
        out[2][:] = point[2]
        return out
    def get_feas(self):
        if self.feas_updated:
            return self.feas
        self.feas_updated = True
        (self.t, self.u, self.x) = self.primal
        # Check that u and x are strictly positive
        if (self.u <= 0) or any(self.x <= 0):
            self.feas = False
            return self.feas
        # Check that t > -u H(x/u) = Σ_i xi log(xi) - (Σ_i xi) log(u)
        self.sum_x = np.sum(self.x)
        self.log_x = np.log(self.x)
        self.log_u = np.log(self.u[0, 0])
        entr_x = self.x.T @ self.log_x
        entr_xu = self.sum_x * self.log_u
        self.z = (self.t - (entr_x - entr_xu))[0, 0]
        self.feas = self.z > 0
        return self.feas
    def get_val(self):
        return -np.log(self.z) - np.sum(self.log_u) - np.sum(self.log_x)
    def update_grad(self):
        assert self.feas_updated
        assert not self.grad_updated
        # Compute gradients of classical entropy
        # D_u H(u, x) = -Σ_i xi / u
        self.ui = np.reciprocal(self.u)
        self.DPhiu = -self.sum_x * self.ui
        # D_x H(u, x) = log(x) + (1 - log(u))
        self.DPhiX = self.log_x + (1.0 - self.log_u)
        # Compute 1 / x
        self.xi = np.reciprocal(self.x)
        # Compute gradient of barrier function
        self.zi = np.reciprocal(self.z)
        self.grad = [
            -self.zi,
            self.zi * self.DPhiu - self.ui,
            self.zi * self.DPhiX - self.xi,
        ]
        self.grad_updated = True
    def hess_prod_ip(self, out, H):
        assert self.grad_updated
        if not self.hess_aux_updated:
            self.update_hessprod_aux()
        (Ht, Hu, Hx) = H
        # Hessian product of classical entropy
        D2PhiuH = Hu * self.sum_x * self.ui2 - np.sum(Hx) * self.ui
        D2PhixH = -Hu * self.ui + Hx * self.xi
        # Hessian product of barrier function
        out[0][:] = (Ht - Hu * self.DPhiu - Hx.T @ self.DPhiX) * self.zi2
        out[1][:] = -out[0] * self.DPhiu + self.zi * D2PhiuH + Hu * self.ui2
        out[2][:] = -out[0] * self.DPhiX + self.zi * D2PhixH + Hx * self.xi2
        return out
    def hess_congr(self, A):
        assert self.grad_updated
        if not self.hess_aux_updated:
            self.update_hessprod_aux()
        if not self.congr_aux_updated:
            self.congr_aux(A)
        p = A.shape[0]
        lhs = np.zeros((p, sum(self.dim)))
        work1, work2 = self.work1, self.work2
        # ======================================================================
        # Hessian products with respect to t
        # ======================================================================
        # D2_t F(t, u, x)[Ht, Hu, Hx]
        #         = (Ht - D_u H(u, x)[Hu] - D_x H(u, x)[Hx]) / z^2
        out_t = self.At - self.Au * self.DPhiu[0, 0]
        out_t -= (self.Ax @ self.DPhiX).ravel()
        out_t *= self.zi2
        lhs[:, 0] = out_t
        # ======================================================================
        # Hessian products with respect to u
        # ======================================================================
        # Hessian products for classical entropy
        # D2_uu Phi(u, x) [Hu] =  sum(x) Hu / u^2
        D2PhiuH = self.Huu * self.Au
        # D2_ux Phi(u, x) [Hx] = -sum(Hx) / u
        D2PhiuH += self.Hux * np.sum(self.Ax, axis=1)
        # Hessian product of barrier function
        # D2_u F(t, u, x)[Ht, Hu, Hx]
        #         = -D2_t F(t, u, x)[Ht, Hu, Hx] * D_u H(u, x)
        #           + (D2_uu H(u, x)[Hu] + D2_ux H(u, x)[Hx]) / z
        #           + Hu / u^2
        out_u = -out_t * self.DPhiu
        out_u += D2PhiuH
        lhs[:, 1] = out_u
        # ======================================================================
        # Hessian products with respect to x
        # ======================================================================
        # Hessian products for classical entropy
        # D2_xx Phi(u, x) [Hx] =  Hx / x
        np.multiply(self.Hxx.T, self.Ax, out=work1)
        # D2_xu Phi(u, x) [Hu] = -Hu / u
        work1 += self.Hux * self.Au.reshape(-1, 1)
        # Hessian product of barrier function
        # D2_x F(t, u, x)[Ht, Hu, Hx]
        #         = -D2_t F(t, u, x)[Ht, Hu, Hx] * D_x H(u, x)
        #           + (D2_xu H(u, x)[Hu] + D2_xx H(u, x)[Hx]) / z
        #           + Hx / x^2
        np.outer(out_t, self.DPhiX, out=work2)
        work1 -= work2
        lhs[:, 2:] = work1
        # Multiply A (H A')
        return dense_dot_x(lhs, A.T)
    def invhess_prod_ip(self, out, H):
        assert self.grad_updated
        if not self.hess_aux_updated:
            self.update_hessprod_aux()
        if not self.invhess_aux_updated:
            self.update_invhessprod_aux()
        (Ht, Hu, Hx) = H
        Wu = Hu + Ht * self.DPhiu
        Wx = Hx + Ht * self.DPhiX
        # Inverse Hessian product of classical entropy
        out_u = self.rho * (Wu - self.Hxx_inv_Hux.T @ Wx)
        out_x = self.Hxx_inv * Wx - out_u * self.Hxx_inv_Hux
        # Inverse Hessian product of barrier function
        out[0][:] = Ht * self.z2 + out_u * self.DPhiu + out_x.T @ self.DPhiX
        out[1][:] = out_u
        out[2][:] = out_x
        return out
    def invhess_congr(self, A):
        assert self.grad_updated
        if not self.hess_aux_updated:
            self.update_hessprod_aux()
        if not self.invhess_aux_updated:
            self.update_invhessprod_aux()
        if not self.congr_aux_updated:
            self.congr_aux(A)
        # The inverse Hessian product applied on (Ht, Hu, Hx) for the CE barrier
        # is
        #     (u, x) =  M \ (Wu, Wx)
        #         t  =  z^2 Ht + <DPhi(u, x), (u, x)>
        # where (Wu, Wx) = [(Hu, Hx) + Ht DPhi(u, x)]
        #     M = [ (1 + sum(x)/z) / u^2        -1' / zu       ] = [ a  b']
        #         [        -1 / zu         diag(1/zx + 1/x^2)  ]   [ b  D ]
        #
        # To solve linear systems with M, we simplify it by doing block
        # elimination, in which case we get
        #     u = (Wu - b' D^-1 Wx) / (a - b' D^-1 b)
        #     x = D^-1 (Wx - Wu b)
        p = A.shape[0]
        lhs = np.zeros((p, sum(self.dim)))
        # Compute Wu
        Wu = self.Au + self.At * self.DPhiu[0, 0]
        # Compute Wx
        np.outer(self.At, self.DPhiX, out=self.work1)
        self.work1 += self.Ax
        # ======================================================================
        # Inverse Hessian products with respect to u
        # ======================================================================
        # u = (Wu - b' D^-1 Wx) / (a - b' D^-1 b)
        out_u = self.rho * (Wu - (self.work1 @ self.Hxx_inv_Hux).ravel())
        lhs[:, 1] = out_u
        # ======================================================================
        # Inverse Hessian products with respect to x
        # ======================================================================
        # x = D^-1 (Wx - Wu b)
        self.work1 *= self.Hxx_inv.T
        np.outer(out_u, self.Hxx_inv_Hux, out=self.work2)
        self.work1 -= self.work2
        lhs[:, 2:] = self.work1
        # ======================================================================
        # Inverse Hessian products with respect to t
        # ======================================================================
        # t = z^2 Ht + <DH(u, x), (u, x)>
        out_t = self.z2 * self.At
        out_t += out_u * self.DPhiu[0, 0]
        out_t += (self.work1 @ self.DPhiX).ravel()
        lhs[:, 0] = out_t
        # Multiply A (H A')
        return dense_dot_x(lhs, A.T)
    def third_dir_deriv_axpy(self, out, H, a=True):
        assert self.grad_updated
        if not self.hess_aux_updated:
            self.update_hessprod_aux()
        if not self.dder3_aux_updated:
            self.update_dder3_aux()
        (Ht, Hu, Hx) = H
        Hu2 = Hu * Hu
        Hx2 = Hx * Hx
        sum_Hx = np.sum(Hx)
        chi = (Ht - self.DPhiu * Hu - self.DPhiX.T @ Hx)[0, 0]
        chi2 = chi * chi
        # Classical entropy Hessians
        D2PhiuH = Hu * self.sum_x * self.ui2 - np.sum(Hx) * self.ui
        D2PhixH = -Hu * self.ui + Hx * self.xi
        D2PhiuHH = Hu * D2PhiuH
        D2PhixHH = Hx.T @ D2PhixH
        # Classical entropy third order derivatives
        D3PhiuHH = -2 * Hu2 * self.sum_x * self.ui3
        D3PhiuHH += 2 * Hu * sum_Hx * self.ui2
        D3PhixHH = -Hx2 * self.xi2
        D3PhixHH += Hu2 * self.ui2
        # Third derivatives of barrier
        dder3_t = -2 * self.zi3 * chi2 - self.zi2 * (D2PhixHH + D2PhiuHH)
        dder3_u = -dder3_t * self.DPhiu
        dder3_u -= 2 * self.zi2 * chi * D2PhiuH
        dder3_u += self.zi * D3PhiuHH
        dder3_u -= 2 * Hu2 * self.ui3
        dder3_x = -dder3_t * self.DPhiX
        dder3_x -= 2 * self.zi2 * chi * D2PhixH
        dder3_x += self.zi * D3PhixHH
        dder3_x -= 2 * Hx2 * self.xi3
        out[0][:] += dder3_t * a
        out[1][:] += dder3_u * a
        out[2][:] += dder3_x * a
        return out
    # ==========================================================================
    # Auxilliary functions
    # ==========================================================================
    def congr_aux(self, A):
        assert not self.congr_aux_updated
        if sp.sparse.issparse(A):
            A = A.toarray()
        self.At = A[:, 0]
        self.Au = A[:, 1]
        self.Ax = A[:, 2:]
        self.work1 = np.empty_like(self.Ax)
        self.work2 = np.empty_like(self.Ax)
        self.congr_aux_updated = True
    def update_hessprod_aux(self):
        assert not self.hess_aux_updated
        assert self.grad_updated
        self.zi2 = self.zi * self.zi
        self.ui2 = self.ui * self.ui
        self.xi2 = self.xi * self.xi
        self.Huu = (self.zi * self.sum_x + 1.0) * self.ui2
        self.Hux = -self.zi * self.ui[0, 0]
        self.Hxx = self.zi * self.xi + self.xi2
        self.hess_aux_updated = True
    def update_invhessprod_aux(self):
        assert not self.invhess_aux_updated
        assert self.grad_updated
        self.Hxx_inv = np.reciprocal(self.Hxx)
        self.Hxx_inv_Hux = self.Hxx_inv * self.Hux
        self.rho = 1.0 / (self.Huu - np.sum(self.Hxx_inv_Hux) * self.Hux)[0, 0]
        self.z2 = self.z * self.z
        self.invhess_aux_updated = True
    def update_dder3_aux(self):
        assert not self.dder3_aux_updated
        assert self.hess_aux_updated
        self.zi3 = self.zi * self.zi2
        self.ui3 = self.ui * self.ui2
        self.xi3 = self.xi * self.xi2
        self.dder3_aux_updated = True