Coverage for picos/expressions/uncertain/pert_wasserstein.py: 79.69%
64 statements
« prev ^ index » next coverage.py v6.5.0, created at 2023-02-15 14:21 +0000
« prev ^ index » next coverage.py v6.5.0, created at 2023-02-15 14:21 +0000
1# ------------------------------------------------------------------------------
2# Copyright (C) 2020 Maximilian Stahlberg
3#
4# This file is part of PICOS.
5#
6# PICOS is free software: you can redistribute it and/or modify it under the
7# terms of the GNU General Public License as published by the Free Software
8# Foundation, either version 3 of the License, or (at your option) any later
9# version.
10#
11# PICOS is distributed in the hope that it will be useful, but WITHOUT ANY
12# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
13# A PARTICULAR PURPOSE. See the GNU General Public License for more details.
14#
15# You should have received a copy of the GNU General Public License along with
16# this program. If not, see <http://www.gnu.org/licenses/>.
17# ------------------------------------------------------------------------------
19"""Implements :class:`WassersteinAmbiguitySet`."""
21from collections import namedtuple
23import numpy
25from ... import glyphs
26from ...apidoc import api_end, api_start
27from ..data import cvx2np
28from ..exp_affine import AffineExpression
29from ..samples import Samples
30from .perturbation import Perturbation, PerturbationUniverse
32_API_START = api_start(globals())
33# -------------------------------
36class WassersteinAmbiguitySet(PerturbationUniverse):
37 r"""A wasserstein ambiguity set centered at a discrete distribution.
39 :Model of uncertainty:
41 As a distributional ambiguity set, an instance :math:`\mathcal{P}` of this
42 class
44 1. represents a safety region for a partially known (ambiguous) probability
45 distribution :math:`\Xi \in \mathcal{P}` and
46 2. provides a random, ambiguously distributed perturbation parameter
47 :math:`\xi \sim \Xi` that can be used to define worst-case-expectation
48 expressions of the form
50 .. math::
52 \mathop{(\max\;\textit{or}\;\min)}_{\Xi \in \mathcal{P}}
53 \mathbb{E}_\Xi[f(x, \xi)]
55 for a selection of functions :math:`f` and a decision variable :math:`x`.
57 :Definition:
59 Formally, this class can describe discrepancy-based ambiguity sets of the
60 form
62 .. math::
64 \mathcal{P} = \left\{
65 \Xi \in \mathcal{M} ~\middle|~
66 \operatorname{W}_p(\Xi, \Xi_\text{N}) \leq \epsilon
67 \right\}
69 where discrepancy from the discrete nominal distribution
71 .. math::
73 \Xi_\text{N} = \sum_{i = 1}^N w_i \delta_{\xi_{(i)}} \in \mathcal{M}
75 is measured with respect to the Wasserstein distance of order
76 :math:`p \geq 1`,
78 .. math::
80 \operatorname{W}_p(\Xi, \Xi') =
81 {\left(
82 \inf_{\Phi \in \Pi(\Xi, \Xi')}
83 \int_{\mathbb{R}^m \times \mathbb{R}^m}
84 \lVert \phi - \phi' \rVert^p \;
85 \Phi(
86 \mathop{}\!\mathrm{d} \phi
87 \times
88 \mathop{}\!\mathrm{d} \phi')
89 \right)}^{\frac{1}{p}},
91 where
93 1. :math:`\mathcal{M}` is the set of all Borel probability measures on
94 :math:`\mathbb{R}^n` for some :math:`n \in \mathbb{Z}_{\geq 1}`,
95 2. :math:`\Pi(\Xi, \Xi')` denotes the set of all couplings of :math:`\Xi`
96 and :math:`\Xi'`,
97 3. :math:`\xi_{(i)} \in \mathbb{R}^n` for all :math:`i \in [N]` are the
98 :math:`N \in \mathbb{Z}_{\geq 1}` *samples* comprising the support of
99 :math:`\Xi_\text{N}`,
100 4. :math:`w_i \in \mathbb{R}_{\geq 0}` are *weights* denoting the nominal
101 probabilitiy mass at :math:`\xi_{(i)}` for all :math:`i \in [N]`,
102 5. :math:`\delta_{\xi_{(i)}}` denotes the Dirac delta function with unit
103 mass at :math:`\xi_{(i)}` for all :math:`i \in [N]` and where
104 6. :math:`\epsilon \in \mathbb{R}_{\geq 0}` controls the radius of the
105 ambiguity set.
107 :Supported functions:
109 For :math:`p = 1`:
111 1. A convex piecewise linear function :math:`f(x, \xi) = max_{i=1}^k a_i(x,
112 \xi)` where :math:`a` is biaffine in :math:`x` and :math:`\xi` for all
113 :math:`i \in [k]`. This can be written as ``picos.max([a_1, ..., a_k])``
114 in Python.
115 2. A concave piecewise linear function :math:`f(x, \xi) = min_{i=1}^k a_i(x,
116 \xi)` where :math:`a` is biaffine in :math:`x` and :math:`\xi` for all
117 :math:`i \in [k]`. This can be written as ``picos.min([a_1, ..., a_k])``
118 in Python.
120 For :math:`p = 2`:
122 1. A squared norm :math:`f(x, \xi) = \lVert A(x, \xi) \rVert_F^2` where
123 :math:`A` is biaffine in :math:`x` and :math:`\xi`. This can be written
124 as ``abs(A)**2`` in Python.
125 """
127 def __init__(self, parameter_name, p, eps, samples, weights=1):
128 r"""Create a :class:`WassersteinAmbiguitySet`.
130 :param str parameter_name:
131 Name of the random parameter :math:`\xi`.
133 :param float p:
134 The Wasserstein type/order parameter :math:`p`.
136 :param float eps:
137 The Wasserstein ball radius :math:`\epsilon`.
139 :param samples:
140 The support of the discrete distribution :math:`\Xi_\text{D}` given
141 as the *samples* :math:`\xi_{(i)}`. The original shape of the
142 samples determines the shape of :math:`\xi`.
143 :type samples:
144 aynthing recognized by :class:`~.samples.Samples`
146 :param weights:
147 A vector denoting the nonnegative weight (e.g. frequency or
148 probability) of each sample. Its length must match the number of
149 samples provided. The argument will be normalized such that its
150 entries sum to one. Entries of zero will be dropped alongside their
151 associated sample. The default value of ``1`` denotes the empirical
152 distribution on the samples.
154 .. warning::
156 Duplicate samples are not detected and can impact performance. If
157 duplicate samples are likely, make sure to detect them and encode
158 their frequency in the weight vector.
159 """
160 # Load p.
161 self._p = float(p)
163 if self._p < 1:
164 raise ValueError("The Wasserstein parameter p must be >= 1.")
166 supported_p = (1, 2)
167 if self._p not in supported_p:
168 raise NotImplementedError("Currently, Wasserstein DRO is only "
169 "supported for p in {}.".format(set(supported_p)))
171 # Load epsilon.
172 self._eps = float(eps)
174 if self._eps < 0:
175 raise ValueError("The Wasserstein ball radius must be nonnegative.")
177 # Load the samples.
178 self._samples = Samples(samples)
180 # Load the normalized weights.
181 w = AffineExpression.from_constant(weights, (len(self._samples), 1))
182 w_np = numpy.ravel(cvx2np(w.value_as_matrix))
184 if any(w_np < 0):
185 raise ValueError(
186 "The weight vector must be nonnegative everywhere.")
188 if any(w_np == 0):
189 if all(w_np == 0):
190 raise ValueError("The weight vector must be nonzero.")
192 nonzero = numpy.where(w_np != 0)[0].tolist()
193 w = w[nonzero]
194 self._samples = self._samples.select(nonzero)
196 self._weights = (w / (w | 1)).renamed("w")
198 assert len(self._samples) == len(self._weights)
200 # Create the perturbation parameter.
201 self._parameter = Perturbation(
202 self, parameter_name, self._samples.original_shape)
204 @property
205 def p(self):
206 """The Wasserstein order :math:`p`."""
207 return self._p
209 @property
210 def eps(self):
211 r"""The Wasserstein ball radius :math:`\epsilon`."""
212 return self._eps
214 @property
215 def samples(self):
216 """The registered samples as a :class:`~.samples.Samples` object."""
217 return self._samples
219 @property
220 def weights(self):
221 """The sample weights a constant PICOS vector."""
222 return self._weights
224 Subtype = namedtuple("Subtype", ("sample_dim", "sample_num", "p"))
226 def _subtype(self):
227 return self.Subtype(self._samples.dim, self._samples.num, self._p)
229 def __str__(self):
230 return "WAS(p={}, eps={}, N={})".format(
231 self._p, self._eps, self._samples.num)
233 @classmethod
234 def _get_type_string_base(cls):
235 return "Wasserstein Ambiguity Set"
237 def __repr__(self):
238 return glyphs.repr2("{} {}".format(glyphs.shape(self._parameter.shape),
239 self._get_type_string_base()), self.__str__())
241 @property
242 def distributional(self):
243 """Implement for :class:`~.perturbation.PerturbationUniverse`."""
244 return True
246 @property
247 def parameter(self):
248 r"""The random perturbation parameter :math:`\xi`."""
249 return self._parameter
252# --------------------------------------
253__all__ = api_end(_API_START, globals())