Coverage for picos/expressions/cone_product.py: 87.69%
65 statements
« prev ^ index » next coverage.py v6.5.0, created at 2023-02-15 14:21 +0000
« prev ^ index » next coverage.py v6.5.0, created at 2023-02-15 14:21 +0000
1# ------------------------------------------------------------------------------
2# Copyright (C) 2020 Maximilian Stahlberg
3#
4# This file is part of PICOS.
5#
6# PICOS is free software: you can redistribute it and/or modify it under the
7# terms of the GNU General Public License as published by the Free Software
8# Foundation, either version 3 of the License, or (at your option) any later
9# version.
10#
11# PICOS is distributed in the hope that it will be useful, but WITHOUT ANY
12# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
13# A PARTICULAR PURPOSE. See the GNU General Public License for more details.
14#
15# You should have received a copy of the GNU General Public License along with
16# this program. If not, see <http://www.gnu.org/licenses/>.
17# ------------------------------------------------------------------------------
19"""Implements a Cartesian product cone."""
21import operator
22from collections import namedtuple
24from .. import glyphs
25from ..apidoc import api_end, api_start
26from ..caching import cached_property
27from ..constraints import ProductConeConstraint
28from .cone import Cone
29from .exp_affine import AffineExpression
31_API_START = api_start(globals())
32# -------------------------------
35class ProductCone(Cone):
36 """A real Cartesian product cone."""
38 @classmethod
39 def _unpack(cls, nested_cones):
40 flattened_cones = []
41 for inner_cone in nested_cones:
42 if isinstance(inner_cone, ProductCone):
43 flattened_cones.extend(cls._unpack(inner_cone.cones))
44 else:
45 flattened_cones.append(inner_cone)
46 return flattened_cones
48 def __init__(self, *cones):
49 """Construct a :class:`ProductCone`.
51 :param list(picos.expressions.Cone) cones:
52 A sequence of cones to build the product cone from. May include
53 other product cones that will be "unpacked" first.
54 """
55 if not cones or not all(isinstance(cone, Cone) for cone in cones):
56 raise TypeError("Must initialize product cones with a nonempty "
57 "sequence of cone instances.")
59 if not all(cone.dim for cone in cones):
60 raise TypeError("Product cones must be built from cones with a "
61 "fixed dimensionality.")
63 if any(cone.mutables for cone in cones):
64 raise NotImplementedError("Product cones may not include cones "
65 "whose definition depends on mutables.")
67 # Unpack nested product cones.
68 cones = self._unpack(cones)
70 dim = sum(cone.dim for cone in cones)
72 Cone.__init__(self, dim, "Product Cone", glyphs.prod(glyphs.sep(
73 "Ci", glyphs.element("i", glyphs.interval(len(cones))))))
75 self._cones = tuple(cones)
77 @property
78 def cones(self):
79 """The cones that make up the product cone as a tuple."""
80 return self._cones
82 def _get_mutables(self):
83 return frozenset() # See NotImplementedError in __init__.
85 def _replace_mutables(self):
86 return self # See NotImplementedError in __init__.
88 Subtype = namedtuple("Subtype", ("dim", "cones"))
90 def _get_subtype(self):
91 # NOTE: Storing dim is redundant but simplifies _predict.
92 return self.Subtype(
93 self.dim, cones=tuple(cone.type for cone in self._cones))
95 @property
96 def refined(self):
97 """Overwrite :attr:`~.set.Set.refined`."""
98 if len(self._cones) == 1:
99 return self._cones[0]
100 else:
101 return self
103 @classmethod
104 def _predict(cls, subtype, relation, other):
105 assert isinstance(subtype, cls.Subtype)
107 if relation == operator.__rshift__:
108 if issubclass(other.clstype, AffineExpression) \
109 and subtype.dim == other.subtype.dim:
110 if len(subtype.cones) == 1:
111 return subtype.cones[0].predict(operator.__rshift__, other)
113 return ProductConeConstraint.make_type(
114 dim=subtype.dim, cones=subtype.cones)
116 return Cone._predict_base(cls, subtype, relation, other)
118 def _rshift_implementation(self, element):
119 if isinstance(element, AffineExpression):
120 self._check_dimension(element)
122 # HACK: Mimic refinement: Do not produce a ProductConeConstraint
123 # to represent a basic conic inequality.
124 # TODO: Add a common base class for Expression and Set that allows
125 # proper refinement also for instances of the latter.
126 if len(self.cones) == 1:
127 return self.cones[0] >> element
129 return ProductConeConstraint(element, self)
131 # Handle scenario uncertainty for all cones.
132 return Cone._rshift_base(self, element)
134 @cached_property
135 def dual_cone(self):
136 """Implement :attr:`.cone.Cone.dual_cone`."""
137 return self.__class__(*(cone.dual_cone for cone in self._cones))
140# --------------------------------------
141__all__ = api_end(_API_START, globals())