File size: 2,837 Bytes
7baf792
 
dac160c
7baf792
f33e554
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6cd7994
605b8ba
7baf792
dac160c
 
605b8ba
c25f048
7baf792
 
605b8ba
 
 
 
 
7baf792
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
from aura_sr import AuraSR
import gradio as gr
import spaces


class ZeroGPUAuraSR(AuraSR):
    @classmethod
    def from_pretrained(cls, model_id: str = "fal-ai/AuraSR", use_safetensors: bool = True):
        import json
        import torch
        from pathlib import Path
        from huggingface_hub import snapshot_download

        # Check if model_id is a local file
        if Path(model_id).is_file():
            local_file = Path(model_id)
            if local_file.suffix == '.safetensors':
                use_safetensors = True
            elif local_file.suffix == '.ckpt':
                use_safetensors = False
            else:
                raise ValueError(f"Unsupported file format: {local_file.suffix}. Please use .safetensors or .ckpt files.")
            
            # For local files, we need to provide the config separately
            config_path = local_file.with_name('config.json')
            if not config_path.exists():
                raise FileNotFoundError(
                    f"Config file not found: {config_path}. "
                    f"When loading from a local file, ensure that 'config.json' "
                    f"is present in the same directory as '{local_file.name}'. "
                    f"If you're trying to load a model from Hugging Face, "
                    f"please provide the model ID instead of a file path."
                )
            
            config = json.loads(config_path.read_text())
            hf_model_path = local_file.parent
        else:
            hf_model_path = Path(snapshot_download(model_id))
            config = json.loads((hf_model_path / "config.json").read_text())

        model = cls(config)

        if use_safetensors:
            try:
                from safetensors.torch import load_file
                checkpoint = load_file(hf_model_path / "model.safetensors" if not Path(model_id).is_file() else model_id)
            except ImportError:
                raise ImportError(
                    "The safetensors library is not installed. "
                    "Please install it with `pip install safetensors` "
                    "or use `use_safetensors=False` to load the model with PyTorch."
                )
        else:
            checkpoint = torch.load(hf_model_path / "model.ckpt" if not Path(model_id).is_file() else model_id)

        model.upsampler.load_state_dict(checkpoint, strict=True)
        return model



aura_sr = ZeroGPUAuraSR.from_pretrained("fal/AuraSR-v2")
aura_sr_v1 = ZeroGPUAuraSR.from_pretrained("fal-ai/AuraSR")


@spaces.GPU()
def predict(img, model_selection):
    return {'v1': aura_sr_v1, 'v2': aura_sr}.get(model_selection).upscale_4x(img)


demo = gr.Interface(
    predict,
    inputs=[gr.Image(), gr.Dropdown(value='v2', choices=['v1', 'v2'])],
    outputs=gr.Image()
)


demo.launch()