Coverage for picos/expressions/exp_wsum.py: 81.87%
171 statements
« prev ^ index » next coverage.py v6.5.0, created at 2023-03-26 07:46 +0000
« prev ^ index » next coverage.py v6.5.0, created at 2023-03-26 07:46 +0000
1# ------------------------------------------------------------------------------
2# Copyright (C) 2021 Maximilian Stahlberg
3#
4# This file is part of PICOS.
5#
6# PICOS is free software: you can redistribute it and/or modify it under the
7# terms of the GNU General Public License as published by the Free Software
8# Foundation, either version 3 of the License, or (at your option) any later
9# version.
10#
11# PICOS is distributed in the hope that it will be useful, but WITHOUT ANY
12# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
13# A PARTICULAR PURPOSE. See the GNU General Public License for more details.
14#
15# You should have received a copy of the GNU General Public License along with
16# this program. If not, see <http://www.gnu.org/licenses/>.
17# ------------------------------------------------------------------------------
19"""Implements the :class:`WeightedSum` fallback class."""
21import operator
22from collections import namedtuple
23from functools import reduce
25import cvxopt
27from .. import glyphs
28from ..apidoc import api_end, api_start
29from ..caching import cached_property, cached_selfinverse_unary_operator
30from ..constraints import Constraint, WeightedSumConstraint
31from .data import convert_operands, load_dense_data
32from .exp_affine import AffineExpression, Constant
33from .expression import Expression, refine_operands, validate_prediction
35_API_START = api_start(globals())
36# -------------------------------
39class WeightedSum(Expression):
40 """A convex or concave weighted sum of scalar expressions."""
42 # --------------------------------------------------------------------------
43 # Initialization and properties.
44 # --------------------------------------------------------------------------
46 def __init__(self, expressions, weights=1, opstring=None):
47 """Construct a weighted sum of expressions.
49 :param expressions:
50 A collection of scalar expressions.
52 :param weights:
53 A constant weight vector.
55 :param str opstring:
56 Used by PICOS internally when this class is tried as a last fallback
57 to represent the result of an otherwise unsupported product or sum.
58 """
59 try:
60 # Avoid iterating over affine expressions.
61 if isinstance(expressions, Expression):
62 raise TypeError("{} is not designed to represent the sum over "
63 "(the elements of) a single expression. Use picos.sum to "
64 "select the correct class automatically."
65 .format(self.__class__.__name__))
67 # Load constant data and refine expressions.
68 expressions = tuple(
69 x.refined if isinstance(x, Expression) else Constant(x)
70 for x in expressions)
72 if not expressions:
73 raise ValueError("Need at least one expression.")
75 # Require that every expression is scalar.
76 if not all(x.scalar for x in expressions):
77 raise TypeError("Not all summands are scalar.")
79 # Load weights as a CVXOPT dense column vector.
80 weights = load_dense_data(weights, (len(expressions), 1), "d")[0]
82 # Never create a nested WeightedSum.
83 # NOTE: This ensures that WeightedSumConstraintReformulation needs
84 # to run just once to get rid of all WeightedSumConstraint.
85 if any(isinstance(x, WeightedSum) for x in expressions):
86 ux, uw = [], [] # Unpacked expressions/weights.
88 for x, w in zip(expressions, weights):
89 if isinstance(x, WeightedSum):
90 ux.extend(x._expressions)
91 uw.extend(w * x._weights)
92 else:
93 ux.append(x)
94 uw.append(w)
96 assert not any(isinstance(x, WeightedSum) for x in ux)
98 expressions = tuple(ux)
99 weights = load_dense_data(uw, (len(ux), 1), "d")[0]
101 # Determine convexity of expressions.
102 convex = all((x.convex and w >= 0) or (x.concave and w <= 0)
103 for x, w in zip(expressions, weights))
104 concave = all((x.concave and w >= 0) or (x.convex and w <= 0)
105 for x, w in zip(expressions, weights))
107 # Don't handle uncertain expressions.
108 # TODO: Consider handling sums with one uncertain summand.
109 if any(x.uncertain for x in expressions):
110 raise NotImplementedError(
111 "{} does not handle uncertain summands at this point."
112 .format(self.__class__.__name__))
114 self._expressions = expressions
115 self._weights = weights
116 self._convex = convex
117 self._concave = concave
119 typeStrWords = []
121 if convex and concave:
122 typeStrWords.append("Affine") # Manually crafted.
123 elif convex:
124 typeStrWords.append("Convex")
125 elif concave:
126 typeStrWords.append("Concave")
128 if not all(w == 1 for w in weights):
129 typeStrWords.append("Weighted")
131 typeStrWords.append("Sum")
133 typeStr = " ".join(typeStrWords)
134 symbStr = reduce(glyphs.clever_add, (
135 glyphs.clever_mul(glyphs.scalar(w), x.string)
136 for w, x in zip(weights, expressions)))
138 Expression.__init__(self, typeStr, symbStr)
139 except Exception as error:
140 if not opstring:
141 raise
143 raise TypeError("Cannot represent {} as a weighted sum: {}"
144 .format(opstring, error)) from None
146 @property
147 def expressions(self):
148 """The expressions being summed, without their coefficients."""
149 return self._expressions
151 @cached_property
152 def weights(self):
153 """The coefficient vector as a PICOS column vector."""
154 return Constant("w", self._weights)
156 # --------------------------------------------------------------------------
157 # Abstract method implementations for Expression, except _predict.
158 # --------------------------------------------------------------------------
160 # TODO: Merge expressions that can be merged.
161 def _get_refined(self):
162 if not self._weights:
163 return AffineExpression.zero()
164 elif all(x.constant for x in self._expressions):
165 return Constant(self.string, self.safe_value, (1, 1))
166 elif len(self._expressions) == 1 and self._weights[0] == 1:
167 return self._expressions[0]
168 elif 0 in self._weights:
169 return self.__class__(*(zip(*(
170 ew for ew in zip(self._expressions, self._weights) if ew[1]))))
171 else:
172 return self
174 Subtype = namedtuple("Subtype", (
175 "convex", "concave", "types", "nonneg_weights"))
177 def _get_subtype(self):
178 return self.Subtype(self.convex, self.concave,
179 tuple(x.type for x in self._expressions),
180 tuple(self.weights.np >= 0))
182 def _get_value(self):
183 values = cvxopt.matrix([x.safe_value for x in self._expressions],
184 (1, len(self._expressions)))
185 return values * self._weights
187 def _get_mutables(self):
188 return reduce(
189 frozenset.union, (x._get_mutables() for x in self._expressions))
191 def _is_convex(self):
192 return self._convex
194 def _is_concave(self):
195 return self._concave
197 def _replace_mutables(self, mapping):
198 return self.__class__(
199 (x._replace_mutables(mapping) for x in self._expressions),
200 self._weights)
202 def _freeze_mutables(self, freeze):
203 return self.__class__(
204 (x._freeze_mutables(freeze) for x in self._expressions),
205 self._weights)
207 # --------------------------------------------------------------------------
208 # Python special method implementations, except constraint-creating ones.
209 # NOTE: WeightedSum is used by Expression as a fallback class, so all
210 # operations are concluded here (return result or raise exception).
211 # --------------------------------------------------------------------------
213 @cached_selfinverse_unary_operator
214 def __neg__(self):
215 return self.__class__(self._expressions, -self._weights)
217 @convert_operands(scalarRHS=True)
218 @refine_operands()
219 def __add__(self, other):
220 opstring = "{} plus {}".format(repr(self), repr(other))
221 return self.__class__(self._expressions + (other,),
222 cvxopt.matrix([self._weights, 1]), opstring)
224 @convert_operands(scalarRHS=True)
225 @refine_operands()
226 def __radd__(self, other):
227 opstring = "{} plus {}".format(repr(other), repr(self))
228 return self.__class__((other,) + self._expressions,
229 cvxopt.matrix([1, self._weights]), opstring)
231 @convert_operands(scalarRHS=True)
232 @refine_operands()
233 def __sub__(self, other):
234 opstring = "{} minus {}".format(repr(self), repr(other))
235 return self.__class__(self._expressions + (other,),
236 cvxopt.matrix([self._weights, -1]), opstring)
238 @convert_operands(scalarRHS=True)
239 @refine_operands()
240 def __rsub__(self, other):
241 opstring = "{} minus {}".format(repr(other), repr(self))
242 return self.__class__((other,) + self._expressions,
243 cvxopt.matrix([1, -self._weights]), opstring)
245 def _mul(self, other, forward):
246 if isinstance(other, AffineExpression) and other.constant:
247 value = other.safe_value
249 if value == 0:
250 return Constant(0)
251 elif value == 1:
252 return self
253 else:
254 p = self.__class__(self._expressions, value*self._weights)
256 if forward:
257 p._symbStr = glyphs.clever_mul(self.string, other.string)
258 else:
259 p._symbStr = glyphs.clever_mul(other.string, self.string)
261 return p
262 else:
263 return NotImplemented
265 @convert_operands(scalarRHS=True)
266 @refine_operands()
267 def __mul__(self, other):
268 return self._mul(other, True)
270 @convert_operands(scalarRHS=True)
271 @refine_operands()
272 def __rmul__(self, other):
273 return self._mul(other, False)
275 # --------------------------------------------------------------------------
276 # Constraint-creating operators, and _predict.
277 # --------------------------------------------------------------------------
279 @classmethod
280 def _predict(cls, subtype, relation, other):
281 assert isinstance(subtype, cls.Subtype)
283 if relation == operator.__le__:
284 if not subtype.convex:
285 return NotImplemented
287 if not issubclass(other.clstype, AffineExpression) \
288 or other.subtype.dim != 1:
289 return NotImplemented
291 return WeightedSumConstraint.make_type(
292 lhs_types=subtype.types, relation=Constraint.LE, rhs_type=other,
293 nonneg_weights=subtype.nonneg_weights)
294 elif relation == operator.__ge__:
295 if not subtype.concave:
296 return NotImplemented
298 if not issubclass(other.clstype, AffineExpression) \
299 or other.subtype.dim != 1:
300 return NotImplemented
302 return WeightedSumConstraint.make_type(
303 lhs_types=subtype.types, relation=Constraint.GE, rhs_type=other,
304 nonneg_weights=subtype.nonneg_weights)
306 return NotImplemented
308 @convert_operands(scalarRHS=True)
309 @validate_prediction
310 @refine_operands()
311 def __le__(self, other):
312 if not self.convex:
313 raise TypeError("Cannot upper-bound the nonconvex expression {}."
314 .format(self.string))
316 if isinstance(other, AffineExpression):
317 return WeightedSumConstraint(self, Constraint.LE, other)
319 return NotImplemented
321 @convert_operands(scalarRHS=True)
322 @validate_prediction
323 @refine_operands()
324 def __ge__(self, other):
325 if not self.concave:
326 raise TypeError("Cannot lower-bound the nonconcave expression {}."
327 .format(self.string))
329 if isinstance(other, AffineExpression):
330 return WeightedSumConstraint(self, Constraint.GE, other)
332 return NotImplemented
335# --------------------------------------
336__all__ = api_end(_API_START, globals())