Coverage for picos/containers.py: 80.38%
316 statements
« prev ^ index » next coverage.py v6.5.0, created at 2023-03-26 07:46 +0000
« prev ^ index » next coverage.py v6.5.0, created at 2023-03-26 07:46 +0000
1# ------------------------------------------------------------------------------
2# Copyright (C) 2019-2020 Maximilian Stahlberg
3#
4# This file is part of PICOS.
5#
6# PICOS is free software: you can redistribute it and/or modify it under the
7# terms of the GNU General Public License as published by the Free Software
8# Foundation, either version 3 of the License, or (at your option) any later
9# version.
10#
11# PICOS is distributed in the hope that it will be useful, but WITHOUT ANY
12# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
13# A PARTICULAR PURPOSE. See the GNU General Public License for more details.
14#
15# You should have received a copy of the GNU General Public License along with
16# this program. If not, see <http://www.gnu.org/licenses/>.
17# ------------------------------------------------------------------------------
19"""Domain-specific container types."""
21from collections import OrderedDict
22from collections.abc import MutableSet
23from itertools import chain
24from types import MappingProxyType
26from .apidoc import api_end, api_start
27from .caching import cached_property
29_API_START = api_start(globals())
30# -------------------------------
33class OrderedSet(MutableSet):
34 """A set that remembers its insertion order.
36 >>> from picos.containers import OrderedSet as oset
37 >>> o = oset([4, 3, 2, 1]); o
38 OrderedSet([4, 3, 2, 1])
39 >>> 3 in o
40 True
41 >>> o.update([5, 4, 3]); o
42 OrderedSet([4, 3, 2, 1, 5])
43 >>> list(o)
44 [4, 3, 2, 1, 5]
45 """
47 def __init__(self, iterable=()):
48 """Intialize the ordered set.
50 :param iterable:
51 Iterable to take initial elements from.
52 """
53 self._dict = OrderedDict((element, None) for element in iterable)
55 # --------------------------------------------------------------------------
56 # Special methods not implemented by MutableSet.
57 # --------------------------------------------------------------------------
59 def __str__(self):
60 return "{{{}}}".format(", ".join(str(element) for element in self))
62 def __repr__(self):
63 return "OrderedSet([{}])".format(
64 ", ".join(str(element) for element in self))
66 # --------------------------------------------------------------------------
67 # Abstract method implementations.
68 # --------------------------------------------------------------------------
70 def __contains__(self, key):
71 return key in self._dict
73 def __iter__(self):
74 return iter(self._dict.keys())
76 def __len__(self):
77 return len(self._dict)
79 def add(self, element):
80 """Add an element to the set."""
81 self._dict[element] = None
83 def discard(self, element):
84 """Discard an element from the set.
86 If the element is not contained, do nothing.
87 """
88 self.pop(element, None)
90 # --------------------------------------------------------------------------
91 # Overridingns to improve performance over MutableSet's implementation.
92 # --------------------------------------------------------------------------
94 def clear(self):
95 """Clear the set."""
96 self._dict.clear()
98 # --------------------------------------------------------------------------
99 # Methods provided by set but not by MutableSet.
100 # --------------------------------------------------------------------------
102 def update(self, *iterables):
103 """Update the set with elements from a number of iterables."""
104 for iterable in iterables:
105 for element in iterable:
106 self._dict[element] = None
108 difference = property(
109 lambda self: self.__sub__,
110 doc=set.difference.__doc__)
112 difference_update = property(
113 lambda self: self.__isub__,
114 doc=set.difference_update.__doc__)
116 intersection = property(
117 lambda self: self.__and__,
118 doc=set.intersection.__doc__)
120 intersection_update = property(
121 lambda self: self.__iand__,
122 doc=set.intersection_update.__doc__)
124 issubset = property(
125 lambda self: self.__le__,
126 doc=set.issubset.__doc__)
128 issuperset = property(
129 lambda self: self.__ge__,
130 doc=set.issuperset.__doc__)
132 symmetric_difference = property(
133 lambda self: self.__xor__,
134 doc=set.symmetric_difference.__doc__)
136 symmetric_difference_update = property(
137 lambda self: self.__ixor__,
138 doc=set.symmetric_difference_update.__doc__)
140 union = property(
141 lambda self: self.__or__,
142 doc=set.union.__doc__)
145class frozendict(dict):
146 """An immutable, hashable dictionary."""
148 @classmethod
149 def fromkeys(cls, iterable, value=None):
150 """Overwrite :meth:`dict.fromkeys`."""
151 return cls(dict.fromkeys(iterable, value))
153 def __hash__(self):
154 if not hasattr(self, "_hash"):
155 self._hash = hash(tuple(sorted(self.items())))
157 return self._hash
159 def __str__(self):
160 return dict.__repr__(self)
162 def __repr__(self):
163 return "{}({})".format(self.__class__.__name__, self)
165 @property
166 def _modify(self, *args, **kwargs):
167 raise AttributeError(
168 "Cannot modify a {}.".format(self.__class__.__name__))
170 __delitem__ = _modify
171 __setitem__ = _modify
172 clear = _modify
173 pop = _modify
174 popitem = _modify
175 setdefault = _modify
176 update = _modify
178 def copy(self):
179 """Since :class:`frozendict` are immutable, returns self."""
180 return self
183class DetailedType:
184 """Container for a pair of Python class and subtype.
186 A container for a pair of a python class type and some logical subtype data
187 structure, together called a detailed type.
189 A detailed type is used when the mathematical type of an object must be
190 distinguished more precisely than at the level of the python classes used to
191 represent such mathematical objects. For instance, a single python class
192 would be used for a type of expressions of varying dimensionality and
193 subtypes would be used to distinguish further based on dimension.
195 Instances of this class are treated exceptionally when used as a label of
196 a :class:`RecordTree`: They are expanded into the class and the subtype as
197 two seperate labels, making it convenient to store detailed types in trees.
198 """
200 def __init__(self, theClass, subtype):
201 """Construct a :class:`DetailedType`.
203 :param type theClass: The Python class part of the detailed type.
204 :param object subtype: Additional type information.
205 """
206 if not hasattr(subtype, "_asdict"):
207 raise TypeError("The given subtype of {} is not a namedtuple "
208 "instance.".format(subtype))
210 self.clstype = theClass
211 self.subtype = subtype
213 def __iter__(self):
214 yield self.clstype
215 yield self.subtype
217 def __hash__(self):
218 return hash((self.clstype, self.subtype))
220 def __eq__(self, other):
221 return hash(self) == hash(other)
223 def equals(self, other):
224 """Whether two detailed types are the same."""
225 return hash(self) == hash(other)
227 def __repr__(self):
228 return "<{}: {}>".format(self.__class__.__name__, str(self))
230 def __str__(self):
231 subtypeArgsStr = "|".join("{}={}".format(key, val)
232 for key, val in self.subtype._asdict().items())
234 return "{}[{}]".format(self.clstype.__name__, subtypeArgsStr)
236 def __add__(self, other):
237 if isinstance(other, tuple):
238 return tuple(self) + other
239 elif isinstance(other, list):
240 return list(self) + other
241 else:
242 return NotImplemented
244 def __radd__(self, other):
245 if isinstance(other, tuple):
246 return other + tuple(self)
247 elif isinstance(other, list):
248 return other + list(self)
249 else:
250 return NotImplemented
253class RecordTreeToken:
254 """Base class for special :class:`RecordTree` value tokens."""
256 def __init__(self):
257 """Raise a :exc:`TypeError` on instanciation."""
258 raise TypeError("{} may not be initialized.".format(
259 self.__class__.__name__))
262class RecordTree():
263 """Labeled tree for storing records.
265 An immutable labeled tree with values at the leaf nodes, where labels and
266 values are arbitrary hashable python objects.
268 An n-tuple whose first (n-1) elements are labels and whose last element is a
269 value is called a record. Thus, every path from the root node to a leaf node
270 represents one record.
272 :class:`DetailedType` labels are treated exceptionally: They are expanded
273 into the class and the subtype as two seperate labels.
274 """
276 class _NodeDict(dict):
277 pass
279 class NONE(RecordTreeToken):
280 """Special :class:`RecordTree` value: No subtrees.
282 If inserted at some (inner) node of the tree, the whole subtree starting
283 at that node is deleted. If that node's parent node has no other
284 children, then the parent node is deleted as well. This process is
285 repeated iteratively up to the root node, which is never deleted.
287 This is the only value that may be inserted at an inner node.
289 This value cannot itself be stored in the tree as its insertion is
290 always read as a deletion.
291 """
293 pass
295 class ALL(RecordTreeToken):
296 """Special :class:`RecordTree` value: Any subtrees.
298 A special value that behaves as an arbitrary subtree during subtree
299 checks.
300 """
302 pass
304 @classmethod
305 def _flatten(cls, path):
306 for index, thing in enumerate(path):
307 if isinstance(thing, DetailedType):
308 return cls._flatten(path[:index] + thing + path[index+1:])
309 return path
311 @classmethod
312 def _freeze(cls, value):
313 """Make a label or value hashable."""
314 if isinstance(value, list):
315 newValue = tuple(value)
316 elif isinstance(value, set):
317 newValue = frozenset(value)
318 elif isinstance(value, dict):
319 newValue = frozendict(value)
320 else:
321 newValue = value
323 try:
324 hash(newValue)
325 except Exception as error:
326 raise TypeError("Failed to freeze {} to a hashable type."
327 .format(value)) from error
329 return newValue
331 @staticmethod
332 def _keyval_iterator(recordsOrDict):
333 if isinstance(recordsOrDict, dict):
334 return recordsOrDict.items()
335 else:
336 return ((rec[:-1], rec[-1]) for rec in recordsOrDict)
338 @staticmethod
339 def _record_iterator(recordsOrDict):
340 if isinstance(recordsOrDict, dict):
341 return ((key + (val,)) for key, val in recordsOrDict.items())
342 else:
343 return recordsOrDict
345 def __init__(self, recordsOrDict=(), addValues=False, freeze=True):
346 """Construct a :class:`RecordTree`.
348 :param recordsOrDict:
349 Data stored in the tree.
350 :type recordsOrDict:
351 dict or list(tuple)
353 :param addValues:
354 Add the (numeric) values of records stored in the same place in the
355 tree, instead of replacing the value. If this is exactly a list of
356 path tuples (precise types are required), then add values only for
357 records below any of these paths instead. In either case, resulting
358 values of zero are not stored in the tree.
359 :type addValues:
360 bool or list(tuple)
362 :param bool freeze:
363 Whether to transform mutable labels and values into hashable ones.
364 """
365 self._tree = self._NodeDict()
367 if isinstance(addValues, list):
368 addValues = [self._flatten(path) for path in addValues]
370 def _add_values_at(path):
371 if isinstance(addValues, list):
372 return any(path[:end] in addValues for end in range(len(path)))
373 else:
374 return bool(addValues)
376 for path, value in self._keyval_iterator(recordsOrDict):
377 path = self._flatten(path)
378 node = self._tree
380 if freeze:
381 path = tuple(self._freeze(thing) for thing in path)
382 value = self._freeze(value)
384 if value is not self.NONE and _add_values_at(path):
385 if value == 0:
386 # Do not add a value equal to zero.
387 continue
388 elif path in self:
389 oldValue = self[path]
391 if isinstance(oldValue, RecordTree):
392 raise LookupError("Can't add value '{}' at '{}': Path "
393 "leads to an inner node.".format(value, path))
395 value = oldValue + value
397 # If the sum is zero, delete the record instead.
398 if value == 0:
399 value = self.NONE
401 if value is self.NONE:
402 # Handle deletion of a subtree.
403 clearNodes = [node]
404 clearLabels = []
405 for label in path:
406 if label in node:
407 node = node[label]
408 clearNodes.insert(0, node)
409 clearLabels.insert(0, label)
410 else:
411 clearNodes = None
412 break
414 if not clearNodes:
415 continue
417 clearNodes.pop(0)
418 for childLabel, node in zip(clearLabels, clearNodes):
419 node.pop(childLabel)
420 if node:
421 break
422 else:
423 # Handle insertion of a leaf (may replace a subtree).
424 for label in path[:-1]:
425 node.setdefault(label, self._NodeDict())
426 node = node[label]
428 if not isinstance(node, self._NodeDict):
429 raise LookupError("Can't set value '{}' at '{}': Path "
430 "already contains a leaf.".format(value, path))
432 node[path[-1]] = value
434 self._hash = hash(self.set)
436 @classmethod
437 def _traverse(cls, node):
438 if not isinstance(node, cls._NodeDict):
439 # Not a node but a value.
440 yield (node,)
441 return
442 elif not node:
443 # Empty tree.
444 return
446 for label, child in node.items():
447 for labels in cls._traverse(child):
448 yield (label,) + labels
450 @property
451 def records(self):
452 """Return an iterator over tuples, each representing one record."""
453 return self._traverse(self._tree)
455 @property
456 def items(self):
457 """Return an iterator over path/value pairs representing records."""
458 return ((path[:-1], path[-1]) for path in self.records)
460 @property
461 def paths(self):
462 """Return an iterator over paths, each representing one record."""
463 return (path[:-1] for path in self.records)
465 @cached_property
466 def dict(self):
467 """Return the tree as a read-only, tuple-indexed dictionary view.
469 Every key/value pair of the returned dictionary is a record.
470 """
471 return MappingProxyType({path[:-1]: path[-1] for path in self.records})
473 @cached_property
474 def set(self):
475 """Return a frozen set of tuples, each representing one record."""
476 return frozenset(self.records)
478 def __hash__(self):
479 return self._hash
481 def __eq__(self, other):
482 return hash(self) == hash(other)
484 def __bool__(self):
485 return bool(self._tree)
487 def __len__(self):
488 # TODO: Compute this on initialization.
489 return len(list(self.records))
491 def __contains__(self, path):
492 if not isinstance(path, tuple):
493 raise TypeError("{} indices must be tuples.".format(
494 self.__class__.__name__))
496 node = self._tree
497 for label in path:
498 if not isinstance(node, self._NodeDict) or label not in node:
499 return False
500 node = node[label]
502 return True
504 def _get(self, path, errorOnBadPath):
505 # try:
506 # iter(path)
507 # except TypeError:
508 # path = (path,)
509 if not isinstance(path, tuple) and not isinstance(path, list):
510 path = (path,)
512 node = self._tree
513 for label in path:
514 if not isinstance(node, self._NodeDict) or label not in node:
515 if errorOnBadPath:
516 raise LookupError(str(path))
517 else:
518 return RecordTree()
519 node = node[label]
521 if isinstance(node, self._NodeDict):
522 return RecordTree(
523 {path[:-1]: path[-1] for path in self._traverse(node)})
524 else:
525 return node
527 def __getitem__(self, path):
528 return self._get(path, True)
530 def get(self, path):
531 """Return an empty :class:`RecordTree` if the path does not exist."""
532 return self._get(path, False)
534 def __repr__(self):
535 return "{}({})".format(self.__class__.__name__, self.dict)
537 def __str__(self):
538 return str(self.dict)
540 def __le__(self, other):
541 """Perform entrywise lower-or-equal-than comparison.
543 Each left hand side path must be present on the right hand side, and the
544 associated left hand side value must compare lower-or-equal-than the
545 right hand side value.
546 """
547 if type(self) != type(other):
548 return NotImplemented
550 for path, value in self.items:
551 if path not in other:
552 return False
554 if not value <= other[path]:
555 return False
557 return True
559 def __ge__(self, other):
560 """Perform entrywise greater-or-equal-than comparison.
562 Each left hand side path must be present on the right hand side, and the
563 associated left hand side value must compare greater-or-equal-than the
564 right hand side value.
565 """
566 if type(self) != type(other):
567 return NotImplemented
569 for path, value in self.items:
570 if path not in other:
571 return False
573 if not value >= other[path]:
574 return False
576 return True
578 def __lshift__(self, other):
579 """Perform subtree comparison.
581 Each left hand side path must be present on the right hand side. If the
582 special :class:`ALL` type is found as a value in the right hand side
583 tree, it is treated as "all possible subtrees". All other values are not
584 considered.
585 """
586 if not isinstance(other, RecordTree):
587 return NotImplemented
589 for path in self.paths:
590 lhsNode = self._tree
591 rhsNode = other._tree
593 for label in path:
594 if rhsNode is self.ALL:
595 break
597 if label not in rhsNode:
598 return False
600 assert label in lhsNode
602 lhsNode = lhsNode[label]
603 rhsNode = rhsNode[label]
605 return True
607 def mismatch(self, other):
608 """A subtree of ``self`` that renders ``self << other`` :obj:`False`.
610 :returns RecordTree:
611 The smallest subtree ``T`` of ``self`` such that ``self`` without
612 the records in ``T`` is a subtree of ``other``. The returned tree is
613 a direct instance of the :class:`RecordTree` base class.
614 """
615 if not isinstance(other, RecordTree):
616 raise TypeError("The argument must be another record tree.")
618 records = []
620 for record in self.records:
621 lhsNode = self._tree
622 rhsNode = other._tree
624 for label in record[:-1]:
625 if rhsNode is self.ALL:
626 break
628 if label not in rhsNode:
629 records.append(record)
630 break
632 assert label in lhsNode
634 lhsNode = lhsNode[label]
635 rhsNode = rhsNode[label]
637 return RecordTree(records)
639 @staticmethod
640 def _str(thing):
641 return thing.__name__ if hasattr(thing, "__name__") else str(thing)
643 @property
644 def text(self):
645 """Return the full tree as a multiline string."""
646 keys, vals = [], []
647 for path in self._traverse(self._tree):
648 keys.append("/".join(self._str(label) for label in path[:-1]))
649 vals.append(self._str(path[-1]))
650 if not keys:
651 return "Empty {} instance.".format(self.__class__.__name__)
652 keyLen = max(len(key) for key in keys)
653 valLen = max(len(val) for val in vals)
654 return "\n".join(sorted(
655 "{{:{}}} = {{:{}}}".format(keyLen, valLen).format(key, val)
656 for key, val in zip(keys, vals)))
658 def copy(self):
659 """Create a shallow copy; the tree is copied, the values are not."""
660 return self.__class__(self.records)
662 def updated(self, recordsOrDict, addValues=False):
663 """Create a shallow copy with modified records.
665 :Example:
667 >>> from picos.modeling.footprint import RecordTree as T
668 >>> t = T({(1, 1, 1): 3, (1, 1, 2): 4, (1, 2, 1): 5}); t
669 RecordTree({(1, 1, 1): 3, (1, 1, 2): 4, (1, 2, 1): 5})
670 >>> t.updated({(1, 1, 1): "a", (2, 2): "b"}) # Change or add a record.
671 RecordTree({(1, 1, 1): 'a', (1, 1, 2): 4, (1, 2, 1): 5, (2, 2): 'b'})
672 >>> t.updated({(1,1,1): T.NONE}) # Delete a single record.
673 RecordTree({(1, 1, 2): 4, (1, 2, 1): 5})
674 >>> t.updated({(1,1): T.NONE}) # Delete multiple records.
675 RecordTree({(1, 2, 1): 5})
676 >>> t.updated([(1, 1, 1, T.NONE), (1, 1, 1, 1, 6)]) # Delete, then add.
677 RecordTree({(1, 1, 2): 4, (1, 1, 1, 1): 6, (1, 2, 1): 5})
678 >>> try: # Not possible to implicitly turn a leaf into an inner node.
679 ... t.updated([(1, 1, 1, 1, 6)])
680 ... except LookupError as error:
681 ... print(error)
682 Can't set value '6' at '(1, 1, 1, 1)': Path already contains a leaf.
683 """
684 return self.__class__(
685 chain(self.records, self._record_iterator(recordsOrDict)),
686 addValues)
689# --------------------------------------
690__all__ = api_end(_API_START, globals())