Coverage for picos/expressions/exp_renyientr.py: 75.50%
298 statements
« prev ^ index » next coverage.py v7.6.12, created at 2025-04-12 07:53 +0000
« prev ^ index » next coverage.py v7.6.12, created at 2025-04-12 07:53 +0000
1# ------------------------------------------------------------------------------
2# Copyright (C) 2024 Kerry He
3#
4# This file is part of PICOS.
5#
6# PICOS is free software: you can redistribute it and/or modify it under the
7# terms of the GNU General Public License as published by the Free Software
8# Foundation, either version 3 of the License, or (at your option) any later
9# version.
10#
11# PICOS is distributed in the hope that it will be useful, but WITHOUT ANY
12# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
13# A PARTICULAR PURPOSE. See the GNU General Public License for more details.
14#
15# You should have received a copy of the GNU General Public License along with
16# this program. If not, see <http://www.gnu.org/licenses/>.
17# ------------------------------------------------------------------------------
19"""Implements Renyi entropy expressions."""
21import operator
22from collections import namedtuple
24import cvxopt
25import numpy
27from .. import glyphs
28from ..apidoc import api_end, api_start
29from ..caching import cached_unary_operator
30from ..constraints import (
31 QuasiEntrEpiConstraint,
32 ComplexQuasiEntrEpiConstraint,
33 QuasiEntrHypoConstraint,
34 ComplexQuasiEntrHypoConstraint,
35 RenyiEntrConstraint,
36 ComplexRenyiEntrConstraint,
37 SandRenyiEntrConstraint,
38 ComplexSandRenyiEntrConstraint,
39 SandQuasiEntrEpiConstraint,
40 ComplexSandQuasiEntrEpiConstraint,
41 SandQuasiEntrHypoConstraint,
42 ComplexSandQuasiEntrHypoConstraint,
44)
45from .data import convert_and_refine_arguments, convert_operands, cvx2np
46from .exp_affine import AffineExpression, ComplexAffineExpression
47from .expression import Expression, refine_operands, validate_prediction
49_API_START = api_start(globals())
50# -------------------------------
52class BaseRenyiEntropy(Expression):
53 r"""Base class used to define a general Renyi entropy expression."""
55 # --------------------------------------------------------------------------
56 # Initialization and factory methods.
57 # --------------------------------------------------------------------------
59 @convert_and_refine_arguments("X", "Y", "u", allowNone=True)
60 def __init__(self, X, Y, alpha, u=None):
61 r"""Construct an :class:`BaseRenyiEntropy`.
63 :param X: The affine expression :math:`X`.
64 :type X: ~picos.expressions.AffineExpression
65 :param Y: The affine expression :math:`Y`. This should have the same
66 dimensions as :math:`X`.
67 :type Y: ~picos.expressions.AffineExpression
68 :param alpha: The parameter :math:`\alpha`.
69 :type alpha: float
70 :param u: An additional scalar affine expression :math:`u`. If
71 specified, then this defines the perspective of the Renyi entropy.
72 :type u: ~picos.expressions.AffineExpression
73 """
74 if not isinstance(X, ComplexAffineExpression):
75 raise TypeError(
76 "Can only take the matrix powers of a real "
77 "or complex affine expression, not of {}.".format(
78 type(X).__name__
79 )
80 )
81 if not X.hermitian:
82 raise TypeError(
83 "Can only take the matrix powers of a symmetric "
84 "or Hermitian expression, not of {}.".format(type(X).__name__)
85 )
87 if not isinstance(Y, ComplexAffineExpression):
88 raise TypeError(
89 "The additional parameter Y must be a real "
90 "or complex affine expression, not {}.".format(type(Y).__name__)
91 )
92 if not Y.hermitian:
93 raise TypeError(
94 "Can only take the matrix powers of a symmetric "
95 "or Hermitian expression, not of {}.".format(type(Y).__name__)
96 )
97 if X.shape != Y.shape:
98 raise TypeError(
99 "The additional parameter Y must be the same shape"
100 "as X, not {}.".format(type(Y).__name__)
101 )
103 if u is not None:
104 if u.shape != (1, 1) or not isinstance(u, AffineExpression):
105 raise TypeError(
106 "The additional parameter u must be a real scalar affine "
107 "expression, not {}.".format(type(Y).__name__)
108 )
109 if u.is1:
110 u = None
112 self._is_valid_alpha(alpha)
114 self._X = X
115 self._Y = Y
116 self._u = u
117 self._alpha = alpha
119 self._iscomplex = not isinstance(X, AffineExpression) or \
120 not isinstance(Y, AffineExpression)
122 typeStr, symbStr = self._get_strings()
124 Expression.__init__(self, typeStr, symbStr)
126 # --------------------------------------------------------------------------
127 # Abstract method implementations and method overridings, except _predict.
128 # --------------------------------------------------------------------------
130 def _get_refined(self):
131 if self._X.constant and self._Y.constant:
132 if (self._u is None or self._u.constant):
133 return AffineExpression.from_constant(
134 self.value, 1, self._symbStr
135 )
137 return self
139 Subtype = namedtuple("Subtype", ("argdim", "alpha", "iscomplex"))
141 def _get_subtype(self):
142 return self.Subtype(len(self._X), self._alpha, self._iscomplex)
144 @cached_unary_operator
145 def _get_mutables(self):
146 return self._X._get_mutables().union(self._Y.mutables)
148 def _is_convex(self):
149 return True
151 def _is_concave(self):
152 return False
154 def _replace_mutables(self, mapping):
155 return self.__class__(
156 self._X._replace_mutables(mapping),
157 self._Y._replace_mutables(mapping),
158 None if self._u is None else self._u._replace_mutables(mapping),
159 )
161 def _freeze_mutables(self, freeze):
162 return self.__class__(
163 self._X._freeze_mutables(freeze),
164 self._Y._freeze_mutables(freeze),
165 None if self._u is None else self._u._freeze_mutables(freeze),
166 )
168 # --------------------------------------------------------------------------
169 # Methods and properties that return expressions.
170 # --------------------------------------------------------------------------
172 @property
173 def X(self):
174 """The expression :math:`X`."""
175 return self._X
177 @property
178 def Y(self):
179 """The additional expression :math:`Y`."""
180 return self._Y
182 @property
183 def u(self):
184 """The additional expression :math:`u`."""
185 return self._u
187 @property
188 def alpha(self):
189 r"""The alpha :math:`\alpha`."""
190 return self._alpha
192 # --------------------------------------------------------------------------
193 # Methods and properties that describe the expression.
194 # --------------------------------------------------------------------------
196 @property
197 def n(self):
198 """Lengths of :attr:`X` and :attr:`Y`."""
199 return len(self._X)
201 @property
202 def iscomplex(self):
203 """Whether :attr:`X` and :attr:`Y` are complex expressions or not."""
204 return self._iscomplex
206 # --------------------------------------------------------------------------
207 # Constraint-creating operators, and _predict.
208 # --------------------------------------------------------------------------
210 @classmethod
211 def _predict(cls, subtype, relation, other):
212 assert isinstance(subtype, cls.Subtype)
214 if not issubclass(other.clstype, AffineExpression):
215 return NotImplemented
217 if relation == operator.__le__:
218 if subtype.iscomplex:
219 return cls._ComplexConstraint().make_type(argdim=subtype.argdim)
220 else:
221 return cls._RealConstraint().make_type(argdim=subtype.argdim)
223 return NotImplemented
225 @convert_operands(scalarRHS=True)
226 @validate_prediction
227 @refine_operands()
228 def __le__(self, other):
229 if not isinstance(other, AffineExpression):
230 return NotImplemented
232 if self.iscomplex:
233 return self._ComplexConstraint()(self, other)
234 else:
235 return self._RealConstraint()(self, other)
238class RenyiEntropy(BaseRenyiEntropy):
239 r"""Renyi entropy of an affine expression.
241 :Definition:
243 Let :math:`X` and :math:`Y` be :math:`N \times N`-dimensional symmetric
244 or hermitian matrices. Then this is defined as
246 .. math::
248 \frac{1}{\alpha-1}\log(\operatorname{Tr}[ X^\alpha Y^{1-\alpha} ]),
250 for some :math:`\alpha\in[0, 1)`.
252 .. warning::
254 When you pose an upper or lower bound on this expression, then PICOS
255 enforces :math:`X \succeq 0` and :math:`Y \succeq 0` through an
256 auxiliary constraint during solution search.
257 """
259 # --------------------------------------------------------------------------
260 # Initialization and factory methods.
261 # --------------------------------------------------------------------------
263 def _is_valid_alpha(self, alpha):
264 if not (numpy.isscalar(alpha) and 0 <= alpha and alpha < 1):
265 raise TypeError("The exponent alpha must be a scalar in [0, 1)")
267 def _get_strings(self):
268 typeStr = "Renyi Entropy"
269 symbStr = glyphs.renyi(str(self._alpha), self._X.string, self._Y.string)
270 return typeStr, symbStr
272 # --------------------------------------------------------------------------
273 # Abstract method implementations and method overridings, except _predict.
274 # --------------------------------------------------------------------------
276 def _get_value(self):
277 X = cvx2np(self._X._get_value())
278 Y = cvx2np(self._Y._get_value())
279 u = cvx2np(self._u._get_value()) if self._u is not None else 1
281 Dx, Ux = numpy.linalg.eigh(X)
282 X_alpha = Ux @ numpy.diag(numpy.power(Dx, self._alpha)) @ Ux.conj().T
284 Dy, Uy = numpy.linalg.eigh(Y)
285 Y_beta = Uy @ numpy.diag(numpy.power(Dy, 1 - self._alpha)) @ Uy.conj().T
287 t = numpy.sum(X_alpha * Y_beta.conj()).real
288 s = u * numpy.log(t / u) / (self._alpha - 1)
290 return cvxopt.matrix(s)
292 # --------------------------------------------------------------------------
293 # Constraint-creating operators, and _predict.
294 # --------------------------------------------------------------------------
296 @classmethod
297 def _ComplexConstraint(cls):
298 return ComplexRenyiEntrConstraint
300 @classmethod
301 def _RealConstraint(cls):
302 return RenyiEntrConstraint
305class SandRenyiEntropy(BaseRenyiEntropy):
306 r"""Sandwiched Renyi entropy of an affine expression.
308 :Definition:
310 Let :math:`X` and :math:`Y` be :math:`N \times N`-dimensional symmetric
311 or hermitian matrices. Then this is defined as
313 .. math::
315 \frac{1}{\alpha-1}\log(\operatorname{Tr}[ (Y^{\frac{1-\alpha}{2\alpha}}
316 X Y^{\frac{1-\alpha}{2\alpha}})^\alpha ]),
318 for some :math:`\alpha\in[1/2, 1)`.
320 .. warning::
322 When you pose an upper or lower bound on this expression, then PICOS
323 enforces :math:`X \succeq 0` and :math:`Y \succeq 0` through an
324 auxiliary constraint during solution search.
325 """
327 # --------------------------------------------------------------------------
328 # Initialization and factory methods.
329 # --------------------------------------------------------------------------
331 def _is_valid_alpha(self, alpha):
332 if not (numpy.isscalar(alpha) and 0.5 <= alpha and alpha < 1):
333 raise TypeError("The exponent alpha must be a scalar in [1/2, 1)")
335 def _get_strings(self):
336 typeStr = "Sandwiched Renyi Entropy"
337 symbStr = glyphs.renyi(str(self._alpha), self._X.string, self._Y.string)
338 return typeStr, symbStr
340 # --------------------------------------------------------------------------
341 # Abstract method implementations and method overridings, except _predict.
342 # --------------------------------------------------------------------------
344 def _get_value(self):
345 X = cvx2np(self._X._get_value())
346 Y = cvx2np(self._Y._get_value())
347 u = cvx2np(self._u._get_value()) if self._u is not None else 1
349 Dy, Uy = numpy.linalg.eigh(Y)
350 Dy_beta = numpy.power(Dy, (1 - self._alpha) / (2 * self._alpha))
351 Y_beta = Uy @ numpy.diag(Dy_beta) @ Uy.conj().T
353 Dyxy = numpy.linalg.eigvalsh(Y_beta @ X @ Y_beta)
355 t = numpy.sum(numpy.power(Dyxy, self._alpha))
356 s = u * numpy.log(t / u) / (self._alpha - 1)
358 return cvxopt.matrix(s)
360 # --------------------------------------------------------------------------
361 # Constraint-creating operators, and _predict.
362 # --------------------------------------------------------------------------
364 @classmethod
365 def _ComplexConstraint(cls):
366 return ComplexSandRenyiEntrConstraint
368 @classmethod
369 def _RealConstraint(cls):
370 return SandRenyiEntrConstraint
373class BaseQuasiEntropy(Expression):
374 r"""Base class defining a general quasi-relative entropy expression."""
376 # --------------------------------------------------------------------------
377 # Initialization and factory methods.
378 # --------------------------------------------------------------------------
380 @convert_and_refine_arguments("X", "Y")
381 def __init__(self, X, Y, alpha):
382 """Construct an :class:`BaseQuasiEntropy`.
384 :param X: The affine expression :math:`X`.
385 :type X: ~picos.expressions.AffineExpression
386 :param Y: The affine expression :math:`Y`. This should have the same
387 dimensions as :math:`X`.
388 :type Y: ~picos.expressions.AffineExpression
389 """
390 if not isinstance(X, ComplexAffineExpression):
391 raise TypeError(
392 "Can only take the matrix powers of a real "
393 "or complex affine expression, not of {}.".format(
394 type(X).__name__
395 )
396 )
397 if not X.hermitian:
398 raise TypeError(
399 "Can only take the matrix powers of a symmetric "
400 "or Hermitian expression, not of {}.".format(type(X).__name__)
401 )
403 if not isinstance(Y, ComplexAffineExpression):
404 raise TypeError(
405 "The additional parameter Y must be a real "
406 "or complex affine expression, not {}.".format(type(Y).__name__)
407 )
408 if not Y.hermitian:
409 raise TypeError(
410 "Can only take the matrix powers of a symmetric "
411 "or Hermitian expression, not of {}.".format(type(Y).__name__)
412 )
413 if X.shape != Y.shape:
414 raise TypeError(
415 "The additional parameter Y must be the same shape"
416 "as X, not {}.".format(type(Y).__name__)
417 )
419 self._is_valid_alpha(alpha)
421 self._X = X
422 self._Y = Y
423 self._alpha = alpha
425 self._iscomplex = not isinstance(X, AffineExpression) or \
426 not isinstance(Y, AffineExpression)
428 typeStr, symbStr = self._get_strings()
430 Expression.__init__(self, typeStr, symbStr)
432 # --------------------------------------------------------------------------
433 # Abstract method implementations and method overridings, except _predict.
434 # --------------------------------------------------------------------------
436 def _get_refined(self):
437 if self._X.constant and self._Y.constant:
438 return AffineExpression.from_constant(self.value, 1, self._symbStr)
439 else:
440 return self
442 Subtype = namedtuple("Subtype", ("argdim", "alpha", "iscomplex"))
444 def _get_subtype(self):
445 return self.Subtype(len(self._X), self._alpha, self._iscomplex)
447 @cached_unary_operator
448 def _get_mutables(self):
449 return self._X._get_mutables().union(self._Y.mutables)
451 def _is_convex(self):
452 return (-1 <= self._alpha and self._alpha <= 0) or \
453 ( 1 <= self._alpha and self._alpha <= 2)
455 def _is_concave(self):
456 return 0 <= self._alpha and self._alpha <= 1
458 def _replace_mutables(self, mapping):
459 return self.__class__(
460 self._X._replace_mutables(mapping),
461 self._Y._replace_mutables(mapping),
462 )
464 def _freeze_mutables(self, freeze):
465 return self.__class__(
466 self._X._freeze_mutables(freeze), self._Y._freeze_mutables(freeze)
467 )
469 # --------------------------------------------------------------------------
470 # Methods and properties that return expressions.
471 # --------------------------------------------------------------------------
473 @property
474 def X(self):
475 """The expression :math:`X`."""
476 return self._X
478 @property
479 def Y(self):
480 """The additional expression :math:`Y`."""
481 return self._Y
483 @property
484 def alpha(self):
485 r"""The alpha :math:`\alpha`."""
486 return self._alpha
488 # --------------------------------------------------------------------------
489 # Methods and properties that describe the expression.
490 # --------------------------------------------------------------------------
492 @property
493 def n(self):
494 """Lengths of :attr:`X` and :attr:`Y`."""
495 return len(self._X)
497 @property
498 def iscomplex(self):
499 """Whether :attr:`X` and :attr:`Y` are complex expressions or not."""
500 return self._iscomplex
502 # --------------------------------------------------------------------------
503 # Constraint-creating operators, and _predict.
504 # --------------------------------------------------------------------------
506 @classmethod
507 def _predict(cls, subtype, relation, other):
508 assert isinstance(subtype, cls.Subtype)
510 if not issubclass(other.clstype, AffineExpression):
511 return NotImplemented
513 if other.subtype.dim != 1:
514 return NotImplemented
516 isconvex = (-1 <= subtype.alpha and subtype.alpha <= 0) or \
517 ( 1 <= subtype.alpha and subtype.alpha <= 2)
518 isconcave = 0 <= subtype.alpha and subtype.alpha <= 1
519 argdim = subtype.argdim
521 if relation == operator.__le__ and isconvex:
522 if subtype.iscomplex or not issubclass(
523 other.clstype, AffineExpression
524 ):
525 return cls._ComplexEpiConstraint().make_type(argdim=argdim)
526 else:
527 return cls._RealEpiConstraint().make_type(argdim=argdim)
529 if relation == operator.__ge__ and isconcave:
530 if subtype.iscomplex or not issubclass(
531 other.clstype, AffineExpression
532 ):
533 return cls._ComplexHypoConstraint().make_type(argdim=argdim)
534 else:
535 return cls._RealHypoConstraint().make_type(argdim=argdim)
537 return NotImplemented
539 @convert_operands(scalarRHS=True)
540 @validate_prediction
541 @refine_operands()
542 def __le__(self, other):
543 if self.convex and isinstance(other, AffineExpression):
544 if self.iscomplex:
545 return self._ComplexEpiConstraint()(self, other)
546 else:
547 return self._RealEpiConstraint()(self, other)
548 else:
549 return NotImplemented
551 @convert_operands(scalarRHS=True)
552 @validate_prediction
553 @refine_operands()
554 def __ge__(self, other):
555 if self.concave and isinstance(other, AffineExpression):
556 if self.iscomplex:
557 return self._ComplexHypoConstraint()(self, other)
558 else:
559 return self._RealHypoConstraint()(self, other)
560 else:
561 return NotImplemented
563class QuasiEntropy(BaseQuasiEntropy):
564 r"""Quasi-relative entropy of an affine expression.
566 :Definition:
568 Let :math:`X` and :math:`Y` be :math:`N \times N`-dimensional symmetric
569 or hermitian matrices. Then this is defined as
571 .. math::
573 \operatorname{Tr}[ X^\alpha Y^{1-\alpha} ],
575 for some :math:`\alpha\in[-1, 2]`.
577 .. warning::
579 When you pose an upper or lower bound on this expression, then PICOS
580 enforces :math:`X \succeq 0` and :math:`Y \succeq 0` through an
581 auxiliary constraint during solution search.
582 """
584 # --------------------------------------------------------------------------
585 # Initialization and factory methods.
586 # --------------------------------------------------------------------------
588 def _is_valid_alpha(self, alpha):
589 if not (numpy.isscalar(alpha) and -1 <= alpha and alpha <= 2):
590 raise TypeError("The exponent alpha must be a scalar in [-1, 2]")
592 def _get_strings(self):
593 typeStr = "Quasi-Relative Entropy"
594 xStr = glyphs.power(self._X.string, "a")
595 yStr = glyphs.power(self._Y.string, "1-a")
596 symbStr = glyphs.trace(glyphs.mul(xStr, yStr))
597 return typeStr, symbStr
599 # --------------------------------------------------------------------------
600 # Abstract method implementations and method overridings, except _predict.
601 # --------------------------------------------------------------------------
603 def _get_value(self):
604 X = cvx2np(self._X._get_value())
605 Y = cvx2np(self._Y._get_value())
607 Dx, Ux = numpy.linalg.eigh(X)
608 X_alpha = Ux @ numpy.diag(numpy.power(Dx, self._alpha)) @ Ux.conj().T
610 Dy, Uy = numpy.linalg.eigh(Y)
611 Y_beta = Uy @ numpy.diag(numpy.power(Dy, 1 - self._alpha)) @ Uy.conj().T
613 s = numpy.sum(X_alpha * Y_beta.conj()).real
615 return cvxopt.matrix(s)
617 # --------------------------------------------------------------------------
618 # Constraint-creating operators, and _predict.
619 # --------------------------------------------------------------------------
621 @classmethod
622 def _RealEpiConstraint(cls):
623 return QuasiEntrEpiConstraint
625 @classmethod
626 def _ComplexEpiConstraint(cls):
627 return ComplexQuasiEntrEpiConstraint
629 @classmethod
630 def _RealHypoConstraint(cls):
631 return QuasiEntrHypoConstraint
633 @classmethod
634 def _ComplexHypoConstraint(cls):
635 return ComplexQuasiEntrHypoConstraint
638class SandQuasiEntropy(BaseQuasiEntropy):
639 r"""Sandwiched quasi-relative entropy of an affine expression.
641 :Definition:
643 Let :math:`X` and :math:`Y` be :math:`N \times N`-dimensional symmetric
644 or hermitian matrices. Then this is defined as
646 .. math::
648 \operatorname{Tr}[ (Y^{\frac{1-\alpha}{2\alpha}}
649 X Y^{\frac{1-\alpha}{2\alpha}})^\alpha ],
651 for some :math:`\alpha\in[1/2, 2]`.
653 .. warning::
655 When you pose an upper or lower bound on this expression, then PICOS
656 enforces :math:`X \succeq 0` and :math:`Y \succeq 0` through an
657 auxiliary constraint during solution search.
658 """
660 # --------------------------------------------------------------------------
661 # Initialization and factory methods.
662 # --------------------------------------------------------------------------
664 def _is_valid_alpha(self, alpha):
665 if not (numpy.isscalar(alpha) and 0.5 <= alpha and alpha <= 2):
666 raise TypeError("The exponent alpha must be a scalar in [1/2, 2]")
668 def _get_strings(self):
669 typeStr = "Sandwiched Quasi-Relative Entropy"
670 xStr = self._X.string
671 yStr = glyphs.power(self._Y.string, "(1-a)/2a")
672 symbStr = glyphs.power(glyphs.mul(glyphs.mul(yStr, xStr), yStr), "a")
673 symbStr = glyphs.trace(symbStr)
674 return typeStr, symbStr
676 # --------------------------------------------------------------------------
677 # Abstract method implementations and method overridings, except _predict.
678 # --------------------------------------------------------------------------
680 def _get_value(self):
681 X = cvx2np(self._X._get_value())
682 Y = cvx2np(self._Y._get_value())
684 Dy, Uy = numpy.linalg.eigh(Y)
685 Dy_beta = numpy.power(Dy, (1 - self._alpha) / (2 * self._alpha))
686 Y_beta = Uy @ numpy.diag(Dy_beta) @ Uy.conj().T
688 Dyxy = numpy.linalg.eigvalsh(Y_beta @ X @ Y_beta)
690 s = numpy.sum(numpy.power(Dyxy, self._alpha))
692 return cvxopt.matrix(s)
694 # --------------------------------------------------------------------------
695 # Constraint-creating operators, and _predict.
696 # --------------------------------------------------------------------------
698 @classmethod
699 def _RealEpiConstraint(cls):
700 return SandQuasiEntrEpiConstraint
702 @classmethod
703 def _ComplexEpiConstraint(cls):
704 return ComplexSandQuasiEntrEpiConstraint
706 @classmethod
707 def _RealHypoConstraint(cls):
708 return SandQuasiEntrHypoConstraint
710 @classmethod
711 def _ComplexHypoConstraint(cls):
712 return ComplexSandQuasiEntrHypoConstraint
714# --------------------------------------
715__all__ = api_end(_API_START, globals())