Coverage for picos/expressions/exp_oprelentr.py: 86.67%
150 statements
« prev ^ index » next coverage.py v7.6.12, created at 2025-04-12 07:53 +0000
« prev ^ index » next coverage.py v7.6.12, created at 2025-04-12 07:53 +0000
1# ------------------------------------------------------------------------------
2# Copyright (C) 2024 Kerry He
3#
4# This file is part of PICOS.
5#
6# PICOS is free software: you can redistribute it and/or modify it under the
7# terms of the GNU General Public License as published by the Free Software
8# Foundation, either version 3 of the License, or (at your option) any later
9# version.
10#
11# PICOS is distributed in the hope that it will be useful, but WITHOUT ANY
12# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
13# A PARTICULAR PURPOSE. See the GNU General Public License for more details.
14#
15# You should have received a copy of the GNU General Public License along with
16# this program. If not, see <http://www.gnu.org/licenses/>.
17# ------------------------------------------------------------------------------
19"""Implements :class:`OperatorRelativeEntropy`."""
21import operator
22from collections import namedtuple
24import cvxopt
25import numpy
27from .. import glyphs
28from ..apidoc import api_end, api_start
29from ..caching import cached_unary_operator
30from ..constraints import (
31 OpRelEntropyConstraint,
32 ComplexOpRelEntropyConstraint,
33 TrOpRelEntropyConstraint,
34 ComplexTrOpRelEntropyConstraint,
35)
36from .data import convert_and_refine_arguments, convert_operands, cvx2np
37from .exp_affine import AffineExpression, ComplexAffineExpression
38from .expression import Expression, refine_operands, validate_prediction
40_API_START = api_start(globals())
41# -------------------------------
44class OperatorRelativeEntropy(Expression):
45 r"""Operator relative entropy of an affine expression.
47 :Definition:
49 For :math:`n \times n`-dimensional symmetric or Hermitian matrices
50 :math:`X` and :math:`Y`, this is defined as
52 .. math::
54 X^{1/2} \log(X^{1/2}Y^{-1}X^{1/2}) X^{1/2}.
56 .. warning::
58 When you pose an upper bound on this expression, then PICOS enforces
59 :math:`X \succeq 0` and :math:`Y \succeq 0` through an auxiliary
60 constraint during solution search.
61 """
63 # --------------------------------------------------------------------------
64 # Initialization and factory methods.
65 # --------------------------------------------------------------------------
67 @convert_and_refine_arguments("X", "Y")
68 def __init__(self, X, Y):
69 """Construct an :class:`OperatorRelativeEntropy`.
71 :param X: The affine expression :math:`X`.
72 :type X: ~picos.expressions.AffineExpression
73 :param Y: The affine expression :math:`Y`. This should have the same
74 dimensions as :math:`X`.
75 :type Y: ~picos.expressions.AffineExpression
76 """
77 if not isinstance(X, ComplexAffineExpression):
78 raise TypeError(
79 "Can only take the matrix logarithm of a real "
80 "or complex affine expression, not of {}.".format(
81 type(X).__name__
82 )
83 )
84 if not X.hermitian:
85 raise TypeError(
86 "Can only take the matrix logarithm of a symmetric "
87 "or Hermitian expression, not of {}.".format(type(X).__name__)
88 )
90 if not isinstance(Y, ComplexAffineExpression):
91 raise TypeError(
92 "The additional parameter Y must be a real "
93 "or complex affine expression, not {}.".format(type(Y).__name__)
94 )
95 if not Y.hermitian:
96 raise TypeError(
97 "Can only take the matrix logarithm of a symmetric "
98 "or Hermitian expression, not of {}.".format(type(Y).__name__)
99 )
100 if X.shape != Y.shape:
101 raise TypeError(
102 "The additional parameter Y must be the same shape"
103 "as X, not {}.".format(type(Y).__name__)
104 )
106 self._X = X
107 self._Y = Y
109 self._iscomplex = not isinstance(X, AffineExpression) or \
110 not isinstance(Y, AffineExpression)
112 typeStr = "Operator Relative Entropy"
113 rtxStr = glyphs.power(X.string, "(1/2)")
114 invyStr = glyphs.inverse(Y.string)
115 xyxStr = glyphs.mul(rtxStr, glyphs.mul(invyStr, rtxStr))
116 symbStr = glyphs.mul(rtxStr, glyphs.mul(glyphs.log(xyxStr), rtxStr))
118 Expression.__init__(self, typeStr, symbStr)
120 # --------------------------------------------------------------------------
121 # Abstract method implementations and method overridings, except _predict.
122 # --------------------------------------------------------------------------
124 def _get_refined(self):
125 if self._X.constant and self._Y.constant:
126 return AffineExpression.from_constant(self.value, 1, self._symbStr)
127 else:
128 return self
130 Subtype = namedtuple("Subtype", ("argdim", "iscomplex"))
132 def _get_subtype(self):
133 return self.Subtype(len(self._X), self._iscomplex)
135 def _get_value(self):
136 X = cvx2np(self._X._get_value())
137 Y = cvx2np(self._Y._get_value())
139 Dx, Ux = numpy.linalg.eigh(X)
140 rtX = Ux @ numpy.diag(numpy.sqrt(Dx)) @ Ux.conj().T
141 invY = numpy.linalg.inv(Y)
143 XYX = rtX @ invY @ rtX
144 Dxyx, Uxyx = numpy.linalg.eigh(XYX)
145 logXYX = Uxyx @ numpy.diag(numpy.log(Dxyx)) @ Uxyx.conj().T
147 S = rtX @ logXYX @ rtX
149 return cvxopt.matrix(S)
151 def _get_shape(self):
152 return self._X.shape
154 @cached_unary_operator
155 def _get_mutables(self):
156 return self._X._get_mutables().union(self._Y.mutables)
158 def _is_convex(self):
159 return True
161 def _is_concave(self):
162 return False
164 def _replace_mutables(self, mapping):
165 return self.__class__(
166 self._X._replace_mutables(mapping),
167 self._Y._replace_mutables(mapping),
168 )
170 def _freeze_mutables(self, freeze):
171 return self.__class__(
172 self._X._freeze_mutables(freeze), self._Y._freeze_mutables(freeze)
173 )
175 # --------------------------------------------------------------------------
176 # Methods and properties that return expressions.
177 # --------------------------------------------------------------------------
179 @property
180 def X(self):
181 """The expression :math:`X`."""
182 return self._X
184 @property
185 def Y(self):
186 """The additional expression :math:`Y`."""
187 return self._Y
189 # --------------------------------------------------------------------------
190 # Methods and properties that describe the expression.
191 # --------------------------------------------------------------------------
193 @property
194 def n(self):
195 """Lengths of :attr:`X` and :attr:`Y`."""
196 return len(self._X)
198 @property
199 def iscomplex(self):
200 """Whether :attr:`X` and :attr:`Y` are complex expressions or not."""
201 return self._iscomplex
203 @property
204 def tr(self):
205 """Trace of the operator relative entropy."""
206 return TrOperatorRelativeEntropy(self.X, self.Y)
208 # --------------------------------------------------------------------------
209 # Constraint-creating operators, and _predict.
210 # --------------------------------------------------------------------------
212 @classmethod
213 def _predict(cls, subtype, relation, other):
214 assert isinstance(subtype, cls.Subtype)
216 if relation == operator.__lshift__:
217 if (
218 issubclass(other.clstype, ComplexAffineExpression)
219 and other.subtype.dim == subtype.argdim
220 ):
221 if subtype.iscomplex or not issubclass(
222 other.clstype, AffineExpression
223 ):
224 return ComplexOpRelEntropyConstraint.make_type(
225 argdim=subtype.argdim
226 )
227 else:
228 return OpRelEntropyConstraint.make_type(
229 argdim=subtype.argdim
230 )
231 return NotImplemented
233 def _lshift_implementation(self, other):
234 if isinstance(other, ComplexAffineExpression):
235 if self.iscomplex or not isinstance(other, AffineExpression):
236 return ComplexOpRelEntropyConstraint(self, other)
237 else:
238 return OpRelEntropyConstraint(self, other)
239 else:
240 return NotImplemented
243class TrOperatorRelativeEntropy(OperatorRelativeEntropy):
244 r"""Trace operator relative entropy of an affine expression.
246 :Definition:
248 For :math:`n \times n`-dimensional symmetric or Hermitian matrices
249 :math:`X` and :math:`Y`, this is defined as
251 .. math::
253 \operatorname{Tr}(X^{1/2} \log(X^{1/2}Y^{-1}X^{1/2}) X^{1/2}).
255 .. warning::
257 When you pose an upper bound on this expression, then PICOS enforces
258 :math:`X \succeq 0` and :math:`Y \succeq 0` through an auxiliary
259 constraint during solution search.
260 """
262 # --------------------------------------------------------------------------
263 # Initialization and factory methods.
264 # --------------------------------------------------------------------------
266 @convert_and_refine_arguments("X", "Y")
267 def __init__(self, X, Y):
268 """Construct an :class:`OperatorRelativeEntropy`.
270 :param X: The affine expression :math:`X`.
271 :type X: ~picos.expressions.AffineExpression
272 :param Y: The affine expression :math:`Y`. This should have the same
273 dimensions as :math:`X`.
274 :type Y: ~picos.expressions.AffineExpression
275 """
276 if not isinstance(X, ComplexAffineExpression):
277 raise TypeError(
278 "Can only take the matrix logarithm of a real "
279 "or complex affine expression, not of {}.".format(
280 type(X).__name__
281 )
282 )
283 if not X.hermitian:
284 raise TypeError(
285 "Can only take the matrix logarithm of a symmetric "
286 "or Hermitian expression, not of {}.".format(type(X).__name__)
287 )
289 if not isinstance(Y, ComplexAffineExpression):
290 raise TypeError(
291 "The additional parameter Y must be a real "
292 "or complex affine expression, not {}.".format(type(Y).__name__)
293 )
294 if not Y.hermitian:
295 raise TypeError(
296 "Can only take the matrix logarithm of a symmetric "
297 "or Hermitian expression, not of {}.".format(type(Y).__name__)
298 )
299 if X.shape != Y.shape:
300 raise TypeError(
301 "The additional parameter Y must be the same shape"
302 "as X, not {}.".format(type(Y).__name__)
303 )
305 self._X = X
306 self._Y = Y
308 self._iscomplex = not isinstance(X, AffineExpression) or \
309 not isinstance(Y, AffineExpression)
311 typeStr = "Trace Operator Relative Entropy"
312 rtxStr = glyphs.power(X.string, "(1/2)")
313 invyStr = glyphs.inverse(Y.string)
314 xyxStr = glyphs.mul(rtxStr, glyphs.mul(invyStr, rtxStr))
315 oprelStr = glyphs.mul(rtxStr, glyphs.mul(glyphs.log(xyxStr), rtxStr))
316 symbStr = glyphs.trace(oprelStr)
318 Expression.__init__(self, typeStr, symbStr)
320 # --------------------------------------------------------------------------
321 # Abstract method implementations and method overridings, except _predict.
322 # --------------------------------------------------------------------------
324 def _get_value(self):
325 X = cvx2np(self._X._get_value())
326 Y = cvx2np(self._Y._get_value())
328 Dx, Ux = numpy.linalg.eigh(X)
329 rtX = Ux @ numpy.diag(numpy.sqrt(Dx)) @ Ux.conj().T
330 invY = numpy.linalg.inv(Y)
332 XYX = rtX @ invY @ rtX
333 Dxyx, Uxyx = numpy.linalg.eigh(XYX)
334 logXYX = Uxyx @ numpy.diag(numpy.log(Dxyx)) @ Uxyx.conj().T
336 s = numpy.sum(X * logXYX.conj()).real
338 return cvxopt.matrix(s)
340 def _get_shape(self):
341 return (1, 1)
343 # --------------------------------------------------------------------------
344 # Constraint-creating operators, and _predict.
345 # --------------------------------------------------------------------------
347 @classmethod
348 def _predict(cls, subtype, relation, other):
349 assert isinstance(subtype, cls.Subtype)
351 if relation == operator.__le__:
352 if (
353 issubclass(other.clstype, AffineExpression)
354 and other.subtype.dim == 1
355 ):
356 if subtype.iscomplex:
357 return ComplexTrOpRelEntropyConstraint.make_type(
358 argdim=subtype.argdim
359 )
360 else:
361 return TrOpRelEntropyConstraint.make_type(
362 argdim=subtype.argdim
363 )
364 return NotImplemented
366 @convert_operands(scalarRHS=True)
367 @validate_prediction
368 @refine_operands()
369 def __le__(self, other):
370 if isinstance(other, AffineExpression):
371 if self.iscomplex:
372 return ComplexTrOpRelEntropyConstraint(self, other)
373 else:
374 return TrOpRelEntropyConstraint(self, other)
375 else:
376 return NotImplemented
379# --------------------------------------
380__all__ = api_end(_API_START, globals())