|
|
|
|
|
|
|
import torch |
|
import random |
|
import numpy as np |
|
import os |
|
|
|
|
|
if torch.cuda.is_available(): |
|
DEVICE = torch.device('cuda') |
|
MP_CONTEXT = None |
|
PIN_MEM = True |
|
elif torch.backends.mps.is_available(): |
|
DEVICE = torch.device('mps') |
|
MP_CONTEXT = 'forkserver' |
|
PIN_MEM = False |
|
else: |
|
DEVICE = torch.device('cpu') |
|
MP_CONTEXT = None |
|
PIN_MEM = False |
|
|
|
|
|
|
|
|
|
|
|
def set_seed(seed: int = 0): |
|
''' |
|
Sets random seed and deterministic settings for reproducibility across: |
|
- PyTorch |
|
- NumPy |
|
- Python's random module |
|
|
|
Args: |
|
seed (int): The seed value to set. |
|
''' |
|
torch.manual_seed(seed) |
|
np.random.seed(seed) |
|
random.seed(seed) |
|
torch.cuda.manual_seed_all(seed) |
|
|
|
torch.use_deterministic_algorithms(True) |
|
os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8' |
|
|
|
def save_model(model: torch.nn.Module, |
|
save_dir: str, |
|
mod_name: str): |
|
''' |
|
Saves the `state_dict()` of a model to the directory 'save_dir.' |
|
|
|
Args: |
|
model (torch.nn.Module): The PyTorch model whose state dict and keyword arguments will be saved. |
|
save_dir (str): Directory to save the model to. |
|
mod_name (str): Filename for the saved model. If this doesn't end with '.pth' or '.pt,' it will be added on for the state_dict. |
|
|
|
''' |
|
|
|
os.makedirs(save_dir, exist_ok = True) |
|
|
|
|
|
if not mod_name.endswith('.pth') and not mod_name.endswith('.pt'): |
|
mod_name += '.pth' |
|
|
|
|
|
save_path = os.path.join(save_dir, mod_name) |
|
|
|
|
|
torch.save(obj = model.state_dict(), f = save_path) |