Coverage for picos/expressions/exp_oprelentr.py: 86.67%

150 statements  

« prev     ^ index     » next       coverage.py v7.6.12, created at 2025-04-12 07:53 +0000

1# ------------------------------------------------------------------------------ 

2# Copyright (C) 2024 Kerry He 

3# 

4# This file is part of PICOS. 

5# 

6# PICOS is free software: you can redistribute it and/or modify it under the 

7# terms of the GNU General Public License as published by the Free Software 

8# Foundation, either version 3 of the License, or (at your option) any later 

9# version. 

10# 

11# PICOS is distributed in the hope that it will be useful, but WITHOUT ANY 

12# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR 

13# A PARTICULAR PURPOSE. See the GNU General Public License for more details. 

14# 

15# You should have received a copy of the GNU General Public License along with 

16# this program. If not, see <http://www.gnu.org/licenses/>. 

17# ------------------------------------------------------------------------------ 

18 

19"""Implements :class:`OperatorRelativeEntropy`.""" 

20 

21import operator 

22from collections import namedtuple 

23 

24import cvxopt 

25import numpy 

26 

27from .. import glyphs 

28from ..apidoc import api_end, api_start 

29from ..caching import cached_unary_operator 

30from ..constraints import ( 

31 OpRelEntropyConstraint, 

32 ComplexOpRelEntropyConstraint, 

33 TrOpRelEntropyConstraint, 

34 ComplexTrOpRelEntropyConstraint, 

35) 

36from .data import convert_and_refine_arguments, convert_operands, cvx2np 

37from .exp_affine import AffineExpression, ComplexAffineExpression 

38from .expression import Expression, refine_operands, validate_prediction 

39 

40_API_START = api_start(globals()) 

41# ------------------------------- 

42 

43 

44class OperatorRelativeEntropy(Expression): 

45 r"""Operator relative entropy of an affine expression. 

46 

47 :Definition: 

48 

49 For :math:`n \times n`-dimensional symmetric or Hermitian matrices 

50 :math:`X` and :math:`Y`, this is defined as 

51 

52 .. math:: 

53 

54 X^{1/2} \log(X^{1/2}Y^{-1}X^{1/2}) X^{1/2}. 

55 

56 .. warning:: 

57 

58 When you pose an upper bound on this expression, then PICOS enforces 

59 :math:`X \succeq 0` and :math:`Y \succeq 0` through an auxiliary 

60 constraint during solution search. 

61 """ 

62 

63 # -------------------------------------------------------------------------- 

64 # Initialization and factory methods. 

65 # -------------------------------------------------------------------------- 

66 

67 @convert_and_refine_arguments("X", "Y") 

68 def __init__(self, X, Y): 

69 """Construct an :class:`OperatorRelativeEntropy`. 

70 

71 :param X: The affine expression :math:`X`. 

72 :type X: ~picos.expressions.AffineExpression 

73 :param Y: The affine expression :math:`Y`. This should have the same 

74 dimensions as :math:`X`. 

75 :type Y: ~picos.expressions.AffineExpression 

76 """ 

77 if not isinstance(X, ComplexAffineExpression): 

78 raise TypeError( 

79 "Can only take the matrix logarithm of a real " 

80 "or complex affine expression, not of {}.".format( 

81 type(X).__name__ 

82 ) 

83 ) 

84 if not X.hermitian: 

85 raise TypeError( 

86 "Can only take the matrix logarithm of a symmetric " 

87 "or Hermitian expression, not of {}.".format(type(X).__name__) 

88 ) 

89 

90 if not isinstance(Y, ComplexAffineExpression): 

91 raise TypeError( 

92 "The additional parameter Y must be a real " 

93 "or complex affine expression, not {}.".format(type(Y).__name__) 

94 ) 

95 if not Y.hermitian: 

96 raise TypeError( 

97 "Can only take the matrix logarithm of a symmetric " 

98 "or Hermitian expression, not of {}.".format(type(Y).__name__) 

99 ) 

100 if X.shape != Y.shape: 

101 raise TypeError( 

102 "The additional parameter Y must be the same shape" 

103 "as X, not {}.".format(type(Y).__name__) 

104 ) 

105 

106 self._X = X 

107 self._Y = Y 

108 

109 self._iscomplex = not isinstance(X, AffineExpression) or \ 

110 not isinstance(Y, AffineExpression) 

111 

112 typeStr = "Operator Relative Entropy" 

113 rtxStr = glyphs.power(X.string, "(1/2)") 

114 invyStr = glyphs.inverse(Y.string) 

115 xyxStr = glyphs.mul(rtxStr, glyphs.mul(invyStr, rtxStr)) 

116 symbStr = glyphs.mul(rtxStr, glyphs.mul(glyphs.log(xyxStr), rtxStr)) 

117 

118 Expression.__init__(self, typeStr, symbStr) 

119 

120 # -------------------------------------------------------------------------- 

121 # Abstract method implementations and method overridings, except _predict. 

122 # -------------------------------------------------------------------------- 

123 

124 def _get_refined(self): 

125 if self._X.constant and self._Y.constant: 

126 return AffineExpression.from_constant(self.value, 1, self._symbStr) 

127 else: 

128 return self 

129 

130 Subtype = namedtuple("Subtype", ("argdim", "iscomplex")) 

131 

132 def _get_subtype(self): 

133 return self.Subtype(len(self._X), self._iscomplex) 

134 

135 def _get_value(self): 

136 X = cvx2np(self._X._get_value()) 

137 Y = cvx2np(self._Y._get_value()) 

138 

139 Dx, Ux = numpy.linalg.eigh(X) 

140 rtX = Ux @ numpy.diag(numpy.sqrt(Dx)) @ Ux.conj().T 

141 invY = numpy.linalg.inv(Y) 

142 

143 XYX = rtX @ invY @ rtX 

144 Dxyx, Uxyx = numpy.linalg.eigh(XYX) 

145 logXYX = Uxyx @ numpy.diag(numpy.log(Dxyx)) @ Uxyx.conj().T 

146 

147 S = rtX @ logXYX @ rtX 

148 

149 return cvxopt.matrix(S) 

150 

151 def _get_shape(self): 

152 return self._X.shape 

153 

154 @cached_unary_operator 

155 def _get_mutables(self): 

156 return self._X._get_mutables().union(self._Y.mutables) 

157 

158 def _is_convex(self): 

159 return True 

160 

161 def _is_concave(self): 

162 return False 

163 

164 def _replace_mutables(self, mapping): 

165 return self.__class__( 

166 self._X._replace_mutables(mapping), 

167 self._Y._replace_mutables(mapping), 

168 ) 

169 

170 def _freeze_mutables(self, freeze): 

171 return self.__class__( 

172 self._X._freeze_mutables(freeze), self._Y._freeze_mutables(freeze) 

173 ) 

174 

175 # -------------------------------------------------------------------------- 

176 # Methods and properties that return expressions. 

177 # -------------------------------------------------------------------------- 

178 

179 @property 

180 def X(self): 

181 """The expression :math:`X`.""" 

182 return self._X 

183 

184 @property 

185 def Y(self): 

186 """The additional expression :math:`Y`.""" 

187 return self._Y 

188 

189 # -------------------------------------------------------------------------- 

190 # Methods and properties that describe the expression. 

191 # -------------------------------------------------------------------------- 

192 

193 @property 

194 def n(self): 

195 """Lengths of :attr:`X` and :attr:`Y`.""" 

196 return len(self._X) 

197 

198 @property 

199 def iscomplex(self): 

200 """Whether :attr:`X` and :attr:`Y` are complex expressions or not.""" 

201 return self._iscomplex 

202 

203 @property 

204 def tr(self): 

205 """Trace of the operator relative entropy.""" 

206 return TrOperatorRelativeEntropy(self.X, self.Y) 

207 

208 # -------------------------------------------------------------------------- 

209 # Constraint-creating operators, and _predict. 

210 # -------------------------------------------------------------------------- 

211 

212 @classmethod 

213 def _predict(cls, subtype, relation, other): 

214 assert isinstance(subtype, cls.Subtype) 

215 

216 if relation == operator.__lshift__: 

217 if ( 

218 issubclass(other.clstype, ComplexAffineExpression) 

219 and other.subtype.dim == subtype.argdim 

220 ): 

221 if subtype.iscomplex or not issubclass( 

222 other.clstype, AffineExpression 

223 ): 

224 return ComplexOpRelEntropyConstraint.make_type( 

225 argdim=subtype.argdim 

226 ) 

227 else: 

228 return OpRelEntropyConstraint.make_type( 

229 argdim=subtype.argdim 

230 ) 

231 return NotImplemented 

232 

233 def _lshift_implementation(self, other): 

234 if isinstance(other, ComplexAffineExpression): 

235 if self.iscomplex or not isinstance(other, AffineExpression): 

236 return ComplexOpRelEntropyConstraint(self, other) 

237 else: 

238 return OpRelEntropyConstraint(self, other) 

239 else: 

240 return NotImplemented 

241 

242 

243class TrOperatorRelativeEntropy(OperatorRelativeEntropy): 

244 r"""Trace operator relative entropy of an affine expression. 

245 

246 :Definition: 

247 

248 For :math:`n \times n`-dimensional symmetric or Hermitian matrices 

249 :math:`X` and :math:`Y`, this is defined as 

250 

251 .. math:: 

252 

253 \operatorname{Tr}(X^{1/2} \log(X^{1/2}Y^{-1}X^{1/2}) X^{1/2}). 

254 

255 .. warning:: 

256 

257 When you pose an upper bound on this expression, then PICOS enforces 

258 :math:`X \succeq 0` and :math:`Y \succeq 0` through an auxiliary 

259 constraint during solution search. 

260 """ 

261 

262 # -------------------------------------------------------------------------- 

263 # Initialization and factory methods. 

264 # -------------------------------------------------------------------------- 

265 

266 @convert_and_refine_arguments("X", "Y") 

267 def __init__(self, X, Y): 

268 """Construct an :class:`OperatorRelativeEntropy`. 

269 

270 :param X: The affine expression :math:`X`. 

271 :type X: ~picos.expressions.AffineExpression 

272 :param Y: The affine expression :math:`Y`. This should have the same 

273 dimensions as :math:`X`. 

274 :type Y: ~picos.expressions.AffineExpression 

275 """ 

276 if not isinstance(X, ComplexAffineExpression): 

277 raise TypeError( 

278 "Can only take the matrix logarithm of a real " 

279 "or complex affine expression, not of {}.".format( 

280 type(X).__name__ 

281 ) 

282 ) 

283 if not X.hermitian: 

284 raise TypeError( 

285 "Can only take the matrix logarithm of a symmetric " 

286 "or Hermitian expression, not of {}.".format(type(X).__name__) 

287 ) 

288 

289 if not isinstance(Y, ComplexAffineExpression): 

290 raise TypeError( 

291 "The additional parameter Y must be a real " 

292 "or complex affine expression, not {}.".format(type(Y).__name__) 

293 ) 

294 if not Y.hermitian: 

295 raise TypeError( 

296 "Can only take the matrix logarithm of a symmetric " 

297 "or Hermitian expression, not of {}.".format(type(Y).__name__) 

298 ) 

299 if X.shape != Y.shape: 

300 raise TypeError( 

301 "The additional parameter Y must be the same shape" 

302 "as X, not {}.".format(type(Y).__name__) 

303 ) 

304 

305 self._X = X 

306 self._Y = Y 

307 

308 self._iscomplex = not isinstance(X, AffineExpression) or \ 

309 not isinstance(Y, AffineExpression) 

310 

311 typeStr = "Trace Operator Relative Entropy" 

312 rtxStr = glyphs.power(X.string, "(1/2)") 

313 invyStr = glyphs.inverse(Y.string) 

314 xyxStr = glyphs.mul(rtxStr, glyphs.mul(invyStr, rtxStr)) 

315 oprelStr = glyphs.mul(rtxStr, glyphs.mul(glyphs.log(xyxStr), rtxStr)) 

316 symbStr = glyphs.trace(oprelStr) 

317 

318 Expression.__init__(self, typeStr, symbStr) 

319 

320 # -------------------------------------------------------------------------- 

321 # Abstract method implementations and method overridings, except _predict. 

322 # -------------------------------------------------------------------------- 

323 

324 def _get_value(self): 

325 X = cvx2np(self._X._get_value()) 

326 Y = cvx2np(self._Y._get_value()) 

327 

328 Dx, Ux = numpy.linalg.eigh(X) 

329 rtX = Ux @ numpy.diag(numpy.sqrt(Dx)) @ Ux.conj().T 

330 invY = numpy.linalg.inv(Y) 

331 

332 XYX = rtX @ invY @ rtX 

333 Dxyx, Uxyx = numpy.linalg.eigh(XYX) 

334 logXYX = Uxyx @ numpy.diag(numpy.log(Dxyx)) @ Uxyx.conj().T 

335 

336 s = numpy.sum(X * logXYX.conj()).real 

337 

338 return cvxopt.matrix(s) 

339 

340 def _get_shape(self): 

341 return (1, 1) 

342 

343 # -------------------------------------------------------------------------- 

344 # Constraint-creating operators, and _predict. 

345 # -------------------------------------------------------------------------- 

346 

347 @classmethod 

348 def _predict(cls, subtype, relation, other): 

349 assert isinstance(subtype, cls.Subtype) 

350 

351 if relation == operator.__le__: 

352 if ( 

353 issubclass(other.clstype, AffineExpression) 

354 and other.subtype.dim == 1 

355 ): 

356 if subtype.iscomplex: 

357 return ComplexTrOpRelEntropyConstraint.make_type( 

358 argdim=subtype.argdim 

359 ) 

360 else: 

361 return TrOpRelEntropyConstraint.make_type( 

362 argdim=subtype.argdim 

363 ) 

364 return NotImplemented 

365 

366 @convert_operands(scalarRHS=True) 

367 @validate_prediction 

368 @refine_operands() 

369 def __le__(self, other): 

370 if isinstance(other, AffineExpression): 

371 if self.iscomplex: 

372 return ComplexTrOpRelEntropyConstraint(self, other) 

373 else: 

374 return TrOpRelEntropyConstraint(self, other) 

375 else: 

376 return NotImplemented 

377 

378 

379# -------------------------------------- 

380__all__ = api_end(_API_START, globals())