Coverage for picos/constraints/con_detrootn.py: 96.49%

57 statements  

« prev     ^ index     » next       coverage.py v6.5.0, created at 2023-03-26 07:46 +0000

1# ------------------------------------------------------------------------------ 

2# Copyright (C) 2012-2017 Guillaume Sagnol 

3# Copyright (C) 2018-2019 Maximilian Stahlberg 

4# 

5# This file is part of PICOS. 

6# 

7# PICOS is free software: you can redistribute it and/or modify it under the 

8# terms of the GNU General Public License as published by the Free Software 

9# Foundation, either version 3 of the License, or (at your option) any later 

10# version. 

11# 

12# PICOS is distributed in the hope that it will be useful, but WITHOUT ANY 

13# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR 

14# A PARTICULAR PURPOSE. See the GNU General Public License for more details. 

15# 

16# You should have received a copy of the GNU General Public License along with 

17# this program. If not, see <http://www.gnu.org/licenses/>. 

18# ------------------------------------------------------------------------------ 

19 

20"""Implementation of :class:`DetRootNConstraint`.""" 

21 

22from collections import namedtuple 

23 

24from .. import glyphs 

25from ..apidoc import api_end, api_start 

26from .constraint import Constraint, ConstraintConversion 

27 

28_API_START = api_start(globals()) 

29# ------------------------------- 

30 

31 

32class DetRootNConstraint(Constraint): 

33 """Lower bound on the :math:`n`-th root of a matrix determinant.""" 

34 

35 class Conversion(ConstraintConversion): 

36 """:math:`n`-th root of a matrix determinant constraint conversion.""" 

37 

38 @classmethod 

39 def predict(cls, subtype, options): 

40 """Implement :meth:`~.constraint.ConstraintConversion.predict`.""" 

41 from ..expressions import LowerTriangularVariable 

42 from . import (ComplexLMIConstraint, GeometricMeanConstraint, 

43 LMIConstraint) 

44 

45 n = subtype.diag 

46 r = (n * (n + 1)) // 2 

47 

48 yield ( 

49 "var", LowerTriangularVariable.make_var_type(dim=r, bnd=0), 1) 

50 

51 if subtype.complex: 

52 yield ("con", ComplexLMIConstraint.make_type(diag=2*n), 1) 

53 else: 

54 yield ("con", LMIConstraint.make_type(diag=2*n), 1) 

55 

56 yield ("con", GeometricMeanConstraint.make_type(argdim=n), 1) 

57 

58 @classmethod 

59 def convert(cls, con, options): 

60 """Implement :meth:`~.constraint.ConstraintConversion.convert`.""" 

61 from ..modeling import Problem 

62 from ..expressions import GeometricMean, LowerTriangularVariable 

63 from ..expressions.algebra import block 

64 

65 n = con.detRootN.n 

66 

67 P = Problem() 

68 L = LowerTriangularVariable("__L", n) 

69 d = L.maindiag 

70 D = d.diag 

71 P.add_constraint(block([[con.detRootN.x, L], [L.T, D]]) >> 0) 

72 P.add_constraint(GeometricMean(d) >= con.lowerBound) 

73 return P 

74 

75 def __init__(self, detRootN, lowerBound): 

76 """Construct a :class:`DetRootNConstraint`. 

77 

78 :param ~picos.expressions.DetRootN detRootN: 

79 Constrained expression. 

80 :param ~picos.expressions.AffineExpression lowerBound: 

81 Lower bound on the expression. 

82 """ 

83 from ..expressions import AffineExpression, DetRootN 

84 

85 assert isinstance(detRootN, DetRootN) 

86 assert isinstance(lowerBound, AffineExpression) 

87 assert len(lowerBound) == 1 

88 

89 self.detRootN = detRootN 

90 self.lowerBound = lowerBound 

91 

92 super(DetRootNConstraint, self).__init__(detRootN._typeStr) 

93 

94 Subtype = namedtuple("Subtype", ("diag", "complex")) 

95 

96 def _subtype(self): 

97 return self.Subtype(self.detRootN.n, self.detRootN.x.complex) 

98 

99 @classmethod 

100 def _cost(cls, subtype): 

101 n = subtype.diag 

102 

103 if subtype.complex: 

104 return n**2 + 1 

105 else: 

106 return n*(n + 1)//2 + 1 

107 

108 def _expression_names(self): 

109 yield "detRootN" 

110 yield "lowerBound" 

111 

112 def _str(self): 

113 return glyphs.ge(self.detRootN.string, self.lowerBound.string) 

114 

115 def _get_slack(self): 

116 return self.detRootN.safe_value - self.lowerBound.safe_value 

117 

118 

119# -------------------------------------- 

120__all__ = api_end(_API_START, globals())