Coverage for picos/constraints/con_sumxtr.py: 84.62%
130 statements
« prev ^ index » next coverage.py v6.5.0, created at 2023-03-26 07:46 +0000
« prev ^ index » next coverage.py v6.5.0, created at 2023-03-26 07:46 +0000
1# ------------------------------------------------------------------------------
2# Copyright (C) 2012-2017 Guillaume Sagnol
3# Copyright (C) 2018-2019 Maximilian Stahlberg
4#
5# This file is part of PICOS.
6#
7# PICOS is free software: you can redistribute it and/or modify it under the
8# terms of the GNU General Public License as published by the Free Software
9# Foundation, either version 3 of the License, or (at your option) any later
10# version.
11#
12# PICOS is distributed in the hope that it will be useful, but WITHOUT ANY
13# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
14# A PARTICULAR PURPOSE. See the GNU General Public License for more details.
15#
16# You should have received a copy of the GNU General Public License along with
17# this program. If not, see <http://www.gnu.org/licenses/>.
18# ------------------------------------------------------------------------------
20"""Implementation of :class:`SumExtremesConstraint`."""
22from collections import namedtuple
24import cvxopt as cvx
26from .. import glyphs
27from ..apidoc import api_end, api_start
28from .constraint import Constraint, ConstraintConversion
30_API_START = api_start(globals())
31# -------------------------------
34class SumExtremesConstraint(Constraint):
35 """Bound on a sum over extreme (eigen)values."""
37 class Conversion(ConstraintConversion):
38 """Sum over extremes to LMI/affine constraint conversion."""
40 @classmethod
41 def predict(cls, subtype, options):
42 """Implement :meth:`~.constraint.ConstraintConversion.predict`."""
43 from ..expressions import (HermitianVariable, RealVariable,
44 SymmetricVariable)
45 from . import AffineConstraint, ComplexLMIConstraint, LMIConstraint
47 nm, k, eigenvalues, complex = subtype
49 # Determine matrix variable dimension.
50 if eigenvalues:
51 nFloat = nm**0.5
52 n = int(nFloat)
53 assert n == nFloat
54 d = n**2 if complex else n*(n + 1) // 2
55 else:
56 n = nm
58 # Validate k.
59 assert k > 0 and k <= n
61 # Define shorthands for better formatting below.
62 RV, AC = RealVariable, AffineConstraint
63 LMI = ComplexLMIConstraint if complex else LMIConstraint
64 MV = HermitianVariable if complex else SymmetricVariable
66 if eigenvalues:
67 if k == 1:
68 yield ("con", LMI.make_type(diag=n), 1)
69 elif k == n:
70 # NOTE: Refinement prevents this case from happening.
71 yield ("con", AC.make_type(dim=1, eq=False), 1)
72 else:
73 yield ("var", RV.make_var_type(dim=1, bnd=0), 1)
74 yield ("var", MV.make_var_type(dim=d, bnd=0), 1)
75 yield ("con", LMI.make_type(diag=n), 2)
76 yield ("con", AC.make_type(dim=1, eq=False), 1)
77 else:
78 if k == 1:
79 yield ("con", AC.make_type(dim=n, eq=False), 1)
80 elif k == n:
81 # NOTE: Refinement prevents this case from happening.
82 yield ("con", AC.make_type(dim=1, eq=False), 1)
83 else:
84 yield ("var", RV.make_var_type(dim=1, bnd=0), 1)
85 yield ("var", RV.make_var_type(dim=n, bnd=n), 1)
86 yield ("con", AC.make_type(dim=n, eq=False), 1)
87 yield ("con", AC.make_type(dim=1, eq=False), 1)
89 @classmethod
90 def convert(cls, con, options):
91 """Implement :meth:`~.constraint.ConstraintConversion.convert`."""
92 from ..expressions import (Constant, HermitianVariable,
93 RealVariable, SymmetricVariable)
94 from ..modeling import Problem
96 theSum = con.theSum
97 relation = con.relation
98 rhs = con.rhs
100 x = theSum.x
101 k = theSum.k
103 if theSum.eigenvalues:
104 n = x.shape[0]
105 I = Constant('I', cvx.spdiag([1.] * n))
106 else:
107 n = len(x)
109 if x.complex:
110 MatrixVariable = HermitianVariable
111 else:
112 MatrixVariable = SymmetricVariable
114 P = Problem()
116 if relation == Constraint.LE:
117 if theSum.eigenvalues:
118 if k == 1:
119 P.add_constraint(x << rhs * I)
120 elif k == n:
121 # NOTE: Refinement prevents this case from happening.
122 P.add_constraint(("I" | x) <= rhs)
123 else:
124 s = RealVariable('s')
125 Z = MatrixVariable('Z', n)
126 P.add_constraint(Z >> 0)
127 P.add_constraint(x << Z + s * I)
128 P.add_constraint(rhs >= (I | Z) + (k * s))
129 else:
130 if k == 1:
131 P.add_constraint(x <= rhs)
132 elif k == n:
133 P.add_constraint((1 | x) <= rhs)
134 else:
135 lbda = RealVariable('lambda')
136 mu = RealVariable('mu', x.shape, lower=0)
137 P.add_constraint(x <= lbda + mu)
138 P.add_constraint(k * lbda + (1 | mu) <= rhs)
139 else:
140 if theSum.eigenvalues:
141 if k == 1:
142 P.add_constraint(x >> rhs * I)
143 elif k == n:
144 # NOTE: Refinement prevents this case from happening.
145 P.add_constraint((I | x) <= rhs)
146 else:
147 s = RealVariable('s')
148 Z = MatrixVariable('Z', n)
149 P.add_constraint(Z >> 0)
150 P.add_constraint(-x << Z + s * I)
151 P.add_constraint(-rhs >= (I | Z) + (k * s))
152 else:
153 if k == 1:
154 P.add_constraint(x >= rhs)
155 elif k == n:
156 P.add_constraint((1 | x) >= rhs)
157 else:
158 lbda = RealVariable('lambda')
159 mu = RealVariable('mu', x.shape, lower=0)
160 P.add_constraint(-x <= lbda + mu)
161 P.add_constraint(k * lbda + (1 | mu) <= -rhs)
163 return P
165 def __init__(self, theSum, relation, rhs):
166 """Construct a :class:`SumExtremesConstraint`.
168 :param ~picos.expressions.SumExtremes theSum:
169 Left hand side expression.
170 :param str relation:
171 Constraint relation symbol.
172 :param ~picos.expressions.AffineExpression rhs:
173 Right hand side expression.
174 """
175 from ..expressions import AffineExpression, SumExtremes
177 assert isinstance(theSum, SumExtremes)
178 assert isinstance(rhs, AffineExpression)
179 assert relation in self.LE + self.GE
180 assert len(rhs) == 1
182 self.theSum = theSum
183 self.relation = relation
184 self.rhs = rhs
186 super(SumExtremesConstraint, self).__init__(theSum._typeStr)
188 Subtype = namedtuple("Subtype", ("argdim", "k", "eigenvalues", "complex"))
190 def _subtype(self):
191 return self.Subtype(len(self.theSum.x), self.theSum.k,
192 self.theSum.eigenvalues, self.theSum.x.complex)
194 @classmethod
195 def _cost(cls, subtype):
196 nm, _, eigenvalues, _ = subtype
198 if eigenvalues:
199 nFloat = nm**0.5
200 n = int(nFloat)
201 assert n == nFloat
202 else:
203 n = nm
205 return n + 1
207 def _expression_names(self):
208 yield "theSum"
209 yield "rhs"
211 def _str(self):
212 if self.relation == self.LE:
213 return glyphs.le(self.theSum.string, self.rhs.string)
214 else:
215 return glyphs.ge(self.theSum.string, self.rhs.string)
217 def _get_slack(self):
218 if self.relation == self.LE:
219 return self.rhs.safe_value - self.theSum.safe_value
220 else:
221 return self.theSum.safe_value - self.rhs.safe_value
224# --------------------------------------
225__all__ = api_end(_API_START, globals())