Coverage for picos/expressions/set_simplex.py: 89.06%

64 statements  

« prev     ^ index     » next       coverage.py v6.5.0, created at 2023-03-26 07:46 +0000

1# ------------------------------------------------------------------------------ 

2# Copyright (C) 2019 Maximilian Stahlberg 

3# Based on the original picos.expressions module by Guillaume Sagnol. 

4# 

5# This file is part of PICOS. 

6# 

7# PICOS is free software: you can redistribute it and/or modify it under the 

8# terms of the GNU General Public License as published by the Free Software 

9# Foundation, either version 3 of the License, or (at your option) any later 

10# version. 

11# 

12# PICOS is distributed in the hope that it will be useful, but WITHOUT ANY 

13# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR 

14# A PARTICULAR PURPOSE. See the GNU General Public License for more details. 

15# 

16# You should have received a copy of the GNU General Public License along with 

17# this program. If not, see <http://www.gnu.org/licenses/>. 

18# ------------------------------------------------------------------------------ 

19 

20"""Implements :class:`Simplex`.""" 

21 

22import operator 

23from collections import namedtuple 

24 

25from .. import glyphs 

26from ..apidoc import api_end, api_start 

27from ..constraints import SimplexConstraint 

28from .data import convert_and_refine_arguments 

29from .exp_affine import AffineExpression, Constant 

30from .set import Set 

31 

32_API_START = api_start(globals()) 

33# ------------------------------- 

34 

35 

36class Simplex(Set): 

37 r"""A (truncated, symmetrized) real simplex. 

38 

39 :Definition: 

40 

41 Let :math:`r \in \mathbb{R}_{\geq 0}` the specified radius and 

42 :math:`n \in \mathbb{Z}_{\geq 1}` an arbitrary dimensionality. 

43 

44 1. Without truncation and symmetrization, this is the nonnegative simplex 

45 

46 .. math:: 

47 

48 \{x \in \mathbb{R}^n_{\geq 0} \mid \sum_{i = 1}^n x_i \leq r\}. 

49 

50 For :math:`r = 1`, this is the standard (unit) :math:`n`-simplex. 

51 

52 2. With truncation but without symmetrization, this is the nonnegative 

53 simplex intersected with the :math:`\infty`-norm unit ball 

54 

55 .. math:: 

56 

57 \{ 

58 x \in \mathbb{R}^n_{\geq 0} 

59 \mid 

60 \sum_{i = 1}^n x_i \leq r \land x \leq 1 

61 \}. 

62 

63 For :math:`r \leq 1`, this equals case (1). 

64 

65 3. With symmetrization but without truncation, this is the :math:`1`-norm 

66 ball of radius :math:`r` 

67 

68 .. math:: 

69 

70 \{x \in \mathbb{R}^n \mid \sum_{i = 1}^n |x_i| \leq r\}. 

71 

72 4. With both symmetrization and truncation, this is the convex polytope 

73 

74 .. math:: 

75 

76 \{ 

77 x \in \mathbb{R} 

78 \mid 

79 \sum_{i = 1}^n |x_i| \leq r \land 0 \leq x \leq 1 

80 \}. 

81 

82 For :math:`r \leq 1`, this equals case (3). 

83 """ 

84 

85 @convert_and_refine_arguments("radius") 

86 def __init__(self, radius=Constant(1), truncated=False, symmetrized=False): 

87 """Construct a :class:`Simplex`. 

88 

89 :param radius: The radius of the simplex. 

90 :type radius: 

91 float or ~picos.expressions.AffineExpression 

92 """ 

93 if not isinstance(radius, AffineExpression): 

94 raise TypeError("A simplex' radius must be given as a real affine " 

95 "expression, not as {}.".format(type(radius).__name__)) 

96 elif not radius.scalar: 

97 raise TypeError("A simplex' radius must be scalar, not of shape {}." 

98 .format(glyphs.shape(radius.shape))) 

99 

100 if radius.constant and radius.value <= 1: 

101 truncated = False 

102 

103 var = glyphs.free_var_name(radius.string) 

104 unit = "Unit " if radius.is1 else "" 

105 if not truncated and not symmetrized: 

106 typeStr = "{}Simplex".format(unit) 

107 symbStr = glyphs.set(glyphs.sep(glyphs.ge(var, 0), 

108 glyphs.le(glyphs.sum(var), radius.string))) 

109 elif truncated and not symmetrized: 

110 typeStr = "Box-Truncated {}Simplex".format(unit) 

111 symbStr = glyphs.set(glyphs.sep(glyphs.le(0, glyphs.le(var, 1)), 

112 glyphs.le(glyphs.sum(var), radius.string))) 

113 elif not truncated and symmetrized: 

114 typeStr = "{}1-norm Ball".format(unit) 

115 symbStr = glyphs.set(glyphs.sep(var, 

116 glyphs.le(glyphs.sum(glyphs.abs(var)), radius.string))) 

117 else: # truncated and symmetrized 

118 typeStr = "Box-Truncated {}1-norm Ball".format(unit) 

119 symbStr = glyphs.set(glyphs.sep(glyphs.le(-1, glyphs.le(var, 1)), 

120 glyphs.le(glyphs.sum(glyphs.abs(var)), radius.string))) 

121 

122 self._radius = radius 

123 self._truncated = truncated 

124 self._symmetrized = symmetrized 

125 

126 Set.__init__(self, typeStr, symbStr) 

127 

128 @property 

129 def radius(self): 

130 """The radius of the simplex.""" 

131 return self._radius 

132 

133 @property 

134 def truncated(self): 

135 r"""Whether this is intersected with the unit :math:`\infty`-ball.""" 

136 return self._truncated 

137 

138 @property 

139 def symmetrized(self): 

140 """Wether the simplex is mirrored onto all orthants.""" 

141 return self._symmetrized 

142 

143 def _get_mutables(self): 

144 return self._radius._get_mutables() 

145 

146 def _replace_mutables(self, mapping): 

147 return self.__class__(self._radius._replace_mutables(mapping), 

148 self._truncated, self._symmetrized) 

149 

150 Subtype = namedtuple("Subtype", ("truncated", "symmetrized")) 

151 

152 def _get_subtype(self): 

153 return self.Subtype(self._truncated, self._symmetrized) 

154 

155 @classmethod 

156 def _predict(cls, subtype, relation, other): 

157 assert isinstance(subtype, cls.Subtype) 

158 

159 if relation == operator.__rshift__: 

160 if issubclass(other.clstype, AffineExpression): 

161 return SimplexConstraint.make_type( 

162 argdim=other.subtype.dim, 

163 truncated=subtype.truncated, 

164 symmetrized=subtype.symmetrized) 

165 

166 return NotImplemented 

167 

168 def _rshift_implementation(self, element): 

169 if isinstance(element, AffineExpression): 

170 return SimplexConstraint(self, element) 

171 else: 

172 return NotImplemented 

173 

174 

175# -------------------------------------- 

176__all__ = api_end(_API_START, globals())