Coverage for picos/constraints/con_powtrace.py: 93.46%

214 statements  

« prev     ^ index     » next       coverage.py v6.5.0, created at 2023-03-26 07:46 +0000

1# ------------------------------------------------------------------------------ 

2# Copyright (C) 2012-2019 Guillaume Sagnol 

3# Copyright (C) 2018-2019 Maximilian Stahlberg 

4# 

5# This file is part of PICOS. 

6# 

7# PICOS is free software: you can redistribute it and/or modify it under the 

8# terms of the GNU General Public License as published by the Free Software 

9# Foundation, either version 3 of the License, or (at your option) any later 

10# version. 

11# 

12# PICOS is distributed in the hope that it will be useful, but WITHOUT ANY 

13# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR 

14# A PARTICULAR PURPOSE. See the GNU General Public License for more details. 

15# 

16# You should have received a copy of the GNU General Public License along with 

17# this program. If not, see <http://www.gnu.org/licenses/>. 

18# ------------------------------------------------------------------------------ 

19 

20"""Implementation of :class:`PowerTraceConstraint`.""" 

21 

22import math 

23from collections import namedtuple 

24 

25import cvxopt as cvx 

26 

27from .. import glyphs 

28from ..apidoc import api_end, api_start 

29from .constraint import Constraint, ConstraintConversion 

30 

31_API_START = api_start(globals()) 

32# ------------------------------- 

33 

34 

35class PowerTraceConstraint(Constraint): 

36 """Bound on the trace over the :math:`p`-th power of a matrix. 

37 

38 For scalar expressions, this is simply a bound on their :math:`p`-th power. 

39 """ 

40 

41 class Conversion(ConstraintConversion): 

42 """Bound on the :math:`p`-th power of a trace constraint conversion. 

43 

44 The conversion is based on 

45 `this paper <http://nbn-resolving.de/urn:nbn:de:0297-zib-17511>`_. 

46 """ 

47 

48 @classmethod 

49 def _count_number_tree_node_types(cls, x): 

50 """Count number of conversion tree nodes. 

51 

52 Consider a binary tree with x[i] leaves of type i, arranged from 

53 left to right, with sum(x) a power of two. A node of the tree is of 

54 type i if its 2 parents are of type i; otherwise, a new type is 

55 created for this node. This function counts the number of additional 

56 types we need to create while growing the tree. 

57 """ 

58 x = [xi for xi in x if xi != 0] 

59 sum_x = sum(x) 

60 

61 # We have reached the tree root. Stop the recursion. 

62 if sum_x == 1: 

63 return 0 

64 

65 # Make sure x is a power of two. 

66 _log2_sum_x = math.log(sum_x, 2) 

67 assert _log2_sum_x == int(_log2_sum_x) 

68 

69 new_x = [] 

70 new_t = 0 

71 s = 0 

72 offset = 0 

73 

74 # Compute the vector new_x of types at next level. 

75 for x_i in x: 

76 s += x_i 

77 

78 if s % 2 == 0: 

79 if x_i - offset >= 2: 

80 new_x.append((x_i - offset) // 2) 

81 

82 offset = 0 

83 else: 

84 if x_i - offset >= 2: 

85 new_x.extend([(x_i - offset) // 2, 1]) 

86 elif x_i - offset == 1: 

87 new_x.append(1) 

88 elif x_i - offset == 0: 

89 assert False, "Unexpected case." 

90 

91 offset = 1 

92 new_t += 1 

93 

94 assert 2*sum(new_x) == sum_x 

95 

96 return new_t + cls._count_number_tree_node_types(new_x) 

97 

98 @staticmethod 

99 def _np2(n): 

100 """Compute the smallest power of two that is an upper bound.""" 

101 return 2**int(math.ceil(math.log(n, 2))) 

102 

103 @classmethod 

104 def predict(cls, subtype, options): 

105 """Implement :meth:`~.constraint.ConstraintConversion.predict`.""" 

106 from ..expressions import (HermitianVariable, RealVariable, 

107 SymmetricVariable) 

108 from . import (AffineConstraint, ComplexLMIConstraint, 

109 RSOCConstraint, LMIConstraint) 

110 

111 n, num, den, hasM, complex = subtype 

112 

113 if num > den > 0: 

114 x = [den, cls._np2(num) - num, num - den] 

115 elif num / den < 0: 

116 num = abs(num) 

117 den = abs(den) 

118 x = [den, num, cls._np2(num + den) - num - den] 

119 elif 0 < num < den: 

120 x = [num, cls._np2(den) - den, den - num] 

121 else: 

122 assert False, "Unexpected exponent." 

123 

124 N = cls._count_number_tree_node_types(x) 

125 

126 if n == 1: 

127 yield ("var", RealVariable.make_var_type(dim=1, bnd=0), N - 1) 

128 yield ("con", RSOCConstraint.make_type(argdim=1), N) 

129 if hasM: 

130 yield ("var", RealVariable.make_var_type(dim=1, bnd=0), 1) 

131 yield ("con", 

132 AffineConstraint.make_type(dim=1, eq=False), 1) 

133 else: 

134 if complex: 

135 yield ("var", 

136 HermitianVariable.make_var_type(dim=n**2, bnd=0), N) 

137 yield ("con", 

138 ComplexLMIConstraint.make_type(diag=2*n), N) 

139 else: 

140 yield ("var", SymmetricVariable.make_var_type( 

141 dim=(n * (n + 1)) // 2, bnd=0), N) 

142 yield ("con", LMIConstraint.make_type(diag=2*n), N) 

143 yield ("con", AffineConstraint.make_type(dim=1, eq=False), 1) 

144 

145 @classmethod 

146 def convert(cls, con, options): 

147 """Implement :meth:`~.constraint.ConstraintConversion.convert`.""" 

148 from ..expressions import Constant 

149 from ..expressions.algebra import block, rsoc 

150 from ..modeling import Problem 

151 

152 x = con.power.x 

153 n = con.power.n 

154 num = con.power.num 

155 den = con.power.den 

156 rhs = con.rhs 

157 m = con.power.m 

158 

159 vtype = "hermitian" if x.complex else "symmetric" 

160 

161 P = Problem() 

162 

163 if n == 1: 

164 idt = Constant('1', 1) 

165 if m is None: 

166 varcnt = 0 

167 v = [] 

168 else: 

169 varcnt = 1 

170 v = [P.add_variable('__v[0]', 1)] 

171 else: 

172 idt = Constant('I', cvx.spdiag([1.] * n)) 

173 varcnt = 1 

174 v = [P.add_variable('__v[0]', (n, n), vtype)] 

175 

176 if con.relation == Constraint.LE and num > den: 

177 pown = cls._np2(num) 

178 

179 if n == 1: 

180 lis = [rhs]*den + [x]*(pown - num) + [idt]*(num - den) 

181 else: 

182 lis = [v[0]]*den + [x]*(pown - num) + [idt]*(num - den) 

183 

184 while len(lis) > 2: 

185 newlis = [] 

186 while lis: 

187 v1 = lis.pop() 

188 v2 = lis.pop() 

189 

190 if v1 is v2: 

191 newlis.append(v2) 

192 else: 

193 if n == 1: 

194 v0 = P.add_variable( 

195 '__v[' + str(varcnt) + ']', 1) 

196 P.add_constraint((v1 & v2 & v0) << rsoc()) 

197 else: 

198 v0 = P.add_variable( 

199 '__v[' + str(varcnt) + ']', (n, n), vtype) 

200 P.add_constraint( 

201 block([[v1, v0], [v0, v2]]) >> 0) 

202 

203 varcnt += 1 

204 newlis.append(v0) 

205 v.append(v0) 

206 lis = newlis 

207 

208 if n == 1: 

209 P.add_constraint((lis[0] & lis[1] & x) << rsoc()) 

210 else: 

211 P.add_constraint(block([[lis[0], x], [x, lis[1]]]) >> 0) 

212 P.add_constraint((idt | v[0]) <= rhs) 

213 elif con.relation == Constraint.LE and num <= den: 

214 num = abs(num) 

215 den = abs(den) 

216 

217 pown = cls._np2(num + den) 

218 

219 if n == 1: 

220 lis = [rhs] * den + [x] * num + [idt] * (pown - num - den) 

221 else: 

222 lis = [v[0]] * den + [x] * num + [idt] * (pown - num - den) 

223 

224 while len(lis) > 2: 

225 newlis = [] 

226 while lis: 

227 v1 = lis.pop() 

228 v2 = lis.pop() 

229 

230 if v1 is v2: 

231 newlis.append(v2) 

232 else: 

233 if n == 1: 

234 v0 = P.add_variable( 

235 '__v[' + str(varcnt) + ']', 1) 

236 P.add_constraint((v1 & v2 & v0) << rsoc()) 

237 else: 

238 v0 = P.add_variable( 

239 '__v[' + str(varcnt) + ']', (n, n), vtype) 

240 P.add_constraint( 

241 block([[v1, v0], [v0, v2]]) >> 0) 

242 

243 varcnt += 1 

244 newlis.append(v0) 

245 v.append(v0) 

246 lis = newlis 

247 

248 if n == 1: 

249 P.add_constraint((lis[0] & lis[1] & 1) << rsoc()) 

250 else: 

251 P.add_constraint(block([[lis[0], idt], [idt, lis[1]]]) >> 0) 

252 P.add_constraint((idt | v[0]) <= rhs) 

253 elif con.relation == Constraint.GE: 

254 pown = cls._np2(den) 

255 

256 if n == 1: 

257 lis = [x]*num + [rhs]*(pown - den) + [idt]*(den - num) 

258 

259 else: 

260 lis = [x]*num + [v[0]]*(pown - den) + [idt]*(den - num) 

261 

262 while len(lis) > 2: 

263 newlis = [] 

264 while lis: 

265 v1 = lis.pop() 

266 v2 = lis.pop() 

267 

268 if v1 is v2: 

269 newlis.append(v2) 

270 else: 

271 if n == 1: 

272 v0 = P.add_variable( 

273 '__v[' + str(varcnt) + ']', 1) 

274 P.add_constraint((v1 & v2 & v0) << rsoc()) 

275 else: 

276 v0 = P.add_variable( 

277 '__v[' + str(varcnt) + ']', (n, n), vtype) 

278 P.add_constraint( 

279 block([[v1, v0], [v0, v2]]) >> 0) 

280 

281 varcnt += 1 

282 newlis.append(v0) 

283 v.append(v0) 

284 lis = newlis 

285 

286 if n == 1: 

287 if m is None: 

288 P.add_constraint((lis[0] & lis[1] & rhs) << rsoc()) 

289 else: 

290 P.add_constraint((lis[0] & lis[1] & v[0]) << rsoc()) 

291 P.add_constraint((m * v[0]) > rhs) 

292 else: 

293 P.add_constraint( 

294 block([[lis[0], v[0]], [v[0], lis[1]]]) >> 0) 

295 if m is None: 

296 P.add_constraint((idt | v[0]) > rhs) 

297 else: 

298 P.add_constraint((m | v[0]) > rhs) 

299 else: 

300 assert False, "Dijkstra-IF fallthrough." 

301 

302 return P 

303 

304 def __init__(self, power, relation, rhs): 

305 """Construct a :class:`PowerTraceConstraint`. 

306 

307 :param ~picos.expressions.PowerTrace ower: 

308 Left hand side expression. 

309 :param str relation: 

310 Constraint relation symbol. 

311 :param ~picos.expressions.AffineExpression rhs: 

312 Right hand side expression. 

313 """ 

314 from ..expressions import AffineExpression, PowerTrace 

315 

316 assert isinstance(power, PowerTrace) 

317 assert relation in self.LE + self.GE 

318 assert isinstance(rhs, AffineExpression) 

319 assert len(rhs) == 1 

320 

321 p = power.p 

322 

323 assert p != 0 and p != 1, \ 

324 "The PowerTraceConstraint should not be created for p = 0 and " \ 

325 "p = 1 as there are more direct ways to represent such powers." 

326 

327 if relation == self.LE: 

328 assert p <= 0 or p >= 1, \ 

329 "Upper bounding p-th power needs p s.t. the power is convex." 

330 else: 

331 assert p >= 0 and p <= 1, \ 

332 "Lower bounding p-th power needs p s.t. the power is concave." 

333 

334 self.power = power 

335 self.relation = relation 

336 self.rhs = rhs 

337 

338 super(PowerTraceConstraint, self).__init__(power._typeStr) 

339 

340 # HACK: Support Constraint's LHS/RHS interface. 

341 # TODO: Add a unified interface for such constraints? 

342 lhs = property(lambda self: self.power) 

343 

344 def is_trace(self): 

345 """Whether the bound concerns a trace as opposed to a scalar.""" 

346 return self.power.n > 1 

347 

348 Subtype = namedtuple("Subtype", ("diag", "num", "den", "hasM", "complex")) 

349 

350 def _subtype(self): 

351 return self.Subtype(*self.power.subtype) 

352 

353 @classmethod 

354 def _cost(cls, subtype): 

355 n = subtype.diag 

356 if subtype.complex: 

357 return n**2 + 1 

358 else: 

359 return n*(n + 1)//2 + 1 

360 

361 def _expression_names(self): 

362 yield "power" 

363 yield "rhs" 

364 

365 def _str(self): 

366 if self.relation == self.LE: 

367 return glyphs.le(self.power.string, self.rhs.string) 

368 else: 

369 return glyphs.ge(self.power.string, self.rhs.string) 

370 

371 def _get_slack(self): 

372 if self.relation == self.LE: 

373 return self.rhs.safe_value - self.power.safe_value 

374 else: 

375 return self.power.safe_value - self.rhs.safe_value 

376 

377 

378# -------------------------------------- 

379__all__ = api_end(_API_START, globals())