Coverage for picos/reforms/reform_constraint.py: 92.86%

98 statements  

« prev     ^ index     » next       coverage.py v6.5.0, created at 2023-03-26 07:46 +0000

1# ------------------------------------------------------------------------------ 

2# Copyright (C) 2019 Maximilian Stahlberg 

3# 

4# This file is part of PICOS. 

5# 

6# PICOS is free software: you can redistribute it and/or modify it under the 

7# terms of the GNU General Public License as published by the Free Software 

8# Foundation, either version 3 of the License, or (at your option) any later 

9# version. 

10# 

11# PICOS is distributed in the hope that it will be useful, but WITHOUT ANY 

12# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR 

13# A PARTICULAR PURPOSE. See the GNU General Public License for more details. 

14# 

15# You should have received a copy of the GNU General Public License along with 

16# this program. If not, see <http://www.gnu.org/licenses/>. 

17# ------------------------------------------------------------------------------ 

18 

19"""Reformulations that concern a particular type of constraint. 

20 

21The reformulations' logic is not found here but defined within the constraint 

22classes in the form of a :class:`constraint conversion class 

23<picos.constraints.constraint.ConstraintConversion>`. 

24""" 

25 

26import inspect 

27 

28from .. import constraints 

29from ..constraints import Constraint, ConstraintConversion 

30from ..modeling.footprint import Footprint 

31from .reformulation import Reformulation 

32 

33# No call to apidoc.api_start: Module defines __all__. 

34# ---------------------------------------------------- 

35 

36 

37def reformulation_init(self, theObject): 

38 """Implement :meth:`~.reformulation.Reformulation.__init__`.""" 

39 Reformulation.__init__(self, theObject) 

40 

41 self.constraintType = self.__class__.CONSTRAINT_TYPE 

42 self.makeTmpProblem = self.__class__.CONVERSION_TYPE.convert 

43 self.makeDualValue = self.__class__.CONVERSION_TYPE.dual 

44 

45 

46def reformulation_supports(cls, footprint): 

47 """Implement :meth:`~.reformulation.Reformulation.supports`.""" 

48 return ("con", cls.CONSTRAINT_TYPE) in footprint 

49 

50 

51def reformulation_predict(cls, footprint): 

52 """Implement :meth:`~.reformulation.Reformulation.predict`.""" 

53 updates = [("con", cls.CONSTRAINT_TYPE, Footprint.NONE)] 

54 

55 for subtype, count in footprint[("con", cls.CONSTRAINT_TYPE)].items: 

56 assert len(subtype) == 1 

57 subtype = subtype[0] 

58 

59 for addition in cls.CONVERSION_TYPE.predict(subtype, footprint.options): 

60 assert isinstance(addition[-1], int) 

61 updates.append(addition[:-1] + (addition[-1]*count,)) 

62 

63 return footprint.updated(updates) 

64 

65 

66def reformulation_reform_single(self, constraint, options): 

67 """Convert a single constraint.""" 

68 assert isinstance(constraint, self.constraintType) 

69 

70 # Create a temporary problem from the constraint to be replaced. 

71 tmpProblem = self.makeTmpProblem(constraint, options) 

72 

73 # Keep track of auxilary vars/cons replacing the constraint. 

74 self.auxVars[constraint] = {} 

75 self.auxCons[constraint] = [] 

76 

77 # If the constraint to be transformed is part of the output prolem, remove 

78 # it. This is the case when forwarding but not when updating. 

79 if constraint.id in self.output.constraints: 

80 self.output.remove_constraint(constraint.id) 

81 

82 # Remember auxiliary variables added so that their value can be removed from 

83 # the solution in reformulation_backward. 

84 for tmpVarName, tmpVar in tmpProblem.variables.items(): 

85 self.auxVars[constraint][tmpVarName] = tmpVar 

86 

87 # Add auxiliary constraints to the output problem. 

88 # HACK: This only works while Problem.constraints is an OrderedDict as 

89 # ConstraintConversion needs a way to identify the constraints added 

90 # to the temporary problem without having any pointer to them. 

91 auxCons = self.output.add_list_of_constraints( 

92 tmpProblem.constraints.values()) 

93 self.auxCons[constraint].extend(auxCons) 

94 

95 

96def reformulation_forward(self): 

97 """Implement :meth:`~.reformulation.Reformulation.forward`.""" 

98 self.output = self.input.clone(copyOptions=False) 

99 

100 self.auxVars = {} 

101 self.auxCons = {} 

102 

103 # TODO: Give Problem quick iterators over constraints of certain type. 

104 for constraint in self.input.constraints.values(): 

105 if isinstance(constraint, self.constraintType): 

106 self._reform_single(constraint, self.input.options) 

107 

108 

109def reformulation_update(self): 

110 """Implement :meth:`~.reformulation.Reformulation.update`.""" 

111 # Pass changes in the objective function. 

112 self._pass_updated_objective() 

113 

114 # Pass all variables as they are. 

115 self._pass_updated_vars() 

116 

117 # Pass all unhandled constraints as they are. 

118 added, removed = self._pass_updated_cons(ignore=self.constraintType) 

119 

120 # Pass all option changes. 

121 self._pass_updated_options() 

122 

123 # Reformulate new relevant constraints. 

124 for constraint in added: 

125 self._reform_single(constraint, self.input.options) 

126 

127 # Remove auxiliary objects added for relevant removed constraints. 

128 for constraint in removed: 

129 assert constraint in self.auxVars and constraint in self.auxCons 

130 

131 for auxCon in self.auxCons[constraint]: 

132 assert auxCon.id in self.output.constraints 

133 self.output.remove_constraint(auxCon.id) 

134 

135 self.auxCons.pop(constraint) 

136 self.auxVars.pop(constraint) 

137 

138 

139def reformulation_backward(self, solution): 

140 """Implement :meth:`~.reformulation.Reformulation.backward`.""" 

141 # TODO: Give Problem quick iterators over constraints of certain type. 

142 for constraint in self.input.constraints.values(): 

143 if isinstance(constraint, self.constraintType): 

144 try: 

145 solution.duals[constraint] = self.makeDualValue( 

146 {v: solution.primals.get(v) 

147 for v in self.auxVars[constraint]}, 

148 [solution.duals.get(c) for c in self.auxCons[constraint]], 

149 self.input.options) 

150 except NotImplementedError: 

151 solution.duals[constraint] = None 

152 

153 for auxVar in self.auxVars[constraint]: 

154 if auxVar in solution.primals: 

155 solution.primals.pop(auxVar) 

156 

157 for auxCon in self.auxCons[constraint]: 

158 if auxCon in solution.duals: 

159 solution.duals.pop(auxCon) 

160 

161 return solution 

162 

163 

164def make_constraint_reformulation(constraint, conversion): 

165 """Produce a :class:`Reformulation` from a :class:`ConstraintConversion`. 

166 

167 A helper that creates a :class:`Reformulation` type (subclass) from a 

168 :class:`Constraint` and :class:`constraint.ConstraintConversion` types. 

169 """ 

170 assert constraint.__name__.endswith("Constraint"), \ 

171 "Constraint types must have a name ending in 'Constraint'." 

172 

173 assert conversion.__name__.endswith("Conversion"), \ 

174 "Constraint conversions must have a name ending in 'Conversion'." 

175 

176 constraintName = constraint.__name__[:-len("Constraint")] 

177 conversionName = conversion.__name__[:-len("Conversion")] 

178 

179 name = "{}{}Reformulation".format(constraintName, 

180 "To{}".format(conversionName) if conversionName else "") 

181 

182 docstring = "Reformulation created from :class:`{1} <{0}.{1}>`." \ 

183 .format(conversion.__module__, conversion.__qualname__) 

184 

185 # TODO: This would need an additional prediction that yields all class 

186 # types of constraints that can be converted. Is that needed? 

187 # NOTE: SumExponentialsConstraint.LSEConversion makes use of return the 

188 # constraint as is for certain unsupported subtypes. 

189 # assert constraint not in conversion.adds_constraint_types(), \ 

190 # "Constraint conversions may not add the very type being converted." 

191 

192 body = { 

193 # Class constants. 

194 "__module__": make_constraint_reformulation.__module__, # Defined here. 

195 "CONSTRAINT_TYPE": constraint, 

196 "CONVERSION_TYPE": conversion, 

197 

198 # Class methods. 

199 "supports": classmethod(reformulation_supports), 

200 "predict": classmethod(reformulation_predict), 

201 

202 # Instance methods. 

203 "__init__": reformulation_init, 

204 "__doc__": docstring, 

205 "_reform_single": reformulation_reform_single, 

206 "forward": reformulation_forward, 

207 "update": reformulation_update, 

208 "backward": reformulation_backward 

209 } 

210 

211 # TODO: Check if anything like the following is still necessary. 

212 # # HACK: See QuadConstraint.ConicConversion.predict. 

213 # if constraint is constraints.QuadConstraint \ 

214 # and conversion is constraints.QuadConstraint.ConicConversion: 

215 # body["_verify_prediction"] = lambda self: None 

216 

217 return type(name, (Reformulation,), body) 

218 

219 

220# Allow __init__ to import exactly the generated reformulations using asterisk. 

221__all__ = [] 

222 

223# For every constraint conversion, generate a problem reformulation. 

224NUM_REFORMS = 0 

225CONSTRAINT_TO_REFORMS = {} 

226for constraint in constraints.__dict__.values(): 

227 if not inspect.isclass(constraint): 

228 continue 

229 

230 if not issubclass(constraint, Constraint): 

231 continue 

232 

233 CONSTRAINT_TO_REFORMS[constraint] = [] 

234 

235 for conversion in constraint.__dict__.values(): 

236 if not inspect.isclass(conversion): 

237 continue 

238 

239 if not issubclass(conversion, ConstraintConversion): 

240 continue 

241 

242 reformulation = make_constraint_reformulation(constraint, conversion) 

243 

244 NUM_REFORMS += 1 

245 CONSTRAINT_TO_REFORMS[constraint].append(reformulation) 

246 

247 # Export the reformulations as if they were defined at module level. 

248 globals()[reformulation.__name__] = reformulation 

249 __all__.append(reformulation.__name__) 

250 

251# Make the order of __all__ deterministic. 

252__all__ = sorted(__all__) 

253 

254# FIXME: Restore a topological sorting. 

255# TODO: As above, this would need a separate predictor. 

256TOPOSORTED_REFORMS = [globals()[name] for name in __all__] 

257 

258 

259# -------------------------------------------------- 

260# No call to apidoc.ape_end: Module defines __all__.