File size: 2,227 Bytes
635f007 1373f78 635f007 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 |
from functools import reduce
from inspect import isfunction
from math import ceil, floor, log2
from typing import Callable, Dict, List, Optional, Sequence, Tuple, TypeVar, Union
import torch
from typing_extensions import TypeGuard
T = TypeVar("T")
def exists(val: Optional[T]) -> TypeGuard[T]:
return val is not None
def iff(condition: bool, value: T) -> Optional[T]:
return value if condition else None
def is_sequence(obj: T) -> TypeGuard[Union[list, tuple]]:
return isinstance(obj, list) or isinstance(obj, tuple)
def default(val: Optional[T], d: Union[Callable[..., T], T]) -> T:
if exists(val):
return val
return d() if isfunction(d) else d
def to_list(val: Union[T, Sequence[T]]) -> List[T]:
if isinstance(val, tuple):
return list(val)
if isinstance(val, list):
return val
return [val] # type: ignore
def prod(vals: Sequence[int]) -> int:
return reduce(lambda x, y: x * y, vals)
def closest_power_2(x: float) -> int:
exponent = log2(x)
distance_fn = lambda z: abs(x - 2**z) # noqa
exponent_closest = min((floor(exponent), ceil(exponent)), key=distance_fn)
return 2 ** int(exponent_closest)
def rand_bool(shape, proba, device=None):
if proba == 1:
return torch.ones(shape, device=device, dtype=torch.bool)
elif proba == 0:
return torch.zeros(shape, device=device, dtype=torch.bool)
else:
return torch.bernoulli(torch.full(shape, proba, device=device)).to(torch.bool)
"""
Kwargs Utils
"""
def group_dict_by_prefix(prefix: str, d: Dict) -> Tuple[Dict, Dict]:
return_dicts: Tuple[Dict, Dict] = ({}, {})
for key in d.keys():
no_prefix = int(not key.startswith(prefix))
return_dicts[no_prefix][key] = d[key]
return return_dicts
def groupby(prefix: str, d: Dict, keep_prefix: bool = False) -> Tuple[Dict, Dict]:
kwargs_with_prefix, kwargs = group_dict_by_prefix(prefix, d)
if keep_prefix:
return kwargs_with_prefix, kwargs
kwargs_no_prefix = {k[len(prefix) :]: v for k, v in kwargs_with_prefix.items()}
return kwargs_no_prefix, kwargs
def prefix_dict(prefix: str, d: Dict) -> Dict:
return {prefix + str(k): v for k, v in d.items()}
|