@mokollmar
test #5
40c4e66
import streamlit as st
import torch
import clip
from PIL import Image
import numpy as np
import hmac
# Load CLIP model and preprocessing
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load("ViT-B/32", device=device)
# Function to predict descriptions and probabilities
def predict(image, descriptions):
image = preprocess(image).unsqueeze(0).to(device)
text = clip.tokenize(descriptions).to(device)
with torch.no_grad():
image_features = model.encode_image(image)
text_features = model.encode_text(text)
logits_per_image, logits_per_text = model(image, text)
probs = logits_per_image.softmax(dim=-1).cpu().numpy()
return descriptions[np.argmax(probs)], np.max(probs)
# Streamlit app
def main():
st.title("Image understanding model test")
# Instructions for the user
st.markdown("---")
st.markdown("### Upload an image to test how well the model understands it")
# Upload image through Streamlit with a unique key
uploaded_image = st.file_uploader("Upload an image...", type=["jpg", "png", "jpeg"], key="uploaded_image")
if uploaded_image is not None:
# Convert the uploaded image to PIL Image
pil_image = Image.open(uploaded_image)
# Limit the height of the displayed image to 400px
st.image(pil_image, caption="Uploaded Image.", use_column_width=True, width=200)
# Instructions for the user
st.markdown("### 2 Lies and 1 Truth")
st.markdown("Write 3 descriptions about the image, 1 must be true.")
# Get user input for descriptions
description1 = st.text_input("Description 1:", placeholder='A red apple')
description2 = st.text_input("Description 2:", placeholder='A car parked in a garage')
description3 = st.text_input("Description 3:", placeholder='An orange fruit on a tree')
descriptions = [description1, description2, description3]
# Button to trigger prediction
if st.button("Predict"):
if all(descriptions):
# Make predictions
best_description, best_prob = predict(pil_image, descriptions)
# Display the highest probability description and its probability
st.write(f"**Best Description:** {best_description}")
st.write(f"**Prediction Probability:** {best_prob:.2%}")
# Display progress bar for the highest probability
st.progress(float(best_prob))
# user has correct Password?
def check_password():
def password_entered():
if hmac.compare_digest(st.session_state["password"], st.secrets["password"]):
st.session_state["password_correct"] = True
del st.session_state["password"] # Don't store the password.
else:
st.session_state["password_correct"] = False
# Return True if the password is validated.
if st.session_state.get("password_correct", False):
return True
# Show input for password.
st.markdown("## Prove you're root.")
st.text_input("Password", type="password", on_change=password_entered, key="password")
if "password_correct" in st.session_state:
st.error("Secret handshake required. Keyboard-only members club.")
return False
if __name__ == "__main__":
if not check_password():
st.stop() # Do not continue if check_password is not True.
# Main Streamlit app starts here
main()