Spaces:
Build error
Build error
File size: 2,402 Bytes
b41a54a 35ed471 b41a54a 35ed471 b41a54a 35ed471 b41a54a 35ed471 b41a54a 35ed471 b41a54a 35ed471 b41a54a 35ed471 b41a54a 35ed471 b41a54a 35ed471 b41a54a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 |
# system
import os
from pathlib import Path
if not Path('./Text2Punk-final-7.pt').exists() and not Path('./clip-final.pt').exists():
os.system("gdown https://drive.google.com/uc?id=1--27E5dk8GzgvpVL0ofr-m631iymBpUH")
os.system("gdown https://drive.google.com/uc?id=191a5lTsUPQ1hXaeo6kVNbo_W3WYuXsmF")
# plot
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
# gradio
import gradio as gr
# text2punks utils
from text2punks.utils import resize, to_pil_image, model_loader, generate_image
# nobs to tune
top_k = 0.8
temperature = 1.25
# helper functions
def compose_predictions(images):
increased_h = 0
b, c, h, w = *images.shape,
image_grid = Image.new("RGB", (b*w*4, h*4), color=0)
for i in range(b):
# resize(images[i], 96)
print(images[i].shape)
img_ = to_pil_image(images[i])
image_grid.paste(img_, (i*w*4, increased_h))
return image_grid
def run_inference(prompt, num_images=32, batch_size=32, num_preds=8):
t2p_path, clip_path = './Text2Punk-final-7.pt', './clip-final.pt'
text2punk, clip = model_loader(t2p_path, clip_path)
images, _ = generate_image(prompt_text=prompt, top_k=top_k, temperature=temperature, num_images=num_images, batch_size=batch_size, top_prediction=num_preds, text2punk_model=text2punk, clip_model=clip)
predictions = compose_predictions(images)
output_title = f"""
<b>{prompt}</b>
"""
return (output_title, predictions)
outputs = [
gr.outputs.HTML(label=""), # To be used as title
gr.outputs.Image(label=''),
]
description = """
Text2Cryptopunks is an AI model that generates Cryptopunks images from text prompt:
"""
gr.Interface(run_inference,
inputs=[gr.inputs.Textbox(label='type somthing like this : "An Ape CryptoPunk that has 2 Attributes, a Pigtails and a Medical Mask."')],
outputs=outputs,
title='Text2Cryptopunks',
description=description,
article="<p style='text-align: center'> Created by kTonpa | <a href='https://github.com/kTonpa/Text2CryptoPunks'>GitHub</a>",
layout='vertical',
theme='huggingface',
examples=[['Cute Alien cryptopunk that has a 2 Attributes, a Pipe, and a Beanie.'], ['A low resolution photo of punky-looking Ape that has 2 Attributes, a Beanie, and a Medical Mask.']],
allow_flagging=False,
live=False,
# server_port=8999
).launch(share=True)
|