File size: 3,579 Bytes
1c35f64
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import tempfile
import csv
import pandas as pd
import gradio as gr
from huggingface_hub import HfApi
from pathlib import Path

def get_model_stats(search_term):
    # Initialize the API
    api = HfApi()
    
    # Create a temporary file for the CSV
    temp_dir = tempfile.mkdtemp()
    output_file = Path(temp_dir) / f"{search_term}_models_alltime.csv"
    
    # Get the generator of models with the working sort parameter
    print(f"Fetching {search_term} models with download statistics...")
    models_generator = api.list_models(
        search=search_term, 
        expand=["downloads", "downloadsAllTime"],  # Get both 30-day and all-time downloads
        sort="_id"  # Sort by ID to avoid timeout issues
    )
    
    # Initialize counters for total downloads
    total_30day_downloads = 0
    total_alltime_downloads = 0
    
    # Create and write to CSV
    with open(output_file, 'w', newline='', encoding='utf-8') as csvfile:
        csv_writer = csv.writer(csvfile)
        # Write header
        csv_writer.writerow(["Model ID", "Downloads (30 days)", "Downloads (All Time)"])
        
        # Process models
        model_count = 0
        for model in models_generator:
            # Get download counts
            downloads_30day = getattr(model, 'downloads', 0)
            downloads_alltime = getattr(model, 'downloads_all_time', 0)
            
            # Add to totals
            total_30day_downloads += downloads_30day
            total_alltime_downloads += downloads_alltime
            
            # Write to CSV
            csv_writer.writerow([
                getattr(model, 'id', "Unknown"),
                downloads_30day,
                downloads_alltime
            ])
            model_count += 1
    
    # Read the CSV file into a pandas DataFrame
    df = pd.read_csv(output_file)
    
    # Create status message with total downloads
    status_message = (
        f"Found {model_count} models for search term '{search_term}'\n"
        f"Total 30-day downloads: {total_30day_downloads:,}\n"
        f"Total all-time downloads: {total_alltime_downloads:,}"
    )
    
    # Return both the DataFrame, status message, and the CSV file path
    return df, status_message, str(output_file)

# Create the Gradio interface
with gr.Blocks(title="Hugging Face Model Statistics") as demo:
    gr.Markdown("# Hugging Face Model Statistics")
    gr.Markdown("Enter a search term to find model statistics from Hugging Face Hub")
    
    with gr.Row():
        search_input = gr.Textbox(
            label="Search Term",
            placeholder="Enter a model name or keyword (e.g., 'gemma', 'llama')",
            value="gemma"
        )
        search_button = gr.Button("Search")
    
    with gr.Row():
        output_table = gr.Dataframe(
            headers=["Model ID", "Downloads (30 days)", "Downloads (All Time)"],
            datatype=["str", "number", "number"],
            label="Model Statistics"
        )
        status_message = gr.Textbox(label="Status", lines=3)  # Increased lines to show all stats
    
    with gr.Row():
        download_button = gr.Button("Download CSV")
        csv_file = gr.File(label="CSV File", visible=False)
    
    # Store the CSV file path in a state
    csv_path = gr.State()
    
    search_button.click(
        fn=get_model_stats,
        inputs=search_input,
        outputs=[output_table, status_message, csv_path]
    )
    
    download_button.click(
        fn=lambda x: x,
        inputs=csv_path,
        outputs=csv_file
    )

if __name__ == "__main__":
    demo.launch()