Coverage for picos/expressions/expression.py: 76.18%

487 statements  

« prev     ^ index     » next       coverage.py v7.6.12, created at 2025-04-12 07:53 +0000

1# ------------------------------------------------------------------------------ 

2# Copyright (C) 2019 Maximilian Stahlberg 

3# Based on the original picos.expressions module by Guillaume Sagnol. 

4# 

5# This file is part of PICOS. 

6# 

7# PICOS is free software: you can redistribute it and/or modify it under the 

8# terms of the GNU General Public License as published by the Free Software 

9# Foundation, either version 3 of the License, or (at your option) any later 

10# version. 

11# 

12# PICOS is distributed in the hope that it will be useful, but WITHOUT ANY 

13# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR 

14# A PARTICULAR PURPOSE. See the GNU General Public License for more details. 

15# 

16# You should have received a copy of the GNU General Public License along with 

17# this program. If not, see <http://www.gnu.org/licenses/>. 

18# ------------------------------------------------------------------------------ 

19 

20"""Backend for expression type implementations.""" 

21 

22import functools 

23import operator 

24import threading 

25import warnings 

26from abc import abstractmethod 

27from contextlib import contextmanager 

28 

29from .. import glyphs 

30from ..apidoc import api_end, api_start 

31from ..caching import cached_property 

32from ..constraints import ConstraintType 

33from ..containers import DetailedType 

34from ..legacy import deprecated 

35from ..valuable import NotValued, Valuable 

36from .data import convert_operands 

37 

38_API_START = api_start(globals()) 

39# ------------------------------- 

40 

41 

42def validate_prediction(the_operator): 

43 """Validate that the constraint outcome matches the predicted outcome.""" 

44 @functools.wraps(the_operator) 

45 def wrapper(lhs, rhs, *args, **kwargs): 

46 from .set import Set 

47 

48 def what(): 

49 return "({}).{}({})".format( 

50 lhs._symbStr, the_operator.__name__, rhs._symbStr) 

51 

52 assert isinstance(lhs, (Expression, Set)) \ 

53 and isinstance(rhs, (Expression, Set)), \ 

54 "validate_prediction must occur below convert_operands." 

55 

56 lhs_type = lhs.type 

57 rhs_type = rhs.type 

58 

59 try: 

60 abstract_operator = getattr(operator, the_operator.__name__) 

61 except AttributeError as error: 

62 raise AssertionError("validate_prediction may only decorate " 

63 "standard operator implementations.") from error 

64 

65 try: 

66 predictedType = lhs_type.predict(abstract_operator, rhs_type) 

67 except NotImplementedError: 

68 predictedType = None # No prediction was made. 

69 except PredictedFailure: 

70 predictedType = NotImplemented # Prediction is "not possible". 

71 

72 try: 

73 outcome = the_operator(lhs, rhs, *args, **kwargs) 

74 except Exception as error: 

75 # Case where the prediction is positive and the outcome is negative. 

76 if predictedType not in (None, NotImplemented): 

77 warnings.warn( 

78 "Outcome for {} was predicted {} but the operation raised " 

79 "an error: \"{}\" This a noncritical error (false positive)" 

80 " in PICOS' constraint outcome prediction." 

81 .format(what(), predictedType, error), 

82 category=RuntimeWarning, stacklevel=3) 

83 raise 

84 else: 

85 raise 

86 

87 # Case where the prediction is negative and the outcome is positive. 

88 if predictedType is NotImplemented and outcome is not NotImplemented: 

89 raise AssertionError( 

90 "The operation {} was predicted to fail but it produced " 

91 "an output of {}.".format(what(), outcome.type)) 

92 

93 # Case where no prediction was made. 

94 if not predictedType: 

95 return outcome 

96 

97 # Case where the outcome is try-to-reverse-the-operation. 

98 if outcome is NotImplemented: 

99 return outcome 

100 

101 # Case where the prediction and the outcome are positive but differ. 

102 outcomeType = outcome.type 

103 if not predictedType.equals(outcomeType): 

104 raise AssertionError("Outcome for {} was predicted {} but is {}." 

105 .format(what(), predictedType, outcomeType)) 

106 

107 return outcome 

108 return wrapper 

109 

110 

111def refine_operands(stop_at_affine=False): 

112 """Cast :meth:`~Expression.refined` on both operands. 

113 

114 If the left hand side operand (i.e. ``self``) is refined to an instance of a 

115 different type, then, instead of the decorated method, the method with the 

116 same name on the refined type is invoked with the (refined) right hand side 

117 operand as its argument. 

118 

119 This decorator is supposed to be used on all constraint creating binary 

120 operator methods so that degenerated instances (e.g. a complex affine 

121 expression with an imaginary part of zero) can occur but are not used in 

122 constraints. This speeds up many computations involving expressions as these 

123 degenerate cases do not need to be detected. Note that 

124 :attr:`Expression.type` also refers to the refined version of an expression. 

125 

126 :param bool stop_at_affine: Do not refine any affine expressions, in 

127 particular do not refine complex affine expressions to real ones. 

128 """ 

129 def decorator(the_operator): 

130 @functools.wraps(the_operator) 

131 def wrapper(lhs, rhs, *args, **kwargs): 

132 from .exp_affine import ComplexAffineExpression 

133 from .set import Set 

134 

135 assert isinstance(lhs, (Expression, Set)) \ 

136 and isinstance(rhs, (Expression, Set)), \ 

137 "refine_operands must occur below convert_operands." 

138 

139 if stop_at_affine and isinstance(lhs, ComplexAffineExpression): 

140 lhs_refined = lhs 

141 else: 

142 lhs_refined = lhs.refined 

143 

144 if type(lhs_refined) is not type(lhs): 

145 assert hasattr(lhs_refined, the_operator.__name__), \ 

146 "refine_operand transformed 'self' to another type that " \ 

147 "does not define an operator with the same name as the " \ 

148 "decorated one." 

149 

150 refined_operation = getattr(lhs_refined, the_operator.__name__) 

151 

152 return refined_operation(rhs, *args, **kwargs) 

153 

154 if stop_at_affine and isinstance(rhs, ComplexAffineExpression): 

155 rhs_refined = rhs 

156 else: 

157 rhs_refined = rhs.refined 

158 

159 return the_operator(lhs_refined, rhs_refined, *args, **kwargs) 

160 return wrapper 

161 return decorator 

162 

163 

164# TODO: Once PICOS requires Python >= 3.7, use a ContextVar instead. 

165class _Refinement(threading.local): 

166 allowed = True 

167 

168 

169_REFINEMENT = _Refinement() 

170 

171 

172@contextmanager 

173def no_refinement(): 

174 """Context manager that disables the effect of :meth:`Expression.refined`. 

175 

176 This can be necessary to ensure that the outcome of a constraint coversion 

177 is as predicted, in particular when PICOS uses overridden comparison 

178 operators for constraint creation internally. 

179 """ 

180 _REFINEMENT.allowed = False 

181 

182 try: 

183 yield 

184 finally: 

185 _REFINEMENT.allowed = True 

186 

187 

188class PredictedFailure(TypeError): 

189 """Denotes that comparing two expressions will not form a constraint.""" 

190 

191 pass 

192 

193 

194class ExpressionType(DetailedType): 

195 """The detailed type of an expression for predicting constraint outcomes. 

196 

197 This is suffcient to predict the detailed type of any constraint that can be 

198 created by comparing with another expression. 

199 """ 

200 

201 @staticmethod 

202 def _relation_str(relation): 

203 if relation is operator.__eq__: 

204 return "==" 

205 elif relation is operator.__le__: 

206 return "<=" 

207 elif relation is operator.__ge__: 

208 return ">=" 

209 elif relation is operator.__lshift__: 

210 return "<<" 

211 elif relation is operator.__rshift__: 

212 return ">>" 

213 else: 

214 return "??" 

215 

216 @staticmethod 

217 def _swap_relation(relation): 

218 if relation is operator.__eq__: 

219 return operator.__eq__ 

220 elif relation is operator.__le__: 

221 return operator.__ge__ 

222 elif relation is operator.__ge__: 

223 return operator.__le__ 

224 elif relation is operator.__lshift__: 

225 return operator.__rshift__ 

226 elif relation is operator.__rshift__: 

227 return operator.__lshift__ 

228 else: 

229 return None 

230 

231 def predict(self, relation, other): 

232 """Predict the constraint outcome of comparing expressions. 

233 

234 :param relation: 

235 An object from the :mod:`operator` namespace representing the 

236 operation being predicted. 

237 

238 :param other: 

239 Another expression type representing the right hand side operand. 

240 :type other: 

241 ~picos.expressions.expression.ExpressionType 

242 

243 :Example: 

244 

245 >>> import operator, picos 

246 >>> a = picos.RealVariable("x") + 1 

247 >>> b = picos.RealVariable("y") + 2 

248 >>> (a <= b).type == a.type.predict(operator.__le__, b.type) 

249 True 

250 """ 

251 if not isinstance(other, ExpressionType): 

252 raise TypeError("The 'other' argument must be another {} instance." 

253 .format(self.__class__.__name__)) 

254 

255 # Perform the forward prediction. 

256 result = self.clstype._predict(self.subtype, relation, other) 

257 

258 # Fall back to the backward prediction. 

259 if result is NotImplemented: 

260 reverse = self._swap_relation(relation) 

261 result = other.clstype._predict(other.subtype, reverse, self) 

262 

263 # If both fail, the prediction is "not possible". 

264 if result is NotImplemented: 

265 raise PredictedFailure( 

266 "The statement {} {} {} is predicted to error." 

267 .format(self, self._relation_str(relation), other)) 

268 else: 

269 assert isinstance(result, ConstraintType) 

270 return result 

271 

272 

273class Expression(Valuable): 

274 """Abstract base class for mathematical expressions, including mutables. 

275 

276 For mutables, this is the secondary base class, with 

277 :class:`~.mutable.Mutable` or a subclass thereof being the primary one. 

278 """ 

279 

280 def __init__(self, typeStr, symbStr): 

281 """Perform basic initialization for :class:`Expression` instances. 

282 

283 :param str typeStr: Short string denoting the expression type. 

284 :param str symbStr: Algebraic string description of the expression. 

285 """ 

286 self._typeStr = typeStr 

287 """A string describing the expression type.""" 

288 

289 self._symbStr = symbStr 

290 """A symbolic string representation of the expression. It is always used 

291 by __descr__, and it is equivalent to the value returned by __str__ when 

292 the expression is not fully valued.""" 

293 

294 @property 

295 def string(self): 

296 """Symbolic string representation of the expression. 

297 

298 Use this over Python's :class:`str` if you want to output the symbolic 

299 representation even when the expression is valued. 

300 """ 

301 return self._symbStr 

302 

303 # -------------------------------------------------------------------------- 

304 # Abstract method implementations for the Valuable base class. 

305 # NOTE: _get_value and possibly _set_value are implemented by subclasses. 

306 # -------------------------------------------------------------------------- 

307 

308 def _get_valuable_string(self): 

309 return "expression {}".format(self.string) 

310 

311 # -------------------------------------------------------------------------- 

312 # Abstract and default-implementation methods. 

313 # -------------------------------------------------------------------------- 

314 

315 def _get_refined(self): 

316 """See :attr:`refined`.""" 

317 return self 

318 

319 def _get_clstype(self): 

320 """Return the Python class part of the expression's detailed type.""" 

321 return self.__class__ 

322 

323 @property 

324 @abstractmethod 

325 def Subtype(self): 

326 """The class of which :attr:`subtype` returns an instance. 

327 

328 Instances must be hashable. By convention a 

329 :func:`namedtuple <collections.namedtuple>` class. 

330 

331 .. warning:: 

332 This should be declared in the class body as e.g. 

333 `Subtype = namedtuple(…)` and not as a property so that it's static. 

334 """ 

335 pass 

336 

337 @abstractmethod 

338 def _get_subtype(self): 

339 """See :attr:`subtype`.""" 

340 pass 

341 

342 @classmethod 

343 @abstractmethod 

344 def _predict(cls, subtype, relation, other): 

345 """Predict the constraint outcome of a comparison. 

346 

347 :param object subtype: An object returned by the :meth:`_get_subtype` 

348 instance method of :class:`cls`. 

349 :param method-wrapper relation: A function from the :mod:`operator` 

350 namespace, such as :func:`operator.__le__`. See 

351 :class:`ExpressionType` for what operators are defined. 

352 :param ExpressionType other: The detailed type of another expression. 

353 :returns: Either the :obj:`NotImplemented` token or a 

354 :class:`ConstraintType` object such that an instance of :class:`cls` 

355 with the given subtype, when compared with another expression with 

356 the given expression type, returns a constraint with that constraint 

357 type. 

358 """ 

359 pass 

360 

361 def _get_shape(self): 

362 """Return the algebraic shape of the expression.""" 

363 return (1, 1) 

364 

365 @abstractmethod 

366 def _get_mutables(self): 

367 """Return the set of mutables that are involved in the expression.""" 

368 pass 

369 

370 @abstractmethod 

371 def _is_convex(self): 

372 """Whether the expression is convex in its :attr:`variables`. 

373 

374 Method implementations may assume that the expression is refined. Thus, 

375 degenerate cases affected by refinement do not need to be considered. 

376 

377 For uncertain expressions, this assumes the perturbation as constant. 

378 """ 

379 pass 

380 

381 @abstractmethod 

382 def _is_concave(self): 

383 """Whether the expression is concave in its :attr:`variables`. 

384 

385 Method implementations may assume that the expression is refined. Thus, 

386 degenerate cases affected by refinement do not need to be considered. 

387 

388 For uncertain expressions, this assumes the perturbation as constant. 

389 """ 

390 pass 

391 

392 @abstractmethod 

393 def _replace_mutables(self, mapping): 

394 """Return a copy of the expression concerning different mutables. 

395 

396 This is the fast internal-use counterpart to :meth:`replace_mutables`. 

397 

398 The returned expression should be of the same type as ``self`` (no 

399 refinement) so that it can be substituted in composite expressions. 

400 

401 :param dict mapping: 

402 A mutable replacement map. The caller must ensure the following 

403 properties: 

404 

405 1. This must be a complete map from existing mutables to the same 

406 mutable, another mutable, or a real-valued affine expression 

407 (completeness). 

408 2. The shape and vectorization format of each replacement must match 

409 the existing mutable. Replacing with affine expressions is only 

410 allowed when the existing mutable uses the trivial 

411 :class:`~vectorizations.FullVectorization` (soudness). 

412 3. Mutables that appear in a replacement may be the same as the 

413 mutable being replaced but may otherwise not appear in the 

414 expression (freshness). 

415 4. Mutables may appear at most once anywhere in the image of the map 

416 (uniqueness). 

417 

418 If any property is not fulfilled, the implementation does not need 

419 to raise a proper exception but may fail arbitrarily. 

420 """ 

421 pass 

422 

423 @abstractmethod 

424 def _freeze_mutables(self, subset): 

425 """Return a copy with some mutables frozen to their current value. 

426 

427 This is the fast internal-use counterpart to :meth:`frozen`. 

428 

429 The returned expression should be of the same type as ``self`` (no 

430 refinement) so that it can be substituted in composite expressions. 

431 

432 :param dict subset: 

433 An iterable of valued :class:`mutables <.mutable.Mutable>` that 

434 should be frozen. May include mutables that are not present in the 

435 expression, but may not include mutables without a value. 

436 """ 

437 pass 

438 

439 # -------------------------------------------------------------------------- 

440 # An interface to the abstract and default-implementation methods above. 

441 # -------------------------------------------------------------------------- 

442 

443 @property 

444 def refined(self): 

445 """A refined version of the expression. 

446 

447 The refined expression can be an instance of a different 

448 :class:`Expression` subclass than the original expression, if that type 

449 is better suited for the mathematical object in question. 

450 

451 The refined expression is automatically used instead of the original one 

452 whenever a constraint is created, and in some other places. 

453 

454 The idea behind refined expressions is that operations that produce new 

455 expressions can be executed quickly without checking for exceptionnel 

456 cases. For instance, the sum of two 

457 :class:`~.exp_affine.ComplexAffineExpression` instances could have the 

458 complex part eliminated so that storing the result as an 

459 :class:`~.exp_affine.AffineExpression` would be prefered, but checking 

460 for this case on every addition would be too slow. Refinement is used 

461 sparingly to detect such cases at times where it makes the most sense. 

462 

463 Refinement may be disallowed within a context with the 

464 :func:`no_refinement` context manager. In this case, this property 

465 returns the expression as is. 

466 """ 

467 if not _REFINEMENT.allowed: 

468 return self 

469 

470 fine = self._get_refined() 

471 

472 if fine is not self: 

473 # Recursively refine until the expression doesn't change further. 

474 return fine.refined 

475 else: 

476 return fine 

477 

478 @property 

479 def subtype(self): 

480 """The subtype part of the expression's detailed type. 

481 

482 Returns a hashable object that, together with the Python class part of 

483 the expression's type, is sufficient to predict the constraint outcome 

484 (constraint class and subtype) of any comparison operation with any 

485 other expression. 

486 

487 By convention the object returned is a 

488 :func:`namedtuple <collections.namedtuple>` instance. 

489 """ 

490 return self._get_subtype() 

491 

492 @property 

493 def type(self): 

494 """The expression's detailed type for constraint prediction. 

495 

496 The returned value is suffcient to predict the detailed type of any 

497 constraint that can be created by comparing with another expression. 

498 

499 Since constraints are created from 

500 :attr:`~.expression.Expression.refined` expressions only, the Python 

501 class part of the detailed type may differ from the type of the 

502 expression whose :attr:`type` is queried. 

503 """ 

504 refined = self.refined 

505 return ExpressionType(refined._get_clstype(), refined._get_subtype()) 

506 

507 @classmethod 

508 def make_type(cls, *args, **kwargs): 

509 """Create a detailed expression type from subtype parameters.""" 

510 return ExpressionType(cls, cls.Subtype(*args, **kwargs)) 

511 

512 shape = property( 

513 lambda self: self._get_shape(), 

514 doc=_get_shape.__doc__) 

515 

516 size = property( 

517 lambda self: self._get_shape(), 

518 doc="""The same as :attr:`shape`.""") 

519 

520 @property 

521 def scalar(self): 

522 """Whether the expression is scalar.""" 

523 return self._get_shape() == (1, 1) 

524 

525 @property 

526 def square(self): 

527 """Whether the expression is a square matrix.""" 

528 shape = self._get_shape() 

529 return shape[0] == shape[1] 

530 

531 mutables = property( 

532 lambda self: self._get_mutables(), 

533 doc=_get_mutables.__doc__) 

534 

535 @property 

536 def constant(self): 

537 """Whether the expression involves no mutables.""" 

538 return not self._get_mutables() 

539 

540 @cached_property 

541 def variables(self): 

542 """The set of decision variables that are involved in the expression.""" 

543 from .variables import BaseVariable 

544 

545 return frozenset(mutable for mutable in self._get_mutables() 

546 if isinstance(mutable, BaseVariable)) 

547 

548 @cached_property 

549 def parameters(self): 

550 """The set of parameters that are involved in the expression.""" 

551 from .variables import BaseVariable 

552 

553 return frozenset(mutable for mutable in self._get_mutables() 

554 if not isinstance(mutable, BaseVariable)) 

555 

556 @property 

557 def convex(self): 

558 """Whether the expression is convex.""" 

559 return self.refined._is_convex() 

560 

561 @property 

562 def concave(self): 

563 """Whether the expression is concave.""" 

564 return self.refined._is_concave() 

565 

566 def replace_mutables(self, replacement): 

567 """Return a copy of the expression concerning different mutables. 

568 

569 New mutables must have the same shape and vectorization format as the 

570 mutables that they replace. This means in particular that 

571 :class:`~.variables.RealVariable`, :class:`~.variables.IntegerVariable` 

572 and :class:`~.variables.BinaryVariable` of same shape are 

573 interchangeable. 

574 

575 If the mutables to be replaced do not appear in the expression, then 

576 the expression is not copied but returned as is. 

577 

578 :param replacement: 

579 Either a map from mutables or mutable names to new mutables or an 

580 iterable of new mutables to replace existing mutables of same name 

581 with. See the section on advanced usage for additional options. 

582 :type replacement: 

583 tuple or list or dict 

584 

585 :returns Expression: 

586 The new expression, refined to a more suitable type if possible. 

587 

588 :Advanced replacement: 

589 

590 It is also possible to replace mutables with real affine expressions 

591 concerning pairwise disjoint sets of fresh mutables. This works only on 

592 real-valued mutables that have a trivial internal vectorization format 

593 (i.e. :class:`~.vectorizations.FullVectorization`). The shape of the 

594 replacing expression must match the variable's. Additional limitations 

595 depending on the type of expression that the replacement is invoked on 

596 are possible. The ``replacement`` argument must be a dictionary. 

597 

598 :Example: 

599 

600 >>> import picos 

601 >>> x = picos.RealVariable("x"); x.value = 1 

602 >>> y = picos.RealVariable("y"); y.value = 10 

603 >>> z = picos.RealVariable("z"); z.value = 100 

604 >>> c = picos.Constant("c", 1000) 

605 >>> a = x + 2*y; a 

606 <1×1 Real Linear Expression: x + 2·y> 

607 >>> a.value 

608 21.0 

609 >>> b = a.replace_mutables({y: z}); b # Replace y with z. 

610 <1×1 Real Linear Expression: x + 2·z> 

611 >>> b.value 

612 201.0 

613 >>> d = a.replace_mutables({x: 2*x + z, y: c}); d # Advanced use. 

614 <1×1 Real Affine Expression: 2·x + z + 2·c> 

615 >>> d.value 

616 2102.0 

617 """ 

618 from .exp_biaffine import BiaffineExpression 

619 from .mutable import Mutable 

620 from .vectorizations import FullVectorization 

621 

622 # Change an iterable of mutables to a map from names to mutables. 

623 if not isinstance(replacement, dict): 

624 if not all(isinstance(new, Mutable) for new in replacement): 

625 raise TypeError("If 'replacement' is a non-dictionary iterable," 

626 " then it may only contain mutables.") 

627 

628 new_replacement = {new.name: new for new in replacement} 

629 

630 if len(new_replacement) != len(replacement): 

631 raise TypeError("If 'replacement' is a non-dictionary iterable," 

632 " then the mutables within must have unique names.") 

633 

634 replacement = new_replacement 

635 

636 # Change a map from names to a map from existing mutables. 

637 # Names that reference non-existing mutables are dropped. 

638 old_mtbs_by_name = {mtb.name: mtb for mtb in self.mutables} 

639 replacing_by_name = False 

640 new_replacement = {} 

641 for old, new in replacement.items(): 

642 if isinstance(old, Mutable): 

643 new_replacement[old] = new 

644 elif not isinstance(old, str): 

645 raise TypeError( 

646 "Keys of 'replacement' must be mutables or names thereof.") 

647 else: 

648 replacing_by_name = True 

649 if old in old_mtbs_by_name: 

650 new_replacement[old_mtbs_by_name[old]] = new 

651 replacement = new_replacement 

652 

653 # Check unique naming of existing mutables if it matters. 

654 if replacing_by_name and len(old_mtbs_by_name) != len(self.mutables): 

655 raise RuntimeError("Cannot replace mutables by name in {} as " 

656 "its mutables are not uniquely named.".format(self.string)) 

657 

658 # Remove non-existing sources and identities. 

659 assert all(isinstance(old, Mutable) for old in replacement) 

660 replacement = {old: new for old, new in replacement.items() 

661 if old is not new and old in self.mutables} 

662 

663 # Do nothing if there is nothing to replace. 

664 if not replacement: 

665 return self 

666 

667 # Validate individual replacement requirements. 

668 for old, new in replacement.items(): 

669 # Replacement must be a mutable or biaffine expression. 

670 if not isinstance(new, BiaffineExpression): 

671 raise TypeError("Can only replace mutables with other mutables " 

672 "or affine expressions thereof.") 

673 

674 # Shapes must match. 

675 if old.shape != new.shape: 

676 raise TypeError( 

677 "Cannot replace {} with {} in {}: Differing shape." 

678 .format(old.name, new.name, self.string)) 

679 

680 # Special requirements when replacing with mutables or expressions. 

681 if isinstance(new, Mutable): 

682 # Vectorization formats must match. 

683 if type(old._vec) != type(new._vec): # noqa: E721 

684 raise TypeError("Cannot replace {} with {} in {}: " 

685 "Differing vectorization." 

686 .format(old.name, new.name, self.string)) 

687 else: 

688 # Replaced mutable must use a trivial vectorization. 

689 if not isinstance(old._vec, FullVectorization): 

690 raise TypeError("Can only replace mutables using a trivial " 

691 "vectorization format with affine expressions.") 

692 

693 # Replacing expression must be real-valued and affine. 

694 if new._bilinear_coefs or new.complex: 

695 raise TypeError("Can only replace mutables with real-valued" 

696 " affine expressions.") 

697 

698 old_mtbs_set = set(replacement) 

699 new_mtbs_lst = [mtb # Excludes each mutable being replaced. 

700 for old, new in replacement.items() 

701 for mtb in new.mutables.difference((old,))] 

702 new_mtbs_set = set(new_mtbs_lst) 

703 

704 # New mutables must be fresh. 

705 # It is OK to replace a mutable with itself or an affine expression of 

706 # itself and other fresh mutables, though. 

707 if old_mtbs_set.intersection(new_mtbs_set): 

708 raise ValueError("Can only replace mutables with fresh mutables " 

709 "or affine expressions of all fresh mutables (the old mutable " 

710 "may appear in the expression).") 

711 

712 # New mutables must be unique. 

713 if len(new_mtbs_lst) != len(new_mtbs_set): 

714 raise ValueError("Can only replace multiple mutables at once if " 

715 "the replacing mutables (and/or the mutables in replacing " 

716 "expressions) are all unique.") 

717 

718 # Turn the replacement map into a complete map. 

719 mapping = {mtb: mtb for mtb in self.mutables} 

720 mapping.update(replacement) 

721 

722 # Replace recursively and refine the result. 

723 return self._replace_mutables(mapping).refined 

724 

725 def frozen(self, subset=None): 

726 """The expression with valued mutables frozen to their current value. 

727 

728 If all mutables of the expression are valued (and in the subset unless 

729 ``subset=None``), this is the same as the inversion operation ``~``. 

730 

731 If the mutables to be frozen do not appear in the expression, then the 

732 expression is not copied but returned as is. 

733 

734 :param subset: 

735 An iterable of valued :class:`mutables <.mutable.Mutable>` or names 

736 thereof that should be frozen. If :obj:`None`, then all valued 

737 mutables are frozen to their current value. May include mutables 

738 that are not present in the expression, but may not include mutables 

739 without a value. 

740 

741 :returns Expression: 

742 The frozen expression, refined to a more suitable type if possible. 

743 

744 :Example: 

745 

746 >>> from picos import RealVariable 

747 >>> x, y = RealVariable("x"), RealVariable("y") 

748 >>> f = x + y; f 

749 <1×1 Real Linear Expression: x + y> 

750 >>> sorted(f.mutables, key=lambda mtb: mtb.name) 

751 [<1×1 Real Variable: x>, <1×1 Real Variable: y>] 

752 >>> x.value = 5 

753 >>> g = f.frozen(); g # g is f with x frozen at its current value of 5. 

754 <1×1 Real Affine Expression: [x] + y> 

755 >>> sorted(g.mutables, key=lambda mtb: mtb.name) 

756 [<1×1 Real Variable: y>] 

757 >>> x.value, y.value = 10, 10 

758 >>> f.value # x takes its new value in f. 

759 20.0 

760 >>> g.value # x remains frozen at [x] = 5 in g. 

761 15.0 

762 >>> # If an expression is frozen to a constant, this is reversable: 

763 >>> f.frozen().equals(~f) and ~f.frozen() is f 

764 True 

765 """ 

766 from .mutable import Mutable 

767 

768 # Collect mutables to be frozen in the expression. 

769 if subset is None: 

770 freeze = set(mtb for mtb in self.mutables if mtb.valued) 

771 else: 

772 if not all(isinstance(mtb, (str, Mutable)) for mtb in subset): 

773 raise TypeError("Some element of the subset of mutables to " 

774 "freeze is neither a mutable nor a string.") 

775 

776 subset_mtbs = set(m for m in subset if isinstance(m, Mutable)) 

777 subset_name = set(n for n in subset if isinstance(n, str)) 

778 

779 freeze = set() 

780 if subset_mtbs: 

781 freeze.update(m for m in subset_mtbs if m in self.mutables) 

782 if subset_name: 

783 freeze.update(m for m in self.mutables if m.name in subset_name) 

784 

785 if not all(mtb.valued for mtb in freeze): 

786 raise NotValued( 

787 "Not all mutables in the selected subset are valued.") 

788 

789 if not freeze: 

790 return self 

791 

792 if freeze == self.mutables: 

793 return ~self # Allow ~self.frozen() to return self. 

794 

795 return self._freeze_mutables(freeze).refined 

796 

797 @property 

798 def certain(self): 

799 """Always :obj:`True` for certain expression types. 

800 

801 This can be :obj:`False` for Expression types that inherit from 

802 :class:`~.uexpression.UncertainExpression` (with priority). 

803 """ 

804 return True 

805 

806 @property 

807 def uncertain(self): 

808 """Always :obj:`False` for certain expression types. 

809 

810 This can be :obj:`True` for Expression types that inherit from 

811 :class:`~.uexpression.UncertainExpression` (with priority). 

812 """ 

813 return False 

814 

815 # -------------------------------------------------------------------------- 

816 # Python special method implementations. 

817 # -------------------------------------------------------------------------- 

818 

819 def __len__(self): 

820 """Report the number of entries of the (multidimensional) expression.""" 

821 return self.shape[0] * self.shape[1] 

822 

823 def __le__(self, other): 

824 """Return a constraint that the expression is upper-bounded.""" 

825 # Try to refine self and see if the operation is then supported. 

826 # This allows e.g. a <= 0 if a is a real-valued complex expression. 

827 refined = self.refined 

828 if type(refined) is not type(self): 

829 return refined.__le__(other) 

830 

831 return NotImplemented 

832 

833 def __ge__(self, other): 

834 """Return a constraint that the expression is lower-bounded.""" 

835 # Try to refine self and see if the operation is then supported. 

836 # This allows e.g. a >= 0 if a is a real-valued complex expression. 

837 refined = self.refined 

838 if type(refined) is not type(self): 

839 return refined.__ge__(other) 

840 

841 return NotImplemented 

842 

843 def __invert__(self): 

844 """Convert between a valued expression and its value. 

845 

846 The value is returned as a constant affine expression whose conversion 

847 returns the original expression. 

848 """ 

849 if hasattr(self, "_origin"): 

850 return self._origin 

851 elif self.constant: 

852 return self 

853 

854 from .exp_affine import Constant 

855 

856 A = Constant( 

857 glyphs.frozen(self.string), self.safe_value_as_matrix, self.shape) 

858 A._origin = self 

859 return A 

860 

861 def __contains__(self, mutable): 

862 """Report whether the expression concerns the given mutable.""" 

863 return mutable in self.mutables 

864 

865 def __eq__(self, exp): 

866 """Return an equality constraint concerning the expression.""" 

867 raise NotImplementedError("PICOS supports equality comparison only " 

868 "between affine expressions, as otherwise the problem would " 

869 "become non-convex. Choose either <= or >= if possible.") 

870 

871 def __repr__(self): 

872 """Return a bracketed string description of the expression. 

873 

874 The description contains both the mathematical type and a symbolic 

875 description of the expression. 

876 """ 

877 return str(glyphs.repr2(self._typeStr, self._symbStr)) 

878 

879 def __str__(self): 

880 """Return a dynamic string description of the expression. 

881 

882 The description is based on whether the expression is valued. If it is 

883 valued, then a string representation of the value is returned. 

884 Otherwise, the symbolic description of the expression is returned. 

885 """ 

886 value = self.value 

887 

888 if value is None: 

889 return str(self._symbStr) 

890 else: 

891 return str(value).strip() 

892 

893 def __format__(self, format_spec): 

894 """Format either the value or the symbolic string of the expression. 

895 

896 If the expression is valued, then its value is formatted, otherwise its 

897 symbolic string description. 

898 """ 

899 value = self.value 

900 

901 if value is None: 

902 return self._symbStr.__format__(format_spec) 

903 else: 

904 return value.__format__(format_spec) 

905 

906 # Since we define __eq__, __hash__ is not inherited. Do this manually. 

907 __hash__ = object.__hash__ 

908 

909 # -------------------------------------------------------------------------- 

910 # Fallback algebraic operations: Try again with converted RHS, refined LHS. 

911 # NOTE: The forward operations call the backward operations manually 

912 # (instead of returning NotImplemented) so that they can be performed 

913 # on a converted operand, which is always a PICOS type. The backward 

914 # operations then use WeightedSum as a last fallback where applicable. 

915 # -------------------------------------------------------------------------- 

916 

917 def _wsum_fallback(self, summands, weights, opstring): 

918 """Try to represent the result as a weighted sum.""" 

919 from .exp_wsum import WeightedSum 

920 

921 # NOTE: WeightedSum with an opstring set will act as a final fallback 

922 # and raise a proper exception if the result can't be represented. 

923 # This is handled there and not here so that also operations on 

924 # existing WeightedSum instances can produce such exceptions, as 

925 # they cannot fallback to Expression like other operations do. 

926 return WeightedSum(summands, weights, opstring) 

927 

928 def _scalar_mult_fallback(self, lhs, rhs): 

929 """Try to express scalar by scalar multiplication as a weighted sum.""" 

930 assert isinstance(lhs, Expression) and isinstance(rhs, Expression) 

931 

932 opstring = "a product between {} and {}".format(repr(lhs), repr(rhs)) 

933 

934 if lhs.scalar and lhs.constant: 

935 return self._wsum_fallback((rhs,), lhs.safe_value, opstring) 

936 elif rhs.scalar and rhs.constant: 

937 return self._wsum_fallback((lhs,), rhs.safe_value, opstring) 

938 else: 

939 # NOTE: Constant scalars are also AffineExpression but otherwise 

940 # raising the default Python TypeError (stating that the two 

941 # types are fully operation-incompatible) makes sense here. 

942 return NotImplemented 

943 

944 @convert_operands(sameShape=True) 

945 def __add__(self, other): 

946 """Denote addition with another expression on the right-hand side.""" 

947 if type(self.refined) is not type(self): 

948 return self.refined.__add__(other) 

949 else: 

950 return other.__radd__(self) 

951 

952 @convert_operands(sameShape=True) 

953 def __radd__(self, other): 

954 """Denote addition with another expression on the left-hand side.""" 

955 if type(self.refined) is not type(self): 

956 return self.refined.__radd__(other) 

957 else: 

958 opstring = "{} plus {}".format(repr(other), repr(self)) 

959 return self._wsum_fallback((other, self), (1, 1), opstring) 

960 

961 @convert_operands(sameShape=True) 

962 def __sub__(self, other): 

963 """Denote subtraction of another expression from the expression.""" 

964 if type(self.refined) is not type(self): 

965 return self.refined.__sub__(other) 

966 else: 

967 return other.__rsub__(self) 

968 

969 @convert_operands(sameShape=True) 

970 def __rsub__(self, other): 

971 """Denote subtraction of the expression from another expression.""" 

972 if type(self.refined) is not type(self): 

973 return self.refined.__rsub__(other) 

974 else: 

975 opstring = "{} minus {}".format(repr(other), repr(self)) 

976 return self._wsum_fallback((other, self), (1, -1), opstring) 

977 

978 @convert_operands(sameShape=True) 

979 def __or__(self, other): 

980 r"""Denote the scalar product with another expression on the right. 

981 

982 For (complex) vectors :math:`a` and :math:`b` this is the dot product 

983 

984 .. math:: 

985 (a \mid b) 

986 &= \langle a \mid b \rangle \\ 

987 &= \langle a, b \rangle \\ 

988 &= a \cdot b \\ 

989 &= a^H b. 

990 

991 For (complex) matrices :math:`A` and :math:`B` this is the Frobenius 

992 inner product 

993 

994 .. math:: 

995 (A \mid B) 

996 &= \langle A, B \rangle_F \\ 

997 &= A : B \\ 

998 &= \operatorname{tr}(A^H B) \\ 

999 &= \operatorname{vec}(\overline{A})^T \operatorname{vec}(B) 

1000 

1001 .. note:: 

1002 Write ``(A|B)`` instead of ``A|B`` for the scalar product of ``A`` 

1003 and ``B`` to obtain correct operator binding within a larger 

1004 expression context. 

1005 """ 

1006 if type(self.refined) is not type(self): 

1007 return self.refined.__or__(other) 

1008 else: 

1009 return other.__ror__(self) 

1010 

1011 @convert_operands(sameShape=True) 

1012 def __ror__(self, other): 

1013 """Denote the scalar product with another expression on the left. 

1014 

1015 See :meth:`__or__` for details on this operation. 

1016 """ 

1017 if type(self.refined) is not type(self): 

1018 return self.refined.__ror__(other) 

1019 else: 

1020 return self._scalar_mult_fallback(other, self) 

1021 

1022 @convert_operands(rMatMul=True) 

1023 def __mul__(self, other): 

1024 """Denote multiplication with another expression on the right.""" 

1025 if type(self.refined) is not type(self): 

1026 return self.refined.__mul__(other) 

1027 else: 

1028 return other.__rmul__(self) 

1029 

1030 @convert_operands(lMatMul=True) 

1031 def __rmul__(self, other): 

1032 """Denote multiplication with another expression on the left.""" 

1033 if type(self.refined) is not type(self): 

1034 return self.refined.__rmul__(other) 

1035 else: 

1036 return self._scalar_mult_fallback(other, self) 

1037 

1038 @convert_operands(sameShape=True) 

1039 def __xor__(self, other): 

1040 """Denote the entrywise product with another expression on the right.""" 

1041 if type(self.refined) is not type(self): 

1042 return self.refined.__xor__(other) 

1043 else: 

1044 return other.__rxor__(self) 

1045 

1046 @convert_operands(sameShape=True) 

1047 def __rxor__(self, other): 

1048 """Denote the entrywise product with another expression on the left.""" 

1049 if type(self.refined) is not type(self): 

1050 return self.refined.__rxor__(other) 

1051 else: 

1052 return self._scalar_mult_fallback(other, self) 

1053 

1054 @convert_operands() 

1055 def __matmul__(self, other): 

1056 """Denote the Kronecker product with another expression on the right.""" 

1057 if type(self.refined) is not type(self): 

1058 return self.refined.__matmul__(other) 

1059 else: 

1060 return other.__rmatmul__(self) 

1061 

1062 @convert_operands() 

1063 def __rmatmul__(self, other): 

1064 """Denote the Kronecker product with another expression on the left.""" 

1065 if type(self.refined) is not type(self): 

1066 return self.refined.__rmatmul__(other) 

1067 else: 

1068 return self._scalar_mult_fallback(other, self) 

1069 

1070 @convert_operands(scalarRHS=True) 

1071 def __truediv__(self, other): 

1072 """Denote division by another, scalar expression.""" 

1073 if type(self.refined) is not type(self): 

1074 return self.refined.__truediv__(other) 

1075 else: 

1076 return other.__rtruediv__(self) 

1077 

1078 @convert_operands(scalarLHS=True) 

1079 def __rtruediv__(self, other): 

1080 """Denote scalar division of another expression.""" 

1081 if type(self.refined) is not type(self): 

1082 return self.refined.__rtruediv__(other) 

1083 else: 

1084 if self.constant and not self.is0: 

1085 try: 

1086 return other.__mul__(1 / self.safe_value) 

1087 except TypeError: 

1088 assert False, "Multiplication of {} by a nonzero constant" \ 

1089 " has unexpectedly failed; it should have produced a " \ 

1090 "weighted sum.".format(repr(other)) 

1091 else: 

1092 reason = "nonconstant" if not self.constant else "zero" 

1093 raise TypeError("Cannot divide {} by {}: The denominator is {}." 

1094 .format(repr(other), repr(self), reason)) 

1095 

1096 @convert_operands(scalarRHS=True) 

1097 def __pow__(self, other): 

1098 """Denote exponentiation with another, scalar expression.""" 

1099 if type(self.refined) is not type(self): 

1100 return self.refined.__pow__(other) 

1101 else: 

1102 return other.__rpow__(self) 

1103 

1104 @convert_operands(scalarLHS=True) 

1105 def __rpow__(self, other): 

1106 """Denote taking another expression to the power of the expression.""" 

1107 if type(self.refined) is not type(self): 

1108 return self.refined.__rpow__(other) 

1109 else: 

1110 return NotImplemented 

1111 

1112 @convert_operands(horiCat=True) 

1113 def __and__(self, other): 

1114 """Denote horizontal stacking with another expression on the right.""" 

1115 if type(self.refined) is not type(self): 

1116 return self.refined.__and__(other) 

1117 else: 

1118 return other.__rand__(self) 

1119 

1120 @convert_operands(horiCat=True) 

1121 def __rand__(self, other): 

1122 """Denote horizontal stacking with another expression on the left.""" 

1123 if type(self.refined) is not type(self): 

1124 return self.refined.__rand__(other) 

1125 else: 

1126 return NotImplemented 

1127 

1128 @convert_operands(vertCat=True) 

1129 def __floordiv__(self, other): 

1130 """Denote vertical stacking with another expression below.""" 

1131 if type(self.refined) is not type(self): 

1132 return self.refined.__floordiv__(other) 

1133 else: 

1134 return other.__rfloordiv__(self) 

1135 

1136 @convert_operands(vertCat=True) 

1137 def __rfloordiv__(self, other): 

1138 """Denote vertical stacking with another expression above.""" 

1139 if type(self.refined) is not type(self): 

1140 return self.refined.__rfloordiv__(other) 

1141 else: 

1142 return NotImplemented 

1143 

1144 def __pos__(self): 

1145 """Return the expression as-is.""" 

1146 return self 

1147 

1148 def __neg__(self): 

1149 """Denote the negation of the expression.""" 

1150 if type(self.refined) is not type(self): 

1151 return self.refined.__neg__() 

1152 else: 

1153 opstring = "the negation of {}".format(repr(self)) 

1154 return self._wsum_fallback((self,), -1, opstring) 

1155 

1156 def __abs__(self): 

1157 """Denote the default norm of the expression. 

1158 

1159 The norm used depends on the expression's domain. It is 

1160 

1161 1. the absolute value of a real scalar, 

1162 2. the modulus of a complex scalar, 

1163 3. the Euclidean norm of a vector, and 

1164 4. the Frobenius norm of a matrix. 

1165 """ 

1166 if type(self.refined) is not type(self): 

1167 return self.refined.__abs__() 

1168 else: 

1169 return NotImplemented 

1170 

1171 # -------------------------------------------------------------------------- 

1172 # Turn __lshift__ and __rshift__ into a single binary relation. 

1173 # This is used for both Loewner order (defining LMIs) and set membership. 

1174 # -------------------------------------------------------------------------- 

1175 

1176 def _lshift_implementation(self, other): 

1177 return NotImplemented 

1178 

1179 def _rshift_implementation(self, other): 

1180 return NotImplemented 

1181 

1182 @convert_operands(diagBroadcast=True) 

1183 @validate_prediction 

1184 @refine_operands() 

1185 def __lshift__(self, other): 

1186 """Denote either set membership or a linear matrix inequality. 

1187 

1188 If the other operand is a set, then this denotes that the expression 

1189 shall be constrained to that set. Otherwise, it is expected that both 

1190 expressions are square matrices of same shape and this denotes that the 

1191 expression is upper-bounded by the other expression with respect to the 

1192 Loewner order (i.e. ``other - self`` is positive semidefinite). 

1193 """ 

1194 result = self._lshift_implementation(other) 

1195 

1196 if result is NotImplemented: 

1197 result = other._rshift_implementation(self) 

1198 

1199 return result 

1200 

1201 @convert_operands(diagBroadcast=True) 

1202 @validate_prediction 

1203 @refine_operands() 

1204 def __rshift__(self, other): 

1205 """Denote that the expression is lower-bounded in the Lowener order. 

1206 

1207 In other words, return a constraint that ``self - other`` is positive 

1208 semidefinite. 

1209 """ 

1210 result = self._rshift_implementation(other) 

1211 

1212 if result is NotImplemented: 

1213 result = other._lshift_implementation(self) 

1214 

1215 return result 

1216 

1217 # -------------------------------------------------------------------------- 

1218 # Backwards compatibility methods. 

1219 # -------------------------------------------------------------------------- 

1220 

1221 @deprecated("2.0", useInstead="~picos.valuable.Valuable.valued") 

1222 def is_valued(self): 

1223 """Whether the expression is valued.""" 

1224 return self.valued 

1225 

1226 @deprecated("2.0", useInstead="~picos.valuable.Valuable.value") 

1227 def set_value(self, value): 

1228 """Set the value of an expression.""" 

1229 self.value = value 

1230 

1231 @deprecated("2.0", "PICOS treats all inequalities as non-strict. Using the " 

1232 "strict inequality comparison operators may lead to unexpected results " 

1233 "when dealing with integer problems.") 

1234 def __lt__(self, exp): 

1235 return self.__le__(exp) 

1236 

1237 @deprecated("2.0", "PICOS treats all inequalities as non-strict. Using the " 

1238 "strict inequality comparison operators may lead to unexpected results " 

1239 "when dealing with integer problems.") 

1240 def __gt__(self, exp): 

1241 return self.__ge__(exp) 

1242 

1243 

1244# -------------------------------------- 

1245__all__ = api_end(_API_START, globals())