from dataclasses import dataclass import os from dotenv import load_dotenv load_dotenv() @dataclass class PetSegTrainConfig: EPOCHS = 5 BATCH_SIZE = 8 FAST_DEV_RUN = False TOTAL_SAMPLES = 100 LEARNING_RATE = 1e-3 TRAIN_VAL_TEST_DATA_PATH = "./data/train_val_test" DEPTHWISE_SEP = False CHANNELS_LIST = [16, 32, 64, 128, 256] DESCRIPTION_TEXT = None @dataclass class PetSegWebappConfig: MODEL_WEIGHTS_GDRIVE_FILE_ID = os.environ.get("MODEL_WEIGHTS_GDRIVE_FILE_ID") MODEL_WEIGHTS_LOCAL_PATH = os.environ.get( "MODEL_WEIGHTS_LOCAL_PATH", "pet-segmentation-pytorch_epoch=4-step=1840.ckpt" ) DOWNLOAD_MODEL_WEIGTHS_FROM_GDRIVE = ( os.environ.get("DOWNLOAD_MODEL_WEIGTHS_FROM_GDRIVE", "True") == "True" )