Coverage for picos/containers.py: 80.38%

316 statements  

« prev     ^ index     » next       coverage.py v6.5.0, created at 2023-03-26 07:46 +0000

1# ------------------------------------------------------------------------------ 

2# Copyright (C) 2019-2020 Maximilian Stahlberg 

3# 

4# This file is part of PICOS. 

5# 

6# PICOS is free software: you can redistribute it and/or modify it under the 

7# terms of the GNU General Public License as published by the Free Software 

8# Foundation, either version 3 of the License, or (at your option) any later 

9# version. 

10# 

11# PICOS is distributed in the hope that it will be useful, but WITHOUT ANY 

12# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR 

13# A PARTICULAR PURPOSE. See the GNU General Public License for more details. 

14# 

15# You should have received a copy of the GNU General Public License along with 

16# this program. If not, see <http://www.gnu.org/licenses/>. 

17# ------------------------------------------------------------------------------ 

18 

19"""Domain-specific container types.""" 

20 

21from collections import OrderedDict 

22from collections.abc import MutableSet 

23from itertools import chain 

24from types import MappingProxyType 

25 

26from .apidoc import api_end, api_start 

27from .caching import cached_property 

28 

29_API_START = api_start(globals()) 

30# ------------------------------- 

31 

32 

33class OrderedSet(MutableSet): 

34 """A set that remembers its insertion order. 

35 

36 >>> from picos.containers import OrderedSet as oset 

37 >>> o = oset([4, 3, 2, 1]); o 

38 OrderedSet([4, 3, 2, 1]) 

39 >>> 3 in o 

40 True 

41 >>> o.update([5, 4, 3]); o 

42 OrderedSet([4, 3, 2, 1, 5]) 

43 >>> list(o) 

44 [4, 3, 2, 1, 5] 

45 """ 

46 

47 def __init__(self, iterable=()): 

48 """Intialize the ordered set. 

49 

50 :param iterable: 

51 Iterable to take initial elements from. 

52 """ 

53 self._dict = OrderedDict((element, None) for element in iterable) 

54 

55 # -------------------------------------------------------------------------- 

56 # Special methods not implemented by MutableSet. 

57 # -------------------------------------------------------------------------- 

58 

59 def __str__(self): 

60 return "{{{}}}".format(", ".join(str(element) for element in self)) 

61 

62 def __repr__(self): 

63 return "OrderedSet([{}])".format( 

64 ", ".join(str(element) for element in self)) 

65 

66 # -------------------------------------------------------------------------- 

67 # Abstract method implementations. 

68 # -------------------------------------------------------------------------- 

69 

70 def __contains__(self, key): 

71 return key in self._dict 

72 

73 def __iter__(self): 

74 return iter(self._dict.keys()) 

75 

76 def __len__(self): 

77 return len(self._dict) 

78 

79 def add(self, element): 

80 """Add an element to the set.""" 

81 self._dict[element] = None 

82 

83 def discard(self, element): 

84 """Discard an element from the set. 

85 

86 If the element is not contained, do nothing. 

87 """ 

88 self.pop(element, None) 

89 

90 # -------------------------------------------------------------------------- 

91 # Overridingns to improve performance over MutableSet's implementation. 

92 # -------------------------------------------------------------------------- 

93 

94 def clear(self): 

95 """Clear the set.""" 

96 self._dict.clear() 

97 

98 # -------------------------------------------------------------------------- 

99 # Methods provided by set but not by MutableSet. 

100 # -------------------------------------------------------------------------- 

101 

102 def update(self, *iterables): 

103 """Update the set with elements from a number of iterables.""" 

104 for iterable in iterables: 

105 for element in iterable: 

106 self._dict[element] = None 

107 

108 difference = property( 

109 lambda self: self.__sub__, 

110 doc=set.difference.__doc__) 

111 

112 difference_update = property( 

113 lambda self: self.__isub__, 

114 doc=set.difference_update.__doc__) 

115 

116 intersection = property( 

117 lambda self: self.__and__, 

118 doc=set.intersection.__doc__) 

119 

120 intersection_update = property( 

121 lambda self: self.__iand__, 

122 doc=set.intersection_update.__doc__) 

123 

124 issubset = property( 

125 lambda self: self.__le__, 

126 doc=set.issubset.__doc__) 

127 

128 issuperset = property( 

129 lambda self: self.__ge__, 

130 doc=set.issuperset.__doc__) 

131 

132 symmetric_difference = property( 

133 lambda self: self.__xor__, 

134 doc=set.symmetric_difference.__doc__) 

135 

136 symmetric_difference_update = property( 

137 lambda self: self.__ixor__, 

138 doc=set.symmetric_difference_update.__doc__) 

139 

140 union = property( 

141 lambda self: self.__or__, 

142 doc=set.union.__doc__) 

143 

144 

145class frozendict(dict): 

146 """An immutable, hashable dictionary.""" 

147 

148 @classmethod 

149 def fromkeys(cls, iterable, value=None): 

150 """Overwrite :meth:`dict.fromkeys`.""" 

151 return cls(dict.fromkeys(iterable, value)) 

152 

153 def __hash__(self): 

154 if not hasattr(self, "_hash"): 

155 self._hash = hash(tuple(sorted(self.items()))) 

156 

157 return self._hash 

158 

159 def __str__(self): 

160 return dict.__repr__(self) 

161 

162 def __repr__(self): 

163 return "{}({})".format(self.__class__.__name__, self) 

164 

165 @property 

166 def _modify(self, *args, **kwargs): 

167 raise AttributeError( 

168 "Cannot modify a {}.".format(self.__class__.__name__)) 

169 

170 __delitem__ = _modify 

171 __setitem__ = _modify 

172 clear = _modify 

173 pop = _modify 

174 popitem = _modify 

175 setdefault = _modify 

176 update = _modify 

177 

178 def copy(self): 

179 """Since :class:`frozendict` are immutable, returns self.""" 

180 return self 

181 

182 

183class DetailedType: 

184 """Container for a pair of Python class and subtype. 

185 

186 A container for a pair of a python class type and some logical subtype data 

187 structure, together called a detailed type. 

188 

189 A detailed type is used when the mathematical type of an object must be 

190 distinguished more precisely than at the level of the python classes used to 

191 represent such mathematical objects. For instance, a single python class 

192 would be used for a type of expressions of varying dimensionality and 

193 subtypes would be used to distinguish further based on dimension. 

194 

195 Instances of this class are treated exceptionally when used as a label of 

196 a :class:`RecordTree`: They are expanded into the class and the subtype as 

197 two seperate labels, making it convenient to store detailed types in trees. 

198 """ 

199 

200 def __init__(self, theClass, subtype): 

201 """Construct a :class:`DetailedType`. 

202 

203 :param type theClass: The Python class part of the detailed type. 

204 :param object subtype: Additional type information. 

205 """ 

206 if not hasattr(subtype, "_asdict"): 

207 raise TypeError("The given subtype of {} is not a namedtuple " 

208 "instance.".format(subtype)) 

209 

210 self.clstype = theClass 

211 self.subtype = subtype 

212 

213 def __iter__(self): 

214 yield self.clstype 

215 yield self.subtype 

216 

217 def __hash__(self): 

218 return hash((self.clstype, self.subtype)) 

219 

220 def __eq__(self, other): 

221 return hash(self) == hash(other) 

222 

223 def equals(self, other): 

224 """Whether two detailed types are the same.""" 

225 return hash(self) == hash(other) 

226 

227 def __repr__(self): 

228 return "<{}: {}>".format(self.__class__.__name__, str(self)) 

229 

230 def __str__(self): 

231 subtypeArgsStr = "|".join("{}={}".format(key, val) 

232 for key, val in self.subtype._asdict().items()) 

233 

234 return "{}[{}]".format(self.clstype.__name__, subtypeArgsStr) 

235 

236 def __add__(self, other): 

237 if isinstance(other, tuple): 

238 return tuple(self) + other 

239 elif isinstance(other, list): 

240 return list(self) + other 

241 else: 

242 return NotImplemented 

243 

244 def __radd__(self, other): 

245 if isinstance(other, tuple): 

246 return other + tuple(self) 

247 elif isinstance(other, list): 

248 return other + list(self) 

249 else: 

250 return NotImplemented 

251 

252 

253class RecordTreeToken: 

254 """Base class for special :class:`RecordTree` value tokens.""" 

255 

256 def __init__(self): 

257 """Raise a :exc:`TypeError` on instanciation.""" 

258 raise TypeError("{} may not be initialized.".format( 

259 self.__class__.__name__)) 

260 

261 

262class RecordTree(): 

263 """Labeled tree for storing records. 

264 

265 An immutable labeled tree with values at the leaf nodes, where labels and 

266 values are arbitrary hashable python objects. 

267 

268 An n-tuple whose first (n-1) elements are labels and whose last element is a 

269 value is called a record. Thus, every path from the root node to a leaf node 

270 represents one record. 

271 

272 :class:`DetailedType` labels are treated exceptionally: They are expanded 

273 into the class and the subtype as two seperate labels. 

274 """ 

275 

276 class _NodeDict(dict): 

277 pass 

278 

279 class NONE(RecordTreeToken): 

280 """Special :class:`RecordTree` value: No subtrees. 

281 

282 If inserted at some (inner) node of the tree, the whole subtree starting 

283 at that node is deleted. If that node's parent node has no other 

284 children, then the parent node is deleted as well. This process is 

285 repeated iteratively up to the root node, which is never deleted. 

286 

287 This is the only value that may be inserted at an inner node. 

288 

289 This value cannot itself be stored in the tree as its insertion is 

290 always read as a deletion. 

291 """ 

292 

293 pass 

294 

295 class ALL(RecordTreeToken): 

296 """Special :class:`RecordTree` value: Any subtrees. 

297 

298 A special value that behaves as an arbitrary subtree during subtree 

299 checks. 

300 """ 

301 

302 pass 

303 

304 @classmethod 

305 def _flatten(cls, path): 

306 for index, thing in enumerate(path): 

307 if isinstance(thing, DetailedType): 

308 return cls._flatten(path[:index] + thing + path[index+1:]) 

309 return path 

310 

311 @classmethod 

312 def _freeze(cls, value): 

313 """Make a label or value hashable.""" 

314 if isinstance(value, list): 

315 newValue = tuple(value) 

316 elif isinstance(value, set): 

317 newValue = frozenset(value) 

318 elif isinstance(value, dict): 

319 newValue = frozendict(value) 

320 else: 

321 newValue = value 

322 

323 try: 

324 hash(newValue) 

325 except Exception as error: 

326 raise TypeError("Failed to freeze {} to a hashable type." 

327 .format(value)) from error 

328 

329 return newValue 

330 

331 @staticmethod 

332 def _keyval_iterator(recordsOrDict): 

333 if isinstance(recordsOrDict, dict): 

334 return recordsOrDict.items() 

335 else: 

336 return ((rec[:-1], rec[-1]) for rec in recordsOrDict) 

337 

338 @staticmethod 

339 def _record_iterator(recordsOrDict): 

340 if isinstance(recordsOrDict, dict): 

341 return ((key + (val,)) for key, val in recordsOrDict.items()) 

342 else: 

343 return recordsOrDict 

344 

345 def __init__(self, recordsOrDict=(), addValues=False, freeze=True): 

346 """Construct a :class:`RecordTree`. 

347 

348 :param recordsOrDict: 

349 Data stored in the tree. 

350 :type recordsOrDict: 

351 dict or list(tuple) 

352 

353 :param addValues: 

354 Add the (numeric) values of records stored in the same place in the 

355 tree, instead of replacing the value. If this is exactly a list of 

356 path tuples (precise types are required), then add values only for 

357 records below any of these paths instead. In either case, resulting 

358 values of zero are not stored in the tree. 

359 :type addValues: 

360 bool or list(tuple) 

361 

362 :param bool freeze: 

363 Whether to transform mutable labels and values into hashable ones. 

364 """ 

365 self._tree = self._NodeDict() 

366 

367 if isinstance(addValues, list): 

368 addValues = [self._flatten(path) for path in addValues] 

369 

370 def _add_values_at(path): 

371 if isinstance(addValues, list): 

372 return any(path[:end] in addValues for end in range(len(path))) 

373 else: 

374 return bool(addValues) 

375 

376 for path, value in self._keyval_iterator(recordsOrDict): 

377 path = self._flatten(path) 

378 node = self._tree 

379 

380 if freeze: 

381 path = tuple(self._freeze(thing) for thing in path) 

382 value = self._freeze(value) 

383 

384 if value is not self.NONE and _add_values_at(path): 

385 if value == 0: 

386 # Do not add a value equal to zero. 

387 continue 

388 elif path in self: 

389 oldValue = self[path] 

390 

391 if isinstance(oldValue, RecordTree): 

392 raise LookupError("Can't add value '{}' at '{}': Path " 

393 "leads to an inner node.".format(value, path)) 

394 

395 value = oldValue + value 

396 

397 # If the sum is zero, delete the record instead. 

398 if value == 0: 

399 value = self.NONE 

400 

401 if value is self.NONE: 

402 # Handle deletion of a subtree. 

403 clearNodes = [node] 

404 clearLabels = [] 

405 for label in path: 

406 if label in node: 

407 node = node[label] 

408 clearNodes.insert(0, node) 

409 clearLabels.insert(0, label) 

410 else: 

411 clearNodes = None 

412 break 

413 

414 if not clearNodes: 

415 continue 

416 

417 clearNodes.pop(0) 

418 for childLabel, node in zip(clearLabels, clearNodes): 

419 node.pop(childLabel) 

420 if node: 

421 break 

422 else: 

423 # Handle insertion of a leaf (may replace a subtree). 

424 for label in path[:-1]: 

425 node.setdefault(label, self._NodeDict()) 

426 node = node[label] 

427 

428 if not isinstance(node, self._NodeDict): 

429 raise LookupError("Can't set value '{}' at '{}': Path " 

430 "already contains a leaf.".format(value, path)) 

431 

432 node[path[-1]] = value 

433 

434 self._hash = hash(self.set) 

435 

436 @classmethod 

437 def _traverse(cls, node): 

438 if not isinstance(node, cls._NodeDict): 

439 # Not a node but a value. 

440 yield (node,) 

441 return 

442 elif not node: 

443 # Empty tree. 

444 return 

445 

446 for label, child in node.items(): 

447 for labels in cls._traverse(child): 

448 yield (label,) + labels 

449 

450 @property 

451 def records(self): 

452 """Return an iterator over tuples, each representing one record.""" 

453 return self._traverse(self._tree) 

454 

455 @property 

456 def items(self): 

457 """Return an iterator over path/value pairs representing records.""" 

458 return ((path[:-1], path[-1]) for path in self.records) 

459 

460 @property 

461 def paths(self): 

462 """Return an iterator over paths, each representing one record.""" 

463 return (path[:-1] for path in self.records) 

464 

465 @cached_property 

466 def dict(self): 

467 """Return the tree as a read-only, tuple-indexed dictionary view. 

468 

469 Every key/value pair of the returned dictionary is a record. 

470 """ 

471 return MappingProxyType({path[:-1]: path[-1] for path in self.records}) 

472 

473 @cached_property 

474 def set(self): 

475 """Return a frozen set of tuples, each representing one record.""" 

476 return frozenset(self.records) 

477 

478 def __hash__(self): 

479 return self._hash 

480 

481 def __eq__(self, other): 

482 return hash(self) == hash(other) 

483 

484 def __bool__(self): 

485 return bool(self._tree) 

486 

487 def __len__(self): 

488 # TODO: Compute this on initialization. 

489 return len(list(self.records)) 

490 

491 def __contains__(self, path): 

492 if not isinstance(path, tuple): 

493 raise TypeError("{} indices must be tuples.".format( 

494 self.__class__.__name__)) 

495 

496 node = self._tree 

497 for label in path: 

498 if not isinstance(node, self._NodeDict) or label not in node: 

499 return False 

500 node = node[label] 

501 

502 return True 

503 

504 def _get(self, path, errorOnBadPath): 

505 # try: 

506 # iter(path) 

507 # except TypeError: 

508 # path = (path,) 

509 if not isinstance(path, tuple) and not isinstance(path, list): 

510 path = (path,) 

511 

512 node = self._tree 

513 for label in path: 

514 if not isinstance(node, self._NodeDict) or label not in node: 

515 if errorOnBadPath: 

516 raise LookupError(str(path)) 

517 else: 

518 return RecordTree() 

519 node = node[label] 

520 

521 if isinstance(node, self._NodeDict): 

522 return RecordTree( 

523 {path[:-1]: path[-1] for path in self._traverse(node)}) 

524 else: 

525 return node 

526 

527 def __getitem__(self, path): 

528 return self._get(path, True) 

529 

530 def get(self, path): 

531 """Return an empty :class:`RecordTree` if the path does not exist.""" 

532 return self._get(path, False) 

533 

534 def __repr__(self): 

535 return "{}({})".format(self.__class__.__name__, self.dict) 

536 

537 def __str__(self): 

538 return str(self.dict) 

539 

540 def __le__(self, other): 

541 """Perform entrywise lower-or-equal-than comparison. 

542 

543 Each left hand side path must be present on the right hand side, and the 

544 associated left hand side value must compare lower-or-equal-than the 

545 right hand side value. 

546 """ 

547 if type(self) != type(other): 

548 return NotImplemented 

549 

550 for path, value in self.items: 

551 if path not in other: 

552 return False 

553 

554 if not value <= other[path]: 

555 return False 

556 

557 return True 

558 

559 def __ge__(self, other): 

560 """Perform entrywise greater-or-equal-than comparison. 

561 

562 Each left hand side path must be present on the right hand side, and the 

563 associated left hand side value must compare greater-or-equal-than the 

564 right hand side value. 

565 """ 

566 if type(self) != type(other): 

567 return NotImplemented 

568 

569 for path, value in self.items: 

570 if path not in other: 

571 return False 

572 

573 if not value >= other[path]: 

574 return False 

575 

576 return True 

577 

578 def __lshift__(self, other): 

579 """Perform subtree comparison. 

580 

581 Each left hand side path must be present on the right hand side. If the 

582 special :class:`ALL` type is found as a value in the right hand side 

583 tree, it is treated as "all possible subtrees". All other values are not 

584 considered. 

585 """ 

586 if not isinstance(other, RecordTree): 

587 return NotImplemented 

588 

589 for path in self.paths: 

590 lhsNode = self._tree 

591 rhsNode = other._tree 

592 

593 for label in path: 

594 if rhsNode is self.ALL: 

595 break 

596 

597 if label not in rhsNode: 

598 return False 

599 

600 assert label in lhsNode 

601 

602 lhsNode = lhsNode[label] 

603 rhsNode = rhsNode[label] 

604 

605 return True 

606 

607 def mismatch(self, other): 

608 """A subtree of ``self`` that renders ``self << other`` :obj:`False`. 

609 

610 :returns RecordTree: 

611 The smallest subtree ``T`` of ``self`` such that ``self`` without 

612 the records in ``T`` is a subtree of ``other``. The returned tree is 

613 a direct instance of the :class:`RecordTree` base class. 

614 """ 

615 if not isinstance(other, RecordTree): 

616 raise TypeError("The argument must be another record tree.") 

617 

618 records = [] 

619 

620 for record in self.records: 

621 lhsNode = self._tree 

622 rhsNode = other._tree 

623 

624 for label in record[:-1]: 

625 if rhsNode is self.ALL: 

626 break 

627 

628 if label not in rhsNode: 

629 records.append(record) 

630 break 

631 

632 assert label in lhsNode 

633 

634 lhsNode = lhsNode[label] 

635 rhsNode = rhsNode[label] 

636 

637 return RecordTree(records) 

638 

639 @staticmethod 

640 def _str(thing): 

641 return thing.__name__ if hasattr(thing, "__name__") else str(thing) 

642 

643 @property 

644 def text(self): 

645 """Return the full tree as a multiline string.""" 

646 keys, vals = [], [] 

647 for path in self._traverse(self._tree): 

648 keys.append("/".join(self._str(label) for label in path[:-1])) 

649 vals.append(self._str(path[-1])) 

650 if not keys: 

651 return "Empty {} instance.".format(self.__class__.__name__) 

652 keyLen = max(len(key) for key in keys) 

653 valLen = max(len(val) for val in vals) 

654 return "\n".join(sorted( 

655 "{{:{}}} = {{:{}}}".format(keyLen, valLen).format(key, val) 

656 for key, val in zip(keys, vals))) 

657 

658 def copy(self): 

659 """Create a shallow copy; the tree is copied, the values are not.""" 

660 return self.__class__(self.records) 

661 

662 def updated(self, recordsOrDict, addValues=False): 

663 """Create a shallow copy with modified records. 

664 

665 :Example: 

666 

667 >>> from picos.modeling.footprint import RecordTree as T 

668 >>> t = T({(1, 1, 1): 3, (1, 1, 2): 4, (1, 2, 1): 5}); t 

669 RecordTree({(1, 1, 1): 3, (1, 1, 2): 4, (1, 2, 1): 5}) 

670 >>> t.updated({(1, 1, 1): "a", (2, 2): "b"}) # Change or add a record. 

671 RecordTree({(1, 1, 1): 'a', (1, 1, 2): 4, (1, 2, 1): 5, (2, 2): 'b'}) 

672 >>> t.updated({(1,1,1): T.NONE}) # Delete a single record. 

673 RecordTree({(1, 1, 2): 4, (1, 2, 1): 5}) 

674 >>> t.updated({(1,1): T.NONE}) # Delete multiple records. 

675 RecordTree({(1, 2, 1): 5}) 

676 >>> t.updated([(1, 1, 1, T.NONE), (1, 1, 1, 1, 6)]) # Delete, then add. 

677 RecordTree({(1, 1, 2): 4, (1, 1, 1, 1): 6, (1, 2, 1): 5}) 

678 >>> try: # Not possible to implicitly turn a leaf into an inner node. 

679 ... t.updated([(1, 1, 1, 1, 6)]) 

680 ... except LookupError as error: 

681 ... print(error) 

682 Can't set value '6' at '(1, 1, 1, 1)': Path already contains a leaf. 

683 """ 

684 return self.__class__( 

685 chain(self.records, self._record_iterator(recordsOrDict)), 

686 addValues) 

687 

688 

689# -------------------------------------- 

690__all__ = api_end(_API_START, globals())