Coverage for picos/expressions/exp_sumxtr.py: 84.49%

187 statements  

« prev     ^ index     » next       coverage.py v6.5.0, created at 2023-03-26 07:46 +0000

1# ------------------------------------------------------------------------------ 

2# Copyright (C) 2019 Maximilian Stahlberg 

3# Based on the original picos.expressions module by Guillaume Sagnol. 

4# 

5# This file is part of PICOS. 

6# 

7# PICOS is free software: you can redistribute it and/or modify it under the 

8# terms of the GNU General Public License as published by the Free Software 

9# Foundation, either version 3 of the License, or (at your option) any later 

10# version. 

11# 

12# PICOS is distributed in the hope that it will be useful, but WITHOUT ANY 

13# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR 

14# A PARTICULAR PURPOSE. See the GNU General Public License for more details. 

15# 

16# You should have received a copy of the GNU General Public License along with 

17# this program. If not, see <http://www.gnu.org/licenses/>. 

18# ------------------------------------------------------------------------------ 

19 

20"""Implements :class:`SumExtremes`.""" 

21 

22import operator 

23from collections import namedtuple 

24 

25import cvxopt 

26import numpy 

27 

28from .. import glyphs 

29from ..apidoc import api_end, api_start 

30from ..constraints import Constraint, SumExtremesConstraint 

31from .data import convert_and_refine_arguments, convert_operands, cvx2np 

32from .exp_affine import AffineExpression, ComplexAffineExpression 

33from .expression import Expression, refine_operands, validate_prediction 

34 

35_API_START = api_start(globals()) 

36# ------------------------------- 

37 

38 

39class SumExtremes(Expression): 

40 r"""Sum of the :math:`k` largest or smallest elements or eigenvalues. 

41 

42 :Definition: 

43 

44 Let :math:`k \in \mathbb{Z}_{\geq 1}`. 

45 

46 1. If :math:`x` is an :math:`n`-dimensional real vector or matrix and 

47 ``eigenvalues == False``, then this is the sum of the :math:`k \leq n` 

48 largest or smallest scalar elements of :math:`x`, depending on the truth 

49 value of ``largest``. 

50 

51 Special cases: 

52 

53 - If :math:`k = 1`, this is either the largest element 

54 :math:`\max_{i = 1}^n \operatorname{vec}(x)_i` or the smallest 

55 element :math:`\min_{i = 1}^n \operatorname{vec}(x)_i` of :math:`x`. 

56 - If :math:`k = n`, this is the sum of all elements 

57 :math:`\langle x, 1 \rangle` of :math:`x`. 

58 

59 2. If :math:`X` is an :math:`n \times n` hermitian matrix and 

60 ``eigenvalues == True``, then this is the sum of the :math:`k \leq n` 

61 largest or smallest eigenvalues of :math:`X`, depending on the truth 

62 value of ``largest``. Recall that the eigenvalues of a hermitian matrix 

63 are real. 

64 

65 Special cases: 

66 

67 - If :math:`k = 1`, this is either the largest eigenvalue 

68 :math:`\lambda_{\max}(X)` or the smallest eigenvalue 

69 :math:`\lambda_{\min}(X)` of :math:`X`. 

70 - If :math:`k = n`, this equals the trace 

71 :math:`\operatorname{tr}(X)`. 

72 

73 If the given :math:`k` exceeds the :math:`n` of either case, then :math:`k` 

74 is silently clipped to :math:`n`. 

75 """ 

76 

77 # -------------------------------------------------------------------------- 

78 # Initialization and factory methods. 

79 # -------------------------------------------------------------------------- 

80 

81 @convert_and_refine_arguments("x") 

82 def __init__(self, x, k, largest, eigenvalues=False): 

83 """Construct a :class:`SumExtremes`. 

84 

85 :param x: The affine expression to take a sum over. 

86 :type x: ~picos.expressions.ComplexAffineExpression 

87 :param int k: Number of summands. 

88 :param bool largest: Whether to sum over the largest (eigen)values as 

89 opposed to the smallest. 

90 :param bool eigenvalues: Whether to sum eigenvalues instead of elements. 

91 """ 

92 largest = bool(largest) 

93 eigenvalues = bool(eigenvalues) 

94 

95 lStr = "largest" if largest else "smallest" 

96 eStr = "eigenvalues" if eigenvalues else "scalar elements" 

97 what = "{} {}".format(lStr, eStr) 

98 

99 # Validate x. 

100 if not isinstance(x, ComplexAffineExpression): 

101 raise TypeError("Can only sum {} of an affine expression, not of " 

102 "{}.".format(what, type(x).__name__)) 

103 

104 # Further validate x. 

105 if eigenvalues: 

106 if not x.square: 

107 raise TypeError("Cannot sum {} of {} as its shape of {} is not " 

108 "square.".format(what, x.string, glyphs.shape(x.shape))) 

109 elif not x.hermitian: 

110 raise NotImplementedError( 

111 "Summing the {0} of {1} is not supported as {1} is not " 

112 "necessarily hermitian.".format(what, x.string)) 

113 else: 

114 if not isinstance(x, AffineExpression): 

115 raise TypeError("Can only sum {} of a real-valued expression " 

116 "but {} is properly complex.".format(what, x.string)) 

117 

118 # Validate k. 

119 if int(k) != k: 

120 raise ValueError( 

121 "Conversion of k = {} to an integer is ambiguous.".format(k)) 

122 k = int(k) 

123 if k < 1: 

124 raise ValueError( 

125 "Number of {} to sum must be positive.".format(what)) 

126 

127 # Clip k to be at most n. 

128 k = min(k, x.shape[0]) if eigenvalues else min(k, len(x)) 

129 

130 # Find out if all (eigen)values are summed. 

131 full = k == x.shape[0] if eigenvalues else k == len(x) 

132 assert len(x) != 1 or full 

133 

134 self._x = x 

135 self._k = k 

136 self._largest = largest 

137 self._eigenvalues = eigenvalues 

138 self._full = full 

139 

140 s, lbd = x.string, glyphs.lambda_() 

141 if full: 

142 if eigenvalues: 

143 typeStr = "Sum of Eigenvalues" 

144 symbStr = symbStr = glyphs.trace(s) 

145 else: 

146 typeStr = "Sum of Elements" 

147 symbStr = glyphs.sum(s) 

148 elif k > 1: 

149 if eigenvalues and largest: 

150 typeStr = "Sum of Largest Eigenvalues" 

151 symbStr = glyphs.make_function( 

152 "sum_{}_largest_{}".format(k, lbd))(s) 

153 elif eigenvalues and not largest: 

154 typeStr = "Sum of Smallest Eigenvalues" 

155 symbStr = glyphs.make_function( 

156 "sum_{}_smallest_{}".format(k, lbd))(s) 

157 elif not eigenvalues and largest: 

158 typeStr = "Sum of Largest Elements" 

159 symbStr = glyphs.make_function("sum_{}_largest".format(k))(s) 

160 else: 

161 typeStr = "Sum of Smallest Elements" 

162 symbStr = glyphs.make_function("sum_{}_smallest".format(k))(s) 

163 else: 

164 if eigenvalues and largest: 

165 typeStr = "Largest Eigenvalue" 

166 symbStr = glyphs.make_function("{}_max".format(lbd))(s) 

167 elif eigenvalues and not largest: 

168 typeStr = "Smallest Eigenvalue" 

169 symbStr = glyphs.make_function("{}_min".format(lbd))(s) 

170 elif not eigenvalues and largest: 

171 typeStr = "Largest Element" 

172 symbStr = glyphs.max(s) 

173 else: 

174 typeStr = "Smallest Element" 

175 symbStr = glyphs.min(s) 

176 

177 Expression.__init__(self, typeStr, symbStr) 

178 

179 # -------------------------------------------------------------------------- 

180 # Abstract method implementations and method overridings, except _predict. 

181 # -------------------------------------------------------------------------- 

182 

183 def _get_refined(self): 

184 if self._x.constant: 

185 return AffineExpression.from_constant(self.value, 1, self._symbStr) 

186 elif self._full: 

187 if len(self._x) == 1: 

188 return self._x # Don't carry the string for an identity. 

189 if self._eigenvalues: 

190 return self._x.tr # Symbolic strings already match. 

191 else: 

192 return (1 | self._x).renamed(self._symbStr) 

193 else: 

194 return self 

195 

196 Subtype = namedtuple("Subtype", 

197 ("argdim", "k", "largest", "eigenvalues", "complex")) 

198 

199 def _get_subtype(self): 

200 return self.Subtype(len(self._x), self._k, self._largest, 

201 self._eigenvalues, self._x.complex) 

202 

203 def _get_value(self): 

204 value = self._x._get_value() 

205 

206 if self._eigenvalues: 

207 value = sorted(numpy.linalg.eigvalsh(cvx2np(value))) 

208 else: 

209 value = sorted(value) 

210 

211 value = sum(value[-self._k:] if self._largest else value[:self._k]) 

212 value = cvxopt.matrix(value) 

213 

214 return value 

215 

216 def _get_mutables(self): 

217 return self._x._get_mutables() 

218 

219 def _is_convex(self): 

220 return self._largest or self._full 

221 

222 def _is_concave(self): 

223 return not self._largest or self._full 

224 

225 def _replace_mutables(self, mapping): 

226 return self.__class__(self._x._replace_mutables(mapping), 

227 self._k, self._largest, self._eigenvalues) 

228 

229 def _freeze_mutables(self, freeze): 

230 return self.__class__(self._x._freeze_mutables(freeze), 

231 self._k, self._largest, self._eigenvalues) 

232 

233 # -------------------------------------------------------------------------- 

234 # Python special method implementations, except constraint-creating ones. 

235 # -------------------------------------------------------------------------- 

236 

237 @classmethod 

238 def _mul(cls, self, other, forward): 

239 if isinstance(other, AffineExpression) and other.constant: 

240 factor = other.safe_value 

241 

242 if not factor: 

243 return AffineExpression.zero() 

244 elif factor == 1: 

245 return self 

246 elif factor > 0: 

247 if forward: 

248 string = glyphs.clever_mul(self.string, other.string) 

249 else: 

250 string = glyphs.clever_mul(other.string, self.string) 

251 

252 product = cls( 

253 factor*self._x, self._k, self._largest, self._eigenvalues) 

254 product._typeStr = "Scaled " + product._typeStr 

255 product._symbStr = string 

256 

257 return product 

258 

259 if forward: 

260 return Expression.__mul__(self, other) 

261 else: 

262 return Expression.__rmul__(self, other) 

263 

264 @convert_operands(scalarRHS=True) 

265 @refine_operands() 

266 def __mul__(self, other): 

267 return SumExtremes._mul(self, other, True) 

268 

269 @convert_operands(scalarRHS=True) 

270 @refine_operands() 

271 def __rmul__(self, other): 

272 return SumExtremes._mul(self, other, False) 

273 

274 # -------------------------------------------------------------------------- 

275 # Methods and properties that return expressions. 

276 # -------------------------------------------------------------------------- 

277 

278 @property 

279 def x(self): 

280 """The expression under the sum.""" 

281 return self._x 

282 

283 # -------------------------------------------------------------------------- 

284 # Methods and properties that describe the expression. 

285 # -------------------------------------------------------------------------- 

286 

287 @property 

288 def k(self): 

289 """Number of (eigen)values to sum.""" 

290 return self._k 

291 

292 @property 

293 def largest(self): 

294 """Whether the sum concerns largest values as opposed to smallest.""" 

295 return self._largest 

296 

297 @property 

298 def eigenvalues(self): 

299 """Whether the sum concerns eigenvalues as opposed to elements.""" 

300 return self._eigenvalues 

301 

302 @property 

303 def full(self): 

304 """Whether the sum concerns *all* (eigen)values of the expression.""" 

305 return self._full 

306 

307 # -------------------------------------------------------------------------- 

308 # Constraint-creating operators, and _predict. 

309 # -------------------------------------------------------------------------- 

310 

311 @classmethod 

312 def _predict(cls, subtype, relation, other): 

313 assert isinstance(subtype, cls.Subtype) 

314 

315 n = subtype.argdim 

316 k = subtype.k 

317 e = subtype.eigenvalues 

318 c = subtype.complex 

319 

320 kmax = int(n**0.5) if e else n 

321 full = k == kmax 

322 

323 convex = subtype.largest or full 

324 concave = not subtype.largest or full 

325 

326 if relation == operator.__le__: 

327 if not convex: 

328 return NotImplemented 

329 

330 if issubclass(other.clstype, AffineExpression) \ 

331 and other.subtype.dim == 1: 

332 return SumExtremesConstraint.make_type(n, k, e, c) 

333 elif relation == operator.__ge__: 

334 if not concave: 

335 return NotImplemented 

336 

337 if issubclass(other.clstype, AffineExpression) \ 

338 and other.subtype.dim == 1: 

339 return SumExtremesConstraint.make_type(n, k, e, c) 

340 

341 return NotImplemented 

342 

343 @convert_operands(scalarRHS=True) 

344 @validate_prediction 

345 @refine_operands() 

346 def __le__(self, other): 

347 if not self.convex: 

348 raise TypeError("Cannot upper-bound the nonconvex expression {}." 

349 .format(self._symbStr)) 

350 

351 if isinstance(other, AffineExpression): 

352 return SumExtremesConstraint(self, Constraint.LE, other) 

353 else: 

354 return NotImplemented 

355 

356 @convert_operands(scalarRHS=True) 

357 @validate_prediction 

358 @refine_operands() 

359 def __ge__(self, other): 

360 if not self.concave: 

361 raise TypeError("Cannot upper-bound the nonconcave expression {}." 

362 .format(self._symbStr)) 

363 

364 if isinstance(other, AffineExpression): 

365 return SumExtremesConstraint(self, Constraint.GE, other) 

366 else: 

367 return NotImplemented 

368 

369 

370# -------------------------------------- 

371__all__ = api_end(_API_START, globals())