Coverage for picos/constraints/con_kldiv.py: 98.39%

62 statements  

« prev     ^ index     » next       coverage.py v6.5.0, created at 2023-03-26 07:46 +0000

1# ------------------------------------------------------------------------------ 

2# Copyright (C) 2018-2019 Maximilian Stahlberg 

3# 

4# This file is part of PICOS. 

5# 

6# PICOS is free software: you can redistribute it and/or modify it under the 

7# terms of the GNU General Public License as published by the Free Software 

8# Foundation, either version 3 of the License, or (at your option) any later 

9# version. 

10# 

11# PICOS is distributed in the hope that it will be useful, but WITHOUT ANY 

12# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR 

13# A PARTICULAR PURPOSE. See the GNU General Public License for more details. 

14# 

15# You should have received a copy of the GNU General Public License along with 

16# this program. If not, see <http://www.gnu.org/licenses/>. 

17# ------------------------------------------------------------------------------ 

18 

19"""Implementation of :class:`KullbackLeiblerConstraint`.""" 

20 

21from collections import namedtuple 

22 

23from .. import glyphs 

24from ..apidoc import api_end, api_start 

25from ..caching import cached_property 

26from .constraint import Constraint, ConstraintConversion 

27 

28_API_START = api_start(globals()) 

29# ------------------------------- 

30 

31 

32class KullbackLeiblerConstraint(Constraint): 

33 """Upper bound on a Kullback-Leibler divergence. 

34 

35 This is the upper bound on a negative or relative entropy, both represented 

36 by :class:`~picos.expressions.NegativeEntropy`. 

37 """ 

38 

39 class ExpConeConversion(ConstraintConversion): 

40 """Kullback-Leibler to exponential cone constraint conversion.""" 

41 

42 @classmethod 

43 def predict(cls, subtype, options): 

44 """Implement :meth:`~.constraint.ConstraintConversion.predict`.""" 

45 from ..expressions import RealVariable 

46 from . import AffineConstraint, ExpConeConstraint 

47 

48 n = subtype.argdim 

49 

50 yield ("var", RealVariable.make_var_type(dim=n, bnd=0), 1) 

51 yield ("con", AffineConstraint.make_type(dim=1, eq=False), 1) 

52 yield ("con", ExpConeConstraint.make_type(), n) 

53 

54 @classmethod 

55 def convert(cls, con, options): 

56 """Implement :meth:`~.constraint.ConstraintConversion.convert`.""" 

57 from ..expressions import ExponentialCone 

58 from ..modeling import Problem 

59 

60 x = con.numerator 

61 y = con.denominator 

62 n = con.divergence.n 

63 b = con.upperBound 

64 

65 P = Problem() 

66 

67 u = P.add_variable("__u", n) 

68 P.add_constraint((u | 1) >= -b) 

69 

70 for i in range(n): 

71 P.add_constraint((y[i] // x[i] // u[i]) << ExponentialCone()) 

72 

73 return P 

74 

75 def __init__(self, divergence, upperBound): 

76 """Construct a :class:`KullbackLeiblerConstraint`. 

77 

78 :param ~picos.expressions.NegativeEntropy divergence: 

79 Constrained expression. 

80 :param ~picos.expressions.AffineExpression upperBound: 

81 Upper bound on the expression. 

82 """ 

83 from ..expressions import AffineExpression, NegativeEntropy 

84 

85 assert isinstance(divergence, NegativeEntropy) 

86 assert isinstance(upperBound, AffineExpression) 

87 assert len(upperBound) == 1 

88 

89 self.divergence = divergence 

90 self.upperBound = upperBound 

91 

92 super(KullbackLeiblerConstraint, self).__init__(divergence._typeStr) 

93 

94 @property 

95 def numerator(self): 

96 """The :math:`x` of the divergence.""" 

97 return self.divergence.x 

98 

99 @cached_property 

100 def denominator(self): 

101 """The :math:`y` of the divergence, or :math:`1`.""" 

102 from ..expressions import AffineExpression 

103 

104 if self.divergence.y is None: 

105 return AffineExpression.from_constant(1, self.divergence.x.shape) 

106 else: 

107 return self.divergence.y 

108 

109 Subtype = namedtuple("Subtype", ("argdim",)) 

110 

111 def _subtype(self): 

112 return self.Subtype(len(self.numerator)) 

113 

114 @classmethod 

115 def _cost(cls, subtype): 

116 # NOTE: Twice the argument dimension due to the denominator. 

117 return 2*subtype.argdim + 1 

118 

119 def _expression_names(self): 

120 yield "divergence" 

121 yield "upperBound" 

122 

123 def _str(self): 

124 return glyphs.le(self.divergence.string, self.upperBound.string) 

125 

126 def _get_slack(self): 

127 return self.upperBound.safe_value - self.divergence.safe_value 

128 

129 

130# -------------------------------------- 

131__all__ = api_end(_API_START, globals())