Coverage for picos/constraints/con_kldiv.py: 98.39%
62 statements
« prev ^ index » next coverage.py v6.5.0, created at 2023-02-15 14:21 +0000
« prev ^ index » next coverage.py v6.5.0, created at 2023-02-15 14:21 +0000
1# ------------------------------------------------------------------------------
2# Copyright (C) 2018-2019 Maximilian Stahlberg
3#
4# This file is part of PICOS.
5#
6# PICOS is free software: you can redistribute it and/or modify it under the
7# terms of the GNU General Public License as published by the Free Software
8# Foundation, either version 3 of the License, or (at your option) any later
9# version.
10#
11# PICOS is distributed in the hope that it will be useful, but WITHOUT ANY
12# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
13# A PARTICULAR PURPOSE. See the GNU General Public License for more details.
14#
15# You should have received a copy of the GNU General Public License along with
16# this program. If not, see <http://www.gnu.org/licenses/>.
17# ------------------------------------------------------------------------------
19"""Implementation of :class:`KullbackLeiblerConstraint`."""
21from collections import namedtuple
23from .. import glyphs
24from ..apidoc import api_end, api_start
25from ..caching import cached_property
26from .constraint import Constraint, ConstraintConversion
28_API_START = api_start(globals())
29# -------------------------------
32class KullbackLeiblerConstraint(Constraint):
33 """Upper bound on a Kullback-Leibler divergence.
35 This is the upper bound on a negative or relative entropy, both represented
36 by :class:`~picos.expressions.NegativeEntropy`.
37 """
39 class ExpConeConversion(ConstraintConversion):
40 """Kullback-Leibler to exponential cone constraint conversion."""
42 @classmethod
43 def predict(cls, subtype, options):
44 """Implement :meth:`~.constraint.ConstraintConversion.predict`."""
45 from ..expressions import RealVariable
46 from . import AffineConstraint, ExpConeConstraint
48 n = subtype.argdim
50 yield ("var", RealVariable.make_var_type(dim=n, bnd=0), 1)
51 yield ("con", AffineConstraint.make_type(dim=1, eq=False), 1)
52 yield ("con", ExpConeConstraint.make_type(), n)
54 @classmethod
55 def convert(cls, con, options):
56 """Implement :meth:`~.constraint.ConstraintConversion.convert`."""
57 from ..expressions import ExponentialCone
58 from ..modeling import Problem
60 x = con.numerator
61 y = con.denominator
62 n = con.divergence.n
63 b = con.upperBound
65 P = Problem()
67 u = P.add_variable("__u", n)
68 P.add_constraint((u | 1) >= -b)
70 for i in range(n):
71 P.add_constraint((y[i] // x[i] // u[i]) << ExponentialCone())
73 return P
75 def __init__(self, divergence, upperBound):
76 """Construct a :class:`KullbackLeiblerConstraint`.
78 :param ~picos.expressions.NegativeEntropy divergence:
79 Constrained expression.
80 :param ~picos.expressions.AffineExpression upperBound:
81 Upper bound on the expression.
82 """
83 from ..expressions import AffineExpression, NegativeEntropy
85 assert isinstance(divergence, NegativeEntropy)
86 assert isinstance(upperBound, AffineExpression)
87 assert len(upperBound) == 1
89 self.divergence = divergence
90 self.upperBound = upperBound
92 super(KullbackLeiblerConstraint, self).__init__(divergence._typeStr)
94 @property
95 def numerator(self):
96 """The :math:`x` of the divergence."""
97 return self.divergence.x
99 @cached_property
100 def denominator(self):
101 """The :math:`y` of the divergence, or :math:`1`."""
102 from ..expressions import AffineExpression
104 if self.divergence.y is None:
105 return AffineExpression.from_constant(1, self.divergence.x.shape)
106 else:
107 return self.divergence.y
109 Subtype = namedtuple("Subtype", ("argdim",))
111 def _subtype(self):
112 return self.Subtype(len(self.numerator))
114 @classmethod
115 def _cost(cls, subtype):
116 # NOTE: Twice the argument dimension due to the denominator.
117 return 2*subtype.argdim + 1
119 def _expression_names(self):
120 yield "divergence"
121 yield "upperBound"
123 def _str(self):
124 return glyphs.le(self.divergence.string, self.upperBound.string)
126 def _get_slack(self):
127 return self.upperBound.safe_value - self.divergence.safe_value
130# --------------------------------------
131__all__ = api_end(_API_START, globals())