Coverage for picos/expressions/exp_sumexp.py: 73.30%
191 statements
« prev ^ index » next coverage.py v6.5.0, created at 2023-02-15 14:21 +0000
« prev ^ index » next coverage.py v6.5.0, created at 2023-02-15 14:21 +0000
1# ------------------------------------------------------------------------------
2# Copyright (C) 2019 Maximilian Stahlberg
3# Based on the original picos.expressions module by Guillaume Sagnol.
4#
5# This file is part of PICOS.
6#
7# PICOS is free software: you can redistribute it and/or modify it under the
8# terms of the GNU General Public License as published by the Free Software
9# Foundation, either version 3 of the License, or (at your option) any later
10# version.
11#
12# PICOS is distributed in the hope that it will be useful, but WITHOUT ANY
13# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
14# A PARTICULAR PURPOSE. See the GNU General Public License for more details.
15#
16# You should have received a copy of the GNU General Public License along with
17# this program. If not, see <http://www.gnu.org/licenses/>.
18# ------------------------------------------------------------------------------
20"""Implements :class:`SumExponentials`."""
22import math
23import operator
24from collections import namedtuple
26import cvxopt
27import numpy
29from .. import glyphs
30from ..apidoc import api_end, api_start
31from ..caching import cached_property, cached_unary_operator
32from ..constraints import LogSumExpConstraint, SumExponentialsConstraint
33from .data import convert_and_refine_arguments, convert_operands, cvx2np
34from .exp_affine import AffineExpression
35from .expression import Expression, refine_operands, validate_prediction
37_API_START = api_start(globals())
38# -------------------------------
41class SumExponentials(Expression):
42 r"""Sum of elementwise exponentials of an affine expression.
44 :Definition:
46 Let :math:`x` be an :math:`n`-dimensional real affine expression.
48 1. If no additional expression :math:`y` is given, this is the sum of
49 elementwise exponentials
51 .. math::
53 \sum_{i = 1}^n \exp(\operatorname{vec}(x)_i).
55 2. If an additional affine expression :math:`y` of same shape as :math:`x`
56 is given, this is the sum of elementwise perspectives of exponentials
58 .. math::
60 \sum_{i = 1}^n \operatorname{vec}(y)_i \exp\left(
61 \frac{\operatorname{vec}(x)_i}{\operatorname{vec}(y)_i}\right).
63 .. warning::
65 When you pose an upper bound :math:`t` on a sum of elementwise
66 exponentials, then PICOS enforces :math:`t \geq 0` through an auxiliary
67 constraint during solution search. When an additional expression
68 :math:`y` is given, PICOS enforces :math:`y \geq 0` as well.
69 """
71 # --------------------------------------------------------------------------
72 # Initialization and factory methods.
73 # --------------------------------------------------------------------------
75 @convert_and_refine_arguments("x", "y", allowNone=True)
76 def __init__(self, x, y=None):
77 """Construct a :class:`SumExponentials`.
79 :param x: The affine expression :math:`x`.
80 :type x: ~picos.expressions.AffineExpression
81 :param y: An additional affine expression :math:`y`. If necessary, PICOS
82 will attempt to reshape or broadcast it to the shape of :math:`x`.
83 :type y: ~picos.expressions.AffineExpression
84 """
85 if not isinstance(x, AffineExpression):
86 raise TypeError("Can only sum the elementwise exponentials of a "
87 "real affine expression, not of {}.".format(x.string))
89 if y is not None:
90 if not isinstance(y, AffineExpression):
91 raise TypeError("The additional parameter y must be a real "
92 "affine expression, not {}.".format(y.string))
93 elif x.shape != y.shape:
94 y = y.reshaped_or_broadcasted(x.shape)
96 if y.is1:
97 y = None
99 self._x = x
100 self._y = y
102 if len(x) == 1:
103 if y is None:
104 typeStr = "Exponential"
105 symbStr = glyphs.exp(x.string)
106 else:
107 typeStr = "Exponential Perspective"
108 symbStr = glyphs.mul(
109 y.string, glyphs.exp(glyphs.div(x.string, y.string)))
110 else:
111 if y is None:
112 typeStr = "Sum of Exponentials"
113 symbStr = glyphs.make_function("sum", "exp")(x.string)
114 else:
115 typeStr = "Sum of Exponential Perspectives"
116 symbStr = glyphs.sum(glyphs.mul(glyphs.slice(y.string, "i"),
117 glyphs.exp(glyphs.div(glyphs.slice(x.string, "i"),
118 glyphs.slice(y.string, "i")))))
120 Expression.__init__(self, typeStr, symbStr)
122 # --------------------------------------------------------------------------
123 # Abstract method implementations and method overridings, except _predict.
124 # --------------------------------------------------------------------------
126 def _get_refined(self):
127 if self._x.constant and (self._y is None or self._y.constant):
128 return AffineExpression.from_constant(self.value, 1, self._symbStr)
129 else:
130 return self
132 Subtype = namedtuple("Subtype", ("argdim", "y"))
134 def _get_subtype(self):
135 return self.Subtype(len(self._x), self._y is not None)
137 def _get_value(self):
138 x = numpy.ravel(cvx2np(self._x._get_value()))
140 if self._y is None:
141 s = numpy.sum(numpy.exp(x))
142 else:
143 y = numpy.ravel(cvx2np(self._y._get_value()))
144 s = y.dot(numpy.exp(x / y))
146 return cvxopt.matrix(s)
148 @cached_unary_operator
149 def _get_mutables(self):
150 if self._y is None:
151 return self._x._get_mutables()
152 else:
153 return self._x._get_mutables().union(self._y.mutables)
155 def _is_convex(self):
156 return True
158 def _is_concave(self):
159 return False
161 def _replace_mutables(self, mapping):
162 return self.__class__(self._x._replace_mutables(mapping),
163 None if self._y is None else self._y._replace_mutables(mapping))
165 def _freeze_mutables(self, freeze):
166 return self.__class__(self._x._freeze_mutables(freeze),
167 None if self._y is None else self._y._freeze_mutables(freeze))
169 # --------------------------------------------------------------------------
170 # Python special method implementations, except constraint-creating ones.
171 # --------------------------------------------------------------------------
173 @classmethod
174 def _add(cls, self, other, forward):
175 if isinstance(other, AffineExpression) and other.constant:
176 value = other.value
178 if not value:
179 return self
180 elif value > 0:
181 if self._y is None:
182 result = cls(self._x // math.log(value))
183 else:
184 result = cls(self._x // value, self._y // 1)
186 if forward:
187 string = glyphs.clever_add(self.string, other.string)
188 else:
189 string = glyphs.clever_add(other.string, self.string)
191 result._typeStr = "Offset " + result._typeStr
192 result._symbStr = string
194 return result
195 elif isinstance(other, cls):
196 assert forward, "Encountered __radd__ on equal types."
198 if self._y is None and other._y is None:
199 result = cls(self._x.vec // other._x.vec)
200 elif self._y is not None and other._y is None:
201 one = AffineExpression.from_constant(1.0, (other.n, 1))
202 result = cls(self._x.vec // other._x.vec, self._y.vec // one)
203 elif self._y is None and other._y is not None:
204 one = AffineExpression.from_constant(1.0, (self.n, 1))
205 result = cls(self._x.vec // other._x.vec, one // other._y.vec)
206 else:
207 result = cls(
208 self._x.vec // other._x.vec, self._y.vec // other._y.vec)
210 result._symbStr = glyphs.clever_add(self.string, other.string)
212 return result
214 if forward:
215 return Expression.__add__(self, other)
216 else:
217 return Expression.__radd__(self, other)
219 @convert_operands(scalarRHS=True)
220 @refine_operands()
221 def __add__(self, other):
222 return SumExponentials._add(self, other, True)
224 @convert_operands(scalarRHS=True)
225 @refine_operands()
226 def __radd__(self, other):
227 return SumExponentials._add(self, other, False)
229 @classmethod
230 def _mul_div(cls, self, other, div, forward):
231 assert not div or forward
233 if isinstance(other, AffineExpression) and other.constant:
234 factor = other.safe_value
236 if not factor:
237 if div:
238 raise ZeroDivisionError(
239 "Cannot divide {} by zero.".format(self.string))
240 else:
241 return AffineExpression.zero()
242 elif factor == 1:
243 return self
244 elif factor > 0:
245 if div:
246 factor = 1 / factor
247 string = glyphs.div(self.string, other.string)
248 elif forward:
249 string = glyphs.clever_mul(self.string, other.string)
250 else:
251 string = glyphs.clever_mul(other.string, self.string)
253 if self._y is None:
254 result = cls(self._x + math.log(factor))
255 else:
256 result = cls(other*self._x, other*self._y)
258 result._typeStr = "Scaled " + result._typeStr
259 result._symbStr = string
261 return result
263 if div:
264 return Expression.__div__(self, other)
265 elif forward:
266 return Expression.__mul__(self, other)
267 else:
268 return Expression.__rmul__(self, other)
270 @convert_operands(scalarRHS=True)
271 @refine_operands()
272 def __mul__(self, other):
273 """Denote scaling from the right hand side."""
274 return SumExponentials._mul_div(self, other, div=False, forward=True)
276 @convert_operands(scalarRHS=True)
277 @refine_operands()
278 def __rmul__(self, other):
279 """Denote scaling from the left hand side."""
280 return SumExponentials._mul_div(self, other, div=False, forward=False)
282 @convert_operands(scalarRHS=True)
283 @refine_operands()
284 def __truediv__(self, other):
285 """Denote division by a constant scalar."""
286 return SumExponentials._mul_div(self, other, div=True, forward=True)
288 # --------------------------------------------------------------------------
289 # Methods and properties that return expressions.
290 # --------------------------------------------------------------------------
292 @property
293 def x(self):
294 """The expression :math:`x`."""
295 return self._x
297 @property
298 def y(self):
299 """The additional expression :math:`y`, or :obj:`None`."""
300 return self._y
302 @cached_property
303 def log(self):
304 """The logarithm of the expression."""
305 from . import LogSumExp
307 if self._y is not None:
308 raise NotImplementedError("May only take the logarithm of a sum of"
309 " exponentials, not of a sum of exponential perspectives.")
311 return LogSumExp(self._x)
313 # --------------------------------------------------------------------------
314 # Methods and properties that describe the expression.
315 # --------------------------------------------------------------------------
317 @property
318 def n(self):
319 """Length of :attr:`x`."""
320 return len(self._x)
322 # --------------------------------------------------------------------------
323 # Constraint-creating operators, and _predict.
324 # --------------------------------------------------------------------------
326 @classmethod
327 def _predict(cls, subtype, relation, other):
328 assert isinstance(subtype, cls.Subtype)
330 if relation == operator.__le__:
331 if issubclass(other.clstype, AffineExpression) \
332 and other.subtype.dim == 1:
333 return SumExponentialsConstraint.make_type(
334 argdim=subtype.argdim,
335 lse_representable=(not subtype.y and other.subtype.nonneg))
336 elif issubclass(other.clstype, SumExponentials):
337 if subtype.y or other.subtype.y:
338 return NotImplemented
340 if other.subtype.argdim != 1:
341 return NotImplemented
343 return LogSumExpConstraint.make_type(argdim=subtype.argdim)
345 return NotImplemented
347 @convert_operands(scalarRHS=True)
348 @validate_prediction
349 @refine_operands()
350 def __le__(self, other):
351 from . import LogSumExp
353 if isinstance(other, AffineExpression):
354 return SumExponentialsConstraint(self, other)
355 elif isinstance(other, SumExponentials):
356 if self._y is not None or other._y is not None:
357 raise NotImplementedError("Comparing two sums of exponentials "
358 "is not supported if either expression has the additional "
359 "perspectives parameter y set.")
361 if other.n != 1:
362 raise NotImplementedError("You may only upper bound a sum of "
363 "exponentials by a single exponential, not by another sum.")
365 return LogSumExp(self._x) <= other._x
366 else:
367 return NotImplemented
370# --------------------------------------
371__all__ = api_end(_API_START, globals())