Coverage for picos/expressions/exp_specnorm.py: 83.62%

116 statements  

« prev     ^ index     » next       coverage.py v6.5.0, created at 2023-03-26 07:46 +0000

1# ------------------------------------------------------------------------------ 

2# Copyright (C) 2020 Guillaume Sagnol 

3# 

4# This file is part of PICOS. 

5# 

6# PICOS is free software: you can redistribute it and/or modify it under the 

7# terms of the GNU General Public License as published by the Free Software 

8# Foundation, either version 3 of the License, or (at your option) any later 

9# version. 

10# 

11# PICOS is distributed in the hope that it will be useful, but WITHOUT ANY 

12# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR 

13# A PARTICULAR PURPOSE. See the GNU General Public License for more details. 

14# 

15# You should have received a copy of the GNU General Public License along with 

16# this program. If not, see <http://www.gnu.org/licenses/>. 

17# ------------------------------------------------------------------------------ 

18 

19"""Implements :class:`SpectralNorm`.""" 

20 

21import operator 

22from collections import namedtuple 

23 

24import cvxopt 

25import numpy 

26 

27from .. import glyphs 

28from ..apidoc import api_end, api_start 

29from ..caching import cached_unary_operator 

30from ..constraints import AbsoluteValueConstraint, SpectralNormConstraint 

31from .data import convert_and_refine_arguments, convert_operands, cvx2np 

32from .exp_affine import AffineExpression, ComplexAffineExpression 

33from .exp_norm import Norm 

34from .expression import Expression, refine_operands, validate_prediction 

35 

36_API_START = api_start(globals()) 

37# ------------------------------- 

38 

39 

40class SpectralNorm(Expression): 

41 r"""The spectral norm of a matrix. 

42 

43 This class can represent the spectral norm of a matrix-affine expression 

44 (real- or complex valued). The spectral norm is convex, so we can form 

45 expressions of the form ``SpectralNorm(X) <= t`` which are typically 

46 reformulated as LMIs that can be handled by SDP solvers. 

47 

48 :Definition: 

49 

50 If the normed expression is a matrix :math:`X`, then its spectral norm is 

51 

52 .. math:: 

53 

54 \|X\|_2 = \max \{ \|Xu\|_2 : \|u\| \leq 1\} 

55 = \sqrt{\lambda_{\max}(XX^*)}, 

56 

57 where :math:`\lambda_{\max}(\cdot)` denotes the largest eigenvalue of 

58 a matrix, and :math:`X^*` denotes the adjoint matrix of :math:`X` 

59 (i.e., the transposed matrix :math:`X^T` if :math:`X` is real-valued). 

60 

61 Special cases: 

62 

63 - If :math:`X` is scalar, then :math:`\|X\|_2` reduces to the the absolute 

64 value (or modulus) :math:`|X|`. 

65 - If :math:`X` is scalar, then :math:`\|X\|_2` coincides with the 

66 Euclidean norm of :math:`X`. 

67 

68 """ 

69 

70 @convert_and_refine_arguments("x") 

71 def __init__(self, x): 

72 """Construct a :class:`SpectralNorm`. 

73 

74 :param x: The affine expression to take the norm of. 

75 :type x: ~picos.expressions.ComplexAffineExpression 

76 """ 

77 # Validate x. 

78 if not isinstance(x, ComplexAffineExpression): 

79 raise TypeError("Can only form the spectral norm of an affine " 

80 "expression, not of {}.".format(type(x).__name__)) 

81 

82 complex = not isinstance(x, AffineExpression) 

83 

84 # Build the string representations. 

85 if len(x) == 1: 

86 typeStr = "Modulus" if complex else "Absolute Value" 

87 symbStr = glyphs.abs(x.string) 

88 elif 1 in x.shape: 

89 typeStr = "Euclidean Norm" 

90 symbStr = glyphs.norm(x.string) 

91 else: 

92 typeStr = "Spectral Norm" 

93 symbStr = glyphs.spnorm(x.string) 

94 

95 if complex: 

96 typeStr = "Complex " + typeStr 

97 

98 self._x = x 

99 self._complex = complex 

100 Expression.__init__(self, typeStr, symbStr) 

101 

102 # -------------------------------------------------------------------------- 

103 # Abstract method implementations and method overridings, except _predict. 

104 # -------------------------------------------------------------------------- 

105 

106 @cached_unary_operator 

107 def _get_refined(self): 

108 if self._x.constant: 

109 return AffineExpression.from_constant(self.value, 1, self.string) 

110 elif len(self._x) == 1 or (1 in self._x.shape): 

111 return Norm(self._x) 

112 else: 

113 return self 

114 

115 Subtype = namedtuple("Subtype", ("argshape", "complex", "hermitian")) 

116 

117 def _get_subtype(self): 

118 return self.Subtype(self._x.shape, self._complex, self._x.hermitian) 

119 

120 def _get_value(self): 

121 value = self._x._get_value() 

122 value = cvx2np(value) 

123 value = numpy.linalg.norm(value, 2) 

124 return cvxopt.matrix(value) 

125 

126 def _get_mutables(self): 

127 return self._x._get_mutables() 

128 

129 def _is_convex(self): 

130 return True 

131 

132 def _is_concave(self): 

133 return False 

134 

135 def _replace_mutables(self, mapping): 

136 return self.__class__(self._x._replace_mutables(mapping)) 

137 

138 def _freeze_mutables(self, freeze): 

139 return self.__class__(self._x._freeze_mutables(freeze)) 

140 

141 # -------------------------------------------------------------------------- 

142 # Python special method implementations, except constraint-creating ones. 

143 # -------------------------------------------------------------------------- 

144 

145 @classmethod 

146 def _mul(cls, self, other, forward): 

147 if isinstance(other, AffineExpression) and other.constant: 

148 factor = other.safe_value 

149 

150 if not factor: 

151 return AffineExpression.zero() 

152 elif factor == 1: 

153 return self 

154 elif factor > 0: 

155 if forward: 

156 string = glyphs.clever_mul(self.string, other.string) 

157 else: 

158 string = glyphs.clever_mul(other.string, self.string) 

159 

160 norm = cls(other*self._x) 

161 norm._typeStr = "Scaled " + norm._typeStr 

162 norm._symbStr = string 

163 

164 return norm 

165 

166 if forward: 

167 return Expression.__mul__(self, other) 

168 else: 

169 return Expression.__rmul__(self, other) 

170 

171 @convert_operands(scalarRHS=True) 

172 @refine_operands() 

173 def __mul__(self, other): 

174 return SpectralNorm._mul(self, other, True) 

175 

176 @convert_operands(scalarRHS=True) 

177 @refine_operands() 

178 def __rmul__(self, other): 

179 return SpectralNorm._mul(self, other, False) 

180 

181 # -------------------------------------------------------------------------- 

182 # Methods and properties that return modified copies. 

183 # -------------------------------------------------------------------------- 

184 

185 @property 

186 def x(self): 

187 """Real expression whose norm equals that of the original expression.""" 

188 return self._x 

189 

190 # -------------------------------------------------------------------------- 

191 # Constraint-creating operators, and _predict. 

192 # -------------------------------------------------------------------------- 

193 

194 @classmethod 

195 def _predict(cls, subtype, relation, other): 

196 assert isinstance(subtype, cls.Subtype) 

197 

198 arg_shape, arg_complex, arg_hermitian = subtype 

199 xLen = arg_shape[0] * arg_shape[1] 

200 

201 if relation == operator.__le__: 

202 if issubclass(other.clstype, AffineExpression) \ 

203 and other.subtype.dim == 1: 

204 if xLen == 1: 

205 return AbsoluteValueConstraint.make_type() 

206 elif 1 in arg_shape: 

207 assert False, "Unexpected case (should have been refined)" 

208 else: 

209 return SpectralNormConstraint.make_type( 

210 arg_shape, arg_complex, arg_hermitian) 

211 elif relation == operator.__ge__: 

212 return NotImplemented # Not concave. 

213 

214 return NotImplemented 

215 

216 @convert_operands(scalarRHS=True) 

217 @validate_prediction 

218 @refine_operands() 

219 def __le__(self, other): 

220 assert self.convex 

221 

222 if isinstance(other, AffineExpression): 

223 if len(self._x) == 1: 

224 return AbsoluteValueConstraint(self._x, other) 

225 elif 1 in self._x.shape: 

226 assert False, "Unexpected case (should have been refined)" 

227 else: 

228 return SpectralNormConstraint(self, other) 

229 else: 

230 return NotImplemented 

231 

232 

233# -------------------------------------- 

234__all__ = api_end(_API_START, globals())