File size: 2,980 Bytes
2ce4950
1
{"cells": [{"cell_type": "markdown", "id": "302934307671667531413257853548643485645", "metadata": {}, "source": ["# Gradio Demo: scatter_plot"]}, {"cell_type": "code", "execution_count": null, "id": "272996653310673477252411125948039410165", "metadata": {}, "outputs": [], "source": ["!pip install -q gradio vega_datasets pandas"]}, {"cell_type": "code", "execution_count": null, "id": "288918539441861185822528903084949547379", "metadata": {}, "outputs": [], "source": ["import gradio as gr\n", "from vega_datasets import data\n", "\n", "cars = data.cars()\n", "iris = data.iris()\n", "\n", "# # Or generate your own fake data\n", "\n", "# import pandas as pd\n", "# import random\n", "\n", "# cars_data = {\n", "#     \"Name\": [\"car name \" + f\" {int(i/10)}\" for i in range(400)],\n", "#     \"Miles_per_Gallon\": [random.randint(10, 30) for _ in range(400)],\n", "#     \"Origin\": [random.choice([\"USA\", \"Europe\", \"Japan\"]) for _ in range(400)],\n", "#     \"Horsepower\": [random.randint(50, 250) for _ in range(400)],\n", "# }\n", "\n", "# iris_data = {\n", "#     \"petalWidth\": [round(random.uniform(0, 2.5), 2) for _ in range(150)],\n", "#     \"petalLength\": [round(random.uniform(0, 7), 2) for _ in range(150)],\n", "#     \"species\": [\n", "#         random.choice([\"setosa\", \"versicolor\", \"virginica\"]) for _ in range(150)\n", "#     ],\n", "# }\n", "\n", "# cars = pd.DataFrame(cars_data)\n", "# iris = pd.DataFrame(iris_data)\n", "\n", "def scatter_plot_fn(dataset):\n", "    if dataset == \"iris\":\n", "        return gr.ScatterPlot(\n", "            value=iris,\n", "            x=\"petalWidth\",\n", "            y=\"petalLength\",\n", "            color=\"species\",\n", "            title=\"Iris Dataset\",\n", "            color_legend_title=\"Species\",\n", "            x_title=\"Petal Width\",\n", "            y_title=\"Petal Length\",\n", "            tooltip=[\"petalWidth\", \"petalLength\", \"species\"],\n", "            caption=\"\",\n", "        )\n", "    else:\n", "        return gr.ScatterPlot(\n", "            value=cars,\n", "            x=\"Horsepower\",\n", "            y=\"Miles_per_Gallon\",\n", "            color=\"Origin\",\n", "            tooltip=[\"Name\"],\n", "            title=\"Car Data\",\n", "            y_title=\"Miles per Gallon\",\n", "            color_legend_title=\"Origin of Car\",\n", "            caption=\"MPG vs Horsepower of various cars\",\n", "        )\n", "\n", "with gr.Blocks() as scatter_plot:\n", "    with gr.Row():\n", "        with gr.Column():\n", "            dataset = gr.Dropdown(choices=[\"cars\", \"iris\"], value=\"cars\")\n", "        with gr.Column():\n", "            plot = gr.ScatterPlot()\n", "    dataset.change(scatter_plot_fn, inputs=dataset, outputs=plot)\n", "    scatter_plot.load(fn=scatter_plot_fn, inputs=dataset, outputs=plot)\n", "\n", "if __name__ == \"__main__\":\n", "    scatter_plot.launch()\n"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5}