Coverage for picos/constraints/con_logsumexp.py: 95.45%
66 statements
« prev ^ index » next coverage.py v6.5.0, created at 2023-02-15 14:21 +0000
« prev ^ index » next coverage.py v6.5.0, created at 2023-02-15 14:21 +0000
1# ------------------------------------------------------------------------------
2# Copyright (C) 2018-2019 Maximilian Stahlberg
3#
4# This file is part of PICOS.
5#
6# PICOS is free software: you can redistribute it and/or modify it under the
7# terms of the GNU General Public License as published by the Free Software
8# Foundation, either version 3 of the License, or (at your option) any later
9# version.
10#
11# PICOS is distributed in the hope that it will be useful, but WITHOUT ANY
12# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
13# A PARTICULAR PURPOSE. See the GNU General Public License for more details.
14#
15# You should have received a copy of the GNU General Public License along with
16# this program. If not, see <http://www.gnu.org/licenses/>.
17# ------------------------------------------------------------------------------
19"""Implementation of :class:`LogSumExpConstraint`."""
21from collections import namedtuple
23from .. import glyphs
24from ..apidoc import api_end, api_start
25from ..caching import cached_property
26from .constraint import Constraint, ConstraintConversion
28_API_START = api_start(globals())
29# -------------------------------
32class LogSumExpConstraint(Constraint):
33 """Upper bound on a logarithm of a sum of exponentials."""
35 class ExpConeConversion(ConstraintConversion):
36 """Bound on a log-sum-exp to exponential cone constraint conversion."""
38 @classmethod
39 def predict(cls, subtype, options):
40 """Implement :meth:`~.constraint.ConstraintConversion.predict`."""
41 from ..expressions import RealVariable
42 from . import AffineConstraint, ExpConeConstraint
44 n = subtype.argdim
46 yield ("var", RealVariable.make_var_type(dim=n, bnd=0), 1)
47 yield ("con", AffineConstraint.make_type(dim=1, eq=False), 1)
48 yield ("con", ExpConeConstraint.make_type(), n)
50 @classmethod
51 def convert(cls, con, options):
52 """Implement :meth:`~.constraint.ConstraintConversion.convert`."""
53 from ..expressions import ExponentialCone
54 from ..modeling import Problem
56 x = con.lse.x
57 n = con.lse.n
58 b = con.ub
60 P = Problem()
62 u = P.add_variable("__u", n)
63 P.add_constraint((u | 1) <= 1)
65 for i in range(n):
66 P.add_constraint((u[i] // 1 // (x[i] - b)) << ExponentialCone())
68 return P
70 @classmethod
71 def dual(cls, auxVarPrimals, auxConDuals, options):
72 """Implement :meth:`~.constraint.ConstraintConversion.dual`."""
73 # TODO: Verify that this is the dual.
74 return auxConDuals[0]
76 def __init__(self, lse, upperBound):
77 """Construct a :class:`LogSumExpConstraint`.
79 :param ~picos.expressions.LogSumExp lse:
80 Constrained expression.
81 :param ~picos.expressions.AffineExpression upperBound:
82 Upper bound on the expression.
83 """
84 from ..expressions import AffineExpression, LogSumExp
86 assert isinstance(lse, LogSumExp)
87 assert isinstance(upperBound, AffineExpression)
88 assert len(upperBound) == 1
90 self.lse = lse
91 self.ub = upperBound
93 super(LogSumExpConstraint, self).__init__(
94 lse._typeStr if isinstance(lse, LogSumExp) else "LSE")
96 @property
97 def exponents(self):
98 """The affine exponents of the bounded log-sum-exp expression."""
99 return self.lse.x
101 @cached_property
102 def le0(self):
103 """The :class:`~.exp_logsumexp.LogSumExp` posed to be at most zero."""
104 from ..expressions import LogSumExp
106 if self.ub.is0:
107 return self.lse
108 else:
109 return LogSumExp(self.lse.x - self.ub)
111 Subtype = namedtuple("Subtype", ("argdim",))
113 def _subtype(self):
114 return self.Subtype(self.lse.n)
116 @classmethod
117 def _cost(cls, subtype):
118 return subtype.argdim + 1
120 def _expression_names(self):
121 yield "lse"
122 yield "ub"
124 def _str(self):
125 return glyphs.le(self.lse.string, self.ub.string)
127 def _get_size(self):
128 return (1, 1)
130 def _get_slack(self):
131 return self.ub.safe_value - self.lse.safe_value
134# --------------------------------------
135__all__ = api_end(_API_START, globals())