Coverage for picos/expressions/exp_wsum.py: 81.87%

171 statements  

« prev     ^ index     » next       coverage.py v6.5.0, created at 2023-03-26 07:46 +0000

1# ------------------------------------------------------------------------------ 

2# Copyright (C) 2021 Maximilian Stahlberg 

3# 

4# This file is part of PICOS. 

5# 

6# PICOS is free software: you can redistribute it and/or modify it under the 

7# terms of the GNU General Public License as published by the Free Software 

8# Foundation, either version 3 of the License, or (at your option) any later 

9# version. 

10# 

11# PICOS is distributed in the hope that it will be useful, but WITHOUT ANY 

12# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR 

13# A PARTICULAR PURPOSE. See the GNU General Public License for more details. 

14# 

15# You should have received a copy of the GNU General Public License along with 

16# this program. If not, see <http://www.gnu.org/licenses/>. 

17# ------------------------------------------------------------------------------ 

18 

19"""Implements the :class:`WeightedSum` fallback class.""" 

20 

21import operator 

22from collections import namedtuple 

23from functools import reduce 

24 

25import cvxopt 

26 

27from .. import glyphs 

28from ..apidoc import api_end, api_start 

29from ..caching import cached_property, cached_selfinverse_unary_operator 

30from ..constraints import Constraint, WeightedSumConstraint 

31from .data import convert_operands, load_dense_data 

32from .exp_affine import AffineExpression, Constant 

33from .expression import Expression, refine_operands, validate_prediction 

34 

35_API_START = api_start(globals()) 

36# ------------------------------- 

37 

38 

39class WeightedSum(Expression): 

40 """A convex or concave weighted sum of scalar expressions.""" 

41 

42 # -------------------------------------------------------------------------- 

43 # Initialization and properties. 

44 # -------------------------------------------------------------------------- 

45 

46 def __init__(self, expressions, weights=1, opstring=None): 

47 """Construct a weighted sum of expressions. 

48 

49 :param expressions: 

50 A collection of scalar expressions. 

51 

52 :param weights: 

53 A constant weight vector. 

54 

55 :param str opstring: 

56 Used by PICOS internally when this class is tried as a last fallback 

57 to represent the result of an otherwise unsupported product or sum. 

58 """ 

59 try: 

60 # Avoid iterating over affine expressions. 

61 if isinstance(expressions, Expression): 

62 raise TypeError("{} is not designed to represent the sum over " 

63 "(the elements of) a single expression. Use picos.sum to " 

64 "select the correct class automatically." 

65 .format(self.__class__.__name__)) 

66 

67 # Load constant data and refine expressions. 

68 expressions = tuple( 

69 x.refined if isinstance(x, Expression) else Constant(x) 

70 for x in expressions) 

71 

72 if not expressions: 

73 raise ValueError("Need at least one expression.") 

74 

75 # Require that every expression is scalar. 

76 if not all(x.scalar for x in expressions): 

77 raise TypeError("Not all summands are scalar.") 

78 

79 # Load weights as a CVXOPT dense column vector. 

80 weights = load_dense_data(weights, (len(expressions), 1), "d")[0] 

81 

82 # Never create a nested WeightedSum. 

83 # NOTE: This ensures that WeightedSumConstraintReformulation needs 

84 # to run just once to get rid of all WeightedSumConstraint. 

85 if any(isinstance(x, WeightedSum) for x in expressions): 

86 ux, uw = [], [] # Unpacked expressions/weights. 

87 

88 for x, w in zip(expressions, weights): 

89 if isinstance(x, WeightedSum): 

90 ux.extend(x._expressions) 

91 uw.extend(w * x._weights) 

92 else: 

93 ux.append(x) 

94 uw.append(w) 

95 

96 assert not any(isinstance(x, WeightedSum) for x in ux) 

97 

98 expressions = tuple(ux) 

99 weights = load_dense_data(uw, (len(ux), 1), "d")[0] 

100 

101 # Determine convexity of expressions. 

102 convex = all((x.convex and w >= 0) or (x.concave and w <= 0) 

103 for x, w in zip(expressions, weights)) 

104 concave = all((x.concave and w >= 0) or (x.convex and w <= 0) 

105 for x, w in zip(expressions, weights)) 

106 

107 # Don't handle uncertain expressions. 

108 # TODO: Consider handling sums with one uncertain summand. 

109 if any(x.uncertain for x in expressions): 

110 raise NotImplementedError( 

111 "{} does not handle uncertain summands at this point." 

112 .format(self.__class__.__name__)) 

113 

114 self._expressions = expressions 

115 self._weights = weights 

116 self._convex = convex 

117 self._concave = concave 

118 

119 typeStrWords = [] 

120 

121 if convex and concave: 

122 typeStrWords.append("Affine") # Manually crafted. 

123 elif convex: 

124 typeStrWords.append("Convex") 

125 elif concave: 

126 typeStrWords.append("Concave") 

127 

128 if not all(w == 1 for w in weights): 

129 typeStrWords.append("Weighted") 

130 

131 typeStrWords.append("Sum") 

132 

133 typeStr = " ".join(typeStrWords) 

134 symbStr = reduce(glyphs.clever_add, ( 

135 glyphs.clever_mul(glyphs.scalar(w), x.string) 

136 for w, x in zip(weights, expressions))) 

137 

138 Expression.__init__(self, typeStr, symbStr) 

139 except Exception as error: 

140 if not opstring: 

141 raise 

142 

143 raise TypeError("Cannot represent {} as a weighted sum: {}" 

144 .format(opstring, error)) from None 

145 

146 @property 

147 def expressions(self): 

148 """The expressions being summed, without their coefficients.""" 

149 return self._expressions 

150 

151 @cached_property 

152 def weights(self): 

153 """The coefficient vector as a PICOS column vector.""" 

154 return Constant("w", self._weights) 

155 

156 # -------------------------------------------------------------------------- 

157 # Abstract method implementations for Expression, except _predict. 

158 # -------------------------------------------------------------------------- 

159 

160 # TODO: Merge expressions that can be merged. 

161 def _get_refined(self): 

162 if not self._weights: 

163 return AffineExpression.zero() 

164 elif all(x.constant for x in self._expressions): 

165 return Constant(self.string, self.safe_value, (1, 1)) 

166 elif len(self._expressions) == 1 and self._weights[0] == 1: 

167 return self._expressions[0] 

168 elif 0 in self._weights: 

169 return self.__class__(*(zip(*( 

170 ew for ew in zip(self._expressions, self._weights) if ew[1])))) 

171 else: 

172 return self 

173 

174 Subtype = namedtuple("Subtype", ( 

175 "convex", "concave", "types", "nonneg_weights")) 

176 

177 def _get_subtype(self): 

178 return self.Subtype(self.convex, self.concave, 

179 tuple(x.type for x in self._expressions), 

180 tuple(self.weights.np >= 0)) 

181 

182 def _get_value(self): 

183 values = cvxopt.matrix([x.safe_value for x in self._expressions], 

184 (1, len(self._expressions))) 

185 return values * self._weights 

186 

187 def _get_mutables(self): 

188 return reduce( 

189 frozenset.union, (x._get_mutables() for x in self._expressions)) 

190 

191 def _is_convex(self): 

192 return self._convex 

193 

194 def _is_concave(self): 

195 return self._concave 

196 

197 def _replace_mutables(self, mapping): 

198 return self.__class__( 

199 (x._replace_mutables(mapping) for x in self._expressions), 

200 self._weights) 

201 

202 def _freeze_mutables(self, freeze): 

203 return self.__class__( 

204 (x._freeze_mutables(freeze) for x in self._expressions), 

205 self._weights) 

206 

207 # -------------------------------------------------------------------------- 

208 # Python special method implementations, except constraint-creating ones. 

209 # NOTE: WeightedSum is used by Expression as a fallback class, so all 

210 # operations are concluded here (return result or raise exception). 

211 # -------------------------------------------------------------------------- 

212 

213 @cached_selfinverse_unary_operator 

214 def __neg__(self): 

215 return self.__class__(self._expressions, -self._weights) 

216 

217 @convert_operands(scalarRHS=True) 

218 @refine_operands() 

219 def __add__(self, other): 

220 opstring = "{} plus {}".format(repr(self), repr(other)) 

221 return self.__class__(self._expressions + (other,), 

222 cvxopt.matrix([self._weights, 1]), opstring) 

223 

224 @convert_operands(scalarRHS=True) 

225 @refine_operands() 

226 def __radd__(self, other): 

227 opstring = "{} plus {}".format(repr(other), repr(self)) 

228 return self.__class__((other,) + self._expressions, 

229 cvxopt.matrix([1, self._weights]), opstring) 

230 

231 @convert_operands(scalarRHS=True) 

232 @refine_operands() 

233 def __sub__(self, other): 

234 opstring = "{} minus {}".format(repr(self), repr(other)) 

235 return self.__class__(self._expressions + (other,), 

236 cvxopt.matrix([self._weights, -1]), opstring) 

237 

238 @convert_operands(scalarRHS=True) 

239 @refine_operands() 

240 def __rsub__(self, other): 

241 opstring = "{} minus {}".format(repr(other), repr(self)) 

242 return self.__class__((other,) + self._expressions, 

243 cvxopt.matrix([1, -self._weights]), opstring) 

244 

245 def _mul(self, other, forward): 

246 if isinstance(other, AffineExpression) and other.constant: 

247 value = other.safe_value 

248 

249 if value == 0: 

250 return Constant(0) 

251 elif value == 1: 

252 return self 

253 else: 

254 p = self.__class__(self._expressions, value*self._weights) 

255 

256 if forward: 

257 p._symbStr = glyphs.clever_mul(self.string, other.string) 

258 else: 

259 p._symbStr = glyphs.clever_mul(other.string, self.string) 

260 

261 return p 

262 else: 

263 return NotImplemented 

264 

265 @convert_operands(scalarRHS=True) 

266 @refine_operands() 

267 def __mul__(self, other): 

268 return self._mul(other, True) 

269 

270 @convert_operands(scalarRHS=True) 

271 @refine_operands() 

272 def __rmul__(self, other): 

273 return self._mul(other, False) 

274 

275 # -------------------------------------------------------------------------- 

276 # Constraint-creating operators, and _predict. 

277 # -------------------------------------------------------------------------- 

278 

279 @classmethod 

280 def _predict(cls, subtype, relation, other): 

281 assert isinstance(subtype, cls.Subtype) 

282 

283 if relation == operator.__le__: 

284 if not subtype.convex: 

285 return NotImplemented 

286 

287 if not issubclass(other.clstype, AffineExpression) \ 

288 or other.subtype.dim != 1: 

289 return NotImplemented 

290 

291 return WeightedSumConstraint.make_type( 

292 lhs_types=subtype.types, relation=Constraint.LE, rhs_type=other, 

293 nonneg_weights=subtype.nonneg_weights) 

294 elif relation == operator.__ge__: 

295 if not subtype.concave: 

296 return NotImplemented 

297 

298 if not issubclass(other.clstype, AffineExpression) \ 

299 or other.subtype.dim != 1: 

300 return NotImplemented 

301 

302 return WeightedSumConstraint.make_type( 

303 lhs_types=subtype.types, relation=Constraint.GE, rhs_type=other, 

304 nonneg_weights=subtype.nonneg_weights) 

305 

306 return NotImplemented 

307 

308 @convert_operands(scalarRHS=True) 

309 @validate_prediction 

310 @refine_operands() 

311 def __le__(self, other): 

312 if not self.convex: 

313 raise TypeError("Cannot upper-bound the nonconvex expression {}." 

314 .format(self.string)) 

315 

316 if isinstance(other, AffineExpression): 

317 return WeightedSumConstraint(self, Constraint.LE, other) 

318 

319 return NotImplemented 

320 

321 @convert_operands(scalarRHS=True) 

322 @validate_prediction 

323 @refine_operands() 

324 def __ge__(self, other): 

325 if not self.concave: 

326 raise TypeError("Cannot lower-bound the nonconcave expression {}." 

327 .format(self.string)) 

328 

329 if isinstance(other, AffineExpression): 

330 return WeightedSumConstraint(self, Constraint.GE, other) 

331 

332 return NotImplemented 

333 

334 

335# -------------------------------------- 

336__all__ = api_end(_API_START, globals())