File size: 776 Bytes
8aed5f0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
from dataclasses import dataclass
import os
from dotenv import load_dotenv

load_dotenv()


@dataclass
class PetSegTrainConfig:
    EPOCHS = 5
    BATCH_SIZE = 8
    FAST_DEV_RUN = False
    TOTAL_SAMPLES = 100
    LEARNING_RATE = 1e-3
    TRAIN_VAL_TEST_DATA_PATH = "./data/train_val_test"
    DEPTHWISE_SEP = False
    CHANNELS_LIST = [16, 32, 64, 128, 256]
    DESCRIPTION_TEXT = None


@dataclass
class PetSegWebappConfig:
    MODEL_WEIGHTS_GDRIVE_FILE_ID = os.environ.get("MODEL_WEIGHTS_GDRIVE_FILE_ID")
    MODEL_WEIGHTS_LOCAL_PATH = os.environ.get(
        "MODEL_WEIGHTS_LOCAL_PATH", "pet-segmentation-pytorch_epoch=4-step=1840.ckpt"
    )
    DOWNLOAD_MODEL_WEIGTHS_FROM_GDRIVE = (
        os.environ.get("DOWNLOAD_MODEL_WEIGTHS_FROM_GDRIVE", "True") == "True"
    )