Coverage for picos/reforms/reform_constraint.py: 92.86%
98 statements
« prev ^ index » next coverage.py v6.5.0, created at 2023-02-15 14:21 +0000
« prev ^ index » next coverage.py v6.5.0, created at 2023-02-15 14:21 +0000
1# ------------------------------------------------------------------------------
2# Copyright (C) 2019 Maximilian Stahlberg
3#
4# This file is part of PICOS.
5#
6# PICOS is free software: you can redistribute it and/or modify it under the
7# terms of the GNU General Public License as published by the Free Software
8# Foundation, either version 3 of the License, or (at your option) any later
9# version.
10#
11# PICOS is distributed in the hope that it will be useful, but WITHOUT ANY
12# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
13# A PARTICULAR PURPOSE. See the GNU General Public License for more details.
14#
15# You should have received a copy of the GNU General Public License along with
16# this program. If not, see <http://www.gnu.org/licenses/>.
17# ------------------------------------------------------------------------------
19"""Reformulations that concern a particular type of constraint.
21The reformulations' logic is not found here but defined within the constraint
22classes in the form of a :class:`constraint conversion class
23<picos.constraints.constraint.ConstraintConversion>`.
24"""
26import inspect
28from .. import constraints
29from ..constraints import Constraint, ConstraintConversion
30from ..modeling.footprint import Footprint
31from .reformulation import Reformulation
33# No call to apidoc.api_start: Module defines __all__.
34# ----------------------------------------------------
37def reformulation_init(self, theObject):
38 """Implement :meth:`~.reformulation.Reformulation.__init__`."""
39 Reformulation.__init__(self, theObject)
41 self.constraintType = self.__class__.CONSTRAINT_TYPE
42 self.makeTmpProblem = self.__class__.CONVERSION_TYPE.convert
43 self.makeDualValue = self.__class__.CONVERSION_TYPE.dual
46def reformulation_supports(cls, footprint):
47 """Implement :meth:`~.reformulation.Reformulation.supports`."""
48 return ("con", cls.CONSTRAINT_TYPE) in footprint
51def reformulation_predict(cls, footprint):
52 """Implement :meth:`~.reformulation.Reformulation.predict`."""
53 updates = [("con", cls.CONSTRAINT_TYPE, Footprint.NONE)]
55 for subtype, count in footprint[("con", cls.CONSTRAINT_TYPE)].items:
56 assert len(subtype) == 1
57 subtype = subtype[0]
59 for addition in cls.CONVERSION_TYPE.predict(subtype, footprint.options):
60 assert isinstance(addition[-1], int)
61 updates.append(addition[:-1] + (addition[-1]*count,))
63 return footprint.updated(updates)
66def reformulation_reform_single(self, constraint, options):
67 """Convert a single constraint."""
68 assert isinstance(constraint, self.constraintType)
70 # Create a temporary problem from the constraint to be replaced.
71 tmpProblem = self.makeTmpProblem(constraint, options)
73 # Keep track of auxilary vars/cons replacing the constraint.
74 self.auxVars[constraint] = {}
75 self.auxCons[constraint] = []
77 # If the constraint to be transformed is part of the output prolem, remove
78 # it. This is the case when forwarding but not when updating.
79 if constraint.id in self.output.constraints:
80 self.output.remove_constraint(constraint.id)
82 # Remember auxiliary variables added so that their value can be removed from
83 # the solution in reformulation_backward.
84 for tmpVarName, tmpVar in tmpProblem.variables.items():
85 self.auxVars[constraint][tmpVarName] = tmpVar
87 # Add auxiliary constraints to the output problem.
88 # HACK: This only works while Problem.constraints is an OrderedDict as
89 # ConstraintConversion needs a way to identify the constraints added
90 # to the temporary problem without having any pointer to them.
91 auxCons = self.output.add_list_of_constraints(
92 tmpProblem.constraints.values())
93 self.auxCons[constraint].extend(auxCons)
96def reformulation_forward(self):
97 """Implement :meth:`~.reformulation.Reformulation.forward`."""
98 self.output = self.input.clone(copyOptions=False)
100 self.auxVars = {}
101 self.auxCons = {}
103 # TODO: Give Problem quick iterators over constraints of certain type.
104 for constraint in self.input.constraints.values():
105 if isinstance(constraint, self.constraintType):
106 self._reform_single(constraint, self.input.options)
109def reformulation_update(self):
110 """Implement :meth:`~.reformulation.Reformulation.update`."""
111 # Pass changes in the objective function.
112 self._pass_updated_objective()
114 # Pass all variables as they are.
115 self._pass_updated_vars()
117 # Pass all unhandled constraints as they are.
118 added, removed = self._pass_updated_cons(ignore=self.constraintType)
120 # Pass all option changes.
121 self._pass_updated_options()
123 # Reformulate new relevant constraints.
124 for constraint in added:
125 self._reform_single(constraint, self.input.options)
127 # Remove auxiliary objects added for relevant removed constraints.
128 for constraint in removed:
129 assert constraint in self.auxVars and constraint in self.auxCons
131 for auxCon in self.auxCons[constraint]:
132 assert auxCon.id in self.output.constraints
133 self.output.remove_constraint(auxCon.id)
135 self.auxCons.pop(constraint)
136 self.auxVars.pop(constraint)
139def reformulation_backward(self, solution):
140 """Implement :meth:`~.reformulation.Reformulation.backward`."""
141 # TODO: Give Problem quick iterators over constraints of certain type.
142 for constraint in self.input.constraints.values():
143 if isinstance(constraint, self.constraintType):
144 try:
145 solution.duals[constraint] = self.makeDualValue(
146 {v: solution.primals.get(v)
147 for v in self.auxVars[constraint]},
148 [solution.duals.get(c) for c in self.auxCons[constraint]],
149 self.input.options)
150 except NotImplementedError:
151 solution.duals[constraint] = None
153 for auxVar in self.auxVars[constraint]:
154 if auxVar in solution.primals:
155 solution.primals.pop(auxVar)
157 for auxCon in self.auxCons[constraint]:
158 if auxCon in solution.duals:
159 solution.duals.pop(auxCon)
161 return solution
164def make_constraint_reformulation(constraint, conversion):
165 """Produce a :class:`Reformulation` from a :class:`ConstraintConversion`.
167 A helper that creates a :class:`Reformulation` type (subclass) from a
168 :class:`Constraint` and :class:`constraint.ConstraintConversion` types.
169 """
170 assert constraint.__name__.endswith("Constraint"), \
171 "Constraint types must have a name ending in 'Constraint'."
173 assert conversion.__name__.endswith("Conversion"), \
174 "Constraint conversions must have a name ending in 'Conversion'."
176 constraintName = constraint.__name__[:-len("Constraint")]
177 conversionName = conversion.__name__[:-len("Conversion")]
179 name = "{}{}Reformulation".format(constraintName,
180 "To{}".format(conversionName) if conversionName else "")
182 docstring = "Reformulation created from :class:`{1} <{0}.{1}>`." \
183 .format(conversion.__module__, conversion.__qualname__)
185 # TODO: This would need an additional prediction that yields all class
186 # types of constraints that can be converted. Is that needed?
187 # NOTE: SumExponentialsConstraint.LSEConversion makes use of return the
188 # constraint as is for certain unsupported subtypes.
189 # assert constraint not in conversion.adds_constraint_types(), \
190 # "Constraint conversions may not add the very type being converted."
192 body = {
193 # Class constants.
194 "__module__": make_constraint_reformulation.__module__, # Defined here.
195 "CONSTRAINT_TYPE": constraint,
196 "CONVERSION_TYPE": conversion,
198 # Class methods.
199 "supports": classmethod(reformulation_supports),
200 "predict": classmethod(reformulation_predict),
202 # Instance methods.
203 "__init__": reformulation_init,
204 "__doc__": docstring,
205 "_reform_single": reformulation_reform_single,
206 "forward": reformulation_forward,
207 "update": reformulation_update,
208 "backward": reformulation_backward
209 }
211 # TODO: Check if anything like the following is still necessary.
212 # # HACK: See QuadConstraint.ConicConversion.predict.
213 # if constraint is constraints.QuadConstraint \
214 # and conversion is constraints.QuadConstraint.ConicConversion:
215 # body["_verify_prediction"] = lambda self: None
217 return type(name, (Reformulation,), body)
220# Allow __init__ to import exactly the generated reformulations using asterisk.
221__all__ = []
223# For every constraint conversion, generate a problem reformulation.
224NUM_REFORMS = 0
225CONSTRAINT_TO_REFORMS = {}
226for constraint in constraints.__dict__.values():
227 if not inspect.isclass(constraint):
228 continue
230 if not issubclass(constraint, Constraint):
231 continue
233 CONSTRAINT_TO_REFORMS[constraint] = []
235 for conversion in constraint.__dict__.values():
236 if not inspect.isclass(conversion):
237 continue
239 if not issubclass(conversion, ConstraintConversion):
240 continue
242 reformulation = make_constraint_reformulation(constraint, conversion)
244 NUM_REFORMS += 1
245 CONSTRAINT_TO_REFORMS[constraint].append(reformulation)
247 # Export the reformulations as if they were defined at module level.
248 globals()[reformulation.__name__] = reformulation
249 __all__.append(reformulation.__name__)
251# Make the order of __all__ deterministic.
252__all__ = sorted(__all__)
254# FIXME: Restore a topological sorting.
255# TODO: As above, this would need a separate predictor.
256TOPOSORTED_REFORMS = [globals()[name] for name in __all__]
259# --------------------------------------------------
260# No call to apidoc.ape_end: Module defines __all__.