1# ------------------------------------------------------------------------------

2# Copyright (C) 2019 Maximilian Stahlberg

3# Based on the original picos.expressions module by Guillaume Sagnol.

4#

5# This file is part of PICOS.

6#

7# PICOS is free software: you can redistribute it and/or modify it under the

9# Foundation, either version 3 of the License, or (at your option) any later

10# version.

11#

12# PICOS is distributed in the hope that it will be useful, but WITHOUT ANY

13# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR

14# A PARTICULAR PURPOSE. See the GNU General Public License for more details.

15#

16# You should have received a copy of the GNU General Public License along with

17# this program. If not, see <http://www.gnu.org/licenses/>.

18# ------------------------------------------------------------------------------

20"""Implements :class:SumExtremes."""

22import operator

23from collections import namedtuple

25import cvxopt

26import numpy

28from .. import glyphs

29from ..apidoc import api_end, api_start

30from ..constraints import Constraint, SumExtremesConstraint

31from .data import convert_and_refine_arguments, convert_operands, cvx2np

32from .exp_affine import AffineExpression, ComplexAffineExpression

33from .expression import Expression, refine_operands, validate_prediction

35_API_START = api_start(globals())

36# -------------------------------

39class SumExtremes(Expression):

40 r"""Sum of the :math:k largest or smallest elements or eigenvalues.

42 :Definition:

44 Let :math:k \in \mathbb{Z}_{\geq 1}.

46 1. If :math:x is an :math:n-dimensional real vector or matrix and

47 eigenvalues == False, then this is the sum of the :math:k \leq n

48 largest or smallest scalar elements of :math:x, depending on the truth

49 value of largest.

51 Special cases:

53 - If :math:k = 1, this is either the largest element

54 :math:\max_{i = 1}^n \operatorname{vec}(x)_i or the smallest

55 element :math:\min_{i = 1}^n \operatorname{vec}(x)_i of :math:x.

56 - If :math:k = n, this is the sum of all elements

57 :math:\langle x, 1 \rangle of :math:x.

59 2. If :math:X is an :math:n \times n hermitian matrix and

60 eigenvalues == True, then this is the sum of the :math:k \leq n

61 largest or smallest eigenvalues of :math:X, depending on the truth

62 value of largest. Recall that the eigenvalues of a hermitian matrix

63 are real.

65 Special cases:

67 - If :math:k = 1, this is either the largest eigenvalue

68 :math:\lambda_{\max}(X) or the smallest eigenvalue

69 :math:\lambda_{\min}(X) of :math:X.

70 - If :math:k = n, this equals the trace

71 :math:\operatorname{tr}(X).

73 If the given :math:k exceeds the :math:n of either case, then :math:k

74 is silently clipped to :math:n.

75 """

77 # --------------------------------------------------------------------------

78 # Initialization and factory methods.

79 # --------------------------------------------------------------------------

81 @convert_and_refine_arguments("x")

82 def __init__(self, x, k, largest, eigenvalues=False):

83 """Construct a :class:SumExtremes.

85 :param x: The affine expression to take a sum over.

86 :type x: ~picos.expressions.ComplexAffineExpression

87 :param int k: Number of summands.

88 :param bool largest: Whether to sum over the largest (eigen)values as

89 opposed to the smallest.

90 :param bool eigenvalues: Whether to sum eigenvalues instead of elements.

91 """

92 largest = bool(largest)

93 eigenvalues = bool(eigenvalues)

95 lStr = "largest" if largest else "smallest"

96 eStr = "eigenvalues" if eigenvalues else "scalar elements"

97 what = "{} {}".format(lStr, eStr)

99 # Validate x.

100 if not isinstance(x, ComplexAffineExpression):

101 raise TypeError("Can only sum {} of an affine expression, not of "

102 "{}.".format(what, type(x).__name__))

104 # Further validate x.

105 if eigenvalues:

106 if not x.square:

107 raise TypeError("Cannot sum {} of {} as its shape of {} is not "

108 "square.".format(what, x.string, glyphs.shape(x.shape)))

109 elif not x.hermitian:

110 raise NotImplementedError(

111 "Summing the {0} of {1} is not supported as {1} is not "

112 "necessarily hermitian.".format(what, x.string))

113 else:

114 if not isinstance(x, AffineExpression):

115 raise TypeError("Can only sum {} of a real-valued expression "

116 "but {} is properly complex.".format(what, x.string))

118 # Validate k.

119 if int(k) != k:

120 raise ValueError(

121 "Conversion of k = {} to an integer is ambiguous.".format(k))

122 k = int(k)

123 if k < 1:

124 raise ValueError(

125 "Number of {} to sum must be positive.".format(what))

127 # Clip k to be at most n.

128 k = min(k, x.shape[0]) if eigenvalues else min(k, len(x))

130 # Find out if all (eigen)values are summed.

131 full = k == x.shape[0] if eigenvalues else k == len(x)

132 assert len(x) != 1 or full

134 self._x = x

135 self._k = k

136 self._largest = largest

137 self._eigenvalues = eigenvalues

138 self._full = full

140 s, lbd = x.string, glyphs.lambda_()

141 if full:

142 if eigenvalues:

143 typeStr = "Sum of Eigenvalues"

144 symbStr = symbStr = glyphs.trace(s)

145 else:

146 typeStr = "Sum of Elements"

147 symbStr = glyphs.sum(s)

148 elif k > 1:

149 if eigenvalues and largest:

150 typeStr = "Sum of Largest Eigenvalues"

151 symbStr = glyphs.make_function(

152 "sum_{}_largest_{}".format(k, lbd))(s)

153 elif eigenvalues and not largest:

154 typeStr = "Sum of Smallest Eigenvalues"

155 symbStr = glyphs.make_function(

156 "sum_{}_smallest_{}".format(k, lbd))(s)

157 elif not eigenvalues and largest:

158 typeStr = "Sum of Largest Elements"

159 symbStr = glyphs.make_function("sum_{}_largest".format(k))(s)

160 else:

161 typeStr = "Sum of Smallest Elements"

162 symbStr = glyphs.make_function("sum_{}_smallest".format(k))(s)

163 else:

164 if eigenvalues and largest:

165 typeStr = "Largest Eigenvalue"

166 symbStr = glyphs.make_function("{}_max".format(lbd))(s)

167 elif eigenvalues and not largest:

168 typeStr = "Smallest Eigenvalue"

169 symbStr = glyphs.make_function("{}_min".format(lbd))(s)

170 elif not eigenvalues and largest:

171 typeStr = "Largest Element"

172 symbStr = glyphs.max(s)

173 else:

174 typeStr = "Smallest Element"

175 symbStr = glyphs.min(s)

177 Expression.__init__(self, typeStr, symbStr)

179 # --------------------------------------------------------------------------

180 # Abstract method implementations and method overridings, except _predict.

181 # --------------------------------------------------------------------------

183 def _get_refined(self):

184 if self._x.constant:

185 return AffineExpression.from_constant(self.value, 1, self._symbStr)

186 elif self._full:

187 if len(self._x) == 1:

188 return self._x # Don't carry the string for an identity.

189 if self._eigenvalues:

190 return self._x.tr # Symbolic strings already match.

191 else:

192 return (1 | self._x).renamed(self._symbStr)

193 else:

194 return self

196 Subtype = namedtuple("Subtype",

197 ("argdim", "k", "largest", "eigenvalues", "complex"))

199 def _get_subtype(self):

200 return self.Subtype(len(self._x), self._k, self._largest,

201 self._eigenvalues, self._x.complex)

203 def _get_value(self):

204 value = self._x._get_value()

206 if self._eigenvalues:

207 value = sorted(numpy.linalg.eigvalsh(cvx2np(value)))

208 else:

209 value = sorted(value)

211 value = sum(value[-self._k:] if self._largest else value[:self._k])

212 value = cvxopt.matrix(value)

214 return value

216 def _get_mutables(self):

217 return self._x._get_mutables()

219 def _is_convex(self):

220 return self._largest or self._full

222 def _is_concave(self):

223 return not self._largest or self._full

225 def _replace_mutables(self, mapping):

226 return self.__class__(self._x._replace_mutables(mapping),

227 self._k, self._largest, self._eigenvalues)

229 def _freeze_mutables(self, freeze):

230 return self.__class__(self._x._freeze_mutables(freeze),

231 self._k, self._largest, self._eigenvalues)

233 # --------------------------------------------------------------------------

234 # Python special method implementations, except constraint-creating ones.

235 # --------------------------------------------------------------------------

237 @convert_operands(scalarRHS=True)

238 @refine_operands()

239 def __mul__(self, other):

240 if isinstance(other, AffineExpression):

241 if not other.constant:

242 raise NotImplementedError("You may only multiply a nonconstant "

243 "PICOS sum of extremes with a constant term.")

245 if other.value < 0:

246 raise NotImplementedError("You may only multiply a nonconstant "

247 "PICOS sum of extremes with a nonnegative term.")

249 product = SumExtremes(

250 other.value*self._x, self._k, self._largest, self._eigenvalues)

251 product._typeStr = "Scaled " + product._typeStr

252 product._symbStr = glyphs.clever_mul(self.string, other.string)

253 return product

254 else:

255 return NotImplemented

257 @convert_operands(scalarRHS=True)

258 @refine_operands()

259 def __rmul__(self, other):

260 if isinstance(other, AffineExpression):

261 product = self.__mul__(other)

262 # NOTE: __mul__ always creates a fresh expression.

263 product._symbStr = glyphs.clever_mul(other.string, self.string)

264 return product

265 else:

266 return NotImplemented

268 # --------------------------------------------------------------------------

269 # Methods and properties that return expressions.

270 # --------------------------------------------------------------------------

272 @property

273 def x(self):

274 """The expression under the sum."""

275 return self._x

277 # --------------------------------------------------------------------------

278 # Methods and properties that describe the expression.

279 # --------------------------------------------------------------------------

281 @property

282 def k(self):

283 """Number of (eigen)values to sum."""

284 return self._k

286 @property

287 def largest(self):

288 """Whether the sum concerns largest values as opposed to smallest."""

289 return self._largest

291 @property

292 def eigenvalues(self):

293 """Whether the sum concerns eigenvalues as opposed to elements."""

294 return self._eigenvalues

296 @property

297 def full(self):

298 """Whether the sum concerns *all* (eigen)values of the expression."""

299 return self._full

301 # --------------------------------------------------------------------------

302 # Constraint-creating operators, and _predict.

303 # --------------------------------------------------------------------------

305 @classmethod

306 def _predict(cls, subtype, relation, other):

307 assert isinstance(subtype, cls.Subtype)

309 n = subtype.argdim

310 k = subtype.k

311 e = subtype.eigenvalues

312 c = subtype.complex

314 kmax = int(n**0.5) if e else n

315 full = k == kmax

317 convex = subtype.largest or full

318 concave = not subtype.largest or full

320 if relation == operator.__le__:

321 if not convex:

322 return NotImplemented

324 if issubclass(other.clstype, AffineExpression) \

325 and other.subtype.dim == 1:

326 return SumExtremesConstraint.make_type(n, k, e, c)

327 elif relation == operator.__ge__:

328 if not concave:

329 return NotImplemented

331 if issubclass(other.clstype, AffineExpression) \

332 and other.subtype.dim == 1:

333 return SumExtremesConstraint.make_type(n, k, e, c)

335 return NotImplemented

337 @convert_operands(scalarRHS=True)

338 @validate_prediction

339 @refine_operands()

340 def __le__(self, other):

341 if not self.convex:

342 raise TypeError("Cannot upper-bound the nonconvex expression {}."

343 .format(self._symbStr))

345 if isinstance(other, AffineExpression):

346 return SumExtremesConstraint(self, Constraint.LE, other)

347 else:

348 return NotImplemented

350 @convert_operands(scalarRHS=True)

351 @validate_prediction

352 @refine_operands()

353 def __ge__(self, other):

354 if not self.concave:

355 raise TypeError("Cannot upper-bound the nonconcave expression {}."

356 .format(self._symbStr))

358 if isinstance(other, AffineExpression):

359 return SumExtremesConstraint(self, Constraint.GE, other)

360 else:

361 return NotImplemented

364# --------------------------------------

365__all__ = api_end(_API_START, globals())