Coverage for picos/constraints/con_powtrace.py: 93.46%
214 statements
« prev ^ index » next coverage.py v6.5.0, created at 2023-02-15 14:21 +0000
« prev ^ index » next coverage.py v6.5.0, created at 2023-02-15 14:21 +0000
1# ------------------------------------------------------------------------------
2# Copyright (C) 2012-2019 Guillaume Sagnol
3# Copyright (C) 2018-2019 Maximilian Stahlberg
4#
5# This file is part of PICOS.
6#
7# PICOS is free software: you can redistribute it and/or modify it under the
8# terms of the GNU General Public License as published by the Free Software
9# Foundation, either version 3 of the License, or (at your option) any later
10# version.
11#
12# PICOS is distributed in the hope that it will be useful, but WITHOUT ANY
13# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
14# A PARTICULAR PURPOSE. See the GNU General Public License for more details.
15#
16# You should have received a copy of the GNU General Public License along with
17# this program. If not, see <http://www.gnu.org/licenses/>.
18# ------------------------------------------------------------------------------
20"""Implementation of :class:`PowerTraceConstraint`."""
22import math
23from collections import namedtuple
25import cvxopt as cvx
27from .. import glyphs
28from ..apidoc import api_end, api_start
29from .constraint import Constraint, ConstraintConversion
31_API_START = api_start(globals())
32# -------------------------------
35class PowerTraceConstraint(Constraint):
36 """Bound on the trace over the :math:`p`-th power of a matrix.
38 For scalar expressions, this is simply a bound on their :math:`p`-th power.
39 """
41 class Conversion(ConstraintConversion):
42 """Bound on the :math:`p`-th power of a trace constraint conversion.
44 The conversion is based on
45 `this paper <http://nbn-resolving.de/urn:nbn:de:0297-zib-17511>`_.
46 """
48 @classmethod
49 def _count_number_tree_node_types(cls, x):
50 """Count number of conversion tree nodes.
52 Consider a binary tree with x[i] leaves of type i, arranged from
53 left to right, with sum(x) a power of two. A node of the tree is of
54 type i if its 2 parents are of type i; otherwise, a new type is
55 created for this node. This function counts the number of additional
56 types we need to create while growing the tree.
57 """
58 x = [xi for xi in x if xi != 0]
59 sum_x = sum(x)
61 # We have reached the tree root. Stop the recursion.
62 if sum_x == 1:
63 return 0
65 # Make sure x is a power of two.
66 _log2_sum_x = math.log(sum_x, 2)
67 assert _log2_sum_x == int(_log2_sum_x)
69 new_x = []
70 new_t = 0
71 s = 0
72 offset = 0
74 # Compute the vector new_x of types at next level.
75 for x_i in x:
76 s += x_i
78 if s % 2 == 0:
79 if x_i - offset >= 2:
80 new_x.append((x_i - offset) // 2)
82 offset = 0
83 else:
84 if x_i - offset >= 2:
85 new_x.extend([(x_i - offset) // 2, 1])
86 elif x_i - offset == 1:
87 new_x.append(1)
88 elif x_i - offset == 0:
89 assert False, "Unexpected case."
91 offset = 1
92 new_t += 1
94 assert 2*sum(new_x) == sum_x
96 return new_t + cls._count_number_tree_node_types(new_x)
98 @staticmethod
99 def _np2(n):
100 """Compute the smallest power of two that is an upper bound."""
101 return 2**int(math.ceil(math.log(n, 2)))
103 @classmethod
104 def predict(cls, subtype, options):
105 """Implement :meth:`~.constraint.ConstraintConversion.predict`."""
106 from ..expressions import (HermitianVariable, RealVariable,
107 SymmetricVariable)
108 from . import (AffineConstraint, ComplexLMIConstraint,
109 RSOCConstraint, LMIConstraint)
111 n, num, den, hasM, complex = subtype
113 if num > den > 0:
114 x = [den, cls._np2(num) - num, num - den]
115 elif num / den < 0:
116 num = abs(num)
117 den = abs(den)
118 x = [den, num, cls._np2(num + den) - num - den]
119 elif 0 < num < den:
120 x = [num, cls._np2(den) - den, den - num]
121 else:
122 assert False, "Unexpected exponent."
124 N = cls._count_number_tree_node_types(x)
126 if n == 1:
127 yield ("var", RealVariable.make_var_type(dim=1, bnd=0), N - 1)
128 yield ("con", RSOCConstraint.make_type(argdim=1), N)
129 if hasM:
130 yield ("var", RealVariable.make_var_type(dim=1, bnd=0), 1)
131 yield ("con",
132 AffineConstraint.make_type(dim=1, eq=False), 1)
133 else:
134 if complex:
135 yield ("var",
136 HermitianVariable.make_var_type(dim=n**2, bnd=0), N)
137 yield ("con",
138 ComplexLMIConstraint.make_type(diag=2*n), N)
139 else:
140 yield ("var", SymmetricVariable.make_var_type(
141 dim=(n * (n + 1)) // 2, bnd=0), N)
142 yield ("con", LMIConstraint.make_type(diag=2*n), N)
143 yield ("con", AffineConstraint.make_type(dim=1, eq=False), 1)
145 @classmethod
146 def convert(cls, con, options):
147 """Implement :meth:`~.constraint.ConstraintConversion.convert`."""
148 from ..expressions import Constant
149 from ..expressions.algebra import block, rsoc
150 from ..modeling import Problem
152 x = con.power.x
153 n = con.power.n
154 num = con.power.num
155 den = con.power.den
156 rhs = con.rhs
157 m = con.power.m
159 vtype = "hermitian" if x.complex else "symmetric"
161 P = Problem()
163 if n == 1:
164 idt = Constant('1', 1)
165 if m is None:
166 varcnt = 0
167 v = []
168 else:
169 varcnt = 1
170 v = [P.add_variable('__v[0]', 1)]
171 else:
172 idt = Constant('I', cvx.spdiag([1.] * n))
173 varcnt = 1
174 v = [P.add_variable('__v[0]', (n, n), vtype)]
176 if con.relation == Constraint.LE and num > den:
177 pown = cls._np2(num)
179 if n == 1:
180 lis = [rhs]*den + [x]*(pown - num) + [idt]*(num - den)
181 else:
182 lis = [v[0]]*den + [x]*(pown - num) + [idt]*(num - den)
184 while len(lis) > 2:
185 newlis = []
186 while lis:
187 v1 = lis.pop()
188 v2 = lis.pop()
190 if v1 is v2:
191 newlis.append(v2)
192 else:
193 if n == 1:
194 v0 = P.add_variable(
195 '__v[' + str(varcnt) + ']', 1)
196 P.add_constraint((v1 & v2 & v0) << rsoc())
197 else:
198 v0 = P.add_variable(
199 '__v[' + str(varcnt) + ']', (n, n), vtype)
200 P.add_constraint(
201 block([[v1, v0], [v0, v2]]) >> 0)
203 varcnt += 1
204 newlis.append(v0)
205 v.append(v0)
206 lis = newlis
208 if n == 1:
209 P.add_constraint((lis[0] & lis[1] & x) << rsoc())
210 else:
211 P.add_constraint(block([[lis[0], x], [x, lis[1]]]) >> 0)
212 P.add_constraint((idt | v[0]) <= rhs)
213 elif con.relation == Constraint.LE and num <= den:
214 num = abs(num)
215 den = abs(den)
217 pown = cls._np2(num + den)
219 if n == 1:
220 lis = [rhs] * den + [x] * num + [idt] * (pown - num - den)
221 else:
222 lis = [v[0]] * den + [x] * num + [idt] * (pown - num - den)
224 while len(lis) > 2:
225 newlis = []
226 while lis:
227 v1 = lis.pop()
228 v2 = lis.pop()
230 if v1 is v2:
231 newlis.append(v2)
232 else:
233 if n == 1:
234 v0 = P.add_variable(
235 '__v[' + str(varcnt) + ']', 1)
236 P.add_constraint((v1 & v2 & v0) << rsoc())
237 else:
238 v0 = P.add_variable(
239 '__v[' + str(varcnt) + ']', (n, n), vtype)
240 P.add_constraint(
241 block([[v1, v0], [v0, v2]]) >> 0)
243 varcnt += 1
244 newlis.append(v0)
245 v.append(v0)
246 lis = newlis
248 if n == 1:
249 P.add_constraint((lis[0] & lis[1] & 1) << rsoc())
250 else:
251 P.add_constraint(block([[lis[0], idt], [idt, lis[1]]]) >> 0)
252 P.add_constraint((idt | v[0]) <= rhs)
253 elif con.relation == Constraint.GE:
254 pown = cls._np2(den)
256 if n == 1:
257 lis = [x]*num + [rhs]*(pown - den) + [idt]*(den - num)
259 else:
260 lis = [x]*num + [v[0]]*(pown - den) + [idt]*(den - num)
262 while len(lis) > 2:
263 newlis = []
264 while lis:
265 v1 = lis.pop()
266 v2 = lis.pop()
268 if v1 is v2:
269 newlis.append(v2)
270 else:
271 if n == 1:
272 v0 = P.add_variable(
273 '__v[' + str(varcnt) + ']', 1)
274 P.add_constraint((v1 & v2 & v0) << rsoc())
275 else:
276 v0 = P.add_variable(
277 '__v[' + str(varcnt) + ']', (n, n), vtype)
278 P.add_constraint(
279 block([[v1, v0], [v0, v2]]) >> 0)
281 varcnt += 1
282 newlis.append(v0)
283 v.append(v0)
284 lis = newlis
286 if n == 1:
287 if m is None:
288 P.add_constraint((lis[0] & lis[1] & rhs) << rsoc())
289 else:
290 P.add_constraint((lis[0] & lis[1] & v[0]) << rsoc())
291 P.add_constraint((m * v[0]) > rhs)
292 else:
293 P.add_constraint(
294 block([[lis[0], v[0]], [v[0], lis[1]]]) >> 0)
295 if m is None:
296 P.add_constraint((idt | v[0]) > rhs)
297 else:
298 P.add_constraint((m | v[0]) > rhs)
299 else:
300 assert False, "Dijkstra-IF fallthrough."
302 return P
304 def __init__(self, power, relation, rhs):
305 """Construct a :class:`PowerTraceConstraint`.
307 :param ~picos.expressions.PowerTrace ower:
308 Left hand side expression.
309 :param str relation:
310 Constraint relation symbol.
311 :param ~picos.expressions.AffineExpression rhs:
312 Right hand side expression.
313 """
314 from ..expressions import AffineExpression, PowerTrace
316 assert isinstance(power, PowerTrace)
317 assert relation in self.LE + self.GE
318 assert isinstance(rhs, AffineExpression)
319 assert len(rhs) == 1
321 p = power.p
323 assert p != 0 and p != 1, \
324 "The PowerTraceConstraint should not be created for p = 0 and " \
325 "p = 1 as there are more direct ways to represent such powers."
327 if relation == self.LE:
328 assert p <= 0 or p >= 1, \
329 "Upper bounding p-th power needs p s.t. the power is convex."
330 else:
331 assert p >= 0 and p <= 1, \
332 "Lower bounding p-th power needs p s.t. the power is concave."
334 self.power = power
335 self.relation = relation
336 self.rhs = rhs
338 super(PowerTraceConstraint, self).__init__(power._typeStr)
340 # HACK: Support Constraint's LHS/RHS interface.
341 # TODO: Add a unified interface for such constraints?
342 lhs = property(lambda self: self.power)
344 def is_trace(self):
345 """Whether the bound concerns a trace as opposed to a scalar."""
346 return self.power.n > 1
348 Subtype = namedtuple("Subtype", ("diag", "num", "den", "hasM", "complex"))
350 def _subtype(self):
351 return self.Subtype(*self.power.subtype)
353 @classmethod
354 def _cost(cls, subtype):
355 n = subtype.diag
356 if subtype.complex:
357 return n**2 + 1
358 else:
359 return n*(n + 1)//2 + 1
361 def _expression_names(self):
362 yield "power"
363 yield "rhs"
365 def _str(self):
366 if self.relation == self.LE:
367 return glyphs.le(self.power.string, self.rhs.string)
368 else:
369 return glyphs.ge(self.power.string, self.rhs.string)
371 def _get_slack(self):
372 if self.relation == self.LE:
373 return self.rhs.safe_value - self.power.safe_value
374 else:
375 return self.power.safe_value - self.rhs.safe_value
378# --------------------------------------
379__all__ = api_end(_API_START, globals())