sayakpaul HF Staff commited on
Commit
d77c5aa
·
verified ·
1 Parent(s): 901d154

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +102 -0
app.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from huggingface_hub import upload_file, create_repo
2
+ import gradio as gr
3
+ import os
4
+ import requests
5
+ import tempfile
6
+ import requests
7
+ import re
8
+
9
+
10
+ article = """
11
+ Some things to note:
12
+
13
+ * To obtain the download link of the state dict hosted on CivitAI, right click on the "Download" button, visible on the model page.
14
+ * If the creator of the state dict requires the users to login to CivitAI first, that means downloading the state dict will require the CivitAI API keys.
15
+ * To obtain the API key of your CivitAI account, head to https://civitai.com/user/account and scroll to "API Keys".
16
+ * For such state dicts, it is mandatory to pass `civitai_api_key`.
17
+ * If you are getting "429 Client Error: Too Many Requests for url" error, retry passing your CivitAI API key.
18
+
19
+ """
20
+
21
+ def download_locally_and_upload_to_hub(civit_url, repo_id, hf_token=None, civitai_api_key=None):
22
+ if not civitai_api_key:
23
+ civitai_api_key = None
24
+ if civitai_api_key:
25
+ headers = {
26
+ "Authorization": f"Bearer {civitai_api_key}",
27
+ "Accept": "application/json"
28
+ }
29
+ else:
30
+ headers = None
31
+
32
+ response = requests.get(civit_url, headers=headers, stream=True)
33
+ response.raise_for_status()
34
+
35
+ cd = response.headers.get("Content-Disposition")
36
+ if cd:
37
+ # This regular expression will try to find a filename attribute in the header
38
+ fname = re.findall('filename="?([^"]+)"?', cd)
39
+ if fname:
40
+ filename = fname[0]
41
+ else:
42
+ filename = civit_url.split("/")[-1]
43
+ else:
44
+ filename = civit_url.split("/")[-1]
45
+
46
+ with tempfile.TemporaryDirectory() as local_path:
47
+ local_path = os.path.join(local_path, filename)
48
+ with open(local_path, "wb") as file:
49
+ for chunk in response.iter_content(chunk_size=8192):
50
+ if chunk: # filter out keep-alive new chunks
51
+ file.write(chunk)
52
+
53
+ if repo_id:
54
+ repo_successfully_created = False
55
+ if not hf_token:
56
+ hf_token = None
57
+ try:
58
+ repo_id = create_repo(repo_id=repo_id, exist_ok=True, token=hf_token).repo_id
59
+ repo_successfully_created = True
60
+ except Exception as e:
61
+ error_message_on_repo_creation = e
62
+
63
+ file_successfully_committed = False
64
+ try:
65
+ if repo_successfully_created:
66
+ commit_info = upload_file(repo_id=repo_id, path_or_fileobj=local_path, path_in_repo=filename, token=hf_token)
67
+ file_successfully_committed = True
68
+ except Exception as e:
69
+ error_message_on_file_commit = e
70
+
71
+
72
+ if repo_successfully_created and file_successfully_committed:
73
+ return f"Pushed the checkpoint here: [{commit_info._url}]({commit_info._url})"
74
+ elif not repo_successfully_created:
75
+ return f"Error happened during repo creation: {error_message_on_repo_creation}"
76
+ elif not file_successfully_committed:
77
+ return f"Error happened during committing the file: {error_message_on_file_commit}"
78
+
79
+
80
+ def get_gradio_demo():
81
+ demo = gr.Interface(
82
+ title="Upload CivitAI checkpoints to the HF Hub 🤗",
83
+ article=article,
84
+ description="**See instructions below the form.**",
85
+ fn=download_locally_and_upload_to_hub,
86
+ inputs=[
87
+ gr.Textbox(lines=1, info="Download URL of the CivitAI checkpoint."),
88
+ gr.Textbox(lines=1, info="Repo ID for the checkpoint to upload on the Hub."),
89
+ gr.TextArea(lines=1, info="Your HF token. Generate one from https://huggingface.co/settings/tokens."),
90
+ gr.TextArea(lines=1, info="Civit API key to download the checkpoint if needed. Should be left otherwise.")
91
+ ],
92
+ outputs="markdown",
93
+ examples=[
94
+ ['https://civitai.com/api/download/models/1432115?type=Model&format=SafeTensor', 'sayakpaul/civitai-test', '', ''],
95
+ ],
96
+ allow_flagging="never"
97
+ )
98
+ return demo
99
+
100
+ if __name__ == "__main__":
101
+ demo = get_gradio_demo()
102
+ demo.launch(show_error=True)