Coverage for picos/constraints/con_wsum.py: 93.51%
77 statements
« prev ^ index » next coverage.py v6.5.0, created at 2023-03-26 07:46 +0000
« prev ^ index » next coverage.py v6.5.0, created at 2023-03-26 07:46 +0000
1# ------------------------------------------------------------------------------
2# Copyright (C) 2021 Maximilian Stahlberg
3#
4# This file is part of PICOS.
5#
6# PICOS is free software: you can redistribute it and/or modify it under the
7# terms of the GNU General Public License as published by the Free Software
8# Foundation, either version 3 of the License, or (at your option) any later
9# version.
10#
11# PICOS is distributed in the hope that it will be useful, but WITHOUT ANY
12# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
13# A PARTICULAR PURPOSE. See the GNU General Public License for more details.
14#
15# You should have received a copy of the GNU General Public License along with
16# this program. If not, see <http://www.gnu.org/licenses/>.
17# ------------------------------------------------------------------------------
19"""Implementation of :class:`WeightedSumConstraint`."""
21import operator
22from collections import namedtuple
24from .. import glyphs
25from ..apidoc import api_end, api_start
26from .constraint import Constraint, ConstraintConversion
28_API_START = api_start(globals())
29# -------------------------------
32class WeightedSumConstraint(Constraint):
33 """Bound on a convex or concave weighted sum of expressions."""
35 class Conversion(ConstraintConversion):
36 """Bound on a weighted sum of expressions conversion."""
38 @classmethod
39 def predict(cls, subtype, options):
40 """Implement :meth:`~.constraint.ConstraintConversion.predict`."""
41 from ..expressions import AffineExpression, RealVariable
42 from . import AffineConstraint
44 if subtype.relation == Constraint.LE:
45 fwdrel, bckrel = operator.__le__, operator.__ge__
46 else:
47 fwdrel, bckrel = operator.__ge__, operator.__le__
49 n = len(subtype.lhs_types)
51 assert n > 0
53 yield ("var", RealVariable.make_var_type(dim=n, bnd=0), 1)
54 yield ("con", AffineConstraint.make_type(dim=1, eq=False), 1)
56 rhs_type = AffineExpression.make_type( # Element of aux. variable.
57 shape=(1, 1), constant=False, nonneg=False
58 )
60 for lhs_type, nnw in zip(subtype.lhs_types, subtype.nonneg_weights):
61 if nnw:
62 yield ("con", lhs_type.predict(fwdrel, rhs_type), 1)
63 else:
64 yield ("con", lhs_type.predict(bckrel, rhs_type), 1)
66 @classmethod
67 def convert(cls, con, options):
68 """Implement :meth:`~.constraint.ConstraintConversion.convert`."""
69 from ..expressions import RealVariable
70 from ..modeling import Problem
72 n = len(con.wsum.expressions)
73 w = con.wsum.weights
75 assert n > 0
77 t = RealVariable("__t", len(con.wsum.expressions))
79 P = Problem()
81 if con.relation == Constraint.LE:
82 P += w.T * t <= con.rhs
84 for i, x in enumerate(con.wsum.expressions):
85 if w[i].value >= 0:
86 P.add_constraint(x <= t[i])
87 else:
88 P.add_constraint(x >= t[i])
89 else:
90 P += w.T * t >= con.rhs
92 for i, x in enumerate(con.wsum.expressions):
93 if w[i].value >= 0:
94 P.add_constraint(x >= t[i])
95 else:
96 P.add_constraint(x <= t[i])
98 return P
100 def __init__(self, wsum, relation, rhs):
101 """Construct a :class:`WeightedSumConstraint`.
103 :param ~picos.expressions.WeightedSum wsum:
104 Left hand side expression.
105 :param str relation:
106 Constraint relation symbol.
107 :param ~picos.expressions.AffineExpression rhs:
108 Right hand side expression.
109 """
110 from ..expressions import AffineExpression, WeightedSum
112 assert isinstance(wsum, WeightedSum)
113 assert isinstance(rhs, AffineExpression)
114 assert relation in self.LE + self.GE
115 if relation == self.LE:
116 assert wsum.convex
117 else:
118 assert wsum.concave
119 assert len(rhs) == 1
121 self.wsum = wsum
122 self.relation = relation
123 self.rhs = rhs
125 super(WeightedSumConstraint, self).__init__(wsum._typeStr)
127 Subtype = namedtuple("Subtype", (
128 "lhs_types", "relation", "rhs_type", "nonneg_weights"))
130 def _subtype(self):
131 return self.Subtype(
132 lhs_types=self.wsum.subtype.types,
133 relation=self.relation,
134 rhs_type=self.rhs.type,
135 nonneg_weights=tuple(self.wsum.weights.np >= 0))
137 @classmethod
138 def _cost(cls, subtype):
139 return len(subtype.lhs_types)
141 def _expression_names(self):
142 yield "wsum"
143 yield "rhs"
145 def _str(self):
146 if self.relation == self.LE:
147 return glyphs.le(self.wsum.string, self.rhs.string)
148 else:
149 return glyphs.ge(self.wsum.string, self.rhs.string)
151 def _get_slack(self):
152 if self.relation == self.LE:
153 return self.rhs.safe_value - self.wsum.safe_value
154 else:
155 return self.wsum.safe_value - self.rhs.safe_value
158# --------------------------------------
159__all__ = api_end(_API_START, globals())