from functools import reduce from inspect import isfunction from math import ceil, floor, log2 from typing import Callable, Dict, List, Optional, Sequence, Tuple, TypeVar, Union import torch from typing_extensions import TypeGuard T = TypeVar("T") def exists(val: Optional[T]) -> TypeGuard[T]: return val is not None def iff(condition: bool, value: T) -> Optional[T]: return value if condition else None def is_sequence(obj: T) -> TypeGuard[Union[list, tuple]]: return isinstance(obj, list) or isinstance(obj, tuple) def default(val: Optional[T], d: Union[Callable[..., T], T]) -> T: if exists(val): return val return d() if isfunction(d) else d def to_list(val: Union[T, Sequence[T]]) -> List[T]: if isinstance(val, tuple): return list(val) if isinstance(val, list): return val return [val] # type: ignore def prod(vals: Sequence[int]) -> int: return reduce(lambda x, y: x * y, vals) def closest_power_2(x: float) -> int: exponent = log2(x) distance_fn = lambda z: abs(x - 2**z) # noqa exponent_closest = min((floor(exponent), ceil(exponent)), key=distance_fn) return 2 ** int(exponent_closest) def rand_bool(shape, proba, device=None): if proba == 1: return torch.ones(shape, device=device, dtype=torch.bool) elif proba == 0: return torch.zeros(shape, device=device, dtype=torch.bool) else: return torch.bernoulli(torch.full(shape, proba, device=device)).to(torch.bool) """ Kwargs Utils """ def group_dict_by_prefix(prefix: str, d: Dict) -> Tuple[Dict, Dict]: return_dicts: Tuple[Dict, Dict] = ({}, {}) for key in d.keys(): no_prefix = int(not key.startswith(prefix)) return_dicts[no_prefix][key] = d[key] return return_dicts def groupby(prefix: str, d: Dict, keep_prefix: bool = False) -> Tuple[Dict, Dict]: kwargs_with_prefix, kwargs = group_dict_by_prefix(prefix, d) if keep_prefix: return kwargs_with_prefix, kwargs kwargs_no_prefix = {k[len(prefix) :]: v for k, v in kwargs_with_prefix.items()} return kwargs_no_prefix, kwargs def prefix_dict(prefix: str, d: Dict) -> Dict: return {prefix + str(k): v for k, v in d.items()}