Spaces:
Running
Running
Merge remote-tracking branch 'upstream/main' into fp/applicatives
Browse files- .github/workflows/deploy.yml +56 -0
- .github/workflows/hf_sync.yml +27 -1
- .gitignore +4 -1
- Dockerfile +13 -0
- Makefile +9 -0
- README.md +15 -0
- _server/README.md +3 -0
- polars/01_why_polars.py +1 -1
- probability/11_expectation.py +860 -0
- probability/12_variance.py +631 -0
- probability/13_bernoulli_distribution.py +427 -0
- probability/14_binomial_distribution.py +545 -0
- probability/15_poisson_distribution.py +805 -0
- probability/16_continuous_distribution.py +979 -0
- probability/17_normal_distribution.py +1127 -0
- probability/18_central_limit_theorem.py +943 -0
- probability/19_maximum_likelihood_estimation.py +1231 -0
- python/006_dictionaries.py +2 -2
- scripts/build.py +281 -0
- scripts/preview.py +76 -0
- scripts/templates/index.html +174 -0
.github/workflows/deploy.yml
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: Deploy to GitHub Pages
|
2 |
+
|
3 |
+
on:
|
4 |
+
push:
|
5 |
+
branches: ['main']
|
6 |
+
workflow_dispatch:
|
7 |
+
|
8 |
+
concurrency:
|
9 |
+
group: 'pages'
|
10 |
+
cancel-in-progress: false
|
11 |
+
|
12 |
+
env:
|
13 |
+
UV_SYSTEM_PYTHON: 1
|
14 |
+
|
15 |
+
jobs:
|
16 |
+
build:
|
17 |
+
runs-on: ubuntu-latest
|
18 |
+
steps:
|
19 |
+
- uses: actions/checkout@v4
|
20 |
+
|
21 |
+
- name: 🚀 Install uv
|
22 |
+
uses: astral-sh/setup-uv@v4
|
23 |
+
|
24 |
+
- name: 🐍 Set up Python
|
25 |
+
uses: actions/setup-python@v5
|
26 |
+
with:
|
27 |
+
python-version: 3.12
|
28 |
+
|
29 |
+
- name: 📦 Install dependencies
|
30 |
+
run: |
|
31 |
+
uv pip install marimo jinja2
|
32 |
+
|
33 |
+
- name: 🛠️ Export notebooks
|
34 |
+
run: |
|
35 |
+
python scripts/build.py
|
36 |
+
|
37 |
+
- name: 📤 Upload artifact
|
38 |
+
uses: actions/upload-pages-artifact@v3
|
39 |
+
with:
|
40 |
+
path: _site
|
41 |
+
|
42 |
+
deploy:
|
43 |
+
needs: build
|
44 |
+
|
45 |
+
permissions:
|
46 |
+
pages: write
|
47 |
+
id-token: write
|
48 |
+
|
49 |
+
environment:
|
50 |
+
name: github-pages
|
51 |
+
url: ${{ steps.deployment.outputs.page_url }}
|
52 |
+
runs-on: ubuntu-latest
|
53 |
+
steps:
|
54 |
+
- name: 🚀 Deploy to GitHub Pages
|
55 |
+
id: deployment
|
56 |
+
uses: actions/deploy-pages@v4
|
.github/workflows/hf_sync.yml
CHANGED
@@ -13,7 +13,33 @@ jobs:
|
|
13 |
- uses: actions/checkout@v4
|
14 |
with:
|
15 |
fetch-depth: 0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
16 |
- name: Push to hub
|
17 |
env:
|
18 |
HF_TOKEN: ${{ secrets.HF_TOKEN }}
|
19 |
-
run: git push https://mylessss:[email protected]/spaces/marimo-team/marimo-learn main
|
|
|
13 |
- uses: actions/checkout@v4
|
14 |
with:
|
15 |
fetch-depth: 0
|
16 |
+
|
17 |
+
- name: Configure Git
|
18 |
+
run: |
|
19 |
+
git config --global user.name "GitHub Action"
|
20 |
+
git config --global user.email "[email protected]"
|
21 |
+
|
22 |
+
- name: Prepend frontmatter to README
|
23 |
+
run: |
|
24 |
+
if [ -f README.md ] && ! grep -q "^---" README.md; then
|
25 |
+
FRONTMATTER="---
|
26 |
+
title: marimo learn
|
27 |
+
emoji: 🧠
|
28 |
+
colorFrom: blue
|
29 |
+
colorTo: indigo
|
30 |
+
sdk: docker
|
31 |
+
sdk_version: \"latest\"
|
32 |
+
app_file: app.py
|
33 |
+
pinned: false
|
34 |
+
---
|
35 |
+
|
36 |
+
"
|
37 |
+
echo "$FRONTMATTER$(cat README.md)" > README.md
|
38 |
+
git add README.md
|
39 |
+
git commit -m "Add HF frontmatter to README" || echo "No changes to commit"
|
40 |
+
fi
|
41 |
+
|
42 |
- name: Push to hub
|
43 |
env:
|
44 |
HF_TOKEN: ${{ secrets.HF_TOKEN }}
|
45 |
+
run: git push -f https://mylessss:[email protected]/spaces/marimo-team/marimo-learn main
|
.gitignore
CHANGED
@@ -168,4 +168,7 @@ cython_debug/
|
|
168 |
#.idea/
|
169 |
|
170 |
# PyPI configuration file
|
171 |
-
.pypirc
|
|
|
|
|
|
|
|
168 |
#.idea/
|
169 |
|
170 |
# PyPI configuration file
|
171 |
+
.pypirc
|
172 |
+
|
173 |
+
# Generated site content
|
174 |
+
_site/
|
Dockerfile
CHANGED
@@ -1,9 +1,22 @@
|
|
1 |
FROM ghcr.io/astral-sh/uv:python3.12-bookworm-slim
|
2 |
|
|
|
|
|
|
|
|
|
|
|
|
|
3 |
COPY _server/main.py _server/main.py
|
4 |
COPY polars/ polars/
|
5 |
COPY duckdb/ duckdb/
|
6 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
7 |
RUN uv venv
|
8 |
RUN uv export --script _server/main.py | uv pip install -r -
|
9 |
|
|
|
1 |
FROM ghcr.io/astral-sh/uv:python3.12-bookworm-slim
|
2 |
|
3 |
+
WORKDIR /app
|
4 |
+
|
5 |
+
# Create a non-root user
|
6 |
+
RUN useradd -m appuser
|
7 |
+
|
8 |
+
# Copy application files
|
9 |
COPY _server/main.py _server/main.py
|
10 |
COPY polars/ polars/
|
11 |
COPY duckdb/ duckdb/
|
12 |
|
13 |
+
# Set proper ownership
|
14 |
+
RUN chown -R appuser:appuser /app
|
15 |
+
|
16 |
+
# Switch to non-root user
|
17 |
+
USER appuser
|
18 |
+
|
19 |
+
# Create virtual environment and install dependencies
|
20 |
RUN uv venv
|
21 |
RUN uv export --script _server/main.py | uv pip install -r -
|
22 |
|
Makefile
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
install:
|
2 |
+
uv pip install marimo jinja2 markdown
|
3 |
+
|
4 |
+
build:
|
5 |
+
rm -rf _site
|
6 |
+
uv run scripts/build.py
|
7 |
+
|
8 |
+
serve:
|
9 |
+
uv run python -m http.server --directory _site
|
README.md
CHANGED
@@ -58,6 +58,21 @@ Here's a contribution checklist:
|
|
58 |
If you aren't comfortable adding a new notebook or course, you can also request
|
59 |
what you'd like to see by [filing an issue](https://github.com/marimo-team/learn/issues/new?template=example_request.yaml).
|
60 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
61 |
## Community
|
62 |
|
63 |
We're building a community. Come hang out with us!
|
|
|
58 |
If you aren't comfortable adding a new notebook or course, you can also request
|
59 |
what you'd like to see by [filing an issue](https://github.com/marimo-team/learn/issues/new?template=example_request.yaml).
|
60 |
|
61 |
+
## Building and Previewing
|
62 |
+
|
63 |
+
The site is built using a Python script that exports marimo notebooks to HTML and generates an index page.
|
64 |
+
|
65 |
+
```bash
|
66 |
+
# Build the site
|
67 |
+
python scripts/build.py --output-dir _site
|
68 |
+
|
69 |
+
# Preview the site (builds first)
|
70 |
+
python scripts/preview.py
|
71 |
+
|
72 |
+
# Preview without rebuilding
|
73 |
+
python scripts/preview.py --no-build
|
74 |
+
```
|
75 |
+
|
76 |
## Community
|
77 |
|
78 |
We're building a community. Come hang out with us!
|
_server/README.md
CHANGED
@@ -5,11 +5,14 @@ This folder contains server code for hosting marimo apps.
|
|
5 |
## Running the server
|
6 |
|
7 |
```bash
|
|
|
8 |
uv run --no-project main.py
|
9 |
```
|
10 |
|
11 |
## Building a Docker image
|
12 |
|
|
|
|
|
13 |
```bash
|
14 |
docker build -t marimo-learn .
|
15 |
```
|
|
|
5 |
## Running the server
|
6 |
|
7 |
```bash
|
8 |
+
cd _server
|
9 |
uv run --no-project main.py
|
10 |
```
|
11 |
|
12 |
## Building a Docker image
|
13 |
|
14 |
+
From the root directory, run:
|
15 |
+
|
16 |
```bash
|
17 |
docker build -t marimo-learn .
|
18 |
```
|
polars/01_why_polars.py
CHANGED
@@ -57,7 +57,7 @@ def _(mo):
|
|
57 |
"""
|
58 |
Unlike Python's earliest DataFrame library Pandas, Polars was designed with performance and usability in mind — Polars can scale to large datasets with ease while maintaining a simple and intuitive API.
|
59 |
|
60 |
-
Polars' performance is due to a number of factors, including its implementation
|
61 |
|
62 |
With its focus on speed, scalability, and ease of use, Polars is quickly becoming a go-to choice for data professionals looking to streamline their data processing pipelines and tackle large-scale data challenges.
|
63 |
"""
|
|
|
57 |
"""
|
58 |
Unlike Python's earliest DataFrame library Pandas, Polars was designed with performance and usability in mind — Polars can scale to large datasets with ease while maintaining a simple and intuitive API.
|
59 |
|
60 |
+
Polars' performance is due to a number of factors, including its implementation in rust and its ability to perform operations in a parallelized and vectorized manner. It supports a wide range of data types, advanced query optimizations, and seamless integration with other Python libraries, making it a versatile tool for data scientists, engineers, and analysts. Additionally, Polars provides a lazy API for deferred execution, allowing users to optimize their workflows by chaining operations and executing them in a single pass.
|
61 |
|
62 |
With its focus on speed, scalability, and ease of use, Polars is quickly becoming a go-to choice for data professionals looking to streamline their data processing pipelines and tackle large-scale data challenges.
|
63 |
"""
|
probability/11_expectation.py
ADDED
@@ -0,0 +1,860 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# /// script
|
2 |
+
# requires-python = ">=3.10"
|
3 |
+
# dependencies = [
|
4 |
+
# "marimo",
|
5 |
+
# "matplotlib==3.10.0",
|
6 |
+
# "numpy==2.2.3",
|
7 |
+
# "scipy==1.15.2",
|
8 |
+
# ]
|
9 |
+
# ///
|
10 |
+
|
11 |
+
import marimo
|
12 |
+
|
13 |
+
__generated_with = "0.11.19"
|
14 |
+
app = marimo.App(width="medium", app_title="Expectation")
|
15 |
+
|
16 |
+
|
17 |
+
@app.cell(hide_code=True)
|
18 |
+
def _(mo):
|
19 |
+
mo.md(
|
20 |
+
r"""
|
21 |
+
# Expectation
|
22 |
+
|
23 |
+
_This notebook is a computational companion to ["Probability for Computer Scientists"](https://chrispiech.github.io/probabilityForComputerScientists/en/part2/expectation/), by Stanford professor Chris Piech._
|
24 |
+
|
25 |
+
A random variable is fully represented by its Probability Mass Function (PMF), which describes each value the random variable can take on and the corresponding probabilities. However, a PMF can contain a lot of information. Sometimes it's useful to summarize a random variable with a single value!
|
26 |
+
|
27 |
+
The most common, and arguably the most useful, summary of a random variable is its **Expectation** (also called the expected value or mean).
|
28 |
+
"""
|
29 |
+
)
|
30 |
+
return
|
31 |
+
|
32 |
+
|
33 |
+
@app.cell(hide_code=True)
|
34 |
+
def _(mo):
|
35 |
+
mo.md(
|
36 |
+
r"""
|
37 |
+
## Definition of Expectation
|
38 |
+
|
39 |
+
The expectation of a random variable $X$, written $E[X]$, is the average of all the values the random variable can take on, each weighted by the probability that the random variable will take on that value.
|
40 |
+
|
41 |
+
$$E[X] = \sum_x x \cdot P(X=x)$$
|
42 |
+
|
43 |
+
Expectation goes by many other names: Mean, Weighted Average, Center of Mass, 1st Moment. All of these are calculated using the same formula.
|
44 |
+
"""
|
45 |
+
)
|
46 |
+
return
|
47 |
+
|
48 |
+
|
49 |
+
@app.cell(hide_code=True)
|
50 |
+
def _(mo):
|
51 |
+
mo.md(
|
52 |
+
r"""
|
53 |
+
## Intuition Behind Expectation
|
54 |
+
|
55 |
+
The expected value represents the long-run average value of a random variable over many independent repetitions of an experiment.
|
56 |
+
|
57 |
+
For example, if you roll a fair six-sided die many times and calculate the average of all rolls, that average will approach the expected value of 3.5 as the number of rolls increases.
|
58 |
+
|
59 |
+
Let's visualize this concept:
|
60 |
+
"""
|
61 |
+
)
|
62 |
+
return
|
63 |
+
|
64 |
+
|
65 |
+
@app.cell(hide_code=True)
|
66 |
+
def _(np, plt):
|
67 |
+
# Set random seed for reproducibility
|
68 |
+
np.random.seed(42)
|
69 |
+
|
70 |
+
# Simulate rolling a die many times
|
71 |
+
exp_num_rolls = 1000
|
72 |
+
exp_die_rolls = np.random.randint(1, 7, size=exp_num_rolls)
|
73 |
+
|
74 |
+
# Calculate the running average
|
75 |
+
exp_running_avg = np.cumsum(exp_die_rolls) / np.arange(1, exp_num_rolls + 1)
|
76 |
+
|
77 |
+
# Create the plot
|
78 |
+
plt.figure(figsize=(10, 5))
|
79 |
+
plt.plot(range(1, exp_num_rolls + 1), exp_running_avg, label='Running Average')
|
80 |
+
plt.axhline(y=3.5, color='r', linestyle='--', label='Expected Value (3.5)')
|
81 |
+
plt.xlabel('Number of Rolls')
|
82 |
+
plt.ylabel('Average Value')
|
83 |
+
plt.title('Running Average of Die Rolls Approaching Expected Value')
|
84 |
+
plt.legend()
|
85 |
+
plt.grid(alpha=0.3)
|
86 |
+
plt.xscale('log') # Log scale to better see convergence
|
87 |
+
|
88 |
+
# Add annotations
|
89 |
+
plt.annotate('As the number of rolls increases,\nthe average approaches the expected value',
|
90 |
+
xy=(exp_num_rolls, exp_running_avg[-1]), xytext=(exp_num_rolls/3, 4),
|
91 |
+
arrowprops=dict(facecolor='black', shrink=0.05, width=1.5))
|
92 |
+
|
93 |
+
plt.gca()
|
94 |
+
return exp_die_rolls, exp_num_rolls, exp_running_avg
|
95 |
+
|
96 |
+
|
97 |
+
@app.cell(hide_code=True)
|
98 |
+
def _(mo):
|
99 |
+
mo.md(r"""## Properties of Expectation""")
|
100 |
+
return
|
101 |
+
|
102 |
+
|
103 |
+
@app.cell(hide_code=True)
|
104 |
+
def _(mo):
|
105 |
+
mo.accordion(
|
106 |
+
{
|
107 |
+
"1. Linearity of Expectation": mo.md(
|
108 |
+
r"""
|
109 |
+
$$E[aX + b] = a \cdot E[X] + b$$
|
110 |
+
|
111 |
+
Where $a$ and $b$ are constants (not random variables).
|
112 |
+
|
113 |
+
This means that if you multiply a random variable by a constant, the expectation is multiplied by that constant. And if you add a constant to a random variable, the expectation increases by that constant.
|
114 |
+
"""
|
115 |
+
),
|
116 |
+
"2. Expectation of the Sum of Random Variables": mo.md(
|
117 |
+
r"""
|
118 |
+
$$E[X + Y] = E[X] + E[Y]$$
|
119 |
+
|
120 |
+
This is true regardless of the relationship between $X$ and $Y$. They can be dependent, and they can have different distributions. This also applies with more than two random variables:
|
121 |
+
|
122 |
+
$$E\left[\sum_{i=1}^n X_i\right] = \sum_{i=1}^n E[X_i]$$
|
123 |
+
"""
|
124 |
+
),
|
125 |
+
"3. Law of the Unconscious Statistician (LOTUS)": mo.md(
|
126 |
+
r"""
|
127 |
+
$$E[g(X)] = \sum_x g(x) \cdot P(X=x)$$
|
128 |
+
|
129 |
+
This allows us to calculate the expected value of a function $g(X)$ of a random variable $X$ when we know the probability distribution of $X$ but don't explicitly know the distribution of $g(X)$.
|
130 |
+
|
131 |
+
This theorem has the humorous name "Law of the Unconscious Statistician" (LOTUS) because it's so useful that you should be able to employ it unconsciously.
|
132 |
+
"""
|
133 |
+
),
|
134 |
+
"4. Expectation of a Constant": mo.md(
|
135 |
+
r"""
|
136 |
+
$$E[a] = a$$
|
137 |
+
|
138 |
+
Sometimes in proofs, you'll end up with the expectation of a constant (rather than a random variable). Since a constant doesn't change, its expected value is just the constant itself.
|
139 |
+
"""
|
140 |
+
),
|
141 |
+
}
|
142 |
+
)
|
143 |
+
return
|
144 |
+
|
145 |
+
|
146 |
+
@app.cell(hide_code=True)
|
147 |
+
def _(mo):
|
148 |
+
mo.md(
|
149 |
+
r"""
|
150 |
+
## Calculating Expectation
|
151 |
+
|
152 |
+
Let's calculate the expected value for some common examples:
|
153 |
+
|
154 |
+
### Example 1: Fair Die Roll
|
155 |
+
|
156 |
+
For a fair six-sided die, the PMF is:
|
157 |
+
|
158 |
+
$$P(X=x) = \frac{1}{6} \text{ for } x \in \{1, 2, 3, 4, 5, 6\}$$
|
159 |
+
|
160 |
+
The expected value is:
|
161 |
+
|
162 |
+
$$E[X] = 1 \cdot \frac{1}{6} + 2 \cdot \frac{1}{6} + 3 \cdot \frac{1}{6} + 4 \cdot \frac{1}{6} + 5 \cdot \frac{1}{6} + 6 \cdot \frac{1}{6} = \frac{21}{6} = 3.5$$
|
163 |
+
|
164 |
+
Let's implement this calculation in Python:
|
165 |
+
"""
|
166 |
+
)
|
167 |
+
return
|
168 |
+
|
169 |
+
|
170 |
+
@app.cell
|
171 |
+
def _():
|
172 |
+
def calc_expectation_die():
|
173 |
+
"""Calculate the expected value of a fair six-sided die roll."""
|
174 |
+
exp_die_values = range(1, 7)
|
175 |
+
exp_die_probs = [1/6] * 6
|
176 |
+
|
177 |
+
exp_die_expected = sum(x * p for x, p in zip(exp_die_values, exp_die_probs))
|
178 |
+
return exp_die_expected
|
179 |
+
|
180 |
+
exp_die_result = calc_expectation_die()
|
181 |
+
print(f"Expected value of a fair die roll: {exp_die_result}")
|
182 |
+
return calc_expectation_die, exp_die_result
|
183 |
+
|
184 |
+
|
185 |
+
@app.cell(hide_code=True)
|
186 |
+
def _(mo):
|
187 |
+
mo.md(
|
188 |
+
r"""
|
189 |
+
### Example 2: Sum of Two Dice
|
190 |
+
|
191 |
+
Now let's calculate the expected value for the sum of two fair dice. First, we need the PMF:
|
192 |
+
"""
|
193 |
+
)
|
194 |
+
return
|
195 |
+
|
196 |
+
|
197 |
+
@app.cell
|
198 |
+
def _():
|
199 |
+
def pmf_sum_two_dice(y_val):
|
200 |
+
"""Returns the probability that the sum of two dice is y."""
|
201 |
+
# Count the number of ways to get sum y
|
202 |
+
exp_count = 0
|
203 |
+
for dice1 in range(1, 7):
|
204 |
+
for dice2 in range(1, 7):
|
205 |
+
if dice1 + dice2 == y_val:
|
206 |
+
exp_count += 1
|
207 |
+
return exp_count / 36 # There are 36 possible outcomes (6×6)
|
208 |
+
|
209 |
+
# Test the function for a few values
|
210 |
+
exp_test_values = [2, 7, 12]
|
211 |
+
for exp_test_y in exp_test_values:
|
212 |
+
print(f"P(Y = {exp_test_y}) = {pmf_sum_two_dice(exp_test_y)}")
|
213 |
+
return exp_test_values, exp_test_y, pmf_sum_two_dice
|
214 |
+
|
215 |
+
|
216 |
+
@app.cell
|
217 |
+
def _(pmf_sum_two_dice):
|
218 |
+
def calc_expectation_sum_two_dice():
|
219 |
+
"""Calculate the expected value of the sum of two dice."""
|
220 |
+
exp_sum_two_dice = 0
|
221 |
+
# Sum of dice can take on the values 2 through 12
|
222 |
+
for exp_x in range(2, 13):
|
223 |
+
exp_pr_x = pmf_sum_two_dice(exp_x) # PMF gives P(sum is x)
|
224 |
+
exp_sum_two_dice += exp_x * exp_pr_x
|
225 |
+
return exp_sum_two_dice
|
226 |
+
|
227 |
+
exp_sum_result = calc_expectation_sum_two_dice()
|
228 |
+
|
229 |
+
# Round to 2 decimal places for display
|
230 |
+
exp_sum_result_rounded = round(exp_sum_result, 2)
|
231 |
+
|
232 |
+
print(f"Expected value of the sum of two dice: {exp_sum_result_rounded}")
|
233 |
+
|
234 |
+
# Let's also verify this with a direct calculation
|
235 |
+
exp_direct_calc = sum(x * pmf_sum_two_dice(x) for x in range(2, 13))
|
236 |
+
exp_direct_calc_rounded = round(exp_direct_calc, 2)
|
237 |
+
|
238 |
+
print(f"Direct calculation: {exp_direct_calc_rounded}")
|
239 |
+
|
240 |
+
# Verify that this equals 7
|
241 |
+
print(f"Is the expected value exactly 7? {abs(exp_sum_result - 7) < 1e-10}")
|
242 |
+
return (
|
243 |
+
calc_expectation_sum_two_dice,
|
244 |
+
exp_direct_calc,
|
245 |
+
exp_direct_calc_rounded,
|
246 |
+
exp_sum_result,
|
247 |
+
exp_sum_result_rounded,
|
248 |
+
)
|
249 |
+
|
250 |
+
|
251 |
+
@app.cell(hide_code=True)
|
252 |
+
def _(mo):
|
253 |
+
mo.md(
|
254 |
+
r"""
|
255 |
+
### Visualizing Expectation
|
256 |
+
|
257 |
+
Let's visualize the expectation for the sum of two dice. The expected value is the "center of mass" of the PMF:
|
258 |
+
"""
|
259 |
+
)
|
260 |
+
return
|
261 |
+
|
262 |
+
|
263 |
+
@app.cell(hide_code=True)
|
264 |
+
def _(plt, pmf_sum_two_dice):
|
265 |
+
# Create the visualization
|
266 |
+
exp_y_values = list(range(2, 13))
|
267 |
+
exp_probabilities = [pmf_sum_two_dice(y) for y in exp_y_values]
|
268 |
+
|
269 |
+
dice_fig, dice_ax = plt.subplots(figsize=(10, 5))
|
270 |
+
dice_ax.bar(exp_y_values, exp_probabilities, width=0.4)
|
271 |
+
dice_ax.axvline(x=7, color='r', linestyle='--', linewidth=2, label='Expected Value (7)')
|
272 |
+
|
273 |
+
dice_ax.set_xticks(exp_y_values)
|
274 |
+
dice_ax.set_xlabel('Sum of two dice (y)')
|
275 |
+
dice_ax.set_ylabel('Probability: P(Y = y)')
|
276 |
+
dice_ax.set_title('PMF of Sum of Two Dice with Expected Value')
|
277 |
+
dice_ax.grid(alpha=0.3)
|
278 |
+
dice_ax.legend()
|
279 |
+
|
280 |
+
# Add probability values on top of bars
|
281 |
+
for exp_i, exp_prob in enumerate(exp_probabilities):
|
282 |
+
dice_ax.text(exp_y_values[exp_i], exp_prob + 0.001, f'{exp_prob:.3f}', ha='center')
|
283 |
+
|
284 |
+
plt.tight_layout()
|
285 |
+
plt.gca()
|
286 |
+
return dice_ax, dice_fig, exp_i, exp_prob, exp_probabilities, exp_y_values
|
287 |
+
|
288 |
+
|
289 |
+
@app.cell(hide_code=True)
|
290 |
+
def _(mo):
|
291 |
+
mo.md(
|
292 |
+
r"""
|
293 |
+
## Demonstrating the Properties of Expectation
|
294 |
+
|
295 |
+
Let's demonstrate some of these properties with examples:
|
296 |
+
"""
|
297 |
+
)
|
298 |
+
return
|
299 |
+
|
300 |
+
|
301 |
+
@app.cell
|
302 |
+
def _(exp_die_result):
|
303 |
+
# Demonstrate linearity of expectation (1)
|
304 |
+
# E[aX + b] = a*E[X] + b
|
305 |
+
|
306 |
+
# For a die roll X with E[X] = 3.5
|
307 |
+
prop_a = 2
|
308 |
+
prop_b = 10
|
309 |
+
|
310 |
+
# Calculate E[2X + 10] using the property
|
311 |
+
prop_expected_using_property = prop_a * exp_die_result + prop_b
|
312 |
+
prop_expected_using_property_rounded = round(prop_expected_using_property, 2)
|
313 |
+
|
314 |
+
print(f"Using linearity property: E[{prop_a}X + {prop_b}] = {prop_a} * E[X] + {prop_b} = {prop_expected_using_property_rounded}")
|
315 |
+
|
316 |
+
# Calculate E[2X + 10] directly
|
317 |
+
prop_expected_direct = sum((prop_a * x + prop_b) * (1/6) for x in range(1, 7))
|
318 |
+
prop_expected_direct_rounded = round(prop_expected_direct, 2)
|
319 |
+
|
320 |
+
print(f"Direct calculation: E[{prop_a}X + {prop_b}] = {prop_expected_direct_rounded}")
|
321 |
+
|
322 |
+
# Verify they match
|
323 |
+
print(f"Do they match? {abs(prop_expected_using_property - prop_expected_direct) < 1e-10}")
|
324 |
+
return (
|
325 |
+
prop_a,
|
326 |
+
prop_b,
|
327 |
+
prop_expected_direct,
|
328 |
+
prop_expected_direct_rounded,
|
329 |
+
prop_expected_using_property,
|
330 |
+
prop_expected_using_property_rounded,
|
331 |
+
)
|
332 |
+
|
333 |
+
|
334 |
+
@app.cell(hide_code=True)
|
335 |
+
def _(mo):
|
336 |
+
mo.md(
|
337 |
+
r"""
|
338 |
+
### Law of the Unconscious Statistician (LOTUS)
|
339 |
+
|
340 |
+
Let's use LOTUS to calculate $E[X^2]$ for a die roll, which will be useful when we study variance:
|
341 |
+
"""
|
342 |
+
)
|
343 |
+
return
|
344 |
+
|
345 |
+
|
346 |
+
@app.cell
|
347 |
+
def _():
|
348 |
+
# Calculate E[X^2] for a die roll using LOTUS (3)
|
349 |
+
lotus_die_values = range(1, 7)
|
350 |
+
lotus_die_probs = [1/6] * 6
|
351 |
+
|
352 |
+
# Using LOTUS: E[X^2] = sum(x^2 * P(X=x))
|
353 |
+
lotus_expected_x_squared = sum(x**2 * p for x, p in zip(lotus_die_values, lotus_die_probs))
|
354 |
+
lotus_expected_x_squared_rounded = round(lotus_expected_x_squared, 2)
|
355 |
+
|
356 |
+
expected_x_squared = 3.5**2
|
357 |
+
expected_x_squared_rounded = round(expected_x_squared, 2)
|
358 |
+
|
359 |
+
print(f"E[X^2] for a die roll = {lotus_expected_x_squared_rounded}")
|
360 |
+
print(f"(E[X])^2 for a die roll = {expected_x_squared_rounded}")
|
361 |
+
return (
|
362 |
+
expected_x_squared,
|
363 |
+
expected_x_squared_rounded,
|
364 |
+
lotus_die_probs,
|
365 |
+
lotus_die_values,
|
366 |
+
lotus_expected_x_squared,
|
367 |
+
lotus_expected_x_squared_rounded,
|
368 |
+
)
|
369 |
+
|
370 |
+
|
371 |
+
@app.cell(hide_code=True)
|
372 |
+
def _(mo):
|
373 |
+
mo.md(
|
374 |
+
r"""
|
375 |
+
/// Note
|
376 |
+
Note that E[X^2] != (E[X])^2
|
377 |
+
"""
|
378 |
+
)
|
379 |
+
return
|
380 |
+
|
381 |
+
|
382 |
+
@app.cell(hide_code=True)
|
383 |
+
def _(mo):
|
384 |
+
mo.md(
|
385 |
+
r"""
|
386 |
+
## Interactive Example
|
387 |
+
|
388 |
+
Let's explore how the expected value changes as we adjust the parameters of common probability distributions. This interactive visualization focuses specifically on the relationship between distribution parameters and expected values.
|
389 |
+
|
390 |
+
Use the controls below to select a distribution and adjust its parameters. The graph will show how the expected value changes across a range of parameter values.
|
391 |
+
"""
|
392 |
+
)
|
393 |
+
return
|
394 |
+
|
395 |
+
|
396 |
+
@app.cell(hide_code=True)
|
397 |
+
def _(mo):
|
398 |
+
# Create UI elements for distribution selection
|
399 |
+
dist_selection = mo.ui.dropdown(
|
400 |
+
options=[
|
401 |
+
"bernoulli",
|
402 |
+
"binomial",
|
403 |
+
"geometric",
|
404 |
+
"poisson"
|
405 |
+
],
|
406 |
+
value="bernoulli",
|
407 |
+
label="Select a distribution"
|
408 |
+
)
|
409 |
+
return (dist_selection,)
|
410 |
+
|
411 |
+
|
412 |
+
@app.cell(hide_code=True)
|
413 |
+
def _(dist_selection):
|
414 |
+
dist_selection.center()
|
415 |
+
return
|
416 |
+
|
417 |
+
|
418 |
+
@app.cell(hide_code=True)
|
419 |
+
def _(dist_description):
|
420 |
+
dist_description
|
421 |
+
return
|
422 |
+
|
423 |
+
|
424 |
+
@app.cell(hide_code=True)
|
425 |
+
def _(mo):
|
426 |
+
mo.md("""### Adjust Parameters""")
|
427 |
+
return
|
428 |
+
|
429 |
+
|
430 |
+
@app.cell(hide_code=True)
|
431 |
+
def _(controls):
|
432 |
+
controls
|
433 |
+
return
|
434 |
+
|
435 |
+
|
436 |
+
@app.cell(hide_code=True)
|
437 |
+
def _(
|
438 |
+
dist_selection,
|
439 |
+
lambda_range,
|
440 |
+
np,
|
441 |
+
param_lambda,
|
442 |
+
param_n,
|
443 |
+
param_p,
|
444 |
+
param_range,
|
445 |
+
plt,
|
446 |
+
):
|
447 |
+
# Calculate expected values based on the selected distribution
|
448 |
+
if dist_selection.value == "bernoulli":
|
449 |
+
# Get parameter range for visualization
|
450 |
+
p_min, p_max = param_range.value
|
451 |
+
param_values = np.linspace(p_min, p_max, 100)
|
452 |
+
|
453 |
+
# E[X] = p for Bernoulli
|
454 |
+
expected_values = param_values
|
455 |
+
current_param = param_p.value
|
456 |
+
current_expected = round(current_param, 2)
|
457 |
+
x_label = "p (probability of success)"
|
458 |
+
title = "Expected Value of Bernoulli Distribution"
|
459 |
+
formula = "E[X] = p"
|
460 |
+
|
461 |
+
elif dist_selection.value == "binomial":
|
462 |
+
p_min, p_max = param_range.value
|
463 |
+
param_values = np.linspace(p_min, p_max, 100)
|
464 |
+
|
465 |
+
# E[X] = np for Binomial
|
466 |
+
n = int(param_n.value)
|
467 |
+
expected_values = [n * p for p in param_values]
|
468 |
+
current_param = param_p.value
|
469 |
+
current_expected = round(n * current_param, 2)
|
470 |
+
x_label = "p (probability of success)"
|
471 |
+
title = f"Expected Value of Binomial Distribution (n={n})"
|
472 |
+
formula = f"E[X] = n × p = {n} × p"
|
473 |
+
|
474 |
+
elif dist_selection.value == "geometric":
|
475 |
+
p_min, p_max = param_range.value
|
476 |
+
# Ensure p is not 0 for geometric distribution
|
477 |
+
p_min = max(0.01, p_min)
|
478 |
+
param_values = np.linspace(p_min, p_max, 100)
|
479 |
+
|
480 |
+
# E[X] = 1/p for Geometric
|
481 |
+
expected_values = [1/p for p in param_values]
|
482 |
+
current_param = param_p.value
|
483 |
+
current_expected = round(1 / current_param, 2)
|
484 |
+
x_label = "p (probability of success)"
|
485 |
+
title = "Expected Value of Geometric Distribution"
|
486 |
+
formula = "E[X] = 1/p"
|
487 |
+
|
488 |
+
else: # Poisson
|
489 |
+
lambda_min, lambda_max = lambda_range.value
|
490 |
+
param_values = np.linspace(lambda_min, lambda_max, 100)
|
491 |
+
|
492 |
+
# E[X] = lambda for Poisson
|
493 |
+
expected_values = param_values
|
494 |
+
current_param = param_lambda.value
|
495 |
+
current_expected = round(current_param, 2)
|
496 |
+
x_label = "λ (rate parameter)"
|
497 |
+
title = "Expected Value of Poisson Distribution"
|
498 |
+
formula = "E[X] = λ"
|
499 |
+
|
500 |
+
# Create the plot
|
501 |
+
dist_fig, dist_ax = plt.subplots(figsize=(10, 6))
|
502 |
+
|
503 |
+
# Plot the expected value function
|
504 |
+
dist_ax.plot(param_values, expected_values, 'b-', linewidth=2, label="Expected Value Function")
|
505 |
+
|
506 |
+
dist_ax.plot(current_param, current_expected, 'ro', markersize=10, label=f"Current Value: E[X] = {current_expected}")
|
507 |
+
|
508 |
+
dist_ax.hlines(current_expected, param_values[0], current_param, colors='r', linestyles='dashed')
|
509 |
+
|
510 |
+
dist_ax.vlines(current_param, 0, current_expected, colors='r', linestyles='dashed')
|
511 |
+
|
512 |
+
dist_ax.fill_between(param_values, 0, expected_values, alpha=0.2, color='blue')
|
513 |
+
|
514 |
+
dist_ax.set_xlabel(x_label, fontsize=12)
|
515 |
+
dist_ax.set_ylabel("Expected Value: E[X]", fontsize=12)
|
516 |
+
dist_ax.set_title(title, fontsize=14, fontweight='bold')
|
517 |
+
dist_ax.grid(True, alpha=0.3)
|
518 |
+
|
519 |
+
# Move legend to lower right to avoid overlap with formula
|
520 |
+
dist_ax.legend(loc='lower right', fontsize=10)
|
521 |
+
|
522 |
+
# Add formula text box in upper left
|
523 |
+
dist_props = dict(boxstyle='round', facecolor='white', alpha=0.8)
|
524 |
+
dist_ax.text(0.02, 0.95, formula, transform=dist_ax.transAxes, fontsize=12,
|
525 |
+
verticalalignment='top', bbox=dist_props)
|
526 |
+
|
527 |
+
if dist_selection.value == "geometric":
|
528 |
+
max_y = min(50, 2/max(0.01, param_values[0]))
|
529 |
+
dist_ax.set_ylim(0, max_y)
|
530 |
+
elif dist_selection.value == "binomial":
|
531 |
+
dist_ax.set_ylim(0, int(param_n.value) + 1)
|
532 |
+
else:
|
533 |
+
dist_ax.set_ylim(0, max(expected_values) * 1.1)
|
534 |
+
|
535 |
+
annotation_x = current_param + (param_values[-1] - param_values[0]) * 0.05
|
536 |
+
annotation_y = current_expected
|
537 |
+
|
538 |
+
# Adjust annotation position if it would go off the chart
|
539 |
+
if annotation_x > param_values[-1] * 0.9:
|
540 |
+
annotation_x = current_param - (param_values[-1] - param_values[0]) * 0.2
|
541 |
+
|
542 |
+
dist_ax.annotate(
|
543 |
+
f"Parameter: {current_param:.2f}\nE[X] = {current_expected}",
|
544 |
+
xy=(current_param, current_expected),
|
545 |
+
xytext=(annotation_x, annotation_y),
|
546 |
+
arrowprops=dict(facecolor='black', shrink=0.05, width=1.5, alpha=0.7),
|
547 |
+
bbox=dist_props
|
548 |
+
)
|
549 |
+
|
550 |
+
plt.tight_layout()
|
551 |
+
plt.gca()
|
552 |
+
return (
|
553 |
+
annotation_x,
|
554 |
+
annotation_y,
|
555 |
+
current_expected,
|
556 |
+
current_param,
|
557 |
+
dist_ax,
|
558 |
+
dist_fig,
|
559 |
+
dist_props,
|
560 |
+
expected_values,
|
561 |
+
formula,
|
562 |
+
lambda_max,
|
563 |
+
lambda_min,
|
564 |
+
max_y,
|
565 |
+
n,
|
566 |
+
p_max,
|
567 |
+
p_min,
|
568 |
+
param_values,
|
569 |
+
title,
|
570 |
+
x_label,
|
571 |
+
)
|
572 |
+
|
573 |
+
|
574 |
+
@app.cell(hide_code=True)
|
575 |
+
def _(mo):
|
576 |
+
mo.md(
|
577 |
+
r"""
|
578 |
+
## Expectation vs. Mode
|
579 |
+
|
580 |
+
The expected value (mean) of a random variable is not always the same as its most likely value (mode). Let's explore this with an example:
|
581 |
+
"""
|
582 |
+
)
|
583 |
+
return
|
584 |
+
|
585 |
+
|
586 |
+
@app.cell(hide_code=True)
|
587 |
+
def _(np, plt, stats):
|
588 |
+
# Create a skewed distribution
|
589 |
+
skew_n = 10
|
590 |
+
skew_p = 0.25
|
591 |
+
|
592 |
+
# Binomial PMF
|
593 |
+
skew_x_values = np.arange(0, skew_n+1)
|
594 |
+
skew_pmf_values = stats.binom.pmf(skew_x_values, skew_n, skew_p)
|
595 |
+
|
596 |
+
# Find the mode (most likely value)
|
597 |
+
skew_mode = skew_x_values[np.argmax(skew_pmf_values)]
|
598 |
+
|
599 |
+
# Calculate the expected value
|
600 |
+
skew_expected = skew_n * skew_p
|
601 |
+
skew_expected_rounded = round(skew_expected, 2)
|
602 |
+
|
603 |
+
skew_fig, skew_ax = plt.subplots(figsize=(10, 5))
|
604 |
+
skew_ax.bar(skew_x_values, skew_pmf_values, alpha=0.7, width=0.4)
|
605 |
+
|
606 |
+
# Add vertical lines for mode and expected value
|
607 |
+
skew_ax.axvline(x=skew_mode, color='g', linestyle='--', linewidth=2,
|
608 |
+
label=f'Mode = {skew_mode} (Most likely value)')
|
609 |
+
skew_ax.axvline(x=skew_expected, color='r', linestyle='--', linewidth=2,
|
610 |
+
label=f'Expected Value = {skew_expected_rounded} (Mean)')
|
611 |
+
|
612 |
+
skew_ax.annotate('Mode', xy=(skew_mode, 0.05), xytext=(skew_mode-2.0, 0.1),
|
613 |
+
arrowprops=dict(facecolor='green', shrink=0.05, width=1.5), color='green')
|
614 |
+
skew_ax.annotate('Expected Value', xy=(skew_expected, 0.05), xytext=(skew_expected+1, 0.15),
|
615 |
+
arrowprops=dict(facecolor='red', shrink=0.05, width=1.5), color='red')
|
616 |
+
|
617 |
+
if skew_mode != int(skew_expected):
|
618 |
+
min_x = min(skew_mode, skew_expected)
|
619 |
+
max_x = max(skew_mode, skew_expected)
|
620 |
+
skew_ax.axvspan(min_x, max_x, alpha=0.2, color='purple')
|
621 |
+
|
622 |
+
# Add text explaining the difference
|
623 |
+
mid_x = (skew_mode + skew_expected) / 2
|
624 |
+
skew_ax.text(mid_x, max(skew_pmf_values) * 0.5,
|
625 |
+
f"Difference: {abs(skew_mode - skew_expected_rounded):.2f}",
|
626 |
+
ha='center', va='center', bbox=dict(facecolor='white', alpha=0.7))
|
627 |
+
|
628 |
+
skew_ax.set_xlabel('Number of Successes')
|
629 |
+
skew_ax.set_ylabel('Probability')
|
630 |
+
skew_ax.set_title(f'Binomial Distribution (n={skew_n}, p={skew_p})')
|
631 |
+
skew_ax.grid(alpha=0.3)
|
632 |
+
skew_ax.legend()
|
633 |
+
|
634 |
+
plt.tight_layout()
|
635 |
+
plt.gca()
|
636 |
+
return (
|
637 |
+
max_x,
|
638 |
+
mid_x,
|
639 |
+
min_x,
|
640 |
+
skew_ax,
|
641 |
+
skew_expected,
|
642 |
+
skew_expected_rounded,
|
643 |
+
skew_fig,
|
644 |
+
skew_mode,
|
645 |
+
skew_n,
|
646 |
+
skew_p,
|
647 |
+
skew_pmf_values,
|
648 |
+
skew_x_values,
|
649 |
+
)
|
650 |
+
|
651 |
+
|
652 |
+
@app.cell(hide_code=True)
|
653 |
+
def _(mo):
|
654 |
+
mo.md(
|
655 |
+
r"""
|
656 |
+
/// NOTE
|
657 |
+
For the sum of two dice we calculated earlier, we found the expected value to be exactly 7. In that case, 7 also happens to be the mode (most likely outcome) of the distribution. However, this is just a coincidence for this particular example!
|
658 |
+
|
659 |
+
As we can see from the binomial distribution above, the expected value (2.50) and the mode (2) are often different values (this is common in skewed distributions). The expected value represents the "center of mass" of the distribution, while the mode represents the most likely single outcome.
|
660 |
+
"""
|
661 |
+
)
|
662 |
+
return
|
663 |
+
|
664 |
+
|
665 |
+
@app.cell(hide_code=True)
|
666 |
+
def _(mo):
|
667 |
+
mo.md(
|
668 |
+
r"""
|
669 |
+
## 🤔 Test Your Understanding
|
670 |
+
|
671 |
+
Choose what you believe are the correct options in the questions below:
|
672 |
+
|
673 |
+
<details>
|
674 |
+
<summary>The expected value of a random variable is always one of the possible values the random variable can take.</summary>
|
675 |
+
❌ False! The expected value is a weighted average and may not be a value the random variable can actually take. For example, the expected value of a fair die roll is 3.5, which is not a possible outcome.
|
676 |
+
</details>
|
677 |
+
|
678 |
+
<details>
|
679 |
+
<summary>If X and Y are independent random variables, then E[X·Y] = E[X]·E[Y].</summary>
|
680 |
+
✅ True! For independent random variables, the expectation of their product equals the product of their expectations.
|
681 |
+
</details>
|
682 |
+
|
683 |
+
<details>
|
684 |
+
<summary>The expected value of a constant random variable (one that always takes the same value) is that constant.</summary>
|
685 |
+
✅ True! If X = c with probability 1, then E[X] = c.
|
686 |
+
</details>
|
687 |
+
|
688 |
+
<details>
|
689 |
+
<summary>The expected value of the sum of two random variables is always the sum of their expected values, regardless of whether they are independent.</summary>
|
690 |
+
✅ True! This is the linearity of expectation property: E[X + Y] = E[X] + E[Y], which holds regardless of dependence.
|
691 |
+
</details>
|
692 |
+
"""
|
693 |
+
)
|
694 |
+
return
|
695 |
+
|
696 |
+
|
697 |
+
@app.cell(hide_code=True)
|
698 |
+
def _(mo):
|
699 |
+
mo.md(
|
700 |
+
r"""
|
701 |
+
## Practical Applications of Expectation
|
702 |
+
|
703 |
+
Expected values show up everywhere - from investment decisions and insurance pricing to machine learning algorithms and game design. Engineers use them to predict system reliability, data scientists to understand customer behavior, and economists to model market outcomes. They're essential for risk assessment in project management and for optimizing resource allocation in operations research.
|
704 |
+
"""
|
705 |
+
)
|
706 |
+
return
|
707 |
+
|
708 |
+
|
709 |
+
@app.cell(hide_code=True)
|
710 |
+
def _(mo):
|
711 |
+
mo.md(
|
712 |
+
r"""
|
713 |
+
## Key Takeaways
|
714 |
+
|
715 |
+
Expectation gives us a single value that summarizes a random variable's central tendency - it's the weighted average of all possible outcomes, where the weights are probabilities. The linearity property makes expectations easy to work with, even for complex combinations of random variables. While a PMF gives the complete probability picture, expectation provides an essential summary that helps us make decisions under uncertainty. In our next notebook, we'll explore variance, which measures how spread out a random variable's values are around its expectation.
|
716 |
+
"""
|
717 |
+
)
|
718 |
+
return
|
719 |
+
|
720 |
+
|
721 |
+
@app.cell(hide_code=True)
|
722 |
+
def _(mo):
|
723 |
+
mo.md(r"""#### Appendix (containing helper code)""")
|
724 |
+
return
|
725 |
+
|
726 |
+
|
727 |
+
@app.cell(hide_code=True)
|
728 |
+
def _():
|
729 |
+
import marimo as mo
|
730 |
+
return (mo,)
|
731 |
+
|
732 |
+
|
733 |
+
@app.cell(hide_code=True)
|
734 |
+
def _():
|
735 |
+
import matplotlib.pyplot as plt
|
736 |
+
import numpy as np
|
737 |
+
from scipy import stats
|
738 |
+
import collections
|
739 |
+
return collections, np, plt, stats
|
740 |
+
|
741 |
+
|
742 |
+
@app.cell(hide_code=True)
|
743 |
+
def _(dist_selection, mo):
|
744 |
+
# Parameter controls for probability-based distributions
|
745 |
+
param_p = mo.ui.slider(
|
746 |
+
start=0.01,
|
747 |
+
stop=0.99,
|
748 |
+
step=0.01,
|
749 |
+
value=0.5,
|
750 |
+
label="p (probability of success)",
|
751 |
+
full_width=True
|
752 |
+
)
|
753 |
+
|
754 |
+
# Parameter control for binomial distribution
|
755 |
+
param_n = mo.ui.slider(
|
756 |
+
start=1,
|
757 |
+
stop=50,
|
758 |
+
step=1,
|
759 |
+
value=10,
|
760 |
+
label="n (number of trials)",
|
761 |
+
full_width=True
|
762 |
+
)
|
763 |
+
|
764 |
+
# Parameter control for Poisson distribution
|
765 |
+
param_lambda = mo.ui.slider(
|
766 |
+
start=0.1,
|
767 |
+
stop=20,
|
768 |
+
step=0.1,
|
769 |
+
value=5,
|
770 |
+
label="λ (rate parameter)",
|
771 |
+
full_width=True
|
772 |
+
)
|
773 |
+
|
774 |
+
# Parameter range sliders for visualization
|
775 |
+
param_range = mo.ui.range_slider(
|
776 |
+
start=0,
|
777 |
+
stop=1,
|
778 |
+
step=0.01,
|
779 |
+
value=[0, 1],
|
780 |
+
label="Parameter range to visualize",
|
781 |
+
full_width=True
|
782 |
+
)
|
783 |
+
|
784 |
+
lambda_range = mo.ui.range_slider(
|
785 |
+
start=0,
|
786 |
+
stop=20,
|
787 |
+
step=0.1,
|
788 |
+
value=[0, 20],
|
789 |
+
label="λ range to visualize",
|
790 |
+
full_width=True
|
791 |
+
)
|
792 |
+
|
793 |
+
# Display appropriate controls based on the selected distribution
|
794 |
+
if dist_selection.value == "bernoulli":
|
795 |
+
controls = mo.hstack([param_p, param_range], justify="space-around")
|
796 |
+
elif dist_selection.value == "binomial":
|
797 |
+
controls = mo.hstack([param_p, param_n, param_range], justify="space-around")
|
798 |
+
elif dist_selection.value == "geometric":
|
799 |
+
controls = mo.hstack([param_p, param_range], justify="space-around")
|
800 |
+
else: # poisson
|
801 |
+
controls = mo.hstack([param_lambda, lambda_range], justify="space-around")
|
802 |
+
return controls, lambda_range, param_lambda, param_n, param_p, param_range
|
803 |
+
|
804 |
+
|
805 |
+
@app.cell(hide_code=True)
|
806 |
+
def _(dist_selection, mo):
|
807 |
+
# Create distribution descriptions based on selection
|
808 |
+
if dist_selection.value == "bernoulli":
|
809 |
+
dist_description = mo.md(
|
810 |
+
r"""
|
811 |
+
**Bernoulli Distribution**
|
812 |
+
|
813 |
+
A Bernoulli distribution models a single trial with two possible outcomes: success (1) or failure (0).
|
814 |
+
|
815 |
+
- Parameter: $p$ = probability of success
|
816 |
+
- Expected Value: $E[X] = p$
|
817 |
+
- Example: Flipping a coin once (p = 0.5 for a fair coin)
|
818 |
+
"""
|
819 |
+
)
|
820 |
+
elif dist_selection.value == "binomial":
|
821 |
+
dist_description = mo.md(
|
822 |
+
r"""
|
823 |
+
**Binomial Distribution**
|
824 |
+
|
825 |
+
A Binomial distribution models the number of successes in $n$ independent trials.
|
826 |
+
|
827 |
+
- Parameters: $n$ = number of trials, $p$ = probability of success
|
828 |
+
- Expected Value: $E[X] = np$
|
829 |
+
- Example: Number of heads in 10 coin flips
|
830 |
+
"""
|
831 |
+
)
|
832 |
+
elif dist_selection.value == "geometric":
|
833 |
+
dist_description = mo.md(
|
834 |
+
r"""
|
835 |
+
**Geometric Distribution**
|
836 |
+
|
837 |
+
A Geometric distribution models the number of trials until the first success.
|
838 |
+
|
839 |
+
- Parameter: $p$ = probability of success
|
840 |
+
- Expected Value: $E[X] = \frac{1}{p}$
|
841 |
+
- Example: Number of coin flips until first heads
|
842 |
+
"""
|
843 |
+
)
|
844 |
+
else: # poisson
|
845 |
+
dist_description = mo.md(
|
846 |
+
r"""
|
847 |
+
**Poisson Distribution**
|
848 |
+
|
849 |
+
A Poisson distribution models the number of events occurring in a fixed interval.
|
850 |
+
|
851 |
+
- Parameter: $\lambda$ = average rate of events
|
852 |
+
- Expected Value: $E[X] = \lambda$
|
853 |
+
- Example: Number of emails received per hour
|
854 |
+
"""
|
855 |
+
)
|
856 |
+
return (dist_description,)
|
857 |
+
|
858 |
+
|
859 |
+
if __name__ == "__main__":
|
860 |
+
app.run()
|
probability/12_variance.py
ADDED
@@ -0,0 +1,631 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# /// script
|
2 |
+
# requires-python = ">=3.10"
|
3 |
+
# dependencies = [
|
4 |
+
# "marimo",
|
5 |
+
# "matplotlib==3.10.0",
|
6 |
+
# "numpy==2.2.3",
|
7 |
+
# "scipy==1.15.2",
|
8 |
+
# "wigglystuff==0.1.10",
|
9 |
+
# ]
|
10 |
+
# ///
|
11 |
+
|
12 |
+
import marimo
|
13 |
+
|
14 |
+
__generated_with = "0.11.20"
|
15 |
+
app = marimo.App(width="medium", app_title="Variance")
|
16 |
+
|
17 |
+
|
18 |
+
@app.cell(hide_code=True)
|
19 |
+
def _(mo):
|
20 |
+
mo.md(
|
21 |
+
r"""
|
22 |
+
# Variance
|
23 |
+
|
24 |
+
_This notebook is a computational companion to ["Probability for Computer Scientists"](https://chrispiech.github.io/probabilityForComputerScientists/en/part2/variance/), by Stanford professor Chris Piech._
|
25 |
+
|
26 |
+
In our previous exploration of random variables, we learned about expectation - a measure of central tendency. However, knowing the average value alone doesn't tell us everything about a distribution. Consider these questions:
|
27 |
+
|
28 |
+
- How spread out are the values around the mean?
|
29 |
+
- How reliable is the expectation as a predictor of individual outcomes?
|
30 |
+
- How much do individual samples typically deviate from the average?
|
31 |
+
|
32 |
+
This is where **variance** comes in - it measures the spread or dispersion of a random variable around its expected value.
|
33 |
+
"""
|
34 |
+
)
|
35 |
+
return
|
36 |
+
|
37 |
+
|
38 |
+
@app.cell(hide_code=True)
|
39 |
+
def _(mo):
|
40 |
+
mo.md(
|
41 |
+
r"""
|
42 |
+
## Definition of Variance
|
43 |
+
|
44 |
+
The variance of a random variable $X$ with expected value $\mu = E[X]$ is defined as:
|
45 |
+
|
46 |
+
$$\text{Var}(X) = E[(X-\mu)^2]$$
|
47 |
+
|
48 |
+
This definition captures the average squared deviation from the mean. There's also an equivalent, often more convenient formula:
|
49 |
+
|
50 |
+
$$\text{Var}(X) = E[X^2] - (E[X])^2$$
|
51 |
+
|
52 |
+
/// tip
|
53 |
+
The second formula is usually easier to compute, as it only requires calculating $E[X^2]$ and $E[X]$, rather than working with deviations from the mean.
|
54 |
+
"""
|
55 |
+
)
|
56 |
+
return
|
57 |
+
|
58 |
+
|
59 |
+
@app.cell(hide_code=True)
|
60 |
+
def _(mo):
|
61 |
+
mo.md(
|
62 |
+
r"""
|
63 |
+
## Intuition Through Example
|
64 |
+
|
65 |
+
Let's look at a real-world example that illustrates why variance is important. Consider three different groups of graders evaluating assignments in a massive online course. Each grader has their own "grading distribution" - their pattern of assigning scores to work that deserves a 70/100.
|
66 |
+
|
67 |
+
The visualization below shows the probability distributions for three types of graders. Try clicking and dragging the blue numbers to adjust the parameters and see how they affect the variance.
|
68 |
+
"""
|
69 |
+
)
|
70 |
+
return
|
71 |
+
|
72 |
+
|
73 |
+
@app.cell(hide_code=True)
|
74 |
+
def _(mo):
|
75 |
+
mo.md(
|
76 |
+
r"""
|
77 |
+
/// TIP
|
78 |
+
Try adjusting the blue numbers above to see how:
|
79 |
+
|
80 |
+
- Increasing spread increases variance
|
81 |
+
- The mixture ratio affects how many outliers appear in Grader C's distribution
|
82 |
+
- Changing the true grade shifts all distributions but maintains their relative variances
|
83 |
+
"""
|
84 |
+
)
|
85 |
+
return
|
86 |
+
|
87 |
+
|
88 |
+
@app.cell(hide_code=True)
|
89 |
+
def _(controls):
|
90 |
+
controls
|
91 |
+
return
|
92 |
+
|
93 |
+
|
94 |
+
@app.cell(hide_code=True)
|
95 |
+
def _(
|
96 |
+
grader_a_spread,
|
97 |
+
grader_b_spread,
|
98 |
+
grader_c_mix,
|
99 |
+
np,
|
100 |
+
plt,
|
101 |
+
stats,
|
102 |
+
true_grade,
|
103 |
+
):
|
104 |
+
# Create data for three grader distributions
|
105 |
+
_grader_x = np.linspace(40, 100, 200)
|
106 |
+
|
107 |
+
# Calculate actual variances
|
108 |
+
var_a = grader_a_spread.amount**2
|
109 |
+
var_b = grader_b_spread.amount**2
|
110 |
+
var_c = (1-grader_c_mix.amount) * 3**2 + grader_c_mix.amount * 8**2 + \
|
111 |
+
grader_c_mix.amount * (1-grader_c_mix.amount) * (8-3)**2 # Mixture variance formula
|
112 |
+
|
113 |
+
# Grader A: Wide spread around true grade
|
114 |
+
grader_a = stats.norm.pdf(_grader_x, loc=true_grade.amount, scale=grader_a_spread.amount)
|
115 |
+
|
116 |
+
# Grader B: Narrow spread around true grade
|
117 |
+
grader_b = stats.norm.pdf(_grader_x, loc=true_grade.amount, scale=grader_b_spread.amount)
|
118 |
+
|
119 |
+
# Grader C: Mixture of distributions
|
120 |
+
grader_c = (1-grader_c_mix.amount) * stats.norm.pdf(_grader_x, loc=true_grade.amount, scale=3) + \
|
121 |
+
grader_c_mix.amount * stats.norm.pdf(_grader_x, loc=true_grade.amount, scale=8)
|
122 |
+
|
123 |
+
grader_fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(15, 5))
|
124 |
+
|
125 |
+
# Plot each distribution
|
126 |
+
ax1.fill_between(_grader_x, grader_a, alpha=0.3, color='green', label=f'Var ≈ {var_a:.2f}')
|
127 |
+
ax1.axvline(x=true_grade.amount, color='black', linestyle='--', label='True Grade')
|
128 |
+
ax1.set_title('Grader A: High Variance')
|
129 |
+
ax1.set_xlabel('Grade')
|
130 |
+
ax1.set_ylabel('Pr(G = g)')
|
131 |
+
ax1.set_ylim(0, max(grader_a)*1.1)
|
132 |
+
|
133 |
+
ax2.fill_between(_grader_x, grader_b, alpha=0.3, color='blue', label=f'Var ≈ {var_b:.2f}')
|
134 |
+
ax2.axvline(x=true_grade.amount, color='black', linestyle='--')
|
135 |
+
ax2.set_title('Grader B: Low Variance')
|
136 |
+
ax2.set_xlabel('Grade')
|
137 |
+
ax2.set_ylim(0, max(grader_b)*1.1)
|
138 |
+
|
139 |
+
ax3.fill_between(_grader_x, grader_c, alpha=0.3, color='purple', label=f'Var ≈ {var_c:.2f}')
|
140 |
+
ax3.axvline(x=true_grade.amount, color='black', linestyle='--')
|
141 |
+
ax3.set_title('Grader C: Mixed Distribution')
|
142 |
+
ax3.set_xlabel('Grade')
|
143 |
+
ax3.set_ylim(0, max(grader_c)*1.1)
|
144 |
+
|
145 |
+
# Add annotations to explain what's happening
|
146 |
+
ax1.annotate('Wide spread = high variance',
|
147 |
+
xy=(true_grade.amount, max(grader_a)*0.5),
|
148 |
+
xytext=(true_grade.amount-15, max(grader_a)*0.7),
|
149 |
+
arrowprops=dict(facecolor='black', shrink=0.05, width=1))
|
150 |
+
|
151 |
+
ax2.annotate('Narrow spread = low variance',
|
152 |
+
xy=(true_grade.amount, max(grader_b)*0.5),
|
153 |
+
xytext=(true_grade.amount+8, max(grader_b)*0.7),
|
154 |
+
arrowprops=dict(facecolor='black', shrink=0.05, width=1))
|
155 |
+
|
156 |
+
ax3.annotate('Mixture creates outliers',
|
157 |
+
xy=(true_grade.amount+15, grader_c[np.where(_grader_x >= true_grade.amount+15)[0][0]]),
|
158 |
+
xytext=(true_grade.amount+5, max(grader_c)*0.7),
|
159 |
+
arrowprops=dict(facecolor='black', shrink=0.05, width=1))
|
160 |
+
|
161 |
+
# Add legends and adjust layout
|
162 |
+
for _ax in [ax1, ax2, ax3]:
|
163 |
+
_ax.legend()
|
164 |
+
_ax.grid(alpha=0.2)
|
165 |
+
|
166 |
+
plt.tight_layout()
|
167 |
+
plt.gca()
|
168 |
+
return (
|
169 |
+
ax1,
|
170 |
+
ax2,
|
171 |
+
ax3,
|
172 |
+
grader_a,
|
173 |
+
grader_b,
|
174 |
+
grader_c,
|
175 |
+
grader_fig,
|
176 |
+
var_a,
|
177 |
+
var_b,
|
178 |
+
var_c,
|
179 |
+
)
|
180 |
+
|
181 |
+
|
182 |
+
@app.cell(hide_code=True)
|
183 |
+
def _(mo):
|
184 |
+
mo.md(
|
185 |
+
r"""
|
186 |
+
/// note
|
187 |
+
All three distributions have the same expected value (the true grade), but they differ significantly in their spread:
|
188 |
+
|
189 |
+
- **Grader A** has high variance - grades vary widely from the true value
|
190 |
+
- **Grader B** has low variance - grades consistently stay close to the true value
|
191 |
+
- **Grader C** has a mixture distribution - mostly consistent but with occasional extreme values
|
192 |
+
|
193 |
+
This illustrates why variance is crucial: two distributions can have the same mean but behave very differently in practice.
|
194 |
+
"""
|
195 |
+
)
|
196 |
+
return
|
197 |
+
|
198 |
+
|
199 |
+
@app.cell(hide_code=True)
|
200 |
+
def _(mo):
|
201 |
+
mo.md(
|
202 |
+
r"""
|
203 |
+
## Computing Variance
|
204 |
+
|
205 |
+
Let's work through some concrete examples to understand how to calculate variance.
|
206 |
+
|
207 |
+
### Example 1: Fair Die Roll
|
208 |
+
|
209 |
+
Consider rolling a fair six-sided die. We'll calculate its variance step by step:
|
210 |
+
"""
|
211 |
+
)
|
212 |
+
return
|
213 |
+
|
214 |
+
|
215 |
+
@app.cell
|
216 |
+
def _(np):
|
217 |
+
# Define the die values and probabilities
|
218 |
+
die_values = np.array([1, 2, 3, 4, 5, 6])
|
219 |
+
die_probs = np.array([1/6] * 6)
|
220 |
+
|
221 |
+
# Calculate E[X]
|
222 |
+
expected_value = np.sum(die_values * die_probs)
|
223 |
+
|
224 |
+
# Calculate E[X^2]
|
225 |
+
expected_square = np.sum(die_values**2 * die_probs)
|
226 |
+
|
227 |
+
# Calculate Var(X) = E[X^2] - (E[X])^2
|
228 |
+
variance = expected_square - expected_value**2
|
229 |
+
|
230 |
+
# Calculate standard deviation
|
231 |
+
std_dev = np.sqrt(variance)
|
232 |
+
|
233 |
+
print(f"E[X] = {expected_value:.2f}")
|
234 |
+
print(f"E[X^2] = {expected_square:.2f}")
|
235 |
+
print(f"Var(X) = {variance:.2f}")
|
236 |
+
print(f"Standard Deviation = {std_dev:.2f}")
|
237 |
+
return (
|
238 |
+
die_probs,
|
239 |
+
die_values,
|
240 |
+
expected_square,
|
241 |
+
expected_value,
|
242 |
+
std_dev,
|
243 |
+
variance,
|
244 |
+
)
|
245 |
+
|
246 |
+
|
247 |
+
@app.cell(hide_code=True)
|
248 |
+
def _(mo):
|
249 |
+
mo.md(
|
250 |
+
r"""
|
251 |
+
/// NOTE
|
252 |
+
For a fair die:
|
253 |
+
|
254 |
+
- The expected value (3.50) tells us the average roll
|
255 |
+
- The variance (2.92) tells us how much typical rolls deviate from this average
|
256 |
+
- The standard deviation (1.71) gives us this spread in the original units
|
257 |
+
"""
|
258 |
+
)
|
259 |
+
return
|
260 |
+
|
261 |
+
|
262 |
+
@app.cell(hide_code=True)
|
263 |
+
def _(mo):
|
264 |
+
mo.md(
|
265 |
+
r"""
|
266 |
+
## Properties of Variance
|
267 |
+
|
268 |
+
Variance has several important properties that make it useful for analyzing random variables:
|
269 |
+
|
270 |
+
1. **Non-negativity**: $\text{Var}(X) \geq 0$ for any random variable $X$
|
271 |
+
2. **Variance of a constant**: $\text{Var}(c) = 0$ for any constant $c$
|
272 |
+
3. **Scaling**: $\text{Var}(aX) = a^2\text{Var}(X)$ for any constant $a$
|
273 |
+
4. **Translation**: $\text{Var}(X + b) = \text{Var}(X)$ for any constant $b$
|
274 |
+
5. **Independence**: If $X$ and $Y$ are independent, then $\text{Var}(X + Y) = \text{Var}(X) + \text{Var}(Y)$
|
275 |
+
|
276 |
+
Let's verify a property with an example.
|
277 |
+
"""
|
278 |
+
)
|
279 |
+
return
|
280 |
+
|
281 |
+
|
282 |
+
@app.cell(hide_code=True)
|
283 |
+
def _(mo):
|
284 |
+
mo.md(
|
285 |
+
r"""
|
286 |
+
## Proof of Variance Formula
|
287 |
+
|
288 |
+
The equivalence of the two variance formulas is a fundamental result in probability theory. Here's the proof:
|
289 |
+
|
290 |
+
Starting with the definition $\text{Var}(X) = E[(X-\mu)^2]$ where $\mu = E[X]$:
|
291 |
+
|
292 |
+
\begin{align}
|
293 |
+
\text{Var}(X) &= E[(X-\mu)^2] \\
|
294 |
+
&= \sum_x(x-\mu)^2P(x) && \text{Definition of Expectation}\\
|
295 |
+
&= \sum_x (x^2 -2\mu x + \mu^2)P(x) && \text{Expanding the square}\\
|
296 |
+
&= \sum_x x^2P(x)- 2\mu \sum_x xP(x) + \mu^2 \sum_x P(x) && \text{Distributing the sum}\\
|
297 |
+
&= E[X^2]- 2\mu E[X] + \mu^2 && \text{Definition of expectation}\\
|
298 |
+
&= E[X^2]- 2(E[X])^2 + (E[X])^2 && \text{Since }\mu = E[X]\\
|
299 |
+
&= E[X^2]- (E[X])^2 && \text{Simplifying}
|
300 |
+
\end{align}
|
301 |
+
|
302 |
+
/// tip
|
303 |
+
This proof shows why the formula $\text{Var}(X) = E[X^2] - (E[X])^2$ is so useful - it's much easier to compute $E[X^2]$ and $E[X]$ separately than to work with deviations directly.
|
304 |
+
"""
|
305 |
+
)
|
306 |
+
return
|
307 |
+
|
308 |
+
|
309 |
+
@app.cell
|
310 |
+
def _(die_probs, die_values, np):
|
311 |
+
# Demonstrate scaling property
|
312 |
+
a = 2 # Scale factor
|
313 |
+
|
314 |
+
# Original variance
|
315 |
+
original_var = np.sum(die_values**2 * die_probs) - (np.sum(die_values * die_probs))**2
|
316 |
+
|
317 |
+
# Scaled random variable variance
|
318 |
+
scaled_values = a * die_values
|
319 |
+
scaled_var = np.sum(scaled_values**2 * die_probs) - (np.sum(scaled_values * die_probs))**2
|
320 |
+
|
321 |
+
print(f"Original Variance: {original_var:.2f}")
|
322 |
+
print(f"Scaled Variance (a={a}): {scaled_var:.2f}")
|
323 |
+
print(f"a^2 * Original Variance: {a**2 * original_var:.2f}")
|
324 |
+
print(f"Property holds: {abs(scaled_var - a**2 * original_var) < 1e-10}")
|
325 |
+
return a, original_var, scaled_values, scaled_var
|
326 |
+
|
327 |
+
|
328 |
+
@app.cell
|
329 |
+
def _():
|
330 |
+
# DIY : Prove more properties as shown above
|
331 |
+
return
|
332 |
+
|
333 |
+
|
334 |
+
@app.cell(hide_code=True)
|
335 |
+
def _(mo):
|
336 |
+
mo.md(
|
337 |
+
r"""
|
338 |
+
## Standard Deviation
|
339 |
+
|
340 |
+
While variance is mathematically convenient, it has one practical drawback: its units are squared. For example, if we're measuring grades (0-100), the variance is in "grade points squared." This makes it hard to interpret intuitively.
|
341 |
+
|
342 |
+
The **standard deviation**, denoted by $\sigma$ or $\text{SD}(X)$, is the square root of variance:
|
343 |
+
|
344 |
+
$$\sigma = \sqrt{\text{Var}(X)}$$
|
345 |
+
|
346 |
+
/// tip
|
347 |
+
Standard deviation is often more intuitive because it's in the same units as the original data. For a normal distribution, approximately:
|
348 |
+
- 68% of values fall within 1 standard deviation of the mean
|
349 |
+
- 95% of values fall within 2 standard deviations
|
350 |
+
- 99.7% of values fall within 3 standard deviations
|
351 |
+
"""
|
352 |
+
)
|
353 |
+
return
|
354 |
+
|
355 |
+
|
356 |
+
@app.cell(hide_code=True)
|
357 |
+
def _(controls1):
|
358 |
+
controls1
|
359 |
+
return
|
360 |
+
|
361 |
+
|
362 |
+
@app.cell(hide_code=True)
|
363 |
+
def _(TangleSlider, mo):
|
364 |
+
normal_mean = mo.ui.anywidget(TangleSlider(
|
365 |
+
amount=0,
|
366 |
+
min_value=-5,
|
367 |
+
max_value=5,
|
368 |
+
step=0.5,
|
369 |
+
digits=1,
|
370 |
+
suffix=" units"
|
371 |
+
))
|
372 |
+
|
373 |
+
normal_std = mo.ui.anywidget(TangleSlider(
|
374 |
+
amount=1,
|
375 |
+
min_value=0.1,
|
376 |
+
max_value=3,
|
377 |
+
step=0.1,
|
378 |
+
digits=1,
|
379 |
+
suffix=" units"
|
380 |
+
))
|
381 |
+
|
382 |
+
# Create a grid layout for the controls
|
383 |
+
controls1 = mo.vstack([
|
384 |
+
mo.md("### Interactive Normal Distribution"),
|
385 |
+
mo.hstack([
|
386 |
+
mo.md("Adjust the parameters to see how standard deviation affects the shape of the distribution:"),
|
387 |
+
]),
|
388 |
+
mo.hstack([
|
389 |
+
mo.md("Mean (μ): "),
|
390 |
+
normal_mean,
|
391 |
+
mo.md(" Standard deviation (σ): "),
|
392 |
+
normal_std
|
393 |
+
], justify="start"),
|
394 |
+
])
|
395 |
+
return controls1, normal_mean, normal_std
|
396 |
+
|
397 |
+
|
398 |
+
@app.cell(hide_code=True)
|
399 |
+
def _(normal_mean, normal_std, np, plt, stats):
|
400 |
+
# data for normal distribution
|
401 |
+
_normal_x = np.linspace(-10, 10, 1000)
|
402 |
+
_normal_y = stats.norm.pdf(_normal_x, loc=normal_mean.amount, scale=normal_std.amount)
|
403 |
+
|
404 |
+
# ranges for standard deviation intervals
|
405 |
+
one_sigma_left = normal_mean.amount - normal_std.amount
|
406 |
+
one_sigma_right = normal_mean.amount + normal_std.amount
|
407 |
+
two_sigma_left = normal_mean.amount - 2 * normal_std.amount
|
408 |
+
two_sigma_right = normal_mean.amount + 2 * normal_std.amount
|
409 |
+
three_sigma_left = normal_mean.amount - 3 * normal_std.amount
|
410 |
+
three_sigma_right = normal_mean.amount + 3 * normal_std.amount
|
411 |
+
|
412 |
+
# Create the plot
|
413 |
+
normal_fig, normal_ax = plt.subplots(figsize=(10, 6))
|
414 |
+
|
415 |
+
# Plot the distribution
|
416 |
+
normal_ax.plot(_normal_x, _normal_y, 'b-', linewidth=2)
|
417 |
+
|
418 |
+
# stdev intervals
|
419 |
+
normal_ax.fill_between(_normal_x, 0, _normal_y, where=(_normal_x >= one_sigma_left) & (_normal_x <= one_sigma_right),
|
420 |
+
alpha=0.3, color='red', label='68% (±1σ)')
|
421 |
+
normal_ax.fill_between(_normal_x, 0, _normal_y, where=(_normal_x >= two_sigma_left) & (_normal_x <= two_sigma_right),
|
422 |
+
alpha=0.2, color='green', label='95% (±2σ)')
|
423 |
+
normal_ax.fill_between(_normal_x, 0, _normal_y, where=(_normal_x >= three_sigma_left) & (_normal_x <= three_sigma_right),
|
424 |
+
alpha=0.1, color='blue', label='99.7% (±3σ)')
|
425 |
+
|
426 |
+
# vertical lines for the mean and standard deviations
|
427 |
+
normal_ax.axvline(x=normal_mean.amount, color='black', linestyle='-', linewidth=1.5, label='Mean (μ)')
|
428 |
+
normal_ax.axvline(x=one_sigma_left, color='red', linestyle='--', linewidth=1)
|
429 |
+
normal_ax.axvline(x=one_sigma_right, color='red', linestyle='--', linewidth=1)
|
430 |
+
normal_ax.axvline(x=two_sigma_left, color='green', linestyle='--', linewidth=1)
|
431 |
+
normal_ax.axvline(x=two_sigma_right, color='green', linestyle='--', linewidth=1)
|
432 |
+
|
433 |
+
# annotations
|
434 |
+
normal_ax.annotate(f'μ = {normal_mean.amount:.2f}',
|
435 |
+
xy=(normal_mean.amount, max(_normal_y)*0.5),
|
436 |
+
xytext=(normal_mean.amount + 0.5, max(_normal_y)*0.8),
|
437 |
+
arrowprops=dict(facecolor='black', shrink=0.05, width=1))
|
438 |
+
|
439 |
+
normal_ax.annotate(f'σ = {normal_std.amount:.2f}',
|
440 |
+
xy=(one_sigma_right, stats.norm.pdf(one_sigma_right, loc=normal_mean.amount, scale=normal_std.amount)),
|
441 |
+
xytext=(one_sigma_right + 0.5, max(_normal_y)*0.6),
|
442 |
+
arrowprops=dict(facecolor='red', shrink=0.05, width=1))
|
443 |
+
|
444 |
+
# labels and title
|
445 |
+
normal_ax.set_xlabel('Value')
|
446 |
+
normal_ax.set_ylabel('Probability Density')
|
447 |
+
normal_ax.set_title(f'Normal Distribution with μ = {normal_mean.amount:.2f} and σ = {normal_std.amount:.2f}')
|
448 |
+
|
449 |
+
# legend and grid
|
450 |
+
normal_ax.legend()
|
451 |
+
normal_ax.grid(alpha=0.3)
|
452 |
+
|
453 |
+
plt.tight_layout()
|
454 |
+
plt.gca()
|
455 |
+
return (
|
456 |
+
normal_ax,
|
457 |
+
normal_fig,
|
458 |
+
one_sigma_left,
|
459 |
+
one_sigma_right,
|
460 |
+
three_sigma_left,
|
461 |
+
three_sigma_right,
|
462 |
+
two_sigma_left,
|
463 |
+
two_sigma_right,
|
464 |
+
)
|
465 |
+
|
466 |
+
|
467 |
+
@app.cell(hide_code=True)
|
468 |
+
def _(mo):
|
469 |
+
mo.md(
|
470 |
+
r"""
|
471 |
+
/// tip
|
472 |
+
The interactive visualization above demonstrates how standard deviation (σ) affects the shape of a normal distribution:
|
473 |
+
|
474 |
+
- The **red region** covers μ ± 1σ, containing approximately 68% of the probability
|
475 |
+
- The **green region** covers μ ± 2σ, containing approximately 95% of the probability
|
476 |
+
- The **blue region** covers μ ± 3σ, containing approximately 99.7% of the probability
|
477 |
+
|
478 |
+
This is known as the "68-95-99.7 rule" or the "empirical rule" and is a useful heuristic for understanding the spread of data.
|
479 |
+
"""
|
480 |
+
)
|
481 |
+
return
|
482 |
+
|
483 |
+
|
484 |
+
@app.cell(hide_code=True)
|
485 |
+
def _(mo):
|
486 |
+
mo.md(
|
487 |
+
r"""
|
488 |
+
## 🤔 Test Your Understanding
|
489 |
+
|
490 |
+
Choose what you believe are the correct options in the questions below:
|
491 |
+
|
492 |
+
<details>
|
493 |
+
<summary>The variance of a random variable can be negative.</summary>
|
494 |
+
❌ False! Variance is defined as an expected value of squared deviations, and squares are always non-negative.
|
495 |
+
</details>
|
496 |
+
|
497 |
+
<details>
|
498 |
+
<summary>If X and Y are independent random variables, then Var(X + Y) = Var(X) + Var(Y).</summary>
|
499 |
+
✅ True! This is one of the key properties of variance for independent random variables.
|
500 |
+
</details>
|
501 |
+
|
502 |
+
<details>
|
503 |
+
<summary>Multiplying a random variable by 2 multiplies its variance by 2.</summary>
|
504 |
+
❌ False! Multiplying a random variable by a constant a multiplies its variance by a². So multiplying by 2 multiplies variance by 4.
|
505 |
+
</details>
|
506 |
+
|
507 |
+
<details>
|
508 |
+
<summary>Standard deviation is always equal to the square root of variance.</summary>
|
509 |
+
✅ True! By definition, standard deviation σ = √Var(X).
|
510 |
+
</details>
|
511 |
+
|
512 |
+
<details>
|
513 |
+
<summary>If Var(X) = 0, then X must be a constant.</summary>
|
514 |
+
✅ True! Zero variance means there is no spread around the mean, so X can only take one value.
|
515 |
+
</details>
|
516 |
+
"""
|
517 |
+
)
|
518 |
+
return
|
519 |
+
|
520 |
+
|
521 |
+
@app.cell(hide_code=True)
|
522 |
+
def _(mo):
|
523 |
+
mo.md(
|
524 |
+
r"""
|
525 |
+
## Key Takeaways
|
526 |
+
|
527 |
+
Variance gives us a way to measure how spread out a random variable is around its mean. It's like the "uncertainty" in our expectation - a high variance means individual outcomes can differ widely from what we expect on average.
|
528 |
+
|
529 |
+
Standard deviation brings this measure back to the original units, making it easier to interpret. For grades, a standard deviation of 10 points means typical grades fall within about 10 points of the average.
|
530 |
+
|
531 |
+
Variance pops up everywhere - from weather forecasts (how reliable is the predicted temperature?) to financial investments (how risky is this stock?) to quality control (how consistent is our manufacturing process?).
|
532 |
+
|
533 |
+
In our next notebook, we'll explore more properties of random variables and see how they combine to form more complex distributions.
|
534 |
+
"""
|
535 |
+
)
|
536 |
+
return
|
537 |
+
|
538 |
+
|
539 |
+
@app.cell(hide_code=True)
|
540 |
+
def _(mo):
|
541 |
+
mo.md(r"""Appendix (containing helper code):""")
|
542 |
+
return
|
543 |
+
|
544 |
+
|
545 |
+
@app.cell(hide_code=True)
|
546 |
+
def _():
|
547 |
+
import marimo as mo
|
548 |
+
return (mo,)
|
549 |
+
|
550 |
+
|
551 |
+
@app.cell(hide_code=True)
|
552 |
+
def _():
|
553 |
+
import numpy as np
|
554 |
+
import scipy.stats as stats
|
555 |
+
import matplotlib.pyplot as plt
|
556 |
+
from wigglystuff import TangleSlider
|
557 |
+
return TangleSlider, np, plt, stats
|
558 |
+
|
559 |
+
|
560 |
+
@app.cell(hide_code=True)
|
561 |
+
def _(TangleSlider, mo):
|
562 |
+
# Create interactive elements using TangleSlider for a more inline experience
|
563 |
+
true_grade = mo.ui.anywidget(TangleSlider(
|
564 |
+
amount=70,
|
565 |
+
min_value=50,
|
566 |
+
max_value=90,
|
567 |
+
step=5,
|
568 |
+
digits=0,
|
569 |
+
suffix=" points"
|
570 |
+
))
|
571 |
+
|
572 |
+
grader_a_spread = mo.ui.anywidget(TangleSlider(
|
573 |
+
amount=10,
|
574 |
+
min_value=5,
|
575 |
+
max_value=20,
|
576 |
+
step=1,
|
577 |
+
digits=0,
|
578 |
+
suffix=" points"
|
579 |
+
))
|
580 |
+
|
581 |
+
grader_b_spread = mo.ui.anywidget(TangleSlider(
|
582 |
+
amount=2,
|
583 |
+
min_value=1,
|
584 |
+
max_value=5,
|
585 |
+
step=0.5,
|
586 |
+
digits=1,
|
587 |
+
suffix=" points"
|
588 |
+
))
|
589 |
+
|
590 |
+
grader_c_mix = mo.ui.anywidget(TangleSlider(
|
591 |
+
amount=0.2,
|
592 |
+
min_value=0,
|
593 |
+
max_value=1,
|
594 |
+
step=0.05,
|
595 |
+
digits=2,
|
596 |
+
suffix=" proportion"
|
597 |
+
))
|
598 |
+
return grader_a_spread, grader_b_spread, grader_c_mix, true_grade
|
599 |
+
|
600 |
+
|
601 |
+
@app.cell(hide_code=True)
|
602 |
+
def _(grader_a_spread, grader_b_spread, grader_c_mix, mo, true_grade):
|
603 |
+
# Create a grid layout for the interactive controls
|
604 |
+
controls = mo.vstack([
|
605 |
+
mo.md("### Adjust Parameters to See How Variance Changes"),
|
606 |
+
mo.hstack([
|
607 |
+
mo.md("**True grade:** The correct score that should be assigned is "),
|
608 |
+
true_grade,
|
609 |
+
mo.md(" out of 100.")
|
610 |
+
], justify="start"),
|
611 |
+
mo.hstack([
|
612 |
+
mo.md("**Grader A:** Has a wide spread with standard deviation of "),
|
613 |
+
grader_a_spread,
|
614 |
+
mo.md(" points.")
|
615 |
+
], justify="start"),
|
616 |
+
mo.hstack([
|
617 |
+
mo.md("**Grader B:** Has a narrow spread with standard deviation of "),
|
618 |
+
grader_b_spread,
|
619 |
+
mo.md(" points.")
|
620 |
+
], justify="start"),
|
621 |
+
mo.hstack([
|
622 |
+
mo.md("**Grader C:** Has a mixture distribution with "),
|
623 |
+
grader_c_mix,
|
624 |
+
mo.md(" proportion of outliers.")
|
625 |
+
], justify="start"),
|
626 |
+
])
|
627 |
+
return (controls,)
|
628 |
+
|
629 |
+
|
630 |
+
if __name__ == "__main__":
|
631 |
+
app.run()
|
probability/13_bernoulli_distribution.py
ADDED
@@ -0,0 +1,427 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# /// script
|
2 |
+
# requires-python = ">=3.10"
|
3 |
+
# dependencies = [
|
4 |
+
# "marimo",
|
5 |
+
# "matplotlib==3.10.0",
|
6 |
+
# "numpy==2.2.3",
|
7 |
+
# "scipy==1.15.2",
|
8 |
+
# ]
|
9 |
+
# ///
|
10 |
+
|
11 |
+
import marimo
|
12 |
+
|
13 |
+
__generated_with = "0.11.22"
|
14 |
+
app = marimo.App(width="medium", app_title="Bernoulli Distribution")
|
15 |
+
|
16 |
+
|
17 |
+
@app.cell(hide_code=True)
|
18 |
+
def _(mo):
|
19 |
+
mo.md(
|
20 |
+
r"""
|
21 |
+
# Bernoulli Distribution
|
22 |
+
|
23 |
+
_This notebook is a computational companion to ["Probability for Computer Scientists"](https://chrispiech.github.io/probabilityForComputerScientists/en/part2/bernoulli/), by Stanford professor Chris Piech._
|
24 |
+
|
25 |
+
## Parametric Random Variables
|
26 |
+
|
27 |
+
There are many classic and commonly-seen random variable abstractions that show up in the world of probability. At this point, we'll learn about several of the most significant parametric discrete distributions.
|
28 |
+
|
29 |
+
When solving problems, if you can recognize that a random variable fits one of these formats, then you can use its pre-derived Probability Mass Function (PMF), expectation, variance, and other properties. Random variables of this sort are called **parametric random variables**. If you can argue that a random variable falls under one of the studied parametric types, you simply need to provide parameters.
|
30 |
+
|
31 |
+
> A good analogy is a `class` in programming. Creating a parametric random variable is very similar to calling a constructor with input parameters.
|
32 |
+
"""
|
33 |
+
)
|
34 |
+
return
|
35 |
+
|
36 |
+
|
37 |
+
@app.cell(hide_code=True)
|
38 |
+
def _(mo):
|
39 |
+
mo.md(
|
40 |
+
r"""
|
41 |
+
## Bernoulli Random Variables
|
42 |
+
|
43 |
+
A **Bernoulli random variable** (also called a boolean or indicator random variable) is the simplest kind of parametric random variable. It can take on two values: 1 and 0.
|
44 |
+
|
45 |
+
It takes on a 1 if an experiment with probability $p$ resulted in success and a 0 otherwise.
|
46 |
+
|
47 |
+
Some example uses include:
|
48 |
+
|
49 |
+
- A coin flip (heads = 1, tails = 0)
|
50 |
+
- A random binary digit
|
51 |
+
- Whether a disk drive crashed
|
52 |
+
- Whether someone likes a Netflix movie
|
53 |
+
|
54 |
+
Here $p$ is the parameter, but different instances of Bernoulli random variables might have different values of $p$.
|
55 |
+
"""
|
56 |
+
)
|
57 |
+
return
|
58 |
+
|
59 |
+
|
60 |
+
@app.cell(hide_code=True)
|
61 |
+
def _(mo):
|
62 |
+
mo.md(
|
63 |
+
r"""
|
64 |
+
## Key Properties of a Bernoulli Random Variable
|
65 |
+
|
66 |
+
If $X$ is declared to be a Bernoulli random variable with parameter $p$, denoted $X \sim \text{Bern}(p)$, it has the following properties:
|
67 |
+
"""
|
68 |
+
)
|
69 |
+
return
|
70 |
+
|
71 |
+
|
72 |
+
@app.cell
|
73 |
+
def _(stats):
|
74 |
+
# Define the Bernoulli distribution function
|
75 |
+
def Bern(p):
|
76 |
+
return stats.bernoulli(p)
|
77 |
+
return (Bern,)
|
78 |
+
|
79 |
+
|
80 |
+
@app.cell(hide_code=True)
|
81 |
+
def _(mo):
|
82 |
+
mo.md(
|
83 |
+
r"""
|
84 |
+
## Bernoulli Distribution Properties
|
85 |
+
|
86 |
+
$\begin{array}{lll}
|
87 |
+
\text{Notation:} & X \sim \text{Bern}(p) \\
|
88 |
+
\text{Description:} & \text{A boolean variable that is 1 with probability } p \\
|
89 |
+
\text{Parameters:} & p, \text{ the probability that } X = 1 \\
|
90 |
+
\text{Support:} & x \text{ is either 0 or 1} \\
|
91 |
+
\text{PMF equation:} & P(X = x) =
|
92 |
+
\begin{cases}
|
93 |
+
p & \text{if }x = 1\\
|
94 |
+
1-p & \text{if }x = 0
|
95 |
+
\end{cases} \\
|
96 |
+
\text{PMF (smooth):} & P(X = x) = p^x(1-p)^{1-x} \\
|
97 |
+
\text{Expectation:} & E[X] = p \\
|
98 |
+
\text{Variance:} & \text{Var}(X) = p(1-p) \\
|
99 |
+
\end{array}$
|
100 |
+
"""
|
101 |
+
)
|
102 |
+
return
|
103 |
+
|
104 |
+
|
105 |
+
@app.cell(hide_code=True)
|
106 |
+
def _(mo, p_slider):
|
107 |
+
# Visualization of the Bernoulli PMF
|
108 |
+
_p = p_slider.value
|
109 |
+
|
110 |
+
# Values for PMF
|
111 |
+
values = [0, 1]
|
112 |
+
probabilities = [1 - _p, _p]
|
113 |
+
|
114 |
+
# Relevant statistics
|
115 |
+
expected_value = _p
|
116 |
+
variance = _p * (1 - _p)
|
117 |
+
|
118 |
+
mo.md(f"""
|
119 |
+
## PMF Graph for Bernoulli($p={_p:.2f}$)
|
120 |
+
|
121 |
+
Parameter $p$: {p_slider}
|
122 |
+
|
123 |
+
Expected value: $E[X] = {expected_value:.2f}$
|
124 |
+
|
125 |
+
Variance: $\\text{{Var}}(X) = {variance:.2f}$
|
126 |
+
""")
|
127 |
+
return expected_value, probabilities, values, variance
|
128 |
+
|
129 |
+
|
130 |
+
@app.cell(hide_code=True)
|
131 |
+
def _(expected_value, p_slider, plt, probabilities, values, variance):
|
132 |
+
# PMF
|
133 |
+
_p = p_slider.value
|
134 |
+
fig, ax = plt.subplots(figsize=(10, 6))
|
135 |
+
|
136 |
+
# Bar plot for PMF
|
137 |
+
ax.bar(values, probabilities, width=0.4, color='blue', alpha=0.7)
|
138 |
+
|
139 |
+
ax.set_xlabel('Values that X can take on')
|
140 |
+
ax.set_ylabel('Probability')
|
141 |
+
ax.set_title(f'PMF of Bernoulli Distribution with p = {_p:.2f}')
|
142 |
+
|
143 |
+
# x-axis limit
|
144 |
+
ax.set_xticks([0, 1])
|
145 |
+
ax.set_xlim(-0.5, 1.5)
|
146 |
+
|
147 |
+
# y-axis w/ some padding
|
148 |
+
ax.set_ylim(0, max(probabilities) * 1.1)
|
149 |
+
|
150 |
+
# Add expectation as vertical line
|
151 |
+
ax.axvline(x=expected_value, color='red', linestyle='--',
|
152 |
+
label=f'E[X] = {expected_value:.2f}')
|
153 |
+
|
154 |
+
# Add variance annotation
|
155 |
+
ax.text(0.5, max(probabilities) * 0.8,
|
156 |
+
f'Var(X) = {variance:.3f}',
|
157 |
+
horizontalalignment='center',
|
158 |
+
bbox=dict(facecolor='white', alpha=0.7))
|
159 |
+
|
160 |
+
ax.legend()
|
161 |
+
plt.tight_layout()
|
162 |
+
plt.gca()
|
163 |
+
return ax, fig
|
164 |
+
|
165 |
+
|
166 |
+
@app.cell(hide_code=True)
|
167 |
+
def _(mo):
|
168 |
+
mo.md(
|
169 |
+
r"""
|
170 |
+
## Proof: Expectation of a Bernoulli
|
171 |
+
|
172 |
+
If $X$ is a Bernoulli with parameter $p$, $X \sim \text{Bern}(p)$:
|
173 |
+
|
174 |
+
\begin{align}
|
175 |
+
E[X] &= \sum_x x \cdot (X=x) && \text{Definition of expectation} \\
|
176 |
+
&= 1 \cdot p + 0 \cdot (1-p) &&
|
177 |
+
X \text{ can take on values 0 and 1} \\
|
178 |
+
&= p && \text{Remove the 0 term}
|
179 |
+
\end{align}
|
180 |
+
|
181 |
+
## Proof: Variance of a Bernoulli
|
182 |
+
|
183 |
+
If $X$ is a Bernoulli with parameter $p$, $X \sim \text{Bern}(p)$:
|
184 |
+
|
185 |
+
To compute variance, first compute $E[X^2]$:
|
186 |
+
|
187 |
+
\begin{align}
|
188 |
+
E[X^2]
|
189 |
+
&= \sum_x x^2 \cdot (X=x) &&\text{LOTUS}\\
|
190 |
+
&= 0^2 \cdot (1-p) + 1^2 \cdot p\\
|
191 |
+
&= p
|
192 |
+
\end{align}
|
193 |
+
|
194 |
+
\begin{align}
|
195 |
+
(X)
|
196 |
+
&= E[X^2] - E[X]^2&& \text{Def of variance} \\
|
197 |
+
&= p - p^2 && \text{Substitute }E[X^2]=p, E[X] = p \\
|
198 |
+
&= p (1-p) && \text{Factor out }p
|
199 |
+
\end{align}
|
200 |
+
"""
|
201 |
+
)
|
202 |
+
return
|
203 |
+
|
204 |
+
|
205 |
+
@app.cell(hide_code=True)
|
206 |
+
def _(mo):
|
207 |
+
mo.md(
|
208 |
+
r"""
|
209 |
+
## Indicator Random Variable
|
210 |
+
|
211 |
+
> **Definition**: An indicator variable is a Bernoulli random variable which takes on the value 1 if an **underlying event occurs**, and 0 _otherwise_.
|
212 |
+
|
213 |
+
Indicator random variables are a convenient way to convert the "true/false" outcome of an event into a number. That number may be easier to incorporate into an equation.
|
214 |
+
|
215 |
+
A random variable $I$ is an indicator variable for an event $A$ if $I = 1$ when $A$ occurs and $I = 0$ if $A$ does not occur. Indicator random variables are Bernoulli random variables, with $p = P(A)$. $I_A$ is a common choice of name for an indicator random variable.
|
216 |
+
|
217 |
+
Here are some properties of indicator random variables:
|
218 |
+
|
219 |
+
- $P(I=1)=P(A)$
|
220 |
+
- $E[I]=P(A)$
|
221 |
+
"""
|
222 |
+
)
|
223 |
+
return
|
224 |
+
|
225 |
+
|
226 |
+
@app.cell(hide_code=True)
|
227 |
+
def _(mo):
|
228 |
+
# Simulation of Bernoulli trials
|
229 |
+
mo.md(r"""
|
230 |
+
## Simulation of Bernoulli Trials
|
231 |
+
|
232 |
+
Let's simulate Bernoulli trials to see the law of large numbers in action. We'll flip a biased coin repeatedly and observe how the proportion of successes approaches the true probability $p$.
|
233 |
+
""")
|
234 |
+
|
235 |
+
# UI element for simulation parameters
|
236 |
+
num_trials_slider = mo.ui.slider(10, 10000, value=1000, step=10, label="Number of trials")
|
237 |
+
p_sim_slider = mo.ui.slider(0.01, 0.99, value=0.65, step=0.01, label="Success probability (p)")
|
238 |
+
return num_trials_slider, p_sim_slider
|
239 |
+
|
240 |
+
|
241 |
+
@app.cell(hide_code=True)
|
242 |
+
def _(mo):
|
243 |
+
mo.md(r"""## Simulation""")
|
244 |
+
return
|
245 |
+
|
246 |
+
|
247 |
+
@app.cell(hide_code=True)
|
248 |
+
def _(mo, num_trials_slider, p_sim_slider):
|
249 |
+
mo.hstack([num_trials_slider, p_sim_slider], justify='space-around')
|
250 |
+
return
|
251 |
+
|
252 |
+
|
253 |
+
@app.cell(hide_code=True)
|
254 |
+
def _(np, num_trials_slider, p_sim_slider, plt):
|
255 |
+
# Bernoulli trials
|
256 |
+
_num_trials = num_trials_slider.value
|
257 |
+
p = p_sim_slider.value
|
258 |
+
|
259 |
+
# Random Bernoulli trials
|
260 |
+
trials = np.random.binomial(1, p, size=_num_trials)
|
261 |
+
|
262 |
+
# Cumulative proportion of successes
|
263 |
+
cumulative_mean = np.cumsum(trials) / np.arange(1, _num_trials + 1)
|
264 |
+
|
265 |
+
# Results
|
266 |
+
plt.figure(figsize=(10, 6))
|
267 |
+
plt.plot(range(1, _num_trials + 1), cumulative_mean, label='Proportion of successes')
|
268 |
+
plt.axhline(y=p, color='r', linestyle='--', label=f'True probability (p={p})')
|
269 |
+
|
270 |
+
plt.xscale('log') # Use log scale for better visualization
|
271 |
+
plt.xlabel('Number of trials')
|
272 |
+
plt.ylabel('Proportion of successes')
|
273 |
+
plt.title('Convergence of Sample Proportion to True Probability')
|
274 |
+
plt.legend()
|
275 |
+
plt.grid(True, alpha=0.3)
|
276 |
+
|
277 |
+
# Add annotation
|
278 |
+
plt.annotate('As the number of trials increases,\nthe proportion approaches p',
|
279 |
+
xy=(_num_trials, cumulative_mean[-1]),
|
280 |
+
xytext=(_num_trials/5, p + 0.1),
|
281 |
+
arrowprops=dict(facecolor='black', shrink=0.05, width=1))
|
282 |
+
|
283 |
+
plt.tight_layout()
|
284 |
+
plt.gca()
|
285 |
+
return cumulative_mean, p, trials
|
286 |
+
|
287 |
+
|
288 |
+
@app.cell(hide_code=True)
|
289 |
+
def _(mo, np, trials):
|
290 |
+
# Calculate statistics from the simulation
|
291 |
+
num_successes = np.sum(trials)
|
292 |
+
num_trials = len(trials)
|
293 |
+
proportion = num_successes / num_trials
|
294 |
+
|
295 |
+
# Display the results
|
296 |
+
mo.md(f"""
|
297 |
+
### Simulation Results
|
298 |
+
|
299 |
+
- Number of trials: {num_trials}
|
300 |
+
- Number of successes: {num_successes}
|
301 |
+
- Proportion of successes: {proportion:.4f}
|
302 |
+
|
303 |
+
This demonstrates how the sample proportion approaches the true probability $p$ as the number of trials increases.
|
304 |
+
""")
|
305 |
+
return num_successes, num_trials, proportion
|
306 |
+
|
307 |
+
|
308 |
+
@app.cell(hide_code=True)
|
309 |
+
def _(mo):
|
310 |
+
mo.md(
|
311 |
+
r"""
|
312 |
+
## 🤔 Test Your Understanding
|
313 |
+
|
314 |
+
Pick which of these statements about Bernoulli random variables you think are correct:
|
315 |
+
|
316 |
+
/// details | The variance of a Bernoulli random variable is always less than or equal to 0.25
|
317 |
+
✅ Correct! The variance $p(1-p)$ reaches its maximum value of 0.25 when $p = 0.5$.
|
318 |
+
///
|
319 |
+
|
320 |
+
/// details | The expected value of a Bernoulli random variable must be either 0 or 1
|
321 |
+
❌ Incorrect! The expected value is $p$, which can be any value between 0 and 1.
|
322 |
+
///
|
323 |
+
|
324 |
+
/// details | If $X \sim \text{Bern}(0.3)$ and $Y \sim \text{Bern}(0.7)$, then $X$ and $Y$ have the same variance
|
325 |
+
✅ Correct! $\text{Var}(X) = 0.3 \times 0.7 = 0.21$ and $\text{Var}(Y) = 0.7 \times 0.3 = 0.21$.
|
326 |
+
///
|
327 |
+
|
328 |
+
/// details | Two independent coin flips can be modeled as the sum of two Bernoulli random variables
|
329 |
+
✅ Correct! The sum would follow a Binomial distribution with $n=2$.
|
330 |
+
///
|
331 |
+
"""
|
332 |
+
)
|
333 |
+
return
|
334 |
+
|
335 |
+
|
336 |
+
@app.cell(hide_code=True)
|
337 |
+
def _(mo):
|
338 |
+
mo.md(
|
339 |
+
r"""
|
340 |
+
## Applications of Bernoulli Random Variables
|
341 |
+
|
342 |
+
Bernoulli random variables are used in many real-world scenarios:
|
343 |
+
|
344 |
+
1. **Quality Control**: Testing if a manufactured item is defective (1) or not (0)
|
345 |
+
|
346 |
+
2. **A/B Testing**: Determining if a user clicks (1) or doesn't click (0) on a website button
|
347 |
+
|
348 |
+
3. **Medical Testing**: Checking if a patient tests positive (1) or negative (0) for a disease
|
349 |
+
|
350 |
+
4. **Election Modeling**: Modeling if a particular voter votes for candidate A (1) or not (0)
|
351 |
+
|
352 |
+
5. **Financial Markets**: Modeling if a stock price goes up (1) or down (0) in a simplified model
|
353 |
+
|
354 |
+
Because Bernoulli random variables are parametric, as soon as you declare a random variable to be of type Bernoulli, you automatically know all of its pre-derived properties!
|
355 |
+
"""
|
356 |
+
)
|
357 |
+
return
|
358 |
+
|
359 |
+
|
360 |
+
@app.cell(hide_code=True)
|
361 |
+
def _(mo):
|
362 |
+
mo.md(
|
363 |
+
r"""
|
364 |
+
## Summary
|
365 |
+
|
366 |
+
And that's a wrap on Bernoulli distributions! We've learnt the simplest of all probability distributions — the one that only has two possible outcomes. Flip a coin, check if an email is spam, see if your blind date shows up — these are all Bernoulli trials with success probability $p$.
|
367 |
+
|
368 |
+
The beauty of Bernoulli is in its simplicity: just set $p$ (the probability of success) and you're good to go! The PMF gives us $P(X=1) = p$ and $P(X=0) = 1-p$, while expectation is simply $p$ and variance is $p(1-p)$. Oh, and when you're tracking whether specific events happen or not? That's an indicator random variable — just another Bernoulli in disguise!
|
369 |
+
|
370 |
+
Two key things to remember:
|
371 |
+
|
372 |
+
/// note
|
373 |
+
💡 **Maximum Variance**: A Bernoulli's variance $p(1-p)$ reaches its maximum at $p=0.5$, making a fair coin the most "unpredictable" Bernoulli random variable.
|
374 |
+
|
375 |
+
💡 **Instant Properties**: When you identify a random variable as Bernoulli, you instantly know all its properties—expectation, variance, PMF—without additional calculations.
|
376 |
+
///
|
377 |
+
|
378 |
+
Next up: Binomial distribution—where we'll see what happens when we let Bernoulli trials have a party and add themselves together!
|
379 |
+
"""
|
380 |
+
)
|
381 |
+
return
|
382 |
+
|
383 |
+
|
384 |
+
@app.cell(hide_code=True)
|
385 |
+
def _(mo):
|
386 |
+
mo.md(r"""#### Appendix (containing helper code for the notebook)""")
|
387 |
+
return
|
388 |
+
|
389 |
+
|
390 |
+
@app.cell
|
391 |
+
def _():
|
392 |
+
import marimo as mo
|
393 |
+
return (mo,)
|
394 |
+
|
395 |
+
|
396 |
+
@app.cell(hide_code=True)
|
397 |
+
def _():
|
398 |
+
from marimo import Html
|
399 |
+
return (Html,)
|
400 |
+
|
401 |
+
|
402 |
+
@app.cell(hide_code=True)
|
403 |
+
def _():
|
404 |
+
import numpy as np
|
405 |
+
import matplotlib.pyplot as plt
|
406 |
+
from scipy import stats
|
407 |
+
import math
|
408 |
+
|
409 |
+
# Set style for consistent visualizations
|
410 |
+
plt.style.use('seaborn-v0_8-whitegrid')
|
411 |
+
plt.rcParams['figure.figsize'] = [10, 6]
|
412 |
+
plt.rcParams['font.size'] = 12
|
413 |
+
|
414 |
+
# Set random seed for reproducibility
|
415 |
+
np.random.seed(42)
|
416 |
+
return math, np, plt, stats
|
417 |
+
|
418 |
+
|
419 |
+
@app.cell(hide_code=True)
|
420 |
+
def _(mo):
|
421 |
+
# Create a UI element for the parameter p
|
422 |
+
p_slider = mo.ui.slider(0.01, 0.99, value=0.65, step=0.01, label="Parameter p")
|
423 |
+
return (p_slider,)
|
424 |
+
|
425 |
+
|
426 |
+
if __name__ == "__main__":
|
427 |
+
app.run()
|
probability/14_binomial_distribution.py
ADDED
@@ -0,0 +1,545 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# /// script
|
2 |
+
# requires-python = ">=3.10"
|
3 |
+
# dependencies = [
|
4 |
+
# "marimo",
|
5 |
+
# "matplotlib==3.10.0",
|
6 |
+
# "numpy==2.2.4",
|
7 |
+
# "scipy==1.15.2",
|
8 |
+
# "altair==5.2.0",
|
9 |
+
# "wigglystuff==0.1.10",
|
10 |
+
# "pandas==2.2.3",
|
11 |
+
# ]
|
12 |
+
# ///
|
13 |
+
|
14 |
+
import marimo
|
15 |
+
|
16 |
+
__generated_with = "0.11.24"
|
17 |
+
app = marimo.App(width="medium", app_title="Binomial Distribution")
|
18 |
+
|
19 |
+
|
20 |
+
@app.cell(hide_code=True)
|
21 |
+
def _(mo):
|
22 |
+
mo.md(
|
23 |
+
r"""
|
24 |
+
# Binomial Distribution
|
25 |
+
|
26 |
+
_This notebook is a computational companion to ["Probability for Computer Scientists"](https://chrispiech.github.io/probabilityForComputerScientists/en/part2/binomial/), by Stanford professor Chris Piech._
|
27 |
+
|
28 |
+
In this section, we will discuss the binomial distribution. To start, imagine the following example:
|
29 |
+
|
30 |
+
Consider $n$ independent trials of an experiment where each trial is a "success" with probability $p$. Let $X$ be the number of successes in $n$ trials.
|
31 |
+
|
32 |
+
This situation is truly common in the natural world, and as such, there has been a lot of research into such phenomena. Random variables like $X$ are called **binomial random variables**. If you can identify that a process fits this description, you can inherit many already proved properties such as the PMF formula, expectation, and variance!
|
33 |
+
"""
|
34 |
+
)
|
35 |
+
return
|
36 |
+
|
37 |
+
|
38 |
+
@app.cell(hide_code=True)
|
39 |
+
def _(mo):
|
40 |
+
mo.md(
|
41 |
+
r"""
|
42 |
+
## Binomial Random Variable Definition
|
43 |
+
|
44 |
+
$X \sim \text{Bin}(n, p)$ represents a binomial random variable where:
|
45 |
+
|
46 |
+
- $X$ is our random variable (number of successes)
|
47 |
+
- $\text{Bin}$ indicates it follows a binomial distribution
|
48 |
+
- $n$ is the number of trials
|
49 |
+
- $p$ is the probability of success in each trial
|
50 |
+
|
51 |
+
```
|
52 |
+
X ~ Bin(n, p)
|
53 |
+
↑ ↑ ↑
|
54 |
+
| | +-- Probability of
|
55 |
+
| | success on each
|
56 |
+
| | trial
|
57 |
+
| +-- Number of trials
|
58 |
+
|
|
59 |
+
Our random variable
|
60 |
+
is distributed
|
61 |
+
as a Binomial
|
62 |
+
```
|
63 |
+
|
64 |
+
Here are a few examples of binomial random variables:
|
65 |
+
|
66 |
+
- Number of heads in $n$ coin flips
|
67 |
+
- Number of 1's in randomly generated length $n$ bit string
|
68 |
+
- Number of disk drives crashed in 1000 computer cluster, assuming disks crash independently
|
69 |
+
"""
|
70 |
+
)
|
71 |
+
return
|
72 |
+
|
73 |
+
|
74 |
+
@app.cell(hide_code=True)
|
75 |
+
def _(mo):
|
76 |
+
mo.md(
|
77 |
+
r"""
|
78 |
+
## Properties of Binomial Distribution
|
79 |
+
|
80 |
+
| Property | Formula |
|
81 |
+
|----------|---------|
|
82 |
+
| Notation | $X \sim \text{Bin}(n, p)$ |
|
83 |
+
| Description | Number of "successes" in $n$ identical, independent experiments each with probability of success $p$ |
|
84 |
+
| Parameters | $n \in \{0, 1, \dots\}$, the number of experiments<br>$p \in [0, 1]$, the probability that a single experiment gives a "success" |
|
85 |
+
| Support | $x \in \{0, 1, \dots, n\}$ |
|
86 |
+
| PMF equation | $P(X=x) = {n \choose x}p^x(1-p)^{n-x}$ |
|
87 |
+
| Expectation | $E[X] = n \cdot p$ |
|
88 |
+
| Variance | $\text{Var}(X) = n \cdot p \cdot (1-p)$ |
|
89 |
+
|
90 |
+
Let's explore how the binomial distribution changes with different parameters.
|
91 |
+
"""
|
92 |
+
)
|
93 |
+
return
|
94 |
+
|
95 |
+
|
96 |
+
@app.cell(hide_code=True)
|
97 |
+
def _(TangleSlider, mo):
|
98 |
+
# Interactive elements using TangleSlider
|
99 |
+
n_slider = mo.ui.anywidget(TangleSlider(
|
100 |
+
amount=10,
|
101 |
+
min_value=1,
|
102 |
+
max_value=30,
|
103 |
+
step=1,
|
104 |
+
digits=0,
|
105 |
+
suffix=" trials"
|
106 |
+
))
|
107 |
+
|
108 |
+
p_slider = mo.ui.anywidget(TangleSlider(
|
109 |
+
amount=0.5,
|
110 |
+
min_value=0.01,
|
111 |
+
max_value=0.99,
|
112 |
+
step=0.01,
|
113 |
+
digits=2,
|
114 |
+
suffix=" probability"
|
115 |
+
))
|
116 |
+
|
117 |
+
# Grid layout for the interactive controls
|
118 |
+
controls = mo.vstack([
|
119 |
+
mo.md("### Adjust Parameters to See How Binomial Distribution Changes"),
|
120 |
+
mo.hstack([
|
121 |
+
mo.md("**Number of trials (n):** "),
|
122 |
+
n_slider
|
123 |
+
], justify="start"),
|
124 |
+
mo.hstack([
|
125 |
+
mo.md("**Probability of success (p):** "),
|
126 |
+
p_slider
|
127 |
+
], justify="start"),
|
128 |
+
])
|
129 |
+
return controls, n_slider, p_slider
|
130 |
+
|
131 |
+
|
132 |
+
@app.cell(hide_code=True)
|
133 |
+
def _(controls):
|
134 |
+
controls
|
135 |
+
return
|
136 |
+
|
137 |
+
|
138 |
+
@app.cell(hide_code=True)
|
139 |
+
def _(n_slider, np, p_slider, plt, stats):
|
140 |
+
# Parameters from sliders
|
141 |
+
_n = int(n_slider.amount)
|
142 |
+
_p = p_slider.amount
|
143 |
+
|
144 |
+
# Calculate PMF
|
145 |
+
_x = np.arange(0, _n + 1)
|
146 |
+
_pmf = stats.binom.pmf(_x, _n, _p)
|
147 |
+
|
148 |
+
# Relevant stats
|
149 |
+
_mean = _n * _p
|
150 |
+
_variance = _n * _p * (1 - _p)
|
151 |
+
_std_dev = np.sqrt(_variance)
|
152 |
+
|
153 |
+
_fig, _ax = plt.subplots(figsize=(10, 6))
|
154 |
+
|
155 |
+
# Plot PMF as bars
|
156 |
+
_ax.bar(_x, _pmf, color='royalblue', alpha=0.7, label=f'PMF: P(X=k)')
|
157 |
+
|
158 |
+
# Add a line
|
159 |
+
_ax.plot(_x, _pmf, 'ro-', alpha=0.6, label='PMF line')
|
160 |
+
|
161 |
+
# Add vertical lines
|
162 |
+
_ax.axvline(x=_mean, color='green', linestyle='--', linewidth=2,
|
163 |
+
label=f'Mean: {_mean:.2f}')
|
164 |
+
|
165 |
+
# Shade the stdev region
|
166 |
+
_ax.axvspan(_mean - _std_dev, _mean + _std_dev, alpha=0.2, color='green',
|
167 |
+
label=f'±1 Std Dev: {_std_dev:.2f}')
|
168 |
+
|
169 |
+
# Add labels and title
|
170 |
+
_ax.set_xlabel('Number of Successes (k)')
|
171 |
+
_ax.set_ylabel('Probability: P(X=k)')
|
172 |
+
_ax.set_title(f'Binomial Distribution with n={_n}, p={_p:.2f}')
|
173 |
+
|
174 |
+
# Annotations
|
175 |
+
_ax.annotate(f'E[X] = {_mean:.2f}',
|
176 |
+
xy=(_mean, stats.binom.pmf(int(_mean), _n, _p)),
|
177 |
+
xytext=(_mean + 1, max(_pmf) * 0.8),
|
178 |
+
arrowprops=dict(facecolor='black', shrink=0.05, width=1))
|
179 |
+
|
180 |
+
_ax.annotate(f'Var(X) = {_variance:.2f}',
|
181 |
+
xy=(_mean, stats.binom.pmf(int(_mean), _n, _p) / 2),
|
182 |
+
xytext=(_mean + 1, max(_pmf) * 0.6),
|
183 |
+
arrowprops=dict(facecolor='black', shrink=0.05, width=1))
|
184 |
+
|
185 |
+
# Grid and legend
|
186 |
+
_ax.grid(alpha=0.3)
|
187 |
+
_ax.legend()
|
188 |
+
|
189 |
+
plt.tight_layout()
|
190 |
+
plt.gca()
|
191 |
+
return
|
192 |
+
|
193 |
+
|
194 |
+
@app.cell(hide_code=True)
|
195 |
+
def _(mo):
|
196 |
+
mo.md(
|
197 |
+
r"""
|
198 |
+
## Relationship to Bernoulli Random Variables
|
199 |
+
|
200 |
+
One way to think of the binomial is as the sum of $n$ Bernoulli variables. Say that $Y_i$ is an indicator Bernoulli random variable which is 1 if experiment $i$ is a success. Then if $X$ is the total number of successes in $n$ experiments, $X \sim \text{Bin}(n, p)$:
|
201 |
+
|
202 |
+
$$X = \sum_{i=1}^n Y_i$$
|
203 |
+
|
204 |
+
Recall that the outcome of $Y_i$ will be 1 or 0, so one way to think of $X$ is as the sum of those 1s and 0s.
|
205 |
+
"""
|
206 |
+
)
|
207 |
+
return
|
208 |
+
|
209 |
+
|
210 |
+
@app.cell(hide_code=True)
|
211 |
+
def _(mo):
|
212 |
+
mo.md(
|
213 |
+
r"""
|
214 |
+
## Binomial Probability Mass Function (PMF)
|
215 |
+
|
216 |
+
The most important property to know about a binomial is its [Probability Mass Function](https://marimo.app/https://github.com/marimo-team/learn/blob/main/probability/10_probability_mass_function.py):
|
217 |
+
|
218 |
+
$$P(X=k) = {n \choose k}p^k(1-p)^{n-k}$$
|
219 |
+
|
220 |
+
```
|
221 |
+
P(X = k) = (n) p^k(1-p)^(n-k)
|
222 |
+
↑ (k)
|
223 |
+
| ↑
|
224 |
+
| +-- Binomial coefficient:
|
225 |
+
| number of ways to choose
|
226 |
+
| k successes from n trials
|
227 |
+
|
|
228 |
+
Probability that our
|
229 |
+
variable takes on the
|
230 |
+
value k
|
231 |
+
```
|
232 |
+
|
233 |
+
Recall, we derived this formula in Part 1. There is a complete example on the probability of $k$ heads in $n$ coin flips, where each flip is heads with probability $p$.
|
234 |
+
|
235 |
+
To briefly review, if you think of each experiment as being distinct, then there are ${n \choose k}$ ways of permuting $k$ successes from $n$ experiments. For any of the mutually exclusive permutations, the probability of that permutation is $p^k \cdot (1-p)^{n-k}$.
|
236 |
+
|
237 |
+
The name binomial comes from the term ${n \choose k}$ which is formally called the binomial coefficient.
|
238 |
+
"""
|
239 |
+
)
|
240 |
+
return
|
241 |
+
|
242 |
+
|
243 |
+
@app.cell(hide_code=True)
|
244 |
+
def _(mo):
|
245 |
+
mo.md(
|
246 |
+
r"""
|
247 |
+
## Expectation of Binomial
|
248 |
+
|
249 |
+
There is an easy way to calculate the expectation of a binomial and a hard way. The easy way is to leverage the fact that a binomial is the sum of Bernoulli indicator random variables $X = \sum_{i=1}^{n} Y_i$ where $Y_i$ is an indicator of whether the $i$-th experiment was a success: $Y_i \sim \text{Bernoulli}(p)$.
|
250 |
+
|
251 |
+
Since the [expectation of the sum](http://marimo.app/https://github.com/marimo-team/learn/blob/main/probability/11_expectation.py) of random variables is the sum of expectations, we can add the expectation, $E[Y_i] = p$, of each of the Bernoulli's:
|
252 |
+
|
253 |
+
\begin{align}
|
254 |
+
E[X] &= E\Big[\sum_{i=1}^{n} Y_i\Big] && \text{Since }X = \sum_{i=1}^{n} Y_i \\
|
255 |
+
&= \sum_{i=1}^{n}E[ Y_i] && \text{Expectation of sum} \\
|
256 |
+
&= \sum_{i=1}^{n}p && \text{Expectation of Bernoulli} \\
|
257 |
+
&= n \cdot p && \text{Sum $n$ times}
|
258 |
+
\end{align}
|
259 |
+
|
260 |
+
The hard way is to use the definition of expectation:
|
261 |
+
|
262 |
+
\begin{align}
|
263 |
+
E[X] &= \sum_{i=0}^n i \cdot P(X = i) && \text{Def of expectation} \\
|
264 |
+
&= \sum_{i=0}^n i \cdot {n \choose i} p^i(1-p)^{n-i} && \text{Sub in PMF} \\
|
265 |
+
& \cdots && \text{Many steps later} \\
|
266 |
+
&= n \cdot p
|
267 |
+
\end{align}
|
268 |
+
"""
|
269 |
+
)
|
270 |
+
return
|
271 |
+
|
272 |
+
|
273 |
+
@app.cell(hide_code=True)
|
274 |
+
def _(mo):
|
275 |
+
mo.md(
|
276 |
+
r"""
|
277 |
+
## Binomial Distribution in Python
|
278 |
+
|
279 |
+
As you might expect, you can use binomial distributions in code. The standardized library for binomials is `scipy.stats.binom`.
|
280 |
+
|
281 |
+
One of the most helpful methods that this package provides is a way to calculate the PMF. For example, say $n=5$, $p=0.6$ and you want to find $P(X=2)$, you could use the following code:
|
282 |
+
"""
|
283 |
+
)
|
284 |
+
return
|
285 |
+
|
286 |
+
|
287 |
+
@app.cell
|
288 |
+
def _(stats):
|
289 |
+
# define variables for x, n, and p
|
290 |
+
_n = 5 # Integer value for n
|
291 |
+
_p = 0.6
|
292 |
+
_x = 2
|
293 |
+
|
294 |
+
# use scipy to compute the pmf
|
295 |
+
p_x = stats.binom.pmf(_x, _n, _p)
|
296 |
+
|
297 |
+
# use the probability for future work
|
298 |
+
print(f'P(X = {_x}) = {p_x:.4f}')
|
299 |
+
return (p_x,)
|
300 |
+
|
301 |
+
|
302 |
+
@app.cell(hide_code=True)
|
303 |
+
def _(mo):
|
304 |
+
mo.md(r"""Another particularly helpful function is the ability to generate a random sample from a binomial. For example, say $X$ represents the number of requests to a website. We can draw 100 samples from this distribution using the following code:""")
|
305 |
+
return
|
306 |
+
|
307 |
+
|
308 |
+
@app.cell
|
309 |
+
def _(n, p, stats):
|
310 |
+
n_int = int(n)
|
311 |
+
|
312 |
+
# samples from the binomial distribution
|
313 |
+
samples = stats.binom.rvs(n_int, p, size=100)
|
314 |
+
|
315 |
+
# Print the samples
|
316 |
+
print(samples)
|
317 |
+
return n_int, samples
|
318 |
+
|
319 |
+
|
320 |
+
@app.cell(hide_code=True)
|
321 |
+
def _(n_int, np, p, plt, samples, stats):
|
322 |
+
# Plot histogram of samples
|
323 |
+
plt.figure(figsize=(10, 5))
|
324 |
+
plt.hist(samples, bins=np.arange(-0.5, n_int+1.5, 1), alpha=0.7, color='royalblue',
|
325 |
+
edgecolor='black', density=True)
|
326 |
+
|
327 |
+
# Overlay the PMF
|
328 |
+
x_values = np.arange(0, n_int+1)
|
329 |
+
pmf_values = stats.binom.pmf(x_values, n_int, p)
|
330 |
+
plt.plot(x_values, pmf_values, 'ro-', ms=8, label='Theoretical PMF')
|
331 |
+
|
332 |
+
# Add labels and title
|
333 |
+
plt.xlabel('Number of Successes')
|
334 |
+
plt.ylabel('Relative Frequency / Probability')
|
335 |
+
plt.title(f'Histogram of 100 Samples from Bin({n_int}, {p})')
|
336 |
+
plt.legend()
|
337 |
+
plt.grid(alpha=0.3)
|
338 |
+
|
339 |
+
# Annotate
|
340 |
+
plt.annotate('Sample mean: %.2f' % np.mean(samples),
|
341 |
+
xy=(0.7, 0.9), xycoords='axes fraction',
|
342 |
+
bbox=dict(boxstyle='round,pad=0.5', fc='yellow', alpha=0.3))
|
343 |
+
plt.annotate('Theoretical mean: %.2f' % (n_int*p),
|
344 |
+
xy=(0.7, 0.8), xycoords='axes fraction',
|
345 |
+
bbox=dict(boxstyle='round,pad=0.5', fc='lightgreen', alpha=0.3))
|
346 |
+
|
347 |
+
plt.tight_layout()
|
348 |
+
plt.gca()
|
349 |
+
return pmf_values, x_values
|
350 |
+
|
351 |
+
|
352 |
+
@app.cell(hide_code=True)
|
353 |
+
def _(mo):
|
354 |
+
mo.md(
|
355 |
+
r"""
|
356 |
+
You might be wondering what a random sample is! A random sample is a randomly chosen assignment for our random variable. Above we have 100 such assignments. The probability that value $k$ is chosen is given by the PMF: $P(X=k)$.
|
357 |
+
|
358 |
+
There are also functions for getting the mean, the variance, and more. You can read the [scipy.stats.binom documentation](https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.binom.html), especially the list of methods.
|
359 |
+
"""
|
360 |
+
)
|
361 |
+
return
|
362 |
+
|
363 |
+
|
364 |
+
@app.cell(hide_code=True)
|
365 |
+
def _(mo):
|
366 |
+
mo.md(
|
367 |
+
r"""
|
368 |
+
## Interactive Exploration of Binomial vs. Negative Binomial
|
369 |
+
|
370 |
+
The standard binomial distribution is a special case of a broader family of distributions. One related distribution is the negative binomial, which can model count data with overdispersion (where the variance is larger than the mean).
|
371 |
+
|
372 |
+
Below, you can explore how the negative binomial distribution compares to a Poisson distribution (which can be seen as a limiting case of the binomial as $n$ gets large and $p$ gets small, with $np$ held constant).
|
373 |
+
|
374 |
+
Adjust the sliders to see how the parameters affect the distribution:
|
375 |
+
|
376 |
+
*Note: The interactive visualization in this section was inspired by work from [liquidcarbon on GitHub](https://github.com/liquidcarbon).*
|
377 |
+
"""
|
378 |
+
)
|
379 |
+
return
|
380 |
+
|
381 |
+
|
382 |
+
@app.cell(hide_code=True)
|
383 |
+
def _(alpha_slider, chart, equation, mo, mu_slider):
|
384 |
+
mo.vstack(
|
385 |
+
[
|
386 |
+
mo.md(f"## Negative Binomial Distribution (Poisson + Overdispersion)\n{equation}"),
|
387 |
+
mo.hstack([mu_slider, alpha_slider], justify="start"),
|
388 |
+
chart,
|
389 |
+
], justify='space-around'
|
390 |
+
).center()
|
391 |
+
return
|
392 |
+
|
393 |
+
|
394 |
+
@app.cell(hide_code=True)
|
395 |
+
def _(mo):
|
396 |
+
mo.md(
|
397 |
+
r"""
|
398 |
+
## 🤔 Test Your Understanding
|
399 |
+
Pick which of these statements about binomial distributions you think are correct:
|
400 |
+
|
401 |
+
/// details | The variance of a binomial distribution is always equal to its mean
|
402 |
+
❌ Incorrect! The variance is $np(1-p)$ while the mean is $np$. They're only equal when $p=1$ (which is a degenerate case).
|
403 |
+
///
|
404 |
+
|
405 |
+
/// details | If $X \sim \text{Bin}(n, p)$ and $Y \sim \text{Bin}(n, 1-p)$, then $X$ and $Y$ have the same variance
|
406 |
+
✅ Correct! $\text{Var}(X) = np(1-p)$ and $\text{Var}(Y) = n(1-p)p$, which are the same.
|
407 |
+
///
|
408 |
+
|
409 |
+
/// details | As the number of trials increases, the binomial distribution approaches a normal distribution
|
410 |
+
✅ Correct! For large $n$, the binomial distribution can be approximated by a normal distribution with the same mean and variance.
|
411 |
+
///
|
412 |
+
|
413 |
+
/// details | The PMF of a binomial distribution is symmetric when $p = 0.5$
|
414 |
+
✅ Correct! When $p = 0.5$, the PMF is symmetric around $n/2$.
|
415 |
+
///
|
416 |
+
|
417 |
+
/// details | The sum of two independent binomial random variables with the same $p$ is also a binomial random variable
|
418 |
+
✅ Correct! If $X \sim \text{Bin}(n_1, p)$ and $Y \sim \text{Bin}(n_2, p)$ are independent, then $X + Y \sim \text{Bin}(n_1 + n_2, p)$.
|
419 |
+
///
|
420 |
+
|
421 |
+
/// details | The maximum value of the PMF for $\text{Bin}(n,p)$ always occurs at $k = np$
|
422 |
+
❌ Incorrect! The mode (maximum value of PMF) is either $\lfloor (n+1)p \rfloor$ or $\lceil (n+1)p-1 \rceil$ depending on whether $(n+1)p$ is an integer.
|
423 |
+
///
|
424 |
+
"""
|
425 |
+
)
|
426 |
+
return
|
427 |
+
|
428 |
+
|
429 |
+
@app.cell(hide_code=True)
|
430 |
+
def _(mo):
|
431 |
+
mo.md(
|
432 |
+
r"""
|
433 |
+
## Summary
|
434 |
+
|
435 |
+
So we've explored the binomial distribution, and honestly, it's one of the most practical probability distributions you'll encounter. Think about it — anytime you're counting successes in a fixed number of trials (like those coin flips we discussed), this is your go-to distribution.
|
436 |
+
|
437 |
+
I find it fascinating how the expectation is simply $np$. Such a clean, intuitive formula! And remember that neat visualization we saw earlier? When we adjusted the parameters, you could actually see how the distribution shape changes—becoming more symmetric as $n$ increases.
|
438 |
+
|
439 |
+
The key things to take away:
|
440 |
+
|
441 |
+
- The binomial distribution models the number of successes in $n$ independent trials, each with probability $p$ of success
|
442 |
+
|
443 |
+
- Its PMF is given by the formula $P(X=k) = {n \choose k}p^k(1-p)^{n-k}$, which lets us calculate exactly how likely any specific number of successes is
|
444 |
+
|
445 |
+
- The expected value is $E[X] = np$ and the variance is $Var(X) = np(1-p)$
|
446 |
+
|
447 |
+
- It's related to other distributions: it's essentially a sum of Bernoulli random variables, and connects to both the negative binomial and Poisson distributions
|
448 |
+
|
449 |
+
- In Python, the `scipy.stats.binom` module makes working with binomial distributions straightforward—you can generate random samples and calculate probabilities with just a few lines of code
|
450 |
+
|
451 |
+
You'll see the binomial distribution pop up everywhere—from computer science to quality control, epidemiology, and data science. Any time you have scenarios with binary outcomes over multiple trials, this distribution has you covered.
|
452 |
+
"""
|
453 |
+
)
|
454 |
+
return
|
455 |
+
|
456 |
+
|
457 |
+
@app.cell(hide_code=True)
|
458 |
+
def _(mo):
|
459 |
+
mo.md(r"""Appendix code (helper functions, variables, etc.):""")
|
460 |
+
return
|
461 |
+
|
462 |
+
|
463 |
+
@app.cell
|
464 |
+
def _():
|
465 |
+
import marimo as mo
|
466 |
+
return (mo,)
|
467 |
+
|
468 |
+
|
469 |
+
@app.cell(hide_code=True)
|
470 |
+
def _():
|
471 |
+
import numpy as np
|
472 |
+
import matplotlib.pyplot as plt
|
473 |
+
import scipy.stats as stats
|
474 |
+
import pandas as pd
|
475 |
+
import altair as alt
|
476 |
+
from wigglystuff import TangleSlider
|
477 |
+
return TangleSlider, alt, np, pd, plt, stats
|
478 |
+
|
479 |
+
|
480 |
+
@app.cell(hide_code=True)
|
481 |
+
def _(mo):
|
482 |
+
alpha_slider = mo.ui.slider(
|
483 |
+
value=0.1,
|
484 |
+
steps=[0, 0.01, 0.02, 0.03, 0.05, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.8, 1],
|
485 |
+
label="α (overdispersion)",
|
486 |
+
show_value=True,
|
487 |
+
)
|
488 |
+
mu_slider = mo.ui.slider(
|
489 |
+
value=100, start=1, stop=100, step=1, label="μ (mean)", show_value=True
|
490 |
+
)
|
491 |
+
return alpha_slider, mu_slider
|
492 |
+
|
493 |
+
|
494 |
+
@app.cell(hide_code=True)
|
495 |
+
def _():
|
496 |
+
equation = """
|
497 |
+
$$
|
498 |
+
P(X = k) = \\frac{\\Gamma(k + \\frac{1}{\\alpha})}{\\Gamma(k + 1) \\Gamma(\\frac{1}{\\alpha})} \\left( \\frac{1}{\\mu \\alpha + 1} \\right)^{\\frac{1}{\\alpha}} \\left( \\frac{\\mu \\alpha}{\\mu \\alpha + 1} \\right)^k
|
499 |
+
$$
|
500 |
+
|
501 |
+
$$
|
502 |
+
\\sigma^2 = \\mu + \\alpha \\mu^2
|
503 |
+
$$
|
504 |
+
"""
|
505 |
+
return (equation,)
|
506 |
+
|
507 |
+
|
508 |
+
@app.cell(hide_code=True)
|
509 |
+
def _(alpha_slider, alt, mu_slider, np, pd, stats):
|
510 |
+
mu = mu_slider.value
|
511 |
+
alpha = alpha_slider.value
|
512 |
+
n = 1000 - mu if alpha == 0 else 1 / alpha
|
513 |
+
p = n / (mu + n)
|
514 |
+
x = np.arange(0, mu * 3 + 1, 1)
|
515 |
+
df = pd.DataFrame(
|
516 |
+
{
|
517 |
+
"x": x,
|
518 |
+
"y": stats.nbinom.pmf(x, n, p),
|
519 |
+
"y_poi": stats.nbinom.pmf(x, 1000 - mu, 1 - mu / 1000),
|
520 |
+
}
|
521 |
+
)
|
522 |
+
r1k = stats.nbinom.rvs(n, p, size=1000)
|
523 |
+
df["in 95% CI"] = df["x"].between(*np.percentile(r1k, q=[2.5, 97.5]))
|
524 |
+
base = alt.Chart(df)
|
525 |
+
|
526 |
+
chart_poi = base.mark_bar(
|
527 |
+
fillOpacity=0.25, width=100 / mu, fill="magenta"
|
528 |
+
).encode(
|
529 |
+
x=alt.X("x").scale(domain=(-0.4, x.max() + 0.4), nice=False),
|
530 |
+
y=alt.Y("y_poi").scale(domain=(0, df.y_poi.max() * 1.1)).title(None),
|
531 |
+
)
|
532 |
+
chart_nb = base.mark_bar(fillOpacity=0.75, width=100 / mu).encode(
|
533 |
+
x="x",
|
534 |
+
y="y",
|
535 |
+
fill=alt.Fill("in 95% CI")
|
536 |
+
.scale(domain=[False, True], range=["#aaa", "#7c7"])
|
537 |
+
.legend(orient="bottom-right"),
|
538 |
+
)
|
539 |
+
|
540 |
+
chart = (chart_poi + chart_nb).configure_view(continuousWidth=450)
|
541 |
+
return alpha, base, chart, chart_nb, chart_poi, df, mu, n, p, r1k, x
|
542 |
+
|
543 |
+
|
544 |
+
if __name__ == "__main__":
|
545 |
+
app.run()
|
probability/15_poisson_distribution.py
ADDED
@@ -0,0 +1,805 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# /// script
|
2 |
+
# requires-python = ">=3.10"
|
3 |
+
# dependencies = [
|
4 |
+
# "marimo",
|
5 |
+
# "matplotlib==3.10.0",
|
6 |
+
# "numpy==2.2.4",
|
7 |
+
# "scipy==1.15.2",
|
8 |
+
# "altair==5.2.0",
|
9 |
+
# "wigglystuff==0.1.10",
|
10 |
+
# "pandas==2.2.3",
|
11 |
+
# ]
|
12 |
+
# ///
|
13 |
+
|
14 |
+
import marimo
|
15 |
+
|
16 |
+
__generated_with = "0.11.25"
|
17 |
+
app = marimo.App(width="medium", app_title="Poisson Distribution")
|
18 |
+
|
19 |
+
|
20 |
+
@app.cell(hide_code=True)
|
21 |
+
def _(mo):
|
22 |
+
mo.md(
|
23 |
+
r"""
|
24 |
+
# Poisson Distribution
|
25 |
+
|
26 |
+
_This notebook is a computational companion to ["Probability for Computer Scientists"](https://chrispiech.github.io/probabilityForComputerScientists/en/part2/poisson/), by Stanford professor Chris Piech._
|
27 |
+
|
28 |
+
A Poisson random variable gives the probability of a given number of events in a fixed interval of time (or space). It makes the Poisson assumption that events occur with a known constant mean rate and independently of the time since the last event.
|
29 |
+
"""
|
30 |
+
)
|
31 |
+
return
|
32 |
+
|
33 |
+
|
34 |
+
@app.cell(hide_code=True)
|
35 |
+
def _(mo):
|
36 |
+
mo.md(
|
37 |
+
r"""
|
38 |
+
## Poisson Random Variable Definition
|
39 |
+
|
40 |
+
$X \sim \text{Poisson}(\lambda)$ represents a Poisson random variable where:
|
41 |
+
|
42 |
+
- $X$ is our random variable (number of events)
|
43 |
+
- $\text{Poisson}$ indicates it follows a Poisson distribution
|
44 |
+
- $\lambda$ is the rate parameter (average number of events per time interval)
|
45 |
+
|
46 |
+
```
|
47 |
+
X ~ Poisson(λ)
|
48 |
+
↑ ↑ ↑
|
49 |
+
| | +-- Rate parameter:
|
50 |
+
| | average number of
|
51 |
+
| | events per interval
|
52 |
+
| +-- Indicates Poisson
|
53 |
+
| distribution
|
54 |
+
|
|
55 |
+
Our random variable
|
56 |
+
counting number of events
|
57 |
+
```
|
58 |
+
|
59 |
+
The Poisson distribution is particularly useful when:
|
60 |
+
|
61 |
+
1. Events occur independently of each other
|
62 |
+
2. The average rate of occurrence is constant
|
63 |
+
3. Two events cannot occur at exactly the same instant
|
64 |
+
4. The probability of an event is proportional to the length of the time interval
|
65 |
+
"""
|
66 |
+
)
|
67 |
+
return
|
68 |
+
|
69 |
+
|
70 |
+
@app.cell(hide_code=True)
|
71 |
+
def _(mo):
|
72 |
+
mo.md(
|
73 |
+
r"""
|
74 |
+
## Properties of Poisson Distribution
|
75 |
+
|
76 |
+
| Property | Formula |
|
77 |
+
|----------|---------|
|
78 |
+
| Notation | $X \sim \text{Poisson}(\lambda)$ |
|
79 |
+
| Description | Number of events in a fixed time frame if (a) events occur with a constant mean rate and (b) they occur independently of time since last event |
|
80 |
+
| Parameters | $\lambda \in \mathbb{R}^{+}$, the constant average rate |
|
81 |
+
| Support | $x \in \{0, 1, \dots\}$ |
|
82 |
+
| PMF equation | $P(X=x) = \frac{\lambda^x e^{-\lambda}}{x!}$ |
|
83 |
+
| Expectation | $E[X] = \lambda$ |
|
84 |
+
| Variance | $\text{Var}(X) = \lambda$ |
|
85 |
+
|
86 |
+
Note that unlike many other distributions, the Poisson distribution's mean and variance are equal, both being $\lambda$.
|
87 |
+
|
88 |
+
Let's explore how the Poisson distribution changes with different rate parameters.
|
89 |
+
"""
|
90 |
+
)
|
91 |
+
return
|
92 |
+
|
93 |
+
|
94 |
+
@app.cell(hide_code=True)
|
95 |
+
def _(TangleSlider, mo):
|
96 |
+
# interactive elements using TangleSlider
|
97 |
+
lambda_slider = mo.ui.anywidget(TangleSlider(
|
98 |
+
amount=5,
|
99 |
+
min_value=0.1,
|
100 |
+
max_value=20,
|
101 |
+
step=0.1,
|
102 |
+
digits=1,
|
103 |
+
suffix=" events"
|
104 |
+
))
|
105 |
+
|
106 |
+
# interactive controls
|
107 |
+
_controls = mo.vstack([
|
108 |
+
mo.md("### Adjust the Rate Parameter to See How Poisson Distribution Changes"),
|
109 |
+
mo.hstack([
|
110 |
+
mo.md("**Rate parameter (λ):** "),
|
111 |
+
lambda_slider,
|
112 |
+
mo.md("**events per interval.** Higher values shift the distribution rightward and make it more spread out.")
|
113 |
+
], justify="start"),
|
114 |
+
])
|
115 |
+
_controls
|
116 |
+
return (lambda_slider,)
|
117 |
+
|
118 |
+
|
119 |
+
@app.cell(hide_code=True)
|
120 |
+
def _(lambda_slider, np, plt, stats):
|
121 |
+
def create_poisson_pmf_plot(lambda_value):
|
122 |
+
"""Create a visualization of Poisson PMF with annotations for mean and variance."""
|
123 |
+
# PMF for values
|
124 |
+
max_x = max(20, int(lambda_value * 3)) # Show at least up to 3*lambda
|
125 |
+
x = np.arange(0, max_x + 1)
|
126 |
+
pmf = stats.poisson.pmf(x, lambda_value)
|
127 |
+
|
128 |
+
# Relevant key statistics
|
129 |
+
mean = lambda_value # For Poisson, mean = lambda
|
130 |
+
variance = lambda_value # For Poisson, variance = lambda
|
131 |
+
std_dev = np.sqrt(variance)
|
132 |
+
|
133 |
+
# plot
|
134 |
+
fig, ax = plt.subplots(figsize=(10, 6))
|
135 |
+
|
136 |
+
# PMF as bars
|
137 |
+
ax.bar(x, pmf, color='royalblue', alpha=0.7, label=f'PMF: P(X=k)')
|
138 |
+
|
139 |
+
# for the PMF values
|
140 |
+
ax.plot(x, pmf, 'ro-', alpha=0.6, label='PMF line')
|
141 |
+
|
142 |
+
# Vertical lines - mean and key values
|
143 |
+
ax.axvline(x=mean, color='green', linestyle='--', linewidth=2,
|
144 |
+
label=f'Mean: {mean:.2f}')
|
145 |
+
|
146 |
+
# Stdev region
|
147 |
+
ax.axvspan(mean - std_dev, mean + std_dev, alpha=0.2, color='green',
|
148 |
+
label=f'±1 Std Dev: {std_dev:.2f}')
|
149 |
+
|
150 |
+
ax.set_xlabel('Number of Events (k)')
|
151 |
+
ax.set_ylabel('Probability: P(X=k)')
|
152 |
+
ax.set_title(f'Poisson Distribution with λ={lambda_value:.1f}')
|
153 |
+
|
154 |
+
# annotations
|
155 |
+
ax.annotate(f'E[X] = {mean:.2f}',
|
156 |
+
xy=(mean, stats.poisson.pmf(int(mean), lambda_value)),
|
157 |
+
xytext=(mean + 1, max(pmf) * 0.8),
|
158 |
+
arrowprops=dict(facecolor='black', shrink=0.05, width=1))
|
159 |
+
|
160 |
+
ax.annotate(f'Var(X) = {variance:.2f}',
|
161 |
+
xy=(mean, stats.poisson.pmf(int(mean), lambda_value) / 2),
|
162 |
+
xytext=(mean + 1, max(pmf) * 0.6),
|
163 |
+
arrowprops=dict(facecolor='black', shrink=0.05, width=1))
|
164 |
+
|
165 |
+
ax.grid(alpha=0.3)
|
166 |
+
ax.legend()
|
167 |
+
|
168 |
+
plt.tight_layout()
|
169 |
+
return plt.gca()
|
170 |
+
|
171 |
+
# Get parameter from slider and create plot
|
172 |
+
_lambda = lambda_slider.amount
|
173 |
+
create_poisson_pmf_plot(_lambda)
|
174 |
+
return (create_poisson_pmf_plot,)
|
175 |
+
|
176 |
+
|
177 |
+
@app.cell(hide_code=True)
|
178 |
+
def _(mo):
|
179 |
+
mo.md(
|
180 |
+
r"""
|
181 |
+
## Poisson Intuition: Relation to Binomial Distribution
|
182 |
+
|
183 |
+
The Poisson distribution can be derived as a limiting case of the [binomial distribution](http://marimo.app/https://github.com/marimo-team/learn/blob/main/probability/14_binomial_distribution.py).
|
184 |
+
|
185 |
+
Let's work on a practical example: predicting the number of ride-sharing requests in a specific area over a one-minute interval. From historical data, we know that the average number of requests per minute is $\lambda = 5$.
|
186 |
+
|
187 |
+
We could approximate this using a binomial distribution by dividing our minute into smaller intervals. For example, we can divide a minute into 60 seconds and treat each second as a [Bernoulli trial](http://marimo.app/https://github.com/marimo-team/learn/blob/main/probability/13_bernoulli_distribution.py) - either there's a request (success) or there isn't (failure).
|
188 |
+
|
189 |
+
Let's visualize this concept:
|
190 |
+
"""
|
191 |
+
)
|
192 |
+
return
|
193 |
+
|
194 |
+
|
195 |
+
@app.cell(hide_code=True)
|
196 |
+
def _(fig_to_image, mo, plt):
|
197 |
+
def create_time_division_visualization():
|
198 |
+
# visualization of dividing a minute into 60 seconds
|
199 |
+
fig, ax = plt.subplots(figsize=(12, 2))
|
200 |
+
|
201 |
+
# Example events hardcoded at 2.75s and 7.12s
|
202 |
+
events = [2.75, 7.12]
|
203 |
+
|
204 |
+
# array of 60 rectangles
|
205 |
+
for i in range(60):
|
206 |
+
color = 'royalblue' if any(i <= e < i+1 for e in events) else 'lightgray'
|
207 |
+
ax.add_patch(plt.Rectangle((i, 0), 0.9, 1, color=color))
|
208 |
+
|
209 |
+
# markers for events
|
210 |
+
for e in events:
|
211 |
+
ax.plot(e, 0.5, 'ro', markersize=10)
|
212 |
+
|
213 |
+
# labels
|
214 |
+
ax.set_xlim(0, 60)
|
215 |
+
ax.set_ylim(0, 1)
|
216 |
+
ax.set_yticks([])
|
217 |
+
ax.set_xticks([0, 15, 30, 45, 60])
|
218 |
+
ax.set_xticklabels(['0s', '15s', '30s', '45s', '60s'])
|
219 |
+
ax.set_xlabel('Time (seconds)')
|
220 |
+
ax.set_title('One Minute Divided into 60 Second Intervals')
|
221 |
+
|
222 |
+
plt.tight_layout()
|
223 |
+
plt.gca()
|
224 |
+
return fig, events, i
|
225 |
+
|
226 |
+
# Create visualization and convert to image
|
227 |
+
_fig, _events, i = create_time_division_visualization()
|
228 |
+
_img = mo.image(fig_to_image(_fig), width="100%")
|
229 |
+
|
230 |
+
# explanation
|
231 |
+
_explanation = mo.md(
|
232 |
+
r"""
|
233 |
+
In this visualization:
|
234 |
+
|
235 |
+
- Each rectangle represents a 1-second interval
|
236 |
+
- Blue rectangles indicate intervals where an event occurred
|
237 |
+
- Red dots show the actual event times (2.75s and 7.12s)
|
238 |
+
|
239 |
+
If we treat this as a binomial experiment with 60 trials (seconds), we can calculate probabilities using the binomial PMF. But there's a problem: what if multiple events occur within the same second? To address this, we can divide our minute into smaller intervals.
|
240 |
+
"""
|
241 |
+
)
|
242 |
+
mo.vstack([_fig, _explanation])
|
243 |
+
return create_time_division_visualization, i
|
244 |
+
|
245 |
+
|
246 |
+
@app.cell(hide_code=True)
|
247 |
+
def _(mo):
|
248 |
+
mo.md(
|
249 |
+
r"""
|
250 |
+
The total number of requests received over the minute can be approximated as the sum of the sixty indicator variables, which conveniently matches the description of a binomial — a sum of Bernoullis.
|
251 |
+
|
252 |
+
Specifically, if we define $X$ to be the number of requests in a minute, $X$ is a binomial with $n=60$ trials. What is the probability, $p$, of a success on a single trial? To make the expectation of $X$ equal the observed historical average $\lambda$, we should choose $p$ so that:
|
253 |
+
|
254 |
+
\begin{align}
|
255 |
+
\lambda &= E[X] && \text{Expectation matches historical average} \\
|
256 |
+
\lambda &= n \cdot p && \text{Expectation of a Binomial is } n \cdot p \\
|
257 |
+
p &= \frac{\lambda}{n} && \text{Solving for $p$}
|
258 |
+
\end{align}
|
259 |
+
|
260 |
+
In this case, since $\lambda=5$ and $n=60$, we should choose $p=\frac{5}{60}=\frac{1}{12}$ and state that $X \sim \text{Bin}(n=60, p=\frac{5}{60})$. Now we can calculate the probability of different numbers of requests using the binomial PMF:
|
261 |
+
|
262 |
+
$P(X = x) = {n \choose x} p^x (1-p)^{n-x}$
|
263 |
+
|
264 |
+
For example:
|
265 |
+
|
266 |
+
\begin{align}
|
267 |
+
P(X=1) &= {60 \choose 1} (5/60)^1 (55/60)^{60-1} \approx 0.0295 \\
|
268 |
+
P(X=2) &= {60 \choose 2} (5/60)^2 (55/60)^{60-2} \approx 0.0790 \\
|
269 |
+
P(X=3) &= {60 \choose 3} (5/60)^3 (55/60)^{60-3} \approx 0.1389
|
270 |
+
\end{align}
|
271 |
+
|
272 |
+
This is a good approximation, but it doesn't account for the possibility of multiple events in a single second. One solution is to divide our minute into even more fine-grained intervals. Let's try 600 deciseconds (tenths of a second):
|
273 |
+
"""
|
274 |
+
)
|
275 |
+
return
|
276 |
+
|
277 |
+
|
278 |
+
@app.cell(hide_code=True)
|
279 |
+
def _(fig_to_image, mo, plt):
|
280 |
+
def create_decisecond_visualization(e_value):
|
281 |
+
# (Just showing the first 100 for clarity)
|
282 |
+
fig, ax = plt.subplots(figsize=(12, 2))
|
283 |
+
|
284 |
+
# Example events at 2.75s and 7.12s (convert to deciseconds)
|
285 |
+
events = [27.5, 71.2]
|
286 |
+
|
287 |
+
for i in range(100):
|
288 |
+
color = 'royalblue' if any(i <= event_val < i + 1 for event_val in events) else 'lightgray'
|
289 |
+
ax.add_patch(plt.Rectangle((i, 0), 0.9, 1, color=color))
|
290 |
+
|
291 |
+
# Markers for events
|
292 |
+
for event in events:
|
293 |
+
if event < 100: # Only show events in our visible range
|
294 |
+
ax.plot(event/10, 0.5, 'ro', markersize=10) # Divide by 10 to convert to deciseconds
|
295 |
+
|
296 |
+
# Add labels
|
297 |
+
ax.set_xlim(0, 100)
|
298 |
+
ax.set_ylim(0, 1)
|
299 |
+
ax.set_yticks([])
|
300 |
+
ax.set_xticks([0, 20, 40, 60, 80, 100])
|
301 |
+
ax.set_xticklabels(['0s', '2s', '4s', '6s', '8s', '10s'])
|
302 |
+
ax.set_xlabel('Time (first 10 seconds shown)')
|
303 |
+
ax.set_title('One Minute Divided into 600 Decisecond Intervals (first 100 shown)')
|
304 |
+
|
305 |
+
plt.tight_layout()
|
306 |
+
plt.gca()
|
307 |
+
return fig
|
308 |
+
|
309 |
+
# Create viz and convert to image
|
310 |
+
_fig = create_decisecond_visualization(e_value=5)
|
311 |
+
_img = mo.image(fig_to_image(_fig), width="100%")
|
312 |
+
|
313 |
+
# Explanation
|
314 |
+
_explanation = mo.md(
|
315 |
+
r"""
|
316 |
+
With $n=600$ and $p=\frac{5}{600}=\frac{1}{120}$, we can recalculate our probabilities:
|
317 |
+
|
318 |
+
\begin{align}
|
319 |
+
P(X=1) &= {600 \choose 1} (5/600)^1 (595/600)^{600-1} \approx 0.0333 \\
|
320 |
+
P(X=2) &= {600 \choose 2} (5/600)^2 (595/600)^{600-2} \approx 0.0837 \\
|
321 |
+
P(X=3) &= {600 \choose 3} (5/600)^3 (595/600)^{600-3} \approx 0.1402
|
322 |
+
\end{align}
|
323 |
+
|
324 |
+
As we make our intervals smaller (increasing $n$), our approximation becomes more accurate.
|
325 |
+
"""
|
326 |
+
)
|
327 |
+
mo.vstack([_fig, _explanation])
|
328 |
+
return (create_decisecond_visualization,)
|
329 |
+
|
330 |
+
|
331 |
+
@app.cell(hide_code=True)
|
332 |
+
def _(mo):
|
333 |
+
mo.md(
|
334 |
+
r"""
|
335 |
+
## The Binomial Distribution in the Limit
|
336 |
+
|
337 |
+
What happens if we continue dividing our time interval into smaller and smaller pieces? Let's explore how the probabilities change as we increase the number of intervals:
|
338 |
+
"""
|
339 |
+
)
|
340 |
+
return
|
341 |
+
|
342 |
+
|
343 |
+
@app.cell(hide_code=True)
|
344 |
+
def _(mo):
|
345 |
+
intervals_slider = mo.ui.slider(
|
346 |
+
start = 60,
|
347 |
+
stop = 10000,
|
348 |
+
step=100,
|
349 |
+
value=600,
|
350 |
+
label="Number of intervals to divide a minute")
|
351 |
+
return (intervals_slider,)
|
352 |
+
|
353 |
+
|
354 |
+
@app.cell(hide_code=True)
|
355 |
+
def _(intervals_slider):
|
356 |
+
intervals_slider
|
357 |
+
return
|
358 |
+
|
359 |
+
|
360 |
+
@app.cell(hide_code=True)
|
361 |
+
def _(intervals_slider, np, pd, plt, stats):
|
362 |
+
def create_comparison_plot(n, lambda_value):
|
363 |
+
# Calculate probability
|
364 |
+
p = lambda_value / n
|
365 |
+
|
366 |
+
# Binomial probabilities
|
367 |
+
x_values = np.arange(0, 15)
|
368 |
+
binom_pmf = stats.binom.pmf(x_values, n, p)
|
369 |
+
|
370 |
+
# True Poisson probabilities
|
371 |
+
poisson_pmf = stats.poisson.pmf(x_values, lambda_value)
|
372 |
+
|
373 |
+
# DF for comparison
|
374 |
+
df = pd.DataFrame({
|
375 |
+
'Events': x_values,
|
376 |
+
f'Binomial(n={n}, p={p:.6f})': binom_pmf,
|
377 |
+
f'Poisson(λ=5)': poisson_pmf,
|
378 |
+
'Difference': np.abs(binom_pmf - poisson_pmf)
|
379 |
+
})
|
380 |
+
|
381 |
+
# Plot both PMFs
|
382 |
+
fig, ax = plt.subplots(figsize=(10, 6))
|
383 |
+
|
384 |
+
# Bar plot for the binomial
|
385 |
+
ax.bar(x_values - 0.2, binom_pmf, width=0.4, alpha=0.7,
|
386 |
+
color='royalblue', label=f'Binomial(n={n}, p={p:.6f})')
|
387 |
+
|
388 |
+
# Bar plot for the Poisson
|
389 |
+
ax.bar(x_values + 0.2, poisson_pmf, width=0.4, alpha=0.7,
|
390 |
+
color='crimson', label='Poisson(λ=5)')
|
391 |
+
|
392 |
+
# Labels and title
|
393 |
+
ax.set_xlabel('Number of Events (k)')
|
394 |
+
ax.set_ylabel('Probability')
|
395 |
+
ax.set_title(f'Comparison of Binomial and Poisson PMFs with n={n}')
|
396 |
+
ax.legend()
|
397 |
+
ax.set_xticks(x_values)
|
398 |
+
ax.grid(alpha=0.3)
|
399 |
+
|
400 |
+
plt.tight_layout()
|
401 |
+
return df, fig, n, p
|
402 |
+
|
403 |
+
# Number of intervals from the slider
|
404 |
+
n = intervals_slider.value
|
405 |
+
_lambda = 5 # Fixed lambda for our example
|
406 |
+
|
407 |
+
# Cromparison plot
|
408 |
+
df, fig, n, p = create_comparison_plot(n, _lambda)
|
409 |
+
return create_comparison_plot, df, fig, n, p
|
410 |
+
|
411 |
+
|
412 |
+
@app.cell(hide_code=True)
|
413 |
+
def _(df, fig, fig_to_image, mo, n, p):
|
414 |
+
# table of values
|
415 |
+
_styled_df = df.style.format({
|
416 |
+
f'Binomial(n={n}, p={p:.6f})': '{:.6f}',
|
417 |
+
f'Poisson(λ=5)': '{:.6f}',
|
418 |
+
'Difference': '{:.6f}'
|
419 |
+
})
|
420 |
+
|
421 |
+
# Calculate the max absolute difference
|
422 |
+
_max_diff = df['Difference'].max()
|
423 |
+
|
424 |
+
# output
|
425 |
+
_chart = mo.image(fig_to_image(fig), width="100%")
|
426 |
+
_explanation = mo.md(f"**Maximum absolute difference between distributions: {_max_diff:.6f}**")
|
427 |
+
_table = mo.ui.table(df)
|
428 |
+
|
429 |
+
mo.vstack([_chart, _explanation, _table])
|
430 |
+
return
|
431 |
+
|
432 |
+
|
433 |
+
@app.cell(hide_code=True)
|
434 |
+
def _(mo):
|
435 |
+
mo.md(
|
436 |
+
r"""
|
437 |
+
As you can see from the interactive comparison above, as the number of intervals increases, the binomial distribution approaches the Poisson distribution! This is not a coincidence - the Poisson distribution is actually the limiting case of the binomial distribution when:
|
438 |
+
|
439 |
+
- The number of trials $n$ approaches infinity
|
440 |
+
- The probability of success $p$ approaches zero
|
441 |
+
- The product $np = \lambda$ remains constant
|
442 |
+
|
443 |
+
This relationship is why the Poisson distribution is so useful - it's easier to work with than a binomial with a very large number of trials and a very small probability of success.
|
444 |
+
|
445 |
+
## Derivation of the Poisson PMF
|
446 |
+
|
447 |
+
Let's derive the Poisson PMF by taking the limit of the binomial PMF as $n \to \infty$. We start with:
|
448 |
+
|
449 |
+
$P(X=x) = \lim_{n \rightarrow \infty} {n \choose x} (\lambda / n)^x(1-\lambda/n)^{n-x}$
|
450 |
+
|
451 |
+
While this expression looks intimidating, it simplifies nicely:
|
452 |
+
|
453 |
+
\begin{align}
|
454 |
+
P(X=x)
|
455 |
+
&= \lim_{n \rightarrow \infty} {n \choose x} (\lambda / n)^x(1-\lambda/n)^{n-x}
|
456 |
+
&& \text{Start: binomial in the limit}\\
|
457 |
+
&= \lim_{n \rightarrow \infty}
|
458 |
+
{n \choose x} \cdot
|
459 |
+
\frac{\lambda^x}{n^x} \cdot
|
460 |
+
\frac{(1-\lambda/n)^{n}}{(1-\lambda/n)^{x}}
|
461 |
+
&& \text{Expanding the power terms} \\
|
462 |
+
&= \lim_{n \rightarrow \infty}
|
463 |
+
\frac{n!}{(n-x)!x!} \cdot
|
464 |
+
\frac{\lambda^x}{n^x} \cdot
|
465 |
+
\frac{(1-\lambda/n)^{n}}{(1-\lambda/n)^{x}}
|
466 |
+
&& \text{Expanding the binomial term} \\
|
467 |
+
&= \lim_{n \rightarrow \infty}
|
468 |
+
\frac{n!}{(n-x)!x!} \cdot
|
469 |
+
\frac{\lambda^x}{n^x} \cdot
|
470 |
+
\frac{e^{-\lambda}}{(1-\lambda/n)^{x}}
|
471 |
+
&& \text{Using limit rule } \lim_{n \rightarrow \infty}(1-\lambda/n)^{n} = e^{-\lambda}\\
|
472 |
+
&= \lim_{n \rightarrow \infty}
|
473 |
+
\frac{n!}{(n-x)!x!} \cdot
|
474 |
+
\frac{\lambda^x}{n^x} \cdot
|
475 |
+
\frac{e^{-\lambda}}{1}
|
476 |
+
&& \text{As } n \to \infty \text{, } \lambda/n \to 0\\
|
477 |
+
&= \lim_{n \rightarrow \infty}
|
478 |
+
\frac{n!}{(n-x)!} \cdot
|
479 |
+
\frac{1}{x!} \cdot
|
480 |
+
\frac{\lambda^x}{n^x} \cdot
|
481 |
+
e^{-\lambda}
|
482 |
+
&& \text{Rearranging terms}\\
|
483 |
+
&= \lim_{n \rightarrow \infty}
|
484 |
+
\frac{n^x}{1} \cdot
|
485 |
+
\frac{1}{x!} \cdot
|
486 |
+
\frac{\lambda^x}{n^x} \cdot
|
487 |
+
e^{-\lambda}
|
488 |
+
&& \text{As } n \to \infty \text{, } \frac{n!}{(n-x)!} \approx n^x\\
|
489 |
+
&= \lim_{n \rightarrow \infty}
|
490 |
+
\frac{\lambda^x}{x!} \cdot
|
491 |
+
e^{-\lambda}
|
492 |
+
&& \text{Canceling } n^x\\
|
493 |
+
&=
|
494 |
+
\frac{\lambda^x \cdot e^{-\lambda}}{x!}
|
495 |
+
&& \text{Simplifying}\\
|
496 |
+
\end{align}
|
497 |
+
|
498 |
+
This gives us our elegant Poisson PMF formula: $P(X=x) = \frac{\lambda^x \cdot e^{-\lambda}}{x!}$
|
499 |
+
"""
|
500 |
+
)
|
501 |
+
return
|
502 |
+
|
503 |
+
|
504 |
+
@app.cell(hide_code=True)
|
505 |
+
def _(mo):
|
506 |
+
mo.md(
|
507 |
+
r"""
|
508 |
+
## Poisson Distribution in Python
|
509 |
+
|
510 |
+
Python's `scipy.stats` module provides functions to work with the Poisson distribution. Let's see how to calculate probabilities and generate random samples.
|
511 |
+
|
512 |
+
First, let's calculate some probabilities for our ride-sharing example with $\lambda = 5$:
|
513 |
+
"""
|
514 |
+
)
|
515 |
+
return
|
516 |
+
|
517 |
+
|
518 |
+
@app.cell
|
519 |
+
def _(stats):
|
520 |
+
_lambda = 5
|
521 |
+
|
522 |
+
# Calculate probabilities for X = 1, 2, 3
|
523 |
+
p_1 = stats.poisson.pmf(1, _lambda)
|
524 |
+
p_2 = stats.poisson.pmf(2, _lambda)
|
525 |
+
p_3 = stats.poisson.pmf(3, _lambda)
|
526 |
+
|
527 |
+
print(f"P(X=1) = {p_1:.5f}")
|
528 |
+
print(f"P(X=2) = {p_2:.5f}")
|
529 |
+
print(f"P(X=3) = {p_3:.5f}")
|
530 |
+
|
531 |
+
# Calculate cumulative probability P(X ≤ 3)
|
532 |
+
p_leq_3 = stats.poisson.cdf(3, _lambda)
|
533 |
+
print(f"P(X≤3) = {p_leq_3:.5f}")
|
534 |
+
|
535 |
+
# Calculate probability P(X > 10)
|
536 |
+
p_gt_10 = 1 - stats.poisson.cdf(10, _lambda)
|
537 |
+
print(f"P(X>10) = {p_gt_10:.5f}")
|
538 |
+
return p_1, p_2, p_3, p_gt_10, p_leq_3
|
539 |
+
|
540 |
+
|
541 |
+
@app.cell(hide_code=True)
|
542 |
+
def _(mo):
|
543 |
+
mo.md(r"""We can also generate random samples from a Poisson distribution and visualize their distribution:""")
|
544 |
+
return
|
545 |
+
|
546 |
+
|
547 |
+
@app.cell(hide_code=True)
|
548 |
+
def _(np, plt, stats):
|
549 |
+
def create_samples_plot(lambda_value, sample_size=1000):
|
550 |
+
# Random samples
|
551 |
+
samples = stats.poisson.rvs(lambda_value, size=sample_size)
|
552 |
+
|
553 |
+
# theoretical PMF
|
554 |
+
x_values = np.arange(0, max(samples) + 1)
|
555 |
+
pmf_values = stats.poisson.pmf(x_values, lambda_value)
|
556 |
+
|
557 |
+
# histograms to compare
|
558 |
+
fig, ax = plt.subplots(figsize=(10, 6))
|
559 |
+
|
560 |
+
# samples as a histogram
|
561 |
+
ax.hist(samples, bins=np.arange(-0.5, max(samples) + 1.5, 1),
|
562 |
+
alpha=0.7, density=True, label='Random Samples')
|
563 |
+
|
564 |
+
# theoretical PMF
|
565 |
+
ax.plot(x_values, pmf_values, 'ro-', label='Theoretical PMF')
|
566 |
+
|
567 |
+
# labels and title
|
568 |
+
ax.set_xlabel('Number of Events')
|
569 |
+
ax.set_ylabel('Relative Frequency / Probability')
|
570 |
+
ax.set_title(f'1000 Random Samples from Poisson(λ={lambda_value})')
|
571 |
+
ax.legend()
|
572 |
+
ax.grid(alpha=0.3)
|
573 |
+
|
574 |
+
# annotations
|
575 |
+
ax.annotate(f'Sample Mean: {np.mean(samples):.2f}',
|
576 |
+
xy=(0.7, 0.9), xycoords='axes fraction',
|
577 |
+
bbox=dict(boxstyle='round,pad=0.5', fc='yellow', alpha=0.3))
|
578 |
+
ax.annotate(f'Theoretical Mean: {lambda_value:.2f}',
|
579 |
+
xy=(0.7, 0.8), xycoords='axes fraction',
|
580 |
+
bbox=dict(boxstyle='round,pad=0.5', fc='lightgreen', alpha=0.3))
|
581 |
+
|
582 |
+
plt.tight_layout()
|
583 |
+
return plt.gca()
|
584 |
+
|
585 |
+
# Use a lambda value of 5 for this example
|
586 |
+
_lambda = 5
|
587 |
+
create_samples_plot(_lambda)
|
588 |
+
return (create_samples_plot,)
|
589 |
+
|
590 |
+
|
591 |
+
@app.cell(hide_code=True)
|
592 |
+
def _(mo):
|
593 |
+
mo.md(
|
594 |
+
r"""
|
595 |
+
## Changing Time Frames
|
596 |
+
|
597 |
+
One important property of the Poisson distribution is that the rate parameter $\lambda$ scales linearly with the time interval. If events occur at a rate of $\lambda$ per unit time, then over a period of $t$ units, the rate parameter becomes $\lambda \cdot t$.
|
598 |
+
|
599 |
+
For example, if a website receives an average of 5 requests per minute, what is the distribution of requests over a 20-minute period?
|
600 |
+
|
601 |
+
The rate parameter for the 20-minute period would be $\lambda = 5 \cdot 20 = 100$ requests.
|
602 |
+
"""
|
603 |
+
)
|
604 |
+
return
|
605 |
+
|
606 |
+
|
607 |
+
@app.cell(hide_code=True)
|
608 |
+
def _(mo):
|
609 |
+
rate_slider = mo.ui.slider(
|
610 |
+
start = 0.1,
|
611 |
+
stop = 10,
|
612 |
+
step=0.1,
|
613 |
+
value=5,
|
614 |
+
label="Rate per unit time (λ)"
|
615 |
+
)
|
616 |
+
|
617 |
+
time_slider = mo.ui.slider(
|
618 |
+
start = 1,
|
619 |
+
stop = 60,
|
620 |
+
step=1,
|
621 |
+
value=20,
|
622 |
+
label="Time period (t units)"
|
623 |
+
)
|
624 |
+
|
625 |
+
controls = mo.vstack([
|
626 |
+
mo.md("### Adjust Parameters to See How Time Scaling Works"),
|
627 |
+
mo.hstack([rate_slider, time_slider], justify="space-between")
|
628 |
+
])
|
629 |
+
return controls, rate_slider, time_slider
|
630 |
+
|
631 |
+
|
632 |
+
@app.cell
|
633 |
+
def _(controls):
|
634 |
+
controls.center()
|
635 |
+
return
|
636 |
+
|
637 |
+
|
638 |
+
@app.cell(hide_code=True)
|
639 |
+
def _(mo, np, plt, rate_slider, stats, time_slider):
|
640 |
+
def create_time_scaling_plot(rate, time_period):
|
641 |
+
# scaled rate parameter
|
642 |
+
lambda_value = rate * time_period
|
643 |
+
|
644 |
+
# PMF for values
|
645 |
+
max_x = max(30, int(lambda_value * 1.5))
|
646 |
+
x = np.arange(0, max_x + 1)
|
647 |
+
pmf = stats.poisson.pmf(x, lambda_value)
|
648 |
+
|
649 |
+
# plot
|
650 |
+
fig, ax = plt.subplots(figsize=(10, 6))
|
651 |
+
|
652 |
+
# PMF as bars
|
653 |
+
ax.bar(x, pmf, color='royalblue', alpha=0.7,
|
654 |
+
label=f'PMF: Poisson(λ={lambda_value:.1f})')
|
655 |
+
|
656 |
+
# vertical line for mean
|
657 |
+
ax.axvline(x=lambda_value, color='red', linestyle='--', linewidth=2,
|
658 |
+
label=f'Mean = {lambda_value:.1f}')
|
659 |
+
|
660 |
+
# labels and title
|
661 |
+
ax.set_xlabel('Number of Events')
|
662 |
+
ax.set_ylabel('Probability')
|
663 |
+
ax.set_title(f'Poisson Distribution Over {time_period} Units (Rate = {rate}/unit)')
|
664 |
+
|
665 |
+
# better visualization if lambda is large
|
666 |
+
if lambda_value > 10:
|
667 |
+
ax.set_xlim(lambda_value - 4*np.sqrt(lambda_value),
|
668 |
+
lambda_value + 4*np.sqrt(lambda_value))
|
669 |
+
|
670 |
+
ax.legend()
|
671 |
+
ax.grid(alpha=0.3)
|
672 |
+
|
673 |
+
plt.tight_layout()
|
674 |
+
|
675 |
+
# Create relevant info markdown
|
676 |
+
info_text = f"""
|
677 |
+
When the rate is **{rate}** events per unit time and we observe for **{time_period}** units:
|
678 |
+
|
679 |
+
- The expected number of events is **{lambda_value:.1f}**
|
680 |
+
- The variance is also **{lambda_value:.1f}**
|
681 |
+
- The standard deviation is **{np.sqrt(lambda_value):.2f}**
|
682 |
+
- P(X=0) = {stats.poisson.pmf(0, lambda_value):.4f} (probability of no events)
|
683 |
+
- P(X≥10) = {1 - stats.poisson.cdf(9, lambda_value):.4f} (probability of 10 or more events)
|
684 |
+
"""
|
685 |
+
|
686 |
+
return plt.gca(), info_text
|
687 |
+
|
688 |
+
# parameters from sliders
|
689 |
+
_rate = rate_slider.value
|
690 |
+
_time = time_slider.value
|
691 |
+
|
692 |
+
# store
|
693 |
+
_plot, _info_text = create_time_scaling_plot(_rate, _time)
|
694 |
+
|
695 |
+
# Display info as markdown
|
696 |
+
info = mo.md(_info_text)
|
697 |
+
|
698 |
+
mo.vstack([_plot, info], justify="center")
|
699 |
+
return create_time_scaling_plot, info
|
700 |
+
|
701 |
+
|
702 |
+
@app.cell(hide_code=True)
|
703 |
+
def _(mo):
|
704 |
+
mo.md(
|
705 |
+
r"""
|
706 |
+
## 🤔 Test Your Understanding
|
707 |
+
Pick which of these statements about Poisson distributions you think are correct:
|
708 |
+
|
709 |
+
/// details | The variance of a Poisson distribution is always equal to its mean
|
710 |
+
✅ Correct! For a Poisson distribution with parameter $\lambda$, both the mean and variance equal $\lambda$.
|
711 |
+
///
|
712 |
+
|
713 |
+
/// details | The Poisson distribution can be used to model the number of successes in a fixed number of trials
|
714 |
+
❌ Incorrect! That's the binomial distribution. The Poisson distribution models the number of events in a fixed interval of time or space, not a fixed number of trials.
|
715 |
+
///
|
716 |
+
|
717 |
+
/// details | If $X \sim \text{Poisson}(\lambda_1)$ and $Y \sim \text{Poisson}(\lambda_2)$ are independent, then $X + Y \sim \text{Poisson}(\lambda_1 + \lambda_2)$
|
718 |
+
✅ Correct! The sum of independent Poisson random variables is also a Poisson random variable with parameter equal to the sum of the individual parameters.
|
719 |
+
///
|
720 |
+
|
721 |
+
/// details | As $\lambda$ increases, the Poisson distribution approaches a normal distribution
|
722 |
+
✅ Correct! For large values of $\lambda$ (generally $\lambda > 10$), the Poisson distribution is approximately normal with mean $\lambda$ and variance $\lambda$.
|
723 |
+
///
|
724 |
+
|
725 |
+
/// details | The probability of zero events in a Poisson process is always less than the probability of one event
|
726 |
+
❌ Incorrect! For $\lambda < 1$, the probability of zero events ($e^{-\lambda}$) is actually greater than the probability of one event ($\lambda e^{-\lambda}$).
|
727 |
+
///
|
728 |
+
|
729 |
+
/// details | The Poisson distribution has a single parameter $\lambda$, which always equals the average number of events per time period
|
730 |
+
✅ Correct! The parameter $\lambda$ represents the average rate of events, and it uniquely defines the distribution.
|
731 |
+
///
|
732 |
+
"""
|
733 |
+
)
|
734 |
+
return
|
735 |
+
|
736 |
+
|
737 |
+
@app.cell(hide_code=True)
|
738 |
+
def _(mo):
|
739 |
+
mo.md(
|
740 |
+
r"""
|
741 |
+
## Summary
|
742 |
+
|
743 |
+
The Poisson distribution is one of those incredibly useful tools that shows up all over the place. I've always found it fascinating how such a simple formula can model so many real-world phenomena - from website traffic to radioactive decay.
|
744 |
+
|
745 |
+
What makes the Poisson really cool is that it emerges naturally as we try to model rare events occurring over a continuous interval. Remember that visualization where we kept dividing time into smaller and smaller chunks? As we showed, when you take a binomial distribution and let the number of trials approach infinity while keeping the expected value constant, you end up with the elegant Poisson formula.
|
746 |
+
|
747 |
+
The key things to remember about the Poisson distribution:
|
748 |
+
|
749 |
+
- It models the number of events occurring in a fixed interval of time or space, assuming events happen at a constant average rate and independently of each other
|
750 |
+
|
751 |
+
- Its PMF is given by the elegantly simple formula $P(X=k) = \frac{\lambda^k e^{-\lambda}}{k!}$
|
752 |
+
|
753 |
+
- Both the mean and variance equal the parameter $\lambda$, which represents the average number of events per interval
|
754 |
+
|
755 |
+
- It's related to the binomial distribution as a limiting case when $n \to \infty$, $p \to 0$, and $np = \lambda$ remains constant
|
756 |
+
|
757 |
+
- The rate parameter scales linearly with the length of the interval - if events occur at rate $\lambda$ per unit time, then over $t$ units, the parameter becomes $\lambda t$
|
758 |
+
|
759 |
+
From modeling website traffic and customer arrivals to defects in manufacturing and radioactive decay, the Poisson distribution provides a powerful and mathematically elegant way to understand random occurrences in our world.
|
760 |
+
"""
|
761 |
+
)
|
762 |
+
return
|
763 |
+
|
764 |
+
|
765 |
+
@app.cell(hide_code=True)
|
766 |
+
def _(mo):
|
767 |
+
mo.md(r"""Appendix code (helper functions, variables, etc.):""")
|
768 |
+
return
|
769 |
+
|
770 |
+
|
771 |
+
@app.cell
|
772 |
+
def _():
|
773 |
+
import marimo as mo
|
774 |
+
return (mo,)
|
775 |
+
|
776 |
+
|
777 |
+
@app.cell(hide_code=True)
|
778 |
+
def _():
|
779 |
+
import numpy as np
|
780 |
+
import matplotlib.pyplot as plt
|
781 |
+
import scipy.stats as stats
|
782 |
+
import pandas as pd
|
783 |
+
import altair as alt
|
784 |
+
from wigglystuff import TangleSlider
|
785 |
+
return TangleSlider, alt, np, pd, plt, stats
|
786 |
+
|
787 |
+
|
788 |
+
@app.cell(hide_code=True)
|
789 |
+
def _():
|
790 |
+
import io
|
791 |
+
import base64
|
792 |
+
from matplotlib.figure import Figure
|
793 |
+
|
794 |
+
# Helper function to convert mpl figure to an image format mo.image can hopefully handle
|
795 |
+
def fig_to_image(fig):
|
796 |
+
buf = io.BytesIO()
|
797 |
+
fig.savefig(buf, format='png')
|
798 |
+
buf.seek(0)
|
799 |
+
data = f"data:image/png;base64,{base64.b64encode(buf.read()).decode('utf-8')}"
|
800 |
+
return data
|
801 |
+
return Figure, base64, fig_to_image, io
|
802 |
+
|
803 |
+
|
804 |
+
if __name__ == "__main__":
|
805 |
+
app.run()
|
probability/16_continuous_distribution.py
ADDED
@@ -0,0 +1,979 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# /// script
|
2 |
+
# requires-python = ">=3.11"
|
3 |
+
# dependencies = [
|
4 |
+
# "marimo",
|
5 |
+
# "altair==5.5.0",
|
6 |
+
# "matplotlib==3.10.1",
|
7 |
+
# "numpy==2.2.4",
|
8 |
+
# "scipy==1.15.2",
|
9 |
+
# "sympy==1.13.3",
|
10 |
+
# "wigglystuff==0.1.10",
|
11 |
+
# "polars==1.26.0",
|
12 |
+
# ]
|
13 |
+
# ///
|
14 |
+
|
15 |
+
import marimo
|
16 |
+
|
17 |
+
__generated_with = "0.11.26"
|
18 |
+
app = marimo.App(width="medium")
|
19 |
+
|
20 |
+
|
21 |
+
@app.cell(hide_code=True)
|
22 |
+
def _(mo):
|
23 |
+
mo.md(
|
24 |
+
r"""
|
25 |
+
# Continuous Distributions
|
26 |
+
|
27 |
+
_This notebook is a computational companion to ["Probability for Computer Scientists"](https://chrispiech.github.io/probabilityForComputerScientists/en/part2/continuous/), by Stanford professor Chris Piech._
|
28 |
+
|
29 |
+
So far, all the random variables we've explored have been discrete, taking on only specific values (usually integers). Now we'll move into the world of **continuous random variables**, which can take on any real number value. Continuous random variables are used to model measurements with arbitrary precision like height, weight, time, and many natural phenomena.
|
30 |
+
"""
|
31 |
+
)
|
32 |
+
return
|
33 |
+
|
34 |
+
|
35 |
+
@app.cell(hide_code=True)
|
36 |
+
def _(mo):
|
37 |
+
mo.md(
|
38 |
+
r"""
|
39 |
+
## From Discrete to Continuous
|
40 |
+
|
41 |
+
To make the transition from discrete to continuous random variables, let's start with a thought experiment:
|
42 |
+
|
43 |
+
> Imagine you're running to catch a bus. You know you'll arrive at 2:15pm, but you don't know exactly when the bus will arrive. You want to model the bus arrival time (in minutes past 2pm) as a random variable $T$ so you can calculate the probability that you'll wait more than five minutes: $P(15 < T < 20)$.
|
44 |
+
|
45 |
+
This immediately highlights a key difference from discrete distributions. For discrete distributions, we described the probability that a random variable takes on exact values. But this doesn't make sense for continuous values like time.
|
46 |
+
|
47 |
+
For example:
|
48 |
+
|
49 |
+
- What's the probability the bus arrives at exactly 2:17pm and 12.12333911102389234 seconds?
|
50 |
+
- What's the probability of a child being born weighing exactly 3.523112342234 kilograms?
|
51 |
+
|
52 |
+
These questions don't have meaningful answers because real-world measurements can have infinite precision. The probability of a continuous random variable taking on any specific exact value is actually zero!
|
53 |
+
|
54 |
+
### Visualizing the Transition
|
55 |
+
|
56 |
+
Let's visualize this transition from discrete to continuous:
|
57 |
+
"""
|
58 |
+
)
|
59 |
+
return
|
60 |
+
|
61 |
+
|
62 |
+
@app.cell(hide_code=True)
|
63 |
+
def _(fig_to_image, mo, np, plt):
|
64 |
+
def create_discretization_plot():
|
65 |
+
fig, axs = plt.subplots(1, 3, figsize=(15, 5))
|
66 |
+
|
67 |
+
# values from 0 to 30 minutes
|
68 |
+
x = np.linspace(0, 30, 1000)
|
69 |
+
|
70 |
+
# Triangular distribution peaked at 15 minutes)
|
71 |
+
y = np.where(x <= 15, x/15, (30-x)/15)
|
72 |
+
y = y / np.trapezoid(y, x) # Normalize
|
73 |
+
|
74 |
+
# 5-minute chunks (first plot)
|
75 |
+
bins = np.arange(0, 31, 5)
|
76 |
+
hist, _ = np.histogram(x, bins=bins, weights=y)
|
77 |
+
width = bins[1] - bins[0]
|
78 |
+
axs[0].bar(bins[:-1], hist * width, width=width, alpha=0.7,
|
79 |
+
color='royalblue', edgecolor='black')
|
80 |
+
axs[0].set_xlim(0, 30)
|
81 |
+
axs[0].set_title('5-Minute Intervals')
|
82 |
+
axs[0].set_xlabel('Minutes past 2pm')
|
83 |
+
axs[0].set_ylabel('Probability')
|
84 |
+
|
85 |
+
# 15-20 minute range more prominent
|
86 |
+
axs[0].bar([15], hist[3] * width, width=width, alpha=0.7,
|
87 |
+
color='darkorange', edgecolor='black')
|
88 |
+
|
89 |
+
# 2.5-minute chunks (second plot)
|
90 |
+
bins = np.arange(0, 31, 2.5)
|
91 |
+
hist, _ = np.histogram(x, bins=bins, weights=y)
|
92 |
+
width = bins[1] - bins[0]
|
93 |
+
axs[1].bar(bins[:-1], hist * width, width=width, alpha=0.7,
|
94 |
+
color='royalblue', edgecolor='black')
|
95 |
+
axs[1].set_xlim(0, 30)
|
96 |
+
axs[1].set_title('2.5-Minute Intervals')
|
97 |
+
axs[1].set_xlabel('Minutes past 2pm')
|
98 |
+
|
99 |
+
# Make 15-20 minute range more prominent
|
100 |
+
highlight_indices = [6, 7]
|
101 |
+
for idx in highlight_indices:
|
102 |
+
axs[1].bar([bins[idx]], hist[idx] * width, width=width, alpha=0.7,
|
103 |
+
color='darkorange', edgecolor='black')
|
104 |
+
|
105 |
+
# Continuous distribution (third plot)
|
106 |
+
axs[2].plot(x, y, 'royalblue', linewidth=2)
|
107 |
+
axs[2].set_xlim(0, 30)
|
108 |
+
axs[2].set_title('Continuous Distribution')
|
109 |
+
axs[2].set_xlabel('Minutes past 2pm')
|
110 |
+
axs[2].set_ylabel('Probability Density')
|
111 |
+
|
112 |
+
# Highlight the AUC between 15 and 20
|
113 |
+
mask = (x >= 15) & (x <= 20)
|
114 |
+
axs[2].fill_between(x[mask], y[mask], color='darkorange', alpha=0.7)
|
115 |
+
|
116 |
+
# Mark 15-20 minute interval
|
117 |
+
for ax in axs:
|
118 |
+
ax.axvline(x=15, color='red', linestyle='--', alpha=0.5)
|
119 |
+
ax.axvline(x=20, color='red', linestyle='--', alpha=0.5)
|
120 |
+
ax.set_xticks([0, 5, 10, 15, 20, 25, 30])
|
121 |
+
ax.grid(alpha=0.3)
|
122 |
+
|
123 |
+
plt.tight_layout()
|
124 |
+
plt.gca()
|
125 |
+
return fig
|
126 |
+
|
127 |
+
# Plot creation & conversion
|
128 |
+
_fig = create_discretization_plot()
|
129 |
+
_img = mo.image(fig_to_image(_fig), width="100%")
|
130 |
+
|
131 |
+
_explanation = mo.md(
|
132 |
+
r"""
|
133 |
+
The figure above illustrates our transition from discrete to continuous thinking:
|
134 |
+
|
135 |
+
- **Left**: Time divided into 5-minute chunks, where the probability of the bus arriving between 15-20 minutes (highlighted in orange) is a single value.
|
136 |
+
- **Center**: Time divided into finer 2.5-minute chunks, where the 15-20 minute range consists of two chunks.
|
137 |
+
- **Right**: In the limit, we get a continuous probability density function where the probability is the area under the curve between 15 and 20 minutes.
|
138 |
+
|
139 |
+
As we make our chunks smaller and smaller, we eventually arrive at a smooth function that gives us the probability density at each point.
|
140 |
+
"""
|
141 |
+
)
|
142 |
+
|
143 |
+
mo.vstack([_img, _explanation])
|
144 |
+
return (create_discretization_plot,)
|
145 |
+
|
146 |
+
|
147 |
+
@app.cell(hide_code=True)
|
148 |
+
def _(mo):
|
149 |
+
mo.md(
|
150 |
+
r"""
|
151 |
+
## Probability Density Functions
|
152 |
+
|
153 |
+
In the world of discrete random variables, we used **Probability Mass Functions (PMFs)** to describe the probability of a random variable taking on specific values. In the continuous world, we need a different approach.
|
154 |
+
|
155 |
+
For continuous random variables, we use a **Probability Density Function (PDF)** which defines the relative likelihood that a random variable takes on a particular value. We traditionally denote the PDF with the symbol $f$ and write it as:
|
156 |
+
|
157 |
+
$$f(X=x) \quad \text{or simply} \quad f(x)$$
|
158 |
+
|
159 |
+
Where the lowercase $x$ implies that we're talking about the relative likelihood of a continuous random variable which is the uppercase $X$.
|
160 |
+
|
161 |
+
### Key Properties of PDFs
|
162 |
+
|
163 |
+
A **Probability Density Function (PDF)** $f(x)$ for a continuous random variable $X$ has these key properties:
|
164 |
+
|
165 |
+
1. The probability that $X$ takes a value in the interval $[a, b]$ is:
|
166 |
+
|
167 |
+
$$P(a \leq X \leq b) = \int_a^b f(x) \, dx$$
|
168 |
+
|
169 |
+
2. The PDF must be non-negative everywhere:
|
170 |
+
|
171 |
+
$$f(x) \geq 0 \text{ for all } x$$
|
172 |
+
|
173 |
+
3. The total probability must sum to 1:
|
174 |
+
|
175 |
+
$$\int_{-\infty}^{\infty} f(x) \, dx = 1$$
|
176 |
+
|
177 |
+
4. The probability that $X$ takes any specific exact value is 0:
|
178 |
+
|
179 |
+
$$P(X = a) = \int_a^a f(x) \, dx = 0$$
|
180 |
+
|
181 |
+
This last property highlights a key difference from discrete distributions: the probability of a continuous random variable taking on an exact value is always 0. Probabilities only make sense when talking about ranges of values.
|
182 |
+
|
183 |
+
### Caution: Density ≠ Probability
|
184 |
+
|
185 |
+
A common misconception is to think of $f(x)$ as a probability. It is instead a **probability density**, representing probability per unit of $x$. The values of $f(x)$ can actually exceed 1, as long as the total area under the curve equals 1.
|
186 |
+
|
187 |
+
The interpretation of $f(x)$ is only meaningful when:
|
188 |
+
|
189 |
+
1. We integrate over a range to get a probability, or
|
190 |
+
2. We compare densities at different points to determine relative likelihoods.
|
191 |
+
"""
|
192 |
+
)
|
193 |
+
return
|
194 |
+
|
195 |
+
|
196 |
+
@app.cell(hide_code=True)
|
197 |
+
def _(TangleSlider, mo):
|
198 |
+
# Create sliders for a and b
|
199 |
+
a_slider = mo.ui.anywidget(TangleSlider(
|
200 |
+
amount=1,
|
201 |
+
min_value=0,
|
202 |
+
max_value=5,
|
203 |
+
step=0.1,
|
204 |
+
digits=1
|
205 |
+
))
|
206 |
+
|
207 |
+
b_slider = mo.ui.anywidget(TangleSlider(
|
208 |
+
amount=3,
|
209 |
+
min_value=0,
|
210 |
+
max_value=5,
|
211 |
+
step=0.1,
|
212 |
+
digits=1
|
213 |
+
))
|
214 |
+
|
215 |
+
# Distribution selector
|
216 |
+
distribution_radio = mo.ui.radio(
|
217 |
+
options=["uniform", "triangular", "exponential"],
|
218 |
+
value="uniform",
|
219 |
+
label="Distribution Type"
|
220 |
+
)
|
221 |
+
|
222 |
+
# Controls layout
|
223 |
+
_controls = mo.vstack([
|
224 |
+
mo.md("### Visualizing Probability as Area Under the PDF Curve"),
|
225 |
+
mo.md("Adjust sliders to change the interval $[a, b]$ and see how the probability changes:"),
|
226 |
+
mo.hstack([
|
227 |
+
mo.md("Lower bound (a):"),
|
228 |
+
a_slider,
|
229 |
+
mo.md("Upper bound (b):"),
|
230 |
+
b_slider
|
231 |
+
], justify="start"),
|
232 |
+
distribution_radio
|
233 |
+
])
|
234 |
+
_controls
|
235 |
+
return a_slider, b_slider, distribution_radio
|
236 |
+
|
237 |
+
|
238 |
+
@app.cell(hide_code=True)
|
239 |
+
def _(
|
240 |
+
a_slider,
|
241 |
+
b_slider,
|
242 |
+
create_pdf_visualization,
|
243 |
+
distribution_radio,
|
244 |
+
fig_to_image,
|
245 |
+
mo,
|
246 |
+
):
|
247 |
+
a = a_slider.amount
|
248 |
+
b = b_slider.amount
|
249 |
+
distribution = distribution_radio.value
|
250 |
+
|
251 |
+
# Ensure a < b
|
252 |
+
if a > b:
|
253 |
+
a, b = b, a
|
254 |
+
|
255 |
+
# visualization
|
256 |
+
_fig, _probability = create_pdf_visualization(a, b, distribution)
|
257 |
+
|
258 |
+
# Display visualization
|
259 |
+
_img = mo.image(fig_to_image(_fig), width="100%")
|
260 |
+
|
261 |
+
# Add appropriate explanation
|
262 |
+
if distribution == "uniform":
|
263 |
+
_explanation = mo.md(
|
264 |
+
f"""
|
265 |
+
In the **uniform distribution**, all values between 0 and 5 are equally likely.
|
266 |
+
The probability density is constant at 0.2 (which is 1/5, ensuring the total area is 1).
|
267 |
+
For a uniform distribution, the probability that $X$ is in the interval $[{a:.1f}, {b:.1f}]$
|
268 |
+
is simply proportional to the width of the interval: $P({a:.1f} \leq X \leq {b:.1f}) = {_probability:.4f}$
|
269 |
+
Note that while the PDF has a constant value of 0.2, this is not a probability but a density!
|
270 |
+
"""
|
271 |
+
)
|
272 |
+
elif distribution == "triangular":
|
273 |
+
_explanation = mo.md(
|
274 |
+
f"""
|
275 |
+
In this **triangular distribution**, the probability density increases linearly from 0 to 2.5,
|
276 |
+
then decreases linearly from 2.5 to 5.
|
277 |
+
The distribution's peak is at x = 2.5, where the value is highest.
|
278 |
+
The orange shaded area representing $P({a:.1f} \leq X \leq {b:.1f}) = {_probability:.4f}$
|
279 |
+
is calculated by integrating the PDF over the interval.
|
280 |
+
"""
|
281 |
+
)
|
282 |
+
else:
|
283 |
+
_explanation = mo.md(
|
284 |
+
f"""
|
285 |
+
The **exponential distribution** (with λ = 0.5) models the time between events in a Poisson process.
|
286 |
+
Unlike the uniform and triangular distributions, the exponential distribution has infinite support
|
287 |
+
(extends from 0 to infinity). The probability density decreases exponentially as x increases.
|
288 |
+
The orange shaded area representing $P({a:.1f} \leq X \leq {b:.1f}) = {_probability:.4f}$
|
289 |
+
is calculated by integrating $f(x) = 0.5e^{{-0.5x}}$ over the interval.
|
290 |
+
"""
|
291 |
+
)
|
292 |
+
mo.vstack([_img, _explanation])
|
293 |
+
return a, b, distribution
|
294 |
+
|
295 |
+
|
296 |
+
@app.cell(hide_code=True)
|
297 |
+
def _(mo):
|
298 |
+
mo.md(
|
299 |
+
r"""
|
300 |
+
## Cumulative Distribution Function
|
301 |
+
|
302 |
+
Since working with PDFs requires solving integrals to find probabilities, we often use the **Cumulative Distribution Function (CDF)** as a more convenient tool.
|
303 |
+
|
304 |
+
The CDF $F(x)$ for a continuous random variable $X$ is defined as:
|
305 |
+
|
306 |
+
$$F(x) = P(X \leq x) = \int_{-\infty}^{x} f(t)\,dt$$
|
307 |
+
|
308 |
+
where $f(t)$ is the PDF of $X$.
|
309 |
+
|
310 |
+
### Properties of CDFs
|
311 |
+
|
312 |
+
A CDF $F(x)$ has these key properties:
|
313 |
+
|
314 |
+
1. $F(x)$ is always non-decreasing: if $a < b$, then $F(a) \leq F(b)$
|
315 |
+
2. $\lim_{x \to -\infty} F(x) = 0$ and $\lim_{x \to \infty} F(x) = 1$
|
316 |
+
3. $F(x)$ is right-continuous: $\lim_{h \to 0^+} F(x+h) = F(x)$
|
317 |
+
|
318 |
+
### Using the CDF to Calculate Probabilities
|
319 |
+
|
320 |
+
The CDF is extremely useful because it allows us to calculate various probabilities without having to perform integrals each time:
|
321 |
+
|
322 |
+
| Probability Query | Solution | Explanation |
|
323 |
+
|-------------------|----------|-------------|
|
324 |
+
| $P(X < a)$ | $F(a)$ | Definition of the CDF |
|
325 |
+
| $P(X \leq a)$ | $F(a)$ | For continuous distributions, $P(X = a) = 0$ |
|
326 |
+
| $P(X > a)$ | $1 - F(a)$ | Since $P(X \leq a) + P(X > a) = 1$ |
|
327 |
+
| $P(a < X < b)$ | $F(b) - F(a)$ | Since $F(a) + P(a < X < b) = F(b)$ |
|
328 |
+
| $P(a \leq X \leq b)$ | $F(b) - F(a)$ | Since $P(X = a) = P(X = b) = 0$ |
|
329 |
+
|
330 |
+
For discrete random variables, the CDF is also defined but it's less commonly used:
|
331 |
+
|
332 |
+
$$F_X(a) = \sum_{i \leq a} P(X = i)$$
|
333 |
+
|
334 |
+
The CDF for discrete distributions is a step function, increasing at each point in the support of the random variable.
|
335 |
+
"""
|
336 |
+
)
|
337 |
+
return
|
338 |
+
|
339 |
+
|
340 |
+
@app.cell(hide_code=True)
|
341 |
+
def _(fig_to_image, mo, np, plt):
|
342 |
+
def create_pdf_cdf_comparison():
|
343 |
+
fig, axs = plt.subplots(3, 2, figsize=(12, 10))
|
344 |
+
|
345 |
+
# x-values
|
346 |
+
x = np.linspace(-1, 6, 1000)
|
347 |
+
|
348 |
+
# 1. Uniform Distribution
|
349 |
+
# PDF
|
350 |
+
pdf_uniform = np.where((x >= 0) & (x <= 5), 0.2, 0)
|
351 |
+
axs[0, 0].plot(x, pdf_uniform, 'b-', linewidth=2)
|
352 |
+
axs[0, 0].set_title('Uniform PDF')
|
353 |
+
axs[0, 0].set_ylabel('Density')
|
354 |
+
axs[0, 0].grid(alpha=0.3)
|
355 |
+
|
356 |
+
# CDF
|
357 |
+
cdf_uniform = np.zeros_like(x)
|
358 |
+
for i, val in enumerate(x):
|
359 |
+
if val < 0:
|
360 |
+
cdf_uniform[i] = 0
|
361 |
+
elif val > 5:
|
362 |
+
cdf_uniform[i] = 1
|
363 |
+
else:
|
364 |
+
cdf_uniform[i] = val / 5
|
365 |
+
|
366 |
+
axs[0, 1].plot(x, cdf_uniform, 'r-', linewidth=2)
|
367 |
+
axs[0, 1].set_title('Uniform CDF')
|
368 |
+
axs[0, 1].set_ylabel('Probability')
|
369 |
+
axs[0, 1].grid(alpha=0.3)
|
370 |
+
|
371 |
+
# 2. Triangular Distribution
|
372 |
+
# PDF
|
373 |
+
pdf_triangular = np.where(x <= 2.5, x/6.25, (5-x)/6.25)
|
374 |
+
pdf_triangular = np.where((x < 0) | (x > 5), 0, pdf_triangular)
|
375 |
+
|
376 |
+
axs[1, 0].plot(x, pdf_triangular, 'b-', linewidth=2)
|
377 |
+
axs[1, 0].set_title('Triangular PDF')
|
378 |
+
axs[1, 0].set_ylabel('Density')
|
379 |
+
axs[1, 0].grid(alpha=0.3)
|
380 |
+
|
381 |
+
# CDF
|
382 |
+
cdf_triangular = np.zeros_like(x)
|
383 |
+
for i, val in enumerate(x):
|
384 |
+
if val <= 0:
|
385 |
+
cdf_triangular[i] = 0
|
386 |
+
elif val >= 5:
|
387 |
+
cdf_triangular[i] = 1
|
388 |
+
else:
|
389 |
+
# For x ≤ 2.5: CDF = x²/(2 *6 .25)
|
390 |
+
# For x > 2.5: CDF = 1 - (5 - x)²/(2 * 6.25)
|
391 |
+
if val <= 2.5:
|
392 |
+
cdf_triangular[i] = (val**2) / (2 * 6.25)
|
393 |
+
else:
|
394 |
+
cdf_triangular[i] = 1 - ((5 - val)**2) / (2 * 6.25)
|
395 |
+
|
396 |
+
axs[1, 1].plot(x, cdf_triangular, 'r-', linewidth=2)
|
397 |
+
axs[1, 1].set_title('Triangular CDF')
|
398 |
+
axs[1, 1].set_ylabel('Probability')
|
399 |
+
axs[1, 1].grid(alpha=0.3)
|
400 |
+
|
401 |
+
# 3. Exponential Distribution
|
402 |
+
# PDF
|
403 |
+
lambda_param = 0.5
|
404 |
+
pdf_exponential = np.where(x >= 0, lambda_param * np.exp(-lambda_param * x), 0)
|
405 |
+
|
406 |
+
axs[2, 0].plot(x, pdf_exponential, 'b-', linewidth=2)
|
407 |
+
axs[2, 0].set_title('Exponential PDF (λ=0.5)')
|
408 |
+
axs[2, 0].set_xlabel('x')
|
409 |
+
axs[2, 0].set_ylabel('Density')
|
410 |
+
axs[2, 0].grid(alpha=0.3)
|
411 |
+
|
412 |
+
# CDF
|
413 |
+
cdf_exponential = np.where(x < 0, 0, 1 - np.exp(-lambda_param * x))
|
414 |
+
|
415 |
+
axs[2, 1].plot(x, cdf_exponential, 'r-', linewidth=2)
|
416 |
+
axs[2, 1].set_title('Exponential CDF (λ=0.5)')
|
417 |
+
axs[2, 1].set_xlabel('x')
|
418 |
+
axs[2, 1].set_ylabel('Probability')
|
419 |
+
axs[2, 1].grid(alpha=0.3)
|
420 |
+
|
421 |
+
# Common x-limits
|
422 |
+
for ax in axs.flatten():
|
423 |
+
ax.set_xlim(-0.5, 5.5)
|
424 |
+
if ax in axs[:, 0]: # PDF plots
|
425 |
+
ax.set_ylim(-0.05, max(0.5, max(pdf_triangular)*1.1))
|
426 |
+
else: # CDF plots
|
427 |
+
ax.set_ylim(-0.05, 1.05)
|
428 |
+
|
429 |
+
plt.tight_layout()
|
430 |
+
plt.gca()
|
431 |
+
return fig
|
432 |
+
|
433 |
+
# Create visualization
|
434 |
+
_fig = create_pdf_cdf_comparison()
|
435 |
+
_img = mo.image(fig_to_image(_fig), width="100%")
|
436 |
+
|
437 |
+
_explanation = mo.md(
|
438 |
+
r"""
|
439 |
+
The figure above compares the Probability Density Functions (PDFs) on the left with their corresponding Cumulative Distribution Functions (CDFs) on the right for three common distributions:
|
440 |
+
|
441 |
+
1. **Uniform Distribution**:
|
442 |
+
|
443 |
+
- PDF: Constant value (0.2) across the support range [0, 5]
|
444 |
+
- CDF: Linear increase from 0 to 1 across the support range
|
445 |
+
|
446 |
+
2. **Triangular Distribution**:
|
447 |
+
|
448 |
+
- PDF: Linearly increases then decreases, forming a triangle shape
|
449 |
+
- CDF: Increases quadratically up to the peak, then approaches 1 quadratically
|
450 |
+
|
451 |
+
3. **Exponential Distribution**:
|
452 |
+
|
453 |
+
- PDF: Starts at λ=0.5 and decreases exponentially
|
454 |
+
- CDF: Starts at 0 and approaches 1 exponentially (never quite reaching 1)
|
455 |
+
|
456 |
+
/// NOTE
|
457 |
+
The common properties of all CDFs:
|
458 |
+
|
459 |
+
- They are non-decreasing functions
|
460 |
+
- They start at 0 (for x = -∞) and approach or reach 1 (for x = ∞)
|
461 |
+
- The slope of the CDF at any point equals the PDF value at that point
|
462 |
+
"""
|
463 |
+
)
|
464 |
+
|
465 |
+
mo.vstack([_img, _explanation])
|
466 |
+
return (create_pdf_cdf_comparison,)
|
467 |
+
|
468 |
+
|
469 |
+
@app.cell(hide_code=True)
|
470 |
+
def _(mo):
|
471 |
+
mo.md(
|
472 |
+
r"""
|
473 |
+
## Solving for Constants in PDFs
|
474 |
+
|
475 |
+
Many PDFs contain a constant that needs to be determined to ensure the total probability equals 1. Let's work through an example to understand how to solve for these constants.
|
476 |
+
|
477 |
+
### Example: Finding the Constant $C$
|
478 |
+
|
479 |
+
Let $X$ be a continuous random variable with PDF:
|
480 |
+
|
481 |
+
$$f(x) = \begin{cases}
|
482 |
+
C(4x - 2x^2) & \text{when } 0 < x < 2 \\
|
483 |
+
0 & \text{otherwise}
|
484 |
+
\end{cases}$$
|
485 |
+
|
486 |
+
In this function, $C$ is a constant we need to determine. Since we know the PDF must integrate to 1:
|
487 |
+
|
488 |
+
\begin{align}
|
489 |
+
&\int_0^2 C(4x - 2x^2) \, dx = 1 \\
|
490 |
+
&C\left(2x^2 - \frac{2x^3}{3}\right)\bigg|_0^2 = 1 \\
|
491 |
+
&C\left[\left(8 - \frac{16}{3}\right) - 0 \right] = 1 \\
|
492 |
+
&C\left(\frac{24 - 16}{3}\right) = 1 \\
|
493 |
+
&C\left(\frac{8}{3}\right) = 1 \\
|
494 |
+
&C = \frac{3}{8}
|
495 |
+
\end{align}
|
496 |
+
|
497 |
+
Now that we know $C = \frac{3}{8}$, we can compute probabilities. For example, what is $P(X > 1)$?
|
498 |
+
|
499 |
+
\begin{align}
|
500 |
+
P(X > 1)
|
501 |
+
&= \int_1^{\infty}f(x) \, dx \\
|
502 |
+
&= \int_1^2 \frac{3}{8}(4x - 2x^2) \, dx \\
|
503 |
+
&= \frac{3}{8}\left(2x^2 - \frac{2x^3}{3}\right)\bigg|_1^2 \\
|
504 |
+
&= \frac{3}{8}\left[\left(8 - \frac{16}{3}\right) - \left(2 - \frac{2}{3}\right)\right] \\
|
505 |
+
&= \frac{3}{8}\left[\left(8 - \frac{16}{3}\right) - \left(\frac{6 - 2}{3}\right)\right] \\
|
506 |
+
&= \frac{3}{8}\left[\left(\frac{24 - 16}{3}\right) - \left(\frac{4}{3}\right)\right] \\
|
507 |
+
&= \frac{3}{8}\left[\left(\frac{8}{3}\right) - \left(\frac{4}{3}\right)\right] \\
|
508 |
+
&= \frac{3}{8} \cdot \frac{4}{3} \\
|
509 |
+
&= \frac{1}{2}
|
510 |
+
\end{align}
|
511 |
+
|
512 |
+
Let's visualize this distribution and verify our results:
|
513 |
+
"""
|
514 |
+
)
|
515 |
+
return
|
516 |
+
|
517 |
+
|
518 |
+
@app.cell(hide_code=True)
|
519 |
+
def _(
|
520 |
+
create_example_pdf_visualization,
|
521 |
+
fig_to_image,
|
522 |
+
mo,
|
523 |
+
symbolic_calculation,
|
524 |
+
):
|
525 |
+
# Create visualization
|
526 |
+
_fig = create_example_pdf_visualization()
|
527 |
+
_img = mo.image(fig_to_image(_fig), width="100%")
|
528 |
+
|
529 |
+
# Symbolic calculation
|
530 |
+
_sympy_verification = mo.md(symbolic_calculation())
|
531 |
+
|
532 |
+
_explanation = mo.md(
|
533 |
+
r"""
|
534 |
+
The figure above shows:
|
535 |
+
|
536 |
+
1. **Left**: The PDF $f(x) = \frac{3}{8}(4x - 2x^2)$ for $0 < x < 2$, with the area representing P(X > 1) shaded in orange.
|
537 |
+
2. **Right**: The corresponding CDF, showing F(1) = 0.5 and thus P(X > 1) = 1 - F(1) = 0.5.
|
538 |
+
|
539 |
+
Notice how we:
|
540 |
+
|
541 |
+
1. First determined the constant C = 3/8 by ensuring the total area under the PDF equals 1
|
542 |
+
2. Used this value to calculate specific probabilities like P(X > 1)
|
543 |
+
3. Verified our results both graphically and symbolically
|
544 |
+
"""
|
545 |
+
)
|
546 |
+
mo.vstack([_img, _sympy_verification, _explanation])
|
547 |
+
return
|
548 |
+
|
549 |
+
|
550 |
+
@app.cell(hide_code=True)
|
551 |
+
def _(mo):
|
552 |
+
mo.md(
|
553 |
+
r"""
|
554 |
+
## Expectation and Variance of Continuous Random Variables
|
555 |
+
|
556 |
+
Just as with discrete random variables, we can calculate the expectation and variance of continuous random variables. The main difference is that we use integrals instead of sums.
|
557 |
+
|
558 |
+
### Expectation (Mean)
|
559 |
+
|
560 |
+
For a continuous random variable $X$ with PDF $f(x)$, the expectation is:
|
561 |
+
|
562 |
+
$$E[X] = \int_{-\infty}^{\infty} x \cdot f(x) \, dx$$
|
563 |
+
|
564 |
+
More generally, for any function $g(X)$:
|
565 |
+
|
566 |
+
$$E[g(X)] = \int_{-\infty}^{\infty} g(x) \cdot f(x) \, dx$$
|
567 |
+
|
568 |
+
### Variance
|
569 |
+
|
570 |
+
The variance is defined the same way as for discrete random variables:
|
571 |
+
|
572 |
+
$$\text{Var}(X) = E[(X - \mu)^2] = E[X^2] - (E[X])^2$$
|
573 |
+
|
574 |
+
where $\mu = E[X]$ is the mean of $X$.
|
575 |
+
|
576 |
+
To calculate $E[X^2]$, we use:
|
577 |
+
|
578 |
+
$$E[X^2] = \int_{-\infty}^{\infty} x^2 \cdot f(x) \, dx$$
|
579 |
+
|
580 |
+
### Properties
|
581 |
+
|
582 |
+
The following properties hold for both continuous and discrete random variables:
|
583 |
+
|
584 |
+
1. $E[aX + b] = aE[X] + b$ for constants $a$ and $b$
|
585 |
+
2. $\text{Var}(aX + b) = a^2 \text{Var}(X)$ for constants $a$ and $b$
|
586 |
+
|
587 |
+
Let's calculate the expectation and variance for our example PDF:
|
588 |
+
"""
|
589 |
+
)
|
590 |
+
return
|
591 |
+
|
592 |
+
|
593 |
+
@app.cell(hide_code=True)
|
594 |
+
def _(fig_to_image, mo, np, plt, sympy):
|
595 |
+
# Symbolic calculation of expectation and variance
|
596 |
+
def symbolic_stats_calc():
|
597 |
+
x = sympy.symbols('x')
|
598 |
+
C = sympy.Rational(3, 8)
|
599 |
+
|
600 |
+
# Define the PDF
|
601 |
+
pdf_expr = C * (4*x - 2*x**2)
|
602 |
+
|
603 |
+
# Calculate expectation
|
604 |
+
E_X = sympy.integrate(x * pdf_expr, (x, 0, 2))
|
605 |
+
|
606 |
+
# Calculate E[X²]
|
607 |
+
E_X2 = sympy.integrate(x**2 * pdf_expr, (x, 0, 2))
|
608 |
+
|
609 |
+
# Calculate variance
|
610 |
+
Var_X = E_X2 - E_X**2
|
611 |
+
|
612 |
+
# Calculate standard deviation
|
613 |
+
Std_X = sympy.sqrt(Var_X)
|
614 |
+
|
615 |
+
return E_X, E_X2, Var_X, Std_X
|
616 |
+
|
617 |
+
# Get symbolic results
|
618 |
+
E_X, E_X2, Var_X, Std_X = symbolic_stats_calc()
|
619 |
+
|
620 |
+
# Numerical values for plotting
|
621 |
+
E_X_val = float(E_X)
|
622 |
+
Var_X_val = float(Var_X)
|
623 |
+
Std_X_val = float(Std_X)
|
624 |
+
|
625 |
+
def create_expectation_variance_vis():
|
626 |
+
"""Create visualization showing mean and variance for the example PDF."""
|
627 |
+
fig, ax = plt.subplots(figsize=(10, 6))
|
628 |
+
|
629 |
+
# x-values
|
630 |
+
x = np.linspace(-0.5, 2.5, 1000)
|
631 |
+
|
632 |
+
# PDF function
|
633 |
+
C = 3/8
|
634 |
+
pdf = np.where((x > 0) & (x < 2), C * (4*x - 2*x**2), 0)
|
635 |
+
|
636 |
+
# Plot the PDF
|
637 |
+
ax.plot(x, pdf, 'b-', linewidth=2, label='PDF')
|
638 |
+
ax.fill_between(x, pdf, where=(x > 0) & (x < 2), alpha=0.3, color='blue')
|
639 |
+
|
640 |
+
# Mark the mean
|
641 |
+
ax.axvline(x=E_X_val, color='r', linestyle='--', linewidth=2,
|
642 |
+
label=f'Mean (E[X] = {E_X_val:.3f})')
|
643 |
+
|
644 |
+
# Mark the standard deviation range
|
645 |
+
ax.axvspan(E_X_val - Std_X_val, E_X_val + Std_X_val, alpha=0.2, color='green',
|
646 |
+
label=f'±1 Std Dev ({Std_X_val:.3f})')
|
647 |
+
|
648 |
+
# Add labels and title
|
649 |
+
ax.set_xlabel('x')
|
650 |
+
ax.set_ylabel('Probability Density')
|
651 |
+
ax.set_title('PDF with Mean and Variance')
|
652 |
+
ax.legend()
|
653 |
+
ax.grid(alpha=0.3)
|
654 |
+
|
655 |
+
# Set x-limits
|
656 |
+
ax.set_xlim(-0.25, 2.25)
|
657 |
+
|
658 |
+
plt.tight_layout()
|
659 |
+
return fig
|
660 |
+
|
661 |
+
# Create the visualization
|
662 |
+
_fig = create_expectation_variance_vis()
|
663 |
+
_img = mo.image(fig_to_image(_fig), width="100%")
|
664 |
+
|
665 |
+
# Detailed calculations for our example
|
666 |
+
_calculations = mo.md(
|
667 |
+
f"""
|
668 |
+
### Calculating Expectation and Variance for Our Example
|
669 |
+
|
670 |
+
Let's calculate the expectation and variance for the PDF:
|
671 |
+
|
672 |
+
$$f(x) = \\begin{{cases}}
|
673 |
+
\\frac{{3}}{{8}}(4x - 2x^2) & \\text{{when }} 0 < x < 2 \\\\
|
674 |
+
0 & \\text{{otherwise}}
|
675 |
+
\\end{{cases}}$$
|
676 |
+
|
677 |
+
#### Expectation Calculation
|
678 |
+
|
679 |
+
$$E[X] = \\int_{{-\\infty}}^{{\\infty}} x \\cdot f(x) \\, dx = \\int_0^2 x \\cdot \\frac{{3}}{{8}}(4x - 2x^2) \\, dx$$
|
680 |
+
|
681 |
+
$$E[X] = \\frac{{3}}{{8}} \\int_0^2 (4x^2 - 2x^3) \\, dx = \\frac{{3}}{{8}} \\left[ \\frac{{4x^3}}{{3}} - \\frac{{2x^4}}{{4}} \\right]_0^2$$
|
682 |
+
|
683 |
+
$$E[X] = \\frac{{3}}{{8}} \\left[ \\frac{{4 \\cdot 2^3}}{{3}} - \\frac{{2 \\cdot 2^4}}{{4}} - 0 \\right] = \\frac{{3}}{{8}} \\left[ \\frac{{32}}{{3}} - 4 \\right]$$
|
684 |
+
|
685 |
+
$$E[X] = \\frac{{3}}{{8}} \\cdot \\frac{{32 - 12}}{{3}} = \\frac{{3}}{{8}} \\cdot \\frac{{20}}{{3}} = \\frac{{20}}{{8}} = {E_X}$$
|
686 |
+
|
687 |
+
#### Variance Calculation
|
688 |
+
|
689 |
+
First, we need $E[X^2]$:
|
690 |
+
|
691 |
+
$$E[X^2] = \\int_{{-\\infty}}^{{\\infty}} x^2 \\cdot f(x) \\, dx = \\int_0^2 x^2 \\cdot \\frac{{3}}{{8}}(4x - 2x^2) \\, dx$$
|
692 |
+
|
693 |
+
$$E[X^2] = \\frac{{3}}{{8}} \\int_0^2 (4x^3 - 2x^4) \\, dx = \\frac{{3}}{{8}} \\left[ \\frac{{4x^4}}{{4}} - \\frac{{2x^5}}{{5}} \\right]_0^2$$
|
694 |
+
|
695 |
+
$$E[X^2] = \\frac{{3}}{{8}} \\left[ 4 - \\frac{{2 \\cdot 32}}{{5}} - 0 \\right] = \\frac{{3}}{{8}} \\left[ 4 - \\frac{{64}}{{5}} \\right]$$
|
696 |
+
|
697 |
+
$$E[X^2] = \\frac{{3}}{{8}} \\cdot \\frac{{20 - 64/5}}{{1}} = {E_X2}$$
|
698 |
+
|
699 |
+
Now we can calculate the variance:
|
700 |
+
|
701 |
+
$$\\text{{Var}}(X) = E[X^2] - (E[X])^2 = {E_X2} - ({E_X})^2 = {Var_X}$$
|
702 |
+
|
703 |
+
Therefore, the standard deviation is $\\sqrt{{\\text{{Var}}(X)}} = {Std_X}$.
|
704 |
+
"""
|
705 |
+
)
|
706 |
+
mo.vstack([_img, _calculations])
|
707 |
+
return (
|
708 |
+
E_X,
|
709 |
+
E_X2,
|
710 |
+
E_X_val,
|
711 |
+
Std_X,
|
712 |
+
Std_X_val,
|
713 |
+
Var_X,
|
714 |
+
Var_X_val,
|
715 |
+
create_expectation_variance_vis,
|
716 |
+
symbolic_stats_calc,
|
717 |
+
)
|
718 |
+
|
719 |
+
|
720 |
+
@app.cell(hide_code=True)
|
721 |
+
def _(mo):
|
722 |
+
mo.md(
|
723 |
+
r"""
|
724 |
+
## 🤔 Test Your Understanding
|
725 |
+
|
726 |
+
Select which of these statements about continuous distributions you think are correct:
|
727 |
+
|
728 |
+
/// details | The PDF of a continuous random variable can have values greater than 1
|
729 |
+
✅ Correct! Since the PDF represents density (not probability), it can exceed 1 as long as the total area under the curve equals 1.
|
730 |
+
///
|
731 |
+
|
732 |
+
/// details | For a continuous distribution, $P(X = a) > 0$ for any value $a$ in the support
|
733 |
+
❌ Incorrect! For continuous random variables, the probability of the random variable taking any specific exact value is always 0. That is, $P(X = a) = 0$ for any value $a$.
|
734 |
+
///
|
735 |
+
|
736 |
+
/// details | The area under a PDF curve between $a$ and $b$ equals the probability $P(a \leq X \leq b)$
|
737 |
+
✅ Correct! The area under the PDF curve over an interval gives the probability that the random variable falls within that interval.
|
738 |
+
///
|
739 |
+
|
740 |
+
/// details | The CDF function $F(x)$ is always equal to $\int_{-\infty}^{x} f(t) \, dt$
|
741 |
+
✅ Correct! The CDF at point $x$ is the integral of the PDF from negative infinity to $x$.
|
742 |
+
///
|
743 |
+
|
744 |
+
/// details | For a continuous random variable, $F(x)$ ranges from 0 to the maximum value in the support of the random variable
|
745 |
+
❌ Incorrect! The CDF $F(x)$ ranges from 0 to 1, representing probabilities. It approaches 1 (not the maximum value in the support) as $x$ approaches infinity.
|
746 |
+
///
|
747 |
+
|
748 |
+
/// details | To calculate the variance of a continuous random variable, we use the formula $\text{Var}(X) = E[X^2] - (E[X])^2$
|
749 |
+
✅ Correct! This formula applies to both discrete and continuous random variables.
|
750 |
+
///
|
751 |
+
"""
|
752 |
+
)
|
753 |
+
return
|
754 |
+
|
755 |
+
|
756 |
+
@app.cell(hide_code=True)
|
757 |
+
def _(mo):
|
758 |
+
mo.md(
|
759 |
+
r"""
|
760 |
+
## Summary
|
761 |
+
|
762 |
+
Moving from discrete to continuous thinking is a big conceptual leap, but it opens up powerful ways to model real-world phenomena.
|
763 |
+
|
764 |
+
In this notebook, we've seen how continuous random variables let us model quantities that can take any real value. Instead of dealing with probabilities at specific points (which are actually zero!), we work with probability density functions (PDFs) and find probabilities by calculating areas under curves.
|
765 |
+
|
766 |
+
Some key points to remember:
|
767 |
+
|
768 |
+
• PDFs give us relative likelihood, not actual probabilities - that's why they can exceed 1
|
769 |
+
• The probability between two points is the area under the PDF curve
|
770 |
+
• CDFs offer a convenient shortcut to find probabilities without integrating
|
771 |
+
• Expectation and variance work similarly to discrete variables, just with integrals instead of sums
|
772 |
+
• Constants in PDFs are determined by ensuring the total probability equals 1
|
773 |
+
|
774 |
+
This foundation will serve you well as we explore specific continuous distributions like normal, exponential, and beta in future notebooks. These distributions are the workhorses of probability theory and statistics, appearing everywhere from quality control to financial modeling.
|
775 |
+
|
776 |
+
One final thought: continuous distributions are beautiful mathematical objects, but remember they're just models. Real-world data is often discrete at some level, but continuous distributions provide elegant approximations that make calculations more tractable.
|
777 |
+
"""
|
778 |
+
)
|
779 |
+
return
|
780 |
+
|
781 |
+
|
782 |
+
@app.cell
|
783 |
+
def _(mo):
|
784 |
+
mo.md(r"""Appendix code (helper functions, variables, etc.):""")
|
785 |
+
return
|
786 |
+
|
787 |
+
|
788 |
+
@app.cell
|
789 |
+
def _():
|
790 |
+
import marimo as mo
|
791 |
+
return (mo,)
|
792 |
+
|
793 |
+
|
794 |
+
@app.cell(hide_code=True)
|
795 |
+
def _():
|
796 |
+
import numpy as np
|
797 |
+
import matplotlib.pyplot as plt
|
798 |
+
import scipy.stats as stats
|
799 |
+
import sympy
|
800 |
+
from scipy import integrate as scipy
|
801 |
+
import polars as pl
|
802 |
+
import altair as alt
|
803 |
+
from wigglystuff import TangleSlider
|
804 |
+
return TangleSlider, alt, np, pl, plt, scipy, stats, sympy
|
805 |
+
|
806 |
+
|
807 |
+
@app.cell(hide_code=True)
|
808 |
+
def _():
|
809 |
+
import io
|
810 |
+
import base64
|
811 |
+
from matplotlib.figure import Figure
|
812 |
+
|
813 |
+
# Helper function to convert mpl figure to an image format mo.image can handle
|
814 |
+
def fig_to_image(fig):
|
815 |
+
buf = io.BytesIO()
|
816 |
+
fig.savefig(buf, format='png')
|
817 |
+
buf.seek(0)
|
818 |
+
data = f"data:image/png;base64,{base64.b64encode(buf.read()).decode('utf-8')}"
|
819 |
+
return data
|
820 |
+
return Figure, base64, fig_to_image, io
|
821 |
+
|
822 |
+
|
823 |
+
@app.cell(hide_code=True)
|
824 |
+
def _(np, plt):
|
825 |
+
def create_pdf_visualization(a, b, distribution='uniform'):
|
826 |
+
fig, ax = plt.subplots(figsize=(10, 6))
|
827 |
+
|
828 |
+
# x-values
|
829 |
+
x = np.linspace(-0.5, 5.5, 1000)
|
830 |
+
|
831 |
+
# Various PDFs to visualize
|
832 |
+
if distribution == 'uniform':
|
833 |
+
# Uniform distribution from 0 to 5
|
834 |
+
y = np.where((x >= 0) & (x <= 5), 0.2, 0)
|
835 |
+
title = f"Uniform PDF from 0 to 5"
|
836 |
+
|
837 |
+
elif distribution == 'triangular':
|
838 |
+
# Triangular distribution peaked at 2.5
|
839 |
+
y = np.where(x <= 2.5, x/6.25, (5-x)/6.25) # peak at 2.5
|
840 |
+
y = np.where((x < 0) | (x > 5), 0, y)
|
841 |
+
title = f"Triangular PDF from 0 to 5"
|
842 |
+
|
843 |
+
elif distribution == 'exponential':
|
844 |
+
lambda_param = 0.5
|
845 |
+
y = np.where(x >= 0, lambda_param * np.exp(-lambda_param * x), 0)
|
846 |
+
title = f"Exponential PDF with λ = {lambda_param}"
|
847 |
+
|
848 |
+
# Plot PDF
|
849 |
+
ax.plot(x, y, 'b-', linewidth=2, label='PDF $f(x)$')
|
850 |
+
|
851 |
+
# Shade the area for the probability P(a ≤ X ≤ b)
|
852 |
+
mask = (x >= a) & (x <= b)
|
853 |
+
ax.fill_between(x[mask], y[mask], color='orange', alpha=0.5)
|
854 |
+
|
855 |
+
# Calculate the probability
|
856 |
+
dx = x[1] - x[0]
|
857 |
+
probability = np.sum(y[mask]) * dx
|
858 |
+
|
859 |
+
# vertical lines at a and b
|
860 |
+
ax.axvline(x=a, color='r', linestyle='--', alpha=0.7,
|
861 |
+
label=f'a = {a:.1f}')
|
862 |
+
ax.axvline(x=b, color='g', linestyle='--', alpha=0.7,
|
863 |
+
label=f'b = {b:.1f}')
|
864 |
+
|
865 |
+
# horizontal line at y=0
|
866 |
+
ax.axhline(y=0, color='black', linestyle='-', alpha=0.3)
|
867 |
+
|
868 |
+
# labels and title
|
869 |
+
ax.set_xlabel('x')
|
870 |
+
ax.set_ylabel('Probability Density $f(x)$')
|
871 |
+
ax.set_title(title)
|
872 |
+
ax.legend(loc='upper right')
|
873 |
+
|
874 |
+
# relevant annotations
|
875 |
+
ax.annotate(f'$P({a:.1f} \leq X \leq {b:.1f}) = {probability:.4f}$',
|
876 |
+
xy=(0.5, 0.9), xycoords='axes fraction',
|
877 |
+
bbox=dict(boxstyle='round,pad=0.5', facecolor='white', alpha=0.8),
|
878 |
+
horizontalalignment='center', fontsize=12)
|
879 |
+
|
880 |
+
plt.grid(alpha=0.3)
|
881 |
+
plt.tight_layout()
|
882 |
+
plt.gca()
|
883 |
+
return fig, probability
|
884 |
+
return (create_pdf_visualization,)
|
885 |
+
|
886 |
+
|
887 |
+
@app.cell(hide_code=True)
|
888 |
+
def _(np, plt, sympy):
|
889 |
+
def create_example_pdf_visualization():
|
890 |
+
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))
|
891 |
+
|
892 |
+
# x-values
|
893 |
+
x = np.linspace(-0.5, 2.5, 1000)
|
894 |
+
|
895 |
+
# PDF function
|
896 |
+
C = 3/8
|
897 |
+
pdf = np.where((x > 0) & (x < 2), C * (4*x - 2*x**2), 0)
|
898 |
+
|
899 |
+
# CDF
|
900 |
+
cdf = np.zeros_like(x)
|
901 |
+
for i, val in enumerate(x):
|
902 |
+
if val <= 0:
|
903 |
+
cdf[i] = 0
|
904 |
+
elif val >= 2:
|
905 |
+
cdf[i] = 1
|
906 |
+
else:
|
907 |
+
# Analytical form: C*(2x^2 - 2x^3/3)
|
908 |
+
cdf[i] = C * (2*val**2 - (2*val**3)/3)
|
909 |
+
|
910 |
+
# PDF Plot
|
911 |
+
ax1.plot(x, pdf, 'b-', linewidth=2)
|
912 |
+
ax1.set_title('PDF: $f(x) = \\frac{3}{8}(4x - 2x^2)$ for $0 < x < 2$')
|
913 |
+
ax1.set_xlabel('x')
|
914 |
+
ax1.set_ylabel('Probability Density')
|
915 |
+
ax1.grid(alpha=0.3)
|
916 |
+
|
917 |
+
# Highlight the area for P(X > 1)
|
918 |
+
mask = (x > 1) & (x < 2)
|
919 |
+
ax1.fill_between(x[mask], pdf[mask], color='orange', alpha=0.5,
|
920 |
+
label='P(X > 1) = 0.5')
|
921 |
+
|
922 |
+
# Add vertical line at x=1
|
923 |
+
ax1.axvline(x=1, color='r', linestyle='--', alpha=0.7)
|
924 |
+
ax1.legend()
|
925 |
+
|
926 |
+
# CDF Plot
|
927 |
+
ax2.plot(x, cdf, 'r-', linewidth=2)
|
928 |
+
ax2.set_title('CDF: $F(x)$ for the Example Distribution')
|
929 |
+
ax2.set_xlabel('x')
|
930 |
+
ax2.set_ylabel('Cumulative Probability')
|
931 |
+
ax2.grid(alpha=0.3)
|
932 |
+
|
933 |
+
# Mark appropriate (F(1) & F(2)) points)
|
934 |
+
ax2.plot(1, cdf[np.abs(x-1).argmin()], 'ro', markersize=8)
|
935 |
+
ax2.plot(2, cdf[np.abs(x-2).argmin()], 'ro', markersize=8)
|
936 |
+
|
937 |
+
# annotations
|
938 |
+
F_1 = C * (2*1**2 - (2*1**3)/3) # F(1)
|
939 |
+
ax2.annotate(f'F(1) = {F_1:.3f}', xy=(1, F_1), xytext=(1.1, 0.4),
|
940 |
+
arrowprops=dict(facecolor='black', shrink=0.05, width=1))
|
941 |
+
|
942 |
+
ax2.annotate(f'F(2) = 1', xy=(2, 1), xytext=(1.7, 0.8),
|
943 |
+
arrowprops=dict(facecolor='black', shrink=0.05, width=1))
|
944 |
+
|
945 |
+
ax2.annotate(f'P(X > 1) = 1 - F(1) = {1-F_1:.3f}', xy=(1.5, 0.7),
|
946 |
+
bbox=dict(boxstyle='round,pad=0.5', facecolor='orange', alpha=0.2))
|
947 |
+
|
948 |
+
# common x-limits
|
949 |
+
for ax in [ax1, ax2]:
|
950 |
+
ax.set_xlim(-0.25, 2.25)
|
951 |
+
|
952 |
+
plt.tight_layout()
|
953 |
+
plt.gca()
|
954 |
+
return fig
|
955 |
+
|
956 |
+
def symbolic_calculation():
|
957 |
+
x = sympy.symbols('x')
|
958 |
+
C = sympy.Rational(3, 8)
|
959 |
+
|
960 |
+
# PDF defn
|
961 |
+
pdf_expr = C * (4*x - 2*x**2)
|
962 |
+
|
963 |
+
# Verify PDF integrates to 1
|
964 |
+
total_prob = sympy.integrate(pdf_expr, (x, 0, 2))
|
965 |
+
|
966 |
+
# Calculate P(X > 1)
|
967 |
+
prob_gt_1 = sympy.integrate(pdf_expr, (x, 1, 2))
|
968 |
+
|
969 |
+
return f"""Symbolic calculation verification:
|
970 |
+
|
971 |
+
1. Total probability: ∫₀² {C}(4x - 2x²) dx = {total_prob}
|
972 |
+
2. P(X > 1): ∫₁² {C}(4x - 2x²) dx = {prob_gt_1}
|
973 |
+
"""
|
974 |
+
|
975 |
+
return create_example_pdf_visualization, symbolic_calculation
|
976 |
+
|
977 |
+
|
978 |
+
if __name__ == "__main__":
|
979 |
+
app.run()
|
probability/17_normal_distribution.py
ADDED
@@ -0,0 +1,1127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# /// script
|
2 |
+
# requires-python = ">=3.10"
|
3 |
+
# dependencies = [
|
4 |
+
# "marimo",
|
5 |
+
# "matplotlib==3.10.1",
|
6 |
+
# "scipy==1.15.2",
|
7 |
+
# "wigglystuff==0.1.10",
|
8 |
+
# "numpy==2.2.4",
|
9 |
+
# ]
|
10 |
+
# ///
|
11 |
+
|
12 |
+
import marimo
|
13 |
+
|
14 |
+
__generated_with = "0.11.26"
|
15 |
+
app = marimo.App(width="medium", app_title="Normal Distribution")
|
16 |
+
|
17 |
+
|
18 |
+
@app.cell(hide_code=True)
|
19 |
+
def _(mo):
|
20 |
+
mo.md(
|
21 |
+
r"""
|
22 |
+
# Normal Distribution
|
23 |
+
|
24 |
+
_This notebook is a computational companion to ["Probability for Computer Scientists"](https://chrispiech.github.io/probabilityForComputerScientists/en/part2/normal/), by Stanford professor Chris Piech._
|
25 |
+
|
26 |
+
The Normal (also known as Gaussian) distribution is one of the most important probability distributions in statistics and data science. It's characterized by a symmetric bell-shaped curve and is fully defined by two parameters: mean (μ) and variance (σ²).
|
27 |
+
"""
|
28 |
+
)
|
29 |
+
return
|
30 |
+
|
31 |
+
|
32 |
+
@app.cell(hide_code=True)
|
33 |
+
def _(mo):
|
34 |
+
mo.md(
|
35 |
+
r"""
|
36 |
+
## Normal Random Variable Definition
|
37 |
+
|
38 |
+
The Normal (or Gaussian) random variable is denoted as:
|
39 |
+
|
40 |
+
$$X \sim \mathcal{N}(\mu, \sigma^2)$$
|
41 |
+
|
42 |
+
Where:
|
43 |
+
|
44 |
+
- $X$ is our random variable
|
45 |
+
- $\mathcal{N}$ indicates it follows a Normal distribution
|
46 |
+
- $\mu$ is the mean parameter
|
47 |
+
- $\sigma^2$ is the variance parameter (sometimes written as $\sigma$ for standard deviation)
|
48 |
+
|
49 |
+
```
|
50 |
+
X ~ N(μ, σ²)
|
51 |
+
↑ ↑ ↑ ↑
|
52 |
+
| | | +-- Variance (spread)
|
53 |
+
| | | of the distribution
|
54 |
+
| | +-- Mean (center)
|
55 |
+
| | of the distribution
|
56 |
+
| +-- Indicates Normal
|
57 |
+
| distribution
|
58 |
+
|
|
59 |
+
Our random variable
|
60 |
+
```
|
61 |
+
|
62 |
+
The Normal distribution is particularly important for many reasons:
|
63 |
+
|
64 |
+
1. It arises naturally from the sum of independent random variables (Central Limit Theorem)
|
65 |
+
2. It appears frequently in natural phenomena
|
66 |
+
3. It is the maximum entropy distribution given a fixed mean and variance
|
67 |
+
4. It simplifies many mathematical calculations in statistics and probability
|
68 |
+
"""
|
69 |
+
)
|
70 |
+
return
|
71 |
+
|
72 |
+
|
73 |
+
@app.cell(hide_code=True)
|
74 |
+
def _(mo):
|
75 |
+
mo.md(
|
76 |
+
r"""
|
77 |
+
## Properties of Normal Distribution
|
78 |
+
|
79 |
+
| Property | Formula |
|
80 |
+
|----------|---------|
|
81 |
+
| Notation | $X \sim \mathcal{N}(\mu, \sigma^2)$ |
|
82 |
+
| Description | A common, naturally occurring distribution |
|
83 |
+
| Parameters | $\mu \in \mathbb{R}$, the mean<br>$\sigma^2 \in \mathbb{R}^+$, the variance |
|
84 |
+
| Support | $x \in \mathbb{R}$ |
|
85 |
+
| PDF equation | $f(x) = \frac{1}{\sigma\sqrt{2\pi}}e^{-\frac{1}{2}(\frac{x-\mu}{\sigma})^2}$ |
|
86 |
+
| CDF equation | $F(x) = \Phi(\frac{x-\mu}{\sigma})$ where $\Phi$ is the CDF of the standard normal |
|
87 |
+
| Expectation | $E[X] = \mu$ |
|
88 |
+
| Variance | $\text{Var}(X) = \sigma^2$ |
|
89 |
+
|
90 |
+
The PDF (Probability Density Function) reaches its maximum value at $x = \mu$, where the exponent becomes zero and $e^0 = 1$.
|
91 |
+
"""
|
92 |
+
)
|
93 |
+
return
|
94 |
+
|
95 |
+
|
96 |
+
@app.cell(hide_code=True)
|
97 |
+
def _(mean_slider, mo, std_slider):
|
98 |
+
mo.md(
|
99 |
+
f"""
|
100 |
+
The figure below shows a comparison between:
|
101 |
+
|
102 |
+
- The **Standard Normal Distribution** (purple curve): N(0, 1)
|
103 |
+
- A **Normal Distribution** with the parameters you selected (blue curve)
|
104 |
+
|
105 |
+
Adjust the mean (μ) {mean_slider} and standard deviation (σ) {std_slider} below to see how the normal distribution changes shape.
|
106 |
+
|
107 |
+
"""
|
108 |
+
)
|
109 |
+
return
|
110 |
+
|
111 |
+
|
112 |
+
@app.cell(hide_code=True)
|
113 |
+
def _(
|
114 |
+
create_distribution_comparison,
|
115 |
+
fig_to_image,
|
116 |
+
mean_slider,
|
117 |
+
mo,
|
118 |
+
std_slider,
|
119 |
+
):
|
120 |
+
# values from the sliders
|
121 |
+
current_mu = mean_slider.amount
|
122 |
+
current_sigma = std_slider.amount
|
123 |
+
|
124 |
+
# Create plot
|
125 |
+
comparison_fig = create_distribution_comparison(current_mu, current_sigma)
|
126 |
+
|
127 |
+
# Call, convert and display
|
128 |
+
comp_image = mo.image(fig_to_image(comparison_fig), width="100%")
|
129 |
+
comp_image
|
130 |
+
return comp_image, comparison_fig, current_mu, current_sigma
|
131 |
+
|
132 |
+
|
133 |
+
@app.cell(hide_code=True)
|
134 |
+
def _(mean_slider, mo, std_slider):
|
135 |
+
mo.md(
|
136 |
+
f"""
|
137 |
+
## Interactive Normal Distribution Visualization
|
138 |
+
|
139 |
+
The shape of a normal distribution is determined by two key parameters:
|
140 |
+
|
141 |
+
- The **mean (μ):** {mean_slider} controls the center of the distribution.
|
142 |
+
|
143 |
+
- The **standard deviation (σ):** {std_slider} controls the spread (width) of the distribution.
|
144 |
+
|
145 |
+
Try adjusting these parameters to see how they affect the shape of the distribution below:
|
146 |
+
|
147 |
+
"""
|
148 |
+
)
|
149 |
+
return
|
150 |
+
|
151 |
+
|
152 |
+
@app.cell(hide_code=True)
|
153 |
+
def _(create_normal_pdf_plot, fig_to_image, mean_slider, mo, std_slider):
|
154 |
+
# value from widgets
|
155 |
+
_current_mu = mean_slider.amount
|
156 |
+
_current_sigma = std_slider.amount
|
157 |
+
|
158 |
+
# Create visualization
|
159 |
+
pdf_fig = create_normal_pdf_plot(_current_mu, _current_sigma)
|
160 |
+
|
161 |
+
# Display plot
|
162 |
+
pdf_image = mo.image(fig_to_image(pdf_fig), width="100%")
|
163 |
+
|
164 |
+
pdf_explanation = mo.md(
|
165 |
+
r"""
|
166 |
+
**Understanding the Normal Distribution Visualization:**
|
167 |
+
|
168 |
+
- **PDF (top)**: The probability density function shows the relative likelihood of different values.
|
169 |
+
The highest point occurs at the mean (μ).
|
170 |
+
|
171 |
+
- **Shaded regions**: The green shaded areas represent:
|
172 |
+
- μ ± 1σ: Contains approximately 68.3% of the probability
|
173 |
+
- μ ± 2σ: Contains approximately 95.5% of the probability
|
174 |
+
- μ ± 3σ: Contains approximately 99.7% of the probability (the "68-95-99.7 rule")
|
175 |
+
|
176 |
+
- **CDF (bottom)**: The cumulative distribution function shows the probability that X is less than or equal to a given value.
|
177 |
+
- At x = μ, the CDF equals 0.5 (50% probability)
|
178 |
+
- At x = μ + σ, the CDF equals approximately 0.84 (84% probability)
|
179 |
+
- At x = μ - σ, the CDF equals approximately 0.16 (16% probability)
|
180 |
+
"""
|
181 |
+
)
|
182 |
+
|
183 |
+
mo.vstack([pdf_image, pdf_explanation])
|
184 |
+
return pdf_explanation, pdf_fig, pdf_image
|
185 |
+
|
186 |
+
|
187 |
+
@app.cell(hide_code=True)
|
188 |
+
def _(mo):
|
189 |
+
mo.md(
|
190 |
+
r"""
|
191 |
+
## Standard Normal Distribution
|
192 |
+
|
193 |
+
The **Standard Normal Distribution** is a special case of the normal distribution where $\mu = 0$ and $\sigma = 1$. We denote it as:
|
194 |
+
|
195 |
+
$$Z \sim \mathcal{N}(0, 1)$$
|
196 |
+
|
197 |
+
This distribution is particularly important because:
|
198 |
+
|
199 |
+
1. Any normal distribution can be transformed into the standard normal
|
200 |
+
2. Statistical tables and calculations often use the standard normal as a reference
|
201 |
+
|
202 |
+
### Standardizing a Normal Random Variable
|
203 |
+
|
204 |
+
For any normal random variable $X \sim \mathcal{N}(\mu, \sigma^2)$, we can transform it to the standard normal $Z$ using:
|
205 |
+
|
206 |
+
$$Z = \frac{X - \mu}{\sigma}$$
|
207 |
+
|
208 |
+
Let's see the mathematical derivation:
|
209 |
+
|
210 |
+
\begin{align*}
|
211 |
+
W &= \frac{X -\mu}{\sigma} && \text{Subtract by $\mu$ and divide by $\sigma$} \\
|
212 |
+
&= \frac{1}{\sigma}X - \frac{\mu}{\sigma} && \text{Use algebra to rewrite the equation}\\
|
213 |
+
&= aX + b && \text{Linear transform where $a = \frac{1}{\sigma}$, $b = -\frac{\mu}{\sigma}$}\\
|
214 |
+
&\sim \mathcal{N}(a\mu + b, a^2\sigma^2) && \text{The linear transform of a Normal is another Normal}\\
|
215 |
+
&\sim \mathcal{N}\left(\frac{\mu}{\sigma} - \frac{\mu}{\sigma}, \frac{\sigma^2}{\sigma^2}\right) && \text{Substitute values for $a$ and $b$}\\
|
216 |
+
&\sim \mathcal{N}(0, 1) && \text{The standard normal}
|
217 |
+
\end{align*}
|
218 |
+
|
219 |
+
This transformation is the foundation for many statistical tests and probability calculations.
|
220 |
+
"""
|
221 |
+
)
|
222 |
+
return
|
223 |
+
|
224 |
+
|
225 |
+
@app.cell(hide_code=True)
|
226 |
+
def _(create_standardization_plot, fig_to_image, mo):
|
227 |
+
# Create and display visualization
|
228 |
+
stand_fig = create_standardization_plot()
|
229 |
+
|
230 |
+
# Display
|
231 |
+
stand_image = mo.image(fig_to_image(stand_fig), width="100%")
|
232 |
+
|
233 |
+
stand_explanation = mo.md(
|
234 |
+
r"""
|
235 |
+
**Standardizing a Normal Distribution: A Two-Step Process**
|
236 |
+
|
237 |
+
The visualization above shows the process of transforming any normal distribution to the standard normal:
|
238 |
+
|
239 |
+
1. **Shift the distribution** (left plot): First, we subtract the mean (μ) from X, centering the distribution at 0.
|
240 |
+
|
241 |
+
2. **Scale the distribution** (right plot): Next, we divide by the standard deviation (σ), which adjusts the spread to match the standard normal.
|
242 |
+
|
243 |
+
The resulting standard normal distribution Z ~ N(0,1) has a mean of 0 and a variance of 1.
|
244 |
+
|
245 |
+
This transformation allows us to use standardized tables and calculations for any normal distribution.
|
246 |
+
"""
|
247 |
+
)
|
248 |
+
|
249 |
+
mo.vstack([stand_image, stand_explanation])
|
250 |
+
return stand_explanation, stand_fig, stand_image
|
251 |
+
|
252 |
+
|
253 |
+
@app.cell(hide_code=True)
|
254 |
+
def _(mo):
|
255 |
+
mo.md(
|
256 |
+
r"""
|
257 |
+
## Linear Transformations of Normal Variables
|
258 |
+
|
259 |
+
One useful property of the normal distribution is that linear transformations of normal random variables remain normal.
|
260 |
+
|
261 |
+
If $X \sim \mathcal{N}(\mu, \sigma^2)$ and $Y = aX + b$ (where $a$ and $b$ are constants), then:
|
262 |
+
|
263 |
+
$$Y \sim \mathcal{N}(a\mu + b, a^2\sigma^2)$$
|
264 |
+
|
265 |
+
This means:
|
266 |
+
|
267 |
+
- The mean is transformed by $a\mu + b$
|
268 |
+
- The variance is transformed by $a^2\sigma^2$
|
269 |
+
|
270 |
+
This property is extremely useful in statistics and probability calculations, as it allows us to easily determine the _distribution_ of transformed variables.
|
271 |
+
"""
|
272 |
+
)
|
273 |
+
return
|
274 |
+
|
275 |
+
|
276 |
+
@app.cell(hide_code=True)
|
277 |
+
def _(mo):
|
278 |
+
mo.md(
|
279 |
+
r"""
|
280 |
+
## Calculating Probabilities with the Normal CDF
|
281 |
+
|
282 |
+
Unlike many other distributions, the normal distribution does not have a closed-form expression for its CDF. However, we can use the standard normal CDF (denoted as $\Phi$) to calculate probabilities.
|
283 |
+
|
284 |
+
For any normal random variable $X \sim \mathcal{N}(\mu, \sigma^2)$, the CDF is:
|
285 |
+
|
286 |
+
$$F_X(x) = P(X \leq x) = \Phi\left(\frac{x - \mu}{\sigma}\right)$$
|
287 |
+
|
288 |
+
Where $\Phi$ is the CDF of the standard normal distribution.
|
289 |
+
|
290 |
+
### Derivation
|
291 |
+
|
292 |
+
\begin{align*}
|
293 |
+
F_X(x) &= P(X \leq x) \\
|
294 |
+
&= P\left(\frac{X - \mu}{\sigma} \leq \frac{x - \mu}{\sigma}\right) \\
|
295 |
+
&= P\left(Z \leq \frac{x - \mu}{\sigma}\right) \\
|
296 |
+
&= \Phi\left(\frac{x - \mu}{\sigma}\right)
|
297 |
+
\end{align*}
|
298 |
+
|
299 |
+
Let's look at some examples of calculating probabilities with normal distributions.
|
300 |
+
"""
|
301 |
+
)
|
302 |
+
return
|
303 |
+
|
304 |
+
|
305 |
+
@app.cell(hide_code=True)
|
306 |
+
def _(mo):
|
307 |
+
mo.md("""## Examples of Normal Distributions""")
|
308 |
+
return
|
309 |
+
|
310 |
+
|
311 |
+
@app.cell(hide_code=True)
|
312 |
+
def _(create_probability_example, fig_to_image, mo):
|
313 |
+
# Create visualization
|
314 |
+
default_mu = 3
|
315 |
+
default_sigma = 4
|
316 |
+
default_query = 0
|
317 |
+
|
318 |
+
prob_fig, prob_value, ex_z_score = create_probability_example(default_mu, default_sigma, default_query)
|
319 |
+
|
320 |
+
# Display
|
321 |
+
prob_image = mo.image(fig_to_image(prob_fig), width="100%")
|
322 |
+
|
323 |
+
prob_explanation = mo.md(
|
324 |
+
f"""
|
325 |
+
**Example: Let X ~ N(3, 16), what is P(X > 0)?**
|
326 |
+
|
327 |
+
To solve this probability question:
|
328 |
+
|
329 |
+
1. First, we standardize the query value:
|
330 |
+
Z = (x - μ) / σ = (0 - 3) / 4 = -0.75
|
331 |
+
|
332 |
+
2. Then we calculate using the standard normal CDF:
|
333 |
+
P(X > 0) = P(Z > -0.75) = 1 - P(Z ≤ -0.75) = 1 - Φ(-0.75)
|
334 |
+
|
335 |
+
3. Because the standard normal is symmetric:
|
336 |
+
1 - Φ(-0.75) = Φ(0.75) = {prob_value:.3f}
|
337 |
+
|
338 |
+
The shaded orange area in the graph represents this probability of approximately {prob_value:.3f}.
|
339 |
+
"""
|
340 |
+
)
|
341 |
+
|
342 |
+
mo.vstack([prob_image, prob_explanation])
|
343 |
+
return (
|
344 |
+
default_mu,
|
345 |
+
default_query,
|
346 |
+
default_sigma,
|
347 |
+
ex_z_score,
|
348 |
+
prob_explanation,
|
349 |
+
prob_fig,
|
350 |
+
prob_image,
|
351 |
+
prob_value,
|
352 |
+
)
|
353 |
+
|
354 |
+
|
355 |
+
@app.cell(hide_code=True)
|
356 |
+
def _(create_range_probability_example, fig_to_image, mo, stats):
|
357 |
+
# Create visualization
|
358 |
+
default_range_mu = 3
|
359 |
+
default_range_sigma = 4
|
360 |
+
default_range_lower = 2
|
361 |
+
default_range_upper = 5
|
362 |
+
|
363 |
+
range_fig, range_prob, range_z_lower, range_z_upper = create_range_probability_example(
|
364 |
+
default_range_mu, default_range_sigma, default_range_lower, default_range_upper)
|
365 |
+
|
366 |
+
# Display
|
367 |
+
range_image = mo.image(fig_to_image(range_fig), width="100%")
|
368 |
+
|
369 |
+
range_explanation = mo.md(
|
370 |
+
f"""
|
371 |
+
**Example: Let X ~ N(3, 16), what is P(2 < X < 5)?**
|
372 |
+
|
373 |
+
To solve this range probability question:
|
374 |
+
|
375 |
+
1. First, we standardize both bounds:
|
376 |
+
Z_lower = (lower - μ) / σ = (2 - 3) / 4 = -0.25
|
377 |
+
Z_upper = (upper - μ) / σ = (5 - 3) / 4 = 0.5
|
378 |
+
|
379 |
+
2. Then we calculate using the standard normal CDF:
|
380 |
+
P(2 < X < 5) = P(-0.25 < Z < 0.5)
|
381 |
+
= Φ(0.5) - Φ(-0.25)
|
382 |
+
= Φ(0.5) - (1 - Φ(0.25))
|
383 |
+
= Φ(0.5) + Φ(0.25) - 1
|
384 |
+
|
385 |
+
3. Computing these values:
|
386 |
+
= {stats.norm.cdf(0.5):.3f} + {stats.norm.cdf(0.25):.3f} - 1
|
387 |
+
= {range_prob:.3f}
|
388 |
+
|
389 |
+
The shaded orange area in the graph represents this probability of approximately {range_prob:.3f}.
|
390 |
+
"""
|
391 |
+
)
|
392 |
+
|
393 |
+
mo.vstack([range_image, range_explanation])
|
394 |
+
return (
|
395 |
+
default_range_lower,
|
396 |
+
default_range_mu,
|
397 |
+
default_range_sigma,
|
398 |
+
default_range_upper,
|
399 |
+
range_explanation,
|
400 |
+
range_fig,
|
401 |
+
range_image,
|
402 |
+
range_prob,
|
403 |
+
range_z_lower,
|
404 |
+
range_z_upper,
|
405 |
+
)
|
406 |
+
|
407 |
+
|
408 |
+
@app.cell(hide_code=True)
|
409 |
+
def _(create_voltage_example_visualization, fig_to_image, mo):
|
410 |
+
# Create visualization
|
411 |
+
voltage_fig, voltage_error_prob = create_voltage_example_visualization()
|
412 |
+
|
413 |
+
# Display
|
414 |
+
voltage_image = mo.image(fig_to_image(voltage_fig), width="100%")
|
415 |
+
|
416 |
+
voltage_explanation = mo.md(
|
417 |
+
r"""
|
418 |
+
**Example: Signal Transmission with Noise**
|
419 |
+
|
420 |
+
In this example, we're sending digital signals over a wire:
|
421 |
+
|
422 |
+
- We send voltage 2 to represent a binary "1"
|
423 |
+
- We send voltage -2 to represent a binary "0"
|
424 |
+
|
425 |
+
The received signal R is the sum of the transmitted voltage (X) and random noise (Y):
|
426 |
+
R = X + Y, where Y ~ N(0, 1)
|
427 |
+
|
428 |
+
When decoding, we use a threshold of 0.5:
|
429 |
+
|
430 |
+
- If R ≥ 0.5, we interpret it as "1"
|
431 |
+
- If R < 0.5, we interpret it as "0"
|
432 |
+
|
433 |
+
Let's calculate the probability of error when sending a "1" (voltage = 2):
|
434 |
+
|
435 |
+
\begin{align*}
|
436 |
+
P(\text{Error when sending "1"}) &= P(X + Y < 0.5) \\
|
437 |
+
&= P(2 + Y < 0.5) \\
|
438 |
+
&= P(Y < -1.5) \\
|
439 |
+
&= \Phi(-1.5) \\
|
440 |
+
&\approx 0.067
|
441 |
+
\end{align*}
|
442 |
+
|
443 |
+
Therefore, the probability of incorrectly decoding a transmitted "1" as "0" is approximately 6.7%.
|
444 |
+
|
445 |
+
The orange shaded area in the plot represents this error probability.
|
446 |
+
"""
|
447 |
+
)
|
448 |
+
|
449 |
+
mo.vstack([voltage_image, voltage_explanation])
|
450 |
+
return voltage_error_prob, voltage_explanation, voltage_fig, voltage_image
|
451 |
+
|
452 |
+
|
453 |
+
@app.cell(hide_code=True)
|
454 |
+
def emirical_rule(mo):
|
455 |
+
mo.md(
|
456 |
+
r"""
|
457 |
+
## The 68-95-99.7 Rule (Empirical Rule)
|
458 |
+
|
459 |
+
One of the most useful properties of the normal distribution is the "[68-95-99.7 rule](https://en.wikipedia.org/wiki/68-95-99.7_rule)," which states that:
|
460 |
+
|
461 |
+
- Approximately 68% of the data falls within 1 standard deviation of the mean
|
462 |
+
- Approximately 95% of the data falls within 2 standard deviations of the mean
|
463 |
+
- Approximately 99.7% of the data falls within 3 standard deviations of the mean
|
464 |
+
|
465 |
+
Let's verify this with a calculation for the 68% rule:
|
466 |
+
|
467 |
+
\begin{align}
|
468 |
+
P(\mu - \sigma < X < \mu + \sigma)
|
469 |
+
&= P(X < \mu + \sigma) - P(X < \mu - \sigma) \\
|
470 |
+
&= \Phi\left(\frac{(\mu + \sigma)-\mu}{\sigma}\right) - \Phi\left(\frac{(\mu - \sigma)-\mu}{\sigma}\right) \\
|
471 |
+
&= \Phi\left(\frac{\sigma}{\sigma}\right) - \Phi\left(\frac{-\sigma}{\sigma}\right) \\
|
472 |
+
&= \Phi(1) - \Phi(-1) \\
|
473 |
+
&\approx 0.8413 - 0.1587 \\
|
474 |
+
&\approx 0.6826 \approx 68.3\%
|
475 |
+
\end{align}
|
476 |
+
|
477 |
+
This calculation works for any normal distribution, regardless of the values of $\mu$ and $\sigma$!
|
478 |
+
"""
|
479 |
+
)
|
480 |
+
return
|
481 |
+
|
482 |
+
|
483 |
+
@app.cell(hide_code=True)
|
484 |
+
def _(mo):
|
485 |
+
mo.md(r"""The Cumulative Distribution Function (CDF) gives the probability that a random variable is less than or equal to a specific value. Use the interactive calculator below to compute CDF values for a normal distribution.""")
|
486 |
+
return
|
487 |
+
|
488 |
+
|
489 |
+
@app.cell(hide_code=True)
|
490 |
+
def _(mo, mu_slider, sigma_slider, x_slider):
|
491 |
+
mo.md(
|
492 |
+
f"""
|
493 |
+
## Interactive Normal CDF Calculator
|
494 |
+
|
495 |
+
Use the sliders below to explore different probability calculations:
|
496 |
+
|
497 |
+
**Query value (x):** {x_slider} — The value at which to evaluate F(x) = P(X ≤ x)
|
498 |
+
|
499 |
+
**Mean (μ):** {mu_slider} — The center of the distribution
|
500 |
+
|
501 |
+
**Standard deviation (σ):** {sigma_slider} — The spread of the distribution (larger σ means more spread)
|
502 |
+
"""
|
503 |
+
)
|
504 |
+
return
|
505 |
+
|
506 |
+
|
507 |
+
@app.cell(hide_code=True)
|
508 |
+
def _(
|
509 |
+
create_cdf_calculator_plot,
|
510 |
+
fig_to_image,
|
511 |
+
mo,
|
512 |
+
mu_slider,
|
513 |
+
sigma_slider,
|
514 |
+
x_slider,
|
515 |
+
):
|
516 |
+
# Values from widgets
|
517 |
+
calc_x = x_slider.amount
|
518 |
+
calc_mu = mu_slider.amount
|
519 |
+
calc_sigma = sigma_slider.amount
|
520 |
+
|
521 |
+
# Create visualization
|
522 |
+
calc_fig, cdf_value = create_cdf_calculator_plot(calc_x, calc_mu, calc_sigma)
|
523 |
+
|
524 |
+
# Standardized z-score
|
525 |
+
calc_z_score = (calc_x - calc_mu) / calc_sigma
|
526 |
+
|
527 |
+
# Display
|
528 |
+
calc_image = mo.image(fig_to_image(calc_fig), width="100%")
|
529 |
+
|
530 |
+
calc_result = mo.md(
|
531 |
+
f"""
|
532 |
+
### Results:
|
533 |
+
|
534 |
+
For a Normal distribution with parameters μ = {calc_mu:.1f} and σ = {calc_sigma:.1f}:
|
535 |
+
|
536 |
+
- The value x = {calc_x:.1f} corresponds to a z-score of z = {calc_z_score:.3f}
|
537 |
+
- The CDF value F({calc_x:.1f}) = P(X ≤ {calc_x:.1f}) = {cdf_value:.3f}
|
538 |
+
- This means the probability that X is less than or equal to {calc_x:.1f} is {cdf_value*100:.1f}%
|
539 |
+
|
540 |
+
**Computing this in Python:**
|
541 |
+
```python
|
542 |
+
from scipy import stats
|
543 |
+
|
544 |
+
# Using the one-line method
|
545 |
+
p = stats.norm.cdf({calc_x:.1f}, {calc_mu:.1f}, {calc_sigma:.1f})
|
546 |
+
|
547 |
+
# OR using the two-line method
|
548 |
+
X = stats.norm({calc_mu:.1f}, {calc_sigma:.1f})
|
549 |
+
p = X.cdf({calc_x:.1f})
|
550 |
+
```
|
551 |
+
|
552 |
+
**Note:** In SciPy's `stats.norm`, the second parameter is the standard deviation (σ), not the variance (σ²).
|
553 |
+
"""
|
554 |
+
)
|
555 |
+
|
556 |
+
mo.vstack([calc_image, calc_result])
|
557 |
+
return (
|
558 |
+
calc_fig,
|
559 |
+
calc_image,
|
560 |
+
calc_mu,
|
561 |
+
calc_result,
|
562 |
+
calc_sigma,
|
563 |
+
calc_x,
|
564 |
+
calc_z_score,
|
565 |
+
cdf_value,
|
566 |
+
)
|
567 |
+
|
568 |
+
|
569 |
+
@app.cell(hide_code=True)
|
570 |
+
def _(mo):
|
571 |
+
mo.md(
|
572 |
+
r"""
|
573 |
+
## 🤔 Test Your Understanding
|
574 |
+
|
575 |
+
Test your knowledge with these true/false questions about normal distributions:
|
576 |
+
|
577 |
+
/// details | For a normal random variable X ~ N(μ, σ²), the probability that X takes on exactly the value μ is highest among all possible values.
|
578 |
+
|
579 |
+
**✅ True**
|
580 |
+
|
581 |
+
While the PDF is indeed highest at x = μ, making this the most likely value in terms of density, remember that for continuous random variables, the probability of any exact value is zero. The statement refers to the density function being maximized at the mean.
|
582 |
+
///
|
583 |
+
|
584 |
+
/// details | The probability that a normal random variable X equals any specific exact value (e.g., P(X = 3)) is always zero.
|
585 |
+
|
586 |
+
**✅ True**
|
587 |
+
|
588 |
+
For continuous random variables including the normal, the probability of any exact value is zero. Probabilities only make sense for ranges of values, which is why we integrate the PDF over intervals.
|
589 |
+
///
|
590 |
+
|
591 |
+
/// details | If X ~ N(μ, σ²), then aX + b ~ N(aμ + b, a²σ²) for any constants a and b.
|
592 |
+
|
593 |
+
**✅ True**
|
594 |
+
|
595 |
+
Linear transformations of normal random variables remain normal, with the given transformation of the parameters. This is a key property that makes normal distributions particularly useful.
|
596 |
+
///
|
597 |
+
|
598 |
+
/// details | If X ~ N(5, 9) and Y ~ N(3, 4) are independent, then X + Y ~ N(8, 5).
|
599 |
+
|
600 |
+
**❌ False**
|
601 |
+
|
602 |
+
While the mean of the sum is indeed the sum of the means (5 + 3 = 8), the variance of the sum is the sum of the variances (9 + 4 = 13), not 5. The correct distribution would be X + Y ~ N(8, 13).
|
603 |
+
///
|
604 |
+
"""
|
605 |
+
)
|
606 |
+
return
|
607 |
+
|
608 |
+
|
609 |
+
@app.cell(hide_code=True)
|
610 |
+
def _(mo):
|
611 |
+
mo.md(
|
612 |
+
r"""
|
613 |
+
## Summary
|
614 |
+
|
615 |
+
We've taken a tour of Normal distributions; probably the most famous probability distribution you'll encounter in statistics. It's that nice bell-shaped curve that shows up everywhere from heights/ weights to memes to measurement errors & stock returns.
|
616 |
+
|
617 |
+
The Normal distribution isn't just pretty — it's incredibly practical. With just two parameters (mean and standard deviation), you can describe complex phenomena and make powerful predictions. Plus, thanks to the Central Limit Theorem, many random processes naturally converge to this distribution, which is why it's so prevalent.
|
618 |
+
|
619 |
+
**What we covered:**
|
620 |
+
|
621 |
+
- The mathematical definition and key properties of Normal random variables
|
622 |
+
|
623 |
+
- How to transform any Normal distribution to the standard Normal
|
624 |
+
|
625 |
+
- Calculating probabilities using the CDF (no more looking up values in those tiny tables in the back of textbooks or Clark's table!)
|
626 |
+
|
627 |
+
Whether you're analyzing data, designing experiments, or building ML models, the concepts we explored provide a solid foundation for working with this fundamental distribution.
|
628 |
+
"""
|
629 |
+
)
|
630 |
+
return
|
631 |
+
|
632 |
+
|
633 |
+
@app.cell(hide_code=True)
|
634 |
+
def _(mo):
|
635 |
+
mo.md(r"""Appendix (helper code and functions)""")
|
636 |
+
return
|
637 |
+
|
638 |
+
|
639 |
+
@app.cell
|
640 |
+
def _():
|
641 |
+
import marimo as mo
|
642 |
+
return (mo,)
|
643 |
+
|
644 |
+
|
645 |
+
@app.cell(hide_code=True)
|
646 |
+
def _():
|
647 |
+
from wigglystuff import TangleSlider
|
648 |
+
return (TangleSlider,)
|
649 |
+
|
650 |
+
|
651 |
+
@app.cell(hide_code=True)
|
652 |
+
def _(np, plt, stats):
|
653 |
+
def create_normal_pdf_plot(mu, sigma):
|
654 |
+
|
655 |
+
# Range for x values (show μ ± 4σ)
|
656 |
+
x = np.linspace(mu - 4*sigma, mu + 4*sigma, 1000)
|
657 |
+
pdf = stats.norm.pdf(x, mu, sigma)
|
658 |
+
|
659 |
+
# Calculate CDF values
|
660 |
+
cdf = stats.norm.cdf(x, mu, sigma)
|
661 |
+
|
662 |
+
# Create plot with two subplots for (PDF and CDF)
|
663 |
+
pdf_fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(10, 8))
|
664 |
+
|
665 |
+
# PDF plot
|
666 |
+
ax1.plot(x, pdf, color='royalblue', linewidth=2, label='PDF')
|
667 |
+
ax1.fill_between(x, pdf, color='royalblue', alpha=0.2)
|
668 |
+
|
669 |
+
# Vertical line at mean
|
670 |
+
ax1.axvline(x=mu, color='red', linestyle='--', linewidth=1.5,
|
671 |
+
label=f'Mean: μ = {mu:.1f}')
|
672 |
+
|
673 |
+
# Stdev regions
|
674 |
+
for i in range(1, 4):
|
675 |
+
alpha = 0.1 if i > 1 else 0.2
|
676 |
+
percentage = 100*stats.norm.cdf(i) - 100*stats.norm.cdf(-i)
|
677 |
+
label = f'μ ± {i}σ: {percentage:.1f}%' if i == 1 else None
|
678 |
+
ax1.axvspan(mu - i*sigma, mu + i*sigma, alpha=alpha, color='green',
|
679 |
+
label=label)
|
680 |
+
|
681 |
+
# Annotations
|
682 |
+
ax1.annotate(f'μ = {mu:.1f}', xy=(mu, max(pdf)*0.15), xytext=(mu+0.5*sigma, max(pdf)*0.4),
|
683 |
+
arrowprops=dict(facecolor='black', width=1, shrink=0.05))
|
684 |
+
|
685 |
+
ax1.annotate(f'σ = {sigma:.1f}',
|
686 |
+
xy=(mu+sigma, stats.norm.pdf(mu+sigma, mu, sigma)),
|
687 |
+
xytext=(mu+1.5*sigma, stats.norm.pdf(mu+sigma, mu, sigma)*1.5),
|
688 |
+
arrowprops=dict(facecolor='black', width=1, shrink=0.05))
|
689 |
+
|
690 |
+
# some styling
|
691 |
+
ax1.set_title(f'Normal Distribution PDF: N({mu:.1f}, {sigma:.1f}²)')
|
692 |
+
ax1.set_xlabel('x')
|
693 |
+
ax1.set_ylabel('Probability Density: f(x)')
|
694 |
+
ax1.legend(loc='upper right')
|
695 |
+
ax1.grid(alpha=0.3)
|
696 |
+
|
697 |
+
# CDF plot
|
698 |
+
ax2.plot(x, cdf, color='darkorange', linewidth=2, label='CDF')
|
699 |
+
|
700 |
+
# key CDF values mark
|
701 |
+
key_points = [
|
702 |
+
(mu-sigma, stats.norm.cdf(mu-sigma, mu, sigma), "16%"),
|
703 |
+
(mu, 0.5, "50%"),
|
704 |
+
(mu+sigma, stats.norm.cdf(mu+sigma, mu, sigma), "84%")
|
705 |
+
]
|
706 |
+
|
707 |
+
for point, value, label in key_points:
|
708 |
+
ax2.plot(point, value, 'ro')
|
709 |
+
ax2.annotate(f'{label}',
|
710 |
+
xy=(point, value),
|
711 |
+
xytext=(point+0.2*sigma, value-0.1),
|
712 |
+
arrowprops=dict(facecolor='black', width=1, shrink=0.05))
|
713 |
+
|
714 |
+
# CDF styling
|
715 |
+
ax2.set_title(f'Normal Distribution CDF: N({mu:.1f}, {sigma:.1f}²)')
|
716 |
+
ax2.set_xlabel('x')
|
717 |
+
ax2.set_ylabel('Cumulative Probability: F(x)')
|
718 |
+
ax2.grid(alpha=0.3)
|
719 |
+
|
720 |
+
plt.tight_layout()
|
721 |
+
return pdf_fig
|
722 |
+
return (create_normal_pdf_plot,)
|
723 |
+
|
724 |
+
|
725 |
+
@app.cell(hide_code=True)
|
726 |
+
def _(base64, io):
|
727 |
+
from matplotlib.figure import Figure
|
728 |
+
|
729 |
+
# convert matplotlib figures to images (helper code)
|
730 |
+
def fig_to_image(fig):
|
731 |
+
buf = io.BytesIO()
|
732 |
+
fig.savefig(buf, format='png', bbox_inches='tight')
|
733 |
+
buf.seek(0)
|
734 |
+
img_str = base64.b64encode(buf.getvalue()).decode('utf-8')
|
735 |
+
return f"data:image/png;base64,{img_str}"
|
736 |
+
return Figure, fig_to_image
|
737 |
+
|
738 |
+
|
739 |
+
@app.cell(hide_code=True)
|
740 |
+
def _():
|
741 |
+
# Import libraries
|
742 |
+
import numpy as np
|
743 |
+
import matplotlib.pyplot as plt
|
744 |
+
from scipy import stats
|
745 |
+
import io
|
746 |
+
import base64
|
747 |
+
return base64, io, np, plt, stats
|
748 |
+
|
749 |
+
|
750 |
+
@app.cell(hide_code=True)
|
751 |
+
def _(TangleSlider, mo):
|
752 |
+
mean_slider = mo.ui.anywidget(TangleSlider(
|
753 |
+
amount=0,
|
754 |
+
min_value=-5,
|
755 |
+
max_value=5,
|
756 |
+
step=0.1,
|
757 |
+
digits=1
|
758 |
+
))
|
759 |
+
|
760 |
+
std_slider = mo.ui.anywidget(TangleSlider(
|
761 |
+
amount=1,
|
762 |
+
min_value=0.1,
|
763 |
+
max_value=3,
|
764 |
+
step=0.1,
|
765 |
+
digits=1
|
766 |
+
))
|
767 |
+
return mean_slider, std_slider
|
768 |
+
|
769 |
+
|
770 |
+
@app.cell(hide_code=True)
|
771 |
+
def _(TangleSlider, mo):
|
772 |
+
x_slider = mo.ui.anywidget(TangleSlider(
|
773 |
+
amount=0,
|
774 |
+
min_value=-5,
|
775 |
+
max_value=5,
|
776 |
+
step=0.1,
|
777 |
+
digits=1
|
778 |
+
))
|
779 |
+
|
780 |
+
mu_slider = mo.ui.anywidget(TangleSlider(
|
781 |
+
amount=0,
|
782 |
+
min_value=-5,
|
783 |
+
max_value=5,
|
784 |
+
step=0.1,
|
785 |
+
digits=1
|
786 |
+
))
|
787 |
+
|
788 |
+
sigma_slider = mo.ui.anywidget(TangleSlider(
|
789 |
+
amount=1,
|
790 |
+
min_value=0.1,
|
791 |
+
max_value=3,
|
792 |
+
step=0.1,
|
793 |
+
digits=1
|
794 |
+
))
|
795 |
+
return mu_slider, sigma_slider, x_slider
|
796 |
+
|
797 |
+
|
798 |
+
@app.cell(hide_code=True)
|
799 |
+
def _(np, plt, stats):
|
800 |
+
def create_distribution_comparison(mu=5, sigma=6):
|
801 |
+
|
802 |
+
# Create figure and axis
|
803 |
+
comparison_fig, ax = plt.subplots(figsize=(10, 6))
|
804 |
+
|
805 |
+
# X range for plotting
|
806 |
+
x = np.linspace(-10, 20, 1000)
|
807 |
+
|
808 |
+
# Standard normal
|
809 |
+
std_normal = stats.norm.pdf(x, 0, 1)
|
810 |
+
|
811 |
+
# Our example normal
|
812 |
+
example_normal = stats.norm.pdf(x, mu, sigma)
|
813 |
+
|
814 |
+
# Plot both distributions
|
815 |
+
ax.plot(x, std_normal, 'darkviolet', linewidth=2, label='Standard Normal')
|
816 |
+
ax.plot(x, example_normal, 'blue', linewidth=2, label=f'X ~ N({mu}, {sigma}²)')
|
817 |
+
|
818 |
+
# format the plot
|
819 |
+
ax.set_xlim(-10, 20)
|
820 |
+
ax.set_ylim(0, 0.45)
|
821 |
+
ax.set_xlabel('x')
|
822 |
+
ax.set_ylabel('Probability Density')
|
823 |
+
ax.grid(True, alpha=0.3)
|
824 |
+
ax.legend()
|
825 |
+
|
826 |
+
# Decorative text box for parameters
|
827 |
+
props = dict(boxstyle='round', facecolor='white', alpha=0.9)
|
828 |
+
textstr = '\n'.join((
|
829 |
+
r'Normal (aka Gaussian) Random Variable',
|
830 |
+
r'',
|
831 |
+
f'Parameter $\mu$: {mu}',
|
832 |
+
f'Parameter $\sigma$: {sigma}'
|
833 |
+
))
|
834 |
+
ax.text(0.05, 0.95, textstr, transform=ax.transAxes, fontsize=10,
|
835 |
+
verticalalignment='top', bbox=props)
|
836 |
+
|
837 |
+
return comparison_fig
|
838 |
+
return (create_distribution_comparison,)
|
839 |
+
|
840 |
+
|
841 |
+
@app.cell(hide_code=True)
|
842 |
+
def _(np, plt, stats):
|
843 |
+
def create_voltage_example_visualization():
|
844 |
+
|
845 |
+
# Create data for plotting
|
846 |
+
x = np.linspace(-4, 4, 1000)
|
847 |
+
|
848 |
+
# Signal without noise (X = 2)
|
849 |
+
signal_value = 2
|
850 |
+
|
851 |
+
# Noise distribution (Y ~ N(0, 1))
|
852 |
+
noise_pdf = stats.norm.pdf(x, 0, 1)
|
853 |
+
|
854 |
+
# Signal + Noise distribution (R = X + Y ~ N(2, 1))
|
855 |
+
received_pdf = stats.norm.pdf(x, signal_value, 1)
|
856 |
+
|
857 |
+
# Create figure
|
858 |
+
voltage_fig, ax = plt.subplots(figsize=(10, 6))
|
859 |
+
|
860 |
+
# Plot the noise distribution
|
861 |
+
ax.plot(x, noise_pdf, 'blue', linewidth=1.5, alpha=0.6,
|
862 |
+
label='Noise: Y ~ N(0, 1)')
|
863 |
+
|
864 |
+
# received signal distribution
|
865 |
+
ax.plot(x, received_pdf, 'red', linewidth=2,
|
866 |
+
label=f'Received: R ~ N({signal_value}, 1)')
|
867 |
+
|
868 |
+
# vertical line at the decision boundary (0.5)
|
869 |
+
threshold = 0.5
|
870 |
+
ax.axvline(x=threshold, color='green', linestyle='--', linewidth=2,
|
871 |
+
label=f'Decision threshold: {threshold}')
|
872 |
+
|
873 |
+
# Shade the error region
|
874 |
+
mask = x < threshold
|
875 |
+
error_prob = stats.norm.cdf(threshold, signal_value, 1)
|
876 |
+
ax.fill_between(x[mask], received_pdf[mask], color='darkorange', alpha=0.5,
|
877 |
+
label=f'Error probability: {error_prob:.3f}')
|
878 |
+
|
879 |
+
# Styling
|
880 |
+
ax.set_title('Voltage Transmission Example: Probability of Error')
|
881 |
+
ax.set_xlabel('Voltage')
|
882 |
+
ax.set_ylabel('Probability Density')
|
883 |
+
ax.legend(loc='upper left')
|
884 |
+
ax.grid(alpha=0.3)
|
885 |
+
|
886 |
+
# Add explanatory annotations
|
887 |
+
ax.text(1.5, 0.1, 'When sending "1" (voltage=2),\nthis area represents\nthe error probability',
|
888 |
+
bbox=dict(boxstyle="round,pad=0.3", fc="white", ec="black", lw=1))
|
889 |
+
|
890 |
+
plt.tight_layout()
|
891 |
+
plt.gca()
|
892 |
+
return voltage_fig, error_prob
|
893 |
+
return (create_voltage_example_visualization,)
|
894 |
+
|
895 |
+
|
896 |
+
@app.cell(hide_code=True)
|
897 |
+
def _(np, plt, stats):
|
898 |
+
def create_cdf_calculator_plot(calc_x, calc_mu, calc_sigma):
|
899 |
+
|
900 |
+
# Data range for plotting
|
901 |
+
x_range = np.linspace(calc_mu - 4*calc_sigma, calc_mu + 4*calc_sigma, 1000)
|
902 |
+
pdf = stats.norm.pdf(x_range, calc_mu, calc_sigma)
|
903 |
+
cdf = stats.norm.cdf(x_range, calc_mu, calc_sigma)
|
904 |
+
|
905 |
+
# Calculate the CDF at x
|
906 |
+
cdf_at_x = stats.norm.cdf(calc_x, calc_mu, calc_sigma)
|
907 |
+
|
908 |
+
# Create figure with two subplots
|
909 |
+
calc_fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(10, 8))
|
910 |
+
|
911 |
+
# Plot PDF on top subplot
|
912 |
+
ax1.plot(x_range, pdf, color='royalblue', linewidth=2, label='PDF')
|
913 |
+
|
914 |
+
# area shade for P(X ≤ x)
|
915 |
+
mask = x_range <= calc_x
|
916 |
+
ax1.fill_between(x_range[mask], pdf[mask], color='darkorange', alpha=0.6)
|
917 |
+
|
918 |
+
# Vertical line at x
|
919 |
+
ax1.axvline(x=calc_x, color='red', linestyle='--', linewidth=1.5)
|
920 |
+
|
921 |
+
# PDF labels and styling
|
922 |
+
ax1.set_title(f'Normal PDF with Area P(X ≤ {calc_x:.1f}) Highlighted')
|
923 |
+
ax1.set_xlabel('x')
|
924 |
+
ax1.set_ylabel('Probability Density')
|
925 |
+
ax1.annotate(f'x = {calc_x:.1f}', xy=(calc_x, 0), xytext=(calc_x, -0.01),
|
926 |
+
horizontalalignment='center', color='red')
|
927 |
+
ax1.grid(alpha=0.3)
|
928 |
+
|
929 |
+
# CDF on bottom subplot
|
930 |
+
ax2.plot(x_range, cdf, color='green', linewidth=2, label='CDF')
|
931 |
+
|
932 |
+
# Mark the point (x, CDF(x))
|
933 |
+
ax2.plot(calc_x, cdf_at_x, 'ro', markersize=8)
|
934 |
+
|
935 |
+
# CDF labels and styling
|
936 |
+
ax2.set_title(f'Normal CDF: F({calc_x:.1f}) = {cdf_at_x:.3f}')
|
937 |
+
ax2.set_xlabel('x')
|
938 |
+
ax2.set_ylabel('Cumulative Probability')
|
939 |
+
ax2.annotate(f'F({calc_x:.1f}) = {cdf_at_x:.3f}',
|
940 |
+
xy=(calc_x, cdf_at_x),
|
941 |
+
xytext=(calc_x + 0.5*calc_sigma, cdf_at_x - 0.1),
|
942 |
+
arrowprops=dict(facecolor='black', width=1, shrink=0.05),
|
943 |
+
bbox=dict(boxstyle="round,pad=0.3", fc="white", ec="black", lw=1))
|
944 |
+
ax2.grid(alpha=0.3)
|
945 |
+
|
946 |
+
plt.tight_layout()
|
947 |
+
plt.gca()
|
948 |
+
return calc_fig, cdf_at_x
|
949 |
+
return (create_cdf_calculator_plot,)
|
950 |
+
|
951 |
+
|
952 |
+
@app.cell(hide_code=True)
|
953 |
+
def _(np, plt, stats):
|
954 |
+
def create_standardization_plot():
|
955 |
+
|
956 |
+
x = np.linspace(-6, 6, 1000)
|
957 |
+
|
958 |
+
# Original distribution N(2, 1.5²)
|
959 |
+
mu_original, sigma_original = 2, 1.5
|
960 |
+
pdf_original = stats.norm.pdf(x, mu_original, sigma_original)
|
961 |
+
|
962 |
+
# shifted distribution N(0, 1.5²)
|
963 |
+
mu_shifted, sigma_shifted = 0, 1.5
|
964 |
+
pdf_shifted = stats.norm.pdf(x, mu_shifted, sigma_shifted)
|
965 |
+
|
966 |
+
# Standard normal N(0, 1)
|
967 |
+
mu_standard, sigma_standard = 0, 1
|
968 |
+
pdf_standard = stats.norm.pdf(x, mu_standard, sigma_standard)
|
969 |
+
|
970 |
+
# Create visualization
|
971 |
+
stand_fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))
|
972 |
+
|
973 |
+
# Plot on left: Original and shifted distributions
|
974 |
+
ax1.plot(x, pdf_original, 'royalblue', linewidth=2,
|
975 |
+
label=f'Original: N({mu_original}, {sigma_original}²)')
|
976 |
+
ax1.plot(x, pdf_shifted, 'darkorange', linewidth=2,
|
977 |
+
label=f'Shifted: N({mu_shifted}, {sigma_shifted}²)')
|
978 |
+
|
979 |
+
# Add arrow to show the shift
|
980 |
+
shift_x1, shift_y1 = mu_original, stats.norm.pdf(mu_original, mu_original, sigma_original)*0.6
|
981 |
+
shift_x2, shift_y2 = mu_shifted, stats.norm.pdf(mu_shifted, mu_shifted, sigma_shifted)*0.6
|
982 |
+
ax1.annotate('', xy=(shift_x2, shift_y2), xytext=(shift_x1, shift_y1),
|
983 |
+
arrowprops=dict(facecolor='black', width=1.5, shrink=0.05))
|
984 |
+
ax1.text(0.8, 0.28, 'Subtract μ', transform=ax1.transAxes)
|
985 |
+
|
986 |
+
# Plot on right: Shifted and standard normal
|
987 |
+
ax2.plot(x, pdf_shifted, 'darkorange', linewidth=2,
|
988 |
+
label=f'Shifted: N({mu_shifted}, {sigma_shifted}²)')
|
989 |
+
ax2.plot(x, pdf_standard, 'green', linewidth=2,
|
990 |
+
label=f'Standard: N({mu_standard}, {sigma_standard}²)')
|
991 |
+
|
992 |
+
# Add arrow to show the scaling
|
993 |
+
scale_x1, scale_y1 = 2*sigma_shifted, stats.norm.pdf(2*sigma_shifted, mu_shifted, sigma_shifted)*0.8
|
994 |
+
scale_x2, scale_y2 = 2*sigma_standard, stats.norm.pdf(2*sigma_standard, mu_standard, sigma_standard)*0.8
|
995 |
+
ax2.annotate('', xy=(scale_x2, scale_y2), xytext=(scale_x1, scale_y1),
|
996 |
+
arrowprops=dict(facecolor='black', width=1.5, shrink=0.05))
|
997 |
+
ax2.text(0.75, 0.5, 'Divide by σ', transform=ax2.transAxes)
|
998 |
+
|
999 |
+
# some styling
|
1000 |
+
for ax in (ax1, ax2):
|
1001 |
+
ax.set_xlabel('x')
|
1002 |
+
ax.set_ylabel('Probability Density')
|
1003 |
+
ax.grid(alpha=0.3)
|
1004 |
+
ax.legend()
|
1005 |
+
|
1006 |
+
ax1.set_title('Step 1: Shift the Distribution')
|
1007 |
+
ax2.set_title('Step 2: Scale the Distribution')
|
1008 |
+
|
1009 |
+
plt.tight_layout()
|
1010 |
+
plt.gca()
|
1011 |
+
return stand_fig
|
1012 |
+
return (create_standardization_plot,)
|
1013 |
+
|
1014 |
+
|
1015 |
+
@app.cell(hide_code=True)
|
1016 |
+
def _(np, plt, stats):
|
1017 |
+
def create_probability_example(example_mu=3, example_sigma=4, example_query=0):
|
1018 |
+
|
1019 |
+
# Create data range
|
1020 |
+
x = np.linspace(example_mu - 4*example_sigma, example_mu + 4*example_sigma, 1000)
|
1021 |
+
pdf = stats.norm.pdf(x, example_mu, example_sigma)
|
1022 |
+
|
1023 |
+
# probability calc
|
1024 |
+
prob_value = 1 - stats.norm.cdf(example_query, example_mu, example_sigma)
|
1025 |
+
ex_z_score = (example_query - example_mu) / example_sigma
|
1026 |
+
|
1027 |
+
# Create visualization
|
1028 |
+
prob_fig, ax = plt.subplots(figsize=(10, 6))
|
1029 |
+
|
1030 |
+
# Plot PDF
|
1031 |
+
ax.plot(x, pdf, 'royalblue', linewidth=2)
|
1032 |
+
|
1033 |
+
# area shading representing the probability
|
1034 |
+
mask = x >= example_query
|
1035 |
+
ax.fill_between(x[mask], pdf[mask], color='darkorange', alpha=0.6)
|
1036 |
+
|
1037 |
+
# Add vertical line at query point
|
1038 |
+
ax.axvline(x=example_query, color='red', linestyle='--', linewidth=1.5)
|
1039 |
+
|
1040 |
+
# Annotations
|
1041 |
+
ax.annotate(f'x = {example_query}', xy=(example_query, 0), xytext=(example_query, -0.005),
|
1042 |
+
horizontalalignment='center')
|
1043 |
+
|
1044 |
+
ax.annotate(f'P(X > {example_query}) = {prob_value:.3f}',
|
1045 |
+
xy=(example_query + example_sigma, 0.015),
|
1046 |
+
xytext=(example_query + 1.5*example_sigma, 0.02),
|
1047 |
+
arrowprops=dict(facecolor='black', width=1, shrink=0.05),
|
1048 |
+
bbox=dict(boxstyle="round,pad=0.3", fc="white", ec="black", lw=1))
|
1049 |
+
|
1050 |
+
# Standard normal calculation annotation
|
1051 |
+
ax.annotate(f'= P(Z > {ex_z_score:.3f}) = {prob_value:.3f}',
|
1052 |
+
xy=(example_query - example_sigma, 0.01),
|
1053 |
+
xytext=(example_query - 2*example_sigma, 0.015),
|
1054 |
+
arrowprops=dict(facecolor='black', width=1, shrink=0.05),
|
1055 |
+
bbox=dict(boxstyle="round,pad=0.3", fc="white", ec="black", lw=1))
|
1056 |
+
|
1057 |
+
# some styling
|
1058 |
+
ax.set_title(f'Example: P(X > {example_query}) where X ~ N({example_mu}, {example_sigma}²)')
|
1059 |
+
ax.set_xlabel('x')
|
1060 |
+
ax.set_ylabel('Probability Density')
|
1061 |
+
ax.grid(alpha=0.3)
|
1062 |
+
|
1063 |
+
plt.tight_layout()
|
1064 |
+
plt.gca()
|
1065 |
+
return prob_fig, prob_value, ex_z_score
|
1066 |
+
return (create_probability_example,)
|
1067 |
+
|
1068 |
+
|
1069 |
+
@app.cell(hide_code=True)
|
1070 |
+
def _(np, plt, stats):
|
1071 |
+
def create_range_probability_example(range_mu=3, range_sigma=4, range_lower=2, range_upper=5):
|
1072 |
+
|
1073 |
+
x = np.linspace(range_mu - 4*range_sigma, range_mu + 4*range_sigma, 1000)
|
1074 |
+
pdf = stats.norm.pdf(x, range_mu, range_sigma)
|
1075 |
+
|
1076 |
+
# probability
|
1077 |
+
range_prob = stats.norm.cdf(range_upper, range_mu, range_sigma) - stats.norm.cdf(range_lower, range_mu, range_sigma)
|
1078 |
+
range_z_lower = (range_lower - range_mu) / range_sigma
|
1079 |
+
range_z_upper = (range_upper - range_mu) / range_sigma
|
1080 |
+
|
1081 |
+
# Create visualization
|
1082 |
+
range_fig, ax = plt.subplots(figsize=(10, 6))
|
1083 |
+
|
1084 |
+
# Plot PDF
|
1085 |
+
ax.plot(x, pdf, 'royalblue', linewidth=2)
|
1086 |
+
|
1087 |
+
# Shade the area representing the probability
|
1088 |
+
mask = (x >= range_lower) & (x <= range_upper)
|
1089 |
+
ax.fill_between(x[mask], pdf[mask], color='darkorange', alpha=0.6)
|
1090 |
+
|
1091 |
+
# Add vertical lines at query points
|
1092 |
+
ax.axvline(x=range_lower, color='red', linestyle='--', linewidth=1.5)
|
1093 |
+
ax.axvline(x=range_upper, color='red', linestyle='--', linewidth=1.5)
|
1094 |
+
|
1095 |
+
# Annotations
|
1096 |
+
ax.annotate(f'x = {range_lower}', xy=(range_lower, 0), xytext=(range_lower, -0.005),
|
1097 |
+
horizontalalignment='center')
|
1098 |
+
ax.annotate(f'x = {range_upper}', xy=(range_upper, 0), xytext=(range_upper, -0.005),
|
1099 |
+
horizontalalignment='center')
|
1100 |
+
|
1101 |
+
ax.annotate(f'P({range_lower} < X < {range_upper}) = {range_prob:.3f}',
|
1102 |
+
xy=((range_lower + range_upper)/2, max(pdf[mask])/2),
|
1103 |
+
xytext=((range_lower + range_upper)/2, max(pdf[mask])*1.5),
|
1104 |
+
arrowprops=dict(facecolor='black', width=1, shrink=0.05),
|
1105 |
+
bbox=dict(boxstyle="round,pad=0.3", fc="white", ec="black", lw=1),
|
1106 |
+
horizontalalignment='center')
|
1107 |
+
|
1108 |
+
# Standard normal calculation annotation
|
1109 |
+
ax.annotate(f'= P({range_z_lower:.3f} < Z < {range_z_upper:.3f}) = {range_prob:.3f}',
|
1110 |
+
xy=((range_lower + range_upper)/2, max(pdf[mask])/3),
|
1111 |
+
xytext=(range_mu - 2*range_sigma, max(pdf[mask])/1.5),
|
1112 |
+
arrowprops=dict(facecolor='black', width=1, shrink=0.05),
|
1113 |
+
bbox=dict(boxstyle="round,pad=0.3", fc="white", ec="black", lw=1))
|
1114 |
+
|
1115 |
+
ax.set_title(f'Example: P({range_lower} < X < {range_upper}) where X ~ N({range_mu}, {range_sigma}²)')
|
1116 |
+
ax.set_xlabel('x')
|
1117 |
+
ax.set_ylabel('Probability Density')
|
1118 |
+
ax.grid(alpha=0.3)
|
1119 |
+
|
1120 |
+
plt.tight_layout()
|
1121 |
+
plt.gca()
|
1122 |
+
return range_fig, range_prob, range_z_lower, range_z_upper
|
1123 |
+
return (create_range_probability_example,)
|
1124 |
+
|
1125 |
+
|
1126 |
+
if __name__ == "__main__":
|
1127 |
+
app.run()
|
probability/18_central_limit_theorem.py
ADDED
@@ -0,0 +1,943 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# /// script
|
2 |
+
# requires-python = ">=3.10"
|
3 |
+
# dependencies = [
|
4 |
+
# "marimo",
|
5 |
+
# "matplotlib==3.10.1",
|
6 |
+
# "scipy==1.15.2",
|
7 |
+
# "numpy==2.2.4",
|
8 |
+
# "plotly==5.18.0",
|
9 |
+
# ]
|
10 |
+
# ///
|
11 |
+
|
12 |
+
import marimo
|
13 |
+
|
14 |
+
__generated_with = "0.11.30"
|
15 |
+
app = marimo.App(width="medium", app_title="Central Limit Theorem")
|
16 |
+
|
17 |
+
|
18 |
+
@app.cell(hide_code=True)
|
19 |
+
def _(mo):
|
20 |
+
mo.md(
|
21 |
+
r"""
|
22 |
+
# Central Limit Theorem
|
23 |
+
|
24 |
+
_This notebook is a computational companion to ["Probability for Computer Scientists"](https://chrispiech.github.io/probabilityForComputerScientists/en/part4/clt/), by Stanford professor Chris Piech._
|
25 |
+
|
26 |
+
The Central Limit Theorem (CLT) is one of the most important concepts in probability theory and statistics. It explains why many real-world distributions tend to be normal, even when the underlying processes are not.
|
27 |
+
"""
|
28 |
+
)
|
29 |
+
return
|
30 |
+
|
31 |
+
|
32 |
+
@app.cell(hide_code=True)
|
33 |
+
def _(mo):
|
34 |
+
mo.md(
|
35 |
+
r"""
|
36 |
+
## Central Limit Theorem Statement
|
37 |
+
|
38 |
+
There are two ways to state the central limit theorem:
|
39 |
+
|
40 |
+
### Sum Version
|
41 |
+
|
42 |
+
Let $X_1, X_2, \dots, X_n$ be independent and identically distributed random variables. The sum of these random variables approaches a normal distribution as $n \rightarrow \infty$:
|
43 |
+
|
44 |
+
$n∑i=1Xi∼N(n⋅μ,n⋅σ2)\sum_{i=1}^{n}X_i \sim \mathcal{N}(n \cdot \mu, n \cdot \sigma^2)$
|
45 |
+
|
46 |
+
Where $\mu = E[X_i]$ and $\sigma^2 = \text{Var}(X_i)$. Since each $X_i$ is identically distributed, they share the same expectation and variance.
|
47 |
+
|
48 |
+
### Average Version
|
49 |
+
|
50 |
+
Let $X_1, X_2, \dots, X_n$ be independent and identically distributed random variables. The average of these random variables approaches a normal distribution as $n \rightarrow \infty$:
|
51 |
+
|
52 |
+
$\frac{1}{n} ∑i=1Xi∼N(μ,σ2n)\frac{1}{n}\sum_{i=1}^{n}X_i \sim \mathcal{N}(\mu, \frac{\sigma^2}{n})$
|
53 |
+
|
54 |
+
Where $\mu = E[X_i]$ and $\sigma^2 = \text{Var}(X_i)$.
|
55 |
+
|
56 |
+
The CLT is incredible because it applies to almost any distribution (as long as it has a finite mean and variance), regardless of its shape.
|
57 |
+
"""
|
58 |
+
)
|
59 |
+
return
|
60 |
+
|
61 |
+
|
62 |
+
@app.cell(hide_code=True)
|
63 |
+
def _(mo):
|
64 |
+
mo.md(
|
65 |
+
r"""
|
66 |
+
## Central Limit Theorem Intuition
|
67 |
+
|
68 |
+
Let's explore what happens when you add random variables together. For example, what if we add 100 different uniform random variables?
|
69 |
+
|
70 |
+
```python
|
71 |
+
from random import random
|
72 |
+
|
73 |
+
def add_100_uniforms():
|
74 |
+
total = 0
|
75 |
+
for i in range(100):
|
76 |
+
# returns a sample from uniform(0, 1)
|
77 |
+
x_i = random()
|
78 |
+
total += x_i
|
79 |
+
return total
|
80 |
+
```
|
81 |
+
|
82 |
+
The value returned by this function will be a random variable. Click the button below to run the function and observe the resulting value of total:
|
83 |
+
"""
|
84 |
+
)
|
85 |
+
return
|
86 |
+
|
87 |
+
|
88 |
+
@app.cell(hide_code=True)
|
89 |
+
def _(mo):
|
90 |
+
run_button = mo.ui.run_button(label="Run add_100_uniforms()")
|
91 |
+
|
92 |
+
run_button.center()
|
93 |
+
return (run_button,)
|
94 |
+
|
95 |
+
|
96 |
+
@app.cell(hide_code=True)
|
97 |
+
def _(mo, random, run_button):
|
98 |
+
def add_100_uniforms():
|
99 |
+
total = 0
|
100 |
+
for i in range(100):
|
101 |
+
# returns a sample from uniform(0, 1)
|
102 |
+
x_i = random.random()
|
103 |
+
total += x_i
|
104 |
+
return total
|
105 |
+
|
106 |
+
# Display the result when the button is clicked
|
107 |
+
if run_button.value:
|
108 |
+
uniform_result = add_100_uniforms()
|
109 |
+
display = mo.md(f"**total**: {uniform_result:.5f}")
|
110 |
+
else:
|
111 |
+
display = mo.md("")
|
112 |
+
|
113 |
+
display
|
114 |
+
return add_100_uniforms, display, uniform_result
|
115 |
+
|
116 |
+
|
117 |
+
@app.cell(hide_code=True)
|
118 |
+
def _(mo):
|
119 |
+
mo.md(r"""What does total look like as a distribution? Let's calculate total many times and visualize the histogram of values it produces.""")
|
120 |
+
return
|
121 |
+
|
122 |
+
|
123 |
+
@app.cell(hide_code=True)
|
124 |
+
def _(mo):
|
125 |
+
# Simulation control
|
126 |
+
run_simulation_button = mo.ui.button(
|
127 |
+
value=0,
|
128 |
+
on_click=lambda value: value + 1,
|
129 |
+
label="Run 10,000 more samples",
|
130 |
+
kind="warn"
|
131 |
+
)
|
132 |
+
|
133 |
+
run_simulation_button.center()
|
134 |
+
return (run_simulation_button,)
|
135 |
+
|
136 |
+
|
137 |
+
@app.cell(hide_code=True)
|
138 |
+
def _(add_100_uniforms, go, mo, np, run_simulation_button, stats, time):
|
139 |
+
# store the results
|
140 |
+
def get_simulation_results():
|
141 |
+
if not hasattr(get_simulation_results, "results"):
|
142 |
+
get_simulation_results.results = []
|
143 |
+
get_simulation_results.last_button_value = -1 # track button clicks
|
144 |
+
return get_simulation_results
|
145 |
+
|
146 |
+
# grab the results
|
147 |
+
sim_storage = get_simulation_results()
|
148 |
+
simulation_results = sim_storage.results
|
149 |
+
|
150 |
+
# Check if button was clicked (value changed)
|
151 |
+
if run_simulation_button.value != sim_storage.last_button_value:
|
152 |
+
# Update the last seen button value
|
153 |
+
sim_storage.last_button_value = run_simulation_button.value
|
154 |
+
|
155 |
+
with mo.status.spinner(title="Running simulation...") as progress_status:
|
156 |
+
sim_count = 10000
|
157 |
+
new_results = []
|
158 |
+
for _ in mo.status.progress_bar(range(sim_count)):
|
159 |
+
sim_result = add_100_uniforms()
|
160 |
+
new_results.append(sim_result)
|
161 |
+
time.sleep(0.0001) # tiny pause
|
162 |
+
|
163 |
+
simulation_results.extend(new_results)
|
164 |
+
|
165 |
+
progress_status.update(f"✅ Added {sim_count:,} samples (total: {len(simulation_results):,})")
|
166 |
+
|
167 |
+
if simulation_results:
|
168 |
+
# Numbers
|
169 |
+
mean = np.mean(simulation_results)
|
170 |
+
std_dev = np.std(simulation_results)
|
171 |
+
|
172 |
+
theoretical_mean = 100 * 0.5 # = 50
|
173 |
+
theoretical_variance = 100 * (1/12) # = 8.33...
|
174 |
+
theoretical_std = np.sqrt(theoretical_variance) # ≈ 2.89
|
175 |
+
|
176 |
+
# should be 10k times the click number (mainly for the y-axis label)
|
177 |
+
total_samples = run_simulation_button.value * 10000
|
178 |
+
|
179 |
+
fig = go.Figure()
|
180 |
+
|
181 |
+
# histogram of samples
|
182 |
+
fig.add_trace(go.Histogram(
|
183 |
+
x=simulation_results,
|
184 |
+
histnorm='probability density',
|
185 |
+
name='Sum Distribution',
|
186 |
+
marker_color='royalblue',
|
187 |
+
opacity=0.7
|
188 |
+
))
|
189 |
+
|
190 |
+
x_vals = np.linspace(min(simulation_results), max(simulation_results), 1000)
|
191 |
+
y_vals = stats.norm.pdf(x_vals, theoretical_mean, theoretical_std)
|
192 |
+
|
193 |
+
fig.add_trace(go.Scatter(
|
194 |
+
x=x_vals,
|
195 |
+
y=y_vals,
|
196 |
+
mode='lines',
|
197 |
+
name='Normal approximation',
|
198 |
+
line=dict(color='red', width=2)
|
199 |
+
))
|
200 |
+
|
201 |
+
fig.add_vline(
|
202 |
+
x=mean,
|
203 |
+
line_dash="dash",
|
204 |
+
line_width=1.5,
|
205 |
+
line_color="green",
|
206 |
+
annotation_text=f"Sample Mean: {mean:.2f}",
|
207 |
+
annotation_position="top right"
|
208 |
+
)
|
209 |
+
|
210 |
+
# some notes
|
211 |
+
fig.add_annotation(
|
212 |
+
x=0.02, y=0.95,
|
213 |
+
xref="paper", yref="paper",
|
214 |
+
text=f"Sum of 100 Uniform(0,1) variables<br>" +
|
215 |
+
f"Sample size: {total_samples:,}<br>" +
|
216 |
+
f"Sample mean: {mean:.2f} (expected: {theoretical_mean})<br>" +
|
217 |
+
f"Sample std: {std_dev:.2f} (expected: {theoretical_std:.2f})<br>" +
|
218 |
+
f"According to CLT: Normal({theoretical_mean}, {theoretical_variance:.2f})",
|
219 |
+
showarrow=False,
|
220 |
+
align="left",
|
221 |
+
bgcolor="white",
|
222 |
+
opacity=0.8
|
223 |
+
)
|
224 |
+
|
225 |
+
fig.update_layout(
|
226 |
+
title=f'Distribution of Sum of 100 Uniforms (Click #{run_simulation_button.value})',
|
227 |
+
xaxis_title='Values',
|
228 |
+
yaxis_title=f'Probability Density ({total_samples:,} runs)',
|
229 |
+
template='plotly_white',
|
230 |
+
height=500
|
231 |
+
)
|
232 |
+
|
233 |
+
# show
|
234 |
+
histogram = mo.ui.plotly(fig)
|
235 |
+
else:
|
236 |
+
histogram = mo.md("Click the button to run the simulation!")
|
237 |
+
|
238 |
+
# display
|
239 |
+
histogram
|
240 |
+
return (
|
241 |
+
fig,
|
242 |
+
get_simulation_results,
|
243 |
+
histogram,
|
244 |
+
mean,
|
245 |
+
new_results,
|
246 |
+
progress_status,
|
247 |
+
sim_count,
|
248 |
+
sim_result,
|
249 |
+
sim_storage,
|
250 |
+
simulation_results,
|
251 |
+
std_dev,
|
252 |
+
theoretical_mean,
|
253 |
+
theoretical_std,
|
254 |
+
theoretical_variance,
|
255 |
+
total_samples,
|
256 |
+
x_vals,
|
257 |
+
y_vals,
|
258 |
+
)
|
259 |
+
|
260 |
+
|
261 |
+
@app.cell(hide_code=True)
|
262 |
+
def _(mo):
|
263 |
+
mo.md(
|
264 |
+
r"""
|
265 |
+
That is interesting! The sum of 100 independent uniforms looks normal. Is that a special property of uniforms? No! It turns out to work for almost any type of distribution (as long as the distribution has finite mean and variance).
|
266 |
+
|
267 |
+
- Sum of 40 $X_i$ where $X_i \sim \text{Beta}(a = 5, b = 4)$? Normal.
|
268 |
+
- Sum of 90 $X_i$ where $X_i \sim \text{Poisson}(\lambda = 4)$? Normal.
|
269 |
+
- Sum of 50 dice-rolls? Normal.
|
270 |
+
- Average of 10000 $X_i$ where $X_i \sim \text{Exp}(\lambda = 8)$? Normal.
|
271 |
+
|
272 |
+
For any distribution, the sum or average of a sufficiently large number of independent, identically distributed random variables will be approximately normally distributed.
|
273 |
+
"""
|
274 |
+
)
|
275 |
+
return
|
276 |
+
|
277 |
+
|
278 |
+
@app.cell(hide_code=True)
|
279 |
+
def _(mo):
|
280 |
+
mo.md(
|
281 |
+
r"""
|
282 |
+
## Continuity Correction
|
283 |
+
|
284 |
+
When using the Central Limit Theorem with discrete random variables (like a Binomial or Poisson), we need to apply a continuity correction. This is because we're approximating a discrete distribution with a continuous one (normal).
|
285 |
+
|
286 |
+
The continuity correction involves adjusting the boundaries in probability calculations by ±0.5 to account for the discrete nature of the original variable.
|
287 |
+
|
288 |
+
You should use a continuity correction any time your normal is approximating a discrete random variable. The rules for a general continuity correction are the same as the rules for the [binomial-approximation continuity correction](http://marimo.app/https://github.com/marimo-team/learn/blob/main/probability/14_binomial_distribution.py).
|
289 |
+
|
290 |
+
In our example above, where we added 100 uniforms, a continuity correction isn't needed because the sum of uniforms is continuous. However, in examples with dice or other discrete distributions, a continuity correction would be necessary.
|
291 |
+
"""
|
292 |
+
)
|
293 |
+
return
|
294 |
+
|
295 |
+
|
296 |
+
@app.cell(hide_code=True)
|
297 |
+
def _(mo):
|
298 |
+
mo.md(
|
299 |
+
r"""
|
300 |
+
## Examples
|
301 |
+
|
302 |
+
Let's work through some practical examples to see how the Central Limit Theorem is applied.
|
303 |
+
"""
|
304 |
+
)
|
305 |
+
return
|
306 |
+
|
307 |
+
|
308 |
+
@app.cell(hide_code=True)
|
309 |
+
def _(mo):
|
310 |
+
mo.md(
|
311 |
+
r"""
|
312 |
+
### Example 1: Dice Game
|
313 |
+
|
314 |
+
You will roll a 6-sided dice 10 times. Let $X$ be the total value of all 10 dice: $X = X_1 + X_2 + \dots + X_{10}$. You win the game if $X \leq 25$ or $X \geq 45$. Use the central limit theorem to calculate the probability that you win.
|
315 |
+
|
316 |
+
Recall that for a single die roll $X_i$:
|
317 |
+
|
318 |
+
- $E[X_i] = 3.5$
|
319 |
+
- $\text{Var}(X_i) = \frac{35}{12}$
|
320 |
+
|
321 |
+
**Solution:**
|
322 |
+
|
323 |
+
Let $Y$ be the approximating normal distribution. By the Central Limit Theorem:
|
324 |
+
|
325 |
+
$Y∼N(10⋅E[Xi],10⋅Var(Xi))Y \sim \mathcal{N}(10 \cdot E[X_i], 10 \cdot \text{Var}(X_i))$
|
326 |
+
|
327 |
+
Substituting in the known values:
|
328 |
+
|
329 |
+
$Y∼N(10⋅3.5,10⋅3512)=N(35,29.2)Y \sim \mathcal{N}(10 \cdot 3.5, 10 \cdot \frac{35}{12}) = \mathcal{N}(35, 29.2)$
|
330 |
+
|
331 |
+
Now we calculate the probability:
|
332 |
+
|
333 |
+
$P(X≤25 or X≥45)P(X \leq 25 \text{ or } X \geq 45)$
|
334 |
+
|
335 |
+
$=P(X≤25)+P(X≥45)= P(X \leq 25) + P(X \geq 45)$
|
336 |
+
|
337 |
+
$≈P(Y<25.5)+P(Y>44.5) (Continuity Correction)\approx P(Y < 25.5) + P(Y > 44.5) \text{ (Continuity Correction)}$
|
338 |
+
|
339 |
+
$≈P(Y<25.5)+[1−P(Y<44.5)]\approx P(Y < 25.5) + [1 - P(Y < 44.5)]$
|
340 |
+
|
341 |
+
$≈Φ(25.5−35√29.2)+[1−Φ(44.5−35√29.2)]\approx \Phi\left(\frac{25.5 - 35}{\sqrt{29.2}}\right) + \left[1 - \Phi\left(\frac{44.5 - 35}{\sqrt{29.2}}\right)\right]$
|
342 |
+
|
343 |
+
$≈Φ(−1.76)+[1−Φ(1.76)]\approx \Phi(-1.76) + [1 - \Phi(1.76)]$
|
344 |
+
|
345 |
+
$≈0.039+(1−0.961)\approx 0.039 + (1 - 0.961)$
|
346 |
+
|
347 |
+
$≈0.078\approx 0.078$
|
348 |
+
So, the probability of winning the game is approximately 7.8%.
|
349 |
+
"""
|
350 |
+
)
|
351 |
+
return
|
352 |
+
|
353 |
+
|
354 |
+
@app.cell(hide_code=True)
|
355 |
+
def _(create_dice_game_visualization, fig_to_image, mo):
|
356 |
+
# Display visualization
|
357 |
+
dice_game_fig = create_dice_game_visualization()
|
358 |
+
dice_game_image = mo.image(fig_to_image(dice_game_fig), width="100%")
|
359 |
+
|
360 |
+
dice_explanation = mo.md(
|
361 |
+
r"""
|
362 |
+
**Visualization Explanation:**
|
363 |
+
|
364 |
+
The graph shows the distribution of the sum of 10 dice rolls. The blue bars represent the actual probability mass function (PMF), while the red curve shows the normal approximation using the Central Limit Theorem.
|
365 |
+
|
366 |
+
The winning regions are shaded in orange:
|
367 |
+
- The left region where $X \leq 25$
|
368 |
+
- The right region where $X \geq 45$
|
369 |
+
|
370 |
+
The total probability of these regions is approximately 0.078 or 7.8%.
|
371 |
+
|
372 |
+
Notice how the normal approximation provides a good fit to the discrete distribution, demonstrating the power of the Central Limit Theorem.
|
373 |
+
"""
|
374 |
+
)
|
375 |
+
|
376 |
+
mo.vstack([dice_game_image, dice_explanation])
|
377 |
+
return dice_explanation, dice_game_fig, dice_game_image
|
378 |
+
|
379 |
+
|
380 |
+
@app.cell(hide_code=True)
|
381 |
+
def _(mo):
|
382 |
+
mo.md(
|
383 |
+
r"""
|
384 |
+
### Example 2: Algorithm Runtime Estimation
|
385 |
+
|
386 |
+
Say you have a new algorithm and you want to test its running time. You know the variance of the algorithm's run time is $\sigma^2 = 4 \text{ sec}^2$, but you want to estimate the mean run time $t$ in seconds.
|
387 |
+
|
388 |
+
You can run the algorithm repeatedly (IID trials). How many trials do you have to run so that your estimated runtime is within ±0.5 seconds of $t$ with 95% certainty?
|
389 |
+
|
390 |
+
Let $X_i$ be the run time of the $i$-th run (for $1 \leq i \leq n$).
|
391 |
+
|
392 |
+
**Solution:**
|
393 |
+
|
394 |
+
We need to find $n$ such that:
|
395 |
+
|
396 |
+
$0.95=P(−0.5≤∑ni=1Xin−t≤0.5)0.95 = P\left(-0.5 \leq \frac{\sum_{i=1}^n X_i}{n} - t \leq 0.5\right)$
|
397 |
+
|
398 |
+
By the central limit theorem, the sample mean follows a normal distribution.
|
399 |
+
We can standardize this to work with the standard normal:
|
400 |
+
|
401 |
+
$Z=(∑ni=1Xi)−nμσ√nZ = \frac{\left(\sum_{i=1}^n X_i\right) - n\mu}{\sigma \sqrt{n}}$
|
402 |
+
|
403 |
+
$=(∑ni=1Xi)−nt2√n= \frac{\left(\sum_{i=1}^n X_i\right) - nt}{2 \sqrt{n}}$
|
404 |
+
|
405 |
+
Rewriting our probability inequality so that the central term is $Z$:
|
406 |
+
|
407 |
+
$0.95=P(−0.5≤∑ni=1Xin−t≤0.5)0.95 = P\left(-0.5 \leq \frac{\sum_{i=1}^n X_i}{n} - t \leq 0.5\right)$
|
408 |
+
|
409 |
+
$=P(−0.5√n2≤Z≤0.5√n2)= P\left(\frac{-0.5 \sqrt{n}}{2} \leq Z \leq \frac{0.5 \sqrt{n}}{2}\right)$
|
410 |
+
|
411 |
+
And now we find the value of $n$ that makes this equation hold:
|
412 |
+
|
413 |
+
$0.95=Φ(√n4)−Φ(−√n4)0.95 = \Phi\left(\frac{\sqrt{n}}{4}\right) - \Phi\left(-\frac{\sqrt{n}}{4}\right)$
|
414 |
+
|
415 |
+
$4=Φ(√n4)−(1−Φ(√n4))= \Phi\left(\frac{\sqrt{n}}{4}\right) - \left(1 - \Phi\left(\frac{\sqrt{n}}{4}\right)\right)$
|
416 |
+
|
417 |
+
$=2Φ(√n4)−1= 2\Phi\left(\frac{\sqrt{n}}{4}\right) - 1$
|
418 |
+
|
419 |
+
Solving for $\Phi\left(\frac{\sqrt{n}}{4}\right)$:
|
420 |
+
|
421 |
+
$0.975=Φ(√n4)0.975 = \Phi\left(\frac{\sqrt{n}}{4}\right)$
|
422 |
+
|
423 |
+
$Φ−1(0.975)=√n4\Phi^{-1}(0.975) = \frac{\sqrt{n}}{4}$
|
424 |
+
|
425 |
+
$1.96=√n41.96 = \frac{\sqrt{n}}{4}$
|
426 |
+
|
427 |
+
$n=61.4n = 61.4$
|
428 |
+
|
429 |
+
Therefore, we need to run the algorithm 62 times to estimate the mean runtime within ±0.5 seconds with 95% confidence.
|
430 |
+
"""
|
431 |
+
)
|
432 |
+
return
|
433 |
+
|
434 |
+
|
435 |
+
@app.cell(hide_code=True)
|
436 |
+
def _(create_algorithm_runtime_visualization, fig_to_image, mo):
|
437 |
+
# Display visualization
|
438 |
+
runtime_fig = create_algorithm_runtime_visualization()
|
439 |
+
runtime_image = mo.image(fig_to_image(runtime_fig), width="100%")
|
440 |
+
|
441 |
+
runtime_explanation = mo.md(
|
442 |
+
r"""
|
443 |
+
**Visualization Explanation:**
|
444 |
+
|
445 |
+
The graph illustrates how the standard error of the mean (SEM) decreases as the number of trials increases. The standard error is calculated as $\frac{\sigma}{\sqrt{n}}$.
|
446 |
+
|
447 |
+
- When we conduct 62 trials, the standard error is approximately 0.254 seconds.
|
448 |
+
- With a 95% confidence level, this gives us a margin of error of about ±0.5 seconds (1.96 × 0.254 ≈ 0.5).
|
449 |
+
- The shaded region shows how the confidence interval narrows as the number of trials increases.
|
450 |
+
|
451 |
+
This demonstrates why 62 trials are sufficient to meet our requirements of estimating the mean runtime within ±0.5 seconds with 95% confidence.
|
452 |
+
"""
|
453 |
+
)
|
454 |
+
|
455 |
+
mo.vstack([runtime_image, runtime_explanation])
|
456 |
+
return runtime_explanation, runtime_fig, runtime_image
|
457 |
+
|
458 |
+
|
459 |
+
@app.cell(hide_code=True)
|
460 |
+
def _(mo):
|
461 |
+
mo.md(
|
462 |
+
r"""
|
463 |
+
## Interactive CLT Explorer
|
464 |
+
|
465 |
+
Let's explore how the Central Limit Theorem works with different underlying distributions. You can select a distribution type and see how the distribution of the sample mean changes as the sample size increases.
|
466 |
+
"""
|
467 |
+
)
|
468 |
+
return
|
469 |
+
|
470 |
+
|
471 |
+
@app.cell(hide_code=True)
|
472 |
+
def _(controls):
|
473 |
+
controls
|
474 |
+
return
|
475 |
+
|
476 |
+
|
477 |
+
@app.cell(hide_code=True)
|
478 |
+
def _(
|
479 |
+
distribution_type,
|
480 |
+
fig_to_image,
|
481 |
+
mo,
|
482 |
+
np,
|
483 |
+
plt,
|
484 |
+
run_explorer_button,
|
485 |
+
sample_size,
|
486 |
+
sim_count_slider,
|
487 |
+
stats,
|
488 |
+
):
|
489 |
+
# Run simulation when button is clicked
|
490 |
+
if run_explorer_button.value:
|
491 |
+
# Set distribution parameters based on selection
|
492 |
+
if distribution_type.value == "uniform":
|
493 |
+
dist_name = "Uniform(0, 1)"
|
494 |
+
# For uniform(0,1): mean = 0.5, variance = 1/12
|
495 |
+
true_mean = 0.5
|
496 |
+
true_var = 1/12
|
497 |
+
|
498 |
+
# generate samples
|
499 |
+
def generate_sample():
|
500 |
+
return np.random.uniform(0, 1, sample_size.value)
|
501 |
+
|
502 |
+
elif distribution_type.value == "exponential":
|
503 |
+
rate = 1.0
|
504 |
+
dist_name = f"Exponential(λ={rate})"
|
505 |
+
# For exponential(λ): mean = 1/λ, variance = 1/λ²
|
506 |
+
true_mean = 1/rate
|
507 |
+
true_var = 1/(rate**2)
|
508 |
+
|
509 |
+
def generate_sample():
|
510 |
+
return np.random.exponential(1/rate, sample_size.value)
|
511 |
+
|
512 |
+
elif distribution_type.value == "binomial":
|
513 |
+
n_param, p = 10, 0.3
|
514 |
+
dist_name = f"Binomial(n={n_param}, p={p})"
|
515 |
+
# For binomial(n,p): mean = np, variance = np(1-p)
|
516 |
+
true_mean = n_param * p
|
517 |
+
true_var = n_param * p * (1-p)
|
518 |
+
|
519 |
+
def generate_sample():
|
520 |
+
return np.random.binomial(n_param, p, sample_size.value)
|
521 |
+
|
522 |
+
elif distribution_type.value == "poisson":
|
523 |
+
rate = 3.0
|
524 |
+
dist_name = f"Poisson(λ={rate})"
|
525 |
+
# For poisson(λ): mean = λ, variance = λ
|
526 |
+
true_mean = rate
|
527 |
+
true_var = rate
|
528 |
+
|
529 |
+
def generate_sample():
|
530 |
+
return np.random.poisson(rate, sample_size.value)
|
531 |
+
|
532 |
+
# Generate the simulation data using a spinner for progress
|
533 |
+
with mo.status.spinner(title="Running simulation...") as explorer_progress:
|
534 |
+
sample_means = []
|
535 |
+
original_samples = []
|
536 |
+
|
537 |
+
# Run simulations
|
538 |
+
for _ in mo.status.progress_bar(range(sim_count_slider.value)):
|
539 |
+
sample = generate_sample()
|
540 |
+
|
541 |
+
# Store the first simulation's individual values for visualizing original distribution
|
542 |
+
if len(original_samples) < 1000: # limit to prevent memory issues
|
543 |
+
original_samples.extend(sample)
|
544 |
+
|
545 |
+
# sample mean
|
546 |
+
sample_means.append(np.mean(sample))
|
547 |
+
|
548 |
+
# progress
|
549 |
+
explorer_progress.update(f"✅ Completed {sim_count_slider.value:,} simulations")
|
550 |
+
|
551 |
+
# Create visualization
|
552 |
+
explorer_fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))
|
553 |
+
|
554 |
+
# Original distribution histogram
|
555 |
+
ax1.hist(original_samples, bins=30, density=True, alpha=0.7, color='royalblue')
|
556 |
+
ax1.set_title(f"Original Distribution: {dist_name}")
|
557 |
+
|
558 |
+
# Theoretical mean line
|
559 |
+
ax1.axvline(x=true_mean, color='red', linestyle='--',
|
560 |
+
label=f'True Mean = {true_mean:.3f}')
|
561 |
+
|
562 |
+
ax1.set_xlabel("Value")
|
563 |
+
ax1.set_ylabel("Density")
|
564 |
+
ax1.legend()
|
565 |
+
|
566 |
+
# Sample means histogram and normal approximation
|
567 |
+
sample_mean_mean = np.mean(sample_means)
|
568 |
+
sample_mean_std = np.std(sample_means)
|
569 |
+
expected_std = np.sqrt(true_var / sample_size.value) # CLT prediction
|
570 |
+
|
571 |
+
ax2.hist(sample_means, bins=30, density=True, alpha=0.7, color='forestgreen',
|
572 |
+
label=f'Sample Size = {sample_size.value}')
|
573 |
+
|
574 |
+
# Normal approximation from CLT
|
575 |
+
explorer_x = np.linspace(min(sample_means), max(sample_means), 1000)
|
576 |
+
explorer_y = stats.norm.pdf(explorer_x, true_mean, expected_std)
|
577 |
+
ax2.plot(explorer_x, explorer_y, 'r-', linewidth=2, label='CLT Normal Approximation')
|
578 |
+
|
579 |
+
# Add mean line
|
580 |
+
ax2.axvline(x=true_mean, color='purple', linestyle='--',
|
581 |
+
label=f'True Mean = {true_mean:.3f}')
|
582 |
+
|
583 |
+
ax2.set_title(f"Distribution of Sample Means\n(CLT Prediction: N({true_mean:.3f}, {true_var/sample_size.value:.5f}))")
|
584 |
+
ax2.set_xlabel("Sample Mean")
|
585 |
+
ax2.set_ylabel("Density")
|
586 |
+
ax2.legend()
|
587 |
+
|
588 |
+
# Add CLT description
|
589 |
+
explorer_fig.text(0.5, 0.01,
|
590 |
+
f"Central Limit Theorem: As sample size increases, the distribution of sample means approaches\n" +
|
591 |
+
f"a normal distribution with mean = {true_mean:.3f} and variance = {true_var:.3f}/{sample_size.value} = {true_var/sample_size.value:.5f}",
|
592 |
+
ha='center', fontsize=10, bbox=dict(facecolor='white', alpha=0.8))
|
593 |
+
|
594 |
+
plt.tight_layout(rect=[0, 0.05, 1, 1])
|
595 |
+
|
596 |
+
# Display plot
|
597 |
+
explorer_image = mo.image(fig_to_image(explorer_fig), width="100%")
|
598 |
+
else:
|
599 |
+
explorer_image = mo.md("Click the 'Run Simulation' button to see how the Central Limit Theorem works.")
|
600 |
+
|
601 |
+
explorer_image
|
602 |
+
return (
|
603 |
+
ax1,
|
604 |
+
ax2,
|
605 |
+
dist_name,
|
606 |
+
expected_std,
|
607 |
+
explorer_fig,
|
608 |
+
explorer_image,
|
609 |
+
explorer_progress,
|
610 |
+
explorer_x,
|
611 |
+
explorer_y,
|
612 |
+
generate_sample,
|
613 |
+
n_param,
|
614 |
+
original_samples,
|
615 |
+
p,
|
616 |
+
rate,
|
617 |
+
sample,
|
618 |
+
sample_mean_mean,
|
619 |
+
sample_mean_std,
|
620 |
+
sample_means,
|
621 |
+
true_mean,
|
622 |
+
true_var,
|
623 |
+
)
|
624 |
+
|
625 |
+
|
626 |
+
@app.cell(hide_code=True)
|
627 |
+
def _(mo):
|
628 |
+
mo.md(
|
629 |
+
r"""
|
630 |
+
## 🤔 Test Your Understanding
|
631 |
+
|
632 |
+
/// details | What is the shape of the distribution of the sum of many independent random variables?
|
633 |
+
The sum of many independent random variables approaches a normal distribution, regardless of the shape of the original distributions (as long as they have finite mean and variance). This is the essence of the Central Limit Theorem.
|
634 |
+
///
|
635 |
+
|
636 |
+
/// details | If $X_1, X_2, \dots, X_{100}$ are IID random variables with $E[X_i] = 5$ and $Var(X_i) = 9$, what is the distribution of their sum?
|
637 |
+
By the Central Limit Theorem, the sum $S = X_1 + X_2 + \dots + X_{100}$ follows a normal distribution with:
|
638 |
+
|
639 |
+
- Mean: $E[S] = 100 \cdot E[X_i] = 100 \cdot 5 = 500$
|
640 |
+
- Variance: $Var(S) = 100 \cdot Var(X_i) = 100 \cdot 9 = 900$
|
641 |
+
|
642 |
+
Therefore, $S \sim \mathcal{N}(500, 900)$, or equivalently $S \sim \mathcal{N}(500, 30^2)$.
|
643 |
+
///
|
644 |
+
|
645 |
+
/// details | When do you need to apply a continuity correction when using the Central Limit Theorem?
|
646 |
+
You need to apply a continuity correction when you're using the normal approximation (through CLT) for a discrete random variable.
|
647 |
+
|
648 |
+
For example, when approximating a binomial or Poisson distribution with a normal distribution, you should adjust boundaries by ±0.5 to account for the discrete nature of the original variable. This makes the approximation more accurate.
|
649 |
+
///
|
650 |
+
|
651 |
+
/// details | If $X_1, X_2, \dots, X_{n}$ are IID random variables, how does the variance of their sample mean $\bar{X} = \frac{1}{n}\sum_{i=1}^{n}X_i$ change as $n$ increases?
|
652 |
+
The variance of the sample mean decreases as the sample size $n$ increases. Specifically:
|
653 |
+
|
654 |
+
$Var(\bar{X}) = \frac{Var(X_i)}{n}$
|
655 |
+
|
656 |
+
This means that as we take more samples, the sample mean becomes more concentrated around the true mean of the distribution. This is why larger samples give more precise estimates.
|
657 |
+
///
|
658 |
+
|
659 |
+
/// details | Why is the Central Limit Theorem so important in statistics?
|
660 |
+
The Central Limit Theorem is foundational in statistics because:
|
661 |
+
|
662 |
+
1. It allows us to make inferences about population parameters using sample statistics, regardless of the population's distribution.
|
663 |
+
2. It explains why the normal distribution appears so frequently in natural phenomena.
|
664 |
+
3. It enables the construction of confidence intervals and hypothesis tests for means, even when the underlying population distribution is unknown.
|
665 |
+
4. It justifies many statistical methods that assume normality, even when working with non-normal data, provided the sample size is large enough.
|
666 |
+
|
667 |
+
In essence, the CLT provides the theoretical justification for much of statistical inference.
|
668 |
+
///
|
669 |
+
"""
|
670 |
+
)
|
671 |
+
return
|
672 |
+
|
673 |
+
|
674 |
+
@app.cell(hide_code=True)
|
675 |
+
def _(mo):
|
676 |
+
mo.md(r"""## Appendix (helper code and functions)""")
|
677 |
+
return
|
678 |
+
|
679 |
+
|
680 |
+
@app.cell
|
681 |
+
def _():
|
682 |
+
import marimo as mo
|
683 |
+
return (mo,)
|
684 |
+
|
685 |
+
|
686 |
+
@app.cell(hide_code=True)
|
687 |
+
def _():
|
688 |
+
from wigglystuff import TangleSlider
|
689 |
+
return (TangleSlider,)
|
690 |
+
|
691 |
+
|
692 |
+
@app.cell(hide_code=True)
|
693 |
+
def _():
|
694 |
+
# Import libraries
|
695 |
+
import numpy as np
|
696 |
+
import matplotlib.pyplot as plt
|
697 |
+
from scipy import stats
|
698 |
+
import io
|
699 |
+
import base64
|
700 |
+
import random
|
701 |
+
import time
|
702 |
+
import plotly.graph_objects as go
|
703 |
+
import plotly.io as pio
|
704 |
+
return base64, go, io, np, pio, plt, random, stats, time
|
705 |
+
|
706 |
+
|
707 |
+
@app.cell(hide_code=True)
|
708 |
+
def _(base64, io):
|
709 |
+
from matplotlib.figure import Figure
|
710 |
+
|
711 |
+
# Helper function to convert matplotlib figures to images
|
712 |
+
def fig_to_image(fig):
|
713 |
+
buf = io.BytesIO()
|
714 |
+
fig.savefig(buf, format='png', bbox_inches='tight')
|
715 |
+
buf.seek(0)
|
716 |
+
img_str = base64.b64encode(buf.getvalue()).decode('utf-8')
|
717 |
+
return f"data:image/png;base64,{img_str}"
|
718 |
+
return Figure, fig_to_image
|
719 |
+
|
720 |
+
|
721 |
+
@app.cell(hide_code=True)
|
722 |
+
def _(np, plt, stats):
|
723 |
+
def create_dice_game_visualization():
|
724 |
+
"""Create a visualization for the dice game example."""
|
725 |
+
# Parameters
|
726 |
+
n_dice = 10
|
727 |
+
dice_values = np.arange(1, 7) # 1 to 6
|
728 |
+
|
729 |
+
# Theoretical values
|
730 |
+
single_die_mean = np.mean(dice_values) # 3.5
|
731 |
+
single_die_var = np.var(dice_values) # 35/12
|
732 |
+
|
733 |
+
# Sum distribution parameters
|
734 |
+
sum_mean = n_dice * single_die_mean
|
735 |
+
sum_var = n_dice * single_die_var
|
736 |
+
sum_std = np.sqrt(sum_var)
|
737 |
+
|
738 |
+
# Possible outcomes for the sum of 10 dice
|
739 |
+
min_sum = n_dice * min(dice_values) # 10
|
740 |
+
max_sum = n_dice * max(dice_values) # 60
|
741 |
+
sum_values = np.arange(min_sum, max_sum + 1)
|
742 |
+
|
743 |
+
# Create figure
|
744 |
+
fig, ax = plt.subplots(figsize=(10, 6))
|
745 |
+
|
746 |
+
# Calculate PMF through convolution
|
747 |
+
# For one die
|
748 |
+
single_pmf = np.ones(6) / 6
|
749 |
+
|
750 |
+
sum_pmf = single_pmf.copy()
|
751 |
+
for _ in range(n_dice - 1):
|
752 |
+
sum_pmf = np.convolve(sum_pmf, single_pmf)
|
753 |
+
|
754 |
+
# Plot the PMF
|
755 |
+
ax.bar(sum_values, sum_pmf, alpha=0.7, color='royalblue', label='Exact PMF')
|
756 |
+
|
757 |
+
# Normal approximation
|
758 |
+
x = np.linspace(min_sum - 5, max_sum + 5, 1000)
|
759 |
+
y = stats.norm.pdf(x, sum_mean, sum_std)
|
760 |
+
ax.plot(x, y, 'r-', linewidth=2, label='Normal Approximation')
|
761 |
+
|
762 |
+
# Win conditions (x ≤ 25 or x ≥ 45)
|
763 |
+
win_region_left = sum_values <= 25
|
764 |
+
win_region_right = sum_values >= 45
|
765 |
+
|
766 |
+
# Shade win regions
|
767 |
+
ax.bar(sum_values[win_region_left], sum_pmf[win_region_left],
|
768 |
+
color='darkorange', alpha=0.7, label='Win Region')
|
769 |
+
ax.bar(sum_values[win_region_right], sum_pmf[win_region_right],
|
770 |
+
color='darkorange', alpha=0.7)
|
771 |
+
|
772 |
+
# Calculate win probability
|
773 |
+
win_prob = np.sum(sum_pmf[win_region_left]) + np.sum(sum_pmf[win_region_right])
|
774 |
+
|
775 |
+
# Add vertical lines for critical values
|
776 |
+
ax.axvline(x=25.5, color='red', linestyle='--', linewidth=1.5, label='Critical Points')
|
777 |
+
ax.axvline(x=44.5, color='red', linestyle='--', linewidth=1.5)
|
778 |
+
|
779 |
+
# Add mean line
|
780 |
+
ax.axvline(x=sum_mean, color='green', linestyle='--', linewidth=1.5,
|
781 |
+
label=f'Mean = {sum_mean}')
|
782 |
+
|
783 |
+
# Text box with relevant information
|
784 |
+
textstr = '\n'.join((
|
785 |
+
f'Number of dice: {n_dice}',
|
786 |
+
f'Sum Mean: {sum_mean}',
|
787 |
+
f'Sum Std Dev: {sum_std:.2f}',
|
788 |
+
f'Win Probability: {win_prob:.4f}',
|
789 |
+
f'CLT Approximation: {0.078:.4f}'
|
790 |
+
))
|
791 |
+
props = dict(boxstyle='round', facecolor='wheat', alpha=0.5)
|
792 |
+
ax.text(0.05, 0.95, textstr, transform=ax.transAxes, fontsize=10,
|
793 |
+
verticalalignment='top', bbox=props)
|
794 |
+
|
795 |
+
# Formatting
|
796 |
+
ax.set_xlabel('Sum of 10 Dice')
|
797 |
+
ax.set_ylabel('Probability')
|
798 |
+
ax.set_title('Central Limit Theorem: Dice Game Example')
|
799 |
+
ax.legend()
|
800 |
+
ax.grid(alpha=0.3)
|
801 |
+
|
802 |
+
plt.tight_layout()
|
803 |
+
plt.gca()
|
804 |
+
return fig
|
805 |
+
return (create_dice_game_visualization,)
|
806 |
+
|
807 |
+
|
808 |
+
@app.cell(hide_code=True)
|
809 |
+
def _(np, plt):
|
810 |
+
def create_algorithm_runtime_visualization():
|
811 |
+
"""Create a visualization for the algorithm runtime example."""
|
812 |
+
# Parameters
|
813 |
+
variance = 4 # σ² = 4 sec²
|
814 |
+
std_dev = np.sqrt(variance) # σ = 2 sec
|
815 |
+
confidence_level = 0.95
|
816 |
+
z_score = 1.96 # for 95% confidence
|
817 |
+
target_error = 0.5 # ±0.5 seconds
|
818 |
+
|
819 |
+
# Calculate n needed for desired precision
|
820 |
+
n_required = int(np.ceil((z_score * std_dev / target_error) ** 2)) # ≈ 62
|
821 |
+
|
822 |
+
n_values = np.arange(1, 100)
|
823 |
+
|
824 |
+
# standard error
|
825 |
+
standard_errors = std_dev / np.sqrt(n_values)
|
826 |
+
|
827 |
+
# margin of error
|
828 |
+
margins_of_error = z_score * standard_errors
|
829 |
+
|
830 |
+
# Create figure
|
831 |
+
fig, ax = plt.subplots(figsize=(10, 6))
|
832 |
+
|
833 |
+
# standard error vs sample size plot
|
834 |
+
ax.plot(n_values, standard_errors, 'b-', linewidth=2, label='Standard Error of Mean')
|
835 |
+
|
836 |
+
# Plot margin of error vs sample size
|
837 |
+
ax.plot(n_values, margins_of_error, 'r--', linewidth=2,
|
838 |
+
label=f'{confidence_level*100}% Margin of Error')
|
839 |
+
|
840 |
+
ax.axvline(x=n_required, color='green', linestyle='-', linewidth=1.5,
|
841 |
+
label=f'Required n = {n_required}')
|
842 |
+
|
843 |
+
ax.axhline(y=target_error, color='purple', linestyle='--', linewidth=1.5,
|
844 |
+
label=f'Target Error = ±{target_error} sec')
|
845 |
+
|
846 |
+
# Shade the region below target error
|
847 |
+
ax.fill_between(n_values, 0, target_error, alpha=0.2, color='green')
|
848 |
+
|
849 |
+
# intersection point
|
850 |
+
ax.plot(n_required, target_error, 'ro', markersize=8)
|
851 |
+
ax.annotate(f'({n_required}, {target_error} sec)',
|
852 |
+
xy=(n_required, target_error),
|
853 |
+
xytext=(n_required + 5, target_error + 0.1),
|
854 |
+
arrowprops=dict(facecolor='black', shrink=0.05, width=1))
|
855 |
+
|
856 |
+
# Text box with appropriate information
|
857 |
+
textstr = '\n'.join((
|
858 |
+
f'Algorithm Variance: {variance} sec²',
|
859 |
+
f'Standard Deviation: {std_dev} sec',
|
860 |
+
f'Confidence Level: {confidence_level*100}%',
|
861 |
+
f'Z-score: {z_score}',
|
862 |
+
f'Target Error: ±{target_error} sec',
|
863 |
+
f'Required Sample Size: {n_required}'
|
864 |
+
))
|
865 |
+
props = dict(boxstyle='round', facecolor='wheat', alpha=0.5)
|
866 |
+
ax.text(0.05, 0.95, textstr, transform=ax.transAxes, fontsize=10,
|
867 |
+
verticalalignment='top', bbox=props)
|
868 |
+
|
869 |
+
# Formatting
|
870 |
+
ax.set_xlabel('Sample Size (n)')
|
871 |
+
ax.set_ylabel('Error (seconds)')
|
872 |
+
ax.set_title('Sample Size Determination for Algorithm Runtime Estimation')
|
873 |
+
ax.set_xlim(0, 100)
|
874 |
+
ax.set_ylim(0, 2)
|
875 |
+
ax.legend()
|
876 |
+
ax.grid(alpha=0.3)
|
877 |
+
|
878 |
+
plt.tight_layout()
|
879 |
+
return fig
|
880 |
+
return (create_algorithm_runtime_visualization,)
|
881 |
+
|
882 |
+
|
883 |
+
@app.cell(hide_code=True)
|
884 |
+
def _(mo):
|
885 |
+
mo.md(
|
886 |
+
r"""
|
887 |
+
## Summary
|
888 |
+
|
889 |
+
The Central Limit Theorem is truly one of the most remarkable ideas in all of statistics. It tells us that when we add up many independent random variables, their sum will follow a normal distribution, regardless of what the original distributions looked like. This is why we see normal distributions so often in real life – many natural phenomena are the result of numerous small, independent factors adding up.
|
890 |
+
|
891 |
+
What makes the CLT so powerful is its universality. Whether we're working with dice rolls, measurement errors, or stock market returns, as long as we have enough independent samples, their average or sum will be approximately normal. For sums, the distribution will be $\mathcal{N}(n\mu, n\sigma^2)$, and for averages, it's $\mathcal{N}(\mu, \frac{\sigma^2}{n})$.
|
892 |
+
|
893 |
+
The CLT gives us the foundation for confidence intervals, hypothesis testing, and many other statistical tools. Without it, we'd have a much harder time making sense of data when we don't know the underlying population distribution. Just remember that if you're working with discrete distributions, you'll need to apply a continuity correction to get more accurate results.
|
894 |
+
|
895 |
+
Next time you see a normal distribution in data, think about the Central Limit Theorem – it might be the reason behind that familiar bell curve!
|
896 |
+
"""
|
897 |
+
)
|
898 |
+
return
|
899 |
+
|
900 |
+
|
901 |
+
@app.cell(hide_code=True)
|
902 |
+
def _(mo):
|
903 |
+
# controls for the interactive explorer
|
904 |
+
distribution_type = mo.ui.dropdown(
|
905 |
+
options=["uniform", "exponential", "binomial", "poisson"],
|
906 |
+
value="uniform",
|
907 |
+
label="Distribution Type"
|
908 |
+
)
|
909 |
+
|
910 |
+
sample_size = mo.ui.slider(
|
911 |
+
start =1,
|
912 |
+
stop =100,
|
913 |
+
step=1,
|
914 |
+
value=30,
|
915 |
+
label="Sample Size (n)"
|
916 |
+
)
|
917 |
+
|
918 |
+
sim_count_slider = mo.ui.slider(
|
919 |
+
start =100,
|
920 |
+
stop =10000,
|
921 |
+
step=100,
|
922 |
+
value=1000,
|
923 |
+
label="Number of Simulations"
|
924 |
+
)
|
925 |
+
|
926 |
+
run_explorer_button = mo.ui.run_button(label="Run Simulation", kind="warn")
|
927 |
+
|
928 |
+
controls = mo.hstack([
|
929 |
+
mo.vstack([distribution_type, sample_size, sim_count_slider]),
|
930 |
+
run_explorer_button
|
931 |
+
], justify='space-around')
|
932 |
+
|
933 |
+
return (
|
934 |
+
controls,
|
935 |
+
distribution_type,
|
936 |
+
run_explorer_button,
|
937 |
+
sample_size,
|
938 |
+
sim_count_slider,
|
939 |
+
)
|
940 |
+
|
941 |
+
|
942 |
+
if __name__ == "__main__":
|
943 |
+
app.run()
|
probability/19_maximum_likelihood_estimation.py
ADDED
@@ -0,0 +1,1231 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# /// script
|
2 |
+
# requires-python = ">=3.10"
|
3 |
+
# dependencies = [
|
4 |
+
# "marimo",
|
5 |
+
# "matplotlib==3.10.1",
|
6 |
+
# "scipy==1.15.2",
|
7 |
+
# "numpy==2.2.4",
|
8 |
+
# "polars==0.20.2",
|
9 |
+
# "plotly==5.18.0",
|
10 |
+
# ]
|
11 |
+
# ///
|
12 |
+
|
13 |
+
import marimo
|
14 |
+
|
15 |
+
__generated_with = "0.12.0"
|
16 |
+
app = marimo.App(width="medium", app_title="Maximum Likelihood Estimation")
|
17 |
+
|
18 |
+
|
19 |
+
@app.cell(hide_code=True)
|
20 |
+
def _(mo):
|
21 |
+
mo.md(
|
22 |
+
r"""
|
23 |
+
# Maximum Likelihood Estimation
|
24 |
+
|
25 |
+
_This notebook is a computational companion to ["Probability for Computer Scientists"](https://chrispiech.github.io/probabilityForComputerScientists/en/part5/mle/), by Stanford professor Chris Piech._
|
26 |
+
|
27 |
+
Maximum Likelihood Estimation (MLE) is a fundamental method in statistics for estimating parameters of a probability distribution. The central idea is elegantly simple: **choose the parameters that make the observed data most likely**.
|
28 |
+
|
29 |
+
In this notebook, we'll try to understand MLE, starting with the core concept of likelihood and how it differs from probability. We'll explore how to formulate MLE problems mathematically and then solve them for various common distributions. Along the way, I've included some interactive visualizations to help build your intuition about these concepts. You'll see how MLE applies to real-world scenarios like linear regression, and hopefully gain a deeper appreciation for why this technique is so widely used in statistics and machine learning. Think of MLE as detective work - we have some evidence (our data) and we're trying to figure out the most plausible explanation (our parameters) for what we've observed.
|
30 |
+
"""
|
31 |
+
)
|
32 |
+
return
|
33 |
+
|
34 |
+
|
35 |
+
@app.cell(hide_code=True)
|
36 |
+
def _(mo):
|
37 |
+
mo.md(
|
38 |
+
r"""
|
39 |
+
## Likelihood: The Core Concept
|
40 |
+
|
41 |
+
Before diving into MLE, we need to understand what "likelihood" means in a statistical context.
|
42 |
+
|
43 |
+
### Data and Parameters
|
44 |
+
|
45 |
+
Suppose we have collected some data $X_1, X_2, \ldots, X_n$ that are independent and identically distributed (IID). We assume these data points come from a specific type of distribution (like Normal, Bernoulli, etc.) with unknown parameters $\theta$.
|
46 |
+
|
47 |
+
### What is Likelihood?
|
48 |
+
|
49 |
+
Likelihood measures how probable our observed data is, given specific values of the parameters $\theta$.
|
50 |
+
|
51 |
+
/// note
|
52 |
+
**Probability vs. Likelihood**
|
53 |
+
|
54 |
+
- **Probability**: Given parameters $\theta$, what's the chance of observing data $X$?
|
55 |
+
- **Likelihood**: Given observed data $X$, how likely are different parameter values $\theta$?
|
56 |
+
///
|
57 |
+
|
58 |
+
To simplify notation, we'll use $f(X=x|\Theta=\theta)$ to represent either the PMF or PDF of our data, conditioned on the parameters.
|
59 |
+
"""
|
60 |
+
)
|
61 |
+
return
|
62 |
+
|
63 |
+
|
64 |
+
@app.cell(hide_code=True)
|
65 |
+
def _(mo):
|
66 |
+
mo.md(
|
67 |
+
r"""
|
68 |
+
### The Likelihood Function
|
69 |
+
|
70 |
+
Since we assume our data points are independent, the likelihood of all our data is the product of the likelihoods of each individual data point:
|
71 |
+
|
72 |
+
$$L(\theta) = \prod_{i=1}^n f(X_i = x_i|\Theta = \theta)$$
|
73 |
+
|
74 |
+
This function $L(\theta)$ gives us the likelihood of observing our entire dataset for different parameter values $\theta$.
|
75 |
+
|
76 |
+
/// tip
|
77 |
+
**Key Insight**: Different parameter values produce different likelihoods for the same data. Better parameter values will make the observed data more likely.
|
78 |
+
///
|
79 |
+
"""
|
80 |
+
)
|
81 |
+
return
|
82 |
+
|
83 |
+
|
84 |
+
@app.cell(hide_code=True)
|
85 |
+
def _(mo):
|
86 |
+
mo.md(
|
87 |
+
r"""
|
88 |
+
## Maximum Likelihood Estimation
|
89 |
+
|
90 |
+
The core idea of MLE is to find the parameter values $\hat{\theta}$ that maximize the likelihood function:
|
91 |
+
|
92 |
+
$$\hat{\theta} = \underset{\theta}{\operatorname{argmax}} \, L(\theta)$$
|
93 |
+
|
94 |
+
The notation $\hat{\theta}$ represents our best estimate of the true parameters based on the observed data.
|
95 |
+
|
96 |
+
### Working with Log-Likelihood
|
97 |
+
|
98 |
+
In practice, we usually work with the **log-likelihood** instead of the likelihood directly. Since logarithm is a monotonically increasing function, the maximum of $L(\theta)$ occurs at the same value of $\theta$ as the maximum of $\log L(\theta)$.
|
99 |
+
|
100 |
+
Taking the logarithm transforms our product into a sum, which is much easier to work with:
|
101 |
+
|
102 |
+
$$LL(\theta) = \log L(\theta) = \log \prod_{i=1}^n f(X_i=x_i|\Theta = \theta) = \sum_{i=1}^n \log f(X_i = x_i|\Theta = \theta)$$
|
103 |
+
|
104 |
+
/// warning
|
105 |
+
Working with products of many small probabilities can lead to numerical underflow. Taking the logarithm converts these products to sums, which is numerically more stable.
|
106 |
+
///
|
107 |
+
"""
|
108 |
+
)
|
109 |
+
return
|
110 |
+
|
111 |
+
|
112 |
+
@app.cell(hide_code=True)
|
113 |
+
def _(mo):
|
114 |
+
mo.md(
|
115 |
+
r"""
|
116 |
+
### Finding the Maximum
|
117 |
+
|
118 |
+
To find the values of $\theta$ that maximize the log-likelihood, we typically:
|
119 |
+
|
120 |
+
1. Take the derivative of $LL(\theta)$ with respect to each parameter
|
121 |
+
2. Set each derivative equal to zero
|
122 |
+
3. Solve for the parameters
|
123 |
+
|
124 |
+
Let's see this approach in action with some common distributions.
|
125 |
+
"""
|
126 |
+
)
|
127 |
+
return
|
128 |
+
|
129 |
+
|
130 |
+
@app.cell(hide_code=True)
|
131 |
+
def _(mo):
|
132 |
+
mo.md(
|
133 |
+
r"""
|
134 |
+
## MLE for Bernoulli Distribution
|
135 |
+
|
136 |
+
Let's start with a simple example: estimating the parameter $p$ of a Bernoulli distribution.
|
137 |
+
|
138 |
+
### The Model
|
139 |
+
|
140 |
+
A Bernoulli distribution has a single parameter $p$ which represents the probability of success (getting a value of 1). Its probability mass function (PMF) can be written as:
|
141 |
+
|
142 |
+
$$f(x|p) = p^x(1-p)^{1-x}, \quad x \in \{0, 1\}$$
|
143 |
+
|
144 |
+
This elegant formula works because:
|
145 |
+
|
146 |
+
- When $x = 1$: $f(1|p) = p^1(1-p)^0 = p$
|
147 |
+
- When $x = 0$: $f(0|p) = p^0(1-p)^1 = 1-p$
|
148 |
+
|
149 |
+
### Deriving the MLE
|
150 |
+
|
151 |
+
Given $n$ independent Bernoulli trials $X_1, X_2, \ldots, X_n$, we want to find the value of $p$ that maximizes the likelihood of our observed data.
|
152 |
+
|
153 |
+
Step 1: Write the likelihood function
|
154 |
+
$$L(p) = \prod_{i=1}^n p^{x_i}(1-p)^{1-x_i}$$
|
155 |
+
|
156 |
+
Step 2: Take the logarithm to get the log-likelihood
|
157 |
+
$$\begin{align*}
|
158 |
+
LL(p) &= \sum_{i=1}^n \log(p^{x_i}(1-p)^{1-x_i}) \\
|
159 |
+
&= \sum_{i=1}^n \left[x_i \log(p) + (1-x_i)\log(1-p)\right] \\
|
160 |
+
&= \left(\sum_{i=1}^n x_i\right) \log(p) + \left(n - \sum_{i=1}^n x_i\right) \log(1-p) \\
|
161 |
+
&= Y\log(p) + (n-Y)\log(1-p)
|
162 |
+
\end{align*}$$
|
163 |
+
|
164 |
+
where $Y = \sum_{i=1}^n x_i$ is the total number of successes.
|
165 |
+
|
166 |
+
Step 3: Find the value of $p$ that maximizes $LL(p)$ by setting the derivative to zero
|
167 |
+
$$\begin{align*}
|
168 |
+
\frac{d\,LL(p)}{dp} &= \frac{Y}{p} - \frac{n-Y}{1-p} = 0 \\
|
169 |
+
\frac{Y}{p} &= \frac{n-Y}{1-p} \\
|
170 |
+
Y(1-p) &= p(n-Y) \\
|
171 |
+
Y - Yp &= pn - pY \\
|
172 |
+
Y &= pn \\
|
173 |
+
\hat{p} &= \frac{Y}{n} = \frac{\sum_{i=1}^n x_i}{n}
|
174 |
+
\end{align*}$$
|
175 |
+
|
176 |
+
/// tip
|
177 |
+
The MLE for the parameter $p$ in a Bernoulli distribution is simply the **sample mean** - the proportion of successes in our data!
|
178 |
+
///
|
179 |
+
"""
|
180 |
+
)
|
181 |
+
return
|
182 |
+
|
183 |
+
|
184 |
+
@app.cell(hide_code=True)
|
185 |
+
def _(controls):
|
186 |
+
controls.center()
|
187 |
+
return
|
188 |
+
|
189 |
+
|
190 |
+
@app.cell(hide_code=True)
|
191 |
+
def _(generate_button, mo, np, plt, sample_size_slider, true_p_slider):
|
192 |
+
# generate bernoulli samples when button is clicked
|
193 |
+
bernoulli_button_value = generate_button.value
|
194 |
+
|
195 |
+
# get parameter values
|
196 |
+
bernoulli_true_p = true_p_slider.value
|
197 |
+
bernoulli_n = sample_size_slider.value
|
198 |
+
|
199 |
+
# generate data
|
200 |
+
bernoulli_data = np.random.binomial(1, bernoulli_true_p, size=bernoulli_n)
|
201 |
+
bernoulli_Y = np.sum(bernoulli_data)
|
202 |
+
bernoulli_p_hat = bernoulli_Y / bernoulli_n
|
203 |
+
|
204 |
+
# create visualization
|
205 |
+
bernoulli_fig, (bernoulli_ax1, bernoulli_ax2) = plt.subplots(1, 2, figsize=(12, 5))
|
206 |
+
|
207 |
+
# plot data histogram
|
208 |
+
bernoulli_ax1.hist(bernoulli_data, bins=[-0.5, 0.5, 1.5], rwidth=0.8, color='lightblue')
|
209 |
+
bernoulli_ax1.set_xticks([0, 1])
|
210 |
+
bernoulli_ax1.set_xticklabels(['Failure (0)', 'Success (1)'])
|
211 |
+
bernoulli_ax1.set_title(f'Bernoulli Data: {bernoulli_n} samples')
|
212 |
+
bernoulli_ax1.set_ylabel('Count')
|
213 |
+
bernoulli_y_counts = [bernoulli_n - bernoulli_Y, bernoulli_Y]
|
214 |
+
for bernoulli_idx, bernoulli_count in enumerate(bernoulli_y_counts):
|
215 |
+
bernoulli_ax1.text(bernoulli_idx, bernoulli_count/2, f"{bernoulli_count}",
|
216 |
+
ha='center', va='center',
|
217 |
+
color='white' if bernoulli_idx == 0 else 'black',
|
218 |
+
fontweight='bold')
|
219 |
+
|
220 |
+
# calculate log-likelihood function
|
221 |
+
bernoulli_p_values = np.linspace(0.01, 0.99, 100)
|
222 |
+
bernoulli_ll_values = np.zeros_like(bernoulli_p_values)
|
223 |
+
|
224 |
+
for bernoulli_i, bernoulli_p in enumerate(bernoulli_p_values):
|
225 |
+
bernoulli_ll_values[bernoulli_i] = bernoulli_Y * np.log(bernoulli_p) + (bernoulli_n - bernoulli_Y) * np.log(1 - bernoulli_p)
|
226 |
+
|
227 |
+
# plot log-likelihood
|
228 |
+
bernoulli_ax2.plot(bernoulli_p_values, bernoulli_ll_values, 'b-', linewidth=2)
|
229 |
+
bernoulli_ax2.axvline(x=bernoulli_p_hat, color='r', linestyle='--', label=f'MLE: $\\hat{{p}} = {bernoulli_p_hat:.3f}$')
|
230 |
+
bernoulli_ax2.axvline(x=bernoulli_true_p, color='g', linestyle='--', label=f'True: $p = {bernoulli_true_p:.3f}$')
|
231 |
+
bernoulli_ax2.set_xlabel('$p$ (probability of success)')
|
232 |
+
bernoulli_ax2.set_ylabel('Log-Likelihood')
|
233 |
+
bernoulli_ax2.set_title('Log-Likelihood Function')
|
234 |
+
bernoulli_ax2.legend()
|
235 |
+
|
236 |
+
plt.tight_layout()
|
237 |
+
plt.gca()
|
238 |
+
|
239 |
+
# Create markdown to explain the results
|
240 |
+
bernoulli_explanation = mo.md(
|
241 |
+
f"""
|
242 |
+
### Bernoulli MLE Results
|
243 |
+
|
244 |
+
**True parameter**: $p = {bernoulli_true_p:.3f}$
|
245 |
+
**Sample statistics**: {bernoulli_Y} successes out of {bernoulli_n} trials
|
246 |
+
**MLE estimate**: $\\hat{{p}} = \\frac{{{bernoulli_Y}}}{{{bernoulli_n}}} = {bernoulli_p_hat:.3f}$
|
247 |
+
|
248 |
+
The plot on the right shows the log-likelihood function $LL(p) = Y\\log(p) + (n-Y)\\log(1-p)$.
|
249 |
+
The red dashed line marks the maximum likelihood estimate $\\hat{{p}}$, and the green dashed line
|
250 |
+
shows the true parameter value.
|
251 |
+
|
252 |
+
/// note
|
253 |
+
Try increasing the sample size to see how the MLE estimate gets closer to the true parameter value!
|
254 |
+
///
|
255 |
+
"""
|
256 |
+
)
|
257 |
+
|
258 |
+
# Display plot and explanation together
|
259 |
+
mo.vstack([
|
260 |
+
bernoulli_fig,
|
261 |
+
bernoulli_explanation
|
262 |
+
])
|
263 |
+
return (
|
264 |
+
bernoulli_Y,
|
265 |
+
bernoulli_ax1,
|
266 |
+
bernoulli_ax2,
|
267 |
+
bernoulli_button_value,
|
268 |
+
bernoulli_count,
|
269 |
+
bernoulli_data,
|
270 |
+
bernoulli_explanation,
|
271 |
+
bernoulli_fig,
|
272 |
+
bernoulli_i,
|
273 |
+
bernoulli_idx,
|
274 |
+
bernoulli_ll_values,
|
275 |
+
bernoulli_n,
|
276 |
+
bernoulli_p,
|
277 |
+
bernoulli_p_hat,
|
278 |
+
bernoulli_p_values,
|
279 |
+
bernoulli_true_p,
|
280 |
+
bernoulli_y_counts,
|
281 |
+
)
|
282 |
+
|
283 |
+
|
284 |
+
@app.cell(hide_code=True)
|
285 |
+
def _(mo):
|
286 |
+
mo.md(
|
287 |
+
r"""
|
288 |
+
## MLE for Normal Distribution
|
289 |
+
|
290 |
+
Next, let's look at a more complex example: estimating the parameters $\mu$ and $\sigma^2$ of a Normal distribution.
|
291 |
+
|
292 |
+
### The Model
|
293 |
+
|
294 |
+
A Normal (Gaussian) distribution has two parameters:
|
295 |
+
- $\mu$: the mean
|
296 |
+
- $\sigma^2$: the variance
|
297 |
+
|
298 |
+
Its probability density function (PDF) is:
|
299 |
+
|
300 |
+
$$f(x|\mu, \sigma^2) = \frac{1}{\sqrt{2\pi\sigma^2}} \exp\left(-\frac{(x - \mu)^2}{2\sigma^2}\right)$$
|
301 |
+
|
302 |
+
### Deriving the MLE
|
303 |
+
|
304 |
+
Given $n$ independent samples $X_1, X_2, \ldots, X_n$ from a Normal distribution, we want to find the values of $\mu$ and $\sigma^2$ that maximize the likelihood of our observed data.
|
305 |
+
|
306 |
+
Step 1: Write the likelihood function
|
307 |
+
$$L(\mu, \sigma^2) = \prod_{i=1}^n \frac{1}{\sqrt{2\pi\sigma^2}} \exp\left(-\frac{(x_i - \mu)^2}{2\sigma^2}\right)$$
|
308 |
+
|
309 |
+
Step 2: Take the logarithm to get the log-likelihood
|
310 |
+
$$\begin{align*}
|
311 |
+
LL(\mu, \sigma^2) &= \log\prod_{i=1}^n \frac{1}{\sqrt{2\pi\sigma^2}} \exp\left(-\frac{(x_i - \mu)^2}{2\sigma^2}\right) \\
|
312 |
+
&= \sum_{i=1}^n \log\left[\frac{1}{\sqrt{2\pi\sigma^2}} \exp\left(-\frac{(x_i - \mu)^2}{2\sigma^2}\right)\right] \\
|
313 |
+
&= \sum_{i=1}^n \left[-\frac{1}{2}\log(2\pi\sigma^2) - \frac{(x_i - \mu)^2}{2\sigma^2}\right] \\
|
314 |
+
&= -\frac{n}{2}\log(2\pi\sigma^2) - \frac{1}{2\sigma^2}\sum_{i=1}^n (x_i - \mu)^2
|
315 |
+
\end{align*}$$
|
316 |
+
|
317 |
+
Step 3: Find the values of $\mu$ and $\sigma^2$ that maximize $LL(\mu, \sigma^2)$ by setting the partial derivatives to zero.
|
318 |
+
|
319 |
+
For $\mu$:
|
320 |
+
$$\begin{align*}
|
321 |
+
\frac{\partial LL(\mu, \sigma^2)}{\partial \mu} &= \frac{1}{\sigma^2}\sum_{i=1}^n (x_i - \mu) = 0 \\
|
322 |
+
\sum_{i=1}^n (x_i - \mu) &= 0 \\
|
323 |
+
\sum_{i=1}^n x_i &= n\mu \\
|
324 |
+
\hat{\mu} &= \frac{1}{n}\sum_{i=1}^n x_i
|
325 |
+
\end{align*}$$
|
326 |
+
|
327 |
+
For $\sigma^2$:
|
328 |
+
$$\begin{align*}
|
329 |
+
\frac{\partial LL(\mu, \sigma^2)}{\partial \sigma^2} &= -\frac{n}{2\sigma^2} + \frac{1}{2(\sigma^2)^2}\sum_{i=1}^n (x_i - \mu)^2 = 0 \\
|
330 |
+
\frac{n}{2\sigma^2} &= \frac{1}{2(\sigma^2)^2}\sum_{i=1}^n (x_i - \mu)^2 \\
|
331 |
+
n\sigma^2 &= \sum_{i=1}^n (x_i - \mu)^2 \\
|
332 |
+
\hat{\sigma}^2 &= \frac{1}{n}\sum_{i=1}^n (x_i - \hat{\mu})^2
|
333 |
+
\end{align*}$$
|
334 |
+
|
335 |
+
/// tip
|
336 |
+
The MLE for a Normal distribution gives us:
|
337 |
+
|
338 |
+
- $\hat{\mu}$ = sample mean
|
339 |
+
- $\hat{\sigma}^2$ = sample variance (using $n$ in the denominator, not $n-1$)
|
340 |
+
///
|
341 |
+
"""
|
342 |
+
)
|
343 |
+
return
|
344 |
+
|
345 |
+
|
346 |
+
@app.cell(hide_code=True)
|
347 |
+
def _(normal_controls):
|
348 |
+
normal_controls.center()
|
349 |
+
return
|
350 |
+
|
351 |
+
|
352 |
+
@app.cell(hide_code=True)
|
353 |
+
def _(
|
354 |
+
mo,
|
355 |
+
normal_generate_button,
|
356 |
+
normal_sample_size_slider,
|
357 |
+
np,
|
358 |
+
plt,
|
359 |
+
true_mu_slider,
|
360 |
+
true_sigma_slider,
|
361 |
+
):
|
362 |
+
# generate normal samples when button is clicked
|
363 |
+
normal_button_value = normal_generate_button.value
|
364 |
+
|
365 |
+
# get parameter values
|
366 |
+
normal_true_mu = true_mu_slider.value
|
367 |
+
normal_true_sigma = true_sigma_slider.value
|
368 |
+
normal_true_var = normal_true_sigma**2
|
369 |
+
normal_n = normal_sample_size_slider.value
|
370 |
+
|
371 |
+
# generate random data
|
372 |
+
normal_data = np.random.normal(normal_true_mu, normal_true_sigma, size=normal_n)
|
373 |
+
|
374 |
+
# calculate mle estimates
|
375 |
+
normal_mu_hat = np.mean(normal_data)
|
376 |
+
normal_sigma2_hat = np.mean((normal_data - normal_mu_hat)**2) # mle variance using n
|
377 |
+
normal_sigma_hat = np.sqrt(normal_sigma2_hat)
|
378 |
+
|
379 |
+
# create visualization
|
380 |
+
normal_fig, (normal_ax1, normal_ax2) = plt.subplots(1, 2, figsize=(12, 5))
|
381 |
+
|
382 |
+
# plot histogram and density curves
|
383 |
+
normal_bins = np.linspace(min(normal_data) - 1, max(normal_data) + 1, 30)
|
384 |
+
normal_ax1.hist(normal_data, bins=normal_bins, density=True, alpha=0.6, color='lightblue', label='Data Histogram')
|
385 |
+
|
386 |
+
# plot range for density curves
|
387 |
+
normal_x = np.linspace(min(normal_data) - 2*normal_true_sigma, max(normal_data) + 2*normal_true_sigma, 1000)
|
388 |
+
|
389 |
+
# plot true and mle densities
|
390 |
+
normal_true_pdf = (1/(normal_true_sigma * np.sqrt(2*np.pi))) * np.exp(-0.5 * ((normal_x - normal_true_mu)/normal_true_sigma)**2)
|
391 |
+
normal_ax1.plot(normal_x, normal_true_pdf, 'g-', linewidth=2, label=f'True: N({normal_true_mu:.2f}, {normal_true_var:.2f})')
|
392 |
+
|
393 |
+
normal_mle_pdf = (1/(normal_sigma_hat * np.sqrt(2*np.pi))) * np.exp(-0.5 * ((normal_x - normal_mu_hat)/normal_sigma_hat)**2)
|
394 |
+
normal_ax1.plot(normal_x, normal_mle_pdf, 'r--', linewidth=2, label=f'MLE: N({normal_mu_hat:.2f}, {normal_sigma2_hat:.2f})')
|
395 |
+
|
396 |
+
normal_ax1.set_xlabel('x')
|
397 |
+
normal_ax1.set_ylabel('Density')
|
398 |
+
normal_ax1.set_title(f'Normal Distribution: {normal_n} samples')
|
399 |
+
normal_ax1.legend()
|
400 |
+
|
401 |
+
# create contour plot of log-likelihood
|
402 |
+
normal_mu_range = np.linspace(normal_mu_hat - 2, normal_mu_hat + 2, 100)
|
403 |
+
normal_sigma_range = np.linspace(max(0.1, normal_sigma_hat - 1), normal_sigma_hat + 1, 100)
|
404 |
+
|
405 |
+
normal_mu_grid, normal_sigma_grid = np.meshgrid(normal_mu_range, normal_sigma_range)
|
406 |
+
normal_ll_grid = np.zeros_like(normal_mu_grid)
|
407 |
+
|
408 |
+
# calculate log-likelihood for each grid point
|
409 |
+
for normal_i in range(normal_mu_grid.shape[0]):
|
410 |
+
for normal_j in range(normal_mu_grid.shape[1]):
|
411 |
+
normal_mu = normal_mu_grid[normal_i, normal_j]
|
412 |
+
normal_sigma = normal_sigma_grid[normal_i, normal_j]
|
413 |
+
normal_ll = -normal_n/2 * np.log(2*np.pi*normal_sigma**2) - np.sum((normal_data - normal_mu)**2)/(2*normal_sigma**2)
|
414 |
+
normal_ll_grid[normal_i, normal_j] = normal_ll
|
415 |
+
|
416 |
+
# plot log-likelihood contour
|
417 |
+
normal_contour = normal_ax2.contourf(normal_mu_grid, normal_sigma_grid, normal_ll_grid, levels=50, cmap='viridis')
|
418 |
+
normal_ax2.set_xlabel('μ (mean)')
|
419 |
+
normal_ax2.set_ylabel('σ (standard deviation)')
|
420 |
+
normal_ax2.set_title('Log-Likelihood Contour')
|
421 |
+
|
422 |
+
# mark mle and true params
|
423 |
+
normal_ax2.plot(normal_mu_hat, normal_sigma_hat, 'rx', markersize=10, label='MLE Estimate')
|
424 |
+
normal_ax2.plot(normal_true_mu, normal_true_sigma, 'g*', markersize=10, label='True Parameters')
|
425 |
+
normal_ax2.legend()
|
426 |
+
|
427 |
+
plt.colorbar(normal_contour, ax=normal_ax2, label='Log-Likelihood')
|
428 |
+
plt.tight_layout()
|
429 |
+
plt.gca()
|
430 |
+
|
431 |
+
# relevant markdown for the results
|
432 |
+
normal_explanation = mo.md(
|
433 |
+
f"""
|
434 |
+
### Normal MLE Results
|
435 |
+
|
436 |
+
**True parameters**: $\mu = {normal_true_mu:.3f}$, $\sigma^2 = {normal_true_var:.3f}$
|
437 |
+
**MLE estimates**: $\hat{{\mu}} = {normal_mu_hat:.3f}$, $\hat{{\sigma}}^2 = {normal_sigma2_hat:.3f}$
|
438 |
+
|
439 |
+
The left plot shows the data histogram with the true Normal distribution (green) and the MLE-estimated distribution (red dashed).
|
440 |
+
|
441 |
+
The right plot shows the log-likelihood function as a contour map in the $(\mu, \sigma)$ parameter space. The maximum likelihood estimates are marked with a red X, while the true parameters are marked with a green star.
|
442 |
+
|
443 |
+
/// note
|
444 |
+
Notice how the log-likelihood contour is more stretched along the σ axis than the μ axis. This indicates that we typically estimate the mean with greater precision than the standard deviation.
|
445 |
+
///
|
446 |
+
|
447 |
+
/// tip
|
448 |
+
Increase the sample size to see how the MLE estimates converge to the true parameter values!
|
449 |
+
///
|
450 |
+
"""
|
451 |
+
)
|
452 |
+
|
453 |
+
# plot and explanation together
|
454 |
+
mo.vstack([
|
455 |
+
normal_fig,
|
456 |
+
normal_explanation
|
457 |
+
])
|
458 |
+
return (
|
459 |
+
normal_ax1,
|
460 |
+
normal_ax2,
|
461 |
+
normal_bins,
|
462 |
+
normal_button_value,
|
463 |
+
normal_contour,
|
464 |
+
normal_data,
|
465 |
+
normal_explanation,
|
466 |
+
normal_fig,
|
467 |
+
normal_i,
|
468 |
+
normal_j,
|
469 |
+
normal_ll,
|
470 |
+
normal_ll_grid,
|
471 |
+
normal_mle_pdf,
|
472 |
+
normal_mu,
|
473 |
+
normal_mu_grid,
|
474 |
+
normal_mu_hat,
|
475 |
+
normal_mu_range,
|
476 |
+
normal_n,
|
477 |
+
normal_sigma,
|
478 |
+
normal_sigma2_hat,
|
479 |
+
normal_sigma_grid,
|
480 |
+
normal_sigma_hat,
|
481 |
+
normal_sigma_range,
|
482 |
+
normal_true_mu,
|
483 |
+
normal_true_pdf,
|
484 |
+
normal_true_sigma,
|
485 |
+
normal_true_var,
|
486 |
+
normal_x,
|
487 |
+
)
|
488 |
+
|
489 |
+
|
490 |
+
@app.cell(hide_code=True)
|
491 |
+
def _(mo):
|
492 |
+
mo.md(
|
493 |
+
r"""
|
494 |
+
## MLE for Linear Regression
|
495 |
+
|
496 |
+
Now let's look at a more practical example: using MLE to derive linear regression.
|
497 |
+
|
498 |
+
### The Model
|
499 |
+
|
500 |
+
Consider a model where:
|
501 |
+
- We have pairs of observations $(X_1, Y_1), (X_2, Y_2), \ldots, (X_n, Y_n)$
|
502 |
+
- The relationship between $X$ and $Y$ follows: $Y = \theta X + Z$
|
503 |
+
- $Z \sim N(0, \sigma^2)$ is random noise
|
504 |
+
- Our goal is to estimate the parameter $\theta$
|
505 |
+
|
506 |
+
This means that for a given $X_i$, the conditional distribution of $Y_i$ is:
|
507 |
+
|
508 |
+
$$Y_i | X_i \sim N(\theta X_i, \sigma^2)$$
|
509 |
+
|
510 |
+
### Deriving the MLE
|
511 |
+
|
512 |
+
Step 1: Write the likelihood function for each data point $(X_i, Y_i)$
|
513 |
+
$$f(Y_i | X_i, \theta) = \frac{1}{\sqrt{2\pi\sigma^2}} \exp\left(-\frac{(Y_i - \theta X_i)^2}{2\sigma^2}\right)$$
|
514 |
+
|
515 |
+
Step 2: Write the likelihood for all data
|
516 |
+
$$\begin{align*}
|
517 |
+
L(\theta) &= \prod_{i=1}^n f(Y_i, X_i | \theta) \\
|
518 |
+
&= \prod_{i=1}^n f(Y_i | X_i, \theta) \cdot f(X_i)
|
519 |
+
\end{align*}$$
|
520 |
+
|
521 |
+
Since $f(X_i)$ doesn't depend on $\theta$, we can simplify:
|
522 |
+
$$L(\theta) = \prod_{i=1}^n \frac{1}{\sqrt{2\pi\sigma^2}} \exp\left(-\frac{(Y_i - \theta X_i)^2}{2\sigma^2}\right) \cdot f(X_i)$$
|
523 |
+
|
524 |
+
Step 3: Take the logarithm to get the log-likelihood
|
525 |
+
$$\begin{align*}
|
526 |
+
LL(\theta) &= \log \prod_{i=1}^n \frac{1}{\sqrt{2\pi\sigma^2}} \exp\left(-\frac{(Y_i - \theta X_i)^2}{2\sigma^2}\right) \cdot f(X_i) \\
|
527 |
+
&= \sum_{i=1}^n \log\left[\frac{1}{\sqrt{2\pi\sigma^2}} \exp\left(-\frac{(Y_i - \theta X_i)^2}{2\sigma^2}\right)\right] + \sum_{i=1}^n \log f(X_i) \\
|
528 |
+
&= -\frac{n}{2} \log(2\pi\sigma^2) - \frac{1}{2\sigma^2} \sum_{i=1}^n (Y_i - \theta X_i)^2 + \sum_{i=1}^n \log f(X_i)
|
529 |
+
\end{align*}$$
|
530 |
+
|
531 |
+
Step 4: Since we only care about maximizing with respect to $\theta$, we can drop terms that don't contain $\theta$:
|
532 |
+
$$\hat{\theta} = \underset{\theta}{\operatorname{argmax}} \left[ -\frac{1}{2\sigma^2} \sum_{i=1}^n (Y_i - \theta X_i)^2 \right]$$
|
533 |
+
|
534 |
+
This is equivalent to:
|
535 |
+
$$\hat{\theta} = \underset{\theta}{\operatorname{argmin}} \sum_{i=1}^n (Y_i - \theta X_i)^2$$
|
536 |
+
|
537 |
+
Step 5: Find the value of $\theta$ that minimizes the sum of squared errors by setting the derivative to zero:
|
538 |
+
$$\begin{align*}
|
539 |
+
\frac{d}{d\theta} \sum_{i=1}^n (Y_i - \theta X_i)^2 &= 0 \\
|
540 |
+
\sum_{i=1}^n -2X_i(Y_i - \theta X_i) &= 0 \\
|
541 |
+
\sum_{i=1}^n X_i Y_i - \theta X_i^2 &= 0 \\
|
542 |
+
\sum_{i=1}^n X_i Y_i &= \theta \sum_{i=1}^n X_i^2 \\
|
543 |
+
\hat{\theta} &= \frac{\sum_{i=1}^n X_i Y_i}{\sum_{i=1}^n X_i^2}
|
544 |
+
\end{align*}$$
|
545 |
+
|
546 |
+
/// tip
|
547 |
+
**Key Insight**: MLE for this simple linear model gives us the least squares estimator! This is an important connection between MLE and regression.
|
548 |
+
///
|
549 |
+
"""
|
550 |
+
)
|
551 |
+
return
|
552 |
+
|
553 |
+
|
554 |
+
@app.cell(hide_code=True)
|
555 |
+
def _(linear_controls):
|
556 |
+
linear_controls.center()
|
557 |
+
return
|
558 |
+
|
559 |
+
|
560 |
+
@app.cell(hide_code=True)
|
561 |
+
def _(
|
562 |
+
linear_generate_button,
|
563 |
+
linear_sample_size_slider,
|
564 |
+
mo,
|
565 |
+
noise_sigma_slider,
|
566 |
+
np,
|
567 |
+
plt,
|
568 |
+
true_theta_slider,
|
569 |
+
):
|
570 |
+
# linear model data calc when button is clicked
|
571 |
+
linear_button_value = linear_generate_button.value
|
572 |
+
|
573 |
+
# get parameter values
|
574 |
+
linear_true_theta = true_theta_slider.value
|
575 |
+
linear_noise_sigma = noise_sigma_slider.value
|
576 |
+
linear_n = linear_sample_size_slider.value
|
577 |
+
|
578 |
+
# generate x data (uniformly between -3 and 3)
|
579 |
+
linear_X = np.random.uniform(-3, 3, size=linear_n)
|
580 |
+
|
581 |
+
# generate y data according to the model y = θx + z
|
582 |
+
linear_Z = np.random.normal(0, linear_noise_sigma, size=linear_n)
|
583 |
+
linear_Y = linear_true_theta * linear_X + linear_Z
|
584 |
+
|
585 |
+
# calculate mle estimate
|
586 |
+
linear_theta_hat = np.sum(linear_X * linear_Y) / np.sum(linear_X**2)
|
587 |
+
|
588 |
+
# calculate sse for different theta values
|
589 |
+
linear_theta_range = np.linspace(linear_true_theta - 1.5, linear_true_theta + 1.5, 100)
|
590 |
+
linear_sse_values = np.zeros_like(linear_theta_range)
|
591 |
+
|
592 |
+
for linear_i, linear_theta in enumerate(linear_theta_range):
|
593 |
+
linear_y_pred = linear_theta * linear_X
|
594 |
+
linear_sse_values[linear_i] = np.sum((linear_Y - linear_y_pred)**2)
|
595 |
+
|
596 |
+
# convert sse to log-likelihood (ignoring constant terms)
|
597 |
+
linear_ll_values = -linear_sse_values / (2 * linear_noise_sigma**2)
|
598 |
+
|
599 |
+
# create visualization
|
600 |
+
linear_fig, (linear_ax1, linear_ax2) = plt.subplots(1, 2, figsize=(12, 5))
|
601 |
+
|
602 |
+
# plot scatter plot with regression lines
|
603 |
+
linear_ax1.scatter(linear_X, linear_Y, color='blue', alpha=0.6, label='Data points')
|
604 |
+
|
605 |
+
# plot range for regression lines
|
606 |
+
linear_x_line = np.linspace(-3, 3, 100)
|
607 |
+
|
608 |
+
# plot true and mle regression lines
|
609 |
+
linear_ax1.plot(linear_x_line, linear_true_theta * linear_x_line, 'g-', linewidth=2, label=f'True: Y = {linear_true_theta:.2f}X')
|
610 |
+
linear_ax1.plot(linear_x_line, linear_theta_hat * linear_x_line, 'r--', linewidth=2, label=f'MLE: Y = {linear_theta_hat:.2f}X')
|
611 |
+
|
612 |
+
linear_ax1.set_xlabel('X')
|
613 |
+
linear_ax1.set_ylabel('Y')
|
614 |
+
linear_ax1.set_title(f'Linear Regression: {linear_n} data points')
|
615 |
+
linear_ax1.grid(True, alpha=0.3)
|
616 |
+
linear_ax1.legend()
|
617 |
+
|
618 |
+
# plot log-likelihood function
|
619 |
+
linear_ax2.plot(linear_theta_range, linear_ll_values, 'b-', linewidth=2)
|
620 |
+
linear_ax2.axvline(x=linear_theta_hat, color='r', linestyle='--', label=f'MLE: θ = {linear_theta_hat:.3f}')
|
621 |
+
linear_ax2.axvline(x=linear_true_theta, color='g', linestyle='--', label=f'True: θ = {linear_true_theta:.3f}')
|
622 |
+
linear_ax2.set_xlabel('θ (slope parameter)')
|
623 |
+
linear_ax2.set_ylabel('Log-Likelihood')
|
624 |
+
linear_ax2.set_title('Log-Likelihood Function')
|
625 |
+
linear_ax2.grid(True, alpha=0.3)
|
626 |
+
linear_ax2.legend()
|
627 |
+
|
628 |
+
plt.tight_layout()
|
629 |
+
plt.gca()
|
630 |
+
|
631 |
+
# relevant markdown to explain results
|
632 |
+
linear_explanation = mo.md(
|
633 |
+
f"""
|
634 |
+
### Linear Regression MLE Results
|
635 |
+
|
636 |
+
**True parameter**: $\\theta = {linear_true_theta:.3f}$
|
637 |
+
**MLE estimate**: $\\hat{{\\theta}} = {linear_theta_hat:.3f}$
|
638 |
+
|
639 |
+
The left plot shows the scatter plot of data points with the true regression line (green) and the MLE-estimated regression line (red dashed).
|
640 |
+
|
641 |
+
The right plot shows the log-likelihood function for different values of $\\theta$. The maximum likelihood estimate is marked with a red dashed line, and the true parameter is marked with a green dashed line.
|
642 |
+
|
643 |
+
/// note
|
644 |
+
The MLE estimate $\\hat{{\\theta}} = \\frac{{\\sum_{{i=1}}^n X_i Y_i}}{{\\sum_{{i=1}}^n X_i^2}}$ minimizes the sum of squared errors between the predicted and actual Y values.
|
645 |
+
///
|
646 |
+
|
647 |
+
/// tip
|
648 |
+
Try increasing the noise level to see how it affects the precision of the estimate!
|
649 |
+
///
|
650 |
+
"""
|
651 |
+
)
|
652 |
+
|
653 |
+
# show plot and explanation
|
654 |
+
mo.vstack([
|
655 |
+
linear_fig,
|
656 |
+
linear_explanation
|
657 |
+
])
|
658 |
+
return (
|
659 |
+
linear_X,
|
660 |
+
linear_Y,
|
661 |
+
linear_Z,
|
662 |
+
linear_ax1,
|
663 |
+
linear_ax2,
|
664 |
+
linear_button_value,
|
665 |
+
linear_explanation,
|
666 |
+
linear_fig,
|
667 |
+
linear_i,
|
668 |
+
linear_ll_values,
|
669 |
+
linear_n,
|
670 |
+
linear_noise_sigma,
|
671 |
+
linear_sse_values,
|
672 |
+
linear_theta,
|
673 |
+
linear_theta_hat,
|
674 |
+
linear_theta_range,
|
675 |
+
linear_true_theta,
|
676 |
+
linear_x_line,
|
677 |
+
linear_y_pred,
|
678 |
+
)
|
679 |
+
|
680 |
+
|
681 |
+
@app.cell(hide_code=True)
|
682 |
+
def _(mo):
|
683 |
+
mo.md(
|
684 |
+
r"""
|
685 |
+
## Interactive Concept: Density/Mass Functions vs. Likelihood
|
686 |
+
|
687 |
+
To better understand the distinction between likelihood and density/mass functions, let's create an interactive visualization. This concept is crucial for understanding why MLE works.
|
688 |
+
"""
|
689 |
+
)
|
690 |
+
return
|
691 |
+
|
692 |
+
|
693 |
+
@app.cell(hide_code=True)
|
694 |
+
def _(concept_controls):
|
695 |
+
concept_controls.center()
|
696 |
+
return
|
697 |
+
|
698 |
+
|
699 |
+
@app.cell(hide_code=True)
|
700 |
+
def _(concept_dist_type, mo, np, perspective_selector, plt, stats):
|
701 |
+
# current distribution type
|
702 |
+
concept_dist_type_value = concept_dist_type.value
|
703 |
+
|
704 |
+
# view mode from dropdown
|
705 |
+
concept_view_mode = "likelihood" if perspective_selector.value == "Likelihood Perspective" else "probability"
|
706 |
+
|
707 |
+
# visualization based on distribution type
|
708 |
+
concept_fig, concept_ax = plt.subplots(figsize=(10, 6))
|
709 |
+
|
710 |
+
if concept_dist_type_value == "Normal":
|
711 |
+
if concept_view_mode == "probability":
|
712 |
+
# density function perspective: fixed params, varying data
|
713 |
+
concept_mu = 0 # fixed parameter
|
714 |
+
concept_sigma = 1 # fixed parameter
|
715 |
+
|
716 |
+
# generate x values for the pdf
|
717 |
+
concept_x = np.linspace(-4, 4, 1000)
|
718 |
+
|
719 |
+
# plot pdf
|
720 |
+
concept_pdf = stats.norm.pdf(concept_x, concept_mu, concept_sigma)
|
721 |
+
concept_ax.plot(concept_x, concept_pdf, 'b-', linewidth=2, label='PDF: N(0, 1)')
|
722 |
+
|
723 |
+
# highlight specific data values
|
724 |
+
concept_data_points = [-2, -1, 0, 1, 2]
|
725 |
+
concept_colors = ['#FF9999', '#FFCC99', '#99FF99', '#99CCFF', '#CC99FF']
|
726 |
+
|
727 |
+
for concept_i, concept_data in enumerate(concept_data_points):
|
728 |
+
concept_prob = stats.norm.pdf(concept_data, concept_mu, concept_sigma)
|
729 |
+
concept_ax.plot([concept_data, concept_data], [0, concept_prob], concept_colors[concept_i], linewidth=2)
|
730 |
+
concept_ax.scatter(concept_data, concept_prob, color=concept_colors[concept_i], s=50,
|
731 |
+
label=f'PDF at x={concept_data}: {concept_prob:.3f}')
|
732 |
+
|
733 |
+
concept_ax.set_xlabel('Data (x)')
|
734 |
+
concept_ax.set_ylabel('Probability Density')
|
735 |
+
concept_ax.set_title('Density Function Perspective: Fixed Parameters (μ=0, σ=1), Different Data Points')
|
736 |
+
|
737 |
+
else: # likelihood perspective
|
738 |
+
# likelihood perspective: fixed data, varying parameters
|
739 |
+
concept_data_point = 1.5 # fixed observed data
|
740 |
+
|
741 |
+
# different possible parameter values (means)
|
742 |
+
concept_mus = [-1, 0, 1, 2, 3]
|
743 |
+
concept_sigma = 1
|
744 |
+
|
745 |
+
# generate x values for multiple pdfs
|
746 |
+
concept_x = np.linspace(-4, 6, 1000)
|
747 |
+
|
748 |
+
concept_colors = ['#FF9999', '#FFCC99', '#99FF99', '#99CCFF', '#CC99FF']
|
749 |
+
|
750 |
+
for concept_i, concept_mu in enumerate(concept_mus):
|
751 |
+
concept_pdf = stats.norm.pdf(concept_x, concept_mu, concept_sigma)
|
752 |
+
concept_ax.plot(concept_x, concept_pdf, color=concept_colors[concept_i], linewidth=2, alpha=0.7,
|
753 |
+
label=f'N({concept_mu}, 1)')
|
754 |
+
|
755 |
+
# mark the likelihood of the data point for this param
|
756 |
+
concept_likelihood = stats.norm.pdf(concept_data_point, concept_mu, concept_sigma)
|
757 |
+
concept_ax.plot([concept_data_point, concept_data_point], [0, concept_likelihood], concept_colors[concept_i], linewidth=2)
|
758 |
+
concept_ax.scatter(concept_data_point, concept_likelihood, color=concept_colors[concept_i], s=50,
|
759 |
+
label=f'L(μ={concept_mu}|X=1.5) = {concept_likelihood:.3f}')
|
760 |
+
|
761 |
+
# add vertical line at the observed data point
|
762 |
+
concept_ax.axvline(x=concept_data_point, color='black', linestyle='--',
|
763 |
+
label=f'Observed data: X=1.5')
|
764 |
+
|
765 |
+
concept_ax.set_xlabel('Data (x)')
|
766 |
+
concept_ax.set_ylabel('Probability Density / Likelihood')
|
767 |
+
concept_ax.set_title('Likelihood Perspective: Fixed Data Point (X=1.5), Different Parameter Values')
|
768 |
+
|
769 |
+
elif concept_dist_type_value == "Bernoulli":
|
770 |
+
if concept_view_mode == "probability":
|
771 |
+
# probability perspective: fixed parameter, two possible data values
|
772 |
+
concept_p = 0.3 # fixed parameter
|
773 |
+
|
774 |
+
# bar chart for p(x=0) and p(x=1)
|
775 |
+
concept_ax.bar([0, 1], [1-concept_p, concept_p], width=0.4, color=['#99CCFF', '#FF9999'],
|
776 |
+
alpha=0.7, label=f'PMF: Bernoulli({concept_p})')
|
777 |
+
|
778 |
+
# text showing probabilities
|
779 |
+
concept_ax.text(0, (1-concept_p)/2, f'P(X=0|p={concept_p}) = {1-concept_p:.3f}', ha='center', va='center', fontweight='bold')
|
780 |
+
concept_ax.text(1, concept_p/2, f'P(X=1|p={concept_p}) = {concept_p:.3f}', ha='center', va='center', fontweight='bold')
|
781 |
+
|
782 |
+
concept_ax.set_xlabel('Data (x)')
|
783 |
+
concept_ax.set_ylabel('Probability')
|
784 |
+
concept_ax.set_xticks([0, 1])
|
785 |
+
concept_ax.set_xticklabels(['X=0', 'X=1'])
|
786 |
+
concept_ax.set_ylim(0, 1)
|
787 |
+
concept_ax.set_title('Probability Perspective: Fixed Parameter (p=0.3), Different Data Values')
|
788 |
+
|
789 |
+
else: # likelihood perspective
|
790 |
+
# likelihood perspective: fixed data, varying parameter
|
791 |
+
concept_data_point = 1 # fixed observed data (success)
|
792 |
+
|
793 |
+
# different possible parameter values
|
794 |
+
concept_p_values = np.linspace(0.01, 0.99, 100)
|
795 |
+
|
796 |
+
# calculate likelihood for each p value
|
797 |
+
if concept_data_point == 1:
|
798 |
+
# for x=1, likelihood is p
|
799 |
+
concept_likelihood = concept_p_values
|
800 |
+
concept_ax.plot(concept_p_values, concept_likelihood, 'b-', linewidth=2,
|
801 |
+
label=f'L(p|X=1) = p')
|
802 |
+
|
803 |
+
# highlight specific values
|
804 |
+
concept_highlight_ps = [0.2, 0.5, 0.8]
|
805 |
+
concept_colors = ['#FF9999', '#99FF99', '#99CCFF']
|
806 |
+
|
807 |
+
for concept_i, concept_p in enumerate(concept_highlight_ps):
|
808 |
+
concept_ax.plot([concept_p, concept_p], [0, concept_p], concept_colors[concept_i], linewidth=2)
|
809 |
+
concept_ax.scatter(concept_p, concept_p, color=concept_colors[concept_i], s=50,
|
810 |
+
label=f'L(p={concept_p}|X=1) = {concept_p:.3f}')
|
811 |
+
|
812 |
+
concept_ax.set_title('Likelihood Perspective: Fixed Data Point (X=1), Different Parameter Values')
|
813 |
+
|
814 |
+
else: # x=0
|
815 |
+
# for x = 0, likelihood is (1-p)
|
816 |
+
concept_likelihood = 1 - concept_p_values
|
817 |
+
concept_ax.plot(concept_p_values, concept_likelihood, 'r-', linewidth=2,
|
818 |
+
label=f'L(p|X=0) = (1-p)')
|
819 |
+
|
820 |
+
# highlight some specific values
|
821 |
+
concept_highlight_ps = [0.2, 0.5, 0.8]
|
822 |
+
concept_colors = ['#FF9999', '#99FF99', '#99CCFF']
|
823 |
+
|
824 |
+
for concept_i, concept_p in enumerate(concept_highlight_ps):
|
825 |
+
concept_ax.plot([concept_p, concept_p], [0, 1-concept_p], concept_colors[concept_i], linewidth=2)
|
826 |
+
concept_ax.scatter(concept_p, 1-concept_p, color=concept_colors[concept_i], s=50,
|
827 |
+
label=f'L(p={concept_p}|X=0) = {1-concept_p:.3f}')
|
828 |
+
|
829 |
+
concept_ax.set_title('Likelihood Perspective: Fixed Data Point (X=0), Different Parameter Values')
|
830 |
+
|
831 |
+
concept_ax.set_xlabel('Parameter (p)')
|
832 |
+
concept_ax.set_ylabel('Likelihood')
|
833 |
+
concept_ax.set_xlim(0, 1)
|
834 |
+
concept_ax.set_ylim(0, 1)
|
835 |
+
|
836 |
+
elif concept_dist_type_value == "Poisson":
|
837 |
+
if concept_view_mode == "probability":
|
838 |
+
# probability perspective: fixed parameter, different data values
|
839 |
+
concept_lam = 2.5 # fixed parameter
|
840 |
+
|
841 |
+
# pmf for different x values plot
|
842 |
+
concept_x_values = np.arange(0, 10)
|
843 |
+
concept_pmf_values = stats.poisson.pmf(concept_x_values, concept_lam)
|
844 |
+
|
845 |
+
concept_ax.bar(concept_x_values, concept_pmf_values, width=0.4, color='#99CCFF',
|
846 |
+
alpha=0.7, label=f'PMF: Poisson({concept_lam})')
|
847 |
+
|
848 |
+
# highlight a few specific values
|
849 |
+
concept_highlight_xs = [1, 2, 3, 4]
|
850 |
+
concept_colors = ['#FF9999', '#99FF99', '#FFCC99', '#CC99FF']
|
851 |
+
|
852 |
+
for concept_i, concept_x in enumerate(concept_highlight_xs):
|
853 |
+
concept_prob = stats.poisson.pmf(concept_x, concept_lam)
|
854 |
+
concept_ax.scatter(concept_x, concept_prob, color=concept_colors[concept_i], s=50,
|
855 |
+
label=f'P(X={concept_x}|λ={concept_lam}) = {concept_prob:.3f}')
|
856 |
+
|
857 |
+
concept_ax.set_xlabel('Data (x)')
|
858 |
+
concept_ax.set_ylabel('Probability')
|
859 |
+
concept_ax.set_xticks(concept_x_values)
|
860 |
+
concept_ax.set_title('Probability Perspective: Fixed Parameter (λ=2.5), Different Data Values')
|
861 |
+
|
862 |
+
else: # likelihood perspective
|
863 |
+
# likelihood perspective: fixed data, varying parameter
|
864 |
+
concept_data_point = 4 # fixed observed data
|
865 |
+
|
866 |
+
# different possible param values
|
867 |
+
concept_lambda_values = np.linspace(0.1, 8, 100)
|
868 |
+
|
869 |
+
# calc likelihood for each lambda value
|
870 |
+
concept_likelihood = stats.poisson.pmf(concept_data_point, concept_lambda_values)
|
871 |
+
|
872 |
+
concept_ax.plot(concept_lambda_values, concept_likelihood, 'b-', linewidth=2,
|
873 |
+
label=f'L(λ|X={concept_data_point})')
|
874 |
+
|
875 |
+
# highlight some specific values
|
876 |
+
concept_highlight_lambdas = [1, 2, 4, 6]
|
877 |
+
concept_colors = ['#FF9999', '#99FF99', '#99CCFF', '#FFCC99']
|
878 |
+
|
879 |
+
for concept_i, concept_lam in enumerate(concept_highlight_lambdas):
|
880 |
+
concept_like_val = stats.poisson.pmf(concept_data_point, concept_lam)
|
881 |
+
concept_ax.plot([concept_lam, concept_lam], [0, concept_like_val], concept_colors[concept_i], linewidth=2)
|
882 |
+
concept_ax.scatter(concept_lam, concept_like_val, color=concept_colors[concept_i], s=50,
|
883 |
+
label=f'L(λ={concept_lam}|X={concept_data_point}) = {concept_like_val:.3f}')
|
884 |
+
|
885 |
+
concept_ax.set_xlabel('Parameter (λ)')
|
886 |
+
concept_ax.set_ylabel('Likelihood')
|
887 |
+
concept_ax.set_title(f'Likelihood Perspective: Fixed Data Point (X={concept_data_point}), Different Parameter Values')
|
888 |
+
|
889 |
+
concept_ax.legend(loc='best', fontsize=9)
|
890 |
+
concept_ax.grid(True, alpha=0.3)
|
891 |
+
plt.tight_layout()
|
892 |
+
plt.gca()
|
893 |
+
|
894 |
+
# relevant explanation based on view mode
|
895 |
+
if concept_view_mode == "probability":
|
896 |
+
concept_explanation = mo.md(
|
897 |
+
f"""
|
898 |
+
### Density/Mass Function Perspective
|
899 |
+
|
900 |
+
In the **density/mass function perspective**, the parameters of the distribution are **fixed and known**, and we evaluate the function at **different possible data values**.
|
901 |
+
|
902 |
+
For the {concept_dist_type_value} distribution, we've fixed the parameter{'s' if concept_dist_type_value == 'Normal' else ''} and shown the {'density' if concept_dist_type_value == 'Normal' else 'probability mass'} function evaluated at different data points.
|
903 |
+
|
904 |
+
This is the typical perspective when:
|
905 |
+
|
906 |
+
- We know the true parameters of a distribution
|
907 |
+
- We want to evaluate the {'density' if concept_dist_type_value == 'Normal' else 'probability mass'} at different observations
|
908 |
+
- We make predictions based on our model
|
909 |
+
|
910 |
+
**Mathematical notation**: $f(x | \theta)$
|
911 |
+
"""
|
912 |
+
)
|
913 |
+
else: # likelihood perspective
|
914 |
+
concept_explanation = mo.md(
|
915 |
+
f"""
|
916 |
+
### Likelihood Perspective
|
917 |
+
|
918 |
+
In the **likelihood perspective**, the observed data is **fixed and known**, and we calculate how likely different parameter values are to have generated that data.
|
919 |
+
|
920 |
+
For the {concept_dist_type_value} distribution, we've fixed the observed data point{'s' if concept_dist_type_value == 'Normal' else ''} and shown the likelihood of different parameter values.
|
921 |
+
|
922 |
+
This is the perspective used in MLE:
|
923 |
+
|
924 |
+
- We have observed data
|
925 |
+
- We don't know the true parameters
|
926 |
+
- We want to find parameters that best explain our observations
|
927 |
+
|
928 |
+
**Mathematical notation**: $L(\theta | X = x)$
|
929 |
+
|
930 |
+
/// tip
|
931 |
+
The value of $\\theta$ that maximizes this likelihood function is the MLE estimate $\\hat{{\\theta}}$!
|
932 |
+
///
|
933 |
+
"""
|
934 |
+
)
|
935 |
+
|
936 |
+
# Display plot and explanation together
|
937 |
+
mo.vstack([
|
938 |
+
concept_fig,
|
939 |
+
concept_explanation
|
940 |
+
])
|
941 |
+
return (
|
942 |
+
concept_ax,
|
943 |
+
concept_colors,
|
944 |
+
concept_data,
|
945 |
+
concept_data_point,
|
946 |
+
concept_data_points,
|
947 |
+
concept_dist_type_value,
|
948 |
+
concept_explanation,
|
949 |
+
concept_fig,
|
950 |
+
concept_highlight_lambdas,
|
951 |
+
concept_highlight_ps,
|
952 |
+
concept_highlight_xs,
|
953 |
+
concept_i,
|
954 |
+
concept_lam,
|
955 |
+
concept_lambda_values,
|
956 |
+
concept_like_val,
|
957 |
+
concept_likelihood,
|
958 |
+
concept_mu,
|
959 |
+
concept_mus,
|
960 |
+
concept_p,
|
961 |
+
concept_p_values,
|
962 |
+
concept_pdf,
|
963 |
+
concept_pmf_values,
|
964 |
+
concept_prob,
|
965 |
+
concept_sigma,
|
966 |
+
concept_view_mode,
|
967 |
+
concept_x,
|
968 |
+
concept_x_values,
|
969 |
+
)
|
970 |
+
|
971 |
+
|
972 |
+
@app.cell(hide_code=True)
|
973 |
+
def _(mo):
|
974 |
+
mo.md(
|
975 |
+
r"""
|
976 |
+
## 🤔 Test Your Understanding
|
977 |
+
|
978 |
+
Which of the following statements about Maximum Likelihood Estimation are correct? Click each statement to check your answer.
|
979 |
+
|
980 |
+
/// details | Probability and likelihood have different interpretations: probability measures the chance of data given parameters, while likelihood measures how likely parameters are given data.
|
981 |
+
✅ **Correct!**
|
982 |
+
|
983 |
+
Probability measures how likely it is to observe particular data when we know the parameters. Likelihood measures how likely particular parameter values are, given observed data.
|
984 |
+
|
985 |
+
Mathematically, probability is $P(X=x|\theta)$ while likelihood is $L(\theta|X=x)$.
|
986 |
+
///
|
987 |
+
|
988 |
+
/// details | We use log-likelihood instead of likelihood because it's mathematically simpler and numerically more stable.
|
989 |
+
✅ **Correct!**
|
990 |
+
|
991 |
+
We work with log-likelihood for several reasons:
|
992 |
+
1. It converts products into sums, which is easier to work with mathematically
|
993 |
+
2. It avoids numerical underflow when multiplying many small probabilities
|
994 |
+
3. Logarithm is a monotonically increasing function, so the maximum of the likelihood occurs at the same parameter values as the maximum of the log-likelihood
|
995 |
+
///
|
996 |
+
|
997 |
+
/// details | For a Bernoulli distribution, the MLE for parameter p is the sample mean of the observations.
|
998 |
+
✅ **Correct!**
|
999 |
+
|
1000 |
+
For a Bernoulli distribution with parameter $p$, given $n$ independent samples $X_1, X_2, \ldots, X_n$, the MLE estimator is:
|
1001 |
+
|
1002 |
+
$$\hat{p} = \frac{\sum_{i=1}^n X_i}{n}$$
|
1003 |
+
|
1004 |
+
This is simply the sample mean, or the proportion of successes (1s) in the data.
|
1005 |
+
///
|
1006 |
+
|
1007 |
+
/// details | For a Normal distribution, MLE gives unbiased estimates for both mean and variance parameters.
|
1008 |
+
❌ **Incorrect.**
|
1009 |
+
|
1010 |
+
While the MLE for the mean ($\hat{\mu} = \frac{1}{n}\sum_{i=1}^n X_i$) is unbiased, the MLE for variance:
|
1011 |
+
|
1012 |
+
$$\hat{\sigma}^2 = \frac{1}{n}\sum_{i=1}^n (X_i - \hat{\mu})^2$$
|
1013 |
+
|
1014 |
+
is a biased estimator. It uses $n$ in the denominator rather than $n-1$ used in the unbiased estimator.
|
1015 |
+
///
|
1016 |
+
|
1017 |
+
/// details | MLE estimators are always unbiased regardless of the distribution.
|
1018 |
+
❌ **Incorrect.**
|
1019 |
+
|
1020 |
+
MLE is not always unbiased, though it often is asymptotically unbiased (meaning the bias approaches zero as the sample size increases).
|
1021 |
+
|
1022 |
+
A notable example is the MLE estimator for the variance of a Normal distribution:
|
1023 |
+
$$\hat{\sigma}^2 = \frac{1}{n}\sum_{i=1}^n (X_i - \hat{\mu})^2$$
|
1024 |
+
|
1025 |
+
This estimator is biased, which is why we often use the unbiased estimator:
|
1026 |
+
$$s^2 = \frac{1}{n-1}\sum_{i=1}^n (X_i - \hat{\mu})^2$$
|
1027 |
+
|
1028 |
+
Despite occasional bias, MLE estimators have many desirable properties, including consistency and asymptotic efficiency.
|
1029 |
+
///
|
1030 |
+
"""
|
1031 |
+
)
|
1032 |
+
return
|
1033 |
+
|
1034 |
+
|
1035 |
+
@app.cell(hide_code=True)
|
1036 |
+
def _(mo):
|
1037 |
+
mo.md(
|
1038 |
+
r"""
|
1039 |
+
## Summary
|
1040 |
+
|
1041 |
+
Maximum Likelihood Estimation really is one of those elegant ideas that sits at the core of modern statistics. When you get down to it, MLE is just about finding the most plausible explanation for the data we've observed. It's like being a detective - you have some clues (your data), and you're trying to piece together the most likely story (your parameters) that explains them.
|
1042 |
+
|
1043 |
+
We've seen how this works with different distributions. For the Bernoulli, it simply gives us the sample proportion. For the Normal, it gives us the sample mean and a slightly biased estimate of variance. And for linear regression, it provides a mathematical justification for the least squares method that everyone learns in basic stats classes.
|
1044 |
+
|
1045 |
+
What makes MLE so useful in practice is that it tends to give us estimates with good properties. As you collect more data, the estimates generally get closer to the true values (consistency) and do so efficiently. That's why MLE is everywhere in statistics and machine learning - from simple regression models to complex neural networks.
|
1046 |
+
|
1047 |
+
The most important takeaway? Next time you're fitting a model to data, remember that you're not just following a recipe - you're finding the parameters that make your observed data most likely to have occurred. That's the essence of statistical inference.
|
1048 |
+
"""
|
1049 |
+
)
|
1050 |
+
return
|
1051 |
+
|
1052 |
+
|
1053 |
+
@app.cell(hide_code=True)
|
1054 |
+
def _(mo):
|
1055 |
+
mo.md(
|
1056 |
+
r"""
|
1057 |
+
## Further Reading
|
1058 |
+
|
1059 |
+
If you're curious to dive deeper into this topic, check out "Statistical Inference" by Casella and Berger - it's the classic text that many statisticians learned from. For a more machine learning angle, Bishop's "Pattern Recognition and Machine Learning" shows how MLE connects to more advanced topics like EM algorithms and Bayesian methods.
|
1060 |
+
|
1061 |
+
Beyond the basics we've covered, you might explore Bayesian estimation (which incorporates prior knowledge), Fisher Information (which tells us how precisely we can estimate parameters), or the EM algorithm (for when we have missing data or latent variables). Each of these builds on the foundation of likelihood that we've established here.
|
1062 |
+
"""
|
1063 |
+
)
|
1064 |
+
return
|
1065 |
+
|
1066 |
+
|
1067 |
+
@app.cell(hide_code=True)
|
1068 |
+
def _(mo):
|
1069 |
+
mo.md(r"""## Appendix (helper functions and imports)""")
|
1070 |
+
return
|
1071 |
+
|
1072 |
+
|
1073 |
+
@app.cell
|
1074 |
+
def _():
|
1075 |
+
import marimo as mo
|
1076 |
+
return (mo,)
|
1077 |
+
|
1078 |
+
|
1079 |
+
@app.cell
|
1080 |
+
def _():
|
1081 |
+
import numpy as np
|
1082 |
+
import matplotlib.pyplot as plt
|
1083 |
+
from scipy import stats
|
1084 |
+
import plotly.graph_objects as go
|
1085 |
+
import polars as pl
|
1086 |
+
from matplotlib import cm
|
1087 |
+
|
1088 |
+
# Set a consistent random seed for reproducibility
|
1089 |
+
np.random.seed(42)
|
1090 |
+
|
1091 |
+
# Set a nice style for matplotlib
|
1092 |
+
plt.style.use('seaborn-v0_8-darkgrid')
|
1093 |
+
return cm, go, np, pl, plt, stats
|
1094 |
+
|
1095 |
+
|
1096 |
+
@app.cell(hide_code=True)
|
1097 |
+
def _(mo):
|
1098 |
+
# Create interactive elements
|
1099 |
+
true_p_slider = mo.ui.slider(
|
1100 |
+
start =0.01,
|
1101 |
+
stop =0.99,
|
1102 |
+
value=0.3,
|
1103 |
+
step=0.01,
|
1104 |
+
label="True probability (p)"
|
1105 |
+
)
|
1106 |
+
|
1107 |
+
sample_size_slider = mo.ui.slider(
|
1108 |
+
start =10,
|
1109 |
+
stop =1000,
|
1110 |
+
value=100,
|
1111 |
+
step=10,
|
1112 |
+
label="Sample size (n)"
|
1113 |
+
)
|
1114 |
+
|
1115 |
+
generate_button = mo.ui.button(label="Generate New Sample", kind="success")
|
1116 |
+
|
1117 |
+
controls = mo.vstack([
|
1118 |
+
mo.vstack([true_p_slider, sample_size_slider]),
|
1119 |
+
generate_button
|
1120 |
+
], justify="space-between")
|
1121 |
+
return controls, generate_button, sample_size_slider, true_p_slider
|
1122 |
+
|
1123 |
+
|
1124 |
+
@app.cell(hide_code=True)
|
1125 |
+
def _(mo):
|
1126 |
+
# Create interactive elements for Normal distribution
|
1127 |
+
true_mu_slider = mo.ui.slider(
|
1128 |
+
start =-5,
|
1129 |
+
stop =5,
|
1130 |
+
value=0,
|
1131 |
+
step=0.1,
|
1132 |
+
label="True mean (μ)"
|
1133 |
+
)
|
1134 |
+
|
1135 |
+
true_sigma_slider = mo.ui.slider(
|
1136 |
+
start =0.5,
|
1137 |
+
stop =3,
|
1138 |
+
value=1,
|
1139 |
+
step=0.1,
|
1140 |
+
label="True standard deviation (σ)"
|
1141 |
+
)
|
1142 |
+
|
1143 |
+
normal_sample_size_slider = mo.ui.slider(
|
1144 |
+
start =10,
|
1145 |
+
stop =500,
|
1146 |
+
value=50,
|
1147 |
+
step=10,
|
1148 |
+
label="Sample size (n)"
|
1149 |
+
)
|
1150 |
+
|
1151 |
+
normal_generate_button = mo.ui.button(label="Generate New Sample", kind="warn")
|
1152 |
+
|
1153 |
+
normal_controls = mo.hstack([
|
1154 |
+
mo.vstack([true_mu_slider, true_sigma_slider, normal_sample_size_slider]),
|
1155 |
+
normal_generate_button
|
1156 |
+
], justify="space-between")
|
1157 |
+
return (
|
1158 |
+
normal_controls,
|
1159 |
+
normal_generate_button,
|
1160 |
+
normal_sample_size_slider,
|
1161 |
+
true_mu_slider,
|
1162 |
+
true_sigma_slider,
|
1163 |
+
)
|
1164 |
+
|
1165 |
+
|
1166 |
+
@app.cell(hide_code=True)
|
1167 |
+
def _(mo):
|
1168 |
+
# Create interactive elements for linear regression
|
1169 |
+
true_theta_slider = mo.ui.slider(
|
1170 |
+
start =-2,
|
1171 |
+
stop =2,
|
1172 |
+
value=0.5,
|
1173 |
+
step=0.1,
|
1174 |
+
label="True slope (θ)"
|
1175 |
+
)
|
1176 |
+
|
1177 |
+
noise_sigma_slider = mo.ui.slider(
|
1178 |
+
start =0.1,
|
1179 |
+
stop =2,
|
1180 |
+
value=0.5,
|
1181 |
+
step=0.1,
|
1182 |
+
label="Noise level (σ)"
|
1183 |
+
)
|
1184 |
+
|
1185 |
+
linear_sample_size_slider = mo.ui.slider(
|
1186 |
+
start =10,
|
1187 |
+
stop =200,
|
1188 |
+
value=50,
|
1189 |
+
step=10,
|
1190 |
+
label="Sample size (n)"
|
1191 |
+
)
|
1192 |
+
|
1193 |
+
linear_generate_button = mo.ui.button(label="Generate New Sample", kind="warn")
|
1194 |
+
|
1195 |
+
linear_controls = mo.hstack([
|
1196 |
+
mo.vstack([true_theta_slider, noise_sigma_slider, linear_sample_size_slider]),
|
1197 |
+
linear_generate_button
|
1198 |
+
], justify="space-between")
|
1199 |
+
return (
|
1200 |
+
linear_controls,
|
1201 |
+
linear_generate_button,
|
1202 |
+
linear_sample_size_slider,
|
1203 |
+
noise_sigma_slider,
|
1204 |
+
true_theta_slider,
|
1205 |
+
)
|
1206 |
+
|
1207 |
+
|
1208 |
+
@app.cell(hide_code=True)
|
1209 |
+
def _(mo):
|
1210 |
+
# Interactive elements for likelihood vs probability demo
|
1211 |
+
concept_dist_type = mo.ui.dropdown(
|
1212 |
+
options=["Normal", "Bernoulli", "Poisson"],
|
1213 |
+
value="Normal",
|
1214 |
+
label="Distribution"
|
1215 |
+
)
|
1216 |
+
|
1217 |
+
# Replace buttons with a simple dropdown selector
|
1218 |
+
perspective_selector = mo.ui.dropdown(
|
1219 |
+
options=["Probability Perspective", "Likelihood Perspective"],
|
1220 |
+
value="Probability Perspective",
|
1221 |
+
label="View"
|
1222 |
+
)
|
1223 |
+
|
1224 |
+
concept_controls = mo.vstack([
|
1225 |
+
mo.hstack([concept_dist_type, perspective_selector])
|
1226 |
+
])
|
1227 |
+
return concept_controls, concept_dist_type, perspective_selector
|
1228 |
+
|
1229 |
+
|
1230 |
+
if __name__ == "__main__":
|
1231 |
+
app.run()
|
python/006_dictionaries.py
CHANGED
@@ -196,13 +196,13 @@ def _():
|
|
196 |
|
197 |
@app.cell
|
198 |
def _(mo, nested_data):
|
199 |
-
mo.md(f"Alice's age: {nested_data[
|
200 |
return
|
201 |
|
202 |
|
203 |
@app.cell
|
204 |
def _(mo, nested_data):
|
205 |
-
mo.md(f"Bob's interests: {nested_data[
|
206 |
return
|
207 |
|
208 |
|
|
|
196 |
|
197 |
@app.cell
|
198 |
def _(mo, nested_data):
|
199 |
+
mo.md(f"Alice's age: {nested_data['users']['alice']['age']}")
|
200 |
return
|
201 |
|
202 |
|
203 |
@app.cell
|
204 |
def _(mo, nested_data):
|
205 |
+
mo.md(f"Bob's interests: {nested_data['users']['bob']['interests']}")
|
206 |
return
|
207 |
|
208 |
|
scripts/build.py
ADDED
@@ -0,0 +1,281 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
|
3 |
+
import os
|
4 |
+
import subprocess
|
5 |
+
import argparse
|
6 |
+
import json
|
7 |
+
import datetime
|
8 |
+
import markdown
|
9 |
+
from datetime import date
|
10 |
+
from pathlib import Path
|
11 |
+
from typing import Dict, List, Any, Optional, Tuple
|
12 |
+
|
13 |
+
from jinja2 import Environment, FileSystemLoader
|
14 |
+
|
15 |
+
|
16 |
+
def export_html_wasm(notebook_path: str, output_dir: str, as_app: bool = False) -> bool:
|
17 |
+
"""Export a single marimo notebook to HTML format.
|
18 |
+
|
19 |
+
Args:
|
20 |
+
notebook_path: Path to the notebook to export
|
21 |
+
output_dir: Directory to write the output HTML files
|
22 |
+
as_app: If True, export as app instead of notebook
|
23 |
+
|
24 |
+
Returns:
|
25 |
+
bool: True if export succeeded, False otherwise
|
26 |
+
"""
|
27 |
+
# Create directory for the output
|
28 |
+
os.makedirs(output_dir, exist_ok=True)
|
29 |
+
|
30 |
+
# Determine the output path (preserving directory structure)
|
31 |
+
rel_path = os.path.basename(os.path.dirname(notebook_path))
|
32 |
+
if rel_path != os.path.dirname(notebook_path):
|
33 |
+
# Create subdirectory if needed
|
34 |
+
os.makedirs(os.path.join(output_dir, rel_path), exist_ok=True)
|
35 |
+
|
36 |
+
# Determine output filename (same as input but with .html extension)
|
37 |
+
output_filename = os.path.basename(notebook_path).replace(".py", ".html")
|
38 |
+
output_path = os.path.join(output_dir, rel_path, output_filename)
|
39 |
+
|
40 |
+
# Run marimo export command
|
41 |
+
mode = "--mode app" if as_app else "--mode edit"
|
42 |
+
cmd = f"marimo export html-wasm {mode} {notebook_path} -o {output_path} --sandbox"
|
43 |
+
print(f"Exporting {notebook_path} to {rel_path}/{output_filename} as {'app' if as_app else 'notebook'}")
|
44 |
+
print(f"Running command: {cmd}")
|
45 |
+
|
46 |
+
try:
|
47 |
+
result = subprocess.run(cmd, shell=True, check=True, capture_output=True, text=True)
|
48 |
+
print(f"Successfully exported {notebook_path} to {output_path}")
|
49 |
+
return True
|
50 |
+
except subprocess.CalledProcessError as e:
|
51 |
+
print(f"Error exporting {notebook_path}: {e}")
|
52 |
+
print(f"Command output: {e.output}")
|
53 |
+
return False
|
54 |
+
|
55 |
+
|
56 |
+
def get_course_metadata(course_dir: Path) -> Dict[str, Any]:
|
57 |
+
"""Extract metadata from a course directory.
|
58 |
+
|
59 |
+
Reads the README.md file to extract title and description.
|
60 |
+
|
61 |
+
Args:
|
62 |
+
course_dir: Path to the course directory
|
63 |
+
|
64 |
+
Returns:
|
65 |
+
Dict: Dictionary containing course metadata (title, description)
|
66 |
+
"""
|
67 |
+
readme_path = course_dir / "README.md"
|
68 |
+
title = course_dir.name.replace("_", " ").title()
|
69 |
+
description = ""
|
70 |
+
description_html = ""
|
71 |
+
|
72 |
+
if readme_path.exists():
|
73 |
+
with open(readme_path, "r", encoding="utf-8") as f:
|
74 |
+
content = f.read()
|
75 |
+
|
76 |
+
# Try to extract title from first heading
|
77 |
+
title_match = content.split("\n")[0]
|
78 |
+
if title_match.startswith("# "):
|
79 |
+
title = title_match[2:].strip()
|
80 |
+
|
81 |
+
# Extract description from content after first heading
|
82 |
+
desc_content = "\n".join(content.split("\n")[1:]).strip()
|
83 |
+
if desc_content:
|
84 |
+
# Take first paragraph as description, preserve markdown formatting
|
85 |
+
description = desc_content.split("\n\n")[0].strip()
|
86 |
+
# Convert markdown to HTML
|
87 |
+
description_html = markdown.markdown(description)
|
88 |
+
|
89 |
+
return {
|
90 |
+
"title": title,
|
91 |
+
"description": description,
|
92 |
+
"description_html": description_html
|
93 |
+
}
|
94 |
+
|
95 |
+
|
96 |
+
def organize_notebooks_by_course(all_notebooks: List[str]) -> Dict[str, Dict[str, Any]]:
|
97 |
+
"""Organize notebooks by course.
|
98 |
+
|
99 |
+
Args:
|
100 |
+
all_notebooks: List of paths to notebooks
|
101 |
+
|
102 |
+
Returns:
|
103 |
+
Dict: A dictionary where keys are course directories and values are
|
104 |
+
metadata about the course and its notebooks
|
105 |
+
"""
|
106 |
+
courses = {}
|
107 |
+
|
108 |
+
for notebook_path in sorted(all_notebooks):
|
109 |
+
# Parse the path to determine course
|
110 |
+
# The first directory in the path is the course
|
111 |
+
path_parts = Path(notebook_path).parts
|
112 |
+
|
113 |
+
if len(path_parts) < 2:
|
114 |
+
print(f"Skipping notebook with invalid path: {notebook_path}")
|
115 |
+
continue
|
116 |
+
|
117 |
+
course_id = path_parts[0]
|
118 |
+
|
119 |
+
# If this is a new course, initialize it
|
120 |
+
if course_id not in courses:
|
121 |
+
course_metadata = get_course_metadata(Path(course_id))
|
122 |
+
|
123 |
+
courses[course_id] = {
|
124 |
+
"id": course_id,
|
125 |
+
"title": course_metadata["title"],
|
126 |
+
"description": course_metadata["description"],
|
127 |
+
"description_html": course_metadata["description_html"],
|
128 |
+
"notebooks": []
|
129 |
+
}
|
130 |
+
|
131 |
+
# Extract the notebook number and name from the filename
|
132 |
+
filename = Path(notebook_path).name
|
133 |
+
basename = filename.replace(".py", "")
|
134 |
+
|
135 |
+
# Extract notebook metadata
|
136 |
+
notebook_title = basename.replace("_", " ").title()
|
137 |
+
|
138 |
+
# Try to extract a sequence number from the start of the filename
|
139 |
+
# Match patterns like: 01_xxx, 1_xxx, etc.
|
140 |
+
import re
|
141 |
+
number_match = re.match(r'^(\d+)(?:[_-]|$)', basename)
|
142 |
+
notebook_number = number_match.group(1) if number_match else None
|
143 |
+
|
144 |
+
# If we found a number, remove it from the title
|
145 |
+
if number_match:
|
146 |
+
notebook_title = re.sub(r'^\d+\s*[_-]?\s*', '', notebook_title)
|
147 |
+
|
148 |
+
# Calculate the HTML output path (for linking)
|
149 |
+
html_path = f"{course_id}/{filename.replace('.py', '.html')}"
|
150 |
+
|
151 |
+
# Add the notebook to the course
|
152 |
+
courses[course_id]["notebooks"].append({
|
153 |
+
"path": notebook_path,
|
154 |
+
"html_path": html_path,
|
155 |
+
"title": notebook_title,
|
156 |
+
"display_name": notebook_title,
|
157 |
+
"original_number": notebook_number
|
158 |
+
})
|
159 |
+
|
160 |
+
# Sort notebooks by number if available, otherwise by title
|
161 |
+
for course_id, course_data in courses.items():
|
162 |
+
# Sort the notebooks list by number and title
|
163 |
+
course_data["notebooks"] = sorted(
|
164 |
+
course_data["notebooks"],
|
165 |
+
key=lambda x: (
|
166 |
+
int(x["original_number"]) if x["original_number"] is not None else float('inf'),
|
167 |
+
x["title"]
|
168 |
+
)
|
169 |
+
)
|
170 |
+
|
171 |
+
return courses
|
172 |
+
|
173 |
+
|
174 |
+
def generate_clean_tailwind_landing_page(courses: Dict[str, Dict[str, Any]], output_dir: str) -> None:
|
175 |
+
"""Generate a clean tailwindcss landing page with green accents.
|
176 |
+
|
177 |
+
This generates a modern, minimal landing page for marimo notebooks using tailwindcss.
|
178 |
+
The page is designed with clean aesthetics and green color accents using Jinja2 templates.
|
179 |
+
|
180 |
+
Args:
|
181 |
+
courses: Dictionary of courses metadata
|
182 |
+
output_dir: Directory to write the output index.html file
|
183 |
+
"""
|
184 |
+
print("Generating clean tailwindcss landing page")
|
185 |
+
|
186 |
+
index_path = os.path.join(output_dir, "index.html")
|
187 |
+
os.makedirs(output_dir, exist_ok=True)
|
188 |
+
|
189 |
+
# Load Jinja2 template
|
190 |
+
current_dir = Path(__file__).parent
|
191 |
+
templates_dir = current_dir / "templates"
|
192 |
+
env = Environment(loader=FileSystemLoader(templates_dir))
|
193 |
+
template = env.get_template('index.html')
|
194 |
+
|
195 |
+
try:
|
196 |
+
with open(index_path, "w", encoding="utf-8") as f:
|
197 |
+
# Render the template with the provided data
|
198 |
+
rendered_html = template.render(
|
199 |
+
courses=courses,
|
200 |
+
current_year=datetime.date.today().year
|
201 |
+
)
|
202 |
+
f.write(rendered_html)
|
203 |
+
|
204 |
+
print(f"Successfully generated clean tailwindcss landing page at {index_path}")
|
205 |
+
|
206 |
+
except IOError as e:
|
207 |
+
print(f"Error generating clean tailwindcss landing page: {e}")
|
208 |
+
|
209 |
+
|
210 |
+
def main() -> None:
|
211 |
+
parser = argparse.ArgumentParser(description="Build marimo notebooks")
|
212 |
+
parser.add_argument(
|
213 |
+
"--output-dir", default="_site", help="Output directory for built files"
|
214 |
+
)
|
215 |
+
parser.add_argument(
|
216 |
+
"--course-dirs", nargs="+", default=None,
|
217 |
+
help="Specific course directories to build (default: all directories with .py files)"
|
218 |
+
)
|
219 |
+
args = parser.parse_args()
|
220 |
+
|
221 |
+
# Find all course directories (directories containing .py files)
|
222 |
+
all_notebooks: List[str] = []
|
223 |
+
|
224 |
+
# Directories to exclude from course detection
|
225 |
+
excluded_dirs = ["scripts", "env", "__pycache__", ".git", ".github", "assets"]
|
226 |
+
|
227 |
+
if args.course_dirs:
|
228 |
+
course_dirs = args.course_dirs
|
229 |
+
else:
|
230 |
+
# Automatically detect course directories (any directory with .py files)
|
231 |
+
course_dirs = []
|
232 |
+
for item in os.listdir("."):
|
233 |
+
if (os.path.isdir(item) and
|
234 |
+
not item.startswith(".") and
|
235 |
+
not item.startswith("_") and
|
236 |
+
item not in excluded_dirs):
|
237 |
+
# Check if directory contains .py files
|
238 |
+
if list(Path(item).glob("*.py")):
|
239 |
+
course_dirs.append(item)
|
240 |
+
|
241 |
+
print(f"Found course directories: {', '.join(course_dirs)}")
|
242 |
+
|
243 |
+
for directory in course_dirs:
|
244 |
+
dir_path = Path(directory)
|
245 |
+
if not dir_path.exists():
|
246 |
+
print(f"Warning: Directory not found: {dir_path}")
|
247 |
+
continue
|
248 |
+
|
249 |
+
notebooks = [str(path) for path in dir_path.rglob("*.py")
|
250 |
+
if not path.name.startswith("_") and "/__pycache__/" not in str(path)]
|
251 |
+
all_notebooks.extend(notebooks)
|
252 |
+
|
253 |
+
if not all_notebooks:
|
254 |
+
print("No notebooks found!")
|
255 |
+
return
|
256 |
+
|
257 |
+
# Export notebooks sequentially
|
258 |
+
successful_notebooks = []
|
259 |
+
for nb in all_notebooks:
|
260 |
+
# Determine if notebook should be exported as app or notebook
|
261 |
+
# For now, export all as notebooks
|
262 |
+
if export_html_wasm(nb, args.output_dir, as_app=False):
|
263 |
+
successful_notebooks.append(nb)
|
264 |
+
|
265 |
+
# Organize notebooks by course (only include successfully exported notebooks)
|
266 |
+
courses = organize_notebooks_by_course(successful_notebooks)
|
267 |
+
|
268 |
+
# Generate landing page using Tailwind CSS
|
269 |
+
generate_clean_tailwind_landing_page(courses, args.output_dir)
|
270 |
+
|
271 |
+
# Save course data as JSON for potential use by other tools
|
272 |
+
courses_json_path = os.path.join(args.output_dir, "courses.json")
|
273 |
+
with open(courses_json_path, "w", encoding="utf-8") as f:
|
274 |
+
json.dump(courses, f, indent=2)
|
275 |
+
|
276 |
+
print(f"Build complete! Site generated in {args.output_dir}")
|
277 |
+
print(f"Successfully exported {len(successful_notebooks)} out of {len(all_notebooks)} notebooks")
|
278 |
+
|
279 |
+
|
280 |
+
if __name__ == "__main__":
|
281 |
+
main()
|
scripts/preview.py
ADDED
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
|
3 |
+
import os
|
4 |
+
import subprocess
|
5 |
+
import argparse
|
6 |
+
import webbrowser
|
7 |
+
import time
|
8 |
+
import sys
|
9 |
+
from pathlib import Path
|
10 |
+
|
11 |
+
def main():
|
12 |
+
parser = argparse.ArgumentParser(description="Build and preview marimo notebooks site")
|
13 |
+
parser.add_argument(
|
14 |
+
"--port", default=8000, type=int, help="Port to run the server on"
|
15 |
+
)
|
16 |
+
parser.add_argument(
|
17 |
+
"--no-build", action="store_true", help="Skip building the site (just serve existing files)"
|
18 |
+
)
|
19 |
+
parser.add_argument(
|
20 |
+
"--output-dir", default="_site", help="Output directory for built files"
|
21 |
+
)
|
22 |
+
args = parser.parse_args()
|
23 |
+
|
24 |
+
# Store the current directory
|
25 |
+
original_dir = os.getcwd()
|
26 |
+
|
27 |
+
try:
|
28 |
+
# Build the site if not skipped
|
29 |
+
if not args.no_build:
|
30 |
+
print("Building site...")
|
31 |
+
build_script = Path("scripts/build.py")
|
32 |
+
if not build_script.exists():
|
33 |
+
print(f"Error: Build script not found at {build_script}")
|
34 |
+
return 1
|
35 |
+
|
36 |
+
result = subprocess.run(
|
37 |
+
[sys.executable, str(build_script), "--output-dir", args.output_dir],
|
38 |
+
check=False
|
39 |
+
)
|
40 |
+
if result.returncode != 0:
|
41 |
+
print("Warning: Build process completed with errors.")
|
42 |
+
|
43 |
+
# Check if the output directory exists
|
44 |
+
output_dir = Path(args.output_dir)
|
45 |
+
if not output_dir.exists():
|
46 |
+
print(f"Error: Output directory '{args.output_dir}' does not exist.")
|
47 |
+
return 1
|
48 |
+
|
49 |
+
# Change to the output directory
|
50 |
+
os.chdir(args.output_dir)
|
51 |
+
|
52 |
+
# Open the browser
|
53 |
+
url = f"http://localhost:{args.port}"
|
54 |
+
print(f"Opening {url} in your browser...")
|
55 |
+
webbrowser.open(url)
|
56 |
+
|
57 |
+
# Start the server
|
58 |
+
print(f"Starting server on port {args.port}...")
|
59 |
+
print("Press Ctrl+C to stop the server")
|
60 |
+
|
61 |
+
# Use the appropriate Python executable
|
62 |
+
subprocess.run([sys.executable, "-m", "http.server", str(args.port)])
|
63 |
+
|
64 |
+
return 0
|
65 |
+
except KeyboardInterrupt:
|
66 |
+
print("\nServer stopped.")
|
67 |
+
return 0
|
68 |
+
except Exception as e:
|
69 |
+
print(f"Error: {e}")
|
70 |
+
return 1
|
71 |
+
finally:
|
72 |
+
# Always return to the original directory
|
73 |
+
os.chdir(original_dir)
|
74 |
+
|
75 |
+
if __name__ == "__main__":
|
76 |
+
sys.exit(main())
|
scripts/templates/index.html
ADDED
@@ -0,0 +1,174 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<!DOCTYPE html>
|
2 |
+
<html lang="en">
|
3 |
+
<head>
|
4 |
+
<meta charset="UTF-8">
|
5 |
+
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
6 |
+
<title>Marimo Learn - Interactive Python Notebooks</title>
|
7 |
+
<meta name="description" content="Learn Python, data science, and machine learning with interactive marimo notebooks">
|
8 |
+
<link href="https://cdn.jsdelivr.net/npm/[email protected]/dist/tailwind.min.css" rel="stylesheet">
|
9 |
+
<style>
|
10 |
+
:root {
|
11 |
+
--primary-green: #10B981;
|
12 |
+
--dark-green: #047857;
|
13 |
+
--light-green: #D1FAE5;
|
14 |
+
}
|
15 |
+
.bg-primary { background-color: var(--primary-green); }
|
16 |
+
.text-primary { color: var(--primary-green); }
|
17 |
+
.border-primary { border-color: var(--primary-green); }
|
18 |
+
.bg-light { background-color: var(--light-green); }
|
19 |
+
.hover-grow { transition: transform 0.2s ease; }
|
20 |
+
.hover-grow:hover { transform: scale(1.02); }
|
21 |
+
.card-shadow { box-shadow: 0 4px 6px rgba(0, 0, 0, 0.05), 0 1px 3px rgba(0, 0, 0, 0.1); }
|
22 |
+
</style>
|
23 |
+
</head>
|
24 |
+
<body class="bg-gray-50 text-gray-800 font-sans">
|
25 |
+
<!-- Hero Section -->
|
26 |
+
<header class="bg-white">
|
27 |
+
<div class="container mx-auto px-4 py-12 max-w-6xl">
|
28 |
+
<div class="flex flex-col md:flex-row items-center justify-between">
|
29 |
+
<div class="md:w-1/2 mb-8 md:mb-0 md:pr-12">
|
30 |
+
<h1 class="text-4xl md:text-5xl font-bold mb-4">Interactive Python Learning with <span class="text-primary">marimo</span></h1>
|
31 |
+
<p class="text-lg text-gray-600 mb-6">Explore our collection of interactive notebooks for Python, data science, and machine learning.</p>
|
32 |
+
<div class="flex flex-wrap gap-4">
|
33 |
+
<a href="#courses" class="bg-primary hover:bg-dark-green text-white font-medium py-2 px-6 rounded-md transition duration-300">Explore Courses</a>
|
34 |
+
<a href="https://github.com/marimo-team/learn" target="_blank" class="bg-white border border-gray-300 hover:border-primary text-gray-700 font-medium py-2 px-6 rounded-md transition duration-300">View on GitHub</a>
|
35 |
+
</div>
|
36 |
+
</div>
|
37 |
+
<div class="md:w-1/2">
|
38 |
+
<div class="bg-light p-1 rounded-lg">
|
39 |
+
<img src="https://github.com/marimo-team/learn/blob/main/assets/marimo-learn.png?raw=true" alt="Marimo Logo" class="w-64 h-64 mx-auto object-contain">
|
40 |
+
</div>
|
41 |
+
</div>
|
42 |
+
</div>
|
43 |
+
</div>
|
44 |
+
</header>
|
45 |
+
|
46 |
+
<!-- Features Section -->
|
47 |
+
<section class="py-16 bg-gray-50">
|
48 |
+
<div class="container mx-auto px-4 max-w-6xl">
|
49 |
+
<h2 class="text-3xl font-bold text-center mb-12">Why Learn with <span class="text-primary">Marimo</span>?</h2>
|
50 |
+
<div class="grid md:grid-cols-3 gap-8">
|
51 |
+
<div class="bg-white p-6 rounded-lg card-shadow">
|
52 |
+
<div class="w-12 h-12 bg-light rounded-full flex items-center justify-center mb-4">
|
53 |
+
<svg xmlns="http://www.w3.org/2000/svg" class="h-6 w-6 text-primary" fill="none" viewBox="0 0 24 24" stroke="currentColor">
|
54 |
+
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M13 10V3L4 14h7v7l9-11h-7z" />
|
55 |
+
</svg>
|
56 |
+
</div>
|
57 |
+
<h3 class="text-xl font-semibold mb-2">Interactive Learning</h3>
|
58 |
+
<p class="text-gray-600">Learn by doing with interactive notebooks that run directly in your browser.</p>
|
59 |
+
</div>
|
60 |
+
<div class="bg-white p-6 rounded-lg card-shadow">
|
61 |
+
<div class="w-12 h-12 bg-light rounded-full flex items-center justify-center mb-4">
|
62 |
+
<svg xmlns="http://www.w3.org/2000/svg" class="h-6 w-6 text-primary" fill="none" viewBox="0 0 24 24" stroke="currentColor">
|
63 |
+
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M19.428 15.428a2 2 0 00-1.022-.547l-2.387-.477a6 6 0 00-3.86.517l-.318.158a6 6 0 01-3.86.517L6.05 15.21a2 2 0 00-1.806.547M8 4h8l-1 1v5.172a2 2 0 00.586 1.414l5 5c1.26 1.26.367 3.414-1.415 3.414H4.828c-1.782 0-2.674-2.154-1.414-3.414l5-5A2 2 0 009 10.172V5L8 4z" />
|
64 |
+
</svg>
|
65 |
+
</div>
|
66 |
+
<h3 class="text-xl font-semibold mb-2">Practical Examples</h3>
|
67 |
+
<p class="text-gray-600">Real-world examples and applications to reinforce your understanding.</p>
|
68 |
+
</div>
|
69 |
+
<div class="bg-white p-6 rounded-lg card-shadow">
|
70 |
+
<div class="w-12 h-12 bg-light rounded-full flex items-center justify-center mb-4">
|
71 |
+
<svg xmlns="http://www.w3.org/2000/svg" class="h-6 w-6 text-primary" fill="none" viewBox="0 0 24 24" stroke="currentColor">
|
72 |
+
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M12 6.253v13m0-13C10.832 5.477 9.246 5 7.5 5S4.168 5.477 3 6.253v13C4.168 18.477 5.754 18 7.5 18s3.332.477 4.5 1.253m0-13C13.168 5.477 14.754 5 16.5 5c1.747 0 3.332.477 4.5 1.253v13C19.832 18.477 18.247 18 16.5 18c-1.746 0-3.332.477-4.5 1.253" />
|
73 |
+
</svg>
|
74 |
+
</div>
|
75 |
+
<h3 class="text-xl font-semibold mb-2">Comprehensive Curriculum</h3>
|
76 |
+
<p class="text-gray-600">From Python basics to advanced machine learning concepts.</p>
|
77 |
+
</div>
|
78 |
+
</div>
|
79 |
+
</div>
|
80 |
+
</section>
|
81 |
+
|
82 |
+
<!-- Courses Section -->
|
83 |
+
<section id="courses" class="py-16 bg-white">
|
84 |
+
<div class="container mx-auto px-4 max-w-6xl">
|
85 |
+
<h2 class="text-3xl font-bold text-center mb-12">Explore Our <span class="text-primary">Courses</span></h2>
|
86 |
+
<div class="grid md:grid-cols-2 lg:grid-cols-3 gap-8">
|
87 |
+
{% for course_id, course in courses.items() %}
|
88 |
+
{% set notebooks = course.get('notebooks', []) %}
|
89 |
+
{% set notebook_count = notebooks|length %}
|
90 |
+
|
91 |
+
{% if notebook_count > 0 %}
|
92 |
+
{% set title = course.get('title', course_id|replace('_', ' ')|title) %}
|
93 |
+
|
94 |
+
<div class="bg-white border border-gray-200 rounded-lg overflow-hidden hover-grow card-shadow">
|
95 |
+
<div class="h-2 bg-primary"></div>
|
96 |
+
<div class="p-6">
|
97 |
+
<h3 class="text-xl font-semibold mb-2">{{ title }}</h3>
|
98 |
+
<p class="text-gray-600 mb-4">
|
99 |
+
{% if course.get('description_html') %}
|
100 |
+
{{ course.get('description_html')|safe }}
|
101 |
+
{% endif %}
|
102 |
+
</p>
|
103 |
+
<div class="mb-4">
|
104 |
+
<span class="text-sm text-gray-500 block mb-2">{{ notebook_count }} notebooks:</span>
|
105 |
+
<ol class="space-y-1 list-decimal pl-5">
|
106 |
+
{% for notebook in notebooks %}
|
107 |
+
{% set notebook_title = notebook.get('title', notebook.get('path', '').split('/')[-1].replace('.py', '').replace('_', ' ').title()) %}
|
108 |
+
<li>
|
109 |
+
<a href="{{ notebook.get('html_path', '#') }}" class="text-primary hover:text-dark-green text-sm flex items-center">
|
110 |
+
{{ notebook_title }}
|
111 |
+
</a>
|
112 |
+
</li>
|
113 |
+
{% endfor %}
|
114 |
+
</ol>
|
115 |
+
</div>
|
116 |
+
</div>
|
117 |
+
</div>
|
118 |
+
{% endif %}
|
119 |
+
{% endfor %}
|
120 |
+
</div>
|
121 |
+
</div>
|
122 |
+
</section>
|
123 |
+
|
124 |
+
<!-- Contribute Section -->
|
125 |
+
<section class="py-16 bg-light">
|
126 |
+
<div class="container mx-auto px-4 max-w-6xl text-center">
|
127 |
+
<h2 class="text-3xl font-bold mb-6">Want to Contribute?</h2>
|
128 |
+
<p class="text-lg text-gray-700 mb-8 max-w-2xl mx-auto">Help us improve these learning materials by contributing to the GitHub repository. We welcome new content, bug fixes, and improvements!</p>
|
129 |
+
<a href="https://github.com/marimo-team/learn" target="_blank" class="bg-primary hover:bg-dark-green text-white font-medium py-3 px-8 rounded-md transition duration-300 inline-flex items-center">
|
130 |
+
<svg class="w-5 h-5 mr-2" fill="currentColor" viewBox="0 0 20 20" xmlns="http://www.w3.org/2000/svg">
|
131 |
+
<path fill-rule="evenodd" d="M10 0C4.477 0 0 4.477 0 10c0 4.42 2.87 8.17 6.84 9.5.5.08.66-.23.66-.5v-1.69c-2.77.6-3.36-1.34-3.36-1.34-.46-1.16-1.11-1.47-1.11-1.47-.91-.62.07-.6.07-.6 1 .07 1.53 1.03 1.53 1.03.87 1.52 2.34 1.07 2.91.83.09-.65.35-1.09.63-1.34-2.22-.25-4.55-1.11-4.55-4.92 0-1.11.38-2 1.03-2.71-.1-.25-.45-1.29.1-2.64 0 0 1.005-.315 3.3 1.23.96-.27 1.98-.405 3-.405s2.04.135 3 .405c2.295-1.56 3.3-1.23 3.3-1.23.66 1.65.24 2.88.12 3.18.765.84 1.23 1.905 1.23 3.225 0 4.605-2.805 5.625-5.475 5.925.435.375.81 1.095.81 2.22 0 1.605-.015 2.895-.015 3.3 0 .315.225.69.825.57A12.02 12.02 0 0024 12c0-6.63-5.37-12-12-12z" clip-rule="evenodd"></path>
|
132 |
+
</svg>
|
133 |
+
Contribute on GitHub
|
134 |
+
</a>
|
135 |
+
</div>
|
136 |
+
</section>
|
137 |
+
|
138 |
+
<!-- Footer -->
|
139 |
+
<footer class="bg-gray-800 text-white py-8">
|
140 |
+
<div class="container mx-auto px-4 max-w-6xl">
|
141 |
+
<div class="flex flex-col md:flex-row justify-between items-center">
|
142 |
+
<div class="mb-4 md:mb-0">
|
143 |
+
<p>© {{ current_year }} marimo. All rights reserved.</p>
|
144 |
+
</div>
|
145 |
+
<div class="flex space-x-4">
|
146 |
+
<a href="https://github.com/marimo-team/learn" target="_blank" class="text-gray-300 hover:text-white transition duration-300">
|
147 |
+
<svg class="w-6 h-6" fill="currentColor" viewBox="0 0 24 24" xmlns="http://www.w3.org/2000/svg">
|
148 |
+
<path fill-rule="evenodd" d="M12 0C5.37 0 0 5.37 0 12c0 5.31 3.435 9.795 8.205 11.385.6.105.825-.255.825-.57 0-.285-.015-1.23-.015-2.235-3.015.555-3.795-.735-4.035-1.41-.135-.345-.72-1.41-1.23-1.695-.42-.225-1.02-.78-.015-.795.945-.015 1.62.87 1.845 1.23 1.08 1.815 2.805 1.305 3.495.99.105-.78.42-1.305.765-1.605-2.67-.3-5.46-1.335-5.46-5.925 0-1.305.465-2.385 1.23-3.225-.12-.3-.54-1.53.12-3.18 0 0 1.005-.315 3.3 1.23.96-.27 1.98-.405 3-.405s2.04.135 3 .405c2.295-1.56 3.3-1.23 3.3-1.23.66 1.65.24 2.88.12 3.18.765.84 1.23 1.905 1.23 3.225 0 4.605-2.805 5.625-5.475 5.925.435.375.81 1.095.81 2.22 0 1.605-.015 2.895-.015 3.3 0 .315.225.69.825.57A12.02 12.02 0 0024 12c0-6.63-5.37-12-12-12z" clip-rule="evenodd"></path>
|
149 |
+
</svg>
|
150 |
+
</a>
|
151 |
+
<a href="https://marimo.io" target="_blank" class="text-gray-300 hover:text-white transition duration-300">
|
152 |
+
<svg xmlns="http://www.w3.org/2000/svg" class="h-6 w-6" fill="none" viewBox="0 0 24 24" stroke="currentColor">
|
153 |
+
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M21 12a9 9 0 01-9 9m9-9a9 9 0 00-9-9m9 9H3m9 9a9 9 0 01-9-9m9 9c1.657 0 3-4.03 3-9s-1.343-9-3-9m0 18c-1.657 0-3-4.03-3-9s1.343-9 3-9m-9 9a9 9 0 019-9" />
|
154 |
+
</svg>
|
155 |
+
</a>
|
156 |
+
</div>
|
157 |
+
</div>
|
158 |
+
</div>
|
159 |
+
</footer>
|
160 |
+
|
161 |
+
<!-- Scripts -->
|
162 |
+
<script>
|
163 |
+
// Smooth scrolling for anchor links
|
164 |
+
document.querySelectorAll('a[href^="#"]').forEach(anchor => {
|
165 |
+
anchor.addEventListener('click', function (e) {
|
166 |
+
e.preventDefault();
|
167 |
+
document.querySelector(this.getAttribute('href')).scrollIntoView({
|
168 |
+
behavior: 'smooth'
|
169 |
+
});
|
170 |
+
});
|
171 |
+
});
|
172 |
+
</script>
|
173 |
+
</body>
|
174 |
+
</html>
|