File size: 3,494 Bytes
5086590
0b2b0ab
 
 
 
 
 
 
d51b792
0b2b0ab
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d51b792
0b2b0ab
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d51b792
 
510e898
0b2b0ab
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d51b792
 
0b2b0ab
 
510e898
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
from models.structure.Advanced_Conditional_Unet import Unet
from diffusers import DDPMScheduler
import torch
import os
import glob
from torchvision import transforms
import pathlib
from safetensors.torch import load_model, save_model
import time as tm


denoising_timesteps = 4000
image_size = 128
channels = 3


device = "cuda" if torch.cuda.is_available() else "cpu"
device = "mps" if torch.backends.mps.is_available() else device

model = Unet(
    dim=image_size,
    channels=channels,
    dim_mults=(1, 2, 4, 8),
    use_convnext=False,
).to(device)

results_folder = pathlib.Path("models")


checkpoint_files_st = glob.glob(str(results_folder / "model-epoch_*.st"))
checkpoint_files_pt = glob.glob(str(results_folder / "model-epoch_*.pt"))

if checkpoint_files_st:
    # Sort the list of matching files by modification time (newest first)
    checkpoint_files_st.sort(key=lambda x: os.path.getmtime(x), reverse=True)
    # Select the newest file
    checkpoint_files = checkpoint_files_st[0]
    # Now, newest_model_file contains the path to the newest "model" file
    load_model(model, checkpoint_files)
    model.eval()
    print("Loaded model from checkpoint", checkpoint_files)

elif checkpoint_files_pt:
    # Sort the list of matching files by modification time (newest first)
    checkpoint_files_pt.sort(key=lambda x: os.path.getmtime(x), reverse=True)
    # Select the newest file
    checkpoint_files = checkpoint_files_pt[0]
    # Now, newest_model_file contains the path to the newest "model" file
    checkpoint = torch.load(checkpoint_files, map_location=device)
    model.load_state_dict(checkpoint["model_state_dict"])
    epoch = checkpoint["epoch"]
    model.eval()
    print("Loaded model from checkpoint", checkpoint_files)

    if not pathlib.Path(str(results_folder / "model-epoch_*.st")).exists():
        save_model(model, results_folder / "model-epoch_{}.st".format(epoch))
        print("Saved model as a safetensor", results_folder)

else:
    raise Exception("No model files found in the folder.")


def sample(sketch, scribbles, sampling_steps, seed_nr, progress):
    torch.manual_seed(seed_nr)

    noise_scheduler = DDPMScheduler(
        num_train_timesteps=denoising_timesteps, beta_schedule="squaredcos_cap_v2"
    )
    noise_scheduler.set_timesteps(sampling_steps, device=device)

    sketch = sketch.to(device)
    scribbles = scribbles.to(device)

    sketch = sketch.unsqueeze(0)
    scribbles = scribbles.unsqueeze(0)

    with torch.no_grad():
        b = sketch.shape[0]

        noise_for_plain = torch.randn_like(sketch, device=device)

        for t in progress.tqdm(
            noise_scheduler.timesteps,
            desc="Painting πŸ–ŒπŸ–ŒπŸ–Œ",
        ):
            noise_for_plain = noise_scheduler.scale_model_input(noise_for_plain, t).to(
                device
            )

            time = t.expand(
                b,
            ).to(device)

            plain_noise_pred = model(
                x=noise_for_plain,
                time=time,
                implicit_conditioning=scribbles,
                explicit_conditioning=sketch,
            )

            noise_for_plain = noise_scheduler.step(
                plain_noise_pred,
                t.long(),
                noise_for_plain,
            ).prev_sample

            tm.sleep(0.01)

    sample = torch.clamp((noise_for_plain / 2) + 0.5, 0, 1)

    image = transforms.ToPILImage()(sample[0].cpu())

    image.save("results/sample.png")

    return image