Coverage for picos/expressions/exp_mtxgeomean.py: 88.52%

183 statements  

« prev     ^ index     » next       coverage.py v7.6.12, created at 2025-04-12 07:53 +0000

1# ------------------------------------------------------------------------------ 

2# Copyright (C) 2024 Kerry He 

3# 

4# This file is part of PICOS. 

5# 

6# PICOS is free software: you can redistribute it and/or modify it under the 

7# terms of the GNU General Public License as published by the Free Software 

8# Foundation, either version 3 of the License, or (at your option) any later 

9# version. 

10# 

11# PICOS is distributed in the hope that it will be useful, but WITHOUT ANY 

12# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR 

13# A PARTICULAR PURPOSE. See the GNU General Public License for more details. 

14# 

15# You should have received a copy of the GNU General Public License along with 

16# this program. If not, see <http://www.gnu.org/licenses/>. 

17# ------------------------------------------------------------------------------ 

18 

19"""Implements :class:`MatrixGeometricMean`.""" 

20 

21import operator 

22from collections import namedtuple 

23 

24import cvxopt 

25import numpy 

26 

27from .. import glyphs 

28from ..apidoc import api_end, api_start 

29from ..caching import cached_unary_operator 

30from ..constraints import ( 

31 MatrixGeoMeanEpiConstraint, 

32 ComplexMatrixGeoMeanEpiConstraint, 

33 MatrixGeoMeanHypoConstraint, 

34 ComplexMatrixGeoMeanHypoConstraint, 

35 TrMatrixGeoMeanEpiConstraint, 

36 ComplexTrMatrixGeoMeanEpiConstraint, 

37 TrMatrixGeoMeanHypoConstraint, 

38 ComplexTrMatrixGeoMeanHypoConstraint, 

39) 

40from .data import convert_and_refine_arguments, convert_operands, cvx2np 

41from .exp_affine import AffineExpression, ComplexAffineExpression 

42from .expression import Expression, refine_operands, validate_prediction 

43 

44_API_START = api_start(globals()) 

45# ------------------------------- 

46 

47 

48class MatrixGeometricMean(Expression): 

49 r"""Matrix geometric mean of an affine expression. 

50 

51 :Definition: 

52 

53 For :math:`n \times n`-dimensional symmetric or Hermitian matrices 

54 :math:`X` and :math:`Y`, this is defined as 

55 

56 .. math:: 

57 

58 X^{1/2} (X^{-1/2}Y^{-1}X^{-1/2})^p X^{1/2}. 

59 

60 for a given scalar :math:`p\in[-1, 2]`, where :math:`p=1/2` by default. 

61 

62 .. warning:: 

63 

64 When you pose an upper or lower bound on this expression, then PICOS 

65 enforces :math:`X \succeq 0` and :math:`Y \succeq 0` through an 

66 auxiliary constraint during solution search. 

67 """ 

68 

69 # -------------------------------------------------------------------------- 

70 # Initialization and factory methods. 

71 # -------------------------------------------------------------------------- 

72 

73 @convert_and_refine_arguments("X", "Y") 

74 def __init__(self, X, Y, power=0.5): 

75 """Construct an :class:`MatrixGeometricMean`. 

76 

77 :param X: The affine expression :math:`X`. 

78 :type X: ~picos.expressions.AffineExpression 

79 :param Y: The affine expression :math:`Y`. This should have the same 

80 dimensions as :math:`X`. 

81 :type Y: ~picos.expressions.AffineExpression 

82 """ 

83 if not isinstance(X, ComplexAffineExpression): 

84 raise TypeError( 

85 "Can only take the matrix powers of a real " 

86 "or complex affine expression, not of {}.".format( 

87 type(X).__name__ 

88 ) 

89 ) 

90 if not X.hermitian: 

91 raise TypeError( 

92 "Can only take the matrix powers of a symmetric " 

93 "or Hermitian expression, not of {}.".format(type(X).__name__) 

94 ) 

95 

96 if not isinstance(Y, ComplexAffineExpression): 

97 raise TypeError( 

98 "The additional parameter Y must be a real " 

99 "or complex affine expression, not {}.".format(type(Y).__name__) 

100 ) 

101 if not Y.hermitian: 

102 raise TypeError( 

103 "Can only take the matrix powers of a symmetric " 

104 "or Hermitian expression, not of {}.".format(type(Y).__name__) 

105 ) 

106 if X.shape != Y.shape: 

107 raise TypeError( 

108 "The additional parameter Y must be the same shape" 

109 "as X, not {}.".format(type(Y).__name__) 

110 ) 

111 

112 if not (numpy.isscalar(power) and -1 <= power and power <= 2): 

113 raise TypeError("The exponent p must be a scalar between [-1, 2]") 

114 

115 self._X = X 

116 self._Y = Y 

117 self._power = power 

118 

119 self._iscomplex = not isinstance(X, AffineExpression) or \ 

120 not isinstance(Y, AffineExpression) 

121 

122 typeStr = "Matrix Geometric Mean" 

123 if power == 0.5: 

124 symbStr = glyphs.geomean(X.string, Y.string) 

125 else: 

126 symbStr = glyphs.wgeomean(X.string, str(power), Y.string) 

127 

128 Expression.__init__(self, typeStr, symbStr) 

129 

130 # -------------------------------------------------------------------------- 

131 # Abstract method implementations and method overridings, except _predict. 

132 # -------------------------------------------------------------------------- 

133 

134 def _get_refined(self): 

135 if self._X.constant and self._Y.constant: 

136 return AffineExpression.from_constant(self.value, 1, self._symbStr) 

137 else: 

138 return self 

139 

140 Subtype = namedtuple("Subtype", ("argdim", "power", "iscomplex")) 

141 

142 def _get_subtype(self): 

143 return self.Subtype(len(self._X), self._power, self._iscomplex) 

144 

145 def _get_value(self): 

146 X = cvx2np(self._X._get_value()) 

147 Y = cvx2np(self._Y._get_value()) 

148 

149 Dx, Ux = numpy.linalg.eigh(X) 

150 rtX = Ux @ numpy.diag(numpy.sqrt(Dx)) @ Ux.conj().T 

151 irtX = Ux @ numpy.diag(numpy.reciprocal(numpy.sqrt(Dx))) @ Ux.conj().T 

152 

153 XYX = irtX @ Y @ irtX 

154 Dxyx, Uxyx = numpy.linalg.eigh(XYX) 

155 XYX_p = Uxyx @ numpy.diag(numpy.power(Dxyx, self.power)) @ Uxyx.conj().T 

156 

157 S = rtX @ XYX_p @ rtX 

158 

159 return cvxopt.matrix(S) 

160 

161 def _get_shape(self): 

162 return self._X.shape 

163 

164 @cached_unary_operator 

165 def _get_mutables(self): 

166 return self._X._get_mutables().union(self._Y.mutables) 

167 

168 def _is_convex(self): 

169 return (-1 <= self._power and self._power <= 0) or \ 

170 ( 1 <= self._power and self._power <= 2) 

171 

172 def _is_concave(self): 

173 return 0 <= self._power and self._power <= 1 

174 

175 def _replace_mutables(self, mapping): 

176 return self.__class__( 

177 self._X._replace_mutables(mapping), 

178 self._Y._replace_mutables(mapping), 

179 ) 

180 

181 def _freeze_mutables(self, freeze): 

182 return self.__class__( 

183 self._X._freeze_mutables(freeze), self._Y._freeze_mutables(freeze) 

184 ) 

185 

186 # -------------------------------------------------------------------------- 

187 # Methods and properties that return expressions. 

188 # -------------------------------------------------------------------------- 

189 

190 @property 

191 def X(self): 

192 """The expression :math:`X`.""" 

193 return self._X 

194 

195 @property 

196 def Y(self): 

197 """The additional expression :math:`Y`.""" 

198 return self._Y 

199 

200 @property 

201 def power(self): 

202 """The power :math:`p`.""" 

203 return self._power 

204 

205 @property 

206 def tr(self): 

207 """Trace of the matrix geometric mean.""" 

208 return TrMatrixGeometricMean(self.X, self.Y, self.power) 

209 

210 # -------------------------------------------------------------------------- 

211 # Methods and properties that describe the expression. 

212 # -------------------------------------------------------------------------- 

213 

214 @property 

215 def n(self): 

216 """Lengths of :attr:`X` and :attr:`Y`.""" 

217 return len(self._X) 

218 

219 @property 

220 def iscomplex(self): 

221 """Whether :attr:`X` and :attr:`Y` are complex expressions or not.""" 

222 return self._iscomplex 

223 

224 # -------------------------------------------------------------------------- 

225 # Constraint-creating operators, and _predict. 

226 # -------------------------------------------------------------------------- 

227 

228 @classmethod 

229 def _predict(cls, subtype, relation, other): 

230 assert isinstance(subtype, cls.Subtype) 

231 

232 isconvex = (-1 <= subtype.power and subtype.power <= 0) or \ 

233 ( 1 <= subtype.power and subtype.power <= 2) 

234 isconcave = 0 <= subtype.power and subtype.power <= 1 

235 

236 if relation == operator.__lshift__: 

237 if ( 

238 isconvex 

239 and issubclass(other.clstype, ComplexAffineExpression) 

240 and other.subtype.dim == subtype.argdim 

241 ): 

242 if subtype.iscomplex or not issubclass( 

243 other.clstype, AffineExpression 

244 ): 

245 return ComplexMatrixGeoMeanEpiConstraint.make_type( 

246 argdim=subtype.argdim 

247 ) 

248 else: 

249 return MatrixGeoMeanEpiConstraint.make_type( 

250 argdim=subtype.argdim 

251 ) 

252 

253 if relation == operator.__rshift__: 

254 if ( 

255 isconcave 

256 and issubclass(other.clstype, ComplexAffineExpression) 

257 and other.subtype.dim == subtype.argdim 

258 ): 

259 if subtype.iscomplex or not issubclass( 

260 other.clstype, AffineExpression 

261 ): 

262 return ComplexMatrixGeoMeanHypoConstraint.make_type( 

263 argdim=subtype.argdim 

264 ) 

265 else: 

266 return MatrixGeoMeanHypoConstraint.make_type( 

267 argdim=subtype.argdim 

268 ) 

269 

270 return NotImplemented 

271 

272 def _lshift_implementation(self, other): 

273 if self.convex and isinstance(other, ComplexAffineExpression): 

274 if self.iscomplex or not isinstance(other, AffineExpression): 

275 return ComplexMatrixGeoMeanEpiConstraint(self, other) 

276 else: 

277 return MatrixGeoMeanEpiConstraint(self, other) 

278 else: 

279 return NotImplemented 

280 

281 def _rshift_implementation(self, other): 

282 if self.concave and isinstance(other, ComplexAffineExpression): 

283 if self.iscomplex or not isinstance(other, AffineExpression): 

284 return ComplexMatrixGeoMeanHypoConstraint(self, other) 

285 else: 

286 return MatrixGeoMeanHypoConstraint(self, other) 

287 else: 

288 return NotImplemented 

289 

290 

291class TrMatrixGeometricMean(MatrixGeometricMean): 

292 r"""Trace matrix geometric mean of an affine expression. 

293 

294 :Definition: 

295 

296 For :math:`n \times n`-dimensional symmetric or Hermitian matrices 

297 :math:`X` and :math:`Y`, this is defined as 

298 

299 .. math:: 

300 

301 \operatorname{Tr}(X^{1/2} (X^{-1/2}Y^{-1}X^{-1/2})^p X^{1/2}). 

302 

303 for a given scalar :math:`p\in[-1, 2]`, where :math:`p=1/2` by default. 

304 

305 .. warning:: 

306 

307 When you pose an upper or lower bound on this expression, then PICOS 

308 enforces :math:`X \succeq 0` and :math:`Y \succeq 0` through an 

309 auxiliary constraint during solution search. 

310 """ 

311 

312 # -------------------------------------------------------------------------- 

313 # Initialization and factory methods. 

314 # -------------------------------------------------------------------------- 

315 

316 @convert_and_refine_arguments("X", "Y") 

317 def __init__(self, X, Y, power=0.5): 

318 """Construct an :class:`MatrixGeometricMean`. 

319 

320 :param X: The affine expression :math:`X`. 

321 :type X: ~picos.expressions.AffineExpression 

322 :param Y: The affine expression :math:`Y`. This should have the same 

323 dimensions as :math:`X`. 

324 :type Y: ~picos.expressions.AffineExpression 

325 """ 

326 if not isinstance(X, ComplexAffineExpression): 

327 raise TypeError( 

328 "Can only take the matrix powers of a real " 

329 "or complex affine expression, not of {}.".format( 

330 type(X).__name__ 

331 ) 

332 ) 

333 if not X.hermitian: 

334 raise TypeError( 

335 "Can only take the matrix powers of a symmetric " 

336 "or Hermitian expression, not of {}.".format(type(X).__name__) 

337 ) 

338 

339 if not isinstance(Y, ComplexAffineExpression): 

340 raise TypeError( 

341 "The additional parameter Y must be a real " 

342 "or complex affine expression, not {}.".format(type(Y).__name__) 

343 ) 

344 if not Y.hermitian: 

345 raise TypeError( 

346 "Can only take the matrix powers of a symmetric " 

347 "or Hermitian expression, not of {}.".format(type(Y).__name__) 

348 ) 

349 if X.shape != Y.shape: 

350 raise TypeError( 

351 "The additional parameter Y must be the same shape" 

352 "as X, not {}.".format(type(Y).__name__) 

353 ) 

354 

355 self._X = X 

356 self._Y = Y 

357 self._power = power 

358 

359 self._iscomplex = not isinstance(X, AffineExpression) or \ 

360 not isinstance(Y, AffineExpression) 

361 

362 typeStr = "Trace Matrix Geometric Mean" 

363 if power == 0.5: 

364 symbStr = glyphs.trace(glyphs.geomean(X.string, Y.string)) 

365 else: 

366 pStr = str(power) 

367 symbStr = glyphs.trace(glyphs.wgeomean(X.string, pStr, Y.string)) 

368 

369 Expression.__init__(self, typeStr, symbStr) 

370 

371 # -------------------------------------------------------------------------- 

372 # Abstract method implementations and method overridings, except _predict. 

373 # -------------------------------------------------------------------------- 

374 

375 def _get_value(self): 

376 X = cvx2np(self._X._get_value()) 

377 Y = cvx2np(self._Y._get_value()) 

378 

379 Dx, Ux = numpy.linalg.eigh(X) 

380 irtX = Ux @ numpy.diag(numpy.reciprocal(numpy.sqrt(Dx))) @ Ux.conj().T 

381 

382 XYX = irtX @ Y @ irtX 

383 Dxyx, Uxyx = numpy.linalg.eigh(XYX) 

384 XYX_p = Uxyx @ numpy.diag(numpy.power(Dxyx, self.power)) @ Uxyx.conj().T 

385 

386 s = numpy.sum(X * XYX_p.conj()).real 

387 

388 return cvxopt.matrix(s) 

389 

390 def _get_shape(self): 

391 return (1, 1) 

392 

393 # -------------------------------------------------------------------------- 

394 # Constraint-creating operators, and _predict. 

395 # -------------------------------------------------------------------------- 

396 

397 @classmethod 

398 def _predict(cls, subtype, relation, other): 

399 assert isinstance(subtype, cls.Subtype) 

400 

401 isconvex = (-1 <= subtype.power and subtype.power <= 0) or \ 

402 ( 1 <= subtype.power and subtype.power <= 2) 

403 isconcave = 0 <= subtype.power and subtype.power <= 1 

404 

405 if relation == operator.__le__: 

406 if ( 

407 isconvex 

408 and issubclass(other.clstype, AffineExpression) 

409 and other.subtype.dim == 1 

410 ): 

411 if subtype.iscomplex or not issubclass( 

412 other.clstype, AffineExpression 

413 ): 

414 return ComplexTrMatrixGeoMeanEpiConstraint.make_type( 

415 argdim=subtype.argdim 

416 ) 

417 else: 

418 return TrMatrixGeoMeanEpiConstraint.make_type( 

419 argdim=subtype.argdim 

420 ) 

421 

422 if relation == operator.__ge__: 

423 if ( 

424 isconcave 

425 and issubclass(other.clstype, AffineExpression) 

426 and other.subtype.dim == 1 

427 ): 

428 if subtype.iscomplex or not issubclass( 

429 other.clstype, AffineExpression 

430 ): 

431 return ComplexTrMatrixGeoMeanHypoConstraint.make_type( 

432 argdim=subtype.argdim 

433 ) 

434 else: 

435 return TrMatrixGeoMeanHypoConstraint.make_type( 

436 argdim=subtype.argdim 

437 ) 

438 

439 return NotImplemented 

440 

441 @convert_operands(scalarRHS=True) 

442 @validate_prediction 

443 @refine_operands() 

444 def __le__(self, other): 

445 if self.convex and isinstance(other, AffineExpression): 

446 if self.iscomplex: 

447 return ComplexTrMatrixGeoMeanEpiConstraint(self, other) 

448 else: 

449 return TrMatrixGeoMeanEpiConstraint(self, other) 

450 else: 

451 return NotImplemented 

452 

453 @convert_operands(scalarRHS=True) 

454 @validate_prediction 

455 @refine_operands() 

456 def __ge__(self, other): 

457 if self.concave and isinstance(other, AffineExpression): 

458 if self.iscomplex: 

459 return ComplexTrMatrixGeoMeanHypoConstraint(self, other) 

460 else: 

461 return TrMatrixGeoMeanHypoConstraint(self, other) 

462 else: 

463 return NotImplemented 

464 

465 

466# -------------------------------------- 

467__all__ = api_end(_API_START, globals())