Coverage for picos/constraints/con_powtrace.py: 93.40%

212 statements  

« prev     ^ index     » next       coverage.py v7.6.12, created at 2025-04-12 07:53 +0000

1# ------------------------------------------------------------------------------ 

2# Copyright (C) 2012-2019 Guillaume Sagnol 

3# Copyright (C) 2018-2019 Maximilian Stahlberg 

4# 

5# This file is part of PICOS. 

6# 

7# PICOS is free software: you can redistribute it and/or modify it under the 

8# terms of the GNU General Public License as published by the Free Software 

9# Foundation, either version 3 of the License, or (at your option) any later 

10# version. 

11# 

12# PICOS is distributed in the hope that it will be useful, but WITHOUT ANY 

13# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR 

14# A PARTICULAR PURPOSE. See the GNU General Public License for more details. 

15# 

16# You should have received a copy of the GNU General Public License along with 

17# this program. If not, see <http://www.gnu.org/licenses/>. 

18# ------------------------------------------------------------------------------ 

19 

20"""Implementation of :class:`PowerTraceConstraint`.""" 

21 

22import math 

23from collections import namedtuple 

24 

25import cvxopt as cvx 

26 

27from picos.expressions.variables import HermitianVariable, SymmetricVariable 

28 

29from .. import glyphs 

30from ..apidoc import api_end, api_start 

31from .constraint import Constraint, ConstraintConversion 

32 

33_API_START = api_start(globals()) 

34# ------------------------------- 

35 

36 

37class PowerTraceConstraint(Constraint): 

38 """Bound on the trace over the :math:`p`-th power of a matrix. 

39 

40 For scalar expressions, this is simply a bound on their :math:`p`-th power. 

41 """ 

42 

43 class Conversion(ConstraintConversion): 

44 """Bound on the :math:`p`-th power of a trace constraint conversion. 

45 

46 The conversion is based on 

47 `this paper <http://nbn-resolving.de/urn:nbn:de:0297-zib-17511>`_. 

48 """ 

49 

50 @classmethod 

51 def _count_number_tree_node_types(cls, x): 

52 """Count number of conversion tree nodes. 

53 

54 Consider a binary tree with x[i] leaves of type i, arranged from 

55 left to right, with sum(x) a power of two. A node of the tree is of 

56 type i if its 2 parents are of type i; otherwise, a new type is 

57 created for this node. This function counts the number of additional 

58 types we need to create while growing the tree. 

59 """ 

60 x = [xi for xi in x if xi != 0] 

61 sum_x = sum(x) 

62 

63 # We have reached the tree root. Stop the recursion. 

64 if sum_x == 1: 

65 return 0 

66 

67 # Make sure x is a power of two. 

68 _log2_sum_x = math.log(sum_x, 2) 

69 assert _log2_sum_x == int(_log2_sum_x) 

70 

71 new_x = [] 

72 new_t = 0 

73 s = 0 

74 offset = 0 

75 

76 # Compute the vector new_x of types at next level. 

77 for x_i in x: 

78 s += x_i 

79 

80 if s % 2 == 0: 

81 if x_i - offset >= 2: 

82 new_x.append((x_i - offset) // 2) 

83 

84 offset = 0 

85 else: 

86 if x_i - offset >= 2: 

87 new_x.extend([(x_i - offset) // 2, 1]) 

88 elif x_i - offset == 1: 

89 new_x.append(1) 

90 elif x_i - offset == 0: 

91 assert False, "Unexpected case." 

92 

93 offset = 1 

94 new_t += 1 

95 

96 assert 2*sum(new_x) == sum_x 

97 

98 return new_t + cls._count_number_tree_node_types(new_x) 

99 

100 @staticmethod 

101 def _np2(n): 

102 """Compute the smallest power of two that is an upper bound.""" 

103 return 2**int(math.ceil(math.log(n, 2))) 

104 

105 @classmethod 

106 def predict(cls, subtype, options): 

107 """Implement :meth:`~.constraint.ConstraintConversion.predict`.""" 

108 from ..expressions import (HermitianVariable, RealVariable, 

109 SymmetricVariable) 

110 from . import (AffineConstraint, ComplexLMIConstraint, 

111 LMIConstraint, RSOCConstraint) 

112 

113 n, num, den, hasM, complex = subtype 

114 

115 if num > den > 0: 

116 x = [den, cls._np2(num) - num, num - den] 

117 elif num / den < 0: 

118 num = abs(num) 

119 den = abs(den) 

120 x = [den, num, cls._np2(num + den) - num - den] 

121 elif 0 < num < den: 

122 x = [num, cls._np2(den) - den, den - num] 

123 else: 

124 assert False, "Unexpected exponent." 

125 

126 N = cls._count_number_tree_node_types(x) 

127 

128 if n == 1: 

129 yield ("var", RealVariable.make_var_type(dim=1, bnd=0), N - 1) 

130 yield ("con", RSOCConstraint.make_type(argdim=1), N) 

131 if hasM: 

132 yield ("var", RealVariable.make_var_type(dim=1, bnd=0), 1) 

133 yield ("con", 

134 AffineConstraint.make_type(dim=1, eq=False), 1) 

135 else: 

136 if complex: 

137 yield ("var", 

138 HermitianVariable.make_var_type(dim=n**2, bnd=0), N) 

139 yield ("con", 

140 ComplexLMIConstraint.make_type(diag=2*n), N) 

141 else: 

142 yield ("var", SymmetricVariable.make_var_type( 

143 dim=(n * (n + 1)) // 2, bnd=0), N) 

144 yield ("con", LMIConstraint.make_type(diag=2*n), N) 

145 yield ("con", AffineConstraint.make_type(dim=1, eq=False), 1) 

146 

147 @classmethod 

148 def convert(cls, con, options): 

149 """Implement :meth:`~.constraint.ConstraintConversion.convert`.""" 

150 from ..expressions import (HermitianVariable, RealVariable, 

151 SymmetricVariable) 

152 from ..expressions.algebra import I, block, rsoc 

153 from ..modeling import Problem 

154 

155 x = con.power.x 

156 n = con.power.n 

157 num = con.power.num 

158 den = con.power.den 

159 rhs = con.rhs 

160 m = con.power.m 

161 

162 Var = HermitianVariable if x.complex else SymmetricVariable 

163 

164 P = Problem() 

165 

166 if n == 1: 

167 if m is None: 

168 varcnt = 0 

169 v = [] 

170 else: 

171 varcnt = 1 

172 v = [RealVariable('__v[0]')] 

173 else: 

174 varcnt = 1 

175 v = [Var('__v[0]', (n, n))] 

176 

177 if con.relation == Constraint.LE and num > den: 

178 pown = cls._np2(num) 

179 

180 if n == 1: 

181 lis = [rhs]*den + [x]*(pown - num) + [I(n)]*(num - den) 

182 else: 

183 lis = [v[0]]*den + [x]*(pown - num) + [I(n)]*(num - den) 

184 

185 while len(lis) > 2: 

186 newlis = [] 

187 while lis: 

188 v1 = lis.pop() 

189 v2 = lis.pop() 

190 

191 if v1 is v2: 

192 newlis.append(v2) 

193 else: 

194 if n == 1: 

195 v0 = RealVariable('__v[' + str(varcnt) + ']') 

196 P.add_constraint((v1 & v2 & v0) << rsoc()) 

197 else: 

198 v0 = Var('__v[' + str(varcnt) + ']', (n, n)) 

199 P.add_constraint( 

200 block([[v1, v0], [v0, v2]]) >> 0) 

201 

202 varcnt += 1 

203 newlis.append(v0) 

204 v.append(v0) 

205 lis = newlis 

206 

207 if n == 1: 

208 P.add_constraint((lis[0] & lis[1] & x) << rsoc()) 

209 else: 

210 P.add_constraint(block([[lis[0], x], [x, lis[1]]]) >> 0) 

211 P.add_constraint(v[0].tr <= rhs) 

212 elif con.relation == Constraint.LE and num <= den: 

213 num = abs(num) 

214 den = abs(den) 

215 

216 pown = cls._np2(num + den) 

217 

218 if n == 1: 

219 lis = [rhs] * den + [x] * num + [I(n)] * (pown - num - den) 

220 else: 

221 lis = [v[0]] * den + [x] * num + [I(n)] * (pown - num - den) 

222 

223 while len(lis) > 2: 

224 newlis = [] 

225 while lis: 

226 v1 = lis.pop() 

227 v2 = lis.pop() 

228 

229 if v1 is v2: 

230 newlis.append(v2) 

231 else: 

232 if n == 1: 

233 v0 = RealVariable('__v[' + str(varcnt) + ']') 

234 P.add_constraint((v1 & v2 & v0) << rsoc()) 

235 else: 

236 v0 = Var('__v[' + str(varcnt) + ']', (n, n)) 

237 P.add_constraint( 

238 block([[v1, v0], [v0, v2]]) >> 0) 

239 

240 varcnt += 1 

241 newlis.append(v0) 

242 v.append(v0) 

243 lis = newlis 

244 

245 if n == 1: 

246 P.add_constraint((lis[0] & lis[1] & 1) << rsoc()) 

247 else: 

248 P.add_constraint( 

249 block([[lis[0], I(n)], [I(n), lis[1]]]) >> 0) 

250 P.add_constraint(v[0].tr <= rhs) 

251 elif con.relation == Constraint.GE: 

252 pown = cls._np2(den) 

253 

254 if n == 1: 

255 lis = [x]*num + [rhs]*(pown - den) + [I(n)]*(den - num) 

256 

257 else: 

258 lis = [x]*num + [v[0]]*(pown - den) + [I(n)]*(den - num) 

259 

260 while len(lis) > 2: 

261 newlis = [] 

262 while lis: 

263 v1 = lis.pop() 

264 v2 = lis.pop() 

265 

266 if v1 is v2: 

267 newlis.append(v2) 

268 else: 

269 if n == 1: 

270 v0 = RealVariable('__v[' + str(varcnt) + ']') 

271 P.add_constraint((v1 & v2 & v0) << rsoc()) 

272 else: 

273 v0 = Var('__v[' + str(varcnt) + ']', (n, n)) 

274 P.add_constraint( 

275 block([[v1, v0], [v0, v2]]) >> 0) 

276 

277 varcnt += 1 

278 newlis.append(v0) 

279 v.append(v0) 

280 lis = newlis 

281 

282 if n == 1: 

283 if m is None: 

284 P.add_constraint((lis[0] & lis[1] & rhs) << rsoc()) 

285 else: 

286 P.add_constraint((lis[0] & lis[1] & v[0]) << rsoc()) 

287 P.add_constraint((m * v[0]) >= rhs) 

288 else: 

289 P.add_constraint( 

290 block([[lis[0], v[0]], [v[0], lis[1]]]) >> 0) 

291 if m is None: 

292 P.add_constraint(v[0].tr >= rhs) 

293 else: 

294 P.add_constraint((m | v[0]) >= rhs) 

295 else: 

296 assert False, "Dijkstra-IF fallthrough." 

297 

298 return P 

299 

300 def __init__(self, power, relation, rhs): 

301 """Construct a :class:`PowerTraceConstraint`. 

302 

303 :param ~picos.expressions.PowerTrace ower: 

304 Left hand side expression. 

305 :param str relation: 

306 Constraint relation symbol. 

307 :param ~picos.expressions.AffineExpression rhs: 

308 Right hand side expression. 

309 """ 

310 from ..expressions import AffineExpression, PowerTrace 

311 

312 assert isinstance(power, PowerTrace) 

313 assert relation in self.LE + self.GE 

314 assert isinstance(rhs, AffineExpression) 

315 assert len(rhs) == 1 

316 

317 p = power.p 

318 

319 assert p != 0 and p != 1, \ 

320 "The PowerTraceConstraint should not be created for p = 0 and " \ 

321 "p = 1 as there are more direct ways to represent such powers." 

322 

323 if relation == self.LE: 

324 assert p <= 0 or p >= 1, \ 

325 "Upper bounding p-th power needs p s.t. the power is convex." 

326 else: 

327 assert p >= 0 and p <= 1, \ 

328 "Lower bounding p-th power needs p s.t. the power is concave." 

329 

330 self.power = power 

331 self.relation = relation 

332 self.rhs = rhs 

333 

334 super(PowerTraceConstraint, self).__init__(power._typeStr) 

335 

336 # HACK: Support Constraint's LHS/RHS interface. 

337 # TODO: Add a unified interface for such constraints? 

338 lhs = property(lambda self: self.power) 

339 

340 def is_trace(self): 

341 """Whether the bound concerns a trace as opposed to a scalar.""" 

342 return self.power.n > 1 

343 

344 Subtype = namedtuple("Subtype", ("diag", "num", "den", "hasM", "complex")) 

345 

346 def _subtype(self): 

347 return self.Subtype(*self.power.subtype) 

348 

349 @classmethod 

350 def _cost(cls, subtype): 

351 n = subtype.diag 

352 if subtype.complex: 

353 return n**2 + 1 

354 else: 

355 return n*(n + 1)//2 + 1 

356 

357 def _expression_names(self): 

358 yield "power" 

359 yield "rhs" 

360 

361 def _str(self): 

362 if self.relation == self.LE: 

363 return glyphs.le(self.power.string, self.rhs.string) 

364 else: 

365 return glyphs.ge(self.power.string, self.rhs.string) 

366 

367 def _get_slack(self): 

368 if self.relation == self.LE: 

369 return self.rhs.safe_value - self.power.safe_value 

370 else: 

371 return self.power.safe_value - self.rhs.safe_value 

372 

373 

374# -------------------------------------- 

375__all__ = api_end(_API_START, globals())