r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1# ------------------------------------------------------------------------------

2# Copyright (C) 2019 Maximilian Stahlberg

3# Based on the original picos.expressions module by Guillaume Sagnol.

4#

5# This file is part of PICOS.

6#

7# PICOS is free software: you can redistribute it and/or modify it under the

9# Foundation, either version 3 of the License, or (at your option) any later

10# version.

11#

12# PICOS is distributed in the hope that it will be useful, but WITHOUT ANY

13# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR

14# A PARTICULAR PURPOSE. See the GNU General Public License for more details.

15#

16# You should have received a copy of the GNU General Public License along with

17# this program. If not, see <http://www.gnu.org/licenses/>.

18# ------------------------------------------------------------------------------

20"""Implements :class:SumExponentials."""

22import math

23import operator

24from collections import namedtuple

26import cvxopt

27import numpy

29from .. import glyphs

30from ..apidoc import api_end, api_start

31from ..caching import cached_property, cached_unary_operator

32from ..constraints import LogSumExpConstraint, SumExponentialsConstraint

33from .data import convert_and_refine_arguments, convert_operands, cvx2np

34from .exp_affine import AffineExpression

35from .expression import Expression, refine_operands, validate_prediction

37_API_START = api_start(globals())

38# -------------------------------

41class SumExponentials(Expression):

42 r"""Sum of elementwise exponentials of an affine expression.

44 :Definition:

46 Let :math:x be an :math:n-dimensional real affine expression.

48 1. If no additional expression :math:y is given, this is the sum of

49 elementwise exponentials

51 .. math::

53 \sum_{i = 1}^n \exp(\operatorname{vec}(x)_i).

55 2. If an additional affine expression :math:y of same shape as :math:x

56 is given, this is the sum of elementwise perspectives of exponentials

58 .. math::

60 \sum_{i = 1}^n \operatorname{vec}(y)_i \exp\left(

61 \frac{\operatorname{vec}(x)_i}{\operatorname{vec}(y)_i}\right).

63 .. warning::

65 When you pose an upper bound :math:t on a sum of elementwise

66 exponentials, then PICOS enforces :math:t \geq 0 through an auxiliary

67 constraint during solution search. When an additional expression

68 :math:y is given, PICOS enforces :math:y \geq 0 as well.

69 """

71 # --------------------------------------------------------------------------

72 # Initialization and factory methods.

73 # --------------------------------------------------------------------------

75 @convert_and_refine_arguments("x", "y", allowNone=True)

76 def __init__(self, x, y=None):

77 """Construct a :class:SumExponentials.

79 :param x: The affine expression :math:x.

80 :type x: ~picos.expressions.AffineExpression

81 :param y: An additional affine expression :math:y. If necessary, PICOS

82 will attempt to reshape or broadcast it to the shape of :math:x.

83 :type y: ~picos.expressions.AffineExpression

84 """

85 if not isinstance(x, AffineExpression):

86 raise TypeError("Can only sum the elementwise exponentials of a "

87 "real affine expression, not of {}.".format(x.string))

89 if y is not None:

90 if not isinstance(y, AffineExpression):

91 raise TypeError("The additional parameter y must be a real "

92 "affine expression, not {}.".format(y.string))

93 elif x.shape != y.shape:

96 if y.is1:

97 y = None

99 self._x = x

100 self._y = y

102 if len(x) == 1:

103 if y is None:

104 typeStr = "Exponential"

105 symbStr = glyphs.exp(x.string)

106 else:

107 typeStr = "Exponential Perspective"

108 symbStr = glyphs.mul(

109 y.string, glyphs.exp(glyphs.div(x.string, y.string)))

110 else:

111 if y is None:

112 typeStr = "Sum of Exponentials"

113 symbStr = glyphs.make_function("sum", "exp")(x.string)

114 else:

115 typeStr = "Sum of Exponential Perspectives"

116 symbStr = glyphs.sum(glyphs.mul(glyphs.slice(y.string, "i"),

117 glyphs.exp(glyphs.div(glyphs.slice(x.string, "i"),

118 glyphs.slice(y.string, "i")))))

120 Expression.__init__(self, typeStr, symbStr)

122 # --------------------------------------------------------------------------

123 # Abstract method implementations and method overridings, except _predict.

124 # --------------------------------------------------------------------------

126 def _get_refined(self):

127 if self._x.constant and (self._y is None or self._y.constant):

128 return AffineExpression.from_constant(self.value, 1, self._symbStr)

129 else:

130 return self

132 Subtype = namedtuple("Subtype", ("argdim", "y"))

134 def _get_subtype(self):

135 return self.Subtype(len(self._x), self._y is not None)

137 def _get_value(self):

138 x = numpy.ravel(cvx2np(self._x._get_value()))

140 if self._y is None:

141 s = numpy.sum(numpy.exp(x))

142 else:

143 y = numpy.ravel(cvx2np(self._y._get_value()))

144 s = y.dot(numpy.exp(x / y))

146 return cvxopt.matrix(s)

148 @cached_unary_operator

149 def _get_mutables(self):

150 if self._y is None:

151 return self._x._get_mutables()

152 else:

153 return self._x._get_mutables().union(self._y.mutables)

155 def _is_convex(self):

156 return True

158 def _is_concave(self):

159 return False

161 def _replace_mutables(self, mapping):

162 return self.__class__(self._x._replace_mutables(mapping),

163 None if self._y is None else self._y._replace_mutables(mapping))

165 def _freeze_mutables(self, freeze):

166 return self.__class__(self._x._freeze_mutables(freeze),

167 None if self._y is None else self._y._freeze_mutables(freeze))

169 # --------------------------------------------------------------------------

170 # Python special method implementations, except constraint-creating ones.

171 # --------------------------------------------------------------------------

173 @convert_operands(scalarRHS=True)

174 @refine_operands()

176 if isinstance(other, AffineExpression):

177 if not other.constant:

178 raise NotImplementedError("You may only add a constant term to "

179 "a nonconstant PICOS sum of exponentials.")

181 value = other.value

183 if value < 0:

184 raise NotImplementedError("You may only add a nonnegative term "

185 "to a nonconstant PICOS sum of exponentials.")

187 if value == 0:

188 # NOTE: We could return self here, but this is more consistent

189 # with other expressions' __add__ methods.

190 sumexp = SumExponentials(self._x)

191 elif self._y is None:

192 sumexp = SumExponentials(self._x // math.log(value))

193 else:

194 sumexp = SumExponentials(self._x // value, self._y // 1)

196 sumexp._typeStr = "Offset " + sumexp._typeStr

199 return sumexp

200 elif isinstance(other, SumExponentials):

201 if self._y is None and other._y is None:

202 sumexp = SumExponentials(self._x.vec // other._x.vec)

203 elif self._y is not None and other._y is None:

204 one = AffineExpression.from_constant(1.0, (other.n, 1))

205 sumexp = SumExponentials(

206 self._x.vec // other._x.vec, self._y.vec // one)

207 elif self._y is None and other._y is not None:

208 one = AffineExpression.from_constant(1.0, (self.n, 1))

209 sumexp = SumExponentials(

210 self._x.vec // other._x.vec, one // other._y.vec)

211 else:

212 sumexp = SumExponentials(

213 self._x.vec // other._x.vec, self._y.vec // other._y.vec)

217 return sumexp

218 else:

219 return NotImplemented

221 @convert_operands(scalarRHS=True)

222 @refine_operands()

224 if isinstance(other, (AffineExpression, SumExponentials)):

226 # NOTE: __add__ always creates a fresh expression.

228 return sumexp

229 else:

230 return NotImplemented

232 @convert_operands(scalarRHS=True)

233 @refine_operands()

234 def __mul__(self, other):

235 if isinstance(other, AffineExpression):

236 if not other.constant:

237 raise NotImplementedError("You may only multiply a nonconstant "

238 "PICOS sum of exponentials with a constant term.")

240 value = other.value

242 if value < 0:

243 raise NotImplementedError("You may only multiply a nonconstant "

244 "PICOS sum of exponential with a nonnegative term.")

246 if value == 0:

247 return AffineExpression.zero()

249 if self._y is None:

250 sumexp = SumExponentials(self._x + math.log(value))

251 else:

252 sumexp = SumExponentials(self._x * value, self._y * value)

254 sumexp._typeStr = "Scaled " + sumexp._typeStr

255 sumexp._symbStr = glyphs.clever_mul(self.string, other.string)

257 return sumexp

258 else:

259 return NotImplemented

261 @convert_operands(scalarRHS=True)

262 @refine_operands()

263 def __rmul__(self, other):

264 if isinstance(other, AffineExpression):

265 sumexp = self.__mul__(other)

266 # NOTE: __mul__ always creates a fresh expression.

267 sumexp._symbStr = glyphs.clever_mul(other.string, self.string)

268 return sumexp

269 else:

270 return NotImplemented

272 @convert_operands(scalarRHS=True)

273 @refine_operands()

274 def __truediv__(self, other):

275 if isinstance(other, AffineExpression):

276 if not other.constant:

277 raise NotImplementedError("You may only divide a nonconstant "

278 "PICOS sum of exponentials by a constant term.")

280 value = other.value

282 if value <= 0:

283 raise NotImplementedError("You may only divide a nonconstant "

284 "PICOS sum of exponential by a positive term.")

286 sumexp = self * (1.0 / value)

287 # NOTE: __mul__ always creates a fresh expression.

288 sumexp._symbStr = glyphs.div(self.string, other.string)

289 else:

290 return NotImplemented

292 # --------------------------------------------------------------------------

293 # Methods and properties that return expressions.

294 # --------------------------------------------------------------------------

296 @property

297 def x(self):

298 """The expression :math:x."""

299 return self._x

301 @property

302 def y(self):

303 """The additional expression :math:y, or :obj:None."""

304 return self._y

306 @cached_property

307 def log(self):

308 """The logarithm of the expression."""

309 from . import LogSumExp

311 if self._y is not None:

312 raise NotImplementedError("May only take the logarithm of a sum of"

313 " exponentials, not of a sum of exponential perspectives.")

315 return LogSumExp(self._x)

317 # --------------------------------------------------------------------------

318 # Methods and properties that describe the expression.

319 # --------------------------------------------------------------------------

321 @property

322 def n(self):

323 """Length of :attr:x."""

324 return len(self._x)

326 # --------------------------------------------------------------------------

327 # Constraint-creating operators, and _predict.

328 # --------------------------------------------------------------------------

330 @classmethod

331 def _predict(cls, subtype, relation, other):

332 assert isinstance(subtype, cls.Subtype)

334 if relation == operator.__le__:

335 if issubclass(other.clstype, AffineExpression) \

336 and other.subtype.dim == 1:

337 return SumExponentialsConstraint.make_type(

338 argdim=subtype.argdim,

339 lse_representable=(not subtype.y and other.subtype.nonneg))

340 elif issubclass(other.clstype, SumExponentials):

341 if subtype.y or other.subtype.y:

342 return NotImplemented

344 if other.subtype.argdim != 1:

345 return NotImplemented

347 return LogSumExpConstraint.make_type(argdim=subtype.argdim)

349 return NotImplemented

351 @convert_operands(scalarRHS=True)

352 @validate_prediction

353 @refine_operands()

354 def __le__(self, other):

355 from . import LogSumExp

357 if isinstance(other, AffineExpression):

358 return SumExponentialsConstraint(self, other)

359 elif isinstance(other, SumExponentials):

360 if self._y is not None or other._y is not None:

361 raise NotImplementedError("Comparing two sums of exponentials "

362 "is not supported if either expression has the additional "

363 "perspectives parameter y set.")

365 if other.n != 1:

366 raise NotImplementedError("You may only upper bound a sum of "

367 "exponentials by a single exponential, not by another sum.")

369 return LogSumExp(self._x) <= other._x

370 else:

371 return NotImplemented

374# --------------------------------------

375__all__ = api_end(_API_START, globals())