metaboulie commited on
Commit
8177ad2
·
2 Parent(s): e9e13d8 8bc43c8

Merge remote-tracking branch 'upstream/main' into fp/applicatives

Browse files
.github/workflows/deploy.yml ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: Deploy to GitHub Pages
2
+
3
+ on:
4
+ push:
5
+ branches: ['main']
6
+ workflow_dispatch:
7
+
8
+ concurrency:
9
+ group: 'pages'
10
+ cancel-in-progress: false
11
+
12
+ env:
13
+ UV_SYSTEM_PYTHON: 1
14
+
15
+ jobs:
16
+ build:
17
+ runs-on: ubuntu-latest
18
+ steps:
19
+ - uses: actions/checkout@v4
20
+
21
+ - name: 🚀 Install uv
22
+ uses: astral-sh/setup-uv@v4
23
+
24
+ - name: 🐍 Set up Python
25
+ uses: actions/setup-python@v5
26
+ with:
27
+ python-version: 3.12
28
+
29
+ - name: 📦 Install dependencies
30
+ run: |
31
+ uv pip install marimo jinja2
32
+
33
+ - name: 🛠️ Export notebooks
34
+ run: |
35
+ python scripts/build.py
36
+
37
+ - name: 📤 Upload artifact
38
+ uses: actions/upload-pages-artifact@v3
39
+ with:
40
+ path: _site
41
+
42
+ deploy:
43
+ needs: build
44
+
45
+ permissions:
46
+ pages: write
47
+ id-token: write
48
+
49
+ environment:
50
+ name: github-pages
51
+ url: ${{ steps.deployment.outputs.page_url }}
52
+ runs-on: ubuntu-latest
53
+ steps:
54
+ - name: 🚀 Deploy to GitHub Pages
55
+ id: deployment
56
+ uses: actions/deploy-pages@v4
.github/workflows/hf_sync.yml CHANGED
@@ -13,7 +13,33 @@ jobs:
13
  - uses: actions/checkout@v4
14
  with:
15
  fetch-depth: 0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  - name: Push to hub
17
  env:
18
  HF_TOKEN: ${{ secrets.HF_TOKEN }}
19
- run: git push https://mylessss:[email protected]/spaces/marimo-team/marimo-learn main
 
13
  - uses: actions/checkout@v4
14
  with:
15
  fetch-depth: 0
16
+
17
+ - name: Configure Git
18
+ run: |
19
+ git config --global user.name "GitHub Action"
20
+ git config --global user.email "[email protected]"
21
+
22
+ - name: Prepend frontmatter to README
23
+ run: |
24
+ if [ -f README.md ] && ! grep -q "^---" README.md; then
25
+ FRONTMATTER="---
26
+ title: marimo learn
27
+ emoji: 🧠
28
+ colorFrom: blue
29
+ colorTo: indigo
30
+ sdk: docker
31
+ sdk_version: \"latest\"
32
+ app_file: app.py
33
+ pinned: false
34
+ ---
35
+
36
+ "
37
+ echo "$FRONTMATTER$(cat README.md)" > README.md
38
+ git add README.md
39
+ git commit -m "Add HF frontmatter to README" || echo "No changes to commit"
40
+ fi
41
+
42
  - name: Push to hub
43
  env:
44
  HF_TOKEN: ${{ secrets.HF_TOKEN }}
45
+ run: git push -f https://mylessss:[email protected]/spaces/marimo-team/marimo-learn main
.gitignore CHANGED
@@ -168,4 +168,7 @@ cython_debug/
168
  #.idea/
169
 
170
  # PyPI configuration file
171
- .pypirc
 
 
 
 
168
  #.idea/
169
 
170
  # PyPI configuration file
171
+ .pypirc
172
+
173
+ # Generated site content
174
+ _site/
Dockerfile CHANGED
@@ -1,9 +1,22 @@
1
  FROM ghcr.io/astral-sh/uv:python3.12-bookworm-slim
2
 
 
 
 
 
 
 
3
  COPY _server/main.py _server/main.py
4
  COPY polars/ polars/
5
  COPY duckdb/ duckdb/
6
 
 
 
 
 
 
 
 
7
  RUN uv venv
8
  RUN uv export --script _server/main.py | uv pip install -r -
9
 
 
1
  FROM ghcr.io/astral-sh/uv:python3.12-bookworm-slim
2
 
3
+ WORKDIR /app
4
+
5
+ # Create a non-root user
6
+ RUN useradd -m appuser
7
+
8
+ # Copy application files
9
  COPY _server/main.py _server/main.py
10
  COPY polars/ polars/
11
  COPY duckdb/ duckdb/
12
 
13
+ # Set proper ownership
14
+ RUN chown -R appuser:appuser /app
15
+
16
+ # Switch to non-root user
17
+ USER appuser
18
+
19
+ # Create virtual environment and install dependencies
20
  RUN uv venv
21
  RUN uv export --script _server/main.py | uv pip install -r -
22
 
Makefile ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ install:
2
+ uv pip install marimo jinja2 markdown
3
+
4
+ build:
5
+ rm -rf _site
6
+ uv run scripts/build.py
7
+
8
+ serve:
9
+ uv run python -m http.server --directory _site
README.md CHANGED
@@ -58,6 +58,21 @@ Here's a contribution checklist:
58
  If you aren't comfortable adding a new notebook or course, you can also request
59
  what you'd like to see by [filing an issue](https://github.com/marimo-team/learn/issues/new?template=example_request.yaml).
60
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
  ## Community
62
 
63
  We're building a community. Come hang out with us!
 
58
  If you aren't comfortable adding a new notebook or course, you can also request
59
  what you'd like to see by [filing an issue](https://github.com/marimo-team/learn/issues/new?template=example_request.yaml).
60
 
61
+ ## Building and Previewing
62
+
63
+ The site is built using a Python script that exports marimo notebooks to HTML and generates an index page.
64
+
65
+ ```bash
66
+ # Build the site
67
+ python scripts/build.py --output-dir _site
68
+
69
+ # Preview the site (builds first)
70
+ python scripts/preview.py
71
+
72
+ # Preview without rebuilding
73
+ python scripts/preview.py --no-build
74
+ ```
75
+
76
  ## Community
77
 
78
  We're building a community. Come hang out with us!
_server/README.md CHANGED
@@ -5,11 +5,14 @@ This folder contains server code for hosting marimo apps.
5
  ## Running the server
6
 
7
  ```bash
 
8
  uv run --no-project main.py
9
  ```
10
 
11
  ## Building a Docker image
12
 
 
 
13
  ```bash
14
  docker build -t marimo-learn .
15
  ```
 
5
  ## Running the server
6
 
7
  ```bash
8
+ cd _server
9
  uv run --no-project main.py
10
  ```
11
 
12
  ## Building a Docker image
13
 
14
+ From the root directory, run:
15
+
16
  ```bash
17
  docker build -t marimo-learn .
18
  ```
polars/01_why_polars.py CHANGED
@@ -57,7 +57,7 @@ def _(mo):
57
  """
58
  Unlike Python's earliest DataFrame library Pandas, Polars was designed with performance and usability in mind — Polars can scale to large datasets with ease while maintaining a simple and intuitive API.
59
 
60
- Polars' performance is due to a number of factors, including its implementation and rust and its ability to perform operations in a parallelized and vectorized manner. It supports a wide range of data types, advanced query optimizations, and seamless integration with other Python libraries, making it a versatile tool for data scientists, engineers, and analysts. Additionally, Polars provides a lazy API for deferred execution, allowing users to optimize their workflows by chaining operations and executing them in a single pass.
61
 
62
  With its focus on speed, scalability, and ease of use, Polars is quickly becoming a go-to choice for data professionals looking to streamline their data processing pipelines and tackle large-scale data challenges.
63
  """
 
57
  """
58
  Unlike Python's earliest DataFrame library Pandas, Polars was designed with performance and usability in mind — Polars can scale to large datasets with ease while maintaining a simple and intuitive API.
59
 
60
+ Polars' performance is due to a number of factors, including its implementation in rust and its ability to perform operations in a parallelized and vectorized manner. It supports a wide range of data types, advanced query optimizations, and seamless integration with other Python libraries, making it a versatile tool for data scientists, engineers, and analysts. Additionally, Polars provides a lazy API for deferred execution, allowing users to optimize their workflows by chaining operations and executing them in a single pass.
61
 
62
  With its focus on speed, scalability, and ease of use, Polars is quickly becoming a go-to choice for data professionals looking to streamline their data processing pipelines and tackle large-scale data challenges.
63
  """
probability/11_expectation.py ADDED
@@ -0,0 +1,860 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # /// script
2
+ # requires-python = ">=3.10"
3
+ # dependencies = [
4
+ # "marimo",
5
+ # "matplotlib==3.10.0",
6
+ # "numpy==2.2.3",
7
+ # "scipy==1.15.2",
8
+ # ]
9
+ # ///
10
+
11
+ import marimo
12
+
13
+ __generated_with = "0.11.19"
14
+ app = marimo.App(width="medium", app_title="Expectation")
15
+
16
+
17
+ @app.cell(hide_code=True)
18
+ def _(mo):
19
+ mo.md(
20
+ r"""
21
+ # Expectation
22
+
23
+ _This notebook is a computational companion to ["Probability for Computer Scientists"](https://chrispiech.github.io/probabilityForComputerScientists/en/part2/expectation/), by Stanford professor Chris Piech._
24
+
25
+ A random variable is fully represented by its Probability Mass Function (PMF), which describes each value the random variable can take on and the corresponding probabilities. However, a PMF can contain a lot of information. Sometimes it's useful to summarize a random variable with a single value!
26
+
27
+ The most common, and arguably the most useful, summary of a random variable is its **Expectation** (also called the expected value or mean).
28
+ """
29
+ )
30
+ return
31
+
32
+
33
+ @app.cell(hide_code=True)
34
+ def _(mo):
35
+ mo.md(
36
+ r"""
37
+ ## Definition of Expectation
38
+
39
+ The expectation of a random variable $X$, written $E[X]$, is the average of all the values the random variable can take on, each weighted by the probability that the random variable will take on that value.
40
+
41
+ $$E[X] = \sum_x x \cdot P(X=x)$$
42
+
43
+ Expectation goes by many other names: Mean, Weighted Average, Center of Mass, 1st Moment. All of these are calculated using the same formula.
44
+ """
45
+ )
46
+ return
47
+
48
+
49
+ @app.cell(hide_code=True)
50
+ def _(mo):
51
+ mo.md(
52
+ r"""
53
+ ## Intuition Behind Expectation
54
+
55
+ The expected value represents the long-run average value of a random variable over many independent repetitions of an experiment.
56
+
57
+ For example, if you roll a fair six-sided die many times and calculate the average of all rolls, that average will approach the expected value of 3.5 as the number of rolls increases.
58
+
59
+ Let's visualize this concept:
60
+ """
61
+ )
62
+ return
63
+
64
+
65
+ @app.cell(hide_code=True)
66
+ def _(np, plt):
67
+ # Set random seed for reproducibility
68
+ np.random.seed(42)
69
+
70
+ # Simulate rolling a die many times
71
+ exp_num_rolls = 1000
72
+ exp_die_rolls = np.random.randint(1, 7, size=exp_num_rolls)
73
+
74
+ # Calculate the running average
75
+ exp_running_avg = np.cumsum(exp_die_rolls) / np.arange(1, exp_num_rolls + 1)
76
+
77
+ # Create the plot
78
+ plt.figure(figsize=(10, 5))
79
+ plt.plot(range(1, exp_num_rolls + 1), exp_running_avg, label='Running Average')
80
+ plt.axhline(y=3.5, color='r', linestyle='--', label='Expected Value (3.5)')
81
+ plt.xlabel('Number of Rolls')
82
+ plt.ylabel('Average Value')
83
+ plt.title('Running Average of Die Rolls Approaching Expected Value')
84
+ plt.legend()
85
+ plt.grid(alpha=0.3)
86
+ plt.xscale('log') # Log scale to better see convergence
87
+
88
+ # Add annotations
89
+ plt.annotate('As the number of rolls increases,\nthe average approaches the expected value',
90
+ xy=(exp_num_rolls, exp_running_avg[-1]), xytext=(exp_num_rolls/3, 4),
91
+ arrowprops=dict(facecolor='black', shrink=0.05, width=1.5))
92
+
93
+ plt.gca()
94
+ return exp_die_rolls, exp_num_rolls, exp_running_avg
95
+
96
+
97
+ @app.cell(hide_code=True)
98
+ def _(mo):
99
+ mo.md(r"""## Properties of Expectation""")
100
+ return
101
+
102
+
103
+ @app.cell(hide_code=True)
104
+ def _(mo):
105
+ mo.accordion(
106
+ {
107
+ "1. Linearity of Expectation": mo.md(
108
+ r"""
109
+ $$E[aX + b] = a \cdot E[X] + b$$
110
+
111
+ Where $a$ and $b$ are constants (not random variables).
112
+
113
+ This means that if you multiply a random variable by a constant, the expectation is multiplied by that constant. And if you add a constant to a random variable, the expectation increases by that constant.
114
+ """
115
+ ),
116
+ "2. Expectation of the Sum of Random Variables": mo.md(
117
+ r"""
118
+ $$E[X + Y] = E[X] + E[Y]$$
119
+
120
+ This is true regardless of the relationship between $X$ and $Y$. They can be dependent, and they can have different distributions. This also applies with more than two random variables:
121
+
122
+ $$E\left[\sum_{i=1}^n X_i\right] = \sum_{i=1}^n E[X_i]$$
123
+ """
124
+ ),
125
+ "3. Law of the Unconscious Statistician (LOTUS)": mo.md(
126
+ r"""
127
+ $$E[g(X)] = \sum_x g(x) \cdot P(X=x)$$
128
+
129
+ This allows us to calculate the expected value of a function $g(X)$ of a random variable $X$ when we know the probability distribution of $X$ but don't explicitly know the distribution of $g(X)$.
130
+
131
+ This theorem has the humorous name "Law of the Unconscious Statistician" (LOTUS) because it's so useful that you should be able to employ it unconsciously.
132
+ """
133
+ ),
134
+ "4. Expectation of a Constant": mo.md(
135
+ r"""
136
+ $$E[a] = a$$
137
+
138
+ Sometimes in proofs, you'll end up with the expectation of a constant (rather than a random variable). Since a constant doesn't change, its expected value is just the constant itself.
139
+ """
140
+ ),
141
+ }
142
+ )
143
+ return
144
+
145
+
146
+ @app.cell(hide_code=True)
147
+ def _(mo):
148
+ mo.md(
149
+ r"""
150
+ ## Calculating Expectation
151
+
152
+ Let's calculate the expected value for some common examples:
153
+
154
+ ### Example 1: Fair Die Roll
155
+
156
+ For a fair six-sided die, the PMF is:
157
+
158
+ $$P(X=x) = \frac{1}{6} \text{ for } x \in \{1, 2, 3, 4, 5, 6\}$$
159
+
160
+ The expected value is:
161
+
162
+ $$E[X] = 1 \cdot \frac{1}{6} + 2 \cdot \frac{1}{6} + 3 \cdot \frac{1}{6} + 4 \cdot \frac{1}{6} + 5 \cdot \frac{1}{6} + 6 \cdot \frac{1}{6} = \frac{21}{6} = 3.5$$
163
+
164
+ Let's implement this calculation in Python:
165
+ """
166
+ )
167
+ return
168
+
169
+
170
+ @app.cell
171
+ def _():
172
+ def calc_expectation_die():
173
+ """Calculate the expected value of a fair six-sided die roll."""
174
+ exp_die_values = range(1, 7)
175
+ exp_die_probs = [1/6] * 6
176
+
177
+ exp_die_expected = sum(x * p for x, p in zip(exp_die_values, exp_die_probs))
178
+ return exp_die_expected
179
+
180
+ exp_die_result = calc_expectation_die()
181
+ print(f"Expected value of a fair die roll: {exp_die_result}")
182
+ return calc_expectation_die, exp_die_result
183
+
184
+
185
+ @app.cell(hide_code=True)
186
+ def _(mo):
187
+ mo.md(
188
+ r"""
189
+ ### Example 2: Sum of Two Dice
190
+
191
+ Now let's calculate the expected value for the sum of two fair dice. First, we need the PMF:
192
+ """
193
+ )
194
+ return
195
+
196
+
197
+ @app.cell
198
+ def _():
199
+ def pmf_sum_two_dice(y_val):
200
+ """Returns the probability that the sum of two dice is y."""
201
+ # Count the number of ways to get sum y
202
+ exp_count = 0
203
+ for dice1 in range(1, 7):
204
+ for dice2 in range(1, 7):
205
+ if dice1 + dice2 == y_val:
206
+ exp_count += 1
207
+ return exp_count / 36 # There are 36 possible outcomes (6×6)
208
+
209
+ # Test the function for a few values
210
+ exp_test_values = [2, 7, 12]
211
+ for exp_test_y in exp_test_values:
212
+ print(f"P(Y = {exp_test_y}) = {pmf_sum_two_dice(exp_test_y)}")
213
+ return exp_test_values, exp_test_y, pmf_sum_two_dice
214
+
215
+
216
+ @app.cell
217
+ def _(pmf_sum_two_dice):
218
+ def calc_expectation_sum_two_dice():
219
+ """Calculate the expected value of the sum of two dice."""
220
+ exp_sum_two_dice = 0
221
+ # Sum of dice can take on the values 2 through 12
222
+ for exp_x in range(2, 13):
223
+ exp_pr_x = pmf_sum_two_dice(exp_x) # PMF gives P(sum is x)
224
+ exp_sum_two_dice += exp_x * exp_pr_x
225
+ return exp_sum_two_dice
226
+
227
+ exp_sum_result = calc_expectation_sum_two_dice()
228
+
229
+ # Round to 2 decimal places for display
230
+ exp_sum_result_rounded = round(exp_sum_result, 2)
231
+
232
+ print(f"Expected value of the sum of two dice: {exp_sum_result_rounded}")
233
+
234
+ # Let's also verify this with a direct calculation
235
+ exp_direct_calc = sum(x * pmf_sum_two_dice(x) for x in range(2, 13))
236
+ exp_direct_calc_rounded = round(exp_direct_calc, 2)
237
+
238
+ print(f"Direct calculation: {exp_direct_calc_rounded}")
239
+
240
+ # Verify that this equals 7
241
+ print(f"Is the expected value exactly 7? {abs(exp_sum_result - 7) < 1e-10}")
242
+ return (
243
+ calc_expectation_sum_two_dice,
244
+ exp_direct_calc,
245
+ exp_direct_calc_rounded,
246
+ exp_sum_result,
247
+ exp_sum_result_rounded,
248
+ )
249
+
250
+
251
+ @app.cell(hide_code=True)
252
+ def _(mo):
253
+ mo.md(
254
+ r"""
255
+ ### Visualizing Expectation
256
+
257
+ Let's visualize the expectation for the sum of two dice. The expected value is the "center of mass" of the PMF:
258
+ """
259
+ )
260
+ return
261
+
262
+
263
+ @app.cell(hide_code=True)
264
+ def _(plt, pmf_sum_two_dice):
265
+ # Create the visualization
266
+ exp_y_values = list(range(2, 13))
267
+ exp_probabilities = [pmf_sum_two_dice(y) for y in exp_y_values]
268
+
269
+ dice_fig, dice_ax = plt.subplots(figsize=(10, 5))
270
+ dice_ax.bar(exp_y_values, exp_probabilities, width=0.4)
271
+ dice_ax.axvline(x=7, color='r', linestyle='--', linewidth=2, label='Expected Value (7)')
272
+
273
+ dice_ax.set_xticks(exp_y_values)
274
+ dice_ax.set_xlabel('Sum of two dice (y)')
275
+ dice_ax.set_ylabel('Probability: P(Y = y)')
276
+ dice_ax.set_title('PMF of Sum of Two Dice with Expected Value')
277
+ dice_ax.grid(alpha=0.3)
278
+ dice_ax.legend()
279
+
280
+ # Add probability values on top of bars
281
+ for exp_i, exp_prob in enumerate(exp_probabilities):
282
+ dice_ax.text(exp_y_values[exp_i], exp_prob + 0.001, f'{exp_prob:.3f}', ha='center')
283
+
284
+ plt.tight_layout()
285
+ plt.gca()
286
+ return dice_ax, dice_fig, exp_i, exp_prob, exp_probabilities, exp_y_values
287
+
288
+
289
+ @app.cell(hide_code=True)
290
+ def _(mo):
291
+ mo.md(
292
+ r"""
293
+ ## Demonstrating the Properties of Expectation
294
+
295
+ Let's demonstrate some of these properties with examples:
296
+ """
297
+ )
298
+ return
299
+
300
+
301
+ @app.cell
302
+ def _(exp_die_result):
303
+ # Demonstrate linearity of expectation (1)
304
+ # E[aX + b] = a*E[X] + b
305
+
306
+ # For a die roll X with E[X] = 3.5
307
+ prop_a = 2
308
+ prop_b = 10
309
+
310
+ # Calculate E[2X + 10] using the property
311
+ prop_expected_using_property = prop_a * exp_die_result + prop_b
312
+ prop_expected_using_property_rounded = round(prop_expected_using_property, 2)
313
+
314
+ print(f"Using linearity property: E[{prop_a}X + {prop_b}] = {prop_a} * E[X] + {prop_b} = {prop_expected_using_property_rounded}")
315
+
316
+ # Calculate E[2X + 10] directly
317
+ prop_expected_direct = sum((prop_a * x + prop_b) * (1/6) for x in range(1, 7))
318
+ prop_expected_direct_rounded = round(prop_expected_direct, 2)
319
+
320
+ print(f"Direct calculation: E[{prop_a}X + {prop_b}] = {prop_expected_direct_rounded}")
321
+
322
+ # Verify they match
323
+ print(f"Do they match? {abs(prop_expected_using_property - prop_expected_direct) < 1e-10}")
324
+ return (
325
+ prop_a,
326
+ prop_b,
327
+ prop_expected_direct,
328
+ prop_expected_direct_rounded,
329
+ prop_expected_using_property,
330
+ prop_expected_using_property_rounded,
331
+ )
332
+
333
+
334
+ @app.cell(hide_code=True)
335
+ def _(mo):
336
+ mo.md(
337
+ r"""
338
+ ### Law of the Unconscious Statistician (LOTUS)
339
+
340
+ Let's use LOTUS to calculate $E[X^2]$ for a die roll, which will be useful when we study variance:
341
+ """
342
+ )
343
+ return
344
+
345
+
346
+ @app.cell
347
+ def _():
348
+ # Calculate E[X^2] for a die roll using LOTUS (3)
349
+ lotus_die_values = range(1, 7)
350
+ lotus_die_probs = [1/6] * 6
351
+
352
+ # Using LOTUS: E[X^2] = sum(x^2 * P(X=x))
353
+ lotus_expected_x_squared = sum(x**2 * p for x, p in zip(lotus_die_values, lotus_die_probs))
354
+ lotus_expected_x_squared_rounded = round(lotus_expected_x_squared, 2)
355
+
356
+ expected_x_squared = 3.5**2
357
+ expected_x_squared_rounded = round(expected_x_squared, 2)
358
+
359
+ print(f"E[X^2] for a die roll = {lotus_expected_x_squared_rounded}")
360
+ print(f"(E[X])^2 for a die roll = {expected_x_squared_rounded}")
361
+ return (
362
+ expected_x_squared,
363
+ expected_x_squared_rounded,
364
+ lotus_die_probs,
365
+ lotus_die_values,
366
+ lotus_expected_x_squared,
367
+ lotus_expected_x_squared_rounded,
368
+ )
369
+
370
+
371
+ @app.cell(hide_code=True)
372
+ def _(mo):
373
+ mo.md(
374
+ r"""
375
+ /// Note
376
+ Note that E[X^2] != (E[X])^2
377
+ """
378
+ )
379
+ return
380
+
381
+
382
+ @app.cell(hide_code=True)
383
+ def _(mo):
384
+ mo.md(
385
+ r"""
386
+ ## Interactive Example
387
+
388
+ Let's explore how the expected value changes as we adjust the parameters of common probability distributions. This interactive visualization focuses specifically on the relationship between distribution parameters and expected values.
389
+
390
+ Use the controls below to select a distribution and adjust its parameters. The graph will show how the expected value changes across a range of parameter values.
391
+ """
392
+ )
393
+ return
394
+
395
+
396
+ @app.cell(hide_code=True)
397
+ def _(mo):
398
+ # Create UI elements for distribution selection
399
+ dist_selection = mo.ui.dropdown(
400
+ options=[
401
+ "bernoulli",
402
+ "binomial",
403
+ "geometric",
404
+ "poisson"
405
+ ],
406
+ value="bernoulli",
407
+ label="Select a distribution"
408
+ )
409
+ return (dist_selection,)
410
+
411
+
412
+ @app.cell(hide_code=True)
413
+ def _(dist_selection):
414
+ dist_selection.center()
415
+ return
416
+
417
+
418
+ @app.cell(hide_code=True)
419
+ def _(dist_description):
420
+ dist_description
421
+ return
422
+
423
+
424
+ @app.cell(hide_code=True)
425
+ def _(mo):
426
+ mo.md("""### Adjust Parameters""")
427
+ return
428
+
429
+
430
+ @app.cell(hide_code=True)
431
+ def _(controls):
432
+ controls
433
+ return
434
+
435
+
436
+ @app.cell(hide_code=True)
437
+ def _(
438
+ dist_selection,
439
+ lambda_range,
440
+ np,
441
+ param_lambda,
442
+ param_n,
443
+ param_p,
444
+ param_range,
445
+ plt,
446
+ ):
447
+ # Calculate expected values based on the selected distribution
448
+ if dist_selection.value == "bernoulli":
449
+ # Get parameter range for visualization
450
+ p_min, p_max = param_range.value
451
+ param_values = np.linspace(p_min, p_max, 100)
452
+
453
+ # E[X] = p for Bernoulli
454
+ expected_values = param_values
455
+ current_param = param_p.value
456
+ current_expected = round(current_param, 2)
457
+ x_label = "p (probability of success)"
458
+ title = "Expected Value of Bernoulli Distribution"
459
+ formula = "E[X] = p"
460
+
461
+ elif dist_selection.value == "binomial":
462
+ p_min, p_max = param_range.value
463
+ param_values = np.linspace(p_min, p_max, 100)
464
+
465
+ # E[X] = np for Binomial
466
+ n = int(param_n.value)
467
+ expected_values = [n * p for p in param_values]
468
+ current_param = param_p.value
469
+ current_expected = round(n * current_param, 2)
470
+ x_label = "p (probability of success)"
471
+ title = f"Expected Value of Binomial Distribution (n={n})"
472
+ formula = f"E[X] = n × p = {n} × p"
473
+
474
+ elif dist_selection.value == "geometric":
475
+ p_min, p_max = param_range.value
476
+ # Ensure p is not 0 for geometric distribution
477
+ p_min = max(0.01, p_min)
478
+ param_values = np.linspace(p_min, p_max, 100)
479
+
480
+ # E[X] = 1/p for Geometric
481
+ expected_values = [1/p for p in param_values]
482
+ current_param = param_p.value
483
+ current_expected = round(1 / current_param, 2)
484
+ x_label = "p (probability of success)"
485
+ title = "Expected Value of Geometric Distribution"
486
+ formula = "E[X] = 1/p"
487
+
488
+ else: # Poisson
489
+ lambda_min, lambda_max = lambda_range.value
490
+ param_values = np.linspace(lambda_min, lambda_max, 100)
491
+
492
+ # E[X] = lambda for Poisson
493
+ expected_values = param_values
494
+ current_param = param_lambda.value
495
+ current_expected = round(current_param, 2)
496
+ x_label = "λ (rate parameter)"
497
+ title = "Expected Value of Poisson Distribution"
498
+ formula = "E[X] = λ"
499
+
500
+ # Create the plot
501
+ dist_fig, dist_ax = plt.subplots(figsize=(10, 6))
502
+
503
+ # Plot the expected value function
504
+ dist_ax.plot(param_values, expected_values, 'b-', linewidth=2, label="Expected Value Function")
505
+
506
+ dist_ax.plot(current_param, current_expected, 'ro', markersize=10, label=f"Current Value: E[X] = {current_expected}")
507
+
508
+ dist_ax.hlines(current_expected, param_values[0], current_param, colors='r', linestyles='dashed')
509
+
510
+ dist_ax.vlines(current_param, 0, current_expected, colors='r', linestyles='dashed')
511
+
512
+ dist_ax.fill_between(param_values, 0, expected_values, alpha=0.2, color='blue')
513
+
514
+ dist_ax.set_xlabel(x_label, fontsize=12)
515
+ dist_ax.set_ylabel("Expected Value: E[X]", fontsize=12)
516
+ dist_ax.set_title(title, fontsize=14, fontweight='bold')
517
+ dist_ax.grid(True, alpha=0.3)
518
+
519
+ # Move legend to lower right to avoid overlap with formula
520
+ dist_ax.legend(loc='lower right', fontsize=10)
521
+
522
+ # Add formula text box in upper left
523
+ dist_props = dict(boxstyle='round', facecolor='white', alpha=0.8)
524
+ dist_ax.text(0.02, 0.95, formula, transform=dist_ax.transAxes, fontsize=12,
525
+ verticalalignment='top', bbox=dist_props)
526
+
527
+ if dist_selection.value == "geometric":
528
+ max_y = min(50, 2/max(0.01, param_values[0]))
529
+ dist_ax.set_ylim(0, max_y)
530
+ elif dist_selection.value == "binomial":
531
+ dist_ax.set_ylim(0, int(param_n.value) + 1)
532
+ else:
533
+ dist_ax.set_ylim(0, max(expected_values) * 1.1)
534
+
535
+ annotation_x = current_param + (param_values[-1] - param_values[0]) * 0.05
536
+ annotation_y = current_expected
537
+
538
+ # Adjust annotation position if it would go off the chart
539
+ if annotation_x > param_values[-1] * 0.9:
540
+ annotation_x = current_param - (param_values[-1] - param_values[0]) * 0.2
541
+
542
+ dist_ax.annotate(
543
+ f"Parameter: {current_param:.2f}\nE[X] = {current_expected}",
544
+ xy=(current_param, current_expected),
545
+ xytext=(annotation_x, annotation_y),
546
+ arrowprops=dict(facecolor='black', shrink=0.05, width=1.5, alpha=0.7),
547
+ bbox=dist_props
548
+ )
549
+
550
+ plt.tight_layout()
551
+ plt.gca()
552
+ return (
553
+ annotation_x,
554
+ annotation_y,
555
+ current_expected,
556
+ current_param,
557
+ dist_ax,
558
+ dist_fig,
559
+ dist_props,
560
+ expected_values,
561
+ formula,
562
+ lambda_max,
563
+ lambda_min,
564
+ max_y,
565
+ n,
566
+ p_max,
567
+ p_min,
568
+ param_values,
569
+ title,
570
+ x_label,
571
+ )
572
+
573
+
574
+ @app.cell(hide_code=True)
575
+ def _(mo):
576
+ mo.md(
577
+ r"""
578
+ ## Expectation vs. Mode
579
+
580
+ The expected value (mean) of a random variable is not always the same as its most likely value (mode). Let's explore this with an example:
581
+ """
582
+ )
583
+ return
584
+
585
+
586
+ @app.cell(hide_code=True)
587
+ def _(np, plt, stats):
588
+ # Create a skewed distribution
589
+ skew_n = 10
590
+ skew_p = 0.25
591
+
592
+ # Binomial PMF
593
+ skew_x_values = np.arange(0, skew_n+1)
594
+ skew_pmf_values = stats.binom.pmf(skew_x_values, skew_n, skew_p)
595
+
596
+ # Find the mode (most likely value)
597
+ skew_mode = skew_x_values[np.argmax(skew_pmf_values)]
598
+
599
+ # Calculate the expected value
600
+ skew_expected = skew_n * skew_p
601
+ skew_expected_rounded = round(skew_expected, 2)
602
+
603
+ skew_fig, skew_ax = plt.subplots(figsize=(10, 5))
604
+ skew_ax.bar(skew_x_values, skew_pmf_values, alpha=0.7, width=0.4)
605
+
606
+ # Add vertical lines for mode and expected value
607
+ skew_ax.axvline(x=skew_mode, color='g', linestyle='--', linewidth=2,
608
+ label=f'Mode = {skew_mode} (Most likely value)')
609
+ skew_ax.axvline(x=skew_expected, color='r', linestyle='--', linewidth=2,
610
+ label=f'Expected Value = {skew_expected_rounded} (Mean)')
611
+
612
+ skew_ax.annotate('Mode', xy=(skew_mode, 0.05), xytext=(skew_mode-2.0, 0.1),
613
+ arrowprops=dict(facecolor='green', shrink=0.05, width=1.5), color='green')
614
+ skew_ax.annotate('Expected Value', xy=(skew_expected, 0.05), xytext=(skew_expected+1, 0.15),
615
+ arrowprops=dict(facecolor='red', shrink=0.05, width=1.5), color='red')
616
+
617
+ if skew_mode != int(skew_expected):
618
+ min_x = min(skew_mode, skew_expected)
619
+ max_x = max(skew_mode, skew_expected)
620
+ skew_ax.axvspan(min_x, max_x, alpha=0.2, color='purple')
621
+
622
+ # Add text explaining the difference
623
+ mid_x = (skew_mode + skew_expected) / 2
624
+ skew_ax.text(mid_x, max(skew_pmf_values) * 0.5,
625
+ f"Difference: {abs(skew_mode - skew_expected_rounded):.2f}",
626
+ ha='center', va='center', bbox=dict(facecolor='white', alpha=0.7))
627
+
628
+ skew_ax.set_xlabel('Number of Successes')
629
+ skew_ax.set_ylabel('Probability')
630
+ skew_ax.set_title(f'Binomial Distribution (n={skew_n}, p={skew_p})')
631
+ skew_ax.grid(alpha=0.3)
632
+ skew_ax.legend()
633
+
634
+ plt.tight_layout()
635
+ plt.gca()
636
+ return (
637
+ max_x,
638
+ mid_x,
639
+ min_x,
640
+ skew_ax,
641
+ skew_expected,
642
+ skew_expected_rounded,
643
+ skew_fig,
644
+ skew_mode,
645
+ skew_n,
646
+ skew_p,
647
+ skew_pmf_values,
648
+ skew_x_values,
649
+ )
650
+
651
+
652
+ @app.cell(hide_code=True)
653
+ def _(mo):
654
+ mo.md(
655
+ r"""
656
+ /// NOTE
657
+ For the sum of two dice we calculated earlier, we found the expected value to be exactly 7. In that case, 7 also happens to be the mode (most likely outcome) of the distribution. However, this is just a coincidence for this particular example!
658
+
659
+ As we can see from the binomial distribution above, the expected value (2.50) and the mode (2) are often different values (this is common in skewed distributions). The expected value represents the "center of mass" of the distribution, while the mode represents the most likely single outcome.
660
+ """
661
+ )
662
+ return
663
+
664
+
665
+ @app.cell(hide_code=True)
666
+ def _(mo):
667
+ mo.md(
668
+ r"""
669
+ ## 🤔 Test Your Understanding
670
+
671
+ Choose what you believe are the correct options in the questions below:
672
+
673
+ <details>
674
+ <summary>The expected value of a random variable is always one of the possible values the random variable can take.</summary>
675
+ ❌ False! The expected value is a weighted average and may not be a value the random variable can actually take. For example, the expected value of a fair die roll is 3.5, which is not a possible outcome.
676
+ </details>
677
+
678
+ <details>
679
+ <summary>If X and Y are independent random variables, then E[X·Y] = E[X]·E[Y].</summary>
680
+ ✅ True! For independent random variables, the expectation of their product equals the product of their expectations.
681
+ </details>
682
+
683
+ <details>
684
+ <summary>The expected value of a constant random variable (one that always takes the same value) is that constant.</summary>
685
+ ✅ True! If X = c with probability 1, then E[X] = c.
686
+ </details>
687
+
688
+ <details>
689
+ <summary>The expected value of the sum of two random variables is always the sum of their expected values, regardless of whether they are independent.</summary>
690
+ ✅ True! This is the linearity of expectation property: E[X + Y] = E[X] + E[Y], which holds regardless of dependence.
691
+ </details>
692
+ """
693
+ )
694
+ return
695
+
696
+
697
+ @app.cell(hide_code=True)
698
+ def _(mo):
699
+ mo.md(
700
+ r"""
701
+ ## Practical Applications of Expectation
702
+
703
+ Expected values show up everywhere - from investment decisions and insurance pricing to machine learning algorithms and game design. Engineers use them to predict system reliability, data scientists to understand customer behavior, and economists to model market outcomes. They're essential for risk assessment in project management and for optimizing resource allocation in operations research.
704
+ """
705
+ )
706
+ return
707
+
708
+
709
+ @app.cell(hide_code=True)
710
+ def _(mo):
711
+ mo.md(
712
+ r"""
713
+ ## Key Takeaways
714
+
715
+ Expectation gives us a single value that summarizes a random variable's central tendency - it's the weighted average of all possible outcomes, where the weights are probabilities. The linearity property makes expectations easy to work with, even for complex combinations of random variables. While a PMF gives the complete probability picture, expectation provides an essential summary that helps us make decisions under uncertainty. In our next notebook, we'll explore variance, which measures how spread out a random variable's values are around its expectation.
716
+ """
717
+ )
718
+ return
719
+
720
+
721
+ @app.cell(hide_code=True)
722
+ def _(mo):
723
+ mo.md(r"""#### Appendix (containing helper code)""")
724
+ return
725
+
726
+
727
+ @app.cell(hide_code=True)
728
+ def _():
729
+ import marimo as mo
730
+ return (mo,)
731
+
732
+
733
+ @app.cell(hide_code=True)
734
+ def _():
735
+ import matplotlib.pyplot as plt
736
+ import numpy as np
737
+ from scipy import stats
738
+ import collections
739
+ return collections, np, plt, stats
740
+
741
+
742
+ @app.cell(hide_code=True)
743
+ def _(dist_selection, mo):
744
+ # Parameter controls for probability-based distributions
745
+ param_p = mo.ui.slider(
746
+ start=0.01,
747
+ stop=0.99,
748
+ step=0.01,
749
+ value=0.5,
750
+ label="p (probability of success)",
751
+ full_width=True
752
+ )
753
+
754
+ # Parameter control for binomial distribution
755
+ param_n = mo.ui.slider(
756
+ start=1,
757
+ stop=50,
758
+ step=1,
759
+ value=10,
760
+ label="n (number of trials)",
761
+ full_width=True
762
+ )
763
+
764
+ # Parameter control for Poisson distribution
765
+ param_lambda = mo.ui.slider(
766
+ start=0.1,
767
+ stop=20,
768
+ step=0.1,
769
+ value=5,
770
+ label="λ (rate parameter)",
771
+ full_width=True
772
+ )
773
+
774
+ # Parameter range sliders for visualization
775
+ param_range = mo.ui.range_slider(
776
+ start=0,
777
+ stop=1,
778
+ step=0.01,
779
+ value=[0, 1],
780
+ label="Parameter range to visualize",
781
+ full_width=True
782
+ )
783
+
784
+ lambda_range = mo.ui.range_slider(
785
+ start=0,
786
+ stop=20,
787
+ step=0.1,
788
+ value=[0, 20],
789
+ label="λ range to visualize",
790
+ full_width=True
791
+ )
792
+
793
+ # Display appropriate controls based on the selected distribution
794
+ if dist_selection.value == "bernoulli":
795
+ controls = mo.hstack([param_p, param_range], justify="space-around")
796
+ elif dist_selection.value == "binomial":
797
+ controls = mo.hstack([param_p, param_n, param_range], justify="space-around")
798
+ elif dist_selection.value == "geometric":
799
+ controls = mo.hstack([param_p, param_range], justify="space-around")
800
+ else: # poisson
801
+ controls = mo.hstack([param_lambda, lambda_range], justify="space-around")
802
+ return controls, lambda_range, param_lambda, param_n, param_p, param_range
803
+
804
+
805
+ @app.cell(hide_code=True)
806
+ def _(dist_selection, mo):
807
+ # Create distribution descriptions based on selection
808
+ if dist_selection.value == "bernoulli":
809
+ dist_description = mo.md(
810
+ r"""
811
+ **Bernoulli Distribution**
812
+
813
+ A Bernoulli distribution models a single trial with two possible outcomes: success (1) or failure (0).
814
+
815
+ - Parameter: $p$ = probability of success
816
+ - Expected Value: $E[X] = p$
817
+ - Example: Flipping a coin once (p = 0.5 for a fair coin)
818
+ """
819
+ )
820
+ elif dist_selection.value == "binomial":
821
+ dist_description = mo.md(
822
+ r"""
823
+ **Binomial Distribution**
824
+
825
+ A Binomial distribution models the number of successes in $n$ independent trials.
826
+
827
+ - Parameters: $n$ = number of trials, $p$ = probability of success
828
+ - Expected Value: $E[X] = np$
829
+ - Example: Number of heads in 10 coin flips
830
+ """
831
+ )
832
+ elif dist_selection.value == "geometric":
833
+ dist_description = mo.md(
834
+ r"""
835
+ **Geometric Distribution**
836
+
837
+ A Geometric distribution models the number of trials until the first success.
838
+
839
+ - Parameter: $p$ = probability of success
840
+ - Expected Value: $E[X] = \frac{1}{p}$
841
+ - Example: Number of coin flips until first heads
842
+ """
843
+ )
844
+ else: # poisson
845
+ dist_description = mo.md(
846
+ r"""
847
+ **Poisson Distribution**
848
+
849
+ A Poisson distribution models the number of events occurring in a fixed interval.
850
+
851
+ - Parameter: $\lambda$ = average rate of events
852
+ - Expected Value: $E[X] = \lambda$
853
+ - Example: Number of emails received per hour
854
+ """
855
+ )
856
+ return (dist_description,)
857
+
858
+
859
+ if __name__ == "__main__":
860
+ app.run()
probability/12_variance.py ADDED
@@ -0,0 +1,631 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # /// script
2
+ # requires-python = ">=3.10"
3
+ # dependencies = [
4
+ # "marimo",
5
+ # "matplotlib==3.10.0",
6
+ # "numpy==2.2.3",
7
+ # "scipy==1.15.2",
8
+ # "wigglystuff==0.1.10",
9
+ # ]
10
+ # ///
11
+
12
+ import marimo
13
+
14
+ __generated_with = "0.11.20"
15
+ app = marimo.App(width="medium", app_title="Variance")
16
+
17
+
18
+ @app.cell(hide_code=True)
19
+ def _(mo):
20
+ mo.md(
21
+ r"""
22
+ # Variance
23
+
24
+ _This notebook is a computational companion to ["Probability for Computer Scientists"](https://chrispiech.github.io/probabilityForComputerScientists/en/part2/variance/), by Stanford professor Chris Piech._
25
+
26
+ In our previous exploration of random variables, we learned about expectation - a measure of central tendency. However, knowing the average value alone doesn't tell us everything about a distribution. Consider these questions:
27
+
28
+ - How spread out are the values around the mean?
29
+ - How reliable is the expectation as a predictor of individual outcomes?
30
+ - How much do individual samples typically deviate from the average?
31
+
32
+ This is where **variance** comes in - it measures the spread or dispersion of a random variable around its expected value.
33
+ """
34
+ )
35
+ return
36
+
37
+
38
+ @app.cell(hide_code=True)
39
+ def _(mo):
40
+ mo.md(
41
+ r"""
42
+ ## Definition of Variance
43
+
44
+ The variance of a random variable $X$ with expected value $\mu = E[X]$ is defined as:
45
+
46
+ $$\text{Var}(X) = E[(X-\mu)^2]$$
47
+
48
+ This definition captures the average squared deviation from the mean. There's also an equivalent, often more convenient formula:
49
+
50
+ $$\text{Var}(X) = E[X^2] - (E[X])^2$$
51
+
52
+ /// tip
53
+ The second formula is usually easier to compute, as it only requires calculating $E[X^2]$ and $E[X]$, rather than working with deviations from the mean.
54
+ """
55
+ )
56
+ return
57
+
58
+
59
+ @app.cell(hide_code=True)
60
+ def _(mo):
61
+ mo.md(
62
+ r"""
63
+ ## Intuition Through Example
64
+
65
+ Let's look at a real-world example that illustrates why variance is important. Consider three different groups of graders evaluating assignments in a massive online course. Each grader has their own "grading distribution" - their pattern of assigning scores to work that deserves a 70/100.
66
+
67
+ The visualization below shows the probability distributions for three types of graders. Try clicking and dragging the blue numbers to adjust the parameters and see how they affect the variance.
68
+ """
69
+ )
70
+ return
71
+
72
+
73
+ @app.cell(hide_code=True)
74
+ def _(mo):
75
+ mo.md(
76
+ r"""
77
+ /// TIP
78
+ Try adjusting the blue numbers above to see how:
79
+
80
+ - Increasing spread increases variance
81
+ - The mixture ratio affects how many outliers appear in Grader C's distribution
82
+ - Changing the true grade shifts all distributions but maintains their relative variances
83
+ """
84
+ )
85
+ return
86
+
87
+
88
+ @app.cell(hide_code=True)
89
+ def _(controls):
90
+ controls
91
+ return
92
+
93
+
94
+ @app.cell(hide_code=True)
95
+ def _(
96
+ grader_a_spread,
97
+ grader_b_spread,
98
+ grader_c_mix,
99
+ np,
100
+ plt,
101
+ stats,
102
+ true_grade,
103
+ ):
104
+ # Create data for three grader distributions
105
+ _grader_x = np.linspace(40, 100, 200)
106
+
107
+ # Calculate actual variances
108
+ var_a = grader_a_spread.amount**2
109
+ var_b = grader_b_spread.amount**2
110
+ var_c = (1-grader_c_mix.amount) * 3**2 + grader_c_mix.amount * 8**2 + \
111
+ grader_c_mix.amount * (1-grader_c_mix.amount) * (8-3)**2 # Mixture variance formula
112
+
113
+ # Grader A: Wide spread around true grade
114
+ grader_a = stats.norm.pdf(_grader_x, loc=true_grade.amount, scale=grader_a_spread.amount)
115
+
116
+ # Grader B: Narrow spread around true grade
117
+ grader_b = stats.norm.pdf(_grader_x, loc=true_grade.amount, scale=grader_b_spread.amount)
118
+
119
+ # Grader C: Mixture of distributions
120
+ grader_c = (1-grader_c_mix.amount) * stats.norm.pdf(_grader_x, loc=true_grade.amount, scale=3) + \
121
+ grader_c_mix.amount * stats.norm.pdf(_grader_x, loc=true_grade.amount, scale=8)
122
+
123
+ grader_fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(15, 5))
124
+
125
+ # Plot each distribution
126
+ ax1.fill_between(_grader_x, grader_a, alpha=0.3, color='green', label=f'Var ≈ {var_a:.2f}')
127
+ ax1.axvline(x=true_grade.amount, color='black', linestyle='--', label='True Grade')
128
+ ax1.set_title('Grader A: High Variance')
129
+ ax1.set_xlabel('Grade')
130
+ ax1.set_ylabel('Pr(G = g)')
131
+ ax1.set_ylim(0, max(grader_a)*1.1)
132
+
133
+ ax2.fill_between(_grader_x, grader_b, alpha=0.3, color='blue', label=f'Var ≈ {var_b:.2f}')
134
+ ax2.axvline(x=true_grade.amount, color='black', linestyle='--')
135
+ ax2.set_title('Grader B: Low Variance')
136
+ ax2.set_xlabel('Grade')
137
+ ax2.set_ylim(0, max(grader_b)*1.1)
138
+
139
+ ax3.fill_between(_grader_x, grader_c, alpha=0.3, color='purple', label=f'Var ≈ {var_c:.2f}')
140
+ ax3.axvline(x=true_grade.amount, color='black', linestyle='--')
141
+ ax3.set_title('Grader C: Mixed Distribution')
142
+ ax3.set_xlabel('Grade')
143
+ ax3.set_ylim(0, max(grader_c)*1.1)
144
+
145
+ # Add annotations to explain what's happening
146
+ ax1.annotate('Wide spread = high variance',
147
+ xy=(true_grade.amount, max(grader_a)*0.5),
148
+ xytext=(true_grade.amount-15, max(grader_a)*0.7),
149
+ arrowprops=dict(facecolor='black', shrink=0.05, width=1))
150
+
151
+ ax2.annotate('Narrow spread = low variance',
152
+ xy=(true_grade.amount, max(grader_b)*0.5),
153
+ xytext=(true_grade.amount+8, max(grader_b)*0.7),
154
+ arrowprops=dict(facecolor='black', shrink=0.05, width=1))
155
+
156
+ ax3.annotate('Mixture creates outliers',
157
+ xy=(true_grade.amount+15, grader_c[np.where(_grader_x >= true_grade.amount+15)[0][0]]),
158
+ xytext=(true_grade.amount+5, max(grader_c)*0.7),
159
+ arrowprops=dict(facecolor='black', shrink=0.05, width=1))
160
+
161
+ # Add legends and adjust layout
162
+ for _ax in [ax1, ax2, ax3]:
163
+ _ax.legend()
164
+ _ax.grid(alpha=0.2)
165
+
166
+ plt.tight_layout()
167
+ plt.gca()
168
+ return (
169
+ ax1,
170
+ ax2,
171
+ ax3,
172
+ grader_a,
173
+ grader_b,
174
+ grader_c,
175
+ grader_fig,
176
+ var_a,
177
+ var_b,
178
+ var_c,
179
+ )
180
+
181
+
182
+ @app.cell(hide_code=True)
183
+ def _(mo):
184
+ mo.md(
185
+ r"""
186
+ /// note
187
+ All three distributions have the same expected value (the true grade), but they differ significantly in their spread:
188
+
189
+ - **Grader A** has high variance - grades vary widely from the true value
190
+ - **Grader B** has low variance - grades consistently stay close to the true value
191
+ - **Grader C** has a mixture distribution - mostly consistent but with occasional extreme values
192
+
193
+ This illustrates why variance is crucial: two distributions can have the same mean but behave very differently in practice.
194
+ """
195
+ )
196
+ return
197
+
198
+
199
+ @app.cell(hide_code=True)
200
+ def _(mo):
201
+ mo.md(
202
+ r"""
203
+ ## Computing Variance
204
+
205
+ Let's work through some concrete examples to understand how to calculate variance.
206
+
207
+ ### Example 1: Fair Die Roll
208
+
209
+ Consider rolling a fair six-sided die. We'll calculate its variance step by step:
210
+ """
211
+ )
212
+ return
213
+
214
+
215
+ @app.cell
216
+ def _(np):
217
+ # Define the die values and probabilities
218
+ die_values = np.array([1, 2, 3, 4, 5, 6])
219
+ die_probs = np.array([1/6] * 6)
220
+
221
+ # Calculate E[X]
222
+ expected_value = np.sum(die_values * die_probs)
223
+
224
+ # Calculate E[X^2]
225
+ expected_square = np.sum(die_values**2 * die_probs)
226
+
227
+ # Calculate Var(X) = E[X^2] - (E[X])^2
228
+ variance = expected_square - expected_value**2
229
+
230
+ # Calculate standard deviation
231
+ std_dev = np.sqrt(variance)
232
+
233
+ print(f"E[X] = {expected_value:.2f}")
234
+ print(f"E[X^2] = {expected_square:.2f}")
235
+ print(f"Var(X) = {variance:.2f}")
236
+ print(f"Standard Deviation = {std_dev:.2f}")
237
+ return (
238
+ die_probs,
239
+ die_values,
240
+ expected_square,
241
+ expected_value,
242
+ std_dev,
243
+ variance,
244
+ )
245
+
246
+
247
+ @app.cell(hide_code=True)
248
+ def _(mo):
249
+ mo.md(
250
+ r"""
251
+ /// NOTE
252
+ For a fair die:
253
+
254
+ - The expected value (3.50) tells us the average roll
255
+ - The variance (2.92) tells us how much typical rolls deviate from this average
256
+ - The standard deviation (1.71) gives us this spread in the original units
257
+ """
258
+ )
259
+ return
260
+
261
+
262
+ @app.cell(hide_code=True)
263
+ def _(mo):
264
+ mo.md(
265
+ r"""
266
+ ## Properties of Variance
267
+
268
+ Variance has several important properties that make it useful for analyzing random variables:
269
+
270
+ 1. **Non-negativity**: $\text{Var}(X) \geq 0$ for any random variable $X$
271
+ 2. **Variance of a constant**: $\text{Var}(c) = 0$ for any constant $c$
272
+ 3. **Scaling**: $\text{Var}(aX) = a^2\text{Var}(X)$ for any constant $a$
273
+ 4. **Translation**: $\text{Var}(X + b) = \text{Var}(X)$ for any constant $b$
274
+ 5. **Independence**: If $X$ and $Y$ are independent, then $\text{Var}(X + Y) = \text{Var}(X) + \text{Var}(Y)$
275
+
276
+ Let's verify a property with an example.
277
+ """
278
+ )
279
+ return
280
+
281
+
282
+ @app.cell(hide_code=True)
283
+ def _(mo):
284
+ mo.md(
285
+ r"""
286
+ ## Proof of Variance Formula
287
+
288
+ The equivalence of the two variance formulas is a fundamental result in probability theory. Here's the proof:
289
+
290
+ Starting with the definition $\text{Var}(X) = E[(X-\mu)^2]$ where $\mu = E[X]$:
291
+
292
+ \begin{align}
293
+ \text{Var}(X) &= E[(X-\mu)^2] \\
294
+ &= \sum_x(x-\mu)^2P(x) && \text{Definition of Expectation}\\
295
+ &= \sum_x (x^2 -2\mu x + \mu^2)P(x) && \text{Expanding the square}\\
296
+ &= \sum_x x^2P(x)- 2\mu \sum_x xP(x) + \mu^2 \sum_x P(x) && \text{Distributing the sum}\\
297
+ &= E[X^2]- 2\mu E[X] + \mu^2 && \text{Definition of expectation}\\
298
+ &= E[X^2]- 2(E[X])^2 + (E[X])^2 && \text{Since }\mu = E[X]\\
299
+ &= E[X^2]- (E[X])^2 && \text{Simplifying}
300
+ \end{align}
301
+
302
+ /// tip
303
+ This proof shows why the formula $\text{Var}(X) = E[X^2] - (E[X])^2$ is so useful - it's much easier to compute $E[X^2]$ and $E[X]$ separately than to work with deviations directly.
304
+ """
305
+ )
306
+ return
307
+
308
+
309
+ @app.cell
310
+ def _(die_probs, die_values, np):
311
+ # Demonstrate scaling property
312
+ a = 2 # Scale factor
313
+
314
+ # Original variance
315
+ original_var = np.sum(die_values**2 * die_probs) - (np.sum(die_values * die_probs))**2
316
+
317
+ # Scaled random variable variance
318
+ scaled_values = a * die_values
319
+ scaled_var = np.sum(scaled_values**2 * die_probs) - (np.sum(scaled_values * die_probs))**2
320
+
321
+ print(f"Original Variance: {original_var:.2f}")
322
+ print(f"Scaled Variance (a={a}): {scaled_var:.2f}")
323
+ print(f"a^2 * Original Variance: {a**2 * original_var:.2f}")
324
+ print(f"Property holds: {abs(scaled_var - a**2 * original_var) < 1e-10}")
325
+ return a, original_var, scaled_values, scaled_var
326
+
327
+
328
+ @app.cell
329
+ def _():
330
+ # DIY : Prove more properties as shown above
331
+ return
332
+
333
+
334
+ @app.cell(hide_code=True)
335
+ def _(mo):
336
+ mo.md(
337
+ r"""
338
+ ## Standard Deviation
339
+
340
+ While variance is mathematically convenient, it has one practical drawback: its units are squared. For example, if we're measuring grades (0-100), the variance is in "grade points squared." This makes it hard to interpret intuitively.
341
+
342
+ The **standard deviation**, denoted by $\sigma$ or $\text{SD}(X)$, is the square root of variance:
343
+
344
+ $$\sigma = \sqrt{\text{Var}(X)}$$
345
+
346
+ /// tip
347
+ Standard deviation is often more intuitive because it's in the same units as the original data. For a normal distribution, approximately:
348
+ - 68% of values fall within 1 standard deviation of the mean
349
+ - 95% of values fall within 2 standard deviations
350
+ - 99.7% of values fall within 3 standard deviations
351
+ """
352
+ )
353
+ return
354
+
355
+
356
+ @app.cell(hide_code=True)
357
+ def _(controls1):
358
+ controls1
359
+ return
360
+
361
+
362
+ @app.cell(hide_code=True)
363
+ def _(TangleSlider, mo):
364
+ normal_mean = mo.ui.anywidget(TangleSlider(
365
+ amount=0,
366
+ min_value=-5,
367
+ max_value=5,
368
+ step=0.5,
369
+ digits=1,
370
+ suffix=" units"
371
+ ))
372
+
373
+ normal_std = mo.ui.anywidget(TangleSlider(
374
+ amount=1,
375
+ min_value=0.1,
376
+ max_value=3,
377
+ step=0.1,
378
+ digits=1,
379
+ suffix=" units"
380
+ ))
381
+
382
+ # Create a grid layout for the controls
383
+ controls1 = mo.vstack([
384
+ mo.md("### Interactive Normal Distribution"),
385
+ mo.hstack([
386
+ mo.md("Adjust the parameters to see how standard deviation affects the shape of the distribution:"),
387
+ ]),
388
+ mo.hstack([
389
+ mo.md("Mean (μ): "),
390
+ normal_mean,
391
+ mo.md(" Standard deviation (σ): "),
392
+ normal_std
393
+ ], justify="start"),
394
+ ])
395
+ return controls1, normal_mean, normal_std
396
+
397
+
398
+ @app.cell(hide_code=True)
399
+ def _(normal_mean, normal_std, np, plt, stats):
400
+ # data for normal distribution
401
+ _normal_x = np.linspace(-10, 10, 1000)
402
+ _normal_y = stats.norm.pdf(_normal_x, loc=normal_mean.amount, scale=normal_std.amount)
403
+
404
+ # ranges for standard deviation intervals
405
+ one_sigma_left = normal_mean.amount - normal_std.amount
406
+ one_sigma_right = normal_mean.amount + normal_std.amount
407
+ two_sigma_left = normal_mean.amount - 2 * normal_std.amount
408
+ two_sigma_right = normal_mean.amount + 2 * normal_std.amount
409
+ three_sigma_left = normal_mean.amount - 3 * normal_std.amount
410
+ three_sigma_right = normal_mean.amount + 3 * normal_std.amount
411
+
412
+ # Create the plot
413
+ normal_fig, normal_ax = plt.subplots(figsize=(10, 6))
414
+
415
+ # Plot the distribution
416
+ normal_ax.plot(_normal_x, _normal_y, 'b-', linewidth=2)
417
+
418
+ # stdev intervals
419
+ normal_ax.fill_between(_normal_x, 0, _normal_y, where=(_normal_x >= one_sigma_left) & (_normal_x <= one_sigma_right),
420
+ alpha=0.3, color='red', label='68% (±1σ)')
421
+ normal_ax.fill_between(_normal_x, 0, _normal_y, where=(_normal_x >= two_sigma_left) & (_normal_x <= two_sigma_right),
422
+ alpha=0.2, color='green', label='95% (±2σ)')
423
+ normal_ax.fill_between(_normal_x, 0, _normal_y, where=(_normal_x >= three_sigma_left) & (_normal_x <= three_sigma_right),
424
+ alpha=0.1, color='blue', label='99.7% (±3σ)')
425
+
426
+ # vertical lines for the mean and standard deviations
427
+ normal_ax.axvline(x=normal_mean.amount, color='black', linestyle='-', linewidth=1.5, label='Mean (μ)')
428
+ normal_ax.axvline(x=one_sigma_left, color='red', linestyle='--', linewidth=1)
429
+ normal_ax.axvline(x=one_sigma_right, color='red', linestyle='--', linewidth=1)
430
+ normal_ax.axvline(x=two_sigma_left, color='green', linestyle='--', linewidth=1)
431
+ normal_ax.axvline(x=two_sigma_right, color='green', linestyle='--', linewidth=1)
432
+
433
+ # annotations
434
+ normal_ax.annotate(f'μ = {normal_mean.amount:.2f}',
435
+ xy=(normal_mean.amount, max(_normal_y)*0.5),
436
+ xytext=(normal_mean.amount + 0.5, max(_normal_y)*0.8),
437
+ arrowprops=dict(facecolor='black', shrink=0.05, width=1))
438
+
439
+ normal_ax.annotate(f'σ = {normal_std.amount:.2f}',
440
+ xy=(one_sigma_right, stats.norm.pdf(one_sigma_right, loc=normal_mean.amount, scale=normal_std.amount)),
441
+ xytext=(one_sigma_right + 0.5, max(_normal_y)*0.6),
442
+ arrowprops=dict(facecolor='red', shrink=0.05, width=1))
443
+
444
+ # labels and title
445
+ normal_ax.set_xlabel('Value')
446
+ normal_ax.set_ylabel('Probability Density')
447
+ normal_ax.set_title(f'Normal Distribution with μ = {normal_mean.amount:.2f} and σ = {normal_std.amount:.2f}')
448
+
449
+ # legend and grid
450
+ normal_ax.legend()
451
+ normal_ax.grid(alpha=0.3)
452
+
453
+ plt.tight_layout()
454
+ plt.gca()
455
+ return (
456
+ normal_ax,
457
+ normal_fig,
458
+ one_sigma_left,
459
+ one_sigma_right,
460
+ three_sigma_left,
461
+ three_sigma_right,
462
+ two_sigma_left,
463
+ two_sigma_right,
464
+ )
465
+
466
+
467
+ @app.cell(hide_code=True)
468
+ def _(mo):
469
+ mo.md(
470
+ r"""
471
+ /// tip
472
+ The interactive visualization above demonstrates how standard deviation (σ) affects the shape of a normal distribution:
473
+
474
+ - The **red region** covers μ ± 1σ, containing approximately 68% of the probability
475
+ - The **green region** covers μ ± 2σ, containing approximately 95% of the probability
476
+ - The **blue region** covers μ ± 3σ, containing approximately 99.7% of the probability
477
+
478
+ This is known as the "68-95-99.7 rule" or the "empirical rule" and is a useful heuristic for understanding the spread of data.
479
+ """
480
+ )
481
+ return
482
+
483
+
484
+ @app.cell(hide_code=True)
485
+ def _(mo):
486
+ mo.md(
487
+ r"""
488
+ ## 🤔 Test Your Understanding
489
+
490
+ Choose what you believe are the correct options in the questions below:
491
+
492
+ <details>
493
+ <summary>The variance of a random variable can be negative.</summary>
494
+ ❌ False! Variance is defined as an expected value of squared deviations, and squares are always non-negative.
495
+ </details>
496
+
497
+ <details>
498
+ <summary>If X and Y are independent random variables, then Var(X + Y) = Var(X) + Var(Y).</summary>
499
+ ✅ True! This is one of the key properties of variance for independent random variables.
500
+ </details>
501
+
502
+ <details>
503
+ <summary>Multiplying a random variable by 2 multiplies its variance by 2.</summary>
504
+ ❌ False! Multiplying a random variable by a constant a multiplies its variance by a². So multiplying by 2 multiplies variance by 4.
505
+ </details>
506
+
507
+ <details>
508
+ <summary>Standard deviation is always equal to the square root of variance.</summary>
509
+ ✅ True! By definition, standard deviation σ = √Var(X).
510
+ </details>
511
+
512
+ <details>
513
+ <summary>If Var(X) = 0, then X must be a constant.</summary>
514
+ ✅ True! Zero variance means there is no spread around the mean, so X can only take one value.
515
+ </details>
516
+ """
517
+ )
518
+ return
519
+
520
+
521
+ @app.cell(hide_code=True)
522
+ def _(mo):
523
+ mo.md(
524
+ r"""
525
+ ## Key Takeaways
526
+
527
+ Variance gives us a way to measure how spread out a random variable is around its mean. It's like the "uncertainty" in our expectation - a high variance means individual outcomes can differ widely from what we expect on average.
528
+
529
+ Standard deviation brings this measure back to the original units, making it easier to interpret. For grades, a standard deviation of 10 points means typical grades fall within about 10 points of the average.
530
+
531
+ Variance pops up everywhere - from weather forecasts (how reliable is the predicted temperature?) to financial investments (how risky is this stock?) to quality control (how consistent is our manufacturing process?).
532
+
533
+ In our next notebook, we'll explore more properties of random variables and see how they combine to form more complex distributions.
534
+ """
535
+ )
536
+ return
537
+
538
+
539
+ @app.cell(hide_code=True)
540
+ def _(mo):
541
+ mo.md(r"""Appendix (containing helper code):""")
542
+ return
543
+
544
+
545
+ @app.cell(hide_code=True)
546
+ def _():
547
+ import marimo as mo
548
+ return (mo,)
549
+
550
+
551
+ @app.cell(hide_code=True)
552
+ def _():
553
+ import numpy as np
554
+ import scipy.stats as stats
555
+ import matplotlib.pyplot as plt
556
+ from wigglystuff import TangleSlider
557
+ return TangleSlider, np, plt, stats
558
+
559
+
560
+ @app.cell(hide_code=True)
561
+ def _(TangleSlider, mo):
562
+ # Create interactive elements using TangleSlider for a more inline experience
563
+ true_grade = mo.ui.anywidget(TangleSlider(
564
+ amount=70,
565
+ min_value=50,
566
+ max_value=90,
567
+ step=5,
568
+ digits=0,
569
+ suffix=" points"
570
+ ))
571
+
572
+ grader_a_spread = mo.ui.anywidget(TangleSlider(
573
+ amount=10,
574
+ min_value=5,
575
+ max_value=20,
576
+ step=1,
577
+ digits=0,
578
+ suffix=" points"
579
+ ))
580
+
581
+ grader_b_spread = mo.ui.anywidget(TangleSlider(
582
+ amount=2,
583
+ min_value=1,
584
+ max_value=5,
585
+ step=0.5,
586
+ digits=1,
587
+ suffix=" points"
588
+ ))
589
+
590
+ grader_c_mix = mo.ui.anywidget(TangleSlider(
591
+ amount=0.2,
592
+ min_value=0,
593
+ max_value=1,
594
+ step=0.05,
595
+ digits=2,
596
+ suffix=" proportion"
597
+ ))
598
+ return grader_a_spread, grader_b_spread, grader_c_mix, true_grade
599
+
600
+
601
+ @app.cell(hide_code=True)
602
+ def _(grader_a_spread, grader_b_spread, grader_c_mix, mo, true_grade):
603
+ # Create a grid layout for the interactive controls
604
+ controls = mo.vstack([
605
+ mo.md("### Adjust Parameters to See How Variance Changes"),
606
+ mo.hstack([
607
+ mo.md("**True grade:** The correct score that should be assigned is "),
608
+ true_grade,
609
+ mo.md(" out of 100.")
610
+ ], justify="start"),
611
+ mo.hstack([
612
+ mo.md("**Grader A:** Has a wide spread with standard deviation of "),
613
+ grader_a_spread,
614
+ mo.md(" points.")
615
+ ], justify="start"),
616
+ mo.hstack([
617
+ mo.md("**Grader B:** Has a narrow spread with standard deviation of "),
618
+ grader_b_spread,
619
+ mo.md(" points.")
620
+ ], justify="start"),
621
+ mo.hstack([
622
+ mo.md("**Grader C:** Has a mixture distribution with "),
623
+ grader_c_mix,
624
+ mo.md(" proportion of outliers.")
625
+ ], justify="start"),
626
+ ])
627
+ return (controls,)
628
+
629
+
630
+ if __name__ == "__main__":
631
+ app.run()
probability/13_bernoulli_distribution.py ADDED
@@ -0,0 +1,427 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # /// script
2
+ # requires-python = ">=3.10"
3
+ # dependencies = [
4
+ # "marimo",
5
+ # "matplotlib==3.10.0",
6
+ # "numpy==2.2.3",
7
+ # "scipy==1.15.2",
8
+ # ]
9
+ # ///
10
+
11
+ import marimo
12
+
13
+ __generated_with = "0.11.22"
14
+ app = marimo.App(width="medium", app_title="Bernoulli Distribution")
15
+
16
+
17
+ @app.cell(hide_code=True)
18
+ def _(mo):
19
+ mo.md(
20
+ r"""
21
+ # Bernoulli Distribution
22
+
23
+ _This notebook is a computational companion to ["Probability for Computer Scientists"](https://chrispiech.github.io/probabilityForComputerScientists/en/part2/bernoulli/), by Stanford professor Chris Piech._
24
+
25
+ ## Parametric Random Variables
26
+
27
+ There are many classic and commonly-seen random variable abstractions that show up in the world of probability. At this point, we'll learn about several of the most significant parametric discrete distributions.
28
+
29
+ When solving problems, if you can recognize that a random variable fits one of these formats, then you can use its pre-derived Probability Mass Function (PMF), expectation, variance, and other properties. Random variables of this sort are called **parametric random variables**. If you can argue that a random variable falls under one of the studied parametric types, you simply need to provide parameters.
30
+
31
+ > A good analogy is a `class` in programming. Creating a parametric random variable is very similar to calling a constructor with input parameters.
32
+ """
33
+ )
34
+ return
35
+
36
+
37
+ @app.cell(hide_code=True)
38
+ def _(mo):
39
+ mo.md(
40
+ r"""
41
+ ## Bernoulli Random Variables
42
+
43
+ A **Bernoulli random variable** (also called a boolean or indicator random variable) is the simplest kind of parametric random variable. It can take on two values: 1 and 0.
44
+
45
+ It takes on a 1 if an experiment with probability $p$ resulted in success and a 0 otherwise.
46
+
47
+ Some example uses include:
48
+
49
+ - A coin flip (heads = 1, tails = 0)
50
+ - A random binary digit
51
+ - Whether a disk drive crashed
52
+ - Whether someone likes a Netflix movie
53
+
54
+ Here $p$ is the parameter, but different instances of Bernoulli random variables might have different values of $p$.
55
+ """
56
+ )
57
+ return
58
+
59
+
60
+ @app.cell(hide_code=True)
61
+ def _(mo):
62
+ mo.md(
63
+ r"""
64
+ ## Key Properties of a Bernoulli Random Variable
65
+
66
+ If $X$ is declared to be a Bernoulli random variable with parameter $p$, denoted $X \sim \text{Bern}(p)$, it has the following properties:
67
+ """
68
+ )
69
+ return
70
+
71
+
72
+ @app.cell
73
+ def _(stats):
74
+ # Define the Bernoulli distribution function
75
+ def Bern(p):
76
+ return stats.bernoulli(p)
77
+ return (Bern,)
78
+
79
+
80
+ @app.cell(hide_code=True)
81
+ def _(mo):
82
+ mo.md(
83
+ r"""
84
+ ## Bernoulli Distribution Properties
85
+
86
+ $\begin{array}{lll}
87
+ \text{Notation:} & X \sim \text{Bern}(p) \\
88
+ \text{Description:} & \text{A boolean variable that is 1 with probability } p \\
89
+ \text{Parameters:} & p, \text{ the probability that } X = 1 \\
90
+ \text{Support:} & x \text{ is either 0 or 1} \\
91
+ \text{PMF equation:} & P(X = x) =
92
+ \begin{cases}
93
+ p & \text{if }x = 1\\
94
+ 1-p & \text{if }x = 0
95
+ \end{cases} \\
96
+ \text{PMF (smooth):} & P(X = x) = p^x(1-p)^{1-x} \\
97
+ \text{Expectation:} & E[X] = p \\
98
+ \text{Variance:} & \text{Var}(X) = p(1-p) \\
99
+ \end{array}$
100
+ """
101
+ )
102
+ return
103
+
104
+
105
+ @app.cell(hide_code=True)
106
+ def _(mo, p_slider):
107
+ # Visualization of the Bernoulli PMF
108
+ _p = p_slider.value
109
+
110
+ # Values for PMF
111
+ values = [0, 1]
112
+ probabilities = [1 - _p, _p]
113
+
114
+ # Relevant statistics
115
+ expected_value = _p
116
+ variance = _p * (1 - _p)
117
+
118
+ mo.md(f"""
119
+ ## PMF Graph for Bernoulli($p={_p:.2f}$)
120
+
121
+ Parameter $p$: {p_slider}
122
+
123
+ Expected value: $E[X] = {expected_value:.2f}$
124
+
125
+ Variance: $\\text{{Var}}(X) = {variance:.2f}$
126
+ """)
127
+ return expected_value, probabilities, values, variance
128
+
129
+
130
+ @app.cell(hide_code=True)
131
+ def _(expected_value, p_slider, plt, probabilities, values, variance):
132
+ # PMF
133
+ _p = p_slider.value
134
+ fig, ax = plt.subplots(figsize=(10, 6))
135
+
136
+ # Bar plot for PMF
137
+ ax.bar(values, probabilities, width=0.4, color='blue', alpha=0.7)
138
+
139
+ ax.set_xlabel('Values that X can take on')
140
+ ax.set_ylabel('Probability')
141
+ ax.set_title(f'PMF of Bernoulli Distribution with p = {_p:.2f}')
142
+
143
+ # x-axis limit
144
+ ax.set_xticks([0, 1])
145
+ ax.set_xlim(-0.5, 1.5)
146
+
147
+ # y-axis w/ some padding
148
+ ax.set_ylim(0, max(probabilities) * 1.1)
149
+
150
+ # Add expectation as vertical line
151
+ ax.axvline(x=expected_value, color='red', linestyle='--',
152
+ label=f'E[X] = {expected_value:.2f}')
153
+
154
+ # Add variance annotation
155
+ ax.text(0.5, max(probabilities) * 0.8,
156
+ f'Var(X) = {variance:.3f}',
157
+ horizontalalignment='center',
158
+ bbox=dict(facecolor='white', alpha=0.7))
159
+
160
+ ax.legend()
161
+ plt.tight_layout()
162
+ plt.gca()
163
+ return ax, fig
164
+
165
+
166
+ @app.cell(hide_code=True)
167
+ def _(mo):
168
+ mo.md(
169
+ r"""
170
+ ## Proof: Expectation of a Bernoulli
171
+
172
+ If $X$ is a Bernoulli with parameter $p$, $X \sim \text{Bern}(p)$:
173
+
174
+ \begin{align}
175
+ E[X] &= \sum_x x \cdot (X=x) && \text{Definition of expectation} \\
176
+ &= 1 \cdot p + 0 \cdot (1-p) &&
177
+ X \text{ can take on values 0 and 1} \\
178
+ &= p && \text{Remove the 0 term}
179
+ \end{align}
180
+
181
+ ## Proof: Variance of a Bernoulli
182
+
183
+ If $X$ is a Bernoulli with parameter $p$, $X \sim \text{Bern}(p)$:
184
+
185
+ To compute variance, first compute $E[X^2]$:
186
+
187
+ \begin{align}
188
+ E[X^2]
189
+ &= \sum_x x^2 \cdot (X=x) &&\text{LOTUS}\\
190
+ &= 0^2 \cdot (1-p) + 1^2 \cdot p\\
191
+ &= p
192
+ \end{align}
193
+
194
+ \begin{align}
195
+ (X)
196
+ &= E[X^2] - E[X]^2&& \text{Def of variance} \\
197
+ &= p - p^2 && \text{Substitute }E[X^2]=p, E[X] = p \\
198
+ &= p (1-p) && \text{Factor out }p
199
+ \end{align}
200
+ """
201
+ )
202
+ return
203
+
204
+
205
+ @app.cell(hide_code=True)
206
+ def _(mo):
207
+ mo.md(
208
+ r"""
209
+ ## Indicator Random Variable
210
+
211
+ > **Definition**: An indicator variable is a Bernoulli random variable which takes on the value 1 if an **underlying event occurs**, and 0 _otherwise_.
212
+
213
+ Indicator random variables are a convenient way to convert the "true/false" outcome of an event into a number. That number may be easier to incorporate into an equation.
214
+
215
+ A random variable $I$ is an indicator variable for an event $A$ if $I = 1$ when $A$ occurs and $I = 0$ if $A$ does not occur. Indicator random variables are Bernoulli random variables, with $p = P(A)$. $I_A$ is a common choice of name for an indicator random variable.
216
+
217
+ Here are some properties of indicator random variables:
218
+
219
+ - $P(I=1)=P(A)$
220
+ - $E[I]=P(A)$
221
+ """
222
+ )
223
+ return
224
+
225
+
226
+ @app.cell(hide_code=True)
227
+ def _(mo):
228
+ # Simulation of Bernoulli trials
229
+ mo.md(r"""
230
+ ## Simulation of Bernoulli Trials
231
+
232
+ Let's simulate Bernoulli trials to see the law of large numbers in action. We'll flip a biased coin repeatedly and observe how the proportion of successes approaches the true probability $p$.
233
+ """)
234
+
235
+ # UI element for simulation parameters
236
+ num_trials_slider = mo.ui.slider(10, 10000, value=1000, step=10, label="Number of trials")
237
+ p_sim_slider = mo.ui.slider(0.01, 0.99, value=0.65, step=0.01, label="Success probability (p)")
238
+ return num_trials_slider, p_sim_slider
239
+
240
+
241
+ @app.cell(hide_code=True)
242
+ def _(mo):
243
+ mo.md(r"""## Simulation""")
244
+ return
245
+
246
+
247
+ @app.cell(hide_code=True)
248
+ def _(mo, num_trials_slider, p_sim_slider):
249
+ mo.hstack([num_trials_slider, p_sim_slider], justify='space-around')
250
+ return
251
+
252
+
253
+ @app.cell(hide_code=True)
254
+ def _(np, num_trials_slider, p_sim_slider, plt):
255
+ # Bernoulli trials
256
+ _num_trials = num_trials_slider.value
257
+ p = p_sim_slider.value
258
+
259
+ # Random Bernoulli trials
260
+ trials = np.random.binomial(1, p, size=_num_trials)
261
+
262
+ # Cumulative proportion of successes
263
+ cumulative_mean = np.cumsum(trials) / np.arange(1, _num_trials + 1)
264
+
265
+ # Results
266
+ plt.figure(figsize=(10, 6))
267
+ plt.plot(range(1, _num_trials + 1), cumulative_mean, label='Proportion of successes')
268
+ plt.axhline(y=p, color='r', linestyle='--', label=f'True probability (p={p})')
269
+
270
+ plt.xscale('log') # Use log scale for better visualization
271
+ plt.xlabel('Number of trials')
272
+ plt.ylabel('Proportion of successes')
273
+ plt.title('Convergence of Sample Proportion to True Probability')
274
+ plt.legend()
275
+ plt.grid(True, alpha=0.3)
276
+
277
+ # Add annotation
278
+ plt.annotate('As the number of trials increases,\nthe proportion approaches p',
279
+ xy=(_num_trials, cumulative_mean[-1]),
280
+ xytext=(_num_trials/5, p + 0.1),
281
+ arrowprops=dict(facecolor='black', shrink=0.05, width=1))
282
+
283
+ plt.tight_layout()
284
+ plt.gca()
285
+ return cumulative_mean, p, trials
286
+
287
+
288
+ @app.cell(hide_code=True)
289
+ def _(mo, np, trials):
290
+ # Calculate statistics from the simulation
291
+ num_successes = np.sum(trials)
292
+ num_trials = len(trials)
293
+ proportion = num_successes / num_trials
294
+
295
+ # Display the results
296
+ mo.md(f"""
297
+ ### Simulation Results
298
+
299
+ - Number of trials: {num_trials}
300
+ - Number of successes: {num_successes}
301
+ - Proportion of successes: {proportion:.4f}
302
+
303
+ This demonstrates how the sample proportion approaches the true probability $p$ as the number of trials increases.
304
+ """)
305
+ return num_successes, num_trials, proportion
306
+
307
+
308
+ @app.cell(hide_code=True)
309
+ def _(mo):
310
+ mo.md(
311
+ r"""
312
+ ## 🤔 Test Your Understanding
313
+
314
+ Pick which of these statements about Bernoulli random variables you think are correct:
315
+
316
+ /// details | The variance of a Bernoulli random variable is always less than or equal to 0.25
317
+ ✅ Correct! The variance $p(1-p)$ reaches its maximum value of 0.25 when $p = 0.5$.
318
+ ///
319
+
320
+ /// details | The expected value of a Bernoulli random variable must be either 0 or 1
321
+ ❌ Incorrect! The expected value is $p$, which can be any value between 0 and 1.
322
+ ///
323
+
324
+ /// details | If $X \sim \text{Bern}(0.3)$ and $Y \sim \text{Bern}(0.7)$, then $X$ and $Y$ have the same variance
325
+ ✅ Correct! $\text{Var}(X) = 0.3 \times 0.7 = 0.21$ and $\text{Var}(Y) = 0.7 \times 0.3 = 0.21$.
326
+ ///
327
+
328
+ /// details | Two independent coin flips can be modeled as the sum of two Bernoulli random variables
329
+ ✅ Correct! The sum would follow a Binomial distribution with $n=2$.
330
+ ///
331
+ """
332
+ )
333
+ return
334
+
335
+
336
+ @app.cell(hide_code=True)
337
+ def _(mo):
338
+ mo.md(
339
+ r"""
340
+ ## Applications of Bernoulli Random Variables
341
+
342
+ Bernoulli random variables are used in many real-world scenarios:
343
+
344
+ 1. **Quality Control**: Testing if a manufactured item is defective (1) or not (0)
345
+
346
+ 2. **A/B Testing**: Determining if a user clicks (1) or doesn't click (0) on a website button
347
+
348
+ 3. **Medical Testing**: Checking if a patient tests positive (1) or negative (0) for a disease
349
+
350
+ 4. **Election Modeling**: Modeling if a particular voter votes for candidate A (1) or not (0)
351
+
352
+ 5. **Financial Markets**: Modeling if a stock price goes up (1) or down (0) in a simplified model
353
+
354
+ Because Bernoulli random variables are parametric, as soon as you declare a random variable to be of type Bernoulli, you automatically know all of its pre-derived properties!
355
+ """
356
+ )
357
+ return
358
+
359
+
360
+ @app.cell(hide_code=True)
361
+ def _(mo):
362
+ mo.md(
363
+ r"""
364
+ ## Summary
365
+
366
+ And that's a wrap on Bernoulli distributions! We've learnt the simplest of all probability distributions — the one that only has two possible outcomes. Flip a coin, check if an email is spam, see if your blind date shows up — these are all Bernoulli trials with success probability $p$.
367
+
368
+ The beauty of Bernoulli is in its simplicity: just set $p$ (the probability of success) and you're good to go! The PMF gives us $P(X=1) = p$ and $P(X=0) = 1-p$, while expectation is simply $p$ and variance is $p(1-p)$. Oh, and when you're tracking whether specific events happen or not? That's an indicator random variable — just another Bernoulli in disguise!
369
+
370
+ Two key things to remember:
371
+
372
+ /// note
373
+ 💡 **Maximum Variance**: A Bernoulli's variance $p(1-p)$ reaches its maximum at $p=0.5$, making a fair coin the most "unpredictable" Bernoulli random variable.
374
+
375
+ 💡 **Instant Properties**: When you identify a random variable as Bernoulli, you instantly know all its properties—expectation, variance, PMF—without additional calculations.
376
+ ///
377
+
378
+ Next up: Binomial distribution—where we'll see what happens when we let Bernoulli trials have a party and add themselves together!
379
+ """
380
+ )
381
+ return
382
+
383
+
384
+ @app.cell(hide_code=True)
385
+ def _(mo):
386
+ mo.md(r"""#### Appendix (containing helper code for the notebook)""")
387
+ return
388
+
389
+
390
+ @app.cell
391
+ def _():
392
+ import marimo as mo
393
+ return (mo,)
394
+
395
+
396
+ @app.cell(hide_code=True)
397
+ def _():
398
+ from marimo import Html
399
+ return (Html,)
400
+
401
+
402
+ @app.cell(hide_code=True)
403
+ def _():
404
+ import numpy as np
405
+ import matplotlib.pyplot as plt
406
+ from scipy import stats
407
+ import math
408
+
409
+ # Set style for consistent visualizations
410
+ plt.style.use('seaborn-v0_8-whitegrid')
411
+ plt.rcParams['figure.figsize'] = [10, 6]
412
+ plt.rcParams['font.size'] = 12
413
+
414
+ # Set random seed for reproducibility
415
+ np.random.seed(42)
416
+ return math, np, plt, stats
417
+
418
+
419
+ @app.cell(hide_code=True)
420
+ def _(mo):
421
+ # Create a UI element for the parameter p
422
+ p_slider = mo.ui.slider(0.01, 0.99, value=0.65, step=0.01, label="Parameter p")
423
+ return (p_slider,)
424
+
425
+
426
+ if __name__ == "__main__":
427
+ app.run()
probability/14_binomial_distribution.py ADDED
@@ -0,0 +1,545 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # /// script
2
+ # requires-python = ">=3.10"
3
+ # dependencies = [
4
+ # "marimo",
5
+ # "matplotlib==3.10.0",
6
+ # "numpy==2.2.4",
7
+ # "scipy==1.15.2",
8
+ # "altair==5.2.0",
9
+ # "wigglystuff==0.1.10",
10
+ # "pandas==2.2.3",
11
+ # ]
12
+ # ///
13
+
14
+ import marimo
15
+
16
+ __generated_with = "0.11.24"
17
+ app = marimo.App(width="medium", app_title="Binomial Distribution")
18
+
19
+
20
+ @app.cell(hide_code=True)
21
+ def _(mo):
22
+ mo.md(
23
+ r"""
24
+ # Binomial Distribution
25
+
26
+ _This notebook is a computational companion to ["Probability for Computer Scientists"](https://chrispiech.github.io/probabilityForComputerScientists/en/part2/binomial/), by Stanford professor Chris Piech._
27
+
28
+ In this section, we will discuss the binomial distribution. To start, imagine the following example:
29
+
30
+ Consider $n$ independent trials of an experiment where each trial is a "success" with probability $p$. Let $X$ be the number of successes in $n$ trials.
31
+
32
+ This situation is truly common in the natural world, and as such, there has been a lot of research into such phenomena. Random variables like $X$ are called **binomial random variables**. If you can identify that a process fits this description, you can inherit many already proved properties such as the PMF formula, expectation, and variance!
33
+ """
34
+ )
35
+ return
36
+
37
+
38
+ @app.cell(hide_code=True)
39
+ def _(mo):
40
+ mo.md(
41
+ r"""
42
+ ## Binomial Random Variable Definition
43
+
44
+ $X \sim \text{Bin}(n, p)$ represents a binomial random variable where:
45
+
46
+ - $X$ is our random variable (number of successes)
47
+ - $\text{Bin}$ indicates it follows a binomial distribution
48
+ - $n$ is the number of trials
49
+ - $p$ is the probability of success in each trial
50
+
51
+ ```
52
+ X ~ Bin(n, p)
53
+ ↑ ↑ ↑
54
+ | | +-- Probability of
55
+ | | success on each
56
+ | | trial
57
+ | +-- Number of trials
58
+ |
59
+ Our random variable
60
+ is distributed
61
+ as a Binomial
62
+ ```
63
+
64
+ Here are a few examples of binomial random variables:
65
+
66
+ - Number of heads in $n$ coin flips
67
+ - Number of 1's in randomly generated length $n$ bit string
68
+ - Number of disk drives crashed in 1000 computer cluster, assuming disks crash independently
69
+ """
70
+ )
71
+ return
72
+
73
+
74
+ @app.cell(hide_code=True)
75
+ def _(mo):
76
+ mo.md(
77
+ r"""
78
+ ## Properties of Binomial Distribution
79
+
80
+ | Property | Formula |
81
+ |----------|---------|
82
+ | Notation | $X \sim \text{Bin}(n, p)$ |
83
+ | Description | Number of "successes" in $n$ identical, independent experiments each with probability of success $p$ |
84
+ | Parameters | $n \in \{0, 1, \dots\}$, the number of experiments<br>$p \in [0, 1]$, the probability that a single experiment gives a "success" |
85
+ | Support | $x \in \{0, 1, \dots, n\}$ |
86
+ | PMF equation | $P(X=x) = {n \choose x}p^x(1-p)^{n-x}$ |
87
+ | Expectation | $E[X] = n \cdot p$ |
88
+ | Variance | $\text{Var}(X) = n \cdot p \cdot (1-p)$ |
89
+
90
+ Let's explore how the binomial distribution changes with different parameters.
91
+ """
92
+ )
93
+ return
94
+
95
+
96
+ @app.cell(hide_code=True)
97
+ def _(TangleSlider, mo):
98
+ # Interactive elements using TangleSlider
99
+ n_slider = mo.ui.anywidget(TangleSlider(
100
+ amount=10,
101
+ min_value=1,
102
+ max_value=30,
103
+ step=1,
104
+ digits=0,
105
+ suffix=" trials"
106
+ ))
107
+
108
+ p_slider = mo.ui.anywidget(TangleSlider(
109
+ amount=0.5,
110
+ min_value=0.01,
111
+ max_value=0.99,
112
+ step=0.01,
113
+ digits=2,
114
+ suffix=" probability"
115
+ ))
116
+
117
+ # Grid layout for the interactive controls
118
+ controls = mo.vstack([
119
+ mo.md("### Adjust Parameters to See How Binomial Distribution Changes"),
120
+ mo.hstack([
121
+ mo.md("**Number of trials (n):** "),
122
+ n_slider
123
+ ], justify="start"),
124
+ mo.hstack([
125
+ mo.md("**Probability of success (p):** "),
126
+ p_slider
127
+ ], justify="start"),
128
+ ])
129
+ return controls, n_slider, p_slider
130
+
131
+
132
+ @app.cell(hide_code=True)
133
+ def _(controls):
134
+ controls
135
+ return
136
+
137
+
138
+ @app.cell(hide_code=True)
139
+ def _(n_slider, np, p_slider, plt, stats):
140
+ # Parameters from sliders
141
+ _n = int(n_slider.amount)
142
+ _p = p_slider.amount
143
+
144
+ # Calculate PMF
145
+ _x = np.arange(0, _n + 1)
146
+ _pmf = stats.binom.pmf(_x, _n, _p)
147
+
148
+ # Relevant stats
149
+ _mean = _n * _p
150
+ _variance = _n * _p * (1 - _p)
151
+ _std_dev = np.sqrt(_variance)
152
+
153
+ _fig, _ax = plt.subplots(figsize=(10, 6))
154
+
155
+ # Plot PMF as bars
156
+ _ax.bar(_x, _pmf, color='royalblue', alpha=0.7, label=f'PMF: P(X=k)')
157
+
158
+ # Add a line
159
+ _ax.plot(_x, _pmf, 'ro-', alpha=0.6, label='PMF line')
160
+
161
+ # Add vertical lines
162
+ _ax.axvline(x=_mean, color='green', linestyle='--', linewidth=2,
163
+ label=f'Mean: {_mean:.2f}')
164
+
165
+ # Shade the stdev region
166
+ _ax.axvspan(_mean - _std_dev, _mean + _std_dev, alpha=0.2, color='green',
167
+ label=f'±1 Std Dev: {_std_dev:.2f}')
168
+
169
+ # Add labels and title
170
+ _ax.set_xlabel('Number of Successes (k)')
171
+ _ax.set_ylabel('Probability: P(X=k)')
172
+ _ax.set_title(f'Binomial Distribution with n={_n}, p={_p:.2f}')
173
+
174
+ # Annotations
175
+ _ax.annotate(f'E[X] = {_mean:.2f}',
176
+ xy=(_mean, stats.binom.pmf(int(_mean), _n, _p)),
177
+ xytext=(_mean + 1, max(_pmf) * 0.8),
178
+ arrowprops=dict(facecolor='black', shrink=0.05, width=1))
179
+
180
+ _ax.annotate(f'Var(X) = {_variance:.2f}',
181
+ xy=(_mean, stats.binom.pmf(int(_mean), _n, _p) / 2),
182
+ xytext=(_mean + 1, max(_pmf) * 0.6),
183
+ arrowprops=dict(facecolor='black', shrink=0.05, width=1))
184
+
185
+ # Grid and legend
186
+ _ax.grid(alpha=0.3)
187
+ _ax.legend()
188
+
189
+ plt.tight_layout()
190
+ plt.gca()
191
+ return
192
+
193
+
194
+ @app.cell(hide_code=True)
195
+ def _(mo):
196
+ mo.md(
197
+ r"""
198
+ ## Relationship to Bernoulli Random Variables
199
+
200
+ One way to think of the binomial is as the sum of $n$ Bernoulli variables. Say that $Y_i$ is an indicator Bernoulli random variable which is 1 if experiment $i$ is a success. Then if $X$ is the total number of successes in $n$ experiments, $X \sim \text{Bin}(n, p)$:
201
+
202
+ $$X = \sum_{i=1}^n Y_i$$
203
+
204
+ Recall that the outcome of $Y_i$ will be 1 or 0, so one way to think of $X$ is as the sum of those 1s and 0s.
205
+ """
206
+ )
207
+ return
208
+
209
+
210
+ @app.cell(hide_code=True)
211
+ def _(mo):
212
+ mo.md(
213
+ r"""
214
+ ## Binomial Probability Mass Function (PMF)
215
+
216
+ The most important property to know about a binomial is its [Probability Mass Function](https://marimo.app/https://github.com/marimo-team/learn/blob/main/probability/10_probability_mass_function.py):
217
+
218
+ $$P(X=k) = {n \choose k}p^k(1-p)^{n-k}$$
219
+
220
+ ```
221
+ P(X = k) = (n) p^k(1-p)^(n-k)
222
+ ↑ (k)
223
+ | ↑
224
+ | +-- Binomial coefficient:
225
+ | number of ways to choose
226
+ | k successes from n trials
227
+ |
228
+ Probability that our
229
+ variable takes on the
230
+ value k
231
+ ```
232
+
233
+ Recall, we derived this formula in Part 1. There is a complete example on the probability of $k$ heads in $n$ coin flips, where each flip is heads with probability $p$.
234
+
235
+ To briefly review, if you think of each experiment as being distinct, then there are ${n \choose k}$ ways of permuting $k$ successes from $n$ experiments. For any of the mutually exclusive permutations, the probability of that permutation is $p^k \cdot (1-p)^{n-k}$.
236
+
237
+ The name binomial comes from the term ${n \choose k}$ which is formally called the binomial coefficient.
238
+ """
239
+ )
240
+ return
241
+
242
+
243
+ @app.cell(hide_code=True)
244
+ def _(mo):
245
+ mo.md(
246
+ r"""
247
+ ## Expectation of Binomial
248
+
249
+ There is an easy way to calculate the expectation of a binomial and a hard way. The easy way is to leverage the fact that a binomial is the sum of Bernoulli indicator random variables $X = \sum_{i=1}^{n} Y_i$ where $Y_i$ is an indicator of whether the $i$-th experiment was a success: $Y_i \sim \text{Bernoulli}(p)$.
250
+
251
+ Since the [expectation of the sum](http://marimo.app/https://github.com/marimo-team/learn/blob/main/probability/11_expectation.py) of random variables is the sum of expectations, we can add the expectation, $E[Y_i] = p$, of each of the Bernoulli's:
252
+
253
+ \begin{align}
254
+ E[X] &= E\Big[\sum_{i=1}^{n} Y_i\Big] && \text{Since }X = \sum_{i=1}^{n} Y_i \\
255
+ &= \sum_{i=1}^{n}E[ Y_i] && \text{Expectation of sum} \\
256
+ &= \sum_{i=1}^{n}p && \text{Expectation of Bernoulli} \\
257
+ &= n \cdot p && \text{Sum $n$ times}
258
+ \end{align}
259
+
260
+ The hard way is to use the definition of expectation:
261
+
262
+ \begin{align}
263
+ E[X] &= \sum_{i=0}^n i \cdot P(X = i) && \text{Def of expectation} \\
264
+ &= \sum_{i=0}^n i \cdot {n \choose i} p^i(1-p)^{n-i} && \text{Sub in PMF} \\
265
+ & \cdots && \text{Many steps later} \\
266
+ &= n \cdot p
267
+ \end{align}
268
+ """
269
+ )
270
+ return
271
+
272
+
273
+ @app.cell(hide_code=True)
274
+ def _(mo):
275
+ mo.md(
276
+ r"""
277
+ ## Binomial Distribution in Python
278
+
279
+ As you might expect, you can use binomial distributions in code. The standardized library for binomials is `scipy.stats.binom`.
280
+
281
+ One of the most helpful methods that this package provides is a way to calculate the PMF. For example, say $n=5$, $p=0.6$ and you want to find $P(X=2)$, you could use the following code:
282
+ """
283
+ )
284
+ return
285
+
286
+
287
+ @app.cell
288
+ def _(stats):
289
+ # define variables for x, n, and p
290
+ _n = 5 # Integer value for n
291
+ _p = 0.6
292
+ _x = 2
293
+
294
+ # use scipy to compute the pmf
295
+ p_x = stats.binom.pmf(_x, _n, _p)
296
+
297
+ # use the probability for future work
298
+ print(f'P(X = {_x}) = {p_x:.4f}')
299
+ return (p_x,)
300
+
301
+
302
+ @app.cell(hide_code=True)
303
+ def _(mo):
304
+ mo.md(r"""Another particularly helpful function is the ability to generate a random sample from a binomial. For example, say $X$ represents the number of requests to a website. We can draw 100 samples from this distribution using the following code:""")
305
+ return
306
+
307
+
308
+ @app.cell
309
+ def _(n, p, stats):
310
+ n_int = int(n)
311
+
312
+ # samples from the binomial distribution
313
+ samples = stats.binom.rvs(n_int, p, size=100)
314
+
315
+ # Print the samples
316
+ print(samples)
317
+ return n_int, samples
318
+
319
+
320
+ @app.cell(hide_code=True)
321
+ def _(n_int, np, p, plt, samples, stats):
322
+ # Plot histogram of samples
323
+ plt.figure(figsize=(10, 5))
324
+ plt.hist(samples, bins=np.arange(-0.5, n_int+1.5, 1), alpha=0.7, color='royalblue',
325
+ edgecolor='black', density=True)
326
+
327
+ # Overlay the PMF
328
+ x_values = np.arange(0, n_int+1)
329
+ pmf_values = stats.binom.pmf(x_values, n_int, p)
330
+ plt.plot(x_values, pmf_values, 'ro-', ms=8, label='Theoretical PMF')
331
+
332
+ # Add labels and title
333
+ plt.xlabel('Number of Successes')
334
+ plt.ylabel('Relative Frequency / Probability')
335
+ plt.title(f'Histogram of 100 Samples from Bin({n_int}, {p})')
336
+ plt.legend()
337
+ plt.grid(alpha=0.3)
338
+
339
+ # Annotate
340
+ plt.annotate('Sample mean: %.2f' % np.mean(samples),
341
+ xy=(0.7, 0.9), xycoords='axes fraction',
342
+ bbox=dict(boxstyle='round,pad=0.5', fc='yellow', alpha=0.3))
343
+ plt.annotate('Theoretical mean: %.2f' % (n_int*p),
344
+ xy=(0.7, 0.8), xycoords='axes fraction',
345
+ bbox=dict(boxstyle='round,pad=0.5', fc='lightgreen', alpha=0.3))
346
+
347
+ plt.tight_layout()
348
+ plt.gca()
349
+ return pmf_values, x_values
350
+
351
+
352
+ @app.cell(hide_code=True)
353
+ def _(mo):
354
+ mo.md(
355
+ r"""
356
+ You might be wondering what a random sample is! A random sample is a randomly chosen assignment for our random variable. Above we have 100 such assignments. The probability that value $k$ is chosen is given by the PMF: $P(X=k)$.
357
+
358
+ There are also functions for getting the mean, the variance, and more. You can read the [scipy.stats.binom documentation](https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.binom.html), especially the list of methods.
359
+ """
360
+ )
361
+ return
362
+
363
+
364
+ @app.cell(hide_code=True)
365
+ def _(mo):
366
+ mo.md(
367
+ r"""
368
+ ## Interactive Exploration of Binomial vs. Negative Binomial
369
+
370
+ The standard binomial distribution is a special case of a broader family of distributions. One related distribution is the negative binomial, which can model count data with overdispersion (where the variance is larger than the mean).
371
+
372
+ Below, you can explore how the negative binomial distribution compares to a Poisson distribution (which can be seen as a limiting case of the binomial as $n$ gets large and $p$ gets small, with $np$ held constant).
373
+
374
+ Adjust the sliders to see how the parameters affect the distribution:
375
+
376
+ *Note: The interactive visualization in this section was inspired by work from [liquidcarbon on GitHub](https://github.com/liquidcarbon).*
377
+ """
378
+ )
379
+ return
380
+
381
+
382
+ @app.cell(hide_code=True)
383
+ def _(alpha_slider, chart, equation, mo, mu_slider):
384
+ mo.vstack(
385
+ [
386
+ mo.md(f"## Negative Binomial Distribution (Poisson + Overdispersion)\n{equation}"),
387
+ mo.hstack([mu_slider, alpha_slider], justify="start"),
388
+ chart,
389
+ ], justify='space-around'
390
+ ).center()
391
+ return
392
+
393
+
394
+ @app.cell(hide_code=True)
395
+ def _(mo):
396
+ mo.md(
397
+ r"""
398
+ ## 🤔 Test Your Understanding
399
+ Pick which of these statements about binomial distributions you think are correct:
400
+
401
+ /// details | The variance of a binomial distribution is always equal to its mean
402
+ ❌ Incorrect! The variance is $np(1-p)$ while the mean is $np$. They're only equal when $p=1$ (which is a degenerate case).
403
+ ///
404
+
405
+ /// details | If $X \sim \text{Bin}(n, p)$ and $Y \sim \text{Bin}(n, 1-p)$, then $X$ and $Y$ have the same variance
406
+ ✅ Correct! $\text{Var}(X) = np(1-p)$ and $\text{Var}(Y) = n(1-p)p$, which are the same.
407
+ ///
408
+
409
+ /// details | As the number of trials increases, the binomial distribution approaches a normal distribution
410
+ ✅ Correct! For large $n$, the binomial distribution can be approximated by a normal distribution with the same mean and variance.
411
+ ///
412
+
413
+ /// details | The PMF of a binomial distribution is symmetric when $p = 0.5$
414
+ ✅ Correct! When $p = 0.5$, the PMF is symmetric around $n/2$.
415
+ ///
416
+
417
+ /// details | The sum of two independent binomial random variables with the same $p$ is also a binomial random variable
418
+ ✅ Correct! If $X \sim \text{Bin}(n_1, p)$ and $Y \sim \text{Bin}(n_2, p)$ are independent, then $X + Y \sim \text{Bin}(n_1 + n_2, p)$.
419
+ ///
420
+
421
+ /// details | The maximum value of the PMF for $\text{Bin}(n,p)$ always occurs at $k = np$
422
+ ❌ Incorrect! The mode (maximum value of PMF) is either $\lfloor (n+1)p \rfloor$ or $\lceil (n+1)p-1 \rceil$ depending on whether $(n+1)p$ is an integer.
423
+ ///
424
+ """
425
+ )
426
+ return
427
+
428
+
429
+ @app.cell(hide_code=True)
430
+ def _(mo):
431
+ mo.md(
432
+ r"""
433
+ ## Summary
434
+
435
+ So we've explored the binomial distribution, and honestly, it's one of the most practical probability distributions you'll encounter. Think about it — anytime you're counting successes in a fixed number of trials (like those coin flips we discussed), this is your go-to distribution.
436
+
437
+ I find it fascinating how the expectation is simply $np$. Such a clean, intuitive formula! And remember that neat visualization we saw earlier? When we adjusted the parameters, you could actually see how the distribution shape changes—becoming more symmetric as $n$ increases.
438
+
439
+ The key things to take away:
440
+
441
+ - The binomial distribution models the number of successes in $n$ independent trials, each with probability $p$ of success
442
+
443
+ - Its PMF is given by the formula $P(X=k) = {n \choose k}p^k(1-p)^{n-k}$, which lets us calculate exactly how likely any specific number of successes is
444
+
445
+ - The expected value is $E[X] = np$ and the variance is $Var(X) = np(1-p)$
446
+
447
+ - It's related to other distributions: it's essentially a sum of Bernoulli random variables, and connects to both the negative binomial and Poisson distributions
448
+
449
+ - In Python, the `scipy.stats.binom` module makes working with binomial distributions straightforward—you can generate random samples and calculate probabilities with just a few lines of code
450
+
451
+ You'll see the binomial distribution pop up everywhere—from computer science to quality control, epidemiology, and data science. Any time you have scenarios with binary outcomes over multiple trials, this distribution has you covered.
452
+ """
453
+ )
454
+ return
455
+
456
+
457
+ @app.cell(hide_code=True)
458
+ def _(mo):
459
+ mo.md(r"""Appendix code (helper functions, variables, etc.):""")
460
+ return
461
+
462
+
463
+ @app.cell
464
+ def _():
465
+ import marimo as mo
466
+ return (mo,)
467
+
468
+
469
+ @app.cell(hide_code=True)
470
+ def _():
471
+ import numpy as np
472
+ import matplotlib.pyplot as plt
473
+ import scipy.stats as stats
474
+ import pandas as pd
475
+ import altair as alt
476
+ from wigglystuff import TangleSlider
477
+ return TangleSlider, alt, np, pd, plt, stats
478
+
479
+
480
+ @app.cell(hide_code=True)
481
+ def _(mo):
482
+ alpha_slider = mo.ui.slider(
483
+ value=0.1,
484
+ steps=[0, 0.01, 0.02, 0.03, 0.05, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.8, 1],
485
+ label="α (overdispersion)",
486
+ show_value=True,
487
+ )
488
+ mu_slider = mo.ui.slider(
489
+ value=100, start=1, stop=100, step=1, label="μ (mean)", show_value=True
490
+ )
491
+ return alpha_slider, mu_slider
492
+
493
+
494
+ @app.cell(hide_code=True)
495
+ def _():
496
+ equation = """
497
+ $$
498
+ P(X = k) = \\frac{\\Gamma(k + \\frac{1}{\\alpha})}{\\Gamma(k + 1) \\Gamma(\\frac{1}{\\alpha})} \\left( \\frac{1}{\\mu \\alpha + 1} \\right)^{\\frac{1}{\\alpha}} \\left( \\frac{\\mu \\alpha}{\\mu \\alpha + 1} \\right)^k
499
+ $$
500
+
501
+ $$
502
+ \\sigma^2 = \\mu + \\alpha \\mu^2
503
+ $$
504
+ """
505
+ return (equation,)
506
+
507
+
508
+ @app.cell(hide_code=True)
509
+ def _(alpha_slider, alt, mu_slider, np, pd, stats):
510
+ mu = mu_slider.value
511
+ alpha = alpha_slider.value
512
+ n = 1000 - mu if alpha == 0 else 1 / alpha
513
+ p = n / (mu + n)
514
+ x = np.arange(0, mu * 3 + 1, 1)
515
+ df = pd.DataFrame(
516
+ {
517
+ "x": x,
518
+ "y": stats.nbinom.pmf(x, n, p),
519
+ "y_poi": stats.nbinom.pmf(x, 1000 - mu, 1 - mu / 1000),
520
+ }
521
+ )
522
+ r1k = stats.nbinom.rvs(n, p, size=1000)
523
+ df["in 95% CI"] = df["x"].between(*np.percentile(r1k, q=[2.5, 97.5]))
524
+ base = alt.Chart(df)
525
+
526
+ chart_poi = base.mark_bar(
527
+ fillOpacity=0.25, width=100 / mu, fill="magenta"
528
+ ).encode(
529
+ x=alt.X("x").scale(domain=(-0.4, x.max() + 0.4), nice=False),
530
+ y=alt.Y("y_poi").scale(domain=(0, df.y_poi.max() * 1.1)).title(None),
531
+ )
532
+ chart_nb = base.mark_bar(fillOpacity=0.75, width=100 / mu).encode(
533
+ x="x",
534
+ y="y",
535
+ fill=alt.Fill("in 95% CI")
536
+ .scale(domain=[False, True], range=["#aaa", "#7c7"])
537
+ .legend(orient="bottom-right"),
538
+ )
539
+
540
+ chart = (chart_poi + chart_nb).configure_view(continuousWidth=450)
541
+ return alpha, base, chart, chart_nb, chart_poi, df, mu, n, p, r1k, x
542
+
543
+
544
+ if __name__ == "__main__":
545
+ app.run()
probability/15_poisson_distribution.py ADDED
@@ -0,0 +1,805 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # /// script
2
+ # requires-python = ">=3.10"
3
+ # dependencies = [
4
+ # "marimo",
5
+ # "matplotlib==3.10.0",
6
+ # "numpy==2.2.4",
7
+ # "scipy==1.15.2",
8
+ # "altair==5.2.0",
9
+ # "wigglystuff==0.1.10",
10
+ # "pandas==2.2.3",
11
+ # ]
12
+ # ///
13
+
14
+ import marimo
15
+
16
+ __generated_with = "0.11.25"
17
+ app = marimo.App(width="medium", app_title="Poisson Distribution")
18
+
19
+
20
+ @app.cell(hide_code=True)
21
+ def _(mo):
22
+ mo.md(
23
+ r"""
24
+ # Poisson Distribution
25
+
26
+ _This notebook is a computational companion to ["Probability for Computer Scientists"](https://chrispiech.github.io/probabilityForComputerScientists/en/part2/poisson/), by Stanford professor Chris Piech._
27
+
28
+ A Poisson random variable gives the probability of a given number of events in a fixed interval of time (or space). It makes the Poisson assumption that events occur with a known constant mean rate and independently of the time since the last event.
29
+ """
30
+ )
31
+ return
32
+
33
+
34
+ @app.cell(hide_code=True)
35
+ def _(mo):
36
+ mo.md(
37
+ r"""
38
+ ## Poisson Random Variable Definition
39
+
40
+ $X \sim \text{Poisson}(\lambda)$ represents a Poisson random variable where:
41
+
42
+ - $X$ is our random variable (number of events)
43
+ - $\text{Poisson}$ indicates it follows a Poisson distribution
44
+ - $\lambda$ is the rate parameter (average number of events per time interval)
45
+
46
+ ```
47
+ X ~ Poisson(λ)
48
+ ↑ ↑ ↑
49
+ | | +-- Rate parameter:
50
+ | | average number of
51
+ | | events per interval
52
+ | +-- Indicates Poisson
53
+ | distribution
54
+ |
55
+ Our random variable
56
+ counting number of events
57
+ ```
58
+
59
+ The Poisson distribution is particularly useful when:
60
+
61
+ 1. Events occur independently of each other
62
+ 2. The average rate of occurrence is constant
63
+ 3. Two events cannot occur at exactly the same instant
64
+ 4. The probability of an event is proportional to the length of the time interval
65
+ """
66
+ )
67
+ return
68
+
69
+
70
+ @app.cell(hide_code=True)
71
+ def _(mo):
72
+ mo.md(
73
+ r"""
74
+ ## Properties of Poisson Distribution
75
+
76
+ | Property | Formula |
77
+ |----------|---------|
78
+ | Notation | $X \sim \text{Poisson}(\lambda)$ |
79
+ | Description | Number of events in a fixed time frame if (a) events occur with a constant mean rate and (b) they occur independently of time since last event |
80
+ | Parameters | $\lambda \in \mathbb{R}^{+}$, the constant average rate |
81
+ | Support | $x \in \{0, 1, \dots\}$ |
82
+ | PMF equation | $P(X=x) = \frac{\lambda^x e^{-\lambda}}{x!}$ |
83
+ | Expectation | $E[X] = \lambda$ |
84
+ | Variance | $\text{Var}(X) = \lambda$ |
85
+
86
+ Note that unlike many other distributions, the Poisson distribution's mean and variance are equal, both being $\lambda$.
87
+
88
+ Let's explore how the Poisson distribution changes with different rate parameters.
89
+ """
90
+ )
91
+ return
92
+
93
+
94
+ @app.cell(hide_code=True)
95
+ def _(TangleSlider, mo):
96
+ # interactive elements using TangleSlider
97
+ lambda_slider = mo.ui.anywidget(TangleSlider(
98
+ amount=5,
99
+ min_value=0.1,
100
+ max_value=20,
101
+ step=0.1,
102
+ digits=1,
103
+ suffix=" events"
104
+ ))
105
+
106
+ # interactive controls
107
+ _controls = mo.vstack([
108
+ mo.md("### Adjust the Rate Parameter to See How Poisson Distribution Changes"),
109
+ mo.hstack([
110
+ mo.md("**Rate parameter (λ):** "),
111
+ lambda_slider,
112
+ mo.md("**events per interval.** Higher values shift the distribution rightward and make it more spread out.")
113
+ ], justify="start"),
114
+ ])
115
+ _controls
116
+ return (lambda_slider,)
117
+
118
+
119
+ @app.cell(hide_code=True)
120
+ def _(lambda_slider, np, plt, stats):
121
+ def create_poisson_pmf_plot(lambda_value):
122
+ """Create a visualization of Poisson PMF with annotations for mean and variance."""
123
+ # PMF for values
124
+ max_x = max(20, int(lambda_value * 3)) # Show at least up to 3*lambda
125
+ x = np.arange(0, max_x + 1)
126
+ pmf = stats.poisson.pmf(x, lambda_value)
127
+
128
+ # Relevant key statistics
129
+ mean = lambda_value # For Poisson, mean = lambda
130
+ variance = lambda_value # For Poisson, variance = lambda
131
+ std_dev = np.sqrt(variance)
132
+
133
+ # plot
134
+ fig, ax = plt.subplots(figsize=(10, 6))
135
+
136
+ # PMF as bars
137
+ ax.bar(x, pmf, color='royalblue', alpha=0.7, label=f'PMF: P(X=k)')
138
+
139
+ # for the PMF values
140
+ ax.plot(x, pmf, 'ro-', alpha=0.6, label='PMF line')
141
+
142
+ # Vertical lines - mean and key values
143
+ ax.axvline(x=mean, color='green', linestyle='--', linewidth=2,
144
+ label=f'Mean: {mean:.2f}')
145
+
146
+ # Stdev region
147
+ ax.axvspan(mean - std_dev, mean + std_dev, alpha=0.2, color='green',
148
+ label=f'±1 Std Dev: {std_dev:.2f}')
149
+
150
+ ax.set_xlabel('Number of Events (k)')
151
+ ax.set_ylabel('Probability: P(X=k)')
152
+ ax.set_title(f'Poisson Distribution with λ={lambda_value:.1f}')
153
+
154
+ # annotations
155
+ ax.annotate(f'E[X] = {mean:.2f}',
156
+ xy=(mean, stats.poisson.pmf(int(mean), lambda_value)),
157
+ xytext=(mean + 1, max(pmf) * 0.8),
158
+ arrowprops=dict(facecolor='black', shrink=0.05, width=1))
159
+
160
+ ax.annotate(f'Var(X) = {variance:.2f}',
161
+ xy=(mean, stats.poisson.pmf(int(mean), lambda_value) / 2),
162
+ xytext=(mean + 1, max(pmf) * 0.6),
163
+ arrowprops=dict(facecolor='black', shrink=0.05, width=1))
164
+
165
+ ax.grid(alpha=0.3)
166
+ ax.legend()
167
+
168
+ plt.tight_layout()
169
+ return plt.gca()
170
+
171
+ # Get parameter from slider and create plot
172
+ _lambda = lambda_slider.amount
173
+ create_poisson_pmf_plot(_lambda)
174
+ return (create_poisson_pmf_plot,)
175
+
176
+
177
+ @app.cell(hide_code=True)
178
+ def _(mo):
179
+ mo.md(
180
+ r"""
181
+ ## Poisson Intuition: Relation to Binomial Distribution
182
+
183
+ The Poisson distribution can be derived as a limiting case of the [binomial distribution](http://marimo.app/https://github.com/marimo-team/learn/blob/main/probability/14_binomial_distribution.py).
184
+
185
+ Let's work on a practical example: predicting the number of ride-sharing requests in a specific area over a one-minute interval. From historical data, we know that the average number of requests per minute is $\lambda = 5$.
186
+
187
+ We could approximate this using a binomial distribution by dividing our minute into smaller intervals. For example, we can divide a minute into 60 seconds and treat each second as a [Bernoulli trial](http://marimo.app/https://github.com/marimo-team/learn/blob/main/probability/13_bernoulli_distribution.py) - either there's a request (success) or there isn't (failure).
188
+
189
+ Let's visualize this concept:
190
+ """
191
+ )
192
+ return
193
+
194
+
195
+ @app.cell(hide_code=True)
196
+ def _(fig_to_image, mo, plt):
197
+ def create_time_division_visualization():
198
+ # visualization of dividing a minute into 60 seconds
199
+ fig, ax = plt.subplots(figsize=(12, 2))
200
+
201
+ # Example events hardcoded at 2.75s and 7.12s
202
+ events = [2.75, 7.12]
203
+
204
+ # array of 60 rectangles
205
+ for i in range(60):
206
+ color = 'royalblue' if any(i <= e < i+1 for e in events) else 'lightgray'
207
+ ax.add_patch(plt.Rectangle((i, 0), 0.9, 1, color=color))
208
+
209
+ # markers for events
210
+ for e in events:
211
+ ax.plot(e, 0.5, 'ro', markersize=10)
212
+
213
+ # labels
214
+ ax.set_xlim(0, 60)
215
+ ax.set_ylim(0, 1)
216
+ ax.set_yticks([])
217
+ ax.set_xticks([0, 15, 30, 45, 60])
218
+ ax.set_xticklabels(['0s', '15s', '30s', '45s', '60s'])
219
+ ax.set_xlabel('Time (seconds)')
220
+ ax.set_title('One Minute Divided into 60 Second Intervals')
221
+
222
+ plt.tight_layout()
223
+ plt.gca()
224
+ return fig, events, i
225
+
226
+ # Create visualization and convert to image
227
+ _fig, _events, i = create_time_division_visualization()
228
+ _img = mo.image(fig_to_image(_fig), width="100%")
229
+
230
+ # explanation
231
+ _explanation = mo.md(
232
+ r"""
233
+ In this visualization:
234
+
235
+ - Each rectangle represents a 1-second interval
236
+ - Blue rectangles indicate intervals where an event occurred
237
+ - Red dots show the actual event times (2.75s and 7.12s)
238
+
239
+ If we treat this as a binomial experiment with 60 trials (seconds), we can calculate probabilities using the binomial PMF. But there's a problem: what if multiple events occur within the same second? To address this, we can divide our minute into smaller intervals.
240
+ """
241
+ )
242
+ mo.vstack([_fig, _explanation])
243
+ return create_time_division_visualization, i
244
+
245
+
246
+ @app.cell(hide_code=True)
247
+ def _(mo):
248
+ mo.md(
249
+ r"""
250
+ The total number of requests received over the minute can be approximated as the sum of the sixty indicator variables, which conveniently matches the description of a binomial — a sum of Bernoullis.
251
+
252
+ Specifically, if we define $X$ to be the number of requests in a minute, $X$ is a binomial with $n=60$ trials. What is the probability, $p$, of a success on a single trial? To make the expectation of $X$ equal the observed historical average $\lambda$, we should choose $p$ so that:
253
+
254
+ \begin{align}
255
+ \lambda &= E[X] && \text{Expectation matches historical average} \\
256
+ \lambda &= n \cdot p && \text{Expectation of a Binomial is } n \cdot p \\
257
+ p &= \frac{\lambda}{n} && \text{Solving for $p$}
258
+ \end{align}
259
+
260
+ In this case, since $\lambda=5$ and $n=60$, we should choose $p=\frac{5}{60}=\frac{1}{12}$ and state that $X \sim \text{Bin}(n=60, p=\frac{5}{60})$. Now we can calculate the probability of different numbers of requests using the binomial PMF:
261
+
262
+ $P(X = x) = {n \choose x} p^x (1-p)^{n-x}$
263
+
264
+ For example:
265
+
266
+ \begin{align}
267
+ P(X=1) &= {60 \choose 1} (5/60)^1 (55/60)^{60-1} \approx 0.0295 \\
268
+ P(X=2) &= {60 \choose 2} (5/60)^2 (55/60)^{60-2} \approx 0.0790 \\
269
+ P(X=3) &= {60 \choose 3} (5/60)^3 (55/60)^{60-3} \approx 0.1389
270
+ \end{align}
271
+
272
+ This is a good approximation, but it doesn't account for the possibility of multiple events in a single second. One solution is to divide our minute into even more fine-grained intervals. Let's try 600 deciseconds (tenths of a second):
273
+ """
274
+ )
275
+ return
276
+
277
+
278
+ @app.cell(hide_code=True)
279
+ def _(fig_to_image, mo, plt):
280
+ def create_decisecond_visualization(e_value):
281
+ # (Just showing the first 100 for clarity)
282
+ fig, ax = plt.subplots(figsize=(12, 2))
283
+
284
+ # Example events at 2.75s and 7.12s (convert to deciseconds)
285
+ events = [27.5, 71.2]
286
+
287
+ for i in range(100):
288
+ color = 'royalblue' if any(i <= event_val < i + 1 for event_val in events) else 'lightgray'
289
+ ax.add_patch(plt.Rectangle((i, 0), 0.9, 1, color=color))
290
+
291
+ # Markers for events
292
+ for event in events:
293
+ if event < 100: # Only show events in our visible range
294
+ ax.plot(event/10, 0.5, 'ro', markersize=10) # Divide by 10 to convert to deciseconds
295
+
296
+ # Add labels
297
+ ax.set_xlim(0, 100)
298
+ ax.set_ylim(0, 1)
299
+ ax.set_yticks([])
300
+ ax.set_xticks([0, 20, 40, 60, 80, 100])
301
+ ax.set_xticklabels(['0s', '2s', '4s', '6s', '8s', '10s'])
302
+ ax.set_xlabel('Time (first 10 seconds shown)')
303
+ ax.set_title('One Minute Divided into 600 Decisecond Intervals (first 100 shown)')
304
+
305
+ plt.tight_layout()
306
+ plt.gca()
307
+ return fig
308
+
309
+ # Create viz and convert to image
310
+ _fig = create_decisecond_visualization(e_value=5)
311
+ _img = mo.image(fig_to_image(_fig), width="100%")
312
+
313
+ # Explanation
314
+ _explanation = mo.md(
315
+ r"""
316
+ With $n=600$ and $p=\frac{5}{600}=\frac{1}{120}$, we can recalculate our probabilities:
317
+
318
+ \begin{align}
319
+ P(X=1) &= {600 \choose 1} (5/600)^1 (595/600)^{600-1} \approx 0.0333 \\
320
+ P(X=2) &= {600 \choose 2} (5/600)^2 (595/600)^{600-2} \approx 0.0837 \\
321
+ P(X=3) &= {600 \choose 3} (5/600)^3 (595/600)^{600-3} \approx 0.1402
322
+ \end{align}
323
+
324
+ As we make our intervals smaller (increasing $n$), our approximation becomes more accurate.
325
+ """
326
+ )
327
+ mo.vstack([_fig, _explanation])
328
+ return (create_decisecond_visualization,)
329
+
330
+
331
+ @app.cell(hide_code=True)
332
+ def _(mo):
333
+ mo.md(
334
+ r"""
335
+ ## The Binomial Distribution in the Limit
336
+
337
+ What happens if we continue dividing our time interval into smaller and smaller pieces? Let's explore how the probabilities change as we increase the number of intervals:
338
+ """
339
+ )
340
+ return
341
+
342
+
343
+ @app.cell(hide_code=True)
344
+ def _(mo):
345
+ intervals_slider = mo.ui.slider(
346
+ start = 60,
347
+ stop = 10000,
348
+ step=100,
349
+ value=600,
350
+ label="Number of intervals to divide a minute")
351
+ return (intervals_slider,)
352
+
353
+
354
+ @app.cell(hide_code=True)
355
+ def _(intervals_slider):
356
+ intervals_slider
357
+ return
358
+
359
+
360
+ @app.cell(hide_code=True)
361
+ def _(intervals_slider, np, pd, plt, stats):
362
+ def create_comparison_plot(n, lambda_value):
363
+ # Calculate probability
364
+ p = lambda_value / n
365
+
366
+ # Binomial probabilities
367
+ x_values = np.arange(0, 15)
368
+ binom_pmf = stats.binom.pmf(x_values, n, p)
369
+
370
+ # True Poisson probabilities
371
+ poisson_pmf = stats.poisson.pmf(x_values, lambda_value)
372
+
373
+ # DF for comparison
374
+ df = pd.DataFrame({
375
+ 'Events': x_values,
376
+ f'Binomial(n={n}, p={p:.6f})': binom_pmf,
377
+ f'Poisson(λ=5)': poisson_pmf,
378
+ 'Difference': np.abs(binom_pmf - poisson_pmf)
379
+ })
380
+
381
+ # Plot both PMFs
382
+ fig, ax = plt.subplots(figsize=(10, 6))
383
+
384
+ # Bar plot for the binomial
385
+ ax.bar(x_values - 0.2, binom_pmf, width=0.4, alpha=0.7,
386
+ color='royalblue', label=f'Binomial(n={n}, p={p:.6f})')
387
+
388
+ # Bar plot for the Poisson
389
+ ax.bar(x_values + 0.2, poisson_pmf, width=0.4, alpha=0.7,
390
+ color='crimson', label='Poisson(λ=5)')
391
+
392
+ # Labels and title
393
+ ax.set_xlabel('Number of Events (k)')
394
+ ax.set_ylabel('Probability')
395
+ ax.set_title(f'Comparison of Binomial and Poisson PMFs with n={n}')
396
+ ax.legend()
397
+ ax.set_xticks(x_values)
398
+ ax.grid(alpha=0.3)
399
+
400
+ plt.tight_layout()
401
+ return df, fig, n, p
402
+
403
+ # Number of intervals from the slider
404
+ n = intervals_slider.value
405
+ _lambda = 5 # Fixed lambda for our example
406
+
407
+ # Cromparison plot
408
+ df, fig, n, p = create_comparison_plot(n, _lambda)
409
+ return create_comparison_plot, df, fig, n, p
410
+
411
+
412
+ @app.cell(hide_code=True)
413
+ def _(df, fig, fig_to_image, mo, n, p):
414
+ # table of values
415
+ _styled_df = df.style.format({
416
+ f'Binomial(n={n}, p={p:.6f})': '{:.6f}',
417
+ f'Poisson(λ=5)': '{:.6f}',
418
+ 'Difference': '{:.6f}'
419
+ })
420
+
421
+ # Calculate the max absolute difference
422
+ _max_diff = df['Difference'].max()
423
+
424
+ # output
425
+ _chart = mo.image(fig_to_image(fig), width="100%")
426
+ _explanation = mo.md(f"**Maximum absolute difference between distributions: {_max_diff:.6f}**")
427
+ _table = mo.ui.table(df)
428
+
429
+ mo.vstack([_chart, _explanation, _table])
430
+ return
431
+
432
+
433
+ @app.cell(hide_code=True)
434
+ def _(mo):
435
+ mo.md(
436
+ r"""
437
+ As you can see from the interactive comparison above, as the number of intervals increases, the binomial distribution approaches the Poisson distribution! This is not a coincidence - the Poisson distribution is actually the limiting case of the binomial distribution when:
438
+
439
+ - The number of trials $n$ approaches infinity
440
+ - The probability of success $p$ approaches zero
441
+ - The product $np = \lambda$ remains constant
442
+
443
+ This relationship is why the Poisson distribution is so useful - it's easier to work with than a binomial with a very large number of trials and a very small probability of success.
444
+
445
+ ## Derivation of the Poisson PMF
446
+
447
+ Let's derive the Poisson PMF by taking the limit of the binomial PMF as $n \to \infty$. We start with:
448
+
449
+ $P(X=x) = \lim_{n \rightarrow \infty} {n \choose x} (\lambda / n)^x(1-\lambda/n)^{n-x}$
450
+
451
+ While this expression looks intimidating, it simplifies nicely:
452
+
453
+ \begin{align}
454
+ P(X=x)
455
+ &= \lim_{n \rightarrow \infty} {n \choose x} (\lambda / n)^x(1-\lambda/n)^{n-x}
456
+ && \text{Start: binomial in the limit}\\
457
+ &= \lim_{n \rightarrow \infty}
458
+ {n \choose x} \cdot
459
+ \frac{\lambda^x}{n^x} \cdot
460
+ \frac{(1-\lambda/n)^{n}}{(1-\lambda/n)^{x}}
461
+ && \text{Expanding the power terms} \\
462
+ &= \lim_{n \rightarrow \infty}
463
+ \frac{n!}{(n-x)!x!} \cdot
464
+ \frac{\lambda^x}{n^x} \cdot
465
+ \frac{(1-\lambda/n)^{n}}{(1-\lambda/n)^{x}}
466
+ && \text{Expanding the binomial term} \\
467
+ &= \lim_{n \rightarrow \infty}
468
+ \frac{n!}{(n-x)!x!} \cdot
469
+ \frac{\lambda^x}{n^x} \cdot
470
+ \frac{e^{-\lambda}}{(1-\lambda/n)^{x}}
471
+ && \text{Using limit rule } \lim_{n \rightarrow \infty}(1-\lambda/n)^{n} = e^{-\lambda}\\
472
+ &= \lim_{n \rightarrow \infty}
473
+ \frac{n!}{(n-x)!x!} \cdot
474
+ \frac{\lambda^x}{n^x} \cdot
475
+ \frac{e^{-\lambda}}{1}
476
+ && \text{As } n \to \infty \text{, } \lambda/n \to 0\\
477
+ &= \lim_{n \rightarrow \infty}
478
+ \frac{n!}{(n-x)!} \cdot
479
+ \frac{1}{x!} \cdot
480
+ \frac{\lambda^x}{n^x} \cdot
481
+ e^{-\lambda}
482
+ && \text{Rearranging terms}\\
483
+ &= \lim_{n \rightarrow \infty}
484
+ \frac{n^x}{1} \cdot
485
+ \frac{1}{x!} \cdot
486
+ \frac{\lambda^x}{n^x} \cdot
487
+ e^{-\lambda}
488
+ && \text{As } n \to \infty \text{, } \frac{n!}{(n-x)!} \approx n^x\\
489
+ &= \lim_{n \rightarrow \infty}
490
+ \frac{\lambda^x}{x!} \cdot
491
+ e^{-\lambda}
492
+ && \text{Canceling } n^x\\
493
+ &=
494
+ \frac{\lambda^x \cdot e^{-\lambda}}{x!}
495
+ && \text{Simplifying}\\
496
+ \end{align}
497
+
498
+ This gives us our elegant Poisson PMF formula: $P(X=x) = \frac{\lambda^x \cdot e^{-\lambda}}{x!}$
499
+ """
500
+ )
501
+ return
502
+
503
+
504
+ @app.cell(hide_code=True)
505
+ def _(mo):
506
+ mo.md(
507
+ r"""
508
+ ## Poisson Distribution in Python
509
+
510
+ Python's `scipy.stats` module provides functions to work with the Poisson distribution. Let's see how to calculate probabilities and generate random samples.
511
+
512
+ First, let's calculate some probabilities for our ride-sharing example with $\lambda = 5$:
513
+ """
514
+ )
515
+ return
516
+
517
+
518
+ @app.cell
519
+ def _(stats):
520
+ _lambda = 5
521
+
522
+ # Calculate probabilities for X = 1, 2, 3
523
+ p_1 = stats.poisson.pmf(1, _lambda)
524
+ p_2 = stats.poisson.pmf(2, _lambda)
525
+ p_3 = stats.poisson.pmf(3, _lambda)
526
+
527
+ print(f"P(X=1) = {p_1:.5f}")
528
+ print(f"P(X=2) = {p_2:.5f}")
529
+ print(f"P(X=3) = {p_3:.5f}")
530
+
531
+ # Calculate cumulative probability P(X ≤ 3)
532
+ p_leq_3 = stats.poisson.cdf(3, _lambda)
533
+ print(f"P(X≤3) = {p_leq_3:.5f}")
534
+
535
+ # Calculate probability P(X > 10)
536
+ p_gt_10 = 1 - stats.poisson.cdf(10, _lambda)
537
+ print(f"P(X>10) = {p_gt_10:.5f}")
538
+ return p_1, p_2, p_3, p_gt_10, p_leq_3
539
+
540
+
541
+ @app.cell(hide_code=True)
542
+ def _(mo):
543
+ mo.md(r"""We can also generate random samples from a Poisson distribution and visualize their distribution:""")
544
+ return
545
+
546
+
547
+ @app.cell(hide_code=True)
548
+ def _(np, plt, stats):
549
+ def create_samples_plot(lambda_value, sample_size=1000):
550
+ # Random samples
551
+ samples = stats.poisson.rvs(lambda_value, size=sample_size)
552
+
553
+ # theoretical PMF
554
+ x_values = np.arange(0, max(samples) + 1)
555
+ pmf_values = stats.poisson.pmf(x_values, lambda_value)
556
+
557
+ # histograms to compare
558
+ fig, ax = plt.subplots(figsize=(10, 6))
559
+
560
+ # samples as a histogram
561
+ ax.hist(samples, bins=np.arange(-0.5, max(samples) + 1.5, 1),
562
+ alpha=0.7, density=True, label='Random Samples')
563
+
564
+ # theoretical PMF
565
+ ax.plot(x_values, pmf_values, 'ro-', label='Theoretical PMF')
566
+
567
+ # labels and title
568
+ ax.set_xlabel('Number of Events')
569
+ ax.set_ylabel('Relative Frequency / Probability')
570
+ ax.set_title(f'1000 Random Samples from Poisson(λ={lambda_value})')
571
+ ax.legend()
572
+ ax.grid(alpha=0.3)
573
+
574
+ # annotations
575
+ ax.annotate(f'Sample Mean: {np.mean(samples):.2f}',
576
+ xy=(0.7, 0.9), xycoords='axes fraction',
577
+ bbox=dict(boxstyle='round,pad=0.5', fc='yellow', alpha=0.3))
578
+ ax.annotate(f'Theoretical Mean: {lambda_value:.2f}',
579
+ xy=(0.7, 0.8), xycoords='axes fraction',
580
+ bbox=dict(boxstyle='round,pad=0.5', fc='lightgreen', alpha=0.3))
581
+
582
+ plt.tight_layout()
583
+ return plt.gca()
584
+
585
+ # Use a lambda value of 5 for this example
586
+ _lambda = 5
587
+ create_samples_plot(_lambda)
588
+ return (create_samples_plot,)
589
+
590
+
591
+ @app.cell(hide_code=True)
592
+ def _(mo):
593
+ mo.md(
594
+ r"""
595
+ ## Changing Time Frames
596
+
597
+ One important property of the Poisson distribution is that the rate parameter $\lambda$ scales linearly with the time interval. If events occur at a rate of $\lambda$ per unit time, then over a period of $t$ units, the rate parameter becomes $\lambda \cdot t$.
598
+
599
+ For example, if a website receives an average of 5 requests per minute, what is the distribution of requests over a 20-minute period?
600
+
601
+ The rate parameter for the 20-minute period would be $\lambda = 5 \cdot 20 = 100$ requests.
602
+ """
603
+ )
604
+ return
605
+
606
+
607
+ @app.cell(hide_code=True)
608
+ def _(mo):
609
+ rate_slider = mo.ui.slider(
610
+ start = 0.1,
611
+ stop = 10,
612
+ step=0.1,
613
+ value=5,
614
+ label="Rate per unit time (λ)"
615
+ )
616
+
617
+ time_slider = mo.ui.slider(
618
+ start = 1,
619
+ stop = 60,
620
+ step=1,
621
+ value=20,
622
+ label="Time period (t units)"
623
+ )
624
+
625
+ controls = mo.vstack([
626
+ mo.md("### Adjust Parameters to See How Time Scaling Works"),
627
+ mo.hstack([rate_slider, time_slider], justify="space-between")
628
+ ])
629
+ return controls, rate_slider, time_slider
630
+
631
+
632
+ @app.cell
633
+ def _(controls):
634
+ controls.center()
635
+ return
636
+
637
+
638
+ @app.cell(hide_code=True)
639
+ def _(mo, np, plt, rate_slider, stats, time_slider):
640
+ def create_time_scaling_plot(rate, time_period):
641
+ # scaled rate parameter
642
+ lambda_value = rate * time_period
643
+
644
+ # PMF for values
645
+ max_x = max(30, int(lambda_value * 1.5))
646
+ x = np.arange(0, max_x + 1)
647
+ pmf = stats.poisson.pmf(x, lambda_value)
648
+
649
+ # plot
650
+ fig, ax = plt.subplots(figsize=(10, 6))
651
+
652
+ # PMF as bars
653
+ ax.bar(x, pmf, color='royalblue', alpha=0.7,
654
+ label=f'PMF: Poisson(λ={lambda_value:.1f})')
655
+
656
+ # vertical line for mean
657
+ ax.axvline(x=lambda_value, color='red', linestyle='--', linewidth=2,
658
+ label=f'Mean = {lambda_value:.1f}')
659
+
660
+ # labels and title
661
+ ax.set_xlabel('Number of Events')
662
+ ax.set_ylabel('Probability')
663
+ ax.set_title(f'Poisson Distribution Over {time_period} Units (Rate = {rate}/unit)')
664
+
665
+ # better visualization if lambda is large
666
+ if lambda_value > 10:
667
+ ax.set_xlim(lambda_value - 4*np.sqrt(lambda_value),
668
+ lambda_value + 4*np.sqrt(lambda_value))
669
+
670
+ ax.legend()
671
+ ax.grid(alpha=0.3)
672
+
673
+ plt.tight_layout()
674
+
675
+ # Create relevant info markdown
676
+ info_text = f"""
677
+ When the rate is **{rate}** events per unit time and we observe for **{time_period}** units:
678
+
679
+ - The expected number of events is **{lambda_value:.1f}**
680
+ - The variance is also **{lambda_value:.1f}**
681
+ - The standard deviation is **{np.sqrt(lambda_value):.2f}**
682
+ - P(X=0) = {stats.poisson.pmf(0, lambda_value):.4f} (probability of no events)
683
+ - P(X≥10) = {1 - stats.poisson.cdf(9, lambda_value):.4f} (probability of 10 or more events)
684
+ """
685
+
686
+ return plt.gca(), info_text
687
+
688
+ # parameters from sliders
689
+ _rate = rate_slider.value
690
+ _time = time_slider.value
691
+
692
+ # store
693
+ _plot, _info_text = create_time_scaling_plot(_rate, _time)
694
+
695
+ # Display info as markdown
696
+ info = mo.md(_info_text)
697
+
698
+ mo.vstack([_plot, info], justify="center")
699
+ return create_time_scaling_plot, info
700
+
701
+
702
+ @app.cell(hide_code=True)
703
+ def _(mo):
704
+ mo.md(
705
+ r"""
706
+ ## 🤔 Test Your Understanding
707
+ Pick which of these statements about Poisson distributions you think are correct:
708
+
709
+ /// details | The variance of a Poisson distribution is always equal to its mean
710
+ ✅ Correct! For a Poisson distribution with parameter $\lambda$, both the mean and variance equal $\lambda$.
711
+ ///
712
+
713
+ /// details | The Poisson distribution can be used to model the number of successes in a fixed number of trials
714
+ ❌ Incorrect! That's the binomial distribution. The Poisson distribution models the number of events in a fixed interval of time or space, not a fixed number of trials.
715
+ ///
716
+
717
+ /// details | If $X \sim \text{Poisson}(\lambda_1)$ and $Y \sim \text{Poisson}(\lambda_2)$ are independent, then $X + Y \sim \text{Poisson}(\lambda_1 + \lambda_2)$
718
+ ✅ Correct! The sum of independent Poisson random variables is also a Poisson random variable with parameter equal to the sum of the individual parameters.
719
+ ///
720
+
721
+ /// details | As $\lambda$ increases, the Poisson distribution approaches a normal distribution
722
+ ✅ Correct! For large values of $\lambda$ (generally $\lambda > 10$), the Poisson distribution is approximately normal with mean $\lambda$ and variance $\lambda$.
723
+ ///
724
+
725
+ /// details | The probability of zero events in a Poisson process is always less than the probability of one event
726
+ ❌ Incorrect! For $\lambda < 1$, the probability of zero events ($e^{-\lambda}$) is actually greater than the probability of one event ($\lambda e^{-\lambda}$).
727
+ ///
728
+
729
+ /// details | The Poisson distribution has a single parameter $\lambda$, which always equals the average number of events per time period
730
+ ✅ Correct! The parameter $\lambda$ represents the average rate of events, and it uniquely defines the distribution.
731
+ ///
732
+ """
733
+ )
734
+ return
735
+
736
+
737
+ @app.cell(hide_code=True)
738
+ def _(mo):
739
+ mo.md(
740
+ r"""
741
+ ## Summary
742
+
743
+ The Poisson distribution is one of those incredibly useful tools that shows up all over the place. I've always found it fascinating how such a simple formula can model so many real-world phenomena - from website traffic to radioactive decay.
744
+
745
+ What makes the Poisson really cool is that it emerges naturally as we try to model rare events occurring over a continuous interval. Remember that visualization where we kept dividing time into smaller and smaller chunks? As we showed, when you take a binomial distribution and let the number of trials approach infinity while keeping the expected value constant, you end up with the elegant Poisson formula.
746
+
747
+ The key things to remember about the Poisson distribution:
748
+
749
+ - It models the number of events occurring in a fixed interval of time or space, assuming events happen at a constant average rate and independently of each other
750
+
751
+ - Its PMF is given by the elegantly simple formula $P(X=k) = \frac{\lambda^k e^{-\lambda}}{k!}$
752
+
753
+ - Both the mean and variance equal the parameter $\lambda$, which represents the average number of events per interval
754
+
755
+ - It's related to the binomial distribution as a limiting case when $n \to \infty$, $p \to 0$, and $np = \lambda$ remains constant
756
+
757
+ - The rate parameter scales linearly with the length of the interval - if events occur at rate $\lambda$ per unit time, then over $t$ units, the parameter becomes $\lambda t$
758
+
759
+ From modeling website traffic and customer arrivals to defects in manufacturing and radioactive decay, the Poisson distribution provides a powerful and mathematically elegant way to understand random occurrences in our world.
760
+ """
761
+ )
762
+ return
763
+
764
+
765
+ @app.cell(hide_code=True)
766
+ def _(mo):
767
+ mo.md(r"""Appendix code (helper functions, variables, etc.):""")
768
+ return
769
+
770
+
771
+ @app.cell
772
+ def _():
773
+ import marimo as mo
774
+ return (mo,)
775
+
776
+
777
+ @app.cell(hide_code=True)
778
+ def _():
779
+ import numpy as np
780
+ import matplotlib.pyplot as plt
781
+ import scipy.stats as stats
782
+ import pandas as pd
783
+ import altair as alt
784
+ from wigglystuff import TangleSlider
785
+ return TangleSlider, alt, np, pd, plt, stats
786
+
787
+
788
+ @app.cell(hide_code=True)
789
+ def _():
790
+ import io
791
+ import base64
792
+ from matplotlib.figure import Figure
793
+
794
+ # Helper function to convert mpl figure to an image format mo.image can hopefully handle
795
+ def fig_to_image(fig):
796
+ buf = io.BytesIO()
797
+ fig.savefig(buf, format='png')
798
+ buf.seek(0)
799
+ data = f"data:image/png;base64,{base64.b64encode(buf.read()).decode('utf-8')}"
800
+ return data
801
+ return Figure, base64, fig_to_image, io
802
+
803
+
804
+ if __name__ == "__main__":
805
+ app.run()
probability/16_continuous_distribution.py ADDED
@@ -0,0 +1,979 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # /// script
2
+ # requires-python = ">=3.11"
3
+ # dependencies = [
4
+ # "marimo",
5
+ # "altair==5.5.0",
6
+ # "matplotlib==3.10.1",
7
+ # "numpy==2.2.4",
8
+ # "scipy==1.15.2",
9
+ # "sympy==1.13.3",
10
+ # "wigglystuff==0.1.10",
11
+ # "polars==1.26.0",
12
+ # ]
13
+ # ///
14
+
15
+ import marimo
16
+
17
+ __generated_with = "0.11.26"
18
+ app = marimo.App(width="medium")
19
+
20
+
21
+ @app.cell(hide_code=True)
22
+ def _(mo):
23
+ mo.md(
24
+ r"""
25
+ # Continuous Distributions
26
+
27
+ _This notebook is a computational companion to ["Probability for Computer Scientists"](https://chrispiech.github.io/probabilityForComputerScientists/en/part2/continuous/), by Stanford professor Chris Piech._
28
+
29
+ So far, all the random variables we've explored have been discrete, taking on only specific values (usually integers). Now we'll move into the world of **continuous random variables**, which can take on any real number value. Continuous random variables are used to model measurements with arbitrary precision like height, weight, time, and many natural phenomena.
30
+ """
31
+ )
32
+ return
33
+
34
+
35
+ @app.cell(hide_code=True)
36
+ def _(mo):
37
+ mo.md(
38
+ r"""
39
+ ## From Discrete to Continuous
40
+
41
+ To make the transition from discrete to continuous random variables, let's start with a thought experiment:
42
+
43
+ > Imagine you're running to catch a bus. You know you'll arrive at 2:15pm, but you don't know exactly when the bus will arrive. You want to model the bus arrival time (in minutes past 2pm) as a random variable $T$ so you can calculate the probability that you'll wait more than five minutes: $P(15 < T < 20)$.
44
+
45
+ This immediately highlights a key difference from discrete distributions. For discrete distributions, we described the probability that a random variable takes on exact values. But this doesn't make sense for continuous values like time.
46
+
47
+ For example:
48
+
49
+ - What's the probability the bus arrives at exactly 2:17pm and 12.12333911102389234 seconds?
50
+ - What's the probability of a child being born weighing exactly 3.523112342234 kilograms?
51
+
52
+ These questions don't have meaningful answers because real-world measurements can have infinite precision. The probability of a continuous random variable taking on any specific exact value is actually zero!
53
+
54
+ ### Visualizing the Transition
55
+
56
+ Let's visualize this transition from discrete to continuous:
57
+ """
58
+ )
59
+ return
60
+
61
+
62
+ @app.cell(hide_code=True)
63
+ def _(fig_to_image, mo, np, plt):
64
+ def create_discretization_plot():
65
+ fig, axs = plt.subplots(1, 3, figsize=(15, 5))
66
+
67
+ # values from 0 to 30 minutes
68
+ x = np.linspace(0, 30, 1000)
69
+
70
+ # Triangular distribution peaked at 15 minutes)
71
+ y = np.where(x <= 15, x/15, (30-x)/15)
72
+ y = y / np.trapezoid(y, x) # Normalize
73
+
74
+ # 5-minute chunks (first plot)
75
+ bins = np.arange(0, 31, 5)
76
+ hist, _ = np.histogram(x, bins=bins, weights=y)
77
+ width = bins[1] - bins[0]
78
+ axs[0].bar(bins[:-1], hist * width, width=width, alpha=0.7,
79
+ color='royalblue', edgecolor='black')
80
+ axs[0].set_xlim(0, 30)
81
+ axs[0].set_title('5-Minute Intervals')
82
+ axs[0].set_xlabel('Minutes past 2pm')
83
+ axs[0].set_ylabel('Probability')
84
+
85
+ # 15-20 minute range more prominent
86
+ axs[0].bar([15], hist[3] * width, width=width, alpha=0.7,
87
+ color='darkorange', edgecolor='black')
88
+
89
+ # 2.5-minute chunks (second plot)
90
+ bins = np.arange(0, 31, 2.5)
91
+ hist, _ = np.histogram(x, bins=bins, weights=y)
92
+ width = bins[1] - bins[0]
93
+ axs[1].bar(bins[:-1], hist * width, width=width, alpha=0.7,
94
+ color='royalblue', edgecolor='black')
95
+ axs[1].set_xlim(0, 30)
96
+ axs[1].set_title('2.5-Minute Intervals')
97
+ axs[1].set_xlabel('Minutes past 2pm')
98
+
99
+ # Make 15-20 minute range more prominent
100
+ highlight_indices = [6, 7]
101
+ for idx in highlight_indices:
102
+ axs[1].bar([bins[idx]], hist[idx] * width, width=width, alpha=0.7,
103
+ color='darkorange', edgecolor='black')
104
+
105
+ # Continuous distribution (third plot)
106
+ axs[2].plot(x, y, 'royalblue', linewidth=2)
107
+ axs[2].set_xlim(0, 30)
108
+ axs[2].set_title('Continuous Distribution')
109
+ axs[2].set_xlabel('Minutes past 2pm')
110
+ axs[2].set_ylabel('Probability Density')
111
+
112
+ # Highlight the AUC between 15 and 20
113
+ mask = (x >= 15) & (x <= 20)
114
+ axs[2].fill_between(x[mask], y[mask], color='darkorange', alpha=0.7)
115
+
116
+ # Mark 15-20 minute interval
117
+ for ax in axs:
118
+ ax.axvline(x=15, color='red', linestyle='--', alpha=0.5)
119
+ ax.axvline(x=20, color='red', linestyle='--', alpha=0.5)
120
+ ax.set_xticks([0, 5, 10, 15, 20, 25, 30])
121
+ ax.grid(alpha=0.3)
122
+
123
+ plt.tight_layout()
124
+ plt.gca()
125
+ return fig
126
+
127
+ # Plot creation & conversion
128
+ _fig = create_discretization_plot()
129
+ _img = mo.image(fig_to_image(_fig), width="100%")
130
+
131
+ _explanation = mo.md(
132
+ r"""
133
+ The figure above illustrates our transition from discrete to continuous thinking:
134
+
135
+ - **Left**: Time divided into 5-minute chunks, where the probability of the bus arriving between 15-20 minutes (highlighted in orange) is a single value.
136
+ - **Center**: Time divided into finer 2.5-minute chunks, where the 15-20 minute range consists of two chunks.
137
+ - **Right**: In the limit, we get a continuous probability density function where the probability is the area under the curve between 15 and 20 minutes.
138
+
139
+ As we make our chunks smaller and smaller, we eventually arrive at a smooth function that gives us the probability density at each point.
140
+ """
141
+ )
142
+
143
+ mo.vstack([_img, _explanation])
144
+ return (create_discretization_plot,)
145
+
146
+
147
+ @app.cell(hide_code=True)
148
+ def _(mo):
149
+ mo.md(
150
+ r"""
151
+ ## Probability Density Functions
152
+
153
+ In the world of discrete random variables, we used **Probability Mass Functions (PMFs)** to describe the probability of a random variable taking on specific values. In the continuous world, we need a different approach.
154
+
155
+ For continuous random variables, we use a **Probability Density Function (PDF)** which defines the relative likelihood that a random variable takes on a particular value. We traditionally denote the PDF with the symbol $f$ and write it as:
156
+
157
+ $$f(X=x) \quad \text{or simply} \quad f(x)$$
158
+
159
+ Where the lowercase $x$ implies that we're talking about the relative likelihood of a continuous random variable which is the uppercase $X$.
160
+
161
+ ### Key Properties of PDFs
162
+
163
+ A **Probability Density Function (PDF)** $f(x)$ for a continuous random variable $X$ has these key properties:
164
+
165
+ 1. The probability that $X$ takes a value in the interval $[a, b]$ is:
166
+
167
+ $$P(a \leq X \leq b) = \int_a^b f(x) \, dx$$
168
+
169
+ 2. The PDF must be non-negative everywhere:
170
+
171
+ $$f(x) \geq 0 \text{ for all } x$$
172
+
173
+ 3. The total probability must sum to 1:
174
+
175
+ $$\int_{-\infty}^{\infty} f(x) \, dx = 1$$
176
+
177
+ 4. The probability that $X$ takes any specific exact value is 0:
178
+
179
+ $$P(X = a) = \int_a^a f(x) \, dx = 0$$
180
+
181
+ This last property highlights a key difference from discrete distributions: the probability of a continuous random variable taking on an exact value is always 0. Probabilities only make sense when talking about ranges of values.
182
+
183
+ ### Caution: Density ≠ Probability
184
+
185
+ A common misconception is to think of $f(x)$ as a probability. It is instead a **probability density**, representing probability per unit of $x$. The values of $f(x)$ can actually exceed 1, as long as the total area under the curve equals 1.
186
+
187
+ The interpretation of $f(x)$ is only meaningful when:
188
+
189
+ 1. We integrate over a range to get a probability, or
190
+ 2. We compare densities at different points to determine relative likelihoods.
191
+ """
192
+ )
193
+ return
194
+
195
+
196
+ @app.cell(hide_code=True)
197
+ def _(TangleSlider, mo):
198
+ # Create sliders for a and b
199
+ a_slider = mo.ui.anywidget(TangleSlider(
200
+ amount=1,
201
+ min_value=0,
202
+ max_value=5,
203
+ step=0.1,
204
+ digits=1
205
+ ))
206
+
207
+ b_slider = mo.ui.anywidget(TangleSlider(
208
+ amount=3,
209
+ min_value=0,
210
+ max_value=5,
211
+ step=0.1,
212
+ digits=1
213
+ ))
214
+
215
+ # Distribution selector
216
+ distribution_radio = mo.ui.radio(
217
+ options=["uniform", "triangular", "exponential"],
218
+ value="uniform",
219
+ label="Distribution Type"
220
+ )
221
+
222
+ # Controls layout
223
+ _controls = mo.vstack([
224
+ mo.md("### Visualizing Probability as Area Under the PDF Curve"),
225
+ mo.md("Adjust sliders to change the interval $[a, b]$ and see how the probability changes:"),
226
+ mo.hstack([
227
+ mo.md("Lower bound (a):"),
228
+ a_slider,
229
+ mo.md("Upper bound (b):"),
230
+ b_slider
231
+ ], justify="start"),
232
+ distribution_radio
233
+ ])
234
+ _controls
235
+ return a_slider, b_slider, distribution_radio
236
+
237
+
238
+ @app.cell(hide_code=True)
239
+ def _(
240
+ a_slider,
241
+ b_slider,
242
+ create_pdf_visualization,
243
+ distribution_radio,
244
+ fig_to_image,
245
+ mo,
246
+ ):
247
+ a = a_slider.amount
248
+ b = b_slider.amount
249
+ distribution = distribution_radio.value
250
+
251
+ # Ensure a < b
252
+ if a > b:
253
+ a, b = b, a
254
+
255
+ # visualization
256
+ _fig, _probability = create_pdf_visualization(a, b, distribution)
257
+
258
+ # Display visualization
259
+ _img = mo.image(fig_to_image(_fig), width="100%")
260
+
261
+ # Add appropriate explanation
262
+ if distribution == "uniform":
263
+ _explanation = mo.md(
264
+ f"""
265
+ In the **uniform distribution**, all values between 0 and 5 are equally likely.
266
+ The probability density is constant at 0.2 (which is 1/5, ensuring the total area is 1).
267
+ For a uniform distribution, the probability that $X$ is in the interval $[{a:.1f}, {b:.1f}]$
268
+ is simply proportional to the width of the interval: $P({a:.1f} \leq X \leq {b:.1f}) = {_probability:.4f}$
269
+ Note that while the PDF has a constant value of 0.2, this is not a probability but a density!
270
+ """
271
+ )
272
+ elif distribution == "triangular":
273
+ _explanation = mo.md(
274
+ f"""
275
+ In this **triangular distribution**, the probability density increases linearly from 0 to 2.5,
276
+ then decreases linearly from 2.5 to 5.
277
+ The distribution's peak is at x = 2.5, where the value is highest.
278
+ The orange shaded area representing $P({a:.1f} \leq X \leq {b:.1f}) = {_probability:.4f}$
279
+ is calculated by integrating the PDF over the interval.
280
+ """
281
+ )
282
+ else:
283
+ _explanation = mo.md(
284
+ f"""
285
+ The **exponential distribution** (with λ = 0.5) models the time between events in a Poisson process.
286
+ Unlike the uniform and triangular distributions, the exponential distribution has infinite support
287
+ (extends from 0 to infinity). The probability density decreases exponentially as x increases.
288
+ The orange shaded area representing $P({a:.1f} \leq X \leq {b:.1f}) = {_probability:.4f}$
289
+ is calculated by integrating $f(x) = 0.5e^{{-0.5x}}$ over the interval.
290
+ """
291
+ )
292
+ mo.vstack([_img, _explanation])
293
+ return a, b, distribution
294
+
295
+
296
+ @app.cell(hide_code=True)
297
+ def _(mo):
298
+ mo.md(
299
+ r"""
300
+ ## Cumulative Distribution Function
301
+
302
+ Since working with PDFs requires solving integrals to find probabilities, we often use the **Cumulative Distribution Function (CDF)** as a more convenient tool.
303
+
304
+ The CDF $F(x)$ for a continuous random variable $X$ is defined as:
305
+
306
+ $$F(x) = P(X \leq x) = \int_{-\infty}^{x} f(t)\,dt$$
307
+
308
+ where $f(t)$ is the PDF of $X$.
309
+
310
+ ### Properties of CDFs
311
+
312
+ A CDF $F(x)$ has these key properties:
313
+
314
+ 1. $F(x)$ is always non-decreasing: if $a < b$, then $F(a) \leq F(b)$
315
+ 2. $\lim_{x \to -\infty} F(x) = 0$ and $\lim_{x \to \infty} F(x) = 1$
316
+ 3. $F(x)$ is right-continuous: $\lim_{h \to 0^+} F(x+h) = F(x)$
317
+
318
+ ### Using the CDF to Calculate Probabilities
319
+
320
+ The CDF is extremely useful because it allows us to calculate various probabilities without having to perform integrals each time:
321
+
322
+ | Probability Query | Solution | Explanation |
323
+ |-------------------|----------|-------------|
324
+ | $P(X < a)$ | $F(a)$ | Definition of the CDF |
325
+ | $P(X \leq a)$ | $F(a)$ | For continuous distributions, $P(X = a) = 0$ |
326
+ | $P(X > a)$ | $1 - F(a)$ | Since $P(X \leq a) + P(X > a) = 1$ |
327
+ | $P(a < X < b)$ | $F(b) - F(a)$ | Since $F(a) + P(a < X < b) = F(b)$ |
328
+ | $P(a \leq X \leq b)$ | $F(b) - F(a)$ | Since $P(X = a) = P(X = b) = 0$ |
329
+
330
+ For discrete random variables, the CDF is also defined but it's less commonly used:
331
+
332
+ $$F_X(a) = \sum_{i \leq a} P(X = i)$$
333
+
334
+ The CDF for discrete distributions is a step function, increasing at each point in the support of the random variable.
335
+ """
336
+ )
337
+ return
338
+
339
+
340
+ @app.cell(hide_code=True)
341
+ def _(fig_to_image, mo, np, plt):
342
+ def create_pdf_cdf_comparison():
343
+ fig, axs = plt.subplots(3, 2, figsize=(12, 10))
344
+
345
+ # x-values
346
+ x = np.linspace(-1, 6, 1000)
347
+
348
+ # 1. Uniform Distribution
349
+ # PDF
350
+ pdf_uniform = np.where((x >= 0) & (x <= 5), 0.2, 0)
351
+ axs[0, 0].plot(x, pdf_uniform, 'b-', linewidth=2)
352
+ axs[0, 0].set_title('Uniform PDF')
353
+ axs[0, 0].set_ylabel('Density')
354
+ axs[0, 0].grid(alpha=0.3)
355
+
356
+ # CDF
357
+ cdf_uniform = np.zeros_like(x)
358
+ for i, val in enumerate(x):
359
+ if val < 0:
360
+ cdf_uniform[i] = 0
361
+ elif val > 5:
362
+ cdf_uniform[i] = 1
363
+ else:
364
+ cdf_uniform[i] = val / 5
365
+
366
+ axs[0, 1].plot(x, cdf_uniform, 'r-', linewidth=2)
367
+ axs[0, 1].set_title('Uniform CDF')
368
+ axs[0, 1].set_ylabel('Probability')
369
+ axs[0, 1].grid(alpha=0.3)
370
+
371
+ # 2. Triangular Distribution
372
+ # PDF
373
+ pdf_triangular = np.where(x <= 2.5, x/6.25, (5-x)/6.25)
374
+ pdf_triangular = np.where((x < 0) | (x > 5), 0, pdf_triangular)
375
+
376
+ axs[1, 0].plot(x, pdf_triangular, 'b-', linewidth=2)
377
+ axs[1, 0].set_title('Triangular PDF')
378
+ axs[1, 0].set_ylabel('Density')
379
+ axs[1, 0].grid(alpha=0.3)
380
+
381
+ # CDF
382
+ cdf_triangular = np.zeros_like(x)
383
+ for i, val in enumerate(x):
384
+ if val <= 0:
385
+ cdf_triangular[i] = 0
386
+ elif val >= 5:
387
+ cdf_triangular[i] = 1
388
+ else:
389
+ # For x ≤ 2.5: CDF = x²/(2 *6 .25)
390
+ # For x > 2.5: CDF = 1 - (5 - x)²/(2 * 6.25)
391
+ if val <= 2.5:
392
+ cdf_triangular[i] = (val**2) / (2 * 6.25)
393
+ else:
394
+ cdf_triangular[i] = 1 - ((5 - val)**2) / (2 * 6.25)
395
+
396
+ axs[1, 1].plot(x, cdf_triangular, 'r-', linewidth=2)
397
+ axs[1, 1].set_title('Triangular CDF')
398
+ axs[1, 1].set_ylabel('Probability')
399
+ axs[1, 1].grid(alpha=0.3)
400
+
401
+ # 3. Exponential Distribution
402
+ # PDF
403
+ lambda_param = 0.5
404
+ pdf_exponential = np.where(x >= 0, lambda_param * np.exp(-lambda_param * x), 0)
405
+
406
+ axs[2, 0].plot(x, pdf_exponential, 'b-', linewidth=2)
407
+ axs[2, 0].set_title('Exponential PDF (λ=0.5)')
408
+ axs[2, 0].set_xlabel('x')
409
+ axs[2, 0].set_ylabel('Density')
410
+ axs[2, 0].grid(alpha=0.3)
411
+
412
+ # CDF
413
+ cdf_exponential = np.where(x < 0, 0, 1 - np.exp(-lambda_param * x))
414
+
415
+ axs[2, 1].plot(x, cdf_exponential, 'r-', linewidth=2)
416
+ axs[2, 1].set_title('Exponential CDF (λ=0.5)')
417
+ axs[2, 1].set_xlabel('x')
418
+ axs[2, 1].set_ylabel('Probability')
419
+ axs[2, 1].grid(alpha=0.3)
420
+
421
+ # Common x-limits
422
+ for ax in axs.flatten():
423
+ ax.set_xlim(-0.5, 5.5)
424
+ if ax in axs[:, 0]: # PDF plots
425
+ ax.set_ylim(-0.05, max(0.5, max(pdf_triangular)*1.1))
426
+ else: # CDF plots
427
+ ax.set_ylim(-0.05, 1.05)
428
+
429
+ plt.tight_layout()
430
+ plt.gca()
431
+ return fig
432
+
433
+ # Create visualization
434
+ _fig = create_pdf_cdf_comparison()
435
+ _img = mo.image(fig_to_image(_fig), width="100%")
436
+
437
+ _explanation = mo.md(
438
+ r"""
439
+ The figure above compares the Probability Density Functions (PDFs) on the left with their corresponding Cumulative Distribution Functions (CDFs) on the right for three common distributions:
440
+
441
+ 1. **Uniform Distribution**:
442
+
443
+ - PDF: Constant value (0.2) across the support range [0, 5]
444
+ - CDF: Linear increase from 0 to 1 across the support range
445
+
446
+ 2. **Triangular Distribution**:
447
+
448
+ - PDF: Linearly increases then decreases, forming a triangle shape
449
+ - CDF: Increases quadratically up to the peak, then approaches 1 quadratically
450
+
451
+ 3. **Exponential Distribution**:
452
+
453
+ - PDF: Starts at λ=0.5 and decreases exponentially
454
+ - CDF: Starts at 0 and approaches 1 exponentially (never quite reaching 1)
455
+
456
+ /// NOTE
457
+ The common properties of all CDFs:
458
+
459
+ - They are non-decreasing functions
460
+ - They start at 0 (for x = -∞) and approach or reach 1 (for x = ∞)
461
+ - The slope of the CDF at any point equals the PDF value at that point
462
+ """
463
+ )
464
+
465
+ mo.vstack([_img, _explanation])
466
+ return (create_pdf_cdf_comparison,)
467
+
468
+
469
+ @app.cell(hide_code=True)
470
+ def _(mo):
471
+ mo.md(
472
+ r"""
473
+ ## Solving for Constants in PDFs
474
+
475
+ Many PDFs contain a constant that needs to be determined to ensure the total probability equals 1. Let's work through an example to understand how to solve for these constants.
476
+
477
+ ### Example: Finding the Constant $C$
478
+
479
+ Let $X$ be a continuous random variable with PDF:
480
+
481
+ $$f(x) = \begin{cases}
482
+ C(4x - 2x^2) & \text{when } 0 < x < 2 \\
483
+ 0 & \text{otherwise}
484
+ \end{cases}$$
485
+
486
+ In this function, $C$ is a constant we need to determine. Since we know the PDF must integrate to 1:
487
+
488
+ \begin{align}
489
+ &\int_0^2 C(4x - 2x^2) \, dx = 1 \\
490
+ &C\left(2x^2 - \frac{2x^3}{3}\right)\bigg|_0^2 = 1 \\
491
+ &C\left[\left(8 - \frac{16}{3}\right) - 0 \right] = 1 \\
492
+ &C\left(\frac{24 - 16}{3}\right) = 1 \\
493
+ &C\left(\frac{8}{3}\right) = 1 \\
494
+ &C = \frac{3}{8}
495
+ \end{align}
496
+
497
+ Now that we know $C = \frac{3}{8}$, we can compute probabilities. For example, what is $P(X > 1)$?
498
+
499
+ \begin{align}
500
+ P(X > 1)
501
+ &= \int_1^{\infty}f(x) \, dx \\
502
+ &= \int_1^2 \frac{3}{8}(4x - 2x^2) \, dx \\
503
+ &= \frac{3}{8}\left(2x^2 - \frac{2x^3}{3}\right)\bigg|_1^2 \\
504
+ &= \frac{3}{8}\left[\left(8 - \frac{16}{3}\right) - \left(2 - \frac{2}{3}\right)\right] \\
505
+ &= \frac{3}{8}\left[\left(8 - \frac{16}{3}\right) - \left(\frac{6 - 2}{3}\right)\right] \\
506
+ &= \frac{3}{8}\left[\left(\frac{24 - 16}{3}\right) - \left(\frac{4}{3}\right)\right] \\
507
+ &= \frac{3}{8}\left[\left(\frac{8}{3}\right) - \left(\frac{4}{3}\right)\right] \\
508
+ &= \frac{3}{8} \cdot \frac{4}{3} \\
509
+ &= \frac{1}{2}
510
+ \end{align}
511
+
512
+ Let's visualize this distribution and verify our results:
513
+ """
514
+ )
515
+ return
516
+
517
+
518
+ @app.cell(hide_code=True)
519
+ def _(
520
+ create_example_pdf_visualization,
521
+ fig_to_image,
522
+ mo,
523
+ symbolic_calculation,
524
+ ):
525
+ # Create visualization
526
+ _fig = create_example_pdf_visualization()
527
+ _img = mo.image(fig_to_image(_fig), width="100%")
528
+
529
+ # Symbolic calculation
530
+ _sympy_verification = mo.md(symbolic_calculation())
531
+
532
+ _explanation = mo.md(
533
+ r"""
534
+ The figure above shows:
535
+
536
+ 1. **Left**: The PDF $f(x) = \frac{3}{8}(4x - 2x^2)$ for $0 < x < 2$, with the area representing P(X > 1) shaded in orange.
537
+ 2. **Right**: The corresponding CDF, showing F(1) = 0.5 and thus P(X > 1) = 1 - F(1) = 0.5.
538
+
539
+ Notice how we:
540
+
541
+ 1. First determined the constant C = 3/8 by ensuring the total area under the PDF equals 1
542
+ 2. Used this value to calculate specific probabilities like P(X > 1)
543
+ 3. Verified our results both graphically and symbolically
544
+ """
545
+ )
546
+ mo.vstack([_img, _sympy_verification, _explanation])
547
+ return
548
+
549
+
550
+ @app.cell(hide_code=True)
551
+ def _(mo):
552
+ mo.md(
553
+ r"""
554
+ ## Expectation and Variance of Continuous Random Variables
555
+
556
+ Just as with discrete random variables, we can calculate the expectation and variance of continuous random variables. The main difference is that we use integrals instead of sums.
557
+
558
+ ### Expectation (Mean)
559
+
560
+ For a continuous random variable $X$ with PDF $f(x)$, the expectation is:
561
+
562
+ $$E[X] = \int_{-\infty}^{\infty} x \cdot f(x) \, dx$$
563
+
564
+ More generally, for any function $g(X)$:
565
+
566
+ $$E[g(X)] = \int_{-\infty}^{\infty} g(x) \cdot f(x) \, dx$$
567
+
568
+ ### Variance
569
+
570
+ The variance is defined the same way as for discrete random variables:
571
+
572
+ $$\text{Var}(X) = E[(X - \mu)^2] = E[X^2] - (E[X])^2$$
573
+
574
+ where $\mu = E[X]$ is the mean of $X$.
575
+
576
+ To calculate $E[X^2]$, we use:
577
+
578
+ $$E[X^2] = \int_{-\infty}^{\infty} x^2 \cdot f(x) \, dx$$
579
+
580
+ ### Properties
581
+
582
+ The following properties hold for both continuous and discrete random variables:
583
+
584
+ 1. $E[aX + b] = aE[X] + b$ for constants $a$ and $b$
585
+ 2. $\text{Var}(aX + b) = a^2 \text{Var}(X)$ for constants $a$ and $b$
586
+
587
+ Let's calculate the expectation and variance for our example PDF:
588
+ """
589
+ )
590
+ return
591
+
592
+
593
+ @app.cell(hide_code=True)
594
+ def _(fig_to_image, mo, np, plt, sympy):
595
+ # Symbolic calculation of expectation and variance
596
+ def symbolic_stats_calc():
597
+ x = sympy.symbols('x')
598
+ C = sympy.Rational(3, 8)
599
+
600
+ # Define the PDF
601
+ pdf_expr = C * (4*x - 2*x**2)
602
+
603
+ # Calculate expectation
604
+ E_X = sympy.integrate(x * pdf_expr, (x, 0, 2))
605
+
606
+ # Calculate E[X²]
607
+ E_X2 = sympy.integrate(x**2 * pdf_expr, (x, 0, 2))
608
+
609
+ # Calculate variance
610
+ Var_X = E_X2 - E_X**2
611
+
612
+ # Calculate standard deviation
613
+ Std_X = sympy.sqrt(Var_X)
614
+
615
+ return E_X, E_X2, Var_X, Std_X
616
+
617
+ # Get symbolic results
618
+ E_X, E_X2, Var_X, Std_X = symbolic_stats_calc()
619
+
620
+ # Numerical values for plotting
621
+ E_X_val = float(E_X)
622
+ Var_X_val = float(Var_X)
623
+ Std_X_val = float(Std_X)
624
+
625
+ def create_expectation_variance_vis():
626
+ """Create visualization showing mean and variance for the example PDF."""
627
+ fig, ax = plt.subplots(figsize=(10, 6))
628
+
629
+ # x-values
630
+ x = np.linspace(-0.5, 2.5, 1000)
631
+
632
+ # PDF function
633
+ C = 3/8
634
+ pdf = np.where((x > 0) & (x < 2), C * (4*x - 2*x**2), 0)
635
+
636
+ # Plot the PDF
637
+ ax.plot(x, pdf, 'b-', linewidth=2, label='PDF')
638
+ ax.fill_between(x, pdf, where=(x > 0) & (x < 2), alpha=0.3, color='blue')
639
+
640
+ # Mark the mean
641
+ ax.axvline(x=E_X_val, color='r', linestyle='--', linewidth=2,
642
+ label=f'Mean (E[X] = {E_X_val:.3f})')
643
+
644
+ # Mark the standard deviation range
645
+ ax.axvspan(E_X_val - Std_X_val, E_X_val + Std_X_val, alpha=0.2, color='green',
646
+ label=f'±1 Std Dev ({Std_X_val:.3f})')
647
+
648
+ # Add labels and title
649
+ ax.set_xlabel('x')
650
+ ax.set_ylabel('Probability Density')
651
+ ax.set_title('PDF with Mean and Variance')
652
+ ax.legend()
653
+ ax.grid(alpha=0.3)
654
+
655
+ # Set x-limits
656
+ ax.set_xlim(-0.25, 2.25)
657
+
658
+ plt.tight_layout()
659
+ return fig
660
+
661
+ # Create the visualization
662
+ _fig = create_expectation_variance_vis()
663
+ _img = mo.image(fig_to_image(_fig), width="100%")
664
+
665
+ # Detailed calculations for our example
666
+ _calculations = mo.md(
667
+ f"""
668
+ ### Calculating Expectation and Variance for Our Example
669
+
670
+ Let's calculate the expectation and variance for the PDF:
671
+
672
+ $$f(x) = \\begin{{cases}}
673
+ \\frac{{3}}{{8}}(4x - 2x^2) & \\text{{when }} 0 < x < 2 \\\\
674
+ 0 & \\text{{otherwise}}
675
+ \\end{{cases}}$$
676
+
677
+ #### Expectation Calculation
678
+
679
+ $$E[X] = \\int_{{-\\infty}}^{{\\infty}} x \\cdot f(x) \\, dx = \\int_0^2 x \\cdot \\frac{{3}}{{8}}(4x - 2x^2) \\, dx$$
680
+
681
+ $$E[X] = \\frac{{3}}{{8}} \\int_0^2 (4x^2 - 2x^3) \\, dx = \\frac{{3}}{{8}} \\left[ \\frac{{4x^3}}{{3}} - \\frac{{2x^4}}{{4}} \\right]_0^2$$
682
+
683
+ $$E[X] = \\frac{{3}}{{8}} \\left[ \\frac{{4 \\cdot 2^3}}{{3}} - \\frac{{2 \\cdot 2^4}}{{4}} - 0 \\right] = \\frac{{3}}{{8}} \\left[ \\frac{{32}}{{3}} - 4 \\right]$$
684
+
685
+ $$E[X] = \\frac{{3}}{{8}} \\cdot \\frac{{32 - 12}}{{3}} = \\frac{{3}}{{8}} \\cdot \\frac{{20}}{{3}} = \\frac{{20}}{{8}} = {E_X}$$
686
+
687
+ #### Variance Calculation
688
+
689
+ First, we need $E[X^2]$:
690
+
691
+ $$E[X^2] = \\int_{{-\\infty}}^{{\\infty}} x^2 \\cdot f(x) \\, dx = \\int_0^2 x^2 \\cdot \\frac{{3}}{{8}}(4x - 2x^2) \\, dx$$
692
+
693
+ $$E[X^2] = \\frac{{3}}{{8}} \\int_0^2 (4x^3 - 2x^4) \\, dx = \\frac{{3}}{{8}} \\left[ \\frac{{4x^4}}{{4}} - \\frac{{2x^5}}{{5}} \\right]_0^2$$
694
+
695
+ $$E[X^2] = \\frac{{3}}{{8}} \\left[ 4 - \\frac{{2 \\cdot 32}}{{5}} - 0 \\right] = \\frac{{3}}{{8}} \\left[ 4 - \\frac{{64}}{{5}} \\right]$$
696
+
697
+ $$E[X^2] = \\frac{{3}}{{8}} \\cdot \\frac{{20 - 64/5}}{{1}} = {E_X2}$$
698
+
699
+ Now we can calculate the variance:
700
+
701
+ $$\\text{{Var}}(X) = E[X^2] - (E[X])^2 = {E_X2} - ({E_X})^2 = {Var_X}$$
702
+
703
+ Therefore, the standard deviation is $\\sqrt{{\\text{{Var}}(X)}} = {Std_X}$.
704
+ """
705
+ )
706
+ mo.vstack([_img, _calculations])
707
+ return (
708
+ E_X,
709
+ E_X2,
710
+ E_X_val,
711
+ Std_X,
712
+ Std_X_val,
713
+ Var_X,
714
+ Var_X_val,
715
+ create_expectation_variance_vis,
716
+ symbolic_stats_calc,
717
+ )
718
+
719
+
720
+ @app.cell(hide_code=True)
721
+ def _(mo):
722
+ mo.md(
723
+ r"""
724
+ ## 🤔 Test Your Understanding
725
+
726
+ Select which of these statements about continuous distributions you think are correct:
727
+
728
+ /// details | The PDF of a continuous random variable can have values greater than 1
729
+ ✅ Correct! Since the PDF represents density (not probability), it can exceed 1 as long as the total area under the curve equals 1.
730
+ ///
731
+
732
+ /// details | For a continuous distribution, $P(X = a) > 0$ for any value $a$ in the support
733
+ ❌ Incorrect! For continuous random variables, the probability of the random variable taking any specific exact value is always 0. That is, $P(X = a) = 0$ for any value $a$.
734
+ ///
735
+
736
+ /// details | The area under a PDF curve between $a$ and $b$ equals the probability $P(a \leq X \leq b)$
737
+ ✅ Correct! The area under the PDF curve over an interval gives the probability that the random variable falls within that interval.
738
+ ///
739
+
740
+ /// details | The CDF function $F(x)$ is always equal to $\int_{-\infty}^{x} f(t) \, dt$
741
+ ✅ Correct! The CDF at point $x$ is the integral of the PDF from negative infinity to $x$.
742
+ ///
743
+
744
+ /// details | For a continuous random variable, $F(x)$ ranges from 0 to the maximum value in the support of the random variable
745
+ ❌ Incorrect! The CDF $F(x)$ ranges from 0 to 1, representing probabilities. It approaches 1 (not the maximum value in the support) as $x$ approaches infinity.
746
+ ///
747
+
748
+ /// details | To calculate the variance of a continuous random variable, we use the formula $\text{Var}(X) = E[X^2] - (E[X])^2$
749
+ ✅ Correct! This formula applies to both discrete and continuous random variables.
750
+ ///
751
+ """
752
+ )
753
+ return
754
+
755
+
756
+ @app.cell(hide_code=True)
757
+ def _(mo):
758
+ mo.md(
759
+ r"""
760
+ ## Summary
761
+
762
+ Moving from discrete to continuous thinking is a big conceptual leap, but it opens up powerful ways to model real-world phenomena.
763
+
764
+ In this notebook, we've seen how continuous random variables let us model quantities that can take any real value. Instead of dealing with probabilities at specific points (which are actually zero!), we work with probability density functions (PDFs) and find probabilities by calculating areas under curves.
765
+
766
+ Some key points to remember:
767
+
768
+ • PDFs give us relative likelihood, not actual probabilities - that's why they can exceed 1
769
+ • The probability between two points is the area under the PDF curve
770
+ • CDFs offer a convenient shortcut to find probabilities without integrating
771
+ • Expectation and variance work similarly to discrete variables, just with integrals instead of sums
772
+ • Constants in PDFs are determined by ensuring the total probability equals 1
773
+
774
+ This foundation will serve you well as we explore specific continuous distributions like normal, exponential, and beta in future notebooks. These distributions are the workhorses of probability theory and statistics, appearing everywhere from quality control to financial modeling.
775
+
776
+ One final thought: continuous distributions are beautiful mathematical objects, but remember they're just models. Real-world data is often discrete at some level, but continuous distributions provide elegant approximations that make calculations more tractable.
777
+ """
778
+ )
779
+ return
780
+
781
+
782
+ @app.cell
783
+ def _(mo):
784
+ mo.md(r"""Appendix code (helper functions, variables, etc.):""")
785
+ return
786
+
787
+
788
+ @app.cell
789
+ def _():
790
+ import marimo as mo
791
+ return (mo,)
792
+
793
+
794
+ @app.cell(hide_code=True)
795
+ def _():
796
+ import numpy as np
797
+ import matplotlib.pyplot as plt
798
+ import scipy.stats as stats
799
+ import sympy
800
+ from scipy import integrate as scipy
801
+ import polars as pl
802
+ import altair as alt
803
+ from wigglystuff import TangleSlider
804
+ return TangleSlider, alt, np, pl, plt, scipy, stats, sympy
805
+
806
+
807
+ @app.cell(hide_code=True)
808
+ def _():
809
+ import io
810
+ import base64
811
+ from matplotlib.figure import Figure
812
+
813
+ # Helper function to convert mpl figure to an image format mo.image can handle
814
+ def fig_to_image(fig):
815
+ buf = io.BytesIO()
816
+ fig.savefig(buf, format='png')
817
+ buf.seek(0)
818
+ data = f"data:image/png;base64,{base64.b64encode(buf.read()).decode('utf-8')}"
819
+ return data
820
+ return Figure, base64, fig_to_image, io
821
+
822
+
823
+ @app.cell(hide_code=True)
824
+ def _(np, plt):
825
+ def create_pdf_visualization(a, b, distribution='uniform'):
826
+ fig, ax = plt.subplots(figsize=(10, 6))
827
+
828
+ # x-values
829
+ x = np.linspace(-0.5, 5.5, 1000)
830
+
831
+ # Various PDFs to visualize
832
+ if distribution == 'uniform':
833
+ # Uniform distribution from 0 to 5
834
+ y = np.where((x >= 0) & (x <= 5), 0.2, 0)
835
+ title = f"Uniform PDF from 0 to 5"
836
+
837
+ elif distribution == 'triangular':
838
+ # Triangular distribution peaked at 2.5
839
+ y = np.where(x <= 2.5, x/6.25, (5-x)/6.25) # peak at 2.5
840
+ y = np.where((x < 0) | (x > 5), 0, y)
841
+ title = f"Triangular PDF from 0 to 5"
842
+
843
+ elif distribution == 'exponential':
844
+ lambda_param = 0.5
845
+ y = np.where(x >= 0, lambda_param * np.exp(-lambda_param * x), 0)
846
+ title = f"Exponential PDF with λ = {lambda_param}"
847
+
848
+ # Plot PDF
849
+ ax.plot(x, y, 'b-', linewidth=2, label='PDF $f(x)$')
850
+
851
+ # Shade the area for the probability P(a ≤ X ≤ b)
852
+ mask = (x >= a) & (x <= b)
853
+ ax.fill_between(x[mask], y[mask], color='orange', alpha=0.5)
854
+
855
+ # Calculate the probability
856
+ dx = x[1] - x[0]
857
+ probability = np.sum(y[mask]) * dx
858
+
859
+ # vertical lines at a and b
860
+ ax.axvline(x=a, color='r', linestyle='--', alpha=0.7,
861
+ label=f'a = {a:.1f}')
862
+ ax.axvline(x=b, color='g', linestyle='--', alpha=0.7,
863
+ label=f'b = {b:.1f}')
864
+
865
+ # horizontal line at y=0
866
+ ax.axhline(y=0, color='black', linestyle='-', alpha=0.3)
867
+
868
+ # labels and title
869
+ ax.set_xlabel('x')
870
+ ax.set_ylabel('Probability Density $f(x)$')
871
+ ax.set_title(title)
872
+ ax.legend(loc='upper right')
873
+
874
+ # relevant annotations
875
+ ax.annotate(f'$P({a:.1f} \leq X \leq {b:.1f}) = {probability:.4f}$',
876
+ xy=(0.5, 0.9), xycoords='axes fraction',
877
+ bbox=dict(boxstyle='round,pad=0.5', facecolor='white', alpha=0.8),
878
+ horizontalalignment='center', fontsize=12)
879
+
880
+ plt.grid(alpha=0.3)
881
+ plt.tight_layout()
882
+ plt.gca()
883
+ return fig, probability
884
+ return (create_pdf_visualization,)
885
+
886
+
887
+ @app.cell(hide_code=True)
888
+ def _(np, plt, sympy):
889
+ def create_example_pdf_visualization():
890
+ fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))
891
+
892
+ # x-values
893
+ x = np.linspace(-0.5, 2.5, 1000)
894
+
895
+ # PDF function
896
+ C = 3/8
897
+ pdf = np.where((x > 0) & (x < 2), C * (4*x - 2*x**2), 0)
898
+
899
+ # CDF
900
+ cdf = np.zeros_like(x)
901
+ for i, val in enumerate(x):
902
+ if val <= 0:
903
+ cdf[i] = 0
904
+ elif val >= 2:
905
+ cdf[i] = 1
906
+ else:
907
+ # Analytical form: C*(2x^2 - 2x^3/3)
908
+ cdf[i] = C * (2*val**2 - (2*val**3)/3)
909
+
910
+ # PDF Plot
911
+ ax1.plot(x, pdf, 'b-', linewidth=2)
912
+ ax1.set_title('PDF: $f(x) = \\frac{3}{8}(4x - 2x^2)$ for $0 < x < 2$')
913
+ ax1.set_xlabel('x')
914
+ ax1.set_ylabel('Probability Density')
915
+ ax1.grid(alpha=0.3)
916
+
917
+ # Highlight the area for P(X > 1)
918
+ mask = (x > 1) & (x < 2)
919
+ ax1.fill_between(x[mask], pdf[mask], color='orange', alpha=0.5,
920
+ label='P(X > 1) = 0.5')
921
+
922
+ # Add vertical line at x=1
923
+ ax1.axvline(x=1, color='r', linestyle='--', alpha=0.7)
924
+ ax1.legend()
925
+
926
+ # CDF Plot
927
+ ax2.plot(x, cdf, 'r-', linewidth=2)
928
+ ax2.set_title('CDF: $F(x)$ for the Example Distribution')
929
+ ax2.set_xlabel('x')
930
+ ax2.set_ylabel('Cumulative Probability')
931
+ ax2.grid(alpha=0.3)
932
+
933
+ # Mark appropriate (F(1) & F(2)) points)
934
+ ax2.plot(1, cdf[np.abs(x-1).argmin()], 'ro', markersize=8)
935
+ ax2.plot(2, cdf[np.abs(x-2).argmin()], 'ro', markersize=8)
936
+
937
+ # annotations
938
+ F_1 = C * (2*1**2 - (2*1**3)/3) # F(1)
939
+ ax2.annotate(f'F(1) = {F_1:.3f}', xy=(1, F_1), xytext=(1.1, 0.4),
940
+ arrowprops=dict(facecolor='black', shrink=0.05, width=1))
941
+
942
+ ax2.annotate(f'F(2) = 1', xy=(2, 1), xytext=(1.7, 0.8),
943
+ arrowprops=dict(facecolor='black', shrink=0.05, width=1))
944
+
945
+ ax2.annotate(f'P(X > 1) = 1 - F(1) = {1-F_1:.3f}', xy=(1.5, 0.7),
946
+ bbox=dict(boxstyle='round,pad=0.5', facecolor='orange', alpha=0.2))
947
+
948
+ # common x-limits
949
+ for ax in [ax1, ax2]:
950
+ ax.set_xlim(-0.25, 2.25)
951
+
952
+ plt.tight_layout()
953
+ plt.gca()
954
+ return fig
955
+
956
+ def symbolic_calculation():
957
+ x = sympy.symbols('x')
958
+ C = sympy.Rational(3, 8)
959
+
960
+ # PDF defn
961
+ pdf_expr = C * (4*x - 2*x**2)
962
+
963
+ # Verify PDF integrates to 1
964
+ total_prob = sympy.integrate(pdf_expr, (x, 0, 2))
965
+
966
+ # Calculate P(X > 1)
967
+ prob_gt_1 = sympy.integrate(pdf_expr, (x, 1, 2))
968
+
969
+ return f"""Symbolic calculation verification:
970
+
971
+ 1. Total probability: ∫₀² {C}(4x - 2x²) dx = {total_prob}
972
+ 2. P(X > 1): ∫₁² {C}(4x - 2x²) dx = {prob_gt_1}
973
+ """
974
+
975
+ return create_example_pdf_visualization, symbolic_calculation
976
+
977
+
978
+ if __name__ == "__main__":
979
+ app.run()
probability/17_normal_distribution.py ADDED
@@ -0,0 +1,1127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # /// script
2
+ # requires-python = ">=3.10"
3
+ # dependencies = [
4
+ # "marimo",
5
+ # "matplotlib==3.10.1",
6
+ # "scipy==1.15.2",
7
+ # "wigglystuff==0.1.10",
8
+ # "numpy==2.2.4",
9
+ # ]
10
+ # ///
11
+
12
+ import marimo
13
+
14
+ __generated_with = "0.11.26"
15
+ app = marimo.App(width="medium", app_title="Normal Distribution")
16
+
17
+
18
+ @app.cell(hide_code=True)
19
+ def _(mo):
20
+ mo.md(
21
+ r"""
22
+ # Normal Distribution
23
+
24
+ _This notebook is a computational companion to ["Probability for Computer Scientists"](https://chrispiech.github.io/probabilityForComputerScientists/en/part2/normal/), by Stanford professor Chris Piech._
25
+
26
+ The Normal (also known as Gaussian) distribution is one of the most important probability distributions in statistics and data science. It's characterized by a symmetric bell-shaped curve and is fully defined by two parameters: mean (μ) and variance (σ²).
27
+ """
28
+ )
29
+ return
30
+
31
+
32
+ @app.cell(hide_code=True)
33
+ def _(mo):
34
+ mo.md(
35
+ r"""
36
+ ## Normal Random Variable Definition
37
+
38
+ The Normal (or Gaussian) random variable is denoted as:
39
+
40
+ $$X \sim \mathcal{N}(\mu, \sigma^2)$$
41
+
42
+ Where:
43
+
44
+ - $X$ is our random variable
45
+ - $\mathcal{N}$ indicates it follows a Normal distribution
46
+ - $\mu$ is the mean parameter
47
+ - $\sigma^2$ is the variance parameter (sometimes written as $\sigma$ for standard deviation)
48
+
49
+ ```
50
+ X ~ N(μ, σ²)
51
+ ↑ ↑ ↑ ↑
52
+ | | | +-- Variance (spread)
53
+ | | | of the distribution
54
+ | | +-- Mean (center)
55
+ | | of the distribution
56
+ | +-- Indicates Normal
57
+ | distribution
58
+ |
59
+ Our random variable
60
+ ```
61
+
62
+ The Normal distribution is particularly important for many reasons:
63
+
64
+ 1. It arises naturally from the sum of independent random variables (Central Limit Theorem)
65
+ 2. It appears frequently in natural phenomena
66
+ 3. It is the maximum entropy distribution given a fixed mean and variance
67
+ 4. It simplifies many mathematical calculations in statistics and probability
68
+ """
69
+ )
70
+ return
71
+
72
+
73
+ @app.cell(hide_code=True)
74
+ def _(mo):
75
+ mo.md(
76
+ r"""
77
+ ## Properties of Normal Distribution
78
+
79
+ | Property | Formula |
80
+ |----------|---------|
81
+ | Notation | $X \sim \mathcal{N}(\mu, \sigma^2)$ |
82
+ | Description | A common, naturally occurring distribution |
83
+ | Parameters | $\mu \in \mathbb{R}$, the mean<br>$\sigma^2 \in \mathbb{R}^+$, the variance |
84
+ | Support | $x \in \mathbb{R}$ |
85
+ | PDF equation | $f(x) = \frac{1}{\sigma\sqrt{2\pi}}e^{-\frac{1}{2}(\frac{x-\mu}{\sigma})^2}$ |
86
+ | CDF equation | $F(x) = \Phi(\frac{x-\mu}{\sigma})$ where $\Phi$ is the CDF of the standard normal |
87
+ | Expectation | $E[X] = \mu$ |
88
+ | Variance | $\text{Var}(X) = \sigma^2$ |
89
+
90
+ The PDF (Probability Density Function) reaches its maximum value at $x = \mu$, where the exponent becomes zero and $e^0 = 1$.
91
+ """
92
+ )
93
+ return
94
+
95
+
96
+ @app.cell(hide_code=True)
97
+ def _(mean_slider, mo, std_slider):
98
+ mo.md(
99
+ f"""
100
+ The figure below shows a comparison between:
101
+
102
+ - The **Standard Normal Distribution** (purple curve): N(0, 1)
103
+ - A **Normal Distribution** with the parameters you selected (blue curve)
104
+
105
+ Adjust the mean (μ) {mean_slider} and standard deviation (σ) {std_slider} below to see how the normal distribution changes shape.
106
+
107
+ """
108
+ )
109
+ return
110
+
111
+
112
+ @app.cell(hide_code=True)
113
+ def _(
114
+ create_distribution_comparison,
115
+ fig_to_image,
116
+ mean_slider,
117
+ mo,
118
+ std_slider,
119
+ ):
120
+ # values from the sliders
121
+ current_mu = mean_slider.amount
122
+ current_sigma = std_slider.amount
123
+
124
+ # Create plot
125
+ comparison_fig = create_distribution_comparison(current_mu, current_sigma)
126
+
127
+ # Call, convert and display
128
+ comp_image = mo.image(fig_to_image(comparison_fig), width="100%")
129
+ comp_image
130
+ return comp_image, comparison_fig, current_mu, current_sigma
131
+
132
+
133
+ @app.cell(hide_code=True)
134
+ def _(mean_slider, mo, std_slider):
135
+ mo.md(
136
+ f"""
137
+ ## Interactive Normal Distribution Visualization
138
+
139
+ The shape of a normal distribution is determined by two key parameters:
140
+
141
+ - The **mean (μ):** {mean_slider} controls the center of the distribution.
142
+
143
+ - The **standard deviation (σ):** {std_slider} controls the spread (width) of the distribution.
144
+
145
+ Try adjusting these parameters to see how they affect the shape of the distribution below:
146
+
147
+ """
148
+ )
149
+ return
150
+
151
+
152
+ @app.cell(hide_code=True)
153
+ def _(create_normal_pdf_plot, fig_to_image, mean_slider, mo, std_slider):
154
+ # value from widgets
155
+ _current_mu = mean_slider.amount
156
+ _current_sigma = std_slider.amount
157
+
158
+ # Create visualization
159
+ pdf_fig = create_normal_pdf_plot(_current_mu, _current_sigma)
160
+
161
+ # Display plot
162
+ pdf_image = mo.image(fig_to_image(pdf_fig), width="100%")
163
+
164
+ pdf_explanation = mo.md(
165
+ r"""
166
+ **Understanding the Normal Distribution Visualization:**
167
+
168
+ - **PDF (top)**: The probability density function shows the relative likelihood of different values.
169
+ The highest point occurs at the mean (μ).
170
+
171
+ - **Shaded regions**: The green shaded areas represent:
172
+ - μ ± 1σ: Contains approximately 68.3% of the probability
173
+ - μ ± 2σ: Contains approximately 95.5% of the probability
174
+ - μ ± 3σ: Contains approximately 99.7% of the probability (the "68-95-99.7 rule")
175
+
176
+ - **CDF (bottom)**: The cumulative distribution function shows the probability that X is less than or equal to a given value.
177
+ - At x = μ, the CDF equals 0.5 (50% probability)
178
+ - At x = μ + σ, the CDF equals approximately 0.84 (84% probability)
179
+ - At x = μ - σ, the CDF equals approximately 0.16 (16% probability)
180
+ """
181
+ )
182
+
183
+ mo.vstack([pdf_image, pdf_explanation])
184
+ return pdf_explanation, pdf_fig, pdf_image
185
+
186
+
187
+ @app.cell(hide_code=True)
188
+ def _(mo):
189
+ mo.md(
190
+ r"""
191
+ ## Standard Normal Distribution
192
+
193
+ The **Standard Normal Distribution** is a special case of the normal distribution where $\mu = 0$ and $\sigma = 1$. We denote it as:
194
+
195
+ $$Z \sim \mathcal{N}(0, 1)$$
196
+
197
+ This distribution is particularly important because:
198
+
199
+ 1. Any normal distribution can be transformed into the standard normal
200
+ 2. Statistical tables and calculations often use the standard normal as a reference
201
+
202
+ ### Standardizing a Normal Random Variable
203
+
204
+ For any normal random variable $X \sim \mathcal{N}(\mu, \sigma^2)$, we can transform it to the standard normal $Z$ using:
205
+
206
+ $$Z = \frac{X - \mu}{\sigma}$$
207
+
208
+ Let's see the mathematical derivation:
209
+
210
+ \begin{align*}
211
+ W &= \frac{X -\mu}{\sigma} && \text{Subtract by $\mu$ and divide by $\sigma$} \\
212
+ &= \frac{1}{\sigma}X - \frac{\mu}{\sigma} && \text{Use algebra to rewrite the equation}\\
213
+ &= aX + b && \text{Linear transform where $a = \frac{1}{\sigma}$, $b = -\frac{\mu}{\sigma}$}\\
214
+ &\sim \mathcal{N}(a\mu + b, a^2\sigma^2) && \text{The linear transform of a Normal is another Normal}\\
215
+ &\sim \mathcal{N}\left(\frac{\mu}{\sigma} - \frac{\mu}{\sigma}, \frac{\sigma^2}{\sigma^2}\right) && \text{Substitute values for $a$ and $b$}\\
216
+ &\sim \mathcal{N}(0, 1) && \text{The standard normal}
217
+ \end{align*}
218
+
219
+ This transformation is the foundation for many statistical tests and probability calculations.
220
+ """
221
+ )
222
+ return
223
+
224
+
225
+ @app.cell(hide_code=True)
226
+ def _(create_standardization_plot, fig_to_image, mo):
227
+ # Create and display visualization
228
+ stand_fig = create_standardization_plot()
229
+
230
+ # Display
231
+ stand_image = mo.image(fig_to_image(stand_fig), width="100%")
232
+
233
+ stand_explanation = mo.md(
234
+ r"""
235
+ **Standardizing a Normal Distribution: A Two-Step Process**
236
+
237
+ The visualization above shows the process of transforming any normal distribution to the standard normal:
238
+
239
+ 1. **Shift the distribution** (left plot): First, we subtract the mean (μ) from X, centering the distribution at 0.
240
+
241
+ 2. **Scale the distribution** (right plot): Next, we divide by the standard deviation (σ), which adjusts the spread to match the standard normal.
242
+
243
+ The resulting standard normal distribution Z ~ N(0,1) has a mean of 0 and a variance of 1.
244
+
245
+ This transformation allows us to use standardized tables and calculations for any normal distribution.
246
+ """
247
+ )
248
+
249
+ mo.vstack([stand_image, stand_explanation])
250
+ return stand_explanation, stand_fig, stand_image
251
+
252
+
253
+ @app.cell(hide_code=True)
254
+ def _(mo):
255
+ mo.md(
256
+ r"""
257
+ ## Linear Transformations of Normal Variables
258
+
259
+ One useful property of the normal distribution is that linear transformations of normal random variables remain normal.
260
+
261
+ If $X \sim \mathcal{N}(\mu, \sigma^2)$ and $Y = aX + b$ (where $a$ and $b$ are constants), then:
262
+
263
+ $$Y \sim \mathcal{N}(a\mu + b, a^2\sigma^2)$$
264
+
265
+ This means:
266
+
267
+ - The mean is transformed by $a\mu + b$
268
+ - The variance is transformed by $a^2\sigma^2$
269
+
270
+ This property is extremely useful in statistics and probability calculations, as it allows us to easily determine the _distribution_ of transformed variables.
271
+ """
272
+ )
273
+ return
274
+
275
+
276
+ @app.cell(hide_code=True)
277
+ def _(mo):
278
+ mo.md(
279
+ r"""
280
+ ## Calculating Probabilities with the Normal CDF
281
+
282
+ Unlike many other distributions, the normal distribution does not have a closed-form expression for its CDF. However, we can use the standard normal CDF (denoted as $\Phi$) to calculate probabilities.
283
+
284
+ For any normal random variable $X \sim \mathcal{N}(\mu, \sigma^2)$, the CDF is:
285
+
286
+ $$F_X(x) = P(X \leq x) = \Phi\left(\frac{x - \mu}{\sigma}\right)$$
287
+
288
+ Where $\Phi$ is the CDF of the standard normal distribution.
289
+
290
+ ### Derivation
291
+
292
+ \begin{align*}
293
+ F_X(x) &= P(X \leq x) \\
294
+ &= P\left(\frac{X - \mu}{\sigma} \leq \frac{x - \mu}{\sigma}\right) \\
295
+ &= P\left(Z \leq \frac{x - \mu}{\sigma}\right) \\
296
+ &= \Phi\left(\frac{x - \mu}{\sigma}\right)
297
+ \end{align*}
298
+
299
+ Let's look at some examples of calculating probabilities with normal distributions.
300
+ """
301
+ )
302
+ return
303
+
304
+
305
+ @app.cell(hide_code=True)
306
+ def _(mo):
307
+ mo.md("""## Examples of Normal Distributions""")
308
+ return
309
+
310
+
311
+ @app.cell(hide_code=True)
312
+ def _(create_probability_example, fig_to_image, mo):
313
+ # Create visualization
314
+ default_mu = 3
315
+ default_sigma = 4
316
+ default_query = 0
317
+
318
+ prob_fig, prob_value, ex_z_score = create_probability_example(default_mu, default_sigma, default_query)
319
+
320
+ # Display
321
+ prob_image = mo.image(fig_to_image(prob_fig), width="100%")
322
+
323
+ prob_explanation = mo.md(
324
+ f"""
325
+ **Example: Let X ~ N(3, 16), what is P(X > 0)?**
326
+
327
+ To solve this probability question:
328
+
329
+ 1. First, we standardize the query value:
330
+ Z = (x - μ) / σ = (0 - 3) / 4 = -0.75
331
+
332
+ 2. Then we calculate using the standard normal CDF:
333
+ P(X > 0) = P(Z > -0.75) = 1 - P(Z ≤ -0.75) = 1 - Φ(-0.75)
334
+
335
+ 3. Because the standard normal is symmetric:
336
+ 1 - Φ(-0.75) = Φ(0.75) = {prob_value:.3f}
337
+
338
+ The shaded orange area in the graph represents this probability of approximately {prob_value:.3f}.
339
+ """
340
+ )
341
+
342
+ mo.vstack([prob_image, prob_explanation])
343
+ return (
344
+ default_mu,
345
+ default_query,
346
+ default_sigma,
347
+ ex_z_score,
348
+ prob_explanation,
349
+ prob_fig,
350
+ prob_image,
351
+ prob_value,
352
+ )
353
+
354
+
355
+ @app.cell(hide_code=True)
356
+ def _(create_range_probability_example, fig_to_image, mo, stats):
357
+ # Create visualization
358
+ default_range_mu = 3
359
+ default_range_sigma = 4
360
+ default_range_lower = 2
361
+ default_range_upper = 5
362
+
363
+ range_fig, range_prob, range_z_lower, range_z_upper = create_range_probability_example(
364
+ default_range_mu, default_range_sigma, default_range_lower, default_range_upper)
365
+
366
+ # Display
367
+ range_image = mo.image(fig_to_image(range_fig), width="100%")
368
+
369
+ range_explanation = mo.md(
370
+ f"""
371
+ **Example: Let X ~ N(3, 16), what is P(2 < X < 5)?**
372
+
373
+ To solve this range probability question:
374
+
375
+ 1. First, we standardize both bounds:
376
+ Z_lower = (lower - μ) / σ = (2 - 3) / 4 = -0.25
377
+ Z_upper = (upper - μ) / σ = (5 - 3) / 4 = 0.5
378
+
379
+ 2. Then we calculate using the standard normal CDF:
380
+ P(2 < X < 5) = P(-0.25 < Z < 0.5)
381
+ = Φ(0.5) - Φ(-0.25)
382
+ = Φ(0.5) - (1 - Φ(0.25))
383
+ = Φ(0.5) + Φ(0.25) - 1
384
+
385
+ 3. Computing these values:
386
+ = {stats.norm.cdf(0.5):.3f} + {stats.norm.cdf(0.25):.3f} - 1
387
+ = {range_prob:.3f}
388
+
389
+ The shaded orange area in the graph represents this probability of approximately {range_prob:.3f}.
390
+ """
391
+ )
392
+
393
+ mo.vstack([range_image, range_explanation])
394
+ return (
395
+ default_range_lower,
396
+ default_range_mu,
397
+ default_range_sigma,
398
+ default_range_upper,
399
+ range_explanation,
400
+ range_fig,
401
+ range_image,
402
+ range_prob,
403
+ range_z_lower,
404
+ range_z_upper,
405
+ )
406
+
407
+
408
+ @app.cell(hide_code=True)
409
+ def _(create_voltage_example_visualization, fig_to_image, mo):
410
+ # Create visualization
411
+ voltage_fig, voltage_error_prob = create_voltage_example_visualization()
412
+
413
+ # Display
414
+ voltage_image = mo.image(fig_to_image(voltage_fig), width="100%")
415
+
416
+ voltage_explanation = mo.md(
417
+ r"""
418
+ **Example: Signal Transmission with Noise**
419
+
420
+ In this example, we're sending digital signals over a wire:
421
+
422
+ - We send voltage 2 to represent a binary "1"
423
+ - We send voltage -2 to represent a binary "0"
424
+
425
+ The received signal R is the sum of the transmitted voltage (X) and random noise (Y):
426
+ R = X + Y, where Y ~ N(0, 1)
427
+
428
+ When decoding, we use a threshold of 0.5:
429
+
430
+ - If R ≥ 0.5, we interpret it as "1"
431
+ - If R < 0.5, we interpret it as "0"
432
+
433
+ Let's calculate the probability of error when sending a "1" (voltage = 2):
434
+
435
+ \begin{align*}
436
+ P(\text{Error when sending "1"}) &= P(X + Y < 0.5) \\
437
+ &= P(2 + Y < 0.5) \\
438
+ &= P(Y < -1.5) \\
439
+ &= \Phi(-1.5) \\
440
+ &\approx 0.067
441
+ \end{align*}
442
+
443
+ Therefore, the probability of incorrectly decoding a transmitted "1" as "0" is approximately 6.7%.
444
+
445
+ The orange shaded area in the plot represents this error probability.
446
+ """
447
+ )
448
+
449
+ mo.vstack([voltage_image, voltage_explanation])
450
+ return voltage_error_prob, voltage_explanation, voltage_fig, voltage_image
451
+
452
+
453
+ @app.cell(hide_code=True)
454
+ def emirical_rule(mo):
455
+ mo.md(
456
+ r"""
457
+ ## The 68-95-99.7 Rule (Empirical Rule)
458
+
459
+ One of the most useful properties of the normal distribution is the "[68-95-99.7 rule](https://en.wikipedia.org/wiki/68-95-99.7_rule)," which states that:
460
+
461
+ - Approximately 68% of the data falls within 1 standard deviation of the mean
462
+ - Approximately 95% of the data falls within 2 standard deviations of the mean
463
+ - Approximately 99.7% of the data falls within 3 standard deviations of the mean
464
+
465
+ Let's verify this with a calculation for the 68% rule:
466
+
467
+ \begin{align}
468
+ P(\mu - \sigma < X < \mu + \sigma)
469
+ &= P(X < \mu + \sigma) - P(X < \mu - \sigma) \\
470
+ &= \Phi\left(\frac{(\mu + \sigma)-\mu}{\sigma}\right) - \Phi\left(\frac{(\mu - \sigma)-\mu}{\sigma}\right) \\
471
+ &= \Phi\left(\frac{\sigma}{\sigma}\right) - \Phi\left(\frac{-\sigma}{\sigma}\right) \\
472
+ &= \Phi(1) - \Phi(-1) \\
473
+ &\approx 0.8413 - 0.1587 \\
474
+ &\approx 0.6826 \approx 68.3\%
475
+ \end{align}
476
+
477
+ This calculation works for any normal distribution, regardless of the values of $\mu$ and $\sigma$!
478
+ """
479
+ )
480
+ return
481
+
482
+
483
+ @app.cell(hide_code=True)
484
+ def _(mo):
485
+ mo.md(r"""The Cumulative Distribution Function (CDF) gives the probability that a random variable is less than or equal to a specific value. Use the interactive calculator below to compute CDF values for a normal distribution.""")
486
+ return
487
+
488
+
489
+ @app.cell(hide_code=True)
490
+ def _(mo, mu_slider, sigma_slider, x_slider):
491
+ mo.md(
492
+ f"""
493
+ ## Interactive Normal CDF Calculator
494
+
495
+ Use the sliders below to explore different probability calculations:
496
+
497
+ **Query value (x):** {x_slider} — The value at which to evaluate F(x) = P(X ≤ x)
498
+
499
+ **Mean (μ):** {mu_slider} — The center of the distribution
500
+
501
+ **Standard deviation (σ):** {sigma_slider} — The spread of the distribution (larger σ means more spread)
502
+ """
503
+ )
504
+ return
505
+
506
+
507
+ @app.cell(hide_code=True)
508
+ def _(
509
+ create_cdf_calculator_plot,
510
+ fig_to_image,
511
+ mo,
512
+ mu_slider,
513
+ sigma_slider,
514
+ x_slider,
515
+ ):
516
+ # Values from widgets
517
+ calc_x = x_slider.amount
518
+ calc_mu = mu_slider.amount
519
+ calc_sigma = sigma_slider.amount
520
+
521
+ # Create visualization
522
+ calc_fig, cdf_value = create_cdf_calculator_plot(calc_x, calc_mu, calc_sigma)
523
+
524
+ # Standardized z-score
525
+ calc_z_score = (calc_x - calc_mu) / calc_sigma
526
+
527
+ # Display
528
+ calc_image = mo.image(fig_to_image(calc_fig), width="100%")
529
+
530
+ calc_result = mo.md(
531
+ f"""
532
+ ### Results:
533
+
534
+ For a Normal distribution with parameters μ = {calc_mu:.1f} and σ = {calc_sigma:.1f}:
535
+
536
+ - The value x = {calc_x:.1f} corresponds to a z-score of z = {calc_z_score:.3f}
537
+ - The CDF value F({calc_x:.1f}) = P(X ≤ {calc_x:.1f}) = {cdf_value:.3f}
538
+ - This means the probability that X is less than or equal to {calc_x:.1f} is {cdf_value*100:.1f}%
539
+
540
+ **Computing this in Python:**
541
+ ```python
542
+ from scipy import stats
543
+
544
+ # Using the one-line method
545
+ p = stats.norm.cdf({calc_x:.1f}, {calc_mu:.1f}, {calc_sigma:.1f})
546
+
547
+ # OR using the two-line method
548
+ X = stats.norm({calc_mu:.1f}, {calc_sigma:.1f})
549
+ p = X.cdf({calc_x:.1f})
550
+ ```
551
+
552
+ **Note:** In SciPy's `stats.norm`, the second parameter is the standard deviation (σ), not the variance (σ²).
553
+ """
554
+ )
555
+
556
+ mo.vstack([calc_image, calc_result])
557
+ return (
558
+ calc_fig,
559
+ calc_image,
560
+ calc_mu,
561
+ calc_result,
562
+ calc_sigma,
563
+ calc_x,
564
+ calc_z_score,
565
+ cdf_value,
566
+ )
567
+
568
+
569
+ @app.cell(hide_code=True)
570
+ def _(mo):
571
+ mo.md(
572
+ r"""
573
+ ## 🤔 Test Your Understanding
574
+
575
+ Test your knowledge with these true/false questions about normal distributions:
576
+
577
+ /// details | For a normal random variable X ~ N(μ, σ²), the probability that X takes on exactly the value μ is highest among all possible values.
578
+
579
+ **✅ True**
580
+
581
+ While the PDF is indeed highest at x = μ, making this the most likely value in terms of density, remember that for continuous random variables, the probability of any exact value is zero. The statement refers to the density function being maximized at the mean.
582
+ ///
583
+
584
+ /// details | The probability that a normal random variable X equals any specific exact value (e.g., P(X = 3)) is always zero.
585
+
586
+ **✅ True**
587
+
588
+ For continuous random variables including the normal, the probability of any exact value is zero. Probabilities only make sense for ranges of values, which is why we integrate the PDF over intervals.
589
+ ///
590
+
591
+ /// details | If X ~ N(μ, σ²), then aX + b ~ N(aμ + b, a²σ²) for any constants a and b.
592
+
593
+ **✅ True**
594
+
595
+ Linear transformations of normal random variables remain normal, with the given transformation of the parameters. This is a key property that makes normal distributions particularly useful.
596
+ ///
597
+
598
+ /// details | If X ~ N(5, 9) and Y ~ N(3, 4) are independent, then X + Y ~ N(8, 5).
599
+
600
+ **❌ False**
601
+
602
+ While the mean of the sum is indeed the sum of the means (5 + 3 = 8), the variance of the sum is the sum of the variances (9 + 4 = 13), not 5. The correct distribution would be X + Y ~ N(8, 13).
603
+ ///
604
+ """
605
+ )
606
+ return
607
+
608
+
609
+ @app.cell(hide_code=True)
610
+ def _(mo):
611
+ mo.md(
612
+ r"""
613
+ ## Summary
614
+
615
+ We've taken a tour of Normal distributions; probably the most famous probability distribution you'll encounter in statistics. It's that nice bell-shaped curve that shows up everywhere from heights/ weights to memes to measurement errors & stock returns.
616
+
617
+ The Normal distribution isn't just pretty — it's incredibly practical. With just two parameters (mean and standard deviation), you can describe complex phenomena and make powerful predictions. Plus, thanks to the Central Limit Theorem, many random processes naturally converge to this distribution, which is why it's so prevalent.
618
+
619
+ **What we covered:**
620
+
621
+ - The mathematical definition and key properties of Normal random variables
622
+
623
+ - How to transform any Normal distribution to the standard Normal
624
+
625
+ - Calculating probabilities using the CDF (no more looking up values in those tiny tables in the back of textbooks or Clark's table!)
626
+
627
+ Whether you're analyzing data, designing experiments, or building ML models, the concepts we explored provide a solid foundation for working with this fundamental distribution.
628
+ """
629
+ )
630
+ return
631
+
632
+
633
+ @app.cell(hide_code=True)
634
+ def _(mo):
635
+ mo.md(r"""Appendix (helper code and functions)""")
636
+ return
637
+
638
+
639
+ @app.cell
640
+ def _():
641
+ import marimo as mo
642
+ return (mo,)
643
+
644
+
645
+ @app.cell(hide_code=True)
646
+ def _():
647
+ from wigglystuff import TangleSlider
648
+ return (TangleSlider,)
649
+
650
+
651
+ @app.cell(hide_code=True)
652
+ def _(np, plt, stats):
653
+ def create_normal_pdf_plot(mu, sigma):
654
+
655
+ # Range for x values (show μ ± 4σ)
656
+ x = np.linspace(mu - 4*sigma, mu + 4*sigma, 1000)
657
+ pdf = stats.norm.pdf(x, mu, sigma)
658
+
659
+ # Calculate CDF values
660
+ cdf = stats.norm.cdf(x, mu, sigma)
661
+
662
+ # Create plot with two subplots for (PDF and CDF)
663
+ pdf_fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(10, 8))
664
+
665
+ # PDF plot
666
+ ax1.plot(x, pdf, color='royalblue', linewidth=2, label='PDF')
667
+ ax1.fill_between(x, pdf, color='royalblue', alpha=0.2)
668
+
669
+ # Vertical line at mean
670
+ ax1.axvline(x=mu, color='red', linestyle='--', linewidth=1.5,
671
+ label=f'Mean: μ = {mu:.1f}')
672
+
673
+ # Stdev regions
674
+ for i in range(1, 4):
675
+ alpha = 0.1 if i > 1 else 0.2
676
+ percentage = 100*stats.norm.cdf(i) - 100*stats.norm.cdf(-i)
677
+ label = f'μ ± {i}σ: {percentage:.1f}%' if i == 1 else None
678
+ ax1.axvspan(mu - i*sigma, mu + i*sigma, alpha=alpha, color='green',
679
+ label=label)
680
+
681
+ # Annotations
682
+ ax1.annotate(f'μ = {mu:.1f}', xy=(mu, max(pdf)*0.15), xytext=(mu+0.5*sigma, max(pdf)*0.4),
683
+ arrowprops=dict(facecolor='black', width=1, shrink=0.05))
684
+
685
+ ax1.annotate(f'σ = {sigma:.1f}',
686
+ xy=(mu+sigma, stats.norm.pdf(mu+sigma, mu, sigma)),
687
+ xytext=(mu+1.5*sigma, stats.norm.pdf(mu+sigma, mu, sigma)*1.5),
688
+ arrowprops=dict(facecolor='black', width=1, shrink=0.05))
689
+
690
+ # some styling
691
+ ax1.set_title(f'Normal Distribution PDF: N({mu:.1f}, {sigma:.1f}²)')
692
+ ax1.set_xlabel('x')
693
+ ax1.set_ylabel('Probability Density: f(x)')
694
+ ax1.legend(loc='upper right')
695
+ ax1.grid(alpha=0.3)
696
+
697
+ # CDF plot
698
+ ax2.plot(x, cdf, color='darkorange', linewidth=2, label='CDF')
699
+
700
+ # key CDF values mark
701
+ key_points = [
702
+ (mu-sigma, stats.norm.cdf(mu-sigma, mu, sigma), "16%"),
703
+ (mu, 0.5, "50%"),
704
+ (mu+sigma, stats.norm.cdf(mu+sigma, mu, sigma), "84%")
705
+ ]
706
+
707
+ for point, value, label in key_points:
708
+ ax2.plot(point, value, 'ro')
709
+ ax2.annotate(f'{label}',
710
+ xy=(point, value),
711
+ xytext=(point+0.2*sigma, value-0.1),
712
+ arrowprops=dict(facecolor='black', width=1, shrink=0.05))
713
+
714
+ # CDF styling
715
+ ax2.set_title(f'Normal Distribution CDF: N({mu:.1f}, {sigma:.1f}²)')
716
+ ax2.set_xlabel('x')
717
+ ax2.set_ylabel('Cumulative Probability: F(x)')
718
+ ax2.grid(alpha=0.3)
719
+
720
+ plt.tight_layout()
721
+ return pdf_fig
722
+ return (create_normal_pdf_plot,)
723
+
724
+
725
+ @app.cell(hide_code=True)
726
+ def _(base64, io):
727
+ from matplotlib.figure import Figure
728
+
729
+ # convert matplotlib figures to images (helper code)
730
+ def fig_to_image(fig):
731
+ buf = io.BytesIO()
732
+ fig.savefig(buf, format='png', bbox_inches='tight')
733
+ buf.seek(0)
734
+ img_str = base64.b64encode(buf.getvalue()).decode('utf-8')
735
+ return f"data:image/png;base64,{img_str}"
736
+ return Figure, fig_to_image
737
+
738
+
739
+ @app.cell(hide_code=True)
740
+ def _():
741
+ # Import libraries
742
+ import numpy as np
743
+ import matplotlib.pyplot as plt
744
+ from scipy import stats
745
+ import io
746
+ import base64
747
+ return base64, io, np, plt, stats
748
+
749
+
750
+ @app.cell(hide_code=True)
751
+ def _(TangleSlider, mo):
752
+ mean_slider = mo.ui.anywidget(TangleSlider(
753
+ amount=0,
754
+ min_value=-5,
755
+ max_value=5,
756
+ step=0.1,
757
+ digits=1
758
+ ))
759
+
760
+ std_slider = mo.ui.anywidget(TangleSlider(
761
+ amount=1,
762
+ min_value=0.1,
763
+ max_value=3,
764
+ step=0.1,
765
+ digits=1
766
+ ))
767
+ return mean_slider, std_slider
768
+
769
+
770
+ @app.cell(hide_code=True)
771
+ def _(TangleSlider, mo):
772
+ x_slider = mo.ui.anywidget(TangleSlider(
773
+ amount=0,
774
+ min_value=-5,
775
+ max_value=5,
776
+ step=0.1,
777
+ digits=1
778
+ ))
779
+
780
+ mu_slider = mo.ui.anywidget(TangleSlider(
781
+ amount=0,
782
+ min_value=-5,
783
+ max_value=5,
784
+ step=0.1,
785
+ digits=1
786
+ ))
787
+
788
+ sigma_slider = mo.ui.anywidget(TangleSlider(
789
+ amount=1,
790
+ min_value=0.1,
791
+ max_value=3,
792
+ step=0.1,
793
+ digits=1
794
+ ))
795
+ return mu_slider, sigma_slider, x_slider
796
+
797
+
798
+ @app.cell(hide_code=True)
799
+ def _(np, plt, stats):
800
+ def create_distribution_comparison(mu=5, sigma=6):
801
+
802
+ # Create figure and axis
803
+ comparison_fig, ax = plt.subplots(figsize=(10, 6))
804
+
805
+ # X range for plotting
806
+ x = np.linspace(-10, 20, 1000)
807
+
808
+ # Standard normal
809
+ std_normal = stats.norm.pdf(x, 0, 1)
810
+
811
+ # Our example normal
812
+ example_normal = stats.norm.pdf(x, mu, sigma)
813
+
814
+ # Plot both distributions
815
+ ax.plot(x, std_normal, 'darkviolet', linewidth=2, label='Standard Normal')
816
+ ax.plot(x, example_normal, 'blue', linewidth=2, label=f'X ~ N({mu}, {sigma}²)')
817
+
818
+ # format the plot
819
+ ax.set_xlim(-10, 20)
820
+ ax.set_ylim(0, 0.45)
821
+ ax.set_xlabel('x')
822
+ ax.set_ylabel('Probability Density')
823
+ ax.grid(True, alpha=0.3)
824
+ ax.legend()
825
+
826
+ # Decorative text box for parameters
827
+ props = dict(boxstyle='round', facecolor='white', alpha=0.9)
828
+ textstr = '\n'.join((
829
+ r'Normal (aka Gaussian) Random Variable',
830
+ r'',
831
+ f'Parameter $\mu$: {mu}',
832
+ f'Parameter $\sigma$: {sigma}'
833
+ ))
834
+ ax.text(0.05, 0.95, textstr, transform=ax.transAxes, fontsize=10,
835
+ verticalalignment='top', bbox=props)
836
+
837
+ return comparison_fig
838
+ return (create_distribution_comparison,)
839
+
840
+
841
+ @app.cell(hide_code=True)
842
+ def _(np, plt, stats):
843
+ def create_voltage_example_visualization():
844
+
845
+ # Create data for plotting
846
+ x = np.linspace(-4, 4, 1000)
847
+
848
+ # Signal without noise (X = 2)
849
+ signal_value = 2
850
+
851
+ # Noise distribution (Y ~ N(0, 1))
852
+ noise_pdf = stats.norm.pdf(x, 0, 1)
853
+
854
+ # Signal + Noise distribution (R = X + Y ~ N(2, 1))
855
+ received_pdf = stats.norm.pdf(x, signal_value, 1)
856
+
857
+ # Create figure
858
+ voltage_fig, ax = plt.subplots(figsize=(10, 6))
859
+
860
+ # Plot the noise distribution
861
+ ax.plot(x, noise_pdf, 'blue', linewidth=1.5, alpha=0.6,
862
+ label='Noise: Y ~ N(0, 1)')
863
+
864
+ # received signal distribution
865
+ ax.plot(x, received_pdf, 'red', linewidth=2,
866
+ label=f'Received: R ~ N({signal_value}, 1)')
867
+
868
+ # vertical line at the decision boundary (0.5)
869
+ threshold = 0.5
870
+ ax.axvline(x=threshold, color='green', linestyle='--', linewidth=2,
871
+ label=f'Decision threshold: {threshold}')
872
+
873
+ # Shade the error region
874
+ mask = x < threshold
875
+ error_prob = stats.norm.cdf(threshold, signal_value, 1)
876
+ ax.fill_between(x[mask], received_pdf[mask], color='darkorange', alpha=0.5,
877
+ label=f'Error probability: {error_prob:.3f}')
878
+
879
+ # Styling
880
+ ax.set_title('Voltage Transmission Example: Probability of Error')
881
+ ax.set_xlabel('Voltage')
882
+ ax.set_ylabel('Probability Density')
883
+ ax.legend(loc='upper left')
884
+ ax.grid(alpha=0.3)
885
+
886
+ # Add explanatory annotations
887
+ ax.text(1.5, 0.1, 'When sending "1" (voltage=2),\nthis area represents\nthe error probability',
888
+ bbox=dict(boxstyle="round,pad=0.3", fc="white", ec="black", lw=1))
889
+
890
+ plt.tight_layout()
891
+ plt.gca()
892
+ return voltage_fig, error_prob
893
+ return (create_voltage_example_visualization,)
894
+
895
+
896
+ @app.cell(hide_code=True)
897
+ def _(np, plt, stats):
898
+ def create_cdf_calculator_plot(calc_x, calc_mu, calc_sigma):
899
+
900
+ # Data range for plotting
901
+ x_range = np.linspace(calc_mu - 4*calc_sigma, calc_mu + 4*calc_sigma, 1000)
902
+ pdf = stats.norm.pdf(x_range, calc_mu, calc_sigma)
903
+ cdf = stats.norm.cdf(x_range, calc_mu, calc_sigma)
904
+
905
+ # Calculate the CDF at x
906
+ cdf_at_x = stats.norm.cdf(calc_x, calc_mu, calc_sigma)
907
+
908
+ # Create figure with two subplots
909
+ calc_fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(10, 8))
910
+
911
+ # Plot PDF on top subplot
912
+ ax1.plot(x_range, pdf, color='royalblue', linewidth=2, label='PDF')
913
+
914
+ # area shade for P(X ≤ x)
915
+ mask = x_range <= calc_x
916
+ ax1.fill_between(x_range[mask], pdf[mask], color='darkorange', alpha=0.6)
917
+
918
+ # Vertical line at x
919
+ ax1.axvline(x=calc_x, color='red', linestyle='--', linewidth=1.5)
920
+
921
+ # PDF labels and styling
922
+ ax1.set_title(f'Normal PDF with Area P(X ≤ {calc_x:.1f}) Highlighted')
923
+ ax1.set_xlabel('x')
924
+ ax1.set_ylabel('Probability Density')
925
+ ax1.annotate(f'x = {calc_x:.1f}', xy=(calc_x, 0), xytext=(calc_x, -0.01),
926
+ horizontalalignment='center', color='red')
927
+ ax1.grid(alpha=0.3)
928
+
929
+ # CDF on bottom subplot
930
+ ax2.plot(x_range, cdf, color='green', linewidth=2, label='CDF')
931
+
932
+ # Mark the point (x, CDF(x))
933
+ ax2.plot(calc_x, cdf_at_x, 'ro', markersize=8)
934
+
935
+ # CDF labels and styling
936
+ ax2.set_title(f'Normal CDF: F({calc_x:.1f}) = {cdf_at_x:.3f}')
937
+ ax2.set_xlabel('x')
938
+ ax2.set_ylabel('Cumulative Probability')
939
+ ax2.annotate(f'F({calc_x:.1f}) = {cdf_at_x:.3f}',
940
+ xy=(calc_x, cdf_at_x),
941
+ xytext=(calc_x + 0.5*calc_sigma, cdf_at_x - 0.1),
942
+ arrowprops=dict(facecolor='black', width=1, shrink=0.05),
943
+ bbox=dict(boxstyle="round,pad=0.3", fc="white", ec="black", lw=1))
944
+ ax2.grid(alpha=0.3)
945
+
946
+ plt.tight_layout()
947
+ plt.gca()
948
+ return calc_fig, cdf_at_x
949
+ return (create_cdf_calculator_plot,)
950
+
951
+
952
+ @app.cell(hide_code=True)
953
+ def _(np, plt, stats):
954
+ def create_standardization_plot():
955
+
956
+ x = np.linspace(-6, 6, 1000)
957
+
958
+ # Original distribution N(2, 1.5²)
959
+ mu_original, sigma_original = 2, 1.5
960
+ pdf_original = stats.norm.pdf(x, mu_original, sigma_original)
961
+
962
+ # shifted distribution N(0, 1.5²)
963
+ mu_shifted, sigma_shifted = 0, 1.5
964
+ pdf_shifted = stats.norm.pdf(x, mu_shifted, sigma_shifted)
965
+
966
+ # Standard normal N(0, 1)
967
+ mu_standard, sigma_standard = 0, 1
968
+ pdf_standard = stats.norm.pdf(x, mu_standard, sigma_standard)
969
+
970
+ # Create visualization
971
+ stand_fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))
972
+
973
+ # Plot on left: Original and shifted distributions
974
+ ax1.plot(x, pdf_original, 'royalblue', linewidth=2,
975
+ label=f'Original: N({mu_original}, {sigma_original}²)')
976
+ ax1.plot(x, pdf_shifted, 'darkorange', linewidth=2,
977
+ label=f'Shifted: N({mu_shifted}, {sigma_shifted}²)')
978
+
979
+ # Add arrow to show the shift
980
+ shift_x1, shift_y1 = mu_original, stats.norm.pdf(mu_original, mu_original, sigma_original)*0.6
981
+ shift_x2, shift_y2 = mu_shifted, stats.norm.pdf(mu_shifted, mu_shifted, sigma_shifted)*0.6
982
+ ax1.annotate('', xy=(shift_x2, shift_y2), xytext=(shift_x1, shift_y1),
983
+ arrowprops=dict(facecolor='black', width=1.5, shrink=0.05))
984
+ ax1.text(0.8, 0.28, 'Subtract μ', transform=ax1.transAxes)
985
+
986
+ # Plot on right: Shifted and standard normal
987
+ ax2.plot(x, pdf_shifted, 'darkorange', linewidth=2,
988
+ label=f'Shifted: N({mu_shifted}, {sigma_shifted}²)')
989
+ ax2.plot(x, pdf_standard, 'green', linewidth=2,
990
+ label=f'Standard: N({mu_standard}, {sigma_standard}²)')
991
+
992
+ # Add arrow to show the scaling
993
+ scale_x1, scale_y1 = 2*sigma_shifted, stats.norm.pdf(2*sigma_shifted, mu_shifted, sigma_shifted)*0.8
994
+ scale_x2, scale_y2 = 2*sigma_standard, stats.norm.pdf(2*sigma_standard, mu_standard, sigma_standard)*0.8
995
+ ax2.annotate('', xy=(scale_x2, scale_y2), xytext=(scale_x1, scale_y1),
996
+ arrowprops=dict(facecolor='black', width=1.5, shrink=0.05))
997
+ ax2.text(0.75, 0.5, 'Divide by σ', transform=ax2.transAxes)
998
+
999
+ # some styling
1000
+ for ax in (ax1, ax2):
1001
+ ax.set_xlabel('x')
1002
+ ax.set_ylabel('Probability Density')
1003
+ ax.grid(alpha=0.3)
1004
+ ax.legend()
1005
+
1006
+ ax1.set_title('Step 1: Shift the Distribution')
1007
+ ax2.set_title('Step 2: Scale the Distribution')
1008
+
1009
+ plt.tight_layout()
1010
+ plt.gca()
1011
+ return stand_fig
1012
+ return (create_standardization_plot,)
1013
+
1014
+
1015
+ @app.cell(hide_code=True)
1016
+ def _(np, plt, stats):
1017
+ def create_probability_example(example_mu=3, example_sigma=4, example_query=0):
1018
+
1019
+ # Create data range
1020
+ x = np.linspace(example_mu - 4*example_sigma, example_mu + 4*example_sigma, 1000)
1021
+ pdf = stats.norm.pdf(x, example_mu, example_sigma)
1022
+
1023
+ # probability calc
1024
+ prob_value = 1 - stats.norm.cdf(example_query, example_mu, example_sigma)
1025
+ ex_z_score = (example_query - example_mu) / example_sigma
1026
+
1027
+ # Create visualization
1028
+ prob_fig, ax = plt.subplots(figsize=(10, 6))
1029
+
1030
+ # Plot PDF
1031
+ ax.plot(x, pdf, 'royalblue', linewidth=2)
1032
+
1033
+ # area shading representing the probability
1034
+ mask = x >= example_query
1035
+ ax.fill_between(x[mask], pdf[mask], color='darkorange', alpha=0.6)
1036
+
1037
+ # Add vertical line at query point
1038
+ ax.axvline(x=example_query, color='red', linestyle='--', linewidth=1.5)
1039
+
1040
+ # Annotations
1041
+ ax.annotate(f'x = {example_query}', xy=(example_query, 0), xytext=(example_query, -0.005),
1042
+ horizontalalignment='center')
1043
+
1044
+ ax.annotate(f'P(X > {example_query}) = {prob_value:.3f}',
1045
+ xy=(example_query + example_sigma, 0.015),
1046
+ xytext=(example_query + 1.5*example_sigma, 0.02),
1047
+ arrowprops=dict(facecolor='black', width=1, shrink=0.05),
1048
+ bbox=dict(boxstyle="round,pad=0.3", fc="white", ec="black", lw=1))
1049
+
1050
+ # Standard normal calculation annotation
1051
+ ax.annotate(f'= P(Z > {ex_z_score:.3f}) = {prob_value:.3f}',
1052
+ xy=(example_query - example_sigma, 0.01),
1053
+ xytext=(example_query - 2*example_sigma, 0.015),
1054
+ arrowprops=dict(facecolor='black', width=1, shrink=0.05),
1055
+ bbox=dict(boxstyle="round,pad=0.3", fc="white", ec="black", lw=1))
1056
+
1057
+ # some styling
1058
+ ax.set_title(f'Example: P(X > {example_query}) where X ~ N({example_mu}, {example_sigma}²)')
1059
+ ax.set_xlabel('x')
1060
+ ax.set_ylabel('Probability Density')
1061
+ ax.grid(alpha=0.3)
1062
+
1063
+ plt.tight_layout()
1064
+ plt.gca()
1065
+ return prob_fig, prob_value, ex_z_score
1066
+ return (create_probability_example,)
1067
+
1068
+
1069
+ @app.cell(hide_code=True)
1070
+ def _(np, plt, stats):
1071
+ def create_range_probability_example(range_mu=3, range_sigma=4, range_lower=2, range_upper=5):
1072
+
1073
+ x = np.linspace(range_mu - 4*range_sigma, range_mu + 4*range_sigma, 1000)
1074
+ pdf = stats.norm.pdf(x, range_mu, range_sigma)
1075
+
1076
+ # probability
1077
+ range_prob = stats.norm.cdf(range_upper, range_mu, range_sigma) - stats.norm.cdf(range_lower, range_mu, range_sigma)
1078
+ range_z_lower = (range_lower - range_mu) / range_sigma
1079
+ range_z_upper = (range_upper - range_mu) / range_sigma
1080
+
1081
+ # Create visualization
1082
+ range_fig, ax = plt.subplots(figsize=(10, 6))
1083
+
1084
+ # Plot PDF
1085
+ ax.plot(x, pdf, 'royalblue', linewidth=2)
1086
+
1087
+ # Shade the area representing the probability
1088
+ mask = (x >= range_lower) & (x <= range_upper)
1089
+ ax.fill_between(x[mask], pdf[mask], color='darkorange', alpha=0.6)
1090
+
1091
+ # Add vertical lines at query points
1092
+ ax.axvline(x=range_lower, color='red', linestyle='--', linewidth=1.5)
1093
+ ax.axvline(x=range_upper, color='red', linestyle='--', linewidth=1.5)
1094
+
1095
+ # Annotations
1096
+ ax.annotate(f'x = {range_lower}', xy=(range_lower, 0), xytext=(range_lower, -0.005),
1097
+ horizontalalignment='center')
1098
+ ax.annotate(f'x = {range_upper}', xy=(range_upper, 0), xytext=(range_upper, -0.005),
1099
+ horizontalalignment='center')
1100
+
1101
+ ax.annotate(f'P({range_lower} < X < {range_upper}) = {range_prob:.3f}',
1102
+ xy=((range_lower + range_upper)/2, max(pdf[mask])/2),
1103
+ xytext=((range_lower + range_upper)/2, max(pdf[mask])*1.5),
1104
+ arrowprops=dict(facecolor='black', width=1, shrink=0.05),
1105
+ bbox=dict(boxstyle="round,pad=0.3", fc="white", ec="black", lw=1),
1106
+ horizontalalignment='center')
1107
+
1108
+ # Standard normal calculation annotation
1109
+ ax.annotate(f'= P({range_z_lower:.3f} < Z < {range_z_upper:.3f}) = {range_prob:.3f}',
1110
+ xy=((range_lower + range_upper)/2, max(pdf[mask])/3),
1111
+ xytext=(range_mu - 2*range_sigma, max(pdf[mask])/1.5),
1112
+ arrowprops=dict(facecolor='black', width=1, shrink=0.05),
1113
+ bbox=dict(boxstyle="round,pad=0.3", fc="white", ec="black", lw=1))
1114
+
1115
+ ax.set_title(f'Example: P({range_lower} < X < {range_upper}) where X ~ N({range_mu}, {range_sigma}²)')
1116
+ ax.set_xlabel('x')
1117
+ ax.set_ylabel('Probability Density')
1118
+ ax.grid(alpha=0.3)
1119
+
1120
+ plt.tight_layout()
1121
+ plt.gca()
1122
+ return range_fig, range_prob, range_z_lower, range_z_upper
1123
+ return (create_range_probability_example,)
1124
+
1125
+
1126
+ if __name__ == "__main__":
1127
+ app.run()
probability/18_central_limit_theorem.py ADDED
@@ -0,0 +1,943 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # /// script
2
+ # requires-python = ">=3.10"
3
+ # dependencies = [
4
+ # "marimo",
5
+ # "matplotlib==3.10.1",
6
+ # "scipy==1.15.2",
7
+ # "numpy==2.2.4",
8
+ # "plotly==5.18.0",
9
+ # ]
10
+ # ///
11
+
12
+ import marimo
13
+
14
+ __generated_with = "0.11.30"
15
+ app = marimo.App(width="medium", app_title="Central Limit Theorem")
16
+
17
+
18
+ @app.cell(hide_code=True)
19
+ def _(mo):
20
+ mo.md(
21
+ r"""
22
+ # Central Limit Theorem
23
+
24
+ _This notebook is a computational companion to ["Probability for Computer Scientists"](https://chrispiech.github.io/probabilityForComputerScientists/en/part4/clt/), by Stanford professor Chris Piech._
25
+
26
+ The Central Limit Theorem (CLT) is one of the most important concepts in probability theory and statistics. It explains why many real-world distributions tend to be normal, even when the underlying processes are not.
27
+ """
28
+ )
29
+ return
30
+
31
+
32
+ @app.cell(hide_code=True)
33
+ def _(mo):
34
+ mo.md(
35
+ r"""
36
+ ## Central Limit Theorem Statement
37
+
38
+ There are two ways to state the central limit theorem:
39
+
40
+ ### Sum Version
41
+
42
+ Let $X_1, X_2, \dots, X_n$ be independent and identically distributed random variables. The sum of these random variables approaches a normal distribution as $n \rightarrow \infty$:
43
+
44
+ $n∑i=1Xi∼N(n⋅μ,n⋅σ2)\sum_{i=1}^{n}X_i \sim \mathcal{N}(n \cdot \mu, n \cdot \sigma^2)$
45
+
46
+ Where $\mu = E[X_i]$ and $\sigma^2 = \text{Var}(X_i)$. Since each $X_i$ is identically distributed, they share the same expectation and variance.
47
+
48
+ ### Average Version
49
+
50
+ Let $X_1, X_2, \dots, X_n$ be independent and identically distributed random variables. The average of these random variables approaches a normal distribution as $n \rightarrow \infty$:
51
+
52
+ $\frac{1}{n} ∑i=1Xi∼N(μ,σ2n)\frac{1}{n}\sum_{i=1}^{n}X_i \sim \mathcal{N}(\mu, \frac{\sigma^2}{n})$
53
+
54
+ Where $\mu = E[X_i]$ and $\sigma^2 = \text{Var}(X_i)$.
55
+
56
+ The CLT is incredible because it applies to almost any distribution (as long as it has a finite mean and variance), regardless of its shape.
57
+ """
58
+ )
59
+ return
60
+
61
+
62
+ @app.cell(hide_code=True)
63
+ def _(mo):
64
+ mo.md(
65
+ r"""
66
+ ## Central Limit Theorem Intuition
67
+
68
+ Let's explore what happens when you add random variables together. For example, what if we add 100 different uniform random variables?
69
+
70
+ ```python
71
+ from random import random
72
+
73
+ def add_100_uniforms():
74
+ total = 0
75
+ for i in range(100):
76
+ # returns a sample from uniform(0, 1)
77
+ x_i = random()
78
+ total += x_i
79
+ return total
80
+ ```
81
+
82
+ The value returned by this function will be a random variable. Click the button below to run the function and observe the resulting value of total:
83
+ """
84
+ )
85
+ return
86
+
87
+
88
+ @app.cell(hide_code=True)
89
+ def _(mo):
90
+ run_button = mo.ui.run_button(label="Run add_100_uniforms()")
91
+
92
+ run_button.center()
93
+ return (run_button,)
94
+
95
+
96
+ @app.cell(hide_code=True)
97
+ def _(mo, random, run_button):
98
+ def add_100_uniforms():
99
+ total = 0
100
+ for i in range(100):
101
+ # returns a sample from uniform(0, 1)
102
+ x_i = random.random()
103
+ total += x_i
104
+ return total
105
+
106
+ # Display the result when the button is clicked
107
+ if run_button.value:
108
+ uniform_result = add_100_uniforms()
109
+ display = mo.md(f"**total**: {uniform_result:.5f}")
110
+ else:
111
+ display = mo.md("")
112
+
113
+ display
114
+ return add_100_uniforms, display, uniform_result
115
+
116
+
117
+ @app.cell(hide_code=True)
118
+ def _(mo):
119
+ mo.md(r"""What does total look like as a distribution? Let's calculate total many times and visualize the histogram of values it produces.""")
120
+ return
121
+
122
+
123
+ @app.cell(hide_code=True)
124
+ def _(mo):
125
+ # Simulation control
126
+ run_simulation_button = mo.ui.button(
127
+ value=0,
128
+ on_click=lambda value: value + 1,
129
+ label="Run 10,000 more samples",
130
+ kind="warn"
131
+ )
132
+
133
+ run_simulation_button.center()
134
+ return (run_simulation_button,)
135
+
136
+
137
+ @app.cell(hide_code=True)
138
+ def _(add_100_uniforms, go, mo, np, run_simulation_button, stats, time):
139
+ # store the results
140
+ def get_simulation_results():
141
+ if not hasattr(get_simulation_results, "results"):
142
+ get_simulation_results.results = []
143
+ get_simulation_results.last_button_value = -1 # track button clicks
144
+ return get_simulation_results
145
+
146
+ # grab the results
147
+ sim_storage = get_simulation_results()
148
+ simulation_results = sim_storage.results
149
+
150
+ # Check if button was clicked (value changed)
151
+ if run_simulation_button.value != sim_storage.last_button_value:
152
+ # Update the last seen button value
153
+ sim_storage.last_button_value = run_simulation_button.value
154
+
155
+ with mo.status.spinner(title="Running simulation...") as progress_status:
156
+ sim_count = 10000
157
+ new_results = []
158
+ for _ in mo.status.progress_bar(range(sim_count)):
159
+ sim_result = add_100_uniforms()
160
+ new_results.append(sim_result)
161
+ time.sleep(0.0001) # tiny pause
162
+
163
+ simulation_results.extend(new_results)
164
+
165
+ progress_status.update(f"✅ Added {sim_count:,} samples (total: {len(simulation_results):,})")
166
+
167
+ if simulation_results:
168
+ # Numbers
169
+ mean = np.mean(simulation_results)
170
+ std_dev = np.std(simulation_results)
171
+
172
+ theoretical_mean = 100 * 0.5 # = 50
173
+ theoretical_variance = 100 * (1/12) # = 8.33...
174
+ theoretical_std = np.sqrt(theoretical_variance) # ≈ 2.89
175
+
176
+ # should be 10k times the click number (mainly for the y-axis label)
177
+ total_samples = run_simulation_button.value * 10000
178
+
179
+ fig = go.Figure()
180
+
181
+ # histogram of samples
182
+ fig.add_trace(go.Histogram(
183
+ x=simulation_results,
184
+ histnorm='probability density',
185
+ name='Sum Distribution',
186
+ marker_color='royalblue',
187
+ opacity=0.7
188
+ ))
189
+
190
+ x_vals = np.linspace(min(simulation_results), max(simulation_results), 1000)
191
+ y_vals = stats.norm.pdf(x_vals, theoretical_mean, theoretical_std)
192
+
193
+ fig.add_trace(go.Scatter(
194
+ x=x_vals,
195
+ y=y_vals,
196
+ mode='lines',
197
+ name='Normal approximation',
198
+ line=dict(color='red', width=2)
199
+ ))
200
+
201
+ fig.add_vline(
202
+ x=mean,
203
+ line_dash="dash",
204
+ line_width=1.5,
205
+ line_color="green",
206
+ annotation_text=f"Sample Mean: {mean:.2f}",
207
+ annotation_position="top right"
208
+ )
209
+
210
+ # some notes
211
+ fig.add_annotation(
212
+ x=0.02, y=0.95,
213
+ xref="paper", yref="paper",
214
+ text=f"Sum of 100 Uniform(0,1) variables<br>" +
215
+ f"Sample size: {total_samples:,}<br>" +
216
+ f"Sample mean: {mean:.2f} (expected: {theoretical_mean})<br>" +
217
+ f"Sample std: {std_dev:.2f} (expected: {theoretical_std:.2f})<br>" +
218
+ f"According to CLT: Normal({theoretical_mean}, {theoretical_variance:.2f})",
219
+ showarrow=False,
220
+ align="left",
221
+ bgcolor="white",
222
+ opacity=0.8
223
+ )
224
+
225
+ fig.update_layout(
226
+ title=f'Distribution of Sum of 100 Uniforms (Click #{run_simulation_button.value})',
227
+ xaxis_title='Values',
228
+ yaxis_title=f'Probability Density ({total_samples:,} runs)',
229
+ template='plotly_white',
230
+ height=500
231
+ )
232
+
233
+ # show
234
+ histogram = mo.ui.plotly(fig)
235
+ else:
236
+ histogram = mo.md("Click the button to run the simulation!")
237
+
238
+ # display
239
+ histogram
240
+ return (
241
+ fig,
242
+ get_simulation_results,
243
+ histogram,
244
+ mean,
245
+ new_results,
246
+ progress_status,
247
+ sim_count,
248
+ sim_result,
249
+ sim_storage,
250
+ simulation_results,
251
+ std_dev,
252
+ theoretical_mean,
253
+ theoretical_std,
254
+ theoretical_variance,
255
+ total_samples,
256
+ x_vals,
257
+ y_vals,
258
+ )
259
+
260
+
261
+ @app.cell(hide_code=True)
262
+ def _(mo):
263
+ mo.md(
264
+ r"""
265
+ That is interesting! The sum of 100 independent uniforms looks normal. Is that a special property of uniforms? No! It turns out to work for almost any type of distribution (as long as the distribution has finite mean and variance).
266
+
267
+ - Sum of 40 $X_i$ where $X_i \sim \text{Beta}(a = 5, b = 4)$? Normal.
268
+ - Sum of 90 $X_i$ where $X_i \sim \text{Poisson}(\lambda = 4)$? Normal.
269
+ - Sum of 50 dice-rolls? Normal.
270
+ - Average of 10000 $X_i$ where $X_i \sim \text{Exp}(\lambda = 8)$? Normal.
271
+
272
+ For any distribution, the sum or average of a sufficiently large number of independent, identically distributed random variables will be approximately normally distributed.
273
+ """
274
+ )
275
+ return
276
+
277
+
278
+ @app.cell(hide_code=True)
279
+ def _(mo):
280
+ mo.md(
281
+ r"""
282
+ ## Continuity Correction
283
+
284
+ When using the Central Limit Theorem with discrete random variables (like a Binomial or Poisson), we need to apply a continuity correction. This is because we're approximating a discrete distribution with a continuous one (normal).
285
+
286
+ The continuity correction involves adjusting the boundaries in probability calculations by ±0.5 to account for the discrete nature of the original variable.
287
+
288
+ You should use a continuity correction any time your normal is approximating a discrete random variable. The rules for a general continuity correction are the same as the rules for the [binomial-approximation continuity correction](http://marimo.app/https://github.com/marimo-team/learn/blob/main/probability/14_binomial_distribution.py).
289
+
290
+ In our example above, where we added 100 uniforms, a continuity correction isn't needed because the sum of uniforms is continuous. However, in examples with dice or other discrete distributions, a continuity correction would be necessary.
291
+ """
292
+ )
293
+ return
294
+
295
+
296
+ @app.cell(hide_code=True)
297
+ def _(mo):
298
+ mo.md(
299
+ r"""
300
+ ## Examples
301
+
302
+ Let's work through some practical examples to see how the Central Limit Theorem is applied.
303
+ """
304
+ )
305
+ return
306
+
307
+
308
+ @app.cell(hide_code=True)
309
+ def _(mo):
310
+ mo.md(
311
+ r"""
312
+ ### Example 1: Dice Game
313
+
314
+ You will roll a 6-sided dice 10 times. Let $X$ be the total value of all 10 dice: $X = X_1 + X_2 + \dots + X_{10}$. You win the game if $X \leq 25$ or $X \geq 45$. Use the central limit theorem to calculate the probability that you win.
315
+
316
+ Recall that for a single die roll $X_i$:
317
+
318
+ - $E[X_i] = 3.5$
319
+ - $\text{Var}(X_i) = \frac{35}{12}$
320
+
321
+ **Solution:**
322
+
323
+ Let $Y$ be the approximating normal distribution. By the Central Limit Theorem:
324
+
325
+ $Y∼N(10⋅E[Xi],10⋅Var(Xi))Y \sim \mathcal{N}(10 \cdot E[X_i], 10 \cdot \text{Var}(X_i))$
326
+
327
+ Substituting in the known values:
328
+
329
+ $Y∼N(10⋅3.5,10⋅3512)=N(35,29.2)Y \sim \mathcal{N}(10 \cdot 3.5, 10 \cdot \frac{35}{12}) = \mathcal{N}(35, 29.2)$
330
+
331
+ Now we calculate the probability:
332
+
333
+ $P(X≤25 or X≥45)P(X \leq 25 \text{ or } X \geq 45)$
334
+
335
+ $=P(X≤25)+P(X≥45)= P(X \leq 25) + P(X \geq 45)$
336
+
337
+ $≈P(Y<25.5)+P(Y>44.5) (Continuity Correction)\approx P(Y < 25.5) + P(Y > 44.5) \text{ (Continuity Correction)}$
338
+
339
+ $≈P(Y<25.5)+[1−P(Y<44.5)]\approx P(Y < 25.5) + [1 - P(Y < 44.5)]$
340
+
341
+ $≈Φ(25.5−35√29.2)+[1−Φ(44.5−35√29.2)]\approx \Phi\left(\frac{25.5 - 35}{\sqrt{29.2}}\right) + \left[1 - \Phi\left(\frac{44.5 - 35}{\sqrt{29.2}}\right)\right]$
342
+
343
+ $≈Φ(−1.76)+[1−Φ(1.76)]\approx \Phi(-1.76) + [1 - \Phi(1.76)]$
344
+
345
+ $≈0.039+(1−0.961)\approx 0.039 + (1 - 0.961)$
346
+
347
+ $≈0.078\approx 0.078$
348
+ So, the probability of winning the game is approximately 7.8%.
349
+ """
350
+ )
351
+ return
352
+
353
+
354
+ @app.cell(hide_code=True)
355
+ def _(create_dice_game_visualization, fig_to_image, mo):
356
+ # Display visualization
357
+ dice_game_fig = create_dice_game_visualization()
358
+ dice_game_image = mo.image(fig_to_image(dice_game_fig), width="100%")
359
+
360
+ dice_explanation = mo.md(
361
+ r"""
362
+ **Visualization Explanation:**
363
+
364
+ The graph shows the distribution of the sum of 10 dice rolls. The blue bars represent the actual probability mass function (PMF), while the red curve shows the normal approximation using the Central Limit Theorem.
365
+
366
+ The winning regions are shaded in orange:
367
+ - The left region where $X \leq 25$
368
+ - The right region where $X \geq 45$
369
+
370
+ The total probability of these regions is approximately 0.078 or 7.8%.
371
+
372
+ Notice how the normal approximation provides a good fit to the discrete distribution, demonstrating the power of the Central Limit Theorem.
373
+ """
374
+ )
375
+
376
+ mo.vstack([dice_game_image, dice_explanation])
377
+ return dice_explanation, dice_game_fig, dice_game_image
378
+
379
+
380
+ @app.cell(hide_code=True)
381
+ def _(mo):
382
+ mo.md(
383
+ r"""
384
+ ### Example 2: Algorithm Runtime Estimation
385
+
386
+ Say you have a new algorithm and you want to test its running time. You know the variance of the algorithm's run time is $\sigma^2 = 4 \text{ sec}^2$, but you want to estimate the mean run time $t$ in seconds.
387
+
388
+ You can run the algorithm repeatedly (IID trials). How many trials do you have to run so that your estimated runtime is within ±0.5 seconds of $t$ with 95% certainty?
389
+
390
+ Let $X_i$ be the run time of the $i$-th run (for $1 \leq i \leq n$).
391
+
392
+ **Solution:**
393
+
394
+ We need to find $n$ such that:
395
+
396
+ $0.95=P(−0.5≤∑ni=1Xin−t≤0.5)0.95 = P\left(-0.5 \leq \frac{\sum_{i=1}^n X_i}{n} - t \leq 0.5\right)$
397
+
398
+ By the central limit theorem, the sample mean follows a normal distribution.
399
+ We can standardize this to work with the standard normal:
400
+
401
+ $Z=(∑ni=1Xi)−nμσ√nZ = \frac{\left(\sum_{i=1}^n X_i\right) - n\mu}{\sigma \sqrt{n}}$
402
+
403
+ $=(∑ni=1Xi)−nt2√n= \frac{\left(\sum_{i=1}^n X_i\right) - nt}{2 \sqrt{n}}$
404
+
405
+ Rewriting our probability inequality so that the central term is $Z$:
406
+
407
+ $0.95=P(−0.5≤∑ni=1Xin−t≤0.5)0.95 = P\left(-0.5 \leq \frac{\sum_{i=1}^n X_i}{n} - t \leq 0.5\right)$
408
+
409
+ $=P(−0.5√n2≤Z≤0.5√n2)= P\left(\frac{-0.5 \sqrt{n}}{2} \leq Z \leq \frac{0.5 \sqrt{n}}{2}\right)$
410
+
411
+ And now we find the value of $n$ that makes this equation hold:
412
+
413
+ $0.95=Φ(√n4)−Φ(−√n4)0.95 = \Phi\left(\frac{\sqrt{n}}{4}\right) - \Phi\left(-\frac{\sqrt{n}}{4}\right)$
414
+
415
+ $4=Φ(√n4)−(1−Φ(√n4))= \Phi\left(\frac{\sqrt{n}}{4}\right) - \left(1 - \Phi\left(\frac{\sqrt{n}}{4}\right)\right)$
416
+
417
+ $=2Φ(√n4)−1= 2\Phi\left(\frac{\sqrt{n}}{4}\right) - 1$
418
+
419
+ Solving for $\Phi\left(\frac{\sqrt{n}}{4}\right)$:
420
+
421
+ $0.975=Φ(√n4)0.975 = \Phi\left(\frac{\sqrt{n}}{4}\right)$
422
+
423
+ $Φ−1(0.975)=√n4\Phi^{-1}(0.975) = \frac{\sqrt{n}}{4}$
424
+
425
+ $1.96=√n41.96 = \frac{\sqrt{n}}{4}$
426
+
427
+ $n=61.4n = 61.4$
428
+
429
+ Therefore, we need to run the algorithm 62 times to estimate the mean runtime within ±0.5 seconds with 95% confidence.
430
+ """
431
+ )
432
+ return
433
+
434
+
435
+ @app.cell(hide_code=True)
436
+ def _(create_algorithm_runtime_visualization, fig_to_image, mo):
437
+ # Display visualization
438
+ runtime_fig = create_algorithm_runtime_visualization()
439
+ runtime_image = mo.image(fig_to_image(runtime_fig), width="100%")
440
+
441
+ runtime_explanation = mo.md(
442
+ r"""
443
+ **Visualization Explanation:**
444
+
445
+ The graph illustrates how the standard error of the mean (SEM) decreases as the number of trials increases. The standard error is calculated as $\frac{\sigma}{\sqrt{n}}$.
446
+
447
+ - When we conduct 62 trials, the standard error is approximately 0.254 seconds.
448
+ - With a 95% confidence level, this gives us a margin of error of about ±0.5 seconds (1.96 × 0.254 ≈ 0.5).
449
+ - The shaded region shows how the confidence interval narrows as the number of trials increases.
450
+
451
+ This demonstrates why 62 trials are sufficient to meet our requirements of estimating the mean runtime within ±0.5 seconds with 95% confidence.
452
+ """
453
+ )
454
+
455
+ mo.vstack([runtime_image, runtime_explanation])
456
+ return runtime_explanation, runtime_fig, runtime_image
457
+
458
+
459
+ @app.cell(hide_code=True)
460
+ def _(mo):
461
+ mo.md(
462
+ r"""
463
+ ## Interactive CLT Explorer
464
+
465
+ Let's explore how the Central Limit Theorem works with different underlying distributions. You can select a distribution type and see how the distribution of the sample mean changes as the sample size increases.
466
+ """
467
+ )
468
+ return
469
+
470
+
471
+ @app.cell(hide_code=True)
472
+ def _(controls):
473
+ controls
474
+ return
475
+
476
+
477
+ @app.cell(hide_code=True)
478
+ def _(
479
+ distribution_type,
480
+ fig_to_image,
481
+ mo,
482
+ np,
483
+ plt,
484
+ run_explorer_button,
485
+ sample_size,
486
+ sim_count_slider,
487
+ stats,
488
+ ):
489
+ # Run simulation when button is clicked
490
+ if run_explorer_button.value:
491
+ # Set distribution parameters based on selection
492
+ if distribution_type.value == "uniform":
493
+ dist_name = "Uniform(0, 1)"
494
+ # For uniform(0,1): mean = 0.5, variance = 1/12
495
+ true_mean = 0.5
496
+ true_var = 1/12
497
+
498
+ # generate samples
499
+ def generate_sample():
500
+ return np.random.uniform(0, 1, sample_size.value)
501
+
502
+ elif distribution_type.value == "exponential":
503
+ rate = 1.0
504
+ dist_name = f"Exponential(λ={rate})"
505
+ # For exponential(λ): mean = 1/λ, variance = 1/λ²
506
+ true_mean = 1/rate
507
+ true_var = 1/(rate**2)
508
+
509
+ def generate_sample():
510
+ return np.random.exponential(1/rate, sample_size.value)
511
+
512
+ elif distribution_type.value == "binomial":
513
+ n_param, p = 10, 0.3
514
+ dist_name = f"Binomial(n={n_param}, p={p})"
515
+ # For binomial(n,p): mean = np, variance = np(1-p)
516
+ true_mean = n_param * p
517
+ true_var = n_param * p * (1-p)
518
+
519
+ def generate_sample():
520
+ return np.random.binomial(n_param, p, sample_size.value)
521
+
522
+ elif distribution_type.value == "poisson":
523
+ rate = 3.0
524
+ dist_name = f"Poisson(λ={rate})"
525
+ # For poisson(λ): mean = λ, variance = λ
526
+ true_mean = rate
527
+ true_var = rate
528
+
529
+ def generate_sample():
530
+ return np.random.poisson(rate, sample_size.value)
531
+
532
+ # Generate the simulation data using a spinner for progress
533
+ with mo.status.spinner(title="Running simulation...") as explorer_progress:
534
+ sample_means = []
535
+ original_samples = []
536
+
537
+ # Run simulations
538
+ for _ in mo.status.progress_bar(range(sim_count_slider.value)):
539
+ sample = generate_sample()
540
+
541
+ # Store the first simulation's individual values for visualizing original distribution
542
+ if len(original_samples) < 1000: # limit to prevent memory issues
543
+ original_samples.extend(sample)
544
+
545
+ # sample mean
546
+ sample_means.append(np.mean(sample))
547
+
548
+ # progress
549
+ explorer_progress.update(f"✅ Completed {sim_count_slider.value:,} simulations")
550
+
551
+ # Create visualization
552
+ explorer_fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))
553
+
554
+ # Original distribution histogram
555
+ ax1.hist(original_samples, bins=30, density=True, alpha=0.7, color='royalblue')
556
+ ax1.set_title(f"Original Distribution: {dist_name}")
557
+
558
+ # Theoretical mean line
559
+ ax1.axvline(x=true_mean, color='red', linestyle='--',
560
+ label=f'True Mean = {true_mean:.3f}')
561
+
562
+ ax1.set_xlabel("Value")
563
+ ax1.set_ylabel("Density")
564
+ ax1.legend()
565
+
566
+ # Sample means histogram and normal approximation
567
+ sample_mean_mean = np.mean(sample_means)
568
+ sample_mean_std = np.std(sample_means)
569
+ expected_std = np.sqrt(true_var / sample_size.value) # CLT prediction
570
+
571
+ ax2.hist(sample_means, bins=30, density=True, alpha=0.7, color='forestgreen',
572
+ label=f'Sample Size = {sample_size.value}')
573
+
574
+ # Normal approximation from CLT
575
+ explorer_x = np.linspace(min(sample_means), max(sample_means), 1000)
576
+ explorer_y = stats.norm.pdf(explorer_x, true_mean, expected_std)
577
+ ax2.plot(explorer_x, explorer_y, 'r-', linewidth=2, label='CLT Normal Approximation')
578
+
579
+ # Add mean line
580
+ ax2.axvline(x=true_mean, color='purple', linestyle='--',
581
+ label=f'True Mean = {true_mean:.3f}')
582
+
583
+ ax2.set_title(f"Distribution of Sample Means\n(CLT Prediction: N({true_mean:.3f}, {true_var/sample_size.value:.5f}))")
584
+ ax2.set_xlabel("Sample Mean")
585
+ ax2.set_ylabel("Density")
586
+ ax2.legend()
587
+
588
+ # Add CLT description
589
+ explorer_fig.text(0.5, 0.01,
590
+ f"Central Limit Theorem: As sample size increases, the distribution of sample means approaches\n" +
591
+ f"a normal distribution with mean = {true_mean:.3f} and variance = {true_var:.3f}/{sample_size.value} = {true_var/sample_size.value:.5f}",
592
+ ha='center', fontsize=10, bbox=dict(facecolor='white', alpha=0.8))
593
+
594
+ plt.tight_layout(rect=[0, 0.05, 1, 1])
595
+
596
+ # Display plot
597
+ explorer_image = mo.image(fig_to_image(explorer_fig), width="100%")
598
+ else:
599
+ explorer_image = mo.md("Click the 'Run Simulation' button to see how the Central Limit Theorem works.")
600
+
601
+ explorer_image
602
+ return (
603
+ ax1,
604
+ ax2,
605
+ dist_name,
606
+ expected_std,
607
+ explorer_fig,
608
+ explorer_image,
609
+ explorer_progress,
610
+ explorer_x,
611
+ explorer_y,
612
+ generate_sample,
613
+ n_param,
614
+ original_samples,
615
+ p,
616
+ rate,
617
+ sample,
618
+ sample_mean_mean,
619
+ sample_mean_std,
620
+ sample_means,
621
+ true_mean,
622
+ true_var,
623
+ )
624
+
625
+
626
+ @app.cell(hide_code=True)
627
+ def _(mo):
628
+ mo.md(
629
+ r"""
630
+ ## 🤔 Test Your Understanding
631
+
632
+ /// details | What is the shape of the distribution of the sum of many independent random variables?
633
+ The sum of many independent random variables approaches a normal distribution, regardless of the shape of the original distributions (as long as they have finite mean and variance). This is the essence of the Central Limit Theorem.
634
+ ///
635
+
636
+ /// details | If $X_1, X_2, \dots, X_{100}$ are IID random variables with $E[X_i] = 5$ and $Var(X_i) = 9$, what is the distribution of their sum?
637
+ By the Central Limit Theorem, the sum $S = X_1 + X_2 + \dots + X_{100}$ follows a normal distribution with:
638
+
639
+ - Mean: $E[S] = 100 \cdot E[X_i] = 100 \cdot 5 = 500$
640
+ - Variance: $Var(S) = 100 \cdot Var(X_i) = 100 \cdot 9 = 900$
641
+
642
+ Therefore, $S \sim \mathcal{N}(500, 900)$, or equivalently $S \sim \mathcal{N}(500, 30^2)$.
643
+ ///
644
+
645
+ /// details | When do you need to apply a continuity correction when using the Central Limit Theorem?
646
+ You need to apply a continuity correction when you're using the normal approximation (through CLT) for a discrete random variable.
647
+
648
+ For example, when approximating a binomial or Poisson distribution with a normal distribution, you should adjust boundaries by ±0.5 to account for the discrete nature of the original variable. This makes the approximation more accurate.
649
+ ///
650
+
651
+ /// details | If $X_1, X_2, \dots, X_{n}$ are IID random variables, how does the variance of their sample mean $\bar{X} = \frac{1}{n}\sum_{i=1}^{n}X_i$ change as $n$ increases?
652
+ The variance of the sample mean decreases as the sample size $n$ increases. Specifically:
653
+
654
+ $Var(\bar{X}) = \frac{Var(X_i)}{n}$
655
+
656
+ This means that as we take more samples, the sample mean becomes more concentrated around the true mean of the distribution. This is why larger samples give more precise estimates.
657
+ ///
658
+
659
+ /// details | Why is the Central Limit Theorem so important in statistics?
660
+ The Central Limit Theorem is foundational in statistics because:
661
+
662
+ 1. It allows us to make inferences about population parameters using sample statistics, regardless of the population's distribution.
663
+ 2. It explains why the normal distribution appears so frequently in natural phenomena.
664
+ 3. It enables the construction of confidence intervals and hypothesis tests for means, even when the underlying population distribution is unknown.
665
+ 4. It justifies many statistical methods that assume normality, even when working with non-normal data, provided the sample size is large enough.
666
+
667
+ In essence, the CLT provides the theoretical justification for much of statistical inference.
668
+ ///
669
+ """
670
+ )
671
+ return
672
+
673
+
674
+ @app.cell(hide_code=True)
675
+ def _(mo):
676
+ mo.md(r"""## Appendix (helper code and functions)""")
677
+ return
678
+
679
+
680
+ @app.cell
681
+ def _():
682
+ import marimo as mo
683
+ return (mo,)
684
+
685
+
686
+ @app.cell(hide_code=True)
687
+ def _():
688
+ from wigglystuff import TangleSlider
689
+ return (TangleSlider,)
690
+
691
+
692
+ @app.cell(hide_code=True)
693
+ def _():
694
+ # Import libraries
695
+ import numpy as np
696
+ import matplotlib.pyplot as plt
697
+ from scipy import stats
698
+ import io
699
+ import base64
700
+ import random
701
+ import time
702
+ import plotly.graph_objects as go
703
+ import plotly.io as pio
704
+ return base64, go, io, np, pio, plt, random, stats, time
705
+
706
+
707
+ @app.cell(hide_code=True)
708
+ def _(base64, io):
709
+ from matplotlib.figure import Figure
710
+
711
+ # Helper function to convert matplotlib figures to images
712
+ def fig_to_image(fig):
713
+ buf = io.BytesIO()
714
+ fig.savefig(buf, format='png', bbox_inches='tight')
715
+ buf.seek(0)
716
+ img_str = base64.b64encode(buf.getvalue()).decode('utf-8')
717
+ return f"data:image/png;base64,{img_str}"
718
+ return Figure, fig_to_image
719
+
720
+
721
+ @app.cell(hide_code=True)
722
+ def _(np, plt, stats):
723
+ def create_dice_game_visualization():
724
+ """Create a visualization for the dice game example."""
725
+ # Parameters
726
+ n_dice = 10
727
+ dice_values = np.arange(1, 7) # 1 to 6
728
+
729
+ # Theoretical values
730
+ single_die_mean = np.mean(dice_values) # 3.5
731
+ single_die_var = np.var(dice_values) # 35/12
732
+
733
+ # Sum distribution parameters
734
+ sum_mean = n_dice * single_die_mean
735
+ sum_var = n_dice * single_die_var
736
+ sum_std = np.sqrt(sum_var)
737
+
738
+ # Possible outcomes for the sum of 10 dice
739
+ min_sum = n_dice * min(dice_values) # 10
740
+ max_sum = n_dice * max(dice_values) # 60
741
+ sum_values = np.arange(min_sum, max_sum + 1)
742
+
743
+ # Create figure
744
+ fig, ax = plt.subplots(figsize=(10, 6))
745
+
746
+ # Calculate PMF through convolution
747
+ # For one die
748
+ single_pmf = np.ones(6) / 6
749
+
750
+ sum_pmf = single_pmf.copy()
751
+ for _ in range(n_dice - 1):
752
+ sum_pmf = np.convolve(sum_pmf, single_pmf)
753
+
754
+ # Plot the PMF
755
+ ax.bar(sum_values, sum_pmf, alpha=0.7, color='royalblue', label='Exact PMF')
756
+
757
+ # Normal approximation
758
+ x = np.linspace(min_sum - 5, max_sum + 5, 1000)
759
+ y = stats.norm.pdf(x, sum_mean, sum_std)
760
+ ax.plot(x, y, 'r-', linewidth=2, label='Normal Approximation')
761
+
762
+ # Win conditions (x ≤ 25 or x ≥ 45)
763
+ win_region_left = sum_values <= 25
764
+ win_region_right = sum_values >= 45
765
+
766
+ # Shade win regions
767
+ ax.bar(sum_values[win_region_left], sum_pmf[win_region_left],
768
+ color='darkorange', alpha=0.7, label='Win Region')
769
+ ax.bar(sum_values[win_region_right], sum_pmf[win_region_right],
770
+ color='darkorange', alpha=0.7)
771
+
772
+ # Calculate win probability
773
+ win_prob = np.sum(sum_pmf[win_region_left]) + np.sum(sum_pmf[win_region_right])
774
+
775
+ # Add vertical lines for critical values
776
+ ax.axvline(x=25.5, color='red', linestyle='--', linewidth=1.5, label='Critical Points')
777
+ ax.axvline(x=44.5, color='red', linestyle='--', linewidth=1.5)
778
+
779
+ # Add mean line
780
+ ax.axvline(x=sum_mean, color='green', linestyle='--', linewidth=1.5,
781
+ label=f'Mean = {sum_mean}')
782
+
783
+ # Text box with relevant information
784
+ textstr = '\n'.join((
785
+ f'Number of dice: {n_dice}',
786
+ f'Sum Mean: {sum_mean}',
787
+ f'Sum Std Dev: {sum_std:.2f}',
788
+ f'Win Probability: {win_prob:.4f}',
789
+ f'CLT Approximation: {0.078:.4f}'
790
+ ))
791
+ props = dict(boxstyle='round', facecolor='wheat', alpha=0.5)
792
+ ax.text(0.05, 0.95, textstr, transform=ax.transAxes, fontsize=10,
793
+ verticalalignment='top', bbox=props)
794
+
795
+ # Formatting
796
+ ax.set_xlabel('Sum of 10 Dice')
797
+ ax.set_ylabel('Probability')
798
+ ax.set_title('Central Limit Theorem: Dice Game Example')
799
+ ax.legend()
800
+ ax.grid(alpha=0.3)
801
+
802
+ plt.tight_layout()
803
+ plt.gca()
804
+ return fig
805
+ return (create_dice_game_visualization,)
806
+
807
+
808
+ @app.cell(hide_code=True)
809
+ def _(np, plt):
810
+ def create_algorithm_runtime_visualization():
811
+ """Create a visualization for the algorithm runtime example."""
812
+ # Parameters
813
+ variance = 4 # σ² = 4 sec²
814
+ std_dev = np.sqrt(variance) # σ = 2 sec
815
+ confidence_level = 0.95
816
+ z_score = 1.96 # for 95% confidence
817
+ target_error = 0.5 # ±0.5 seconds
818
+
819
+ # Calculate n needed for desired precision
820
+ n_required = int(np.ceil((z_score * std_dev / target_error) ** 2)) # ≈ 62
821
+
822
+ n_values = np.arange(1, 100)
823
+
824
+ # standard error
825
+ standard_errors = std_dev / np.sqrt(n_values)
826
+
827
+ # margin of error
828
+ margins_of_error = z_score * standard_errors
829
+
830
+ # Create figure
831
+ fig, ax = plt.subplots(figsize=(10, 6))
832
+
833
+ # standard error vs sample size plot
834
+ ax.plot(n_values, standard_errors, 'b-', linewidth=2, label='Standard Error of Mean')
835
+
836
+ # Plot margin of error vs sample size
837
+ ax.plot(n_values, margins_of_error, 'r--', linewidth=2,
838
+ label=f'{confidence_level*100}% Margin of Error')
839
+
840
+ ax.axvline(x=n_required, color='green', linestyle='-', linewidth=1.5,
841
+ label=f'Required n = {n_required}')
842
+
843
+ ax.axhline(y=target_error, color='purple', linestyle='--', linewidth=1.5,
844
+ label=f'Target Error = ±{target_error} sec')
845
+
846
+ # Shade the region below target error
847
+ ax.fill_between(n_values, 0, target_error, alpha=0.2, color='green')
848
+
849
+ # intersection point
850
+ ax.plot(n_required, target_error, 'ro', markersize=8)
851
+ ax.annotate(f'({n_required}, {target_error} sec)',
852
+ xy=(n_required, target_error),
853
+ xytext=(n_required + 5, target_error + 0.1),
854
+ arrowprops=dict(facecolor='black', shrink=0.05, width=1))
855
+
856
+ # Text box with appropriate information
857
+ textstr = '\n'.join((
858
+ f'Algorithm Variance: {variance} sec²',
859
+ f'Standard Deviation: {std_dev} sec',
860
+ f'Confidence Level: {confidence_level*100}%',
861
+ f'Z-score: {z_score}',
862
+ f'Target Error: ±{target_error} sec',
863
+ f'Required Sample Size: {n_required}'
864
+ ))
865
+ props = dict(boxstyle='round', facecolor='wheat', alpha=0.5)
866
+ ax.text(0.05, 0.95, textstr, transform=ax.transAxes, fontsize=10,
867
+ verticalalignment='top', bbox=props)
868
+
869
+ # Formatting
870
+ ax.set_xlabel('Sample Size (n)')
871
+ ax.set_ylabel('Error (seconds)')
872
+ ax.set_title('Sample Size Determination for Algorithm Runtime Estimation')
873
+ ax.set_xlim(0, 100)
874
+ ax.set_ylim(0, 2)
875
+ ax.legend()
876
+ ax.grid(alpha=0.3)
877
+
878
+ plt.tight_layout()
879
+ return fig
880
+ return (create_algorithm_runtime_visualization,)
881
+
882
+
883
+ @app.cell(hide_code=True)
884
+ def _(mo):
885
+ mo.md(
886
+ r"""
887
+ ## Summary
888
+
889
+ The Central Limit Theorem is truly one of the most remarkable ideas in all of statistics. It tells us that when we add up many independent random variables, their sum will follow a normal distribution, regardless of what the original distributions looked like. This is why we see normal distributions so often in real life – many natural phenomena are the result of numerous small, independent factors adding up.
890
+
891
+ What makes the CLT so powerful is its universality. Whether we're working with dice rolls, measurement errors, or stock market returns, as long as we have enough independent samples, their average or sum will be approximately normal. For sums, the distribution will be $\mathcal{N}(n\mu, n\sigma^2)$, and for averages, it's $\mathcal{N}(\mu, \frac{\sigma^2}{n})$.
892
+
893
+ The CLT gives us the foundation for confidence intervals, hypothesis testing, and many other statistical tools. Without it, we'd have a much harder time making sense of data when we don't know the underlying population distribution. Just remember that if you're working with discrete distributions, you'll need to apply a continuity correction to get more accurate results.
894
+
895
+ Next time you see a normal distribution in data, think about the Central Limit Theorem – it might be the reason behind that familiar bell curve!
896
+ """
897
+ )
898
+ return
899
+
900
+
901
+ @app.cell(hide_code=True)
902
+ def _(mo):
903
+ # controls for the interactive explorer
904
+ distribution_type = mo.ui.dropdown(
905
+ options=["uniform", "exponential", "binomial", "poisson"],
906
+ value="uniform",
907
+ label="Distribution Type"
908
+ )
909
+
910
+ sample_size = mo.ui.slider(
911
+ start =1,
912
+ stop =100,
913
+ step=1,
914
+ value=30,
915
+ label="Sample Size (n)"
916
+ )
917
+
918
+ sim_count_slider = mo.ui.slider(
919
+ start =100,
920
+ stop =10000,
921
+ step=100,
922
+ value=1000,
923
+ label="Number of Simulations"
924
+ )
925
+
926
+ run_explorer_button = mo.ui.run_button(label="Run Simulation", kind="warn")
927
+
928
+ controls = mo.hstack([
929
+ mo.vstack([distribution_type, sample_size, sim_count_slider]),
930
+ run_explorer_button
931
+ ], justify='space-around')
932
+
933
+ return (
934
+ controls,
935
+ distribution_type,
936
+ run_explorer_button,
937
+ sample_size,
938
+ sim_count_slider,
939
+ )
940
+
941
+
942
+ if __name__ == "__main__":
943
+ app.run()
probability/19_maximum_likelihood_estimation.py ADDED
@@ -0,0 +1,1231 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # /// script
2
+ # requires-python = ">=3.10"
3
+ # dependencies = [
4
+ # "marimo",
5
+ # "matplotlib==3.10.1",
6
+ # "scipy==1.15.2",
7
+ # "numpy==2.2.4",
8
+ # "polars==0.20.2",
9
+ # "plotly==5.18.0",
10
+ # ]
11
+ # ///
12
+
13
+ import marimo
14
+
15
+ __generated_with = "0.12.0"
16
+ app = marimo.App(width="medium", app_title="Maximum Likelihood Estimation")
17
+
18
+
19
+ @app.cell(hide_code=True)
20
+ def _(mo):
21
+ mo.md(
22
+ r"""
23
+ # Maximum Likelihood Estimation
24
+
25
+ _This notebook is a computational companion to ["Probability for Computer Scientists"](https://chrispiech.github.io/probabilityForComputerScientists/en/part5/mle/), by Stanford professor Chris Piech._
26
+
27
+ Maximum Likelihood Estimation (MLE) is a fundamental method in statistics for estimating parameters of a probability distribution. The central idea is elegantly simple: **choose the parameters that make the observed data most likely**.
28
+
29
+ In this notebook, we'll try to understand MLE, starting with the core concept of likelihood and how it differs from probability. We'll explore how to formulate MLE problems mathematically and then solve them for various common distributions. Along the way, I've included some interactive visualizations to help build your intuition about these concepts. You'll see how MLE applies to real-world scenarios like linear regression, and hopefully gain a deeper appreciation for why this technique is so widely used in statistics and machine learning. Think of MLE as detective work - we have some evidence (our data) and we're trying to figure out the most plausible explanation (our parameters) for what we've observed.
30
+ """
31
+ )
32
+ return
33
+
34
+
35
+ @app.cell(hide_code=True)
36
+ def _(mo):
37
+ mo.md(
38
+ r"""
39
+ ## Likelihood: The Core Concept
40
+
41
+ Before diving into MLE, we need to understand what "likelihood" means in a statistical context.
42
+
43
+ ### Data and Parameters
44
+
45
+ Suppose we have collected some data $X_1, X_2, \ldots, X_n$ that are independent and identically distributed (IID). We assume these data points come from a specific type of distribution (like Normal, Bernoulli, etc.) with unknown parameters $\theta$.
46
+
47
+ ### What is Likelihood?
48
+
49
+ Likelihood measures how probable our observed data is, given specific values of the parameters $\theta$.
50
+
51
+ /// note
52
+ **Probability vs. Likelihood**
53
+
54
+ - **Probability**: Given parameters $\theta$, what's the chance of observing data $X$?
55
+ - **Likelihood**: Given observed data $X$, how likely are different parameter values $\theta$?
56
+ ///
57
+
58
+ To simplify notation, we'll use $f(X=x|\Theta=\theta)$ to represent either the PMF or PDF of our data, conditioned on the parameters.
59
+ """
60
+ )
61
+ return
62
+
63
+
64
+ @app.cell(hide_code=True)
65
+ def _(mo):
66
+ mo.md(
67
+ r"""
68
+ ### The Likelihood Function
69
+
70
+ Since we assume our data points are independent, the likelihood of all our data is the product of the likelihoods of each individual data point:
71
+
72
+ $$L(\theta) = \prod_{i=1}^n f(X_i = x_i|\Theta = \theta)$$
73
+
74
+ This function $L(\theta)$ gives us the likelihood of observing our entire dataset for different parameter values $\theta$.
75
+
76
+ /// tip
77
+ **Key Insight**: Different parameter values produce different likelihoods for the same data. Better parameter values will make the observed data more likely.
78
+ ///
79
+ """
80
+ )
81
+ return
82
+
83
+
84
+ @app.cell(hide_code=True)
85
+ def _(mo):
86
+ mo.md(
87
+ r"""
88
+ ## Maximum Likelihood Estimation
89
+
90
+ The core idea of MLE is to find the parameter values $\hat{\theta}$ that maximize the likelihood function:
91
+
92
+ $$\hat{\theta} = \underset{\theta}{\operatorname{argmax}} \, L(\theta)$$
93
+
94
+ The notation $\hat{\theta}$ represents our best estimate of the true parameters based on the observed data.
95
+
96
+ ### Working with Log-Likelihood
97
+
98
+ In practice, we usually work with the **log-likelihood** instead of the likelihood directly. Since logarithm is a monotonically increasing function, the maximum of $L(\theta)$ occurs at the same value of $\theta$ as the maximum of $\log L(\theta)$.
99
+
100
+ Taking the logarithm transforms our product into a sum, which is much easier to work with:
101
+
102
+ $$LL(\theta) = \log L(\theta) = \log \prod_{i=1}^n f(X_i=x_i|\Theta = \theta) = \sum_{i=1}^n \log f(X_i = x_i|\Theta = \theta)$$
103
+
104
+ /// warning
105
+ Working with products of many small probabilities can lead to numerical underflow. Taking the logarithm converts these products to sums, which is numerically more stable.
106
+ ///
107
+ """
108
+ )
109
+ return
110
+
111
+
112
+ @app.cell(hide_code=True)
113
+ def _(mo):
114
+ mo.md(
115
+ r"""
116
+ ### Finding the Maximum
117
+
118
+ To find the values of $\theta$ that maximize the log-likelihood, we typically:
119
+
120
+ 1. Take the derivative of $LL(\theta)$ with respect to each parameter
121
+ 2. Set each derivative equal to zero
122
+ 3. Solve for the parameters
123
+
124
+ Let's see this approach in action with some common distributions.
125
+ """
126
+ )
127
+ return
128
+
129
+
130
+ @app.cell(hide_code=True)
131
+ def _(mo):
132
+ mo.md(
133
+ r"""
134
+ ## MLE for Bernoulli Distribution
135
+
136
+ Let's start with a simple example: estimating the parameter $p$ of a Bernoulli distribution.
137
+
138
+ ### The Model
139
+
140
+ A Bernoulli distribution has a single parameter $p$ which represents the probability of success (getting a value of 1). Its probability mass function (PMF) can be written as:
141
+
142
+ $$f(x|p) = p^x(1-p)^{1-x}, \quad x \in \{0, 1\}$$
143
+
144
+ This elegant formula works because:
145
+
146
+ - When $x = 1$: $f(1|p) = p^1(1-p)^0 = p$
147
+ - When $x = 0$: $f(0|p) = p^0(1-p)^1 = 1-p$
148
+
149
+ ### Deriving the MLE
150
+
151
+ Given $n$ independent Bernoulli trials $X_1, X_2, \ldots, X_n$, we want to find the value of $p$ that maximizes the likelihood of our observed data.
152
+
153
+ Step 1: Write the likelihood function
154
+ $$L(p) = \prod_{i=1}^n p^{x_i}(1-p)^{1-x_i}$$
155
+
156
+ Step 2: Take the logarithm to get the log-likelihood
157
+ $$\begin{align*}
158
+ LL(p) &= \sum_{i=1}^n \log(p^{x_i}(1-p)^{1-x_i}) \\
159
+ &= \sum_{i=1}^n \left[x_i \log(p) + (1-x_i)\log(1-p)\right] \\
160
+ &= \left(\sum_{i=1}^n x_i\right) \log(p) + \left(n - \sum_{i=1}^n x_i\right) \log(1-p) \\
161
+ &= Y\log(p) + (n-Y)\log(1-p)
162
+ \end{align*}$$
163
+
164
+ where $Y = \sum_{i=1}^n x_i$ is the total number of successes.
165
+
166
+ Step 3: Find the value of $p$ that maximizes $LL(p)$ by setting the derivative to zero
167
+ $$\begin{align*}
168
+ \frac{d\,LL(p)}{dp} &= \frac{Y}{p} - \frac{n-Y}{1-p} = 0 \\
169
+ \frac{Y}{p} &= \frac{n-Y}{1-p} \\
170
+ Y(1-p) &= p(n-Y) \\
171
+ Y - Yp &= pn - pY \\
172
+ Y &= pn \\
173
+ \hat{p} &= \frac{Y}{n} = \frac{\sum_{i=1}^n x_i}{n}
174
+ \end{align*}$$
175
+
176
+ /// tip
177
+ The MLE for the parameter $p$ in a Bernoulli distribution is simply the **sample mean** - the proportion of successes in our data!
178
+ ///
179
+ """
180
+ )
181
+ return
182
+
183
+
184
+ @app.cell(hide_code=True)
185
+ def _(controls):
186
+ controls.center()
187
+ return
188
+
189
+
190
+ @app.cell(hide_code=True)
191
+ def _(generate_button, mo, np, plt, sample_size_slider, true_p_slider):
192
+ # generate bernoulli samples when button is clicked
193
+ bernoulli_button_value = generate_button.value
194
+
195
+ # get parameter values
196
+ bernoulli_true_p = true_p_slider.value
197
+ bernoulli_n = sample_size_slider.value
198
+
199
+ # generate data
200
+ bernoulli_data = np.random.binomial(1, bernoulli_true_p, size=bernoulli_n)
201
+ bernoulli_Y = np.sum(bernoulli_data)
202
+ bernoulli_p_hat = bernoulli_Y / bernoulli_n
203
+
204
+ # create visualization
205
+ bernoulli_fig, (bernoulli_ax1, bernoulli_ax2) = plt.subplots(1, 2, figsize=(12, 5))
206
+
207
+ # plot data histogram
208
+ bernoulli_ax1.hist(bernoulli_data, bins=[-0.5, 0.5, 1.5], rwidth=0.8, color='lightblue')
209
+ bernoulli_ax1.set_xticks([0, 1])
210
+ bernoulli_ax1.set_xticklabels(['Failure (0)', 'Success (1)'])
211
+ bernoulli_ax1.set_title(f'Bernoulli Data: {bernoulli_n} samples')
212
+ bernoulli_ax1.set_ylabel('Count')
213
+ bernoulli_y_counts = [bernoulli_n - bernoulli_Y, bernoulli_Y]
214
+ for bernoulli_idx, bernoulli_count in enumerate(bernoulli_y_counts):
215
+ bernoulli_ax1.text(bernoulli_idx, bernoulli_count/2, f"{bernoulli_count}",
216
+ ha='center', va='center',
217
+ color='white' if bernoulli_idx == 0 else 'black',
218
+ fontweight='bold')
219
+
220
+ # calculate log-likelihood function
221
+ bernoulli_p_values = np.linspace(0.01, 0.99, 100)
222
+ bernoulli_ll_values = np.zeros_like(bernoulli_p_values)
223
+
224
+ for bernoulli_i, bernoulli_p in enumerate(bernoulli_p_values):
225
+ bernoulli_ll_values[bernoulli_i] = bernoulli_Y * np.log(bernoulli_p) + (bernoulli_n - bernoulli_Y) * np.log(1 - bernoulli_p)
226
+
227
+ # plot log-likelihood
228
+ bernoulli_ax2.plot(bernoulli_p_values, bernoulli_ll_values, 'b-', linewidth=2)
229
+ bernoulli_ax2.axvline(x=bernoulli_p_hat, color='r', linestyle='--', label=f'MLE: $\\hat{{p}} = {bernoulli_p_hat:.3f}$')
230
+ bernoulli_ax2.axvline(x=bernoulli_true_p, color='g', linestyle='--', label=f'True: $p = {bernoulli_true_p:.3f}$')
231
+ bernoulli_ax2.set_xlabel('$p$ (probability of success)')
232
+ bernoulli_ax2.set_ylabel('Log-Likelihood')
233
+ bernoulli_ax2.set_title('Log-Likelihood Function')
234
+ bernoulli_ax2.legend()
235
+
236
+ plt.tight_layout()
237
+ plt.gca()
238
+
239
+ # Create markdown to explain the results
240
+ bernoulli_explanation = mo.md(
241
+ f"""
242
+ ### Bernoulli MLE Results
243
+
244
+ **True parameter**: $p = {bernoulli_true_p:.3f}$
245
+ **Sample statistics**: {bernoulli_Y} successes out of {bernoulli_n} trials
246
+ **MLE estimate**: $\\hat{{p}} = \\frac{{{bernoulli_Y}}}{{{bernoulli_n}}} = {bernoulli_p_hat:.3f}$
247
+
248
+ The plot on the right shows the log-likelihood function $LL(p) = Y\\log(p) + (n-Y)\\log(1-p)$.
249
+ The red dashed line marks the maximum likelihood estimate $\\hat{{p}}$, and the green dashed line
250
+ shows the true parameter value.
251
+
252
+ /// note
253
+ Try increasing the sample size to see how the MLE estimate gets closer to the true parameter value!
254
+ ///
255
+ """
256
+ )
257
+
258
+ # Display plot and explanation together
259
+ mo.vstack([
260
+ bernoulli_fig,
261
+ bernoulli_explanation
262
+ ])
263
+ return (
264
+ bernoulli_Y,
265
+ bernoulli_ax1,
266
+ bernoulli_ax2,
267
+ bernoulli_button_value,
268
+ bernoulli_count,
269
+ bernoulli_data,
270
+ bernoulli_explanation,
271
+ bernoulli_fig,
272
+ bernoulli_i,
273
+ bernoulli_idx,
274
+ bernoulli_ll_values,
275
+ bernoulli_n,
276
+ bernoulli_p,
277
+ bernoulli_p_hat,
278
+ bernoulli_p_values,
279
+ bernoulli_true_p,
280
+ bernoulli_y_counts,
281
+ )
282
+
283
+
284
+ @app.cell(hide_code=True)
285
+ def _(mo):
286
+ mo.md(
287
+ r"""
288
+ ## MLE for Normal Distribution
289
+
290
+ Next, let's look at a more complex example: estimating the parameters $\mu$ and $\sigma^2$ of a Normal distribution.
291
+
292
+ ### The Model
293
+
294
+ A Normal (Gaussian) distribution has two parameters:
295
+ - $\mu$: the mean
296
+ - $\sigma^2$: the variance
297
+
298
+ Its probability density function (PDF) is:
299
+
300
+ $$f(x|\mu, \sigma^2) = \frac{1}{\sqrt{2\pi\sigma^2}} \exp\left(-\frac{(x - \mu)^2}{2\sigma^2}\right)$$
301
+
302
+ ### Deriving the MLE
303
+
304
+ Given $n$ independent samples $X_1, X_2, \ldots, X_n$ from a Normal distribution, we want to find the values of $\mu$ and $\sigma^2$ that maximize the likelihood of our observed data.
305
+
306
+ Step 1: Write the likelihood function
307
+ $$L(\mu, \sigma^2) = \prod_{i=1}^n \frac{1}{\sqrt{2\pi\sigma^2}} \exp\left(-\frac{(x_i - \mu)^2}{2\sigma^2}\right)$$
308
+
309
+ Step 2: Take the logarithm to get the log-likelihood
310
+ $$\begin{align*}
311
+ LL(\mu, \sigma^2) &= \log\prod_{i=1}^n \frac{1}{\sqrt{2\pi\sigma^2}} \exp\left(-\frac{(x_i - \mu)^2}{2\sigma^2}\right) \\
312
+ &= \sum_{i=1}^n \log\left[\frac{1}{\sqrt{2\pi\sigma^2}} \exp\left(-\frac{(x_i - \mu)^2}{2\sigma^2}\right)\right] \\
313
+ &= \sum_{i=1}^n \left[-\frac{1}{2}\log(2\pi\sigma^2) - \frac{(x_i - \mu)^2}{2\sigma^2}\right] \\
314
+ &= -\frac{n}{2}\log(2\pi\sigma^2) - \frac{1}{2\sigma^2}\sum_{i=1}^n (x_i - \mu)^2
315
+ \end{align*}$$
316
+
317
+ Step 3: Find the values of $\mu$ and $\sigma^2$ that maximize $LL(\mu, \sigma^2)$ by setting the partial derivatives to zero.
318
+
319
+ For $\mu$:
320
+ $$\begin{align*}
321
+ \frac{\partial LL(\mu, \sigma^2)}{\partial \mu} &= \frac{1}{\sigma^2}\sum_{i=1}^n (x_i - \mu) = 0 \\
322
+ \sum_{i=1}^n (x_i - \mu) &= 0 \\
323
+ \sum_{i=1}^n x_i &= n\mu \\
324
+ \hat{\mu} &= \frac{1}{n}\sum_{i=1}^n x_i
325
+ \end{align*}$$
326
+
327
+ For $\sigma^2$:
328
+ $$\begin{align*}
329
+ \frac{\partial LL(\mu, \sigma^2)}{\partial \sigma^2} &= -\frac{n}{2\sigma^2} + \frac{1}{2(\sigma^2)^2}\sum_{i=1}^n (x_i - \mu)^2 = 0 \\
330
+ \frac{n}{2\sigma^2} &= \frac{1}{2(\sigma^2)^2}\sum_{i=1}^n (x_i - \mu)^2 \\
331
+ n\sigma^2 &= \sum_{i=1}^n (x_i - \mu)^2 \\
332
+ \hat{\sigma}^2 &= \frac{1}{n}\sum_{i=1}^n (x_i - \hat{\mu})^2
333
+ \end{align*}$$
334
+
335
+ /// tip
336
+ The MLE for a Normal distribution gives us:
337
+
338
+ - $\hat{\mu}$ = sample mean
339
+ - $\hat{\sigma}^2$ = sample variance (using $n$ in the denominator, not $n-1$)
340
+ ///
341
+ """
342
+ )
343
+ return
344
+
345
+
346
+ @app.cell(hide_code=True)
347
+ def _(normal_controls):
348
+ normal_controls.center()
349
+ return
350
+
351
+
352
+ @app.cell(hide_code=True)
353
+ def _(
354
+ mo,
355
+ normal_generate_button,
356
+ normal_sample_size_slider,
357
+ np,
358
+ plt,
359
+ true_mu_slider,
360
+ true_sigma_slider,
361
+ ):
362
+ # generate normal samples when button is clicked
363
+ normal_button_value = normal_generate_button.value
364
+
365
+ # get parameter values
366
+ normal_true_mu = true_mu_slider.value
367
+ normal_true_sigma = true_sigma_slider.value
368
+ normal_true_var = normal_true_sigma**2
369
+ normal_n = normal_sample_size_slider.value
370
+
371
+ # generate random data
372
+ normal_data = np.random.normal(normal_true_mu, normal_true_sigma, size=normal_n)
373
+
374
+ # calculate mle estimates
375
+ normal_mu_hat = np.mean(normal_data)
376
+ normal_sigma2_hat = np.mean((normal_data - normal_mu_hat)**2) # mle variance using n
377
+ normal_sigma_hat = np.sqrt(normal_sigma2_hat)
378
+
379
+ # create visualization
380
+ normal_fig, (normal_ax1, normal_ax2) = plt.subplots(1, 2, figsize=(12, 5))
381
+
382
+ # plot histogram and density curves
383
+ normal_bins = np.linspace(min(normal_data) - 1, max(normal_data) + 1, 30)
384
+ normal_ax1.hist(normal_data, bins=normal_bins, density=True, alpha=0.6, color='lightblue', label='Data Histogram')
385
+
386
+ # plot range for density curves
387
+ normal_x = np.linspace(min(normal_data) - 2*normal_true_sigma, max(normal_data) + 2*normal_true_sigma, 1000)
388
+
389
+ # plot true and mle densities
390
+ normal_true_pdf = (1/(normal_true_sigma * np.sqrt(2*np.pi))) * np.exp(-0.5 * ((normal_x - normal_true_mu)/normal_true_sigma)**2)
391
+ normal_ax1.plot(normal_x, normal_true_pdf, 'g-', linewidth=2, label=f'True: N({normal_true_mu:.2f}, {normal_true_var:.2f})')
392
+
393
+ normal_mle_pdf = (1/(normal_sigma_hat * np.sqrt(2*np.pi))) * np.exp(-0.5 * ((normal_x - normal_mu_hat)/normal_sigma_hat)**2)
394
+ normal_ax1.plot(normal_x, normal_mle_pdf, 'r--', linewidth=2, label=f'MLE: N({normal_mu_hat:.2f}, {normal_sigma2_hat:.2f})')
395
+
396
+ normal_ax1.set_xlabel('x')
397
+ normal_ax1.set_ylabel('Density')
398
+ normal_ax1.set_title(f'Normal Distribution: {normal_n} samples')
399
+ normal_ax1.legend()
400
+
401
+ # create contour plot of log-likelihood
402
+ normal_mu_range = np.linspace(normal_mu_hat - 2, normal_mu_hat + 2, 100)
403
+ normal_sigma_range = np.linspace(max(0.1, normal_sigma_hat - 1), normal_sigma_hat + 1, 100)
404
+
405
+ normal_mu_grid, normal_sigma_grid = np.meshgrid(normal_mu_range, normal_sigma_range)
406
+ normal_ll_grid = np.zeros_like(normal_mu_grid)
407
+
408
+ # calculate log-likelihood for each grid point
409
+ for normal_i in range(normal_mu_grid.shape[0]):
410
+ for normal_j in range(normal_mu_grid.shape[1]):
411
+ normal_mu = normal_mu_grid[normal_i, normal_j]
412
+ normal_sigma = normal_sigma_grid[normal_i, normal_j]
413
+ normal_ll = -normal_n/2 * np.log(2*np.pi*normal_sigma**2) - np.sum((normal_data - normal_mu)**2)/(2*normal_sigma**2)
414
+ normal_ll_grid[normal_i, normal_j] = normal_ll
415
+
416
+ # plot log-likelihood contour
417
+ normal_contour = normal_ax2.contourf(normal_mu_grid, normal_sigma_grid, normal_ll_grid, levels=50, cmap='viridis')
418
+ normal_ax2.set_xlabel('μ (mean)')
419
+ normal_ax2.set_ylabel('σ (standard deviation)')
420
+ normal_ax2.set_title('Log-Likelihood Contour')
421
+
422
+ # mark mle and true params
423
+ normal_ax2.plot(normal_mu_hat, normal_sigma_hat, 'rx', markersize=10, label='MLE Estimate')
424
+ normal_ax2.plot(normal_true_mu, normal_true_sigma, 'g*', markersize=10, label='True Parameters')
425
+ normal_ax2.legend()
426
+
427
+ plt.colorbar(normal_contour, ax=normal_ax2, label='Log-Likelihood')
428
+ plt.tight_layout()
429
+ plt.gca()
430
+
431
+ # relevant markdown for the results
432
+ normal_explanation = mo.md(
433
+ f"""
434
+ ### Normal MLE Results
435
+
436
+ **True parameters**: $\mu = {normal_true_mu:.3f}$, $\sigma^2 = {normal_true_var:.3f}$
437
+ **MLE estimates**: $\hat{{\mu}} = {normal_mu_hat:.3f}$, $\hat{{\sigma}}^2 = {normal_sigma2_hat:.3f}$
438
+
439
+ The left plot shows the data histogram with the true Normal distribution (green) and the MLE-estimated distribution (red dashed).
440
+
441
+ The right plot shows the log-likelihood function as a contour map in the $(\mu, \sigma)$ parameter space. The maximum likelihood estimates are marked with a red X, while the true parameters are marked with a green star.
442
+
443
+ /// note
444
+ Notice how the log-likelihood contour is more stretched along the σ axis than the μ axis. This indicates that we typically estimate the mean with greater precision than the standard deviation.
445
+ ///
446
+
447
+ /// tip
448
+ Increase the sample size to see how the MLE estimates converge to the true parameter values!
449
+ ///
450
+ """
451
+ )
452
+
453
+ # plot and explanation together
454
+ mo.vstack([
455
+ normal_fig,
456
+ normal_explanation
457
+ ])
458
+ return (
459
+ normal_ax1,
460
+ normal_ax2,
461
+ normal_bins,
462
+ normal_button_value,
463
+ normal_contour,
464
+ normal_data,
465
+ normal_explanation,
466
+ normal_fig,
467
+ normal_i,
468
+ normal_j,
469
+ normal_ll,
470
+ normal_ll_grid,
471
+ normal_mle_pdf,
472
+ normal_mu,
473
+ normal_mu_grid,
474
+ normal_mu_hat,
475
+ normal_mu_range,
476
+ normal_n,
477
+ normal_sigma,
478
+ normal_sigma2_hat,
479
+ normal_sigma_grid,
480
+ normal_sigma_hat,
481
+ normal_sigma_range,
482
+ normal_true_mu,
483
+ normal_true_pdf,
484
+ normal_true_sigma,
485
+ normal_true_var,
486
+ normal_x,
487
+ )
488
+
489
+
490
+ @app.cell(hide_code=True)
491
+ def _(mo):
492
+ mo.md(
493
+ r"""
494
+ ## MLE for Linear Regression
495
+
496
+ Now let's look at a more practical example: using MLE to derive linear regression.
497
+
498
+ ### The Model
499
+
500
+ Consider a model where:
501
+ - We have pairs of observations $(X_1, Y_1), (X_2, Y_2), \ldots, (X_n, Y_n)$
502
+ - The relationship between $X$ and $Y$ follows: $Y = \theta X + Z$
503
+ - $Z \sim N(0, \sigma^2)$ is random noise
504
+ - Our goal is to estimate the parameter $\theta$
505
+
506
+ This means that for a given $X_i$, the conditional distribution of $Y_i$ is:
507
+
508
+ $$Y_i | X_i \sim N(\theta X_i, \sigma^2)$$
509
+
510
+ ### Deriving the MLE
511
+
512
+ Step 1: Write the likelihood function for each data point $(X_i, Y_i)$
513
+ $$f(Y_i | X_i, \theta) = \frac{1}{\sqrt{2\pi\sigma^2}} \exp\left(-\frac{(Y_i - \theta X_i)^2}{2\sigma^2}\right)$$
514
+
515
+ Step 2: Write the likelihood for all data
516
+ $$\begin{align*}
517
+ L(\theta) &= \prod_{i=1}^n f(Y_i, X_i | \theta) \\
518
+ &= \prod_{i=1}^n f(Y_i | X_i, \theta) \cdot f(X_i)
519
+ \end{align*}$$
520
+
521
+ Since $f(X_i)$ doesn't depend on $\theta$, we can simplify:
522
+ $$L(\theta) = \prod_{i=1}^n \frac{1}{\sqrt{2\pi\sigma^2}} \exp\left(-\frac{(Y_i - \theta X_i)^2}{2\sigma^2}\right) \cdot f(X_i)$$
523
+
524
+ Step 3: Take the logarithm to get the log-likelihood
525
+ $$\begin{align*}
526
+ LL(\theta) &= \log \prod_{i=1}^n \frac{1}{\sqrt{2\pi\sigma^2}} \exp\left(-\frac{(Y_i - \theta X_i)^2}{2\sigma^2}\right) \cdot f(X_i) \\
527
+ &= \sum_{i=1}^n \log\left[\frac{1}{\sqrt{2\pi\sigma^2}} \exp\left(-\frac{(Y_i - \theta X_i)^2}{2\sigma^2}\right)\right] + \sum_{i=1}^n \log f(X_i) \\
528
+ &= -\frac{n}{2} \log(2\pi\sigma^2) - \frac{1}{2\sigma^2} \sum_{i=1}^n (Y_i - \theta X_i)^2 + \sum_{i=1}^n \log f(X_i)
529
+ \end{align*}$$
530
+
531
+ Step 4: Since we only care about maximizing with respect to $\theta$, we can drop terms that don't contain $\theta$:
532
+ $$\hat{\theta} = \underset{\theta}{\operatorname{argmax}} \left[ -\frac{1}{2\sigma^2} \sum_{i=1}^n (Y_i - \theta X_i)^2 \right]$$
533
+
534
+ This is equivalent to:
535
+ $$\hat{\theta} = \underset{\theta}{\operatorname{argmin}} \sum_{i=1}^n (Y_i - \theta X_i)^2$$
536
+
537
+ Step 5: Find the value of $\theta$ that minimizes the sum of squared errors by setting the derivative to zero:
538
+ $$\begin{align*}
539
+ \frac{d}{d\theta} \sum_{i=1}^n (Y_i - \theta X_i)^2 &= 0 \\
540
+ \sum_{i=1}^n -2X_i(Y_i - \theta X_i) &= 0 \\
541
+ \sum_{i=1}^n X_i Y_i - \theta X_i^2 &= 0 \\
542
+ \sum_{i=1}^n X_i Y_i &= \theta \sum_{i=1}^n X_i^2 \\
543
+ \hat{\theta} &= \frac{\sum_{i=1}^n X_i Y_i}{\sum_{i=1}^n X_i^2}
544
+ \end{align*}$$
545
+
546
+ /// tip
547
+ **Key Insight**: MLE for this simple linear model gives us the least squares estimator! This is an important connection between MLE and regression.
548
+ ///
549
+ """
550
+ )
551
+ return
552
+
553
+
554
+ @app.cell(hide_code=True)
555
+ def _(linear_controls):
556
+ linear_controls.center()
557
+ return
558
+
559
+
560
+ @app.cell(hide_code=True)
561
+ def _(
562
+ linear_generate_button,
563
+ linear_sample_size_slider,
564
+ mo,
565
+ noise_sigma_slider,
566
+ np,
567
+ plt,
568
+ true_theta_slider,
569
+ ):
570
+ # linear model data calc when button is clicked
571
+ linear_button_value = linear_generate_button.value
572
+
573
+ # get parameter values
574
+ linear_true_theta = true_theta_slider.value
575
+ linear_noise_sigma = noise_sigma_slider.value
576
+ linear_n = linear_sample_size_slider.value
577
+
578
+ # generate x data (uniformly between -3 and 3)
579
+ linear_X = np.random.uniform(-3, 3, size=linear_n)
580
+
581
+ # generate y data according to the model y = θx + z
582
+ linear_Z = np.random.normal(0, linear_noise_sigma, size=linear_n)
583
+ linear_Y = linear_true_theta * linear_X + linear_Z
584
+
585
+ # calculate mle estimate
586
+ linear_theta_hat = np.sum(linear_X * linear_Y) / np.sum(linear_X**2)
587
+
588
+ # calculate sse for different theta values
589
+ linear_theta_range = np.linspace(linear_true_theta - 1.5, linear_true_theta + 1.5, 100)
590
+ linear_sse_values = np.zeros_like(linear_theta_range)
591
+
592
+ for linear_i, linear_theta in enumerate(linear_theta_range):
593
+ linear_y_pred = linear_theta * linear_X
594
+ linear_sse_values[linear_i] = np.sum((linear_Y - linear_y_pred)**2)
595
+
596
+ # convert sse to log-likelihood (ignoring constant terms)
597
+ linear_ll_values = -linear_sse_values / (2 * linear_noise_sigma**2)
598
+
599
+ # create visualization
600
+ linear_fig, (linear_ax1, linear_ax2) = plt.subplots(1, 2, figsize=(12, 5))
601
+
602
+ # plot scatter plot with regression lines
603
+ linear_ax1.scatter(linear_X, linear_Y, color='blue', alpha=0.6, label='Data points')
604
+
605
+ # plot range for regression lines
606
+ linear_x_line = np.linspace(-3, 3, 100)
607
+
608
+ # plot true and mle regression lines
609
+ linear_ax1.plot(linear_x_line, linear_true_theta * linear_x_line, 'g-', linewidth=2, label=f'True: Y = {linear_true_theta:.2f}X')
610
+ linear_ax1.plot(linear_x_line, linear_theta_hat * linear_x_line, 'r--', linewidth=2, label=f'MLE: Y = {linear_theta_hat:.2f}X')
611
+
612
+ linear_ax1.set_xlabel('X')
613
+ linear_ax1.set_ylabel('Y')
614
+ linear_ax1.set_title(f'Linear Regression: {linear_n} data points')
615
+ linear_ax1.grid(True, alpha=0.3)
616
+ linear_ax1.legend()
617
+
618
+ # plot log-likelihood function
619
+ linear_ax2.plot(linear_theta_range, linear_ll_values, 'b-', linewidth=2)
620
+ linear_ax2.axvline(x=linear_theta_hat, color='r', linestyle='--', label=f'MLE: θ = {linear_theta_hat:.3f}')
621
+ linear_ax2.axvline(x=linear_true_theta, color='g', linestyle='--', label=f'True: θ = {linear_true_theta:.3f}')
622
+ linear_ax2.set_xlabel('θ (slope parameter)')
623
+ linear_ax2.set_ylabel('Log-Likelihood')
624
+ linear_ax2.set_title('Log-Likelihood Function')
625
+ linear_ax2.grid(True, alpha=0.3)
626
+ linear_ax2.legend()
627
+
628
+ plt.tight_layout()
629
+ plt.gca()
630
+
631
+ # relevant markdown to explain results
632
+ linear_explanation = mo.md(
633
+ f"""
634
+ ### Linear Regression MLE Results
635
+
636
+ **True parameter**: $\\theta = {linear_true_theta:.3f}$
637
+ **MLE estimate**: $\\hat{{\\theta}} = {linear_theta_hat:.3f}$
638
+
639
+ The left plot shows the scatter plot of data points with the true regression line (green) and the MLE-estimated regression line (red dashed).
640
+
641
+ The right plot shows the log-likelihood function for different values of $\\theta$. The maximum likelihood estimate is marked with a red dashed line, and the true parameter is marked with a green dashed line.
642
+
643
+ /// note
644
+ The MLE estimate $\\hat{{\\theta}} = \\frac{{\\sum_{{i=1}}^n X_i Y_i}}{{\\sum_{{i=1}}^n X_i^2}}$ minimizes the sum of squared errors between the predicted and actual Y values.
645
+ ///
646
+
647
+ /// tip
648
+ Try increasing the noise level to see how it affects the precision of the estimate!
649
+ ///
650
+ """
651
+ )
652
+
653
+ # show plot and explanation
654
+ mo.vstack([
655
+ linear_fig,
656
+ linear_explanation
657
+ ])
658
+ return (
659
+ linear_X,
660
+ linear_Y,
661
+ linear_Z,
662
+ linear_ax1,
663
+ linear_ax2,
664
+ linear_button_value,
665
+ linear_explanation,
666
+ linear_fig,
667
+ linear_i,
668
+ linear_ll_values,
669
+ linear_n,
670
+ linear_noise_sigma,
671
+ linear_sse_values,
672
+ linear_theta,
673
+ linear_theta_hat,
674
+ linear_theta_range,
675
+ linear_true_theta,
676
+ linear_x_line,
677
+ linear_y_pred,
678
+ )
679
+
680
+
681
+ @app.cell(hide_code=True)
682
+ def _(mo):
683
+ mo.md(
684
+ r"""
685
+ ## Interactive Concept: Density/Mass Functions vs. Likelihood
686
+
687
+ To better understand the distinction between likelihood and density/mass functions, let's create an interactive visualization. This concept is crucial for understanding why MLE works.
688
+ """
689
+ )
690
+ return
691
+
692
+
693
+ @app.cell(hide_code=True)
694
+ def _(concept_controls):
695
+ concept_controls.center()
696
+ return
697
+
698
+
699
+ @app.cell(hide_code=True)
700
+ def _(concept_dist_type, mo, np, perspective_selector, plt, stats):
701
+ # current distribution type
702
+ concept_dist_type_value = concept_dist_type.value
703
+
704
+ # view mode from dropdown
705
+ concept_view_mode = "likelihood" if perspective_selector.value == "Likelihood Perspective" else "probability"
706
+
707
+ # visualization based on distribution type
708
+ concept_fig, concept_ax = plt.subplots(figsize=(10, 6))
709
+
710
+ if concept_dist_type_value == "Normal":
711
+ if concept_view_mode == "probability":
712
+ # density function perspective: fixed params, varying data
713
+ concept_mu = 0 # fixed parameter
714
+ concept_sigma = 1 # fixed parameter
715
+
716
+ # generate x values for the pdf
717
+ concept_x = np.linspace(-4, 4, 1000)
718
+
719
+ # plot pdf
720
+ concept_pdf = stats.norm.pdf(concept_x, concept_mu, concept_sigma)
721
+ concept_ax.plot(concept_x, concept_pdf, 'b-', linewidth=2, label='PDF: N(0, 1)')
722
+
723
+ # highlight specific data values
724
+ concept_data_points = [-2, -1, 0, 1, 2]
725
+ concept_colors = ['#FF9999', '#FFCC99', '#99FF99', '#99CCFF', '#CC99FF']
726
+
727
+ for concept_i, concept_data in enumerate(concept_data_points):
728
+ concept_prob = stats.norm.pdf(concept_data, concept_mu, concept_sigma)
729
+ concept_ax.plot([concept_data, concept_data], [0, concept_prob], concept_colors[concept_i], linewidth=2)
730
+ concept_ax.scatter(concept_data, concept_prob, color=concept_colors[concept_i], s=50,
731
+ label=f'PDF at x={concept_data}: {concept_prob:.3f}')
732
+
733
+ concept_ax.set_xlabel('Data (x)')
734
+ concept_ax.set_ylabel('Probability Density')
735
+ concept_ax.set_title('Density Function Perspective: Fixed Parameters (μ=0, σ=1), Different Data Points')
736
+
737
+ else: # likelihood perspective
738
+ # likelihood perspective: fixed data, varying parameters
739
+ concept_data_point = 1.5 # fixed observed data
740
+
741
+ # different possible parameter values (means)
742
+ concept_mus = [-1, 0, 1, 2, 3]
743
+ concept_sigma = 1
744
+
745
+ # generate x values for multiple pdfs
746
+ concept_x = np.linspace(-4, 6, 1000)
747
+
748
+ concept_colors = ['#FF9999', '#FFCC99', '#99FF99', '#99CCFF', '#CC99FF']
749
+
750
+ for concept_i, concept_mu in enumerate(concept_mus):
751
+ concept_pdf = stats.norm.pdf(concept_x, concept_mu, concept_sigma)
752
+ concept_ax.plot(concept_x, concept_pdf, color=concept_colors[concept_i], linewidth=2, alpha=0.7,
753
+ label=f'N({concept_mu}, 1)')
754
+
755
+ # mark the likelihood of the data point for this param
756
+ concept_likelihood = stats.norm.pdf(concept_data_point, concept_mu, concept_sigma)
757
+ concept_ax.plot([concept_data_point, concept_data_point], [0, concept_likelihood], concept_colors[concept_i], linewidth=2)
758
+ concept_ax.scatter(concept_data_point, concept_likelihood, color=concept_colors[concept_i], s=50,
759
+ label=f'L(μ={concept_mu}|X=1.5) = {concept_likelihood:.3f}')
760
+
761
+ # add vertical line at the observed data point
762
+ concept_ax.axvline(x=concept_data_point, color='black', linestyle='--',
763
+ label=f'Observed data: X=1.5')
764
+
765
+ concept_ax.set_xlabel('Data (x)')
766
+ concept_ax.set_ylabel('Probability Density / Likelihood')
767
+ concept_ax.set_title('Likelihood Perspective: Fixed Data Point (X=1.5), Different Parameter Values')
768
+
769
+ elif concept_dist_type_value == "Bernoulli":
770
+ if concept_view_mode == "probability":
771
+ # probability perspective: fixed parameter, two possible data values
772
+ concept_p = 0.3 # fixed parameter
773
+
774
+ # bar chart for p(x=0) and p(x=1)
775
+ concept_ax.bar([0, 1], [1-concept_p, concept_p], width=0.4, color=['#99CCFF', '#FF9999'],
776
+ alpha=0.7, label=f'PMF: Bernoulli({concept_p})')
777
+
778
+ # text showing probabilities
779
+ concept_ax.text(0, (1-concept_p)/2, f'P(X=0|p={concept_p}) = {1-concept_p:.3f}', ha='center', va='center', fontweight='bold')
780
+ concept_ax.text(1, concept_p/2, f'P(X=1|p={concept_p}) = {concept_p:.3f}', ha='center', va='center', fontweight='bold')
781
+
782
+ concept_ax.set_xlabel('Data (x)')
783
+ concept_ax.set_ylabel('Probability')
784
+ concept_ax.set_xticks([0, 1])
785
+ concept_ax.set_xticklabels(['X=0', 'X=1'])
786
+ concept_ax.set_ylim(0, 1)
787
+ concept_ax.set_title('Probability Perspective: Fixed Parameter (p=0.3), Different Data Values')
788
+
789
+ else: # likelihood perspective
790
+ # likelihood perspective: fixed data, varying parameter
791
+ concept_data_point = 1 # fixed observed data (success)
792
+
793
+ # different possible parameter values
794
+ concept_p_values = np.linspace(0.01, 0.99, 100)
795
+
796
+ # calculate likelihood for each p value
797
+ if concept_data_point == 1:
798
+ # for x=1, likelihood is p
799
+ concept_likelihood = concept_p_values
800
+ concept_ax.plot(concept_p_values, concept_likelihood, 'b-', linewidth=2,
801
+ label=f'L(p|X=1) = p')
802
+
803
+ # highlight specific values
804
+ concept_highlight_ps = [0.2, 0.5, 0.8]
805
+ concept_colors = ['#FF9999', '#99FF99', '#99CCFF']
806
+
807
+ for concept_i, concept_p in enumerate(concept_highlight_ps):
808
+ concept_ax.plot([concept_p, concept_p], [0, concept_p], concept_colors[concept_i], linewidth=2)
809
+ concept_ax.scatter(concept_p, concept_p, color=concept_colors[concept_i], s=50,
810
+ label=f'L(p={concept_p}|X=1) = {concept_p:.3f}')
811
+
812
+ concept_ax.set_title('Likelihood Perspective: Fixed Data Point (X=1), Different Parameter Values')
813
+
814
+ else: # x=0
815
+ # for x = 0, likelihood is (1-p)
816
+ concept_likelihood = 1 - concept_p_values
817
+ concept_ax.plot(concept_p_values, concept_likelihood, 'r-', linewidth=2,
818
+ label=f'L(p|X=0) = (1-p)')
819
+
820
+ # highlight some specific values
821
+ concept_highlight_ps = [0.2, 0.5, 0.8]
822
+ concept_colors = ['#FF9999', '#99FF99', '#99CCFF']
823
+
824
+ for concept_i, concept_p in enumerate(concept_highlight_ps):
825
+ concept_ax.plot([concept_p, concept_p], [0, 1-concept_p], concept_colors[concept_i], linewidth=2)
826
+ concept_ax.scatter(concept_p, 1-concept_p, color=concept_colors[concept_i], s=50,
827
+ label=f'L(p={concept_p}|X=0) = {1-concept_p:.3f}')
828
+
829
+ concept_ax.set_title('Likelihood Perspective: Fixed Data Point (X=0), Different Parameter Values')
830
+
831
+ concept_ax.set_xlabel('Parameter (p)')
832
+ concept_ax.set_ylabel('Likelihood')
833
+ concept_ax.set_xlim(0, 1)
834
+ concept_ax.set_ylim(0, 1)
835
+
836
+ elif concept_dist_type_value == "Poisson":
837
+ if concept_view_mode == "probability":
838
+ # probability perspective: fixed parameter, different data values
839
+ concept_lam = 2.5 # fixed parameter
840
+
841
+ # pmf for different x values plot
842
+ concept_x_values = np.arange(0, 10)
843
+ concept_pmf_values = stats.poisson.pmf(concept_x_values, concept_lam)
844
+
845
+ concept_ax.bar(concept_x_values, concept_pmf_values, width=0.4, color='#99CCFF',
846
+ alpha=0.7, label=f'PMF: Poisson({concept_lam})')
847
+
848
+ # highlight a few specific values
849
+ concept_highlight_xs = [1, 2, 3, 4]
850
+ concept_colors = ['#FF9999', '#99FF99', '#FFCC99', '#CC99FF']
851
+
852
+ for concept_i, concept_x in enumerate(concept_highlight_xs):
853
+ concept_prob = stats.poisson.pmf(concept_x, concept_lam)
854
+ concept_ax.scatter(concept_x, concept_prob, color=concept_colors[concept_i], s=50,
855
+ label=f'P(X={concept_x}|λ={concept_lam}) = {concept_prob:.3f}')
856
+
857
+ concept_ax.set_xlabel('Data (x)')
858
+ concept_ax.set_ylabel('Probability')
859
+ concept_ax.set_xticks(concept_x_values)
860
+ concept_ax.set_title('Probability Perspective: Fixed Parameter (λ=2.5), Different Data Values')
861
+
862
+ else: # likelihood perspective
863
+ # likelihood perspective: fixed data, varying parameter
864
+ concept_data_point = 4 # fixed observed data
865
+
866
+ # different possible param values
867
+ concept_lambda_values = np.linspace(0.1, 8, 100)
868
+
869
+ # calc likelihood for each lambda value
870
+ concept_likelihood = stats.poisson.pmf(concept_data_point, concept_lambda_values)
871
+
872
+ concept_ax.plot(concept_lambda_values, concept_likelihood, 'b-', linewidth=2,
873
+ label=f'L(λ|X={concept_data_point})')
874
+
875
+ # highlight some specific values
876
+ concept_highlight_lambdas = [1, 2, 4, 6]
877
+ concept_colors = ['#FF9999', '#99FF99', '#99CCFF', '#FFCC99']
878
+
879
+ for concept_i, concept_lam in enumerate(concept_highlight_lambdas):
880
+ concept_like_val = stats.poisson.pmf(concept_data_point, concept_lam)
881
+ concept_ax.plot([concept_lam, concept_lam], [0, concept_like_val], concept_colors[concept_i], linewidth=2)
882
+ concept_ax.scatter(concept_lam, concept_like_val, color=concept_colors[concept_i], s=50,
883
+ label=f'L(λ={concept_lam}|X={concept_data_point}) = {concept_like_val:.3f}')
884
+
885
+ concept_ax.set_xlabel('Parameter (λ)')
886
+ concept_ax.set_ylabel('Likelihood')
887
+ concept_ax.set_title(f'Likelihood Perspective: Fixed Data Point (X={concept_data_point}), Different Parameter Values')
888
+
889
+ concept_ax.legend(loc='best', fontsize=9)
890
+ concept_ax.grid(True, alpha=0.3)
891
+ plt.tight_layout()
892
+ plt.gca()
893
+
894
+ # relevant explanation based on view mode
895
+ if concept_view_mode == "probability":
896
+ concept_explanation = mo.md(
897
+ f"""
898
+ ### Density/Mass Function Perspective
899
+
900
+ In the **density/mass function perspective**, the parameters of the distribution are **fixed and known**, and we evaluate the function at **different possible data values**.
901
+
902
+ For the {concept_dist_type_value} distribution, we've fixed the parameter{'s' if concept_dist_type_value == 'Normal' else ''} and shown the {'density' if concept_dist_type_value == 'Normal' else 'probability mass'} function evaluated at different data points.
903
+
904
+ This is the typical perspective when:
905
+
906
+ - We know the true parameters of a distribution
907
+ - We want to evaluate the {'density' if concept_dist_type_value == 'Normal' else 'probability mass'} at different observations
908
+ - We make predictions based on our model
909
+
910
+ **Mathematical notation**: $f(x | \theta)$
911
+ """
912
+ )
913
+ else: # likelihood perspective
914
+ concept_explanation = mo.md(
915
+ f"""
916
+ ### Likelihood Perspective
917
+
918
+ In the **likelihood perspective**, the observed data is **fixed and known**, and we calculate how likely different parameter values are to have generated that data.
919
+
920
+ For the {concept_dist_type_value} distribution, we've fixed the observed data point{'s' if concept_dist_type_value == 'Normal' else ''} and shown the likelihood of different parameter values.
921
+
922
+ This is the perspective used in MLE:
923
+
924
+ - We have observed data
925
+ - We don't know the true parameters
926
+ - We want to find parameters that best explain our observations
927
+
928
+ **Mathematical notation**: $L(\theta | X = x)$
929
+
930
+ /// tip
931
+ The value of $\\theta$ that maximizes this likelihood function is the MLE estimate $\\hat{{\\theta}}$!
932
+ ///
933
+ """
934
+ )
935
+
936
+ # Display plot and explanation together
937
+ mo.vstack([
938
+ concept_fig,
939
+ concept_explanation
940
+ ])
941
+ return (
942
+ concept_ax,
943
+ concept_colors,
944
+ concept_data,
945
+ concept_data_point,
946
+ concept_data_points,
947
+ concept_dist_type_value,
948
+ concept_explanation,
949
+ concept_fig,
950
+ concept_highlight_lambdas,
951
+ concept_highlight_ps,
952
+ concept_highlight_xs,
953
+ concept_i,
954
+ concept_lam,
955
+ concept_lambda_values,
956
+ concept_like_val,
957
+ concept_likelihood,
958
+ concept_mu,
959
+ concept_mus,
960
+ concept_p,
961
+ concept_p_values,
962
+ concept_pdf,
963
+ concept_pmf_values,
964
+ concept_prob,
965
+ concept_sigma,
966
+ concept_view_mode,
967
+ concept_x,
968
+ concept_x_values,
969
+ )
970
+
971
+
972
+ @app.cell(hide_code=True)
973
+ def _(mo):
974
+ mo.md(
975
+ r"""
976
+ ## 🤔 Test Your Understanding
977
+
978
+ Which of the following statements about Maximum Likelihood Estimation are correct? Click each statement to check your answer.
979
+
980
+ /// details | Probability and likelihood have different interpretations: probability measures the chance of data given parameters, while likelihood measures how likely parameters are given data.
981
+ ✅ **Correct!**
982
+
983
+ Probability measures how likely it is to observe particular data when we know the parameters. Likelihood measures how likely particular parameter values are, given observed data.
984
+
985
+ Mathematically, probability is $P(X=x|\theta)$ while likelihood is $L(\theta|X=x)$.
986
+ ///
987
+
988
+ /// details | We use log-likelihood instead of likelihood because it's mathematically simpler and numerically more stable.
989
+ ✅ **Correct!**
990
+
991
+ We work with log-likelihood for several reasons:
992
+ 1. It converts products into sums, which is easier to work with mathematically
993
+ 2. It avoids numerical underflow when multiplying many small probabilities
994
+ 3. Logarithm is a monotonically increasing function, so the maximum of the likelihood occurs at the same parameter values as the maximum of the log-likelihood
995
+ ///
996
+
997
+ /// details | For a Bernoulli distribution, the MLE for parameter p is the sample mean of the observations.
998
+ ✅ **Correct!**
999
+
1000
+ For a Bernoulli distribution with parameter $p$, given $n$ independent samples $X_1, X_2, \ldots, X_n$, the MLE estimator is:
1001
+
1002
+ $$\hat{p} = \frac{\sum_{i=1}^n X_i}{n}$$
1003
+
1004
+ This is simply the sample mean, or the proportion of successes (1s) in the data.
1005
+ ///
1006
+
1007
+ /// details | For a Normal distribution, MLE gives unbiased estimates for both mean and variance parameters.
1008
+ ❌ **Incorrect.**
1009
+
1010
+ While the MLE for the mean ($\hat{\mu} = \frac{1}{n}\sum_{i=1}^n X_i$) is unbiased, the MLE for variance:
1011
+
1012
+ $$\hat{\sigma}^2 = \frac{1}{n}\sum_{i=1}^n (X_i - \hat{\mu})^2$$
1013
+
1014
+ is a biased estimator. It uses $n$ in the denominator rather than $n-1$ used in the unbiased estimator.
1015
+ ///
1016
+
1017
+ /// details | MLE estimators are always unbiased regardless of the distribution.
1018
+ ❌ **Incorrect.**
1019
+
1020
+ MLE is not always unbiased, though it often is asymptotically unbiased (meaning the bias approaches zero as the sample size increases).
1021
+
1022
+ A notable example is the MLE estimator for the variance of a Normal distribution:
1023
+ $$\hat{\sigma}^2 = \frac{1}{n}\sum_{i=1}^n (X_i - \hat{\mu})^2$$
1024
+
1025
+ This estimator is biased, which is why we often use the unbiased estimator:
1026
+ $$s^2 = \frac{1}{n-1}\sum_{i=1}^n (X_i - \hat{\mu})^2$$
1027
+
1028
+ Despite occasional bias, MLE estimators have many desirable properties, including consistency and asymptotic efficiency.
1029
+ ///
1030
+ """
1031
+ )
1032
+ return
1033
+
1034
+
1035
+ @app.cell(hide_code=True)
1036
+ def _(mo):
1037
+ mo.md(
1038
+ r"""
1039
+ ## Summary
1040
+
1041
+ Maximum Likelihood Estimation really is one of those elegant ideas that sits at the core of modern statistics. When you get down to it, MLE is just about finding the most plausible explanation for the data we've observed. It's like being a detective - you have some clues (your data), and you're trying to piece together the most likely story (your parameters) that explains them.
1042
+
1043
+ We've seen how this works with different distributions. For the Bernoulli, it simply gives us the sample proportion. For the Normal, it gives us the sample mean and a slightly biased estimate of variance. And for linear regression, it provides a mathematical justification for the least squares method that everyone learns in basic stats classes.
1044
+
1045
+ What makes MLE so useful in practice is that it tends to give us estimates with good properties. As you collect more data, the estimates generally get closer to the true values (consistency) and do so efficiently. That's why MLE is everywhere in statistics and machine learning - from simple regression models to complex neural networks.
1046
+
1047
+ The most important takeaway? Next time you're fitting a model to data, remember that you're not just following a recipe - you're finding the parameters that make your observed data most likely to have occurred. That's the essence of statistical inference.
1048
+ """
1049
+ )
1050
+ return
1051
+
1052
+
1053
+ @app.cell(hide_code=True)
1054
+ def _(mo):
1055
+ mo.md(
1056
+ r"""
1057
+ ## Further Reading
1058
+
1059
+ If you're curious to dive deeper into this topic, check out "Statistical Inference" by Casella and Berger - it's the classic text that many statisticians learned from. For a more machine learning angle, Bishop's "Pattern Recognition and Machine Learning" shows how MLE connects to more advanced topics like EM algorithms and Bayesian methods.
1060
+
1061
+ Beyond the basics we've covered, you might explore Bayesian estimation (which incorporates prior knowledge), Fisher Information (which tells us how precisely we can estimate parameters), or the EM algorithm (for when we have missing data or latent variables). Each of these builds on the foundation of likelihood that we've established here.
1062
+ """
1063
+ )
1064
+ return
1065
+
1066
+
1067
+ @app.cell(hide_code=True)
1068
+ def _(mo):
1069
+ mo.md(r"""## Appendix (helper functions and imports)""")
1070
+ return
1071
+
1072
+
1073
+ @app.cell
1074
+ def _():
1075
+ import marimo as mo
1076
+ return (mo,)
1077
+
1078
+
1079
+ @app.cell
1080
+ def _():
1081
+ import numpy as np
1082
+ import matplotlib.pyplot as plt
1083
+ from scipy import stats
1084
+ import plotly.graph_objects as go
1085
+ import polars as pl
1086
+ from matplotlib import cm
1087
+
1088
+ # Set a consistent random seed for reproducibility
1089
+ np.random.seed(42)
1090
+
1091
+ # Set a nice style for matplotlib
1092
+ plt.style.use('seaborn-v0_8-darkgrid')
1093
+ return cm, go, np, pl, plt, stats
1094
+
1095
+
1096
+ @app.cell(hide_code=True)
1097
+ def _(mo):
1098
+ # Create interactive elements
1099
+ true_p_slider = mo.ui.slider(
1100
+ start =0.01,
1101
+ stop =0.99,
1102
+ value=0.3,
1103
+ step=0.01,
1104
+ label="True probability (p)"
1105
+ )
1106
+
1107
+ sample_size_slider = mo.ui.slider(
1108
+ start =10,
1109
+ stop =1000,
1110
+ value=100,
1111
+ step=10,
1112
+ label="Sample size (n)"
1113
+ )
1114
+
1115
+ generate_button = mo.ui.button(label="Generate New Sample", kind="success")
1116
+
1117
+ controls = mo.vstack([
1118
+ mo.vstack([true_p_slider, sample_size_slider]),
1119
+ generate_button
1120
+ ], justify="space-between")
1121
+ return controls, generate_button, sample_size_slider, true_p_slider
1122
+
1123
+
1124
+ @app.cell(hide_code=True)
1125
+ def _(mo):
1126
+ # Create interactive elements for Normal distribution
1127
+ true_mu_slider = mo.ui.slider(
1128
+ start =-5,
1129
+ stop =5,
1130
+ value=0,
1131
+ step=0.1,
1132
+ label="True mean (μ)"
1133
+ )
1134
+
1135
+ true_sigma_slider = mo.ui.slider(
1136
+ start =0.5,
1137
+ stop =3,
1138
+ value=1,
1139
+ step=0.1,
1140
+ label="True standard deviation (σ)"
1141
+ )
1142
+
1143
+ normal_sample_size_slider = mo.ui.slider(
1144
+ start =10,
1145
+ stop =500,
1146
+ value=50,
1147
+ step=10,
1148
+ label="Sample size (n)"
1149
+ )
1150
+
1151
+ normal_generate_button = mo.ui.button(label="Generate New Sample", kind="warn")
1152
+
1153
+ normal_controls = mo.hstack([
1154
+ mo.vstack([true_mu_slider, true_sigma_slider, normal_sample_size_slider]),
1155
+ normal_generate_button
1156
+ ], justify="space-between")
1157
+ return (
1158
+ normal_controls,
1159
+ normal_generate_button,
1160
+ normal_sample_size_slider,
1161
+ true_mu_slider,
1162
+ true_sigma_slider,
1163
+ )
1164
+
1165
+
1166
+ @app.cell(hide_code=True)
1167
+ def _(mo):
1168
+ # Create interactive elements for linear regression
1169
+ true_theta_slider = mo.ui.slider(
1170
+ start =-2,
1171
+ stop =2,
1172
+ value=0.5,
1173
+ step=0.1,
1174
+ label="True slope (θ)"
1175
+ )
1176
+
1177
+ noise_sigma_slider = mo.ui.slider(
1178
+ start =0.1,
1179
+ stop =2,
1180
+ value=0.5,
1181
+ step=0.1,
1182
+ label="Noise level (σ)"
1183
+ )
1184
+
1185
+ linear_sample_size_slider = mo.ui.slider(
1186
+ start =10,
1187
+ stop =200,
1188
+ value=50,
1189
+ step=10,
1190
+ label="Sample size (n)"
1191
+ )
1192
+
1193
+ linear_generate_button = mo.ui.button(label="Generate New Sample", kind="warn")
1194
+
1195
+ linear_controls = mo.hstack([
1196
+ mo.vstack([true_theta_slider, noise_sigma_slider, linear_sample_size_slider]),
1197
+ linear_generate_button
1198
+ ], justify="space-between")
1199
+ return (
1200
+ linear_controls,
1201
+ linear_generate_button,
1202
+ linear_sample_size_slider,
1203
+ noise_sigma_slider,
1204
+ true_theta_slider,
1205
+ )
1206
+
1207
+
1208
+ @app.cell(hide_code=True)
1209
+ def _(mo):
1210
+ # Interactive elements for likelihood vs probability demo
1211
+ concept_dist_type = mo.ui.dropdown(
1212
+ options=["Normal", "Bernoulli", "Poisson"],
1213
+ value="Normal",
1214
+ label="Distribution"
1215
+ )
1216
+
1217
+ # Replace buttons with a simple dropdown selector
1218
+ perspective_selector = mo.ui.dropdown(
1219
+ options=["Probability Perspective", "Likelihood Perspective"],
1220
+ value="Probability Perspective",
1221
+ label="View"
1222
+ )
1223
+
1224
+ concept_controls = mo.vstack([
1225
+ mo.hstack([concept_dist_type, perspective_selector])
1226
+ ])
1227
+ return concept_controls, concept_dist_type, perspective_selector
1228
+
1229
+
1230
+ if __name__ == "__main__":
1231
+ app.run()
python/006_dictionaries.py CHANGED
@@ -196,13 +196,13 @@ def _():
196
 
197
  @app.cell
198
  def _(mo, nested_data):
199
- mo.md(f"Alice's age: {nested_data["users"]["alice"]["age"]}")
200
  return
201
 
202
 
203
  @app.cell
204
  def _(mo, nested_data):
205
- mo.md(f"Bob's interests: {nested_data["users"]["bob"]["interests"]}")
206
  return
207
 
208
 
 
196
 
197
  @app.cell
198
  def _(mo, nested_data):
199
+ mo.md(f"Alice's age: {nested_data['users']['alice']['age']}")
200
  return
201
 
202
 
203
  @app.cell
204
  def _(mo, nested_data):
205
+ mo.md(f"Bob's interests: {nested_data['users']['bob']['interests']}")
206
  return
207
 
208
 
scripts/build.py ADDED
@@ -0,0 +1,281 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+
3
+ import os
4
+ import subprocess
5
+ import argparse
6
+ import json
7
+ import datetime
8
+ import markdown
9
+ from datetime import date
10
+ from pathlib import Path
11
+ from typing import Dict, List, Any, Optional, Tuple
12
+
13
+ from jinja2 import Environment, FileSystemLoader
14
+
15
+
16
+ def export_html_wasm(notebook_path: str, output_dir: str, as_app: bool = False) -> bool:
17
+ """Export a single marimo notebook to HTML format.
18
+
19
+ Args:
20
+ notebook_path: Path to the notebook to export
21
+ output_dir: Directory to write the output HTML files
22
+ as_app: If True, export as app instead of notebook
23
+
24
+ Returns:
25
+ bool: True if export succeeded, False otherwise
26
+ """
27
+ # Create directory for the output
28
+ os.makedirs(output_dir, exist_ok=True)
29
+
30
+ # Determine the output path (preserving directory structure)
31
+ rel_path = os.path.basename(os.path.dirname(notebook_path))
32
+ if rel_path != os.path.dirname(notebook_path):
33
+ # Create subdirectory if needed
34
+ os.makedirs(os.path.join(output_dir, rel_path), exist_ok=True)
35
+
36
+ # Determine output filename (same as input but with .html extension)
37
+ output_filename = os.path.basename(notebook_path).replace(".py", ".html")
38
+ output_path = os.path.join(output_dir, rel_path, output_filename)
39
+
40
+ # Run marimo export command
41
+ mode = "--mode app" if as_app else "--mode edit"
42
+ cmd = f"marimo export html-wasm {mode} {notebook_path} -o {output_path} --sandbox"
43
+ print(f"Exporting {notebook_path} to {rel_path}/{output_filename} as {'app' if as_app else 'notebook'}")
44
+ print(f"Running command: {cmd}")
45
+
46
+ try:
47
+ result = subprocess.run(cmd, shell=True, check=True, capture_output=True, text=True)
48
+ print(f"Successfully exported {notebook_path} to {output_path}")
49
+ return True
50
+ except subprocess.CalledProcessError as e:
51
+ print(f"Error exporting {notebook_path}: {e}")
52
+ print(f"Command output: {e.output}")
53
+ return False
54
+
55
+
56
+ def get_course_metadata(course_dir: Path) -> Dict[str, Any]:
57
+ """Extract metadata from a course directory.
58
+
59
+ Reads the README.md file to extract title and description.
60
+
61
+ Args:
62
+ course_dir: Path to the course directory
63
+
64
+ Returns:
65
+ Dict: Dictionary containing course metadata (title, description)
66
+ """
67
+ readme_path = course_dir / "README.md"
68
+ title = course_dir.name.replace("_", " ").title()
69
+ description = ""
70
+ description_html = ""
71
+
72
+ if readme_path.exists():
73
+ with open(readme_path, "r", encoding="utf-8") as f:
74
+ content = f.read()
75
+
76
+ # Try to extract title from first heading
77
+ title_match = content.split("\n")[0]
78
+ if title_match.startswith("# "):
79
+ title = title_match[2:].strip()
80
+
81
+ # Extract description from content after first heading
82
+ desc_content = "\n".join(content.split("\n")[1:]).strip()
83
+ if desc_content:
84
+ # Take first paragraph as description, preserve markdown formatting
85
+ description = desc_content.split("\n\n")[0].strip()
86
+ # Convert markdown to HTML
87
+ description_html = markdown.markdown(description)
88
+
89
+ return {
90
+ "title": title,
91
+ "description": description,
92
+ "description_html": description_html
93
+ }
94
+
95
+
96
+ def organize_notebooks_by_course(all_notebooks: List[str]) -> Dict[str, Dict[str, Any]]:
97
+ """Organize notebooks by course.
98
+
99
+ Args:
100
+ all_notebooks: List of paths to notebooks
101
+
102
+ Returns:
103
+ Dict: A dictionary where keys are course directories and values are
104
+ metadata about the course and its notebooks
105
+ """
106
+ courses = {}
107
+
108
+ for notebook_path in sorted(all_notebooks):
109
+ # Parse the path to determine course
110
+ # The first directory in the path is the course
111
+ path_parts = Path(notebook_path).parts
112
+
113
+ if len(path_parts) < 2:
114
+ print(f"Skipping notebook with invalid path: {notebook_path}")
115
+ continue
116
+
117
+ course_id = path_parts[0]
118
+
119
+ # If this is a new course, initialize it
120
+ if course_id not in courses:
121
+ course_metadata = get_course_metadata(Path(course_id))
122
+
123
+ courses[course_id] = {
124
+ "id": course_id,
125
+ "title": course_metadata["title"],
126
+ "description": course_metadata["description"],
127
+ "description_html": course_metadata["description_html"],
128
+ "notebooks": []
129
+ }
130
+
131
+ # Extract the notebook number and name from the filename
132
+ filename = Path(notebook_path).name
133
+ basename = filename.replace(".py", "")
134
+
135
+ # Extract notebook metadata
136
+ notebook_title = basename.replace("_", " ").title()
137
+
138
+ # Try to extract a sequence number from the start of the filename
139
+ # Match patterns like: 01_xxx, 1_xxx, etc.
140
+ import re
141
+ number_match = re.match(r'^(\d+)(?:[_-]|$)', basename)
142
+ notebook_number = number_match.group(1) if number_match else None
143
+
144
+ # If we found a number, remove it from the title
145
+ if number_match:
146
+ notebook_title = re.sub(r'^\d+\s*[_-]?\s*', '', notebook_title)
147
+
148
+ # Calculate the HTML output path (for linking)
149
+ html_path = f"{course_id}/{filename.replace('.py', '.html')}"
150
+
151
+ # Add the notebook to the course
152
+ courses[course_id]["notebooks"].append({
153
+ "path": notebook_path,
154
+ "html_path": html_path,
155
+ "title": notebook_title,
156
+ "display_name": notebook_title,
157
+ "original_number": notebook_number
158
+ })
159
+
160
+ # Sort notebooks by number if available, otherwise by title
161
+ for course_id, course_data in courses.items():
162
+ # Sort the notebooks list by number and title
163
+ course_data["notebooks"] = sorted(
164
+ course_data["notebooks"],
165
+ key=lambda x: (
166
+ int(x["original_number"]) if x["original_number"] is not None else float('inf'),
167
+ x["title"]
168
+ )
169
+ )
170
+
171
+ return courses
172
+
173
+
174
+ def generate_clean_tailwind_landing_page(courses: Dict[str, Dict[str, Any]], output_dir: str) -> None:
175
+ """Generate a clean tailwindcss landing page with green accents.
176
+
177
+ This generates a modern, minimal landing page for marimo notebooks using tailwindcss.
178
+ The page is designed with clean aesthetics and green color accents using Jinja2 templates.
179
+
180
+ Args:
181
+ courses: Dictionary of courses metadata
182
+ output_dir: Directory to write the output index.html file
183
+ """
184
+ print("Generating clean tailwindcss landing page")
185
+
186
+ index_path = os.path.join(output_dir, "index.html")
187
+ os.makedirs(output_dir, exist_ok=True)
188
+
189
+ # Load Jinja2 template
190
+ current_dir = Path(__file__).parent
191
+ templates_dir = current_dir / "templates"
192
+ env = Environment(loader=FileSystemLoader(templates_dir))
193
+ template = env.get_template('index.html')
194
+
195
+ try:
196
+ with open(index_path, "w", encoding="utf-8") as f:
197
+ # Render the template with the provided data
198
+ rendered_html = template.render(
199
+ courses=courses,
200
+ current_year=datetime.date.today().year
201
+ )
202
+ f.write(rendered_html)
203
+
204
+ print(f"Successfully generated clean tailwindcss landing page at {index_path}")
205
+
206
+ except IOError as e:
207
+ print(f"Error generating clean tailwindcss landing page: {e}")
208
+
209
+
210
+ def main() -> None:
211
+ parser = argparse.ArgumentParser(description="Build marimo notebooks")
212
+ parser.add_argument(
213
+ "--output-dir", default="_site", help="Output directory for built files"
214
+ )
215
+ parser.add_argument(
216
+ "--course-dirs", nargs="+", default=None,
217
+ help="Specific course directories to build (default: all directories with .py files)"
218
+ )
219
+ args = parser.parse_args()
220
+
221
+ # Find all course directories (directories containing .py files)
222
+ all_notebooks: List[str] = []
223
+
224
+ # Directories to exclude from course detection
225
+ excluded_dirs = ["scripts", "env", "__pycache__", ".git", ".github", "assets"]
226
+
227
+ if args.course_dirs:
228
+ course_dirs = args.course_dirs
229
+ else:
230
+ # Automatically detect course directories (any directory with .py files)
231
+ course_dirs = []
232
+ for item in os.listdir("."):
233
+ if (os.path.isdir(item) and
234
+ not item.startswith(".") and
235
+ not item.startswith("_") and
236
+ item not in excluded_dirs):
237
+ # Check if directory contains .py files
238
+ if list(Path(item).glob("*.py")):
239
+ course_dirs.append(item)
240
+
241
+ print(f"Found course directories: {', '.join(course_dirs)}")
242
+
243
+ for directory in course_dirs:
244
+ dir_path = Path(directory)
245
+ if not dir_path.exists():
246
+ print(f"Warning: Directory not found: {dir_path}")
247
+ continue
248
+
249
+ notebooks = [str(path) for path in dir_path.rglob("*.py")
250
+ if not path.name.startswith("_") and "/__pycache__/" not in str(path)]
251
+ all_notebooks.extend(notebooks)
252
+
253
+ if not all_notebooks:
254
+ print("No notebooks found!")
255
+ return
256
+
257
+ # Export notebooks sequentially
258
+ successful_notebooks = []
259
+ for nb in all_notebooks:
260
+ # Determine if notebook should be exported as app or notebook
261
+ # For now, export all as notebooks
262
+ if export_html_wasm(nb, args.output_dir, as_app=False):
263
+ successful_notebooks.append(nb)
264
+
265
+ # Organize notebooks by course (only include successfully exported notebooks)
266
+ courses = organize_notebooks_by_course(successful_notebooks)
267
+
268
+ # Generate landing page using Tailwind CSS
269
+ generate_clean_tailwind_landing_page(courses, args.output_dir)
270
+
271
+ # Save course data as JSON for potential use by other tools
272
+ courses_json_path = os.path.join(args.output_dir, "courses.json")
273
+ with open(courses_json_path, "w", encoding="utf-8") as f:
274
+ json.dump(courses, f, indent=2)
275
+
276
+ print(f"Build complete! Site generated in {args.output_dir}")
277
+ print(f"Successfully exported {len(successful_notebooks)} out of {len(all_notebooks)} notebooks")
278
+
279
+
280
+ if __name__ == "__main__":
281
+ main()
scripts/preview.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+
3
+ import os
4
+ import subprocess
5
+ import argparse
6
+ import webbrowser
7
+ import time
8
+ import sys
9
+ from pathlib import Path
10
+
11
+ def main():
12
+ parser = argparse.ArgumentParser(description="Build and preview marimo notebooks site")
13
+ parser.add_argument(
14
+ "--port", default=8000, type=int, help="Port to run the server on"
15
+ )
16
+ parser.add_argument(
17
+ "--no-build", action="store_true", help="Skip building the site (just serve existing files)"
18
+ )
19
+ parser.add_argument(
20
+ "--output-dir", default="_site", help="Output directory for built files"
21
+ )
22
+ args = parser.parse_args()
23
+
24
+ # Store the current directory
25
+ original_dir = os.getcwd()
26
+
27
+ try:
28
+ # Build the site if not skipped
29
+ if not args.no_build:
30
+ print("Building site...")
31
+ build_script = Path("scripts/build.py")
32
+ if not build_script.exists():
33
+ print(f"Error: Build script not found at {build_script}")
34
+ return 1
35
+
36
+ result = subprocess.run(
37
+ [sys.executable, str(build_script), "--output-dir", args.output_dir],
38
+ check=False
39
+ )
40
+ if result.returncode != 0:
41
+ print("Warning: Build process completed with errors.")
42
+
43
+ # Check if the output directory exists
44
+ output_dir = Path(args.output_dir)
45
+ if not output_dir.exists():
46
+ print(f"Error: Output directory '{args.output_dir}' does not exist.")
47
+ return 1
48
+
49
+ # Change to the output directory
50
+ os.chdir(args.output_dir)
51
+
52
+ # Open the browser
53
+ url = f"http://localhost:{args.port}"
54
+ print(f"Opening {url} in your browser...")
55
+ webbrowser.open(url)
56
+
57
+ # Start the server
58
+ print(f"Starting server on port {args.port}...")
59
+ print("Press Ctrl+C to stop the server")
60
+
61
+ # Use the appropriate Python executable
62
+ subprocess.run([sys.executable, "-m", "http.server", str(args.port)])
63
+
64
+ return 0
65
+ except KeyboardInterrupt:
66
+ print("\nServer stopped.")
67
+ return 0
68
+ except Exception as e:
69
+ print(f"Error: {e}")
70
+ return 1
71
+ finally:
72
+ # Always return to the original directory
73
+ os.chdir(original_dir)
74
+
75
+ if __name__ == "__main__":
76
+ sys.exit(main())
scripts/templates/index.html ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!DOCTYPE html>
2
+ <html lang="en">
3
+ <head>
4
+ <meta charset="UTF-8">
5
+ <meta name="viewport" content="width=device-width, initial-scale=1.0">
6
+ <title>Marimo Learn - Interactive Python Notebooks</title>
7
+ <meta name="description" content="Learn Python, data science, and machine learning with interactive marimo notebooks">
8
+ <link href="https://cdn.jsdelivr.net/npm/[email protected]/dist/tailwind.min.css" rel="stylesheet">
9
+ <style>
10
+ :root {
11
+ --primary-green: #10B981;
12
+ --dark-green: #047857;
13
+ --light-green: #D1FAE5;
14
+ }
15
+ .bg-primary { background-color: var(--primary-green); }
16
+ .text-primary { color: var(--primary-green); }
17
+ .border-primary { border-color: var(--primary-green); }
18
+ .bg-light { background-color: var(--light-green); }
19
+ .hover-grow { transition: transform 0.2s ease; }
20
+ .hover-grow:hover { transform: scale(1.02); }
21
+ .card-shadow { box-shadow: 0 4px 6px rgba(0, 0, 0, 0.05), 0 1px 3px rgba(0, 0, 0, 0.1); }
22
+ </style>
23
+ </head>
24
+ <body class="bg-gray-50 text-gray-800 font-sans">
25
+ <!-- Hero Section -->
26
+ <header class="bg-white">
27
+ <div class="container mx-auto px-4 py-12 max-w-6xl">
28
+ <div class="flex flex-col md:flex-row items-center justify-between">
29
+ <div class="md:w-1/2 mb-8 md:mb-0 md:pr-12">
30
+ <h1 class="text-4xl md:text-5xl font-bold mb-4">Interactive Python Learning with <span class="text-primary">marimo</span></h1>
31
+ <p class="text-lg text-gray-600 mb-6">Explore our collection of interactive notebooks for Python, data science, and machine learning.</p>
32
+ <div class="flex flex-wrap gap-4">
33
+ <a href="#courses" class="bg-primary hover:bg-dark-green text-white font-medium py-2 px-6 rounded-md transition duration-300">Explore Courses</a>
34
+ <a href="https://github.com/marimo-team/learn" target="_blank" class="bg-white border border-gray-300 hover:border-primary text-gray-700 font-medium py-2 px-6 rounded-md transition duration-300">View on GitHub</a>
35
+ </div>
36
+ </div>
37
+ <div class="md:w-1/2">
38
+ <div class="bg-light p-1 rounded-lg">
39
+ <img src="https://github.com/marimo-team/learn/blob/main/assets/marimo-learn.png?raw=true" alt="Marimo Logo" class="w-64 h-64 mx-auto object-contain">
40
+ </div>
41
+ </div>
42
+ </div>
43
+ </div>
44
+ </header>
45
+
46
+ <!-- Features Section -->
47
+ <section class="py-16 bg-gray-50">
48
+ <div class="container mx-auto px-4 max-w-6xl">
49
+ <h2 class="text-3xl font-bold text-center mb-12">Why Learn with <span class="text-primary">Marimo</span>?</h2>
50
+ <div class="grid md:grid-cols-3 gap-8">
51
+ <div class="bg-white p-6 rounded-lg card-shadow">
52
+ <div class="w-12 h-12 bg-light rounded-full flex items-center justify-center mb-4">
53
+ <svg xmlns="http://www.w3.org/2000/svg" class="h-6 w-6 text-primary" fill="none" viewBox="0 0 24 24" stroke="currentColor">
54
+ <path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M13 10V3L4 14h7v7l9-11h-7z" />
55
+ </svg>
56
+ </div>
57
+ <h3 class="text-xl font-semibold mb-2">Interactive Learning</h3>
58
+ <p class="text-gray-600">Learn by doing with interactive notebooks that run directly in your browser.</p>
59
+ </div>
60
+ <div class="bg-white p-6 rounded-lg card-shadow">
61
+ <div class="w-12 h-12 bg-light rounded-full flex items-center justify-center mb-4">
62
+ <svg xmlns="http://www.w3.org/2000/svg" class="h-6 w-6 text-primary" fill="none" viewBox="0 0 24 24" stroke="currentColor">
63
+ <path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M19.428 15.428a2 2 0 00-1.022-.547l-2.387-.477a6 6 0 00-3.86.517l-.318.158a6 6 0 01-3.86.517L6.05 15.21a2 2 0 00-1.806.547M8 4h8l-1 1v5.172a2 2 0 00.586 1.414l5 5c1.26 1.26.367 3.414-1.415 3.414H4.828c-1.782 0-2.674-2.154-1.414-3.414l5-5A2 2 0 009 10.172V5L8 4z" />
64
+ </svg>
65
+ </div>
66
+ <h3 class="text-xl font-semibold mb-2">Practical Examples</h3>
67
+ <p class="text-gray-600">Real-world examples and applications to reinforce your understanding.</p>
68
+ </div>
69
+ <div class="bg-white p-6 rounded-lg card-shadow">
70
+ <div class="w-12 h-12 bg-light rounded-full flex items-center justify-center mb-4">
71
+ <svg xmlns="http://www.w3.org/2000/svg" class="h-6 w-6 text-primary" fill="none" viewBox="0 0 24 24" stroke="currentColor">
72
+ <path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M12 6.253v13m0-13C10.832 5.477 9.246 5 7.5 5S4.168 5.477 3 6.253v13C4.168 18.477 5.754 18 7.5 18s3.332.477 4.5 1.253m0-13C13.168 5.477 14.754 5 16.5 5c1.747 0 3.332.477 4.5 1.253v13C19.832 18.477 18.247 18 16.5 18c-1.746 0-3.332.477-4.5 1.253" />
73
+ </svg>
74
+ </div>
75
+ <h3 class="text-xl font-semibold mb-2">Comprehensive Curriculum</h3>
76
+ <p class="text-gray-600">From Python basics to advanced machine learning concepts.</p>
77
+ </div>
78
+ </div>
79
+ </div>
80
+ </section>
81
+
82
+ <!-- Courses Section -->
83
+ <section id="courses" class="py-16 bg-white">
84
+ <div class="container mx-auto px-4 max-w-6xl">
85
+ <h2 class="text-3xl font-bold text-center mb-12">Explore Our <span class="text-primary">Courses</span></h2>
86
+ <div class="grid md:grid-cols-2 lg:grid-cols-3 gap-8">
87
+ {% for course_id, course in courses.items() %}
88
+ {% set notebooks = course.get('notebooks', []) %}
89
+ {% set notebook_count = notebooks|length %}
90
+
91
+ {% if notebook_count > 0 %}
92
+ {% set title = course.get('title', course_id|replace('_', ' ')|title) %}
93
+
94
+ <div class="bg-white border border-gray-200 rounded-lg overflow-hidden hover-grow card-shadow">
95
+ <div class="h-2 bg-primary"></div>
96
+ <div class="p-6">
97
+ <h3 class="text-xl font-semibold mb-2">{{ title }}</h3>
98
+ <p class="text-gray-600 mb-4">
99
+ {% if course.get('description_html') %}
100
+ {{ course.get('description_html')|safe }}
101
+ {% endif %}
102
+ </p>
103
+ <div class="mb-4">
104
+ <span class="text-sm text-gray-500 block mb-2">{{ notebook_count }} notebooks:</span>
105
+ <ol class="space-y-1 list-decimal pl-5">
106
+ {% for notebook in notebooks %}
107
+ {% set notebook_title = notebook.get('title', notebook.get('path', '').split('/')[-1].replace('.py', '').replace('_', ' ').title()) %}
108
+ <li>
109
+ <a href="{{ notebook.get('html_path', '#') }}" class="text-primary hover:text-dark-green text-sm flex items-center">
110
+ {{ notebook_title }}
111
+ </a>
112
+ </li>
113
+ {% endfor %}
114
+ </ol>
115
+ </div>
116
+ </div>
117
+ </div>
118
+ {% endif %}
119
+ {% endfor %}
120
+ </div>
121
+ </div>
122
+ </section>
123
+
124
+ <!-- Contribute Section -->
125
+ <section class="py-16 bg-light">
126
+ <div class="container mx-auto px-4 max-w-6xl text-center">
127
+ <h2 class="text-3xl font-bold mb-6">Want to Contribute?</h2>
128
+ <p class="text-lg text-gray-700 mb-8 max-w-2xl mx-auto">Help us improve these learning materials by contributing to the GitHub repository. We welcome new content, bug fixes, and improvements!</p>
129
+ <a href="https://github.com/marimo-team/learn" target="_blank" class="bg-primary hover:bg-dark-green text-white font-medium py-3 px-8 rounded-md transition duration-300 inline-flex items-center">
130
+ <svg class="w-5 h-5 mr-2" fill="currentColor" viewBox="0 0 20 20" xmlns="http://www.w3.org/2000/svg">
131
+ <path fill-rule="evenodd" d="M10 0C4.477 0 0 4.477 0 10c0 4.42 2.87 8.17 6.84 9.5.5.08.66-.23.66-.5v-1.69c-2.77.6-3.36-1.34-3.36-1.34-.46-1.16-1.11-1.47-1.11-1.47-.91-.62.07-.6.07-.6 1 .07 1.53 1.03 1.53 1.03.87 1.52 2.34 1.07 2.91.83.09-.65.35-1.09.63-1.34-2.22-.25-4.55-1.11-4.55-4.92 0-1.11.38-2 1.03-2.71-.1-.25-.45-1.29.1-2.64 0 0 1.005-.315 3.3 1.23.96-.27 1.98-.405 3-.405s2.04.135 3 .405c2.295-1.56 3.3-1.23 3.3-1.23.66 1.65.24 2.88.12 3.18.765.84 1.23 1.905 1.23 3.225 0 4.605-2.805 5.625-5.475 5.925.435.375.81 1.095.81 2.22 0 1.605-.015 2.895-.015 3.3 0 .315.225.69.825.57A12.02 12.02 0 0024 12c0-6.63-5.37-12-12-12z" clip-rule="evenodd"></path>
132
+ </svg>
133
+ Contribute on GitHub
134
+ </a>
135
+ </div>
136
+ </section>
137
+
138
+ <!-- Footer -->
139
+ <footer class="bg-gray-800 text-white py-8">
140
+ <div class="container mx-auto px-4 max-w-6xl">
141
+ <div class="flex flex-col md:flex-row justify-between items-center">
142
+ <div class="mb-4 md:mb-0">
143
+ <p>&copy; {{ current_year }} marimo. All rights reserved.</p>
144
+ </div>
145
+ <div class="flex space-x-4">
146
+ <a href="https://github.com/marimo-team/learn" target="_blank" class="text-gray-300 hover:text-white transition duration-300">
147
+ <svg class="w-6 h-6" fill="currentColor" viewBox="0 0 24 24" xmlns="http://www.w3.org/2000/svg">
148
+ <path fill-rule="evenodd" d="M12 0C5.37 0 0 5.37 0 12c0 5.31 3.435 9.795 8.205 11.385.6.105.825-.255.825-.57 0-.285-.015-1.23-.015-2.235-3.015.555-3.795-.735-4.035-1.41-.135-.345-.72-1.41-1.23-1.695-.42-.225-1.02-.78-.015-.795.945-.015 1.62.87 1.845 1.23 1.08 1.815 2.805 1.305 3.495.99.105-.78.42-1.305.765-1.605-2.67-.3-5.46-1.335-5.46-5.925 0-1.305.465-2.385 1.23-3.225-.12-.3-.54-1.53.12-3.18 0 0 1.005-.315 3.3 1.23.96-.27 1.98-.405 3-.405s2.04.135 3 .405c2.295-1.56 3.3-1.23 3.3-1.23.66 1.65.24 2.88.12 3.18.765.84 1.23 1.905 1.23 3.225 0 4.605-2.805 5.625-5.475 5.925.435.375.81 1.095.81 2.22 0 1.605-.015 2.895-.015 3.3 0 .315.225.69.825.57A12.02 12.02 0 0024 12c0-6.63-5.37-12-12-12z" clip-rule="evenodd"></path>
149
+ </svg>
150
+ </a>
151
+ <a href="https://marimo.io" target="_blank" class="text-gray-300 hover:text-white transition duration-300">
152
+ <svg xmlns="http://www.w3.org/2000/svg" class="h-6 w-6" fill="none" viewBox="0 0 24 24" stroke="currentColor">
153
+ <path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M21 12a9 9 0 01-9 9m9-9a9 9 0 00-9-9m9 9H3m9 9a9 9 0 01-9-9m9 9c1.657 0 3-4.03 3-9s-1.343-9-3-9m0 18c-1.657 0-3-4.03-3-9s1.343-9 3-9m-9 9a9 9 0 019-9" />
154
+ </svg>
155
+ </a>
156
+ </div>
157
+ </div>
158
+ </div>
159
+ </footer>
160
+
161
+ <!-- Scripts -->
162
+ <script>
163
+ // Smooth scrolling for anchor links
164
+ document.querySelectorAll('a[href^="#"]').forEach(anchor => {
165
+ anchor.addEventListener('click', function (e) {
166
+ e.preventDefault();
167
+ document.querySelector(this.getAttribute('href')).scrollIntoView({
168
+ behavior: 'smooth'
169
+ });
170
+ });
171
+ });
172
+ </script>
173
+ </body>
174
+ </html>