Coverage for picos/expressions/exp_logsumexp.py: 72.28%
101 statements
« prev ^ index » next coverage.py v6.5.0, created at 2023-03-26 07:46 +0000
« prev ^ index » next coverage.py v6.5.0, created at 2023-03-26 07:46 +0000
1# ------------------------------------------------------------------------------
2# Copyright (C) 2019 Maximilian Stahlberg
3# Based on the original picos.expressions module by Guillaume Sagnol.
4#
5# This file is part of PICOS.
6#
7# PICOS is free software: you can redistribute it and/or modify it under the
8# terms of the GNU General Public License as published by the Free Software
9# Foundation, either version 3 of the License, or (at your option) any later
10# version.
11#
12# PICOS is distributed in the hope that it will be useful, but WITHOUT ANY
13# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
14# A PARTICULAR PURPOSE. See the GNU General Public License for more details.
15#
16# You should have received a copy of the GNU General Public License along with
17# this program. If not, see <http://www.gnu.org/licenses/>.
18# ------------------------------------------------------------------------------
20"""Implements :class:`LogSumExp`."""
22import operator
23from collections import namedtuple
25import cvxopt
26import numpy
28from .. import glyphs
29from ..apidoc import api_end, api_start
30from ..caching import cached_property
31from ..constraints import LogSumExpConstraint
32from .data import convert_and_refine_arguments, convert_operands, cvx2np
33from .exp_affine import AffineExpression
34from .expression import Expression, refine_operands, validate_prediction
36_API_START = api_start(globals())
37# -------------------------------
40class LogSumExp(Expression):
41 r"""Logarithm of the sum of elementwise exponentials of an expression.
43 :Definition:
45 For an :math:`n`-dimensional real affine expression :math:`x`, this is the
46 logarithm of the sum of elementwise exponentials
48 .. math::
50 \log\sum_{i = 1}^n \exp(\operatorname{vec}(x)_i).
51 """
53 # --------------------------------------------------------------------------
54 # Initialization and factory methods.
55 # --------------------------------------------------------------------------
57 @convert_and_refine_arguments("x")
58 def __init__(self, x):
59 """Construct a :class:`LogSumExp`.
61 :param x: The affine expression :math:`x`.
62 :type x: ~picos.expressions.AffineExpression
63 """
64 if not isinstance(x, AffineExpression):
65 raise TypeError("Can only form the logarithm of the sum of "
66 "elementwise exponentials of a real affine expression, not of "
67 "{}.".format(type(x).__name__))
69 self._x = x
71 typeStr = "Logarithm of Sum of Exponentials"
72 symbStr = glyphs.make_function("log", "sum", "exp")(x.string)
74 Expression.__init__(self, typeStr, symbStr)
76 # --------------------------------------------------------------------------
77 # Abstract method implementations and method overridings, except _predict.
78 # --------------------------------------------------------------------------
80 def _get_refined(self):
81 if self._x.constant:
82 return AffineExpression.from_constant(self.value, 1, self._symbStr)
83 elif len(self._x) == 1:
84 return self._x # Don't carry the string for an identity.
85 else:
86 return self
88 Subtype = namedtuple("Subtype", ("argdim"))
90 def _get_subtype(self):
91 return self.Subtype(len(self._x))
93 def _get_value(self):
94 x = numpy.ravel(cvx2np(self._x._get_value()))
95 s = numpy.log(numpy.sum(numpy.exp(x)))
96 return cvxopt.matrix(s)
98 def _get_mutables(self):
99 return self._x._get_mutables()
101 def _is_convex(self):
102 return True
104 def _is_concave(self):
105 return False
107 def _replace_mutables(self, mapping):
108 return self.__class__(self._x._replace_mutables(mapping))
110 def _freeze_mutables(self, freeze):
111 return self.__class__(self._x._freeze_mutables(freeze))
113 # --------------------------------------------------------------------------
114 # Python special method implementations, except constraint-creating ones.
115 # --------------------------------------------------------------------------
117 @classmethod
118 def _add(cls, self, other, forward):
119 if isinstance(other, AffineExpression):
120 if other.is0:
121 return self
123 lse = cls(self._x + other)
125 if forward:
126 lse._symbStr = glyphs.clever_add(self.string, other.string)
127 else:
128 lse._symbStr = glyphs.clever_add(other.string, self.string)
130 return lse
132 if forward:
133 return Expression.__add__(self, other)
134 else:
135 return Expression.__radd__(self, other)
137 @convert_operands(scalarRHS=True)
138 @refine_operands()
139 def __add__(self, other):
140 return LogSumExp._add(self, other, True)
142 @convert_operands(scalarRHS=True)
143 @refine_operands()
144 def __radd__(self, other):
145 return LogSumExp._add(self, other, False)
147 @convert_operands(scalarRHS=True)
148 @refine_operands()
149 def __sub__(self, other):
150 if isinstance(other, AffineExpression):
151 if other.is0:
152 return self
154 lse = LogSumExp(self._x - other)
155 lse._symbStr = glyphs.clever_sub(self.string, other.string)
157 return lse
159 return Expression.__sub__(self, other)
161 # --------------------------------------------------------------------------
162 # Methods and properties that return expressions.
163 # --------------------------------------------------------------------------
165 @property
166 def x(self):
167 """The expression :math:`x`."""
168 return self._x
170 @cached_property
171 def exp(self):
172 """The elementwise sum of exponentials of :math:`x`."""
173 from . import SumExponentials
174 return SumExponentials(self._x)
176 # --------------------------------------------------------------------------
177 # Methods and properties that describe the expression.
178 # --------------------------------------------------------------------------
180 @property
181 def n(self):
182 """Length of :attr:`x`."""
183 return len(self._x)
185 # --------------------------------------------------------------------------
186 # Constraint-creating operators, and _predict.
187 # --------------------------------------------------------------------------
189 @classmethod
190 def _predict(cls, subtype, relation, other):
191 assert isinstance(subtype, cls.Subtype)
193 if relation == operator.__le__:
194 if issubclass(other.clstype, AffineExpression) \
195 and other.subtype.dim == 1:
196 return LogSumExpConstraint.make_type(subtype.argdim)
198 return NotImplemented
200 @convert_operands(scalarRHS=True)
201 @validate_prediction
202 @refine_operands()
203 def __le__(self, other):
204 if isinstance(other, AffineExpression):
205 return LogSumExpConstraint(self, other)
206 else:
207 return NotImplemented
210# --------------------------------------
211__all__ = api_end(_API_START, globals())