Coverage for picos/constraints/con_powtrace.py: 93.40%
212 statements
« prev ^ index » next coverage.py v7.6.12, created at 2025-04-12 07:53 +0000
« prev ^ index » next coverage.py v7.6.12, created at 2025-04-12 07:53 +0000
1# ------------------------------------------------------------------------------
2# Copyright (C) 2012-2019 Guillaume Sagnol
3# Copyright (C) 2018-2019 Maximilian Stahlberg
4#
5# This file is part of PICOS.
6#
7# PICOS is free software: you can redistribute it and/or modify it under the
8# terms of the GNU General Public License as published by the Free Software
9# Foundation, either version 3 of the License, or (at your option) any later
10# version.
11#
12# PICOS is distributed in the hope that it will be useful, but WITHOUT ANY
13# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
14# A PARTICULAR PURPOSE. See the GNU General Public License for more details.
15#
16# You should have received a copy of the GNU General Public License along with
17# this program. If not, see <http://www.gnu.org/licenses/>.
18# ------------------------------------------------------------------------------
20"""Implementation of :class:`PowerTraceConstraint`."""
22import math
23from collections import namedtuple
25import cvxopt as cvx
27from picos.expressions.variables import HermitianVariable, SymmetricVariable
29from .. import glyphs
30from ..apidoc import api_end, api_start
31from .constraint import Constraint, ConstraintConversion
33_API_START = api_start(globals())
34# -------------------------------
37class PowerTraceConstraint(Constraint):
38 """Bound on the trace over the :math:`p`-th power of a matrix.
40 For scalar expressions, this is simply a bound on their :math:`p`-th power.
41 """
43 class Conversion(ConstraintConversion):
44 """Bound on the :math:`p`-th power of a trace constraint conversion.
46 The conversion is based on
47 `this paper <http://nbn-resolving.de/urn:nbn:de:0297-zib-17511>`_.
48 """
50 @classmethod
51 def _count_number_tree_node_types(cls, x):
52 """Count number of conversion tree nodes.
54 Consider a binary tree with x[i] leaves of type i, arranged from
55 left to right, with sum(x) a power of two. A node of the tree is of
56 type i if its 2 parents are of type i; otherwise, a new type is
57 created for this node. This function counts the number of additional
58 types we need to create while growing the tree.
59 """
60 x = [xi for xi in x if xi != 0]
61 sum_x = sum(x)
63 # We have reached the tree root. Stop the recursion.
64 if sum_x == 1:
65 return 0
67 # Make sure x is a power of two.
68 _log2_sum_x = math.log(sum_x, 2)
69 assert _log2_sum_x == int(_log2_sum_x)
71 new_x = []
72 new_t = 0
73 s = 0
74 offset = 0
76 # Compute the vector new_x of types at next level.
77 for x_i in x:
78 s += x_i
80 if s % 2 == 0:
81 if x_i - offset >= 2:
82 new_x.append((x_i - offset) // 2)
84 offset = 0
85 else:
86 if x_i - offset >= 2:
87 new_x.extend([(x_i - offset) // 2, 1])
88 elif x_i - offset == 1:
89 new_x.append(1)
90 elif x_i - offset == 0:
91 assert False, "Unexpected case."
93 offset = 1
94 new_t += 1
96 assert 2*sum(new_x) == sum_x
98 return new_t + cls._count_number_tree_node_types(new_x)
100 @staticmethod
101 def _np2(n):
102 """Compute the smallest power of two that is an upper bound."""
103 return 2**int(math.ceil(math.log(n, 2)))
105 @classmethod
106 def predict(cls, subtype, options):
107 """Implement :meth:`~.constraint.ConstraintConversion.predict`."""
108 from ..expressions import (HermitianVariable, RealVariable,
109 SymmetricVariable)
110 from . import (AffineConstraint, ComplexLMIConstraint,
111 LMIConstraint, RSOCConstraint)
113 n, num, den, hasM, complex = subtype
115 if num > den > 0:
116 x = [den, cls._np2(num) - num, num - den]
117 elif num / den < 0:
118 num = abs(num)
119 den = abs(den)
120 x = [den, num, cls._np2(num + den) - num - den]
121 elif 0 < num < den:
122 x = [num, cls._np2(den) - den, den - num]
123 else:
124 assert False, "Unexpected exponent."
126 N = cls._count_number_tree_node_types(x)
128 if n == 1:
129 yield ("var", RealVariable.make_var_type(dim=1, bnd=0), N - 1)
130 yield ("con", RSOCConstraint.make_type(argdim=1), N)
131 if hasM:
132 yield ("var", RealVariable.make_var_type(dim=1, bnd=0), 1)
133 yield ("con",
134 AffineConstraint.make_type(dim=1, eq=False), 1)
135 else:
136 if complex:
137 yield ("var",
138 HermitianVariable.make_var_type(dim=n**2, bnd=0), N)
139 yield ("con",
140 ComplexLMIConstraint.make_type(diag=2*n), N)
141 else:
142 yield ("var", SymmetricVariable.make_var_type(
143 dim=(n * (n + 1)) // 2, bnd=0), N)
144 yield ("con", LMIConstraint.make_type(diag=2*n), N)
145 yield ("con", AffineConstraint.make_type(dim=1, eq=False), 1)
147 @classmethod
148 def convert(cls, con, options):
149 """Implement :meth:`~.constraint.ConstraintConversion.convert`."""
150 from ..expressions import (HermitianVariable, RealVariable,
151 SymmetricVariable)
152 from ..expressions.algebra import I, block, rsoc
153 from ..modeling import Problem
155 x = con.power.x
156 n = con.power.n
157 num = con.power.num
158 den = con.power.den
159 rhs = con.rhs
160 m = con.power.m
162 Var = HermitianVariable if x.complex else SymmetricVariable
164 P = Problem()
166 if n == 1:
167 if m is None:
168 varcnt = 0
169 v = []
170 else:
171 varcnt = 1
172 v = [RealVariable('__v[0]')]
173 else:
174 varcnt = 1
175 v = [Var('__v[0]', (n, n))]
177 if con.relation == Constraint.LE and num > den:
178 pown = cls._np2(num)
180 if n == 1:
181 lis = [rhs]*den + [x]*(pown - num) + [I(n)]*(num - den)
182 else:
183 lis = [v[0]]*den + [x]*(pown - num) + [I(n)]*(num - den)
185 while len(lis) > 2:
186 newlis = []
187 while lis:
188 v1 = lis.pop()
189 v2 = lis.pop()
191 if v1 is v2:
192 newlis.append(v2)
193 else:
194 if n == 1:
195 v0 = RealVariable('__v[' + str(varcnt) + ']')
196 P.add_constraint((v1 & v2 & v0) << rsoc())
197 else:
198 v0 = Var('__v[' + str(varcnt) + ']', (n, n))
199 P.add_constraint(
200 block([[v1, v0], [v0, v2]]) >> 0)
202 varcnt += 1
203 newlis.append(v0)
204 v.append(v0)
205 lis = newlis
207 if n == 1:
208 P.add_constraint((lis[0] & lis[1] & x) << rsoc())
209 else:
210 P.add_constraint(block([[lis[0], x], [x, lis[1]]]) >> 0)
211 P.add_constraint(v[0].tr <= rhs)
212 elif con.relation == Constraint.LE and num <= den:
213 num = abs(num)
214 den = abs(den)
216 pown = cls._np2(num + den)
218 if n == 1:
219 lis = [rhs] * den + [x] * num + [I(n)] * (pown - num - den)
220 else:
221 lis = [v[0]] * den + [x] * num + [I(n)] * (pown - num - den)
223 while len(lis) > 2:
224 newlis = []
225 while lis:
226 v1 = lis.pop()
227 v2 = lis.pop()
229 if v1 is v2:
230 newlis.append(v2)
231 else:
232 if n == 1:
233 v0 = RealVariable('__v[' + str(varcnt) + ']')
234 P.add_constraint((v1 & v2 & v0) << rsoc())
235 else:
236 v0 = Var('__v[' + str(varcnt) + ']', (n, n))
237 P.add_constraint(
238 block([[v1, v0], [v0, v2]]) >> 0)
240 varcnt += 1
241 newlis.append(v0)
242 v.append(v0)
243 lis = newlis
245 if n == 1:
246 P.add_constraint((lis[0] & lis[1] & 1) << rsoc())
247 else:
248 P.add_constraint(
249 block([[lis[0], I(n)], [I(n), lis[1]]]) >> 0)
250 P.add_constraint(v[0].tr <= rhs)
251 elif con.relation == Constraint.GE:
252 pown = cls._np2(den)
254 if n == 1:
255 lis = [x]*num + [rhs]*(pown - den) + [I(n)]*(den - num)
257 else:
258 lis = [x]*num + [v[0]]*(pown - den) + [I(n)]*(den - num)
260 while len(lis) > 2:
261 newlis = []
262 while lis:
263 v1 = lis.pop()
264 v2 = lis.pop()
266 if v1 is v2:
267 newlis.append(v2)
268 else:
269 if n == 1:
270 v0 = RealVariable('__v[' + str(varcnt) + ']')
271 P.add_constraint((v1 & v2 & v0) << rsoc())
272 else:
273 v0 = Var('__v[' + str(varcnt) + ']', (n, n))
274 P.add_constraint(
275 block([[v1, v0], [v0, v2]]) >> 0)
277 varcnt += 1
278 newlis.append(v0)
279 v.append(v0)
280 lis = newlis
282 if n == 1:
283 if m is None:
284 P.add_constraint((lis[0] & lis[1] & rhs) << rsoc())
285 else:
286 P.add_constraint((lis[0] & lis[1] & v[0]) << rsoc())
287 P.add_constraint((m * v[0]) >= rhs)
288 else:
289 P.add_constraint(
290 block([[lis[0], v[0]], [v[0], lis[1]]]) >> 0)
291 if m is None:
292 P.add_constraint(v[0].tr >= rhs)
293 else:
294 P.add_constraint((m | v[0]) >= rhs)
295 else:
296 assert False, "Dijkstra-IF fallthrough."
298 return P
300 def __init__(self, power, relation, rhs):
301 """Construct a :class:`PowerTraceConstraint`.
303 :param ~picos.expressions.PowerTrace ower:
304 Left hand side expression.
305 :param str relation:
306 Constraint relation symbol.
307 :param ~picos.expressions.AffineExpression rhs:
308 Right hand side expression.
309 """
310 from ..expressions import AffineExpression, PowerTrace
312 assert isinstance(power, PowerTrace)
313 assert relation in self.LE + self.GE
314 assert isinstance(rhs, AffineExpression)
315 assert len(rhs) == 1
317 p = power.p
319 assert p != 0 and p != 1, \
320 "The PowerTraceConstraint should not be created for p = 0 and " \
321 "p = 1 as there are more direct ways to represent such powers."
323 if relation == self.LE:
324 assert p <= 0 or p >= 1, \
325 "Upper bounding p-th power needs p s.t. the power is convex."
326 else:
327 assert p >= 0 and p <= 1, \
328 "Lower bounding p-th power needs p s.t. the power is concave."
330 self.power = power
331 self.relation = relation
332 self.rhs = rhs
334 super(PowerTraceConstraint, self).__init__(power._typeStr)
336 # HACK: Support Constraint's LHS/RHS interface.
337 # TODO: Add a unified interface for such constraints?
338 lhs = property(lambda self: self.power)
340 def is_trace(self):
341 """Whether the bound concerns a trace as opposed to a scalar."""
342 return self.power.n > 1
344 Subtype = namedtuple("Subtype", ("diag", "num", "den", "hasM", "complex"))
346 def _subtype(self):
347 return self.Subtype(*self.power.subtype)
349 @classmethod
350 def _cost(cls, subtype):
351 n = subtype.diag
352 if subtype.complex:
353 return n**2 + 1
354 else:
355 return n*(n + 1)//2 + 1
357 def _expression_names(self):
358 yield "power"
359 yield "rhs"
361 def _str(self):
362 if self.relation == self.LE:
363 return glyphs.le(self.power.string, self.rhs.string)
364 else:
365 return glyphs.ge(self.power.string, self.rhs.string)
367 def _get_slack(self):
368 if self.relation == self.LE:
369 return self.rhs.safe_value - self.power.safe_value
370 else:
371 return self.power.safe_value - self.rhs.safe_value
374# --------------------------------------
375__all__ = api_end(_API_START, globals())