Coverage for picos/expressions/exp_sumexp.py: 73.30%

191 statements  

« prev     ^ index     » next       coverage.py v6.5.0, created at 2023-03-26 07:46 +0000

1# ------------------------------------------------------------------------------ 

2# Copyright (C) 2019 Maximilian Stahlberg 

3# Based on the original picos.expressions module by Guillaume Sagnol. 

4# 

5# This file is part of PICOS. 

6# 

7# PICOS is free software: you can redistribute it and/or modify it under the 

8# terms of the GNU General Public License as published by the Free Software 

9# Foundation, either version 3 of the License, or (at your option) any later 

10# version. 

11# 

12# PICOS is distributed in the hope that it will be useful, but WITHOUT ANY 

13# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR 

14# A PARTICULAR PURPOSE. See the GNU General Public License for more details. 

15# 

16# You should have received a copy of the GNU General Public License along with 

17# this program. If not, see <http://www.gnu.org/licenses/>. 

18# ------------------------------------------------------------------------------ 

19 

20"""Implements :class:`SumExponentials`.""" 

21 

22import math 

23import operator 

24from collections import namedtuple 

25 

26import cvxopt 

27import numpy 

28 

29from .. import glyphs 

30from ..apidoc import api_end, api_start 

31from ..caching import cached_property, cached_unary_operator 

32from ..constraints import LogSumExpConstraint, SumExponentialsConstraint 

33from .data import convert_and_refine_arguments, convert_operands, cvx2np 

34from .exp_affine import AffineExpression 

35from .expression import Expression, refine_operands, validate_prediction 

36 

37_API_START = api_start(globals()) 

38# ------------------------------- 

39 

40 

41class SumExponentials(Expression): 

42 r"""Sum of elementwise exponentials of an affine expression. 

43 

44 :Definition: 

45 

46 Let :math:`x` be an :math:`n`-dimensional real affine expression. 

47 

48 1. If no additional expression :math:`y` is given, this is the sum of 

49 elementwise exponentials 

50 

51 .. math:: 

52 

53 \sum_{i = 1}^n \exp(\operatorname{vec}(x)_i). 

54 

55 2. If an additional affine expression :math:`y` of same shape as :math:`x` 

56 is given, this is the sum of elementwise perspectives of exponentials 

57 

58 .. math:: 

59 

60 \sum_{i = 1}^n \operatorname{vec}(y)_i \exp\left( 

61 \frac{\operatorname{vec}(x)_i}{\operatorname{vec}(y)_i}\right). 

62 

63 .. warning:: 

64 

65 When you pose an upper bound :math:`t` on a sum of elementwise 

66 exponentials, then PICOS enforces :math:`t \geq 0` through an auxiliary 

67 constraint during solution search. When an additional expression 

68 :math:`y` is given, PICOS enforces :math:`y \geq 0` as well. 

69 """ 

70 

71 # -------------------------------------------------------------------------- 

72 # Initialization and factory methods. 

73 # -------------------------------------------------------------------------- 

74 

75 @convert_and_refine_arguments("x", "y", allowNone=True) 

76 def __init__(self, x, y=None): 

77 """Construct a :class:`SumExponentials`. 

78 

79 :param x: The affine expression :math:`x`. 

80 :type x: ~picos.expressions.AffineExpression 

81 :param y: An additional affine expression :math:`y`. If necessary, PICOS 

82 will attempt to reshape or broadcast it to the shape of :math:`x`. 

83 :type y: ~picos.expressions.AffineExpression 

84 """ 

85 if not isinstance(x, AffineExpression): 

86 raise TypeError("Can only sum the elementwise exponentials of a " 

87 "real affine expression, not of {}.".format(x.string)) 

88 

89 if y is not None: 

90 if not isinstance(y, AffineExpression): 

91 raise TypeError("The additional parameter y must be a real " 

92 "affine expression, not {}.".format(y.string)) 

93 elif x.shape != y.shape: 

94 y = y.reshaped_or_broadcasted(x.shape) 

95 

96 if y.is1: 

97 y = None 

98 

99 self._x = x 

100 self._y = y 

101 

102 if len(x) == 1: 

103 if y is None: 

104 typeStr = "Exponential" 

105 symbStr = glyphs.exp(x.string) 

106 else: 

107 typeStr = "Exponential Perspective" 

108 symbStr = glyphs.mul( 

109 y.string, glyphs.exp(glyphs.div(x.string, y.string))) 

110 else: 

111 if y is None: 

112 typeStr = "Sum of Exponentials" 

113 symbStr = glyphs.make_function("sum", "exp")(x.string) 

114 else: 

115 typeStr = "Sum of Exponential Perspectives" 

116 symbStr = glyphs.sum(glyphs.mul(glyphs.slice(y.string, "i"), 

117 glyphs.exp(glyphs.div(glyphs.slice(x.string, "i"), 

118 glyphs.slice(y.string, "i"))))) 

119 

120 Expression.__init__(self, typeStr, symbStr) 

121 

122 # -------------------------------------------------------------------------- 

123 # Abstract method implementations and method overridings, except _predict. 

124 # -------------------------------------------------------------------------- 

125 

126 def _get_refined(self): 

127 if self._x.constant and (self._y is None or self._y.constant): 

128 return AffineExpression.from_constant(self.value, 1, self._symbStr) 

129 else: 

130 return self 

131 

132 Subtype = namedtuple("Subtype", ("argdim", "y")) 

133 

134 def _get_subtype(self): 

135 return self.Subtype(len(self._x), self._y is not None) 

136 

137 def _get_value(self): 

138 x = numpy.ravel(cvx2np(self._x._get_value())) 

139 

140 if self._y is None: 

141 s = numpy.sum(numpy.exp(x)) 

142 else: 

143 y = numpy.ravel(cvx2np(self._y._get_value())) 

144 s = y.dot(numpy.exp(x / y)) 

145 

146 return cvxopt.matrix(s) 

147 

148 @cached_unary_operator 

149 def _get_mutables(self): 

150 if self._y is None: 

151 return self._x._get_mutables() 

152 else: 

153 return self._x._get_mutables().union(self._y.mutables) 

154 

155 def _is_convex(self): 

156 return True 

157 

158 def _is_concave(self): 

159 return False 

160 

161 def _replace_mutables(self, mapping): 

162 return self.__class__(self._x._replace_mutables(mapping), 

163 None if self._y is None else self._y._replace_mutables(mapping)) 

164 

165 def _freeze_mutables(self, freeze): 

166 return self.__class__(self._x._freeze_mutables(freeze), 

167 None if self._y is None else self._y._freeze_mutables(freeze)) 

168 

169 # -------------------------------------------------------------------------- 

170 # Python special method implementations, except constraint-creating ones. 

171 # -------------------------------------------------------------------------- 

172 

173 @classmethod 

174 def _add(cls, self, other, forward): 

175 if isinstance(other, AffineExpression) and other.constant: 

176 value = other.value 

177 

178 if not value: 

179 return self 

180 elif value > 0: 

181 if self._y is None: 

182 result = cls(self._x // math.log(value)) 

183 else: 

184 result = cls(self._x // value, self._y // 1) 

185 

186 if forward: 

187 string = glyphs.clever_add(self.string, other.string) 

188 else: 

189 string = glyphs.clever_add(other.string, self.string) 

190 

191 result._typeStr = "Offset " + result._typeStr 

192 result._symbStr = string 

193 

194 return result 

195 elif isinstance(other, cls): 

196 assert forward, "Encountered __radd__ on equal types." 

197 

198 if self._y is None and other._y is None: 

199 result = cls(self._x.vec // other._x.vec) 

200 elif self._y is not None and other._y is None: 

201 one = AffineExpression.from_constant(1.0, (other.n, 1)) 

202 result = cls(self._x.vec // other._x.vec, self._y.vec // one) 

203 elif self._y is None and other._y is not None: 

204 one = AffineExpression.from_constant(1.0, (self.n, 1)) 

205 result = cls(self._x.vec // other._x.vec, one // other._y.vec) 

206 else: 

207 result = cls( 

208 self._x.vec // other._x.vec, self._y.vec // other._y.vec) 

209 

210 result._symbStr = glyphs.clever_add(self.string, other.string) 

211 

212 return result 

213 

214 if forward: 

215 return Expression.__add__(self, other) 

216 else: 

217 return Expression.__radd__(self, other) 

218 

219 @convert_operands(scalarRHS=True) 

220 @refine_operands() 

221 def __add__(self, other): 

222 return SumExponentials._add(self, other, True) 

223 

224 @convert_operands(scalarRHS=True) 

225 @refine_operands() 

226 def __radd__(self, other): 

227 return SumExponentials._add(self, other, False) 

228 

229 @classmethod 

230 def _mul_div(cls, self, other, div, forward): 

231 assert not div or forward 

232 

233 if isinstance(other, AffineExpression) and other.constant: 

234 factor = other.safe_value 

235 

236 if not factor: 

237 if div: 

238 raise ZeroDivisionError( 

239 "Cannot divide {} by zero.".format(self.string)) 

240 else: 

241 return AffineExpression.zero() 

242 elif factor == 1: 

243 return self 

244 elif factor > 0: 

245 if div: 

246 factor = 1 / factor 

247 string = glyphs.div(self.string, other.string) 

248 elif forward: 

249 string = glyphs.clever_mul(self.string, other.string) 

250 else: 

251 string = glyphs.clever_mul(other.string, self.string) 

252 

253 if self._y is None: 

254 result = cls(self._x + math.log(factor)) 

255 else: 

256 result = cls(other*self._x, other*self._y) 

257 

258 result._typeStr = "Scaled " + result._typeStr 

259 result._symbStr = string 

260 

261 return result 

262 

263 if div: 

264 return Expression.__div__(self, other) 

265 elif forward: 

266 return Expression.__mul__(self, other) 

267 else: 

268 return Expression.__rmul__(self, other) 

269 

270 @convert_operands(scalarRHS=True) 

271 @refine_operands() 

272 def __mul__(self, other): 

273 """Denote scaling from the right hand side.""" 

274 return SumExponentials._mul_div(self, other, div=False, forward=True) 

275 

276 @convert_operands(scalarRHS=True) 

277 @refine_operands() 

278 def __rmul__(self, other): 

279 """Denote scaling from the left hand side.""" 

280 return SumExponentials._mul_div(self, other, div=False, forward=False) 

281 

282 @convert_operands(scalarRHS=True) 

283 @refine_operands() 

284 def __truediv__(self, other): 

285 """Denote division by a constant scalar.""" 

286 return SumExponentials._mul_div(self, other, div=True, forward=True) 

287 

288 # -------------------------------------------------------------------------- 

289 # Methods and properties that return expressions. 

290 # -------------------------------------------------------------------------- 

291 

292 @property 

293 def x(self): 

294 """The expression :math:`x`.""" 

295 return self._x 

296 

297 @property 

298 def y(self): 

299 """The additional expression :math:`y`, or :obj:`None`.""" 

300 return self._y 

301 

302 @cached_property 

303 def log(self): 

304 """The logarithm of the expression.""" 

305 from . import LogSumExp 

306 

307 if self._y is not None: 

308 raise NotImplementedError("May only take the logarithm of a sum of" 

309 " exponentials, not of a sum of exponential perspectives.") 

310 

311 return LogSumExp(self._x) 

312 

313 # -------------------------------------------------------------------------- 

314 # Methods and properties that describe the expression. 

315 # -------------------------------------------------------------------------- 

316 

317 @property 

318 def n(self): 

319 """Length of :attr:`x`.""" 

320 return len(self._x) 

321 

322 # -------------------------------------------------------------------------- 

323 # Constraint-creating operators, and _predict. 

324 # -------------------------------------------------------------------------- 

325 

326 @classmethod 

327 def _predict(cls, subtype, relation, other): 

328 assert isinstance(subtype, cls.Subtype) 

329 

330 if relation == operator.__le__: 

331 if issubclass(other.clstype, AffineExpression) \ 

332 and other.subtype.dim == 1: 

333 return SumExponentialsConstraint.make_type( 

334 argdim=subtype.argdim, 

335 lse_representable=(not subtype.y and other.subtype.nonneg)) 

336 elif issubclass(other.clstype, SumExponentials): 

337 if subtype.y or other.subtype.y: 

338 return NotImplemented 

339 

340 if other.subtype.argdim != 1: 

341 return NotImplemented 

342 

343 return LogSumExpConstraint.make_type(argdim=subtype.argdim) 

344 

345 return NotImplemented 

346 

347 @convert_operands(scalarRHS=True) 

348 @validate_prediction 

349 @refine_operands() 

350 def __le__(self, other): 

351 from . import LogSumExp 

352 

353 if isinstance(other, AffineExpression): 

354 return SumExponentialsConstraint(self, other) 

355 elif isinstance(other, SumExponentials): 

356 if self._y is not None or other._y is not None: 

357 raise NotImplementedError("Comparing two sums of exponentials " 

358 "is not supported if either expression has the additional " 

359 "perspectives parameter y set.") 

360 

361 if other.n != 1: 

362 raise NotImplementedError("You may only upper bound a sum of " 

363 "exponentials by a single exponential, not by another sum.") 

364 

365 return LogSumExp(self._x) <= other._x 

366 else: 

367 return NotImplemented 

368 

369 

370# -------------------------------------- 

371__all__ = api_end(_API_START, globals())