Spaces:
Running
on
Zero
Running
on
Zero
from collections import OrderedDict | |
from typing import Any, Callable, Dict, List, Optional | |
from typing import OrderedDict, Generic, TypeVar | |
K = TypeVar('K') | |
V = TypeVar('V') | |
class AttrDict(OrderedDict[K, V], Generic[K, V]): | |
""" | |
An attribute dictionary that automatically handles nested keys joined by "/". | |
Originally copied from: https://stackoverflow.com/questions/3031219/recursively-access-dict-via-attributes-as-well-as-index-access | |
""" | |
MARKER = object() | |
# pylint: disable=super-init-not-called | |
def __init__(self, *args, **kwargs): | |
if len(args) == 0: | |
for key, value in kwargs.items(): | |
self.__setitem__(key, value) | |
else: | |
assert len(args) == 1 | |
assert isinstance(args[0], (dict, AttrDict)) | |
for key, value in args[0].items(): | |
self.__setitem__(key, value) | |
def __contains__(self, key): | |
if "/" in key: | |
keys = key.split("/") | |
key, next_key = keys[0], "/".join(keys[1:]) | |
return key in self and next_key in self[key] | |
return super(AttrDict, self).__contains__(key) | |
def __setitem__(self, key, value): | |
if "/" in key: | |
keys = key.split("/") | |
key, next_key = keys[0], "/".join(keys[1:]) | |
if key not in self: | |
self[key] = AttrDict() | |
self[key].__setitem__(next_key, value) | |
return | |
if isinstance(value, dict) and not isinstance(value, AttrDict): | |
value = AttrDict(**value) | |
if isinstance(value, list): | |
value = [AttrDict(val) if isinstance(val, dict) else val for val in value] | |
super(AttrDict, self).__setitem__(key, value) | |
def __getitem__(self, key): | |
if "/" in key: | |
keys = key.split("/") | |
key, next_key = keys[0], "/".join(keys[1:]) | |
val = self[key] | |
if not isinstance(val, AttrDict): | |
raise ValueError | |
return val.__getitem__(next_key) | |
return self.get(key, None) | |
def all_keys( | |
self, | |
leaves_only: bool = False, | |
parent: Optional[str] = None, | |
) -> List[str]: | |
keys = [] | |
for key in self.keys(): | |
cur = key if parent is None else f"{parent}/{key}" | |
if not leaves_only or not isinstance(self[key], dict): | |
keys.append(cur) | |
if isinstance(self[key], dict): | |
keys.extend(self[key].all_keys(leaves_only=leaves_only, parent=cur)) | |
return keys | |
def dumpable(self, strip=True): | |
""" | |
Casts into OrderedDict and removes internal attributes | |
""" | |
def _dump(val): | |
if isinstance(val, AttrDict): | |
return val.dumpable() | |
elif isinstance(val, list): | |
return [_dump(v) for v in val] | |
return val | |
if strip: | |
return {k: _dump(v) for k, v in self.items() if not k.startswith("_")} | |
return {k: _dump(v if not k.startswith("_") else repr(v)) for k, v in self.items()} | |
def map( | |
self, | |
map_fn: Callable[[Any, Any], Any], | |
should_map: Optional[Callable[[Any, Any], bool]] = None, | |
) -> "AttrDict": | |
""" | |
Creates a copy of self where some or all values are transformed by | |
map_fn. | |
:param should_map: If provided, only those values that evaluate to true | |
are converted; otherwise, all values are mapped. | |
""" | |
def _apply(key, val): | |
if isinstance(val, AttrDict): | |
return val.map(map_fn, should_map) | |
elif should_map is None or should_map(key, val): | |
return map_fn(key, val) | |
return val | |
return AttrDict({k: _apply(k, v) for k, v in self.items()}) | |
def __eq__(self, other): | |
return self.keys() == other.keys() and all(self[k] == other[k] for k in self.keys()) | |
def combine( | |
self, | |
other: Dict[str, Any], | |
combine_fn: Callable[[Optional[Any], Optional[Any]], Any], | |
) -> "AttrDict": | |
""" | |
Some values may be missing, but the dictionary structures must be the | |
same. | |
:param combine_fn: a (possibly non-commutative) function to combine the | |
values | |
""" | |
def _apply(val, other_val): | |
if val is not None and isinstance(val, AttrDict): | |
assert isinstance(other_val, AttrDict) | |
return val.combine(other_val, combine_fn) | |
return combine_fn(val, other_val) | |
# TODO nit: this changes the ordering.. | |
keys = self.keys() | other.keys() | |
return AttrDict({k: _apply(self[k], other[k]) for k in keys}) | |
__setattr__, __getattr__ = __setitem__, __getitem__ | |