Srihari Thyagarajan commited on
Commit
d0f22f6
·
unverified ·
2 Parent(s): ccf4a01 43992cd

Merge branch 'main' into fp/applicatives

Browse files
.github/workflows/typos.yaml CHANGED
@@ -13,3 +13,5 @@ jobs:
13
  uses: styfle/[email protected]
14
  - uses: actions/checkout@v4
15
  - uses: crate-ci/[email protected]
 
 
 
13
  uses: styfle/[email protected]
14
  - uses: actions/checkout@v4
15
  - uses: crate-ci/[email protected]
16
+ with:
17
+ config: .typos.toml
.typos.toml ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [default]
2
+ extend-ignore-re = [
3
+ # LaTeX math expressions
4
+ "\\\\\\[.*?\\\\\\]",
5
+ "\\\\\\(.*?\\\\\\)",
6
+ "\\$\\$.*?\\$\\$",
7
+ "\\$.*?\\$",
8
+ # LaTeX commands
9
+ "\\\\[a-zA-Z]+\\{.*?\\}",
10
+ "\\\\[a-zA-Z]+",
11
+ # LaTeX subscripts and superscripts
12
+ "_\\{.*?\\}",
13
+ "\\^\\{.*?\\}"
14
+ ]
15
+
16
+ # Words to explicitly accept
17
+ [default.extend-words]
18
+ pn = "pn"
19
+
20
+ # You can also exclude specific files or directories if needed
21
+ # [files]
22
+ # extend-exclude = ["*.tex", "docs/*.md"]
functional_programming/README.md CHANGED
@@ -38,7 +38,7 @@ uvx marimo edit <URL>
38
  For example, run the `Functor` tutorial with
39
 
40
  ```bash
41
- uvx marimo edit https://github.com/marimo-team/learn/blob/main/Functional_programming/05_functors.py
42
  ```
43
 
44
  You can also open notebooks in our online playground by appending `marimo.app/`
@@ -50,10 +50,9 @@ to a notebook's URL:
50
  Check [here](https://github.com/marimo-team/learn/issues/51) for current series
51
  structure.
52
 
53
- | Notebook | Title | Description | Key Concepts | Prerequisites |
54
- |----------|-------|-------------|--------------|---------------|
55
- | [05. Functors](https://github.com/marimo-team/learn/blob/main/Functional_programming/05_functors.py) | Category and Functors | Learn why `len` is a _Functor_ from `list concatenation` to `integer addition`, how to _lift_ an ordinary function into a _computation context_, and how to write an _adapter_ between two categories. | Categories, Functors, Function lifting, Context mapping | Basic Python, Functions |
56
- | [06. Applicatives](https://github.com/marimo-team/learn/blob/main/Functional_programming/06_applicatives.py) | Applicative programming with effects | Learn how to apply functions within a context, combining multiple effects in a pure way. Learn about the `pure` and `apply` operations that make applicatives powerful for handling multiple computations. | Applicative Functors, Pure, Apply, Effectful programming | Functors |
57
 
58
  **Authors.**
59
 
 
38
  For example, run the `Functor` tutorial with
39
 
40
  ```bash
41
+ uvx marimo edit https://github.com/marimo-team/learn/blob/main/functional_programming/05_functors.py
42
  ```
43
 
44
  You can also open notebooks in our online playground by appending `marimo.app/`
 
50
  Check [here](https://github.com/marimo-team/learn/issues/51) for current series
51
  structure.
52
 
53
+ | Notebook | Description | References |
54
+ | ----------------------------------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ | --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
55
+ | [05. Category and Functors](https://github.com/marimo-team/learn/blob/main/functional_programming/05_functors.py) | Learn why `len` is a _Functor_ from `list concatenation` to `integer addition`, how to _lift_ an ordinary function into a _computation context_, and how to write an _adapter_ between two categories. | - [The Trivial Monad](http://blog.sigfpe.com/2007/04/trivial-monad.html) <br> - [Haskellwiki. Category Theory](https://en.wikibooks.org/wiki/Haskell/Category_theory) <br> - [Haskellforall. The Category Design Pattern](https://www.haskellforall.com/2012/08/the-category-design-pattern.html) <br> - [Haskellforall. The Functor Design Pattern](https://www.haskellforall.com/2012/09/the-functor-design-pattern.html) <br> - [Haskellwiki. Functor](https://wiki.haskell.org/index.php?title=Functor) <br> - [Haskellwiki. Typeclassopedia#Functor](https://wiki.haskell.org/index.php?title=Typeclassopedia#Functor) <br> - [Haskellwiki. Typeclassopedia#Category](https://wiki.haskell.org/index.php?title=Typeclassopedia#Category) |
 
56
 
57
  **Authors.**
58
 
polars/08_working_with_columns.py ADDED
@@ -0,0 +1,605 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # /// script
2
+ # requires-python = ">=3.11"
3
+ # dependencies = [
4
+ # "polars==1.18.0",
5
+ # "marimo",
6
+ # ]
7
+ # ///
8
+
9
+ import marimo
10
+
11
+ __generated_with = "0.12.0"
12
+ app = marimo.App(width="medium")
13
+
14
+
15
+ @app.cell(hide_code=True)
16
+ def _(mo):
17
+ mo.md(
18
+ r"""
19
+ # Working with Columns
20
+
21
+ Author: [Deb Debnath](https://github.com/debajyotid2)
22
+
23
+ **Note**: The following tutorial has been adapted from the Polars [documentation](https://docs.pola.rs/user-guide/expressions/expression-expansion).
24
+ """
25
+ )
26
+ return
27
+
28
+
29
+ @app.cell(hide_code=True)
30
+ def _(mo):
31
+ mo.md(
32
+ r"""
33
+ ## Expressions
34
+
35
+ Data transformations are sometimes complicated, or involve massive computations which are time-consuming. You can make a small version of the dataset with the schema you are trying to work your transformation into. But there is a better way to do it in Polars.
36
+
37
+ A Polars expression is a lazy representation of a data transformation. "Lazy" means that the transformation is not eagerly (immediately) executed.
38
+
39
+ Expressions are modular and flexible. They can be composed to build more complex expressions. For example, to calculate speed from distance and time, you can have an expression as:
40
+ """
41
+ )
42
+ return
43
+
44
+
45
+ @app.cell
46
+ def _(pl):
47
+ speed_expr = pl.col("distance") / (pl.col("time"))
48
+ speed_expr
49
+ return (speed_expr,)
50
+
51
+
52
+ @app.cell(hide_code=True)
53
+ def _(mo):
54
+ mo.md(
55
+ r"""
56
+ ## Expression expansion
57
+
58
+ Expression expansion lets you write a single expression that can expand to multiple different expressions. So rather than repeatedly defining separate expressions, you can avoid redundancy while adhering to clean code principles (Do not Repeat Yourself - [DRY](https://en.wikipedia.org/wiki/Don%27t_repeat_yourself)). Since expressions are reusable, they aid in writing concise code.
59
+ """
60
+ )
61
+ return
62
+
63
+
64
+ @app.cell(hide_code=True)
65
+ def _(mo):
66
+ mo.md("""For the examples in this notebook, we will use a sliver of the *AI4I 2020 Predictive Maintenance Dataset*. This dataset comprises of measurements taken from sensors in industrial machinery undergoing preventive maintenance checks - basically being tested for failure conditions.""")
67
+ return
68
+
69
+
70
+ @app.cell
71
+ def _(StringIO, pl):
72
+ data_csv = """
73
+ Product ID,Type,Air temperature,Process temperature,Rotational speed,Tool wear,Machine failure,TWF,HDF,PWF,OSF,RNF
74
+ L51172,L,302.3,311.3,1614,129,0,0,1,0,0,0
75
+ M22586,M,300.8,311.9,1761,113,1,0,0,0,1,0
76
+ L51639,L,302.6,310.4,1743,191,0,1,0,0,0,1
77
+ L50250,L,300,309.1,1631,110,0,0,0,1,0,0
78
+ M20109,M,303.4,312.9,1422,63,1,0,0,0,0,0
79
+ """
80
+
81
+ data = pl.read_csv(StringIO(data_csv))
82
+ data
83
+ return data, data_csv
84
+
85
+
86
+ @app.cell(hide_code=True)
87
+ def _(mo):
88
+ mo.md(
89
+ r"""
90
+ ## Function `col`
91
+
92
+ The function `col` is used to refer to one column of a dataframe. It is one of the fundamental building blocks of expressions in Polars. `col` is also really handy in expression expansion.
93
+ """
94
+ )
95
+ return
96
+
97
+
98
+ @app.cell(hide_code=True)
99
+ def _(mo):
100
+ mo.md(
101
+ r"""
102
+ ### Explicit expansion by column name
103
+
104
+ The simplest form of expression expansion happens when you provide multiple column names to the function `col`.
105
+
106
+ Say you wish to convert all temperature values in deg. Kelvin (K) to deg. Fahrenheit (F). One way to do this would be to define individual expressions for each column as follows:
107
+ """
108
+ )
109
+ return
110
+
111
+
112
+ @app.cell
113
+ def _(data, pl):
114
+ exprs = [
115
+ ((pl.col("Air temperature") - 273.15) * 1.8 + 32).round(2),
116
+ ((pl.col("Process temperature") - 273.15) * 1.8 + 32).round(2)
117
+ ]
118
+
119
+ result = data.with_columns(exprs)
120
+ result
121
+ return exprs, result
122
+
123
+
124
+ @app.cell(hide_code=True)
125
+ def _(mo):
126
+ mo.md(r"""Expression expansion can reduce this verbosity when you list the column names you want the expression to expand to inside the `col` function. The result is the same as before.""")
127
+ return
128
+
129
+
130
+ @app.cell
131
+ def _(data, pl, result):
132
+ result_2 = data.with_columns(
133
+ (
134
+ (pl.col(
135
+ "Air temperature",
136
+ "Process temperature"
137
+ )
138
+ - 273.15) * 1.8 + 32
139
+ ).round(2)
140
+ )
141
+ result_2.equals(result)
142
+ return (result_2,)
143
+
144
+
145
+ @app.cell(hide_code=True)
146
+ def _(mo):
147
+ mo.md(r"""In this case, the expression that does the temperature conversion is expanded to a list of two expressions. The expansion of the expression is predictable and intuitive.""")
148
+ return
149
+
150
+
151
+ @app.cell(hide_code=True)
152
+ def _(mo):
153
+ mo.md(
154
+ r"""
155
+ ### Expansion by data type
156
+
157
+ Can we do better than explicitly writing the names of every columns we want transformed? Yes.
158
+
159
+ If you provide data types instead of column names, the expression is expanded to all columns that match one of the data types provided.
160
+
161
+ The example below performs the exact same computation as before:
162
+ """
163
+ )
164
+ return
165
+
166
+
167
+ @app.cell
168
+ def _(data, pl, result):
169
+ result_3 = data.with_columns(((pl.col(pl.Float64) - 273.15) * 1.8 + 32).round(2))
170
+ result_3.equals(result)
171
+ return (result_3,)
172
+
173
+
174
+ @app.cell(hide_code=True)
175
+ def _(mo):
176
+ mo.md(
177
+ r"""
178
+ However, you should be careful to ensure that the transformation is only applied to the columns you want. For ensuring this it is important to know the schema of the data beforehand.
179
+
180
+ `col` accepts multiple data types in case the columns you need have more than one data type.
181
+ """
182
+ )
183
+ return
184
+
185
+
186
+ @app.cell
187
+ def _(data, pl, result):
188
+ result_4 = data.with_columns(
189
+ (
190
+ (pl.col(
191
+ pl.Float32,
192
+ pl.Float64,
193
+ )
194
+ - 273.15) * 1.8 + 32
195
+ ).round(2)
196
+ )
197
+ result.equals(result_4)
198
+ return (result_4,)
199
+
200
+
201
+ @app.cell(hide_code=True)
202
+ def _(mo):
203
+ mo.md(
204
+ r"""
205
+ ### Expansion by pattern matching
206
+
207
+ `col` also accepts regular expressions for selecting columns by pattern matching. Regular expressions start and end with ^ and $, respectively.
208
+ """
209
+ )
210
+ return
211
+
212
+
213
+ @app.cell
214
+ def _(data, pl):
215
+ data.select(pl.col("^.*temperature$"))
216
+ return
217
+
218
+
219
+ @app.cell(hide_code=True)
220
+ def _(mo):
221
+ mo.md(r"""Regular expressions can be combined with exact column names.""")
222
+ return
223
+
224
+
225
+ @app.cell
226
+ def _(data, pl):
227
+ data.select(pl.col("^.*temperature$", "Tool wear"))
228
+ return
229
+
230
+
231
+ @app.cell(hide_code=True)
232
+ def _(mo):
233
+ mo.md(r"""**Note**: You _cannot_ mix strings (exact names, regular expressions) and data types in a `col` function.""")
234
+ return
235
+
236
+
237
+ @app.cell
238
+ def _(data, pl):
239
+ try:
240
+ data.select(pl.col("Air temperature", pl.Float64))
241
+ except TypeError as err:
242
+ print("TypeError:", err)
243
+ return
244
+
245
+
246
+ @app.cell(hide_code=True)
247
+ def _(mo):
248
+ mo.md(
249
+ r"""
250
+ ## Selecting all columns
251
+
252
+ To select all columns, you can use the `all` function.
253
+ """
254
+ )
255
+ return
256
+
257
+
258
+ @app.cell
259
+ def _(data, pl):
260
+ result_6 = data.select(pl.all())
261
+ result_6.equals(data)
262
+ return (result_6,)
263
+
264
+
265
+ @app.cell(hide_code=True)
266
+ def _(mo):
267
+ mo.md(
268
+ r"""
269
+ ## Excluding columns
270
+
271
+ There are scenarios where we might want to exclude specific columns from the ones selected by building expressions, e.g. by the `col` or `all` functions. For this purpose, we use the function `exclude`, which accepts exactly the same types of arguments as `col`:
272
+ """
273
+ )
274
+ return
275
+
276
+
277
+ @app.cell
278
+ def _(data, pl):
279
+ data.select(pl.all().exclude("^.*F$"))
280
+ return
281
+
282
+
283
+ @app.cell(hide_code=True)
284
+ def _(mo):
285
+ mo.md(r"""`exclude` can also be used after the function `col`:""")
286
+ return
287
+
288
+
289
+ @app.cell
290
+ def _(data, pl):
291
+ data.select(pl.col(pl.Int64).exclude("^.*F$"))
292
+ return
293
+
294
+
295
+ @app.cell(hide_code=True)
296
+ def _(mo):
297
+ mo.md(
298
+ r"""
299
+ ## Column renaming
300
+
301
+ When applying a transformation with an expression to a column, the data in the column gets overwritten with the transformed data. However, this might not be the intended outcome in all situations - ideally you would want to store transformed data in a new column. Applying multiple transformations to the same column at the same time without renaming leads to errors.
302
+ """
303
+ )
304
+ return
305
+
306
+
307
+ @app.cell
308
+ def _(data, pl):
309
+ from polars.exceptions import DuplicateError
310
+
311
+ try:
312
+ data.select(
313
+ (pl.col("Air temperature") - 273.15) * 1.8 + 32, # This would be named "Air temperature"...
314
+ pl.col("Air temperature") - 273.15, # And so would this.
315
+ )
316
+ except DuplicateError as err:
317
+ print("DuplicateError:", err)
318
+ return (DuplicateError,)
319
+
320
+
321
+ @app.cell(hide_code=True)
322
+ def _(mo):
323
+ mo.md(
324
+ r"""
325
+ ### Renaming a single column with `alias`
326
+
327
+ The function `alias` lets you rename a single column:
328
+ """
329
+ )
330
+ return
331
+
332
+
333
+ @app.cell
334
+ def _(data, pl):
335
+ data.select(
336
+ ((pl.col("Air temperature") - 273.15) * 1.8 + 32).round(2).alias("Air temperature [F]"),
337
+ (pl.col("Air temperature") - 273.15).round(2).alias("Air temperature [C]")
338
+ )
339
+ return
340
+
341
+
342
+ @app.cell(hide_code=True)
343
+ def _(mo):
344
+ mo.md(
345
+ r"""
346
+ ### Prefixing and suffixing column names
347
+
348
+ As `alias` renames a single column at a time, it cannot be used during expression expansion. If it is sufficient add a static prefix or a static suffix to the existing names, you can use the functions `name.prefix` and `name.suffix` with `col`:
349
+ """
350
+ )
351
+ return
352
+
353
+
354
+ @app.cell
355
+ def _(data, pl):
356
+ data.select(
357
+ ((pl.col("Air temperature") - 273.15) * 1.8 + 32).round(2).name.prefix("deg F "),
358
+ (pl.col("Process temperature") - 273.15).round(2).name.suffix(" C"),
359
+ )
360
+ return
361
+
362
+
363
+ @app.cell(hide_code=True)
364
+ def _(mo):
365
+ mo.md(
366
+ r"""
367
+ ### Dynamic name replacement
368
+
369
+ If a static prefix/suffix is not enough, use `name.map`. `name.map` requires a function that transforms column names to the desired. The transformation should lead to unique names to avoid `DuplicateError`.
370
+ """
371
+ )
372
+ return
373
+
374
+
375
+ @app.cell
376
+ def _(data, pl):
377
+ # There is also `.name.to_lowercase`, so this usage of `.map` is moot.
378
+ data.select(pl.col("^.*F$").name.map(str.lower))
379
+ return
380
+
381
+
382
+ @app.cell(hide_code=True)
383
+ def _(mo):
384
+ mo.md(
385
+ r"""
386
+ ## Programmatically generating expressions
387
+
388
+ For this example, we will first create four additional columns with the rolling mean temperatures of the two temperature columns. Such transformations are sometimes used to create additional features for machine learning models or data analysis.
389
+ """
390
+ )
391
+ return
392
+
393
+
394
+ @app.cell
395
+ def _(data, pl):
396
+ ext_temp_data = data.with_columns(
397
+ pl.col("^.*temperature$").rolling_mean(window_size=2).round(2).name.prefix("Rolling mean ")
398
+ ).select(pl.col("^.*temperature*$"))
399
+ ext_temp_data
400
+ return (ext_temp_data,)
401
+
402
+
403
+ @app.cell(hide_code=True)
404
+ def _(mo):
405
+ mo.md(r"""Now, suppose we want to calculate the difference between the rolling mean and actual temperatures. We cannot use expression expansion here as we want differences between specific columns.""")
406
+ return
407
+
408
+
409
+ @app.cell(hide_code=True)
410
+ def _(mo):
411
+ mo.md(r"""At first, you may think about using a `for` loop:""")
412
+ return
413
+
414
+
415
+ @app.cell
416
+ def _(ext_temp_data, pl):
417
+ _result = ext_temp_data
418
+ for col_name in ["Air", "Process"]:
419
+ _result = _result.with_columns(
420
+ (abs(pl.col(f"Rolling mean {col_name} temperature") - pl.col(f"{col_name} temperature")))
421
+ .round(2).alias(f"Delta {col_name} temperature")
422
+ )
423
+ _result
424
+ return (col_name,)
425
+
426
+
427
+ @app.cell(hide_code=True)
428
+ def _(mo):
429
+ mo.md(r"""Using a `for` loop is functional, but not scalable, as each expression needs to be defined in an iteration and executed serially. Instead we can use a generator in Python to programmatically create all expressions at once. In conjunction with the `with_columns` context, we can take advantage of parallel execution of computations and query optimization from Polars.""")
430
+ return
431
+
432
+
433
+ @app.cell
434
+ def _(ext_temp_data, pl):
435
+ def delta_expressions(colnames: list[str]) -> pl.Expr:
436
+ for col_name in colnames:
437
+ yield (abs(pl.col(f"Rolling mean {col_name} temperature") - pl.col(f"{col_name} temperature"))
438
+ .round(2).alias(f"Delta {col_name} temperature"))
439
+
440
+
441
+ ext_temp_data.with_columns(delta_expressions(["Air", "Process"]))
442
+ return (delta_expressions,)
443
+
444
+
445
+ @app.cell(hide_code=True)
446
+ def _(mo):
447
+ mo.md(
448
+ r"""
449
+ ## More flexible column selections
450
+
451
+ For more flexible column selections, you can use column selectors from `selectors`. Column selectors allow for more expressiveness in the way you specify selections. For example, column selectors can perform the familiar set operations of union, intersection, difference, etc. We can use the union operation with the functions `string` and `ends_with` to select all string columns and the columns whose names end with "`_high`":
452
+ """
453
+ )
454
+ return
455
+
456
+
457
+ @app.cell
458
+ def _(data):
459
+ import polars.selectors as cs
460
+
461
+ data.select(cs.string() | cs.ends_with("F"))
462
+ return (cs,)
463
+
464
+
465
+ @app.cell(hide_code=True)
466
+ def _(mo):
467
+ mo.md(r"""Likewise, you can pick columns based on the category of the type of data, offering more flexibility than the `col` function. As an example, `cs.numeric` selects numeric data types (including `pl.Float32`, `pl.Float64`, `pl.Int32`, etc.) or `cs.temporal` for all dates, times and similar data types.""")
468
+ return
469
+
470
+
471
+ @app.cell(hide_code=True)
472
+ def _(mo):
473
+ mo.md(
474
+ r"""
475
+ ### Combining selectors with set operations
476
+
477
+ Multiple selectors can be combined using set operations and the usual Python operators:
478
+
479
+
480
+ | Operator | Operation |
481
+ |:--------:|:--------------------:|
482
+ | `A | B` | Union |
483
+ | `A & B` | Intersection |
484
+ | `A - B` | Difference |
485
+ | `A ^ B` | Symmetric difference |
486
+ | `~A` | Complement |
487
+
488
+ For example, to select all failure indicator variables excluding the failure variables due to wear, we can perform a set difference between the column selectors.
489
+ """
490
+ )
491
+ return
492
+
493
+
494
+ @app.cell
495
+ def _(cs, data):
496
+ data.select(cs.contains("F") - cs.contains("W"))
497
+ return
498
+
499
+
500
+ @app.cell(hide_code=True)
501
+ def _(mo):
502
+ mo.md(
503
+ r"""
504
+ ### Resolving operator ambiguity
505
+
506
+ Expression functions can be chained on top of selectors:
507
+ """
508
+ )
509
+ return
510
+
511
+
512
+ @app.cell
513
+ def _(cs, data, pl):
514
+ ext_failure_data = data.select(cs.contains("F")).cast(pl.Boolean)
515
+ ext_failure_data
516
+ return (ext_failure_data,)
517
+
518
+
519
+ @app.cell(hide_code=True)
520
+ def _(mo):
521
+ mo.md(
522
+ r"""
523
+ However, operators that perform set operations on column selectors operate on both selectors and on expressions. For example, the operator `~` on a selector represents the set operation “complement” and on an expression represents the Boolean operation of negation.
524
+
525
+ For instance, if you want to negate the Boolean values in the columns “HDF”, “OSF”, and “RNF”, at first you would think about using the `~` operator with the column selector to choose all failure variables containing "W". Because of the operator ambiguity here, the columns that are not of interest are selected here.
526
+ """
527
+ )
528
+ return
529
+
530
+
531
+ @app.cell
532
+ def _(cs, ext_failure_data):
533
+ ext_failure_data.select((~cs.ends_with("WF")).name.prefix("No"))
534
+ return
535
+
536
+
537
+ @app.cell(hide_code=True)
538
+ def _(mo):
539
+ mo.md(r"""To resolve the operator ambiguity, we use `as_expr`:""")
540
+ return
541
+
542
+
543
+ @app.cell
544
+ def _(cs, ext_failure_data):
545
+ ext_failure_data.select((~cs.ends_with("WF").as_expr()).name.prefix("No"))
546
+ return
547
+
548
+
549
+ @app.cell(hide_code=True)
550
+ def _(mo):
551
+ mo.md(
552
+ r"""
553
+ ### Debugging selectors
554
+
555
+ The function `cs.is_selector` helps check whether a complex chain of selectors and operators ultimately results in a selector. For example, to resolve any ambiguity with the selector in the last example, we can do:
556
+ """
557
+ )
558
+ return
559
+
560
+
561
+ @app.cell
562
+ def _(cs):
563
+ cs.is_selector(~cs.ends_with("WF").as_expr())
564
+ return
565
+
566
+
567
+ @app.cell(hide_code=True)
568
+ def _(mo):
569
+ mo.md(r"""Additionally we can use `expand_selector` to see what columns a selector expands into. Note that for this function we need to provide additional context in the form of the dataframe.""")
570
+ return
571
+
572
+
573
+ @app.cell
574
+ def _(cs, ext_failure_data):
575
+ cs.expand_selector(
576
+ ext_failure_data,
577
+ cs.ends_with("WF"),
578
+ )
579
+ return
580
+
581
+
582
+ @app.cell(hide_code=True)
583
+ def _(mo):
584
+ mo.md(
585
+ r"""
586
+ ### References
587
+
588
+ 1. AI4I 2020 Predictive Maintenance Dataset [Dataset]. (2020). UCI Machine Learning Repository. ([link](https://doi.org/10.24432/C5HS5C)).
589
+ 2. Polars documentation ([link](https://docs.pola.rs/user-guide/expressions/expression-expansion/#more-flexible-column-selections))
590
+ """
591
+ )
592
+ return
593
+
594
+
595
+ @app.cell(hide_code=True)
596
+ def _():
597
+ import csv
598
+ import marimo as mo
599
+ import polars as pl
600
+ from io import StringIO
601
+ return StringIO, csv, mo, pl
602
+
603
+
604
+ if __name__ == "__main__":
605
+ app.run()
polars/09_data_types.py ADDED
@@ -0,0 +1,296 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # /// script
2
+ # requires-python = ">=3.11"
3
+ # dependencies = [
4
+ # "polars==1.18.0",
5
+ # "marimo",
6
+ # ]
7
+ # ///
8
+
9
+ import marimo
10
+
11
+ __generated_with = "0.12.0"
12
+ app = marimo.App(width="medium")
13
+
14
+
15
+ @app.cell(hide_code=True)
16
+ def _(mo):
17
+ mo.md(
18
+ r"""
19
+ # Data Types
20
+
21
+ Author: [Deb Debnath](https://github.com/debajyotid2)
22
+
23
+ **Note**: The following tutorial has been adapted from the Polars [documentation](https://docs.pola.rs/user-guide/concepts/data-types-and-structures/).
24
+ """
25
+ )
26
+ return
27
+
28
+
29
+ @app.cell(hide_code=True)
30
+ def _(mo):
31
+ mo.md(
32
+ r"""
33
+ Polars supports a variety of data types that fall broadly under the following categories:
34
+
35
+ - Numeric data types: integers and floating point numbers.
36
+ - Nested data types: lists, structs, and arrays.
37
+ - Temporal: dates, datetimes, times, and time deltas.
38
+ - Miscellaneous: strings, binary data, Booleans, categoricals, enums, and objects.
39
+
40
+ All types support missing values represented by `null` which is different from `NaN` used in floating point data types. The numeric datatypes in Polars loosely follow the type system of the Rust language, since its core functionalities are built in Rust.
41
+
42
+ [Here](https://docs.pola.rs/api/python/stable/reference/datatypes.html) is a full list of all data types Polars supports.
43
+ """
44
+ )
45
+ return
46
+
47
+
48
+ @app.cell(hide_code=True)
49
+ def _(mo):
50
+ mo.md(
51
+ r"""
52
+ ## Series
53
+
54
+ A series is a 1-dimensional data structure that can hold only one data type.
55
+ """
56
+ )
57
+ return
58
+
59
+
60
+ @app.cell
61
+ def _(pl):
62
+ s = pl.Series("emojis", ["😀", "🤣", "🥶", "💀", "🤖"])
63
+ s
64
+ return (s,)
65
+
66
+
67
+ @app.cell(hide_code=True)
68
+ def _(mo):
69
+ mo.md(r"""Unless specified, Polars infers the datatype from the supplied values.""")
70
+ return
71
+
72
+
73
+ @app.cell
74
+ def _(pl):
75
+ s1 = pl.Series("friends", ["Евгений", "अभिषेक", "秀良", "Federico", "Bob"])
76
+ s2 = pl.Series("uints", [0x00, 0x01, 0x10, 0x11], dtype=pl.UInt8)
77
+ s1.dtype, s2.dtype
78
+ return s1, s2
79
+
80
+
81
+ @app.cell(hide_code=True)
82
+ def _(mo):
83
+ mo.md(
84
+ r"""
85
+ ## Dataframe
86
+
87
+ A dataframe is a 2-dimensional data structure that contains uniquely named series and can hold multiple data types. Dataframes are more commonly used for data manipulation using the functionality of Polars.
88
+
89
+ The snippet below shows how to create a dataframe from a dictionary of lists:
90
+ """
91
+ )
92
+ return
93
+
94
+
95
+ @app.cell
96
+ def _(pl):
97
+ data = pl.DataFrame(
98
+ {
99
+ "Product ID": ["L51172", "M22586", "L51639", "L50250", "M20109"],
100
+ "Type": ["L", "M", "L", "L", "M"],
101
+ "Air temperature": [302.3, 300.8, 302.6, 300, 303.4], # (K)
102
+ "Machine Failure": [False, True, False, False, True]
103
+ }
104
+ )
105
+ data
106
+ return (data,)
107
+
108
+
109
+ @app.cell(hide_code=True)
110
+ def _(mo):
111
+ mo.md(
112
+ r"""
113
+ ### Inspecting a dataframe
114
+
115
+ Polars has various functions to explore the data in a dataframe. We will use the dataframe `data` defined above in our examples. Alongside we can also see a view of the dataframe rendered by `marimo` as the cells are executed.
116
+
117
+ ///note
118
+ We can also use `marimo`'s built in data-inspection elements/features such as [`mo.ui.dataframe`](https://docs.marimo.io/api/inputs/dataframe/#marimo.ui.dataframe) & [`mo.ui.data_explorer`](https://docs.marimo.io/api/inputs/data_explorer/). For more check out our Polars tutorials at [`marimo learn`](https://marimo-team.github.io/learn/)!
119
+ """
120
+ )
121
+ return
122
+
123
+
124
+ @app.cell(hide_code=True)
125
+ def _(mo):
126
+ mo.md(
127
+ """
128
+ #### Head
129
+
130
+ The function `head` shows the first rows of a dataframe. Unless specified, it shows the first 5 rows.
131
+ """
132
+ )
133
+ return
134
+
135
+
136
+ @app.cell
137
+ def _(data):
138
+ data.head(3)
139
+ return
140
+
141
+
142
+ @app.cell(hide_code=True)
143
+ def _(mo):
144
+ mo.md(
145
+ r"""
146
+ #### Glimpse
147
+
148
+ The function `glimpse` is an alternative to `head` to view the first few columns, but displays each line of the output corresponding to a single column. That way, it makes inspecting wider dataframes easier.
149
+ """
150
+ )
151
+ return
152
+
153
+
154
+ @app.cell
155
+ def _(data):
156
+ print(data.glimpse(return_as_string=True))
157
+ return
158
+
159
+
160
+ @app.cell(hide_code=True)
161
+ def _(mo):
162
+ mo.md(
163
+ r"""
164
+ #### Tail
165
+
166
+ The `tail` function, just like its name suggests, shows the last rows of a dataframe. Unless the number of rows is specified, it will show the last 5 rows.
167
+ """
168
+ )
169
+ return
170
+
171
+
172
+ @app.cell
173
+ def _(data):
174
+ data.tail(3)
175
+ return
176
+
177
+
178
+ @app.cell(hide_code=True)
179
+ def _(mo):
180
+ mo.md(
181
+ r"""
182
+ #### Sample
183
+
184
+ `sample` can be used to show a specified number of randomly selected rows from the dataframe. Unless the number of rows is specified, it will show a single row. `sample` does not preserve order of the rows.
185
+ """
186
+ )
187
+ return
188
+
189
+
190
+ @app.cell
191
+ def _(data):
192
+ import random
193
+
194
+ random.seed(42) # For reproducibility.
195
+
196
+ data.sample(3)
197
+ return (random,)
198
+
199
+
200
+ @app.cell(hide_code=True)
201
+ def _(mo):
202
+ mo.md(
203
+ r"""
204
+ #### Describe
205
+
206
+ The function `describe` describes the summary statistics for all columns of a dataframe.
207
+ """
208
+ )
209
+ return
210
+
211
+
212
+ @app.cell
213
+ def _(data):
214
+ data.describe()
215
+ return
216
+
217
+
218
+ @app.cell(hide_code=True)
219
+ def _(mo):
220
+ mo.md(
221
+ r"""
222
+ ## Schema
223
+
224
+ A schema is a mapping showing the datatype corresponding to every column of a dataframe. The schema of a dataframe can be viewed using the attribute `schema`.
225
+ """
226
+ )
227
+ return
228
+
229
+
230
+ @app.cell
231
+ def _(data):
232
+ data.schema
233
+ return
234
+
235
+
236
+ @app.cell(hide_code=True)
237
+ def _(mo):
238
+ mo.md(r"""Since a schema is a mapping, it can be specified in the form of a Python dictionary. Then this dictionary can be used to specify the schema of a dataframe on definition. If not specified or the entry is `None`, Polars infers the datatype from the contents of the column. Note that if the schema is not specified, it will be inferred automatically by default.""")
239
+ return
240
+
241
+
242
+ @app.cell
243
+ def _(pl):
244
+ pl.DataFrame(
245
+ {
246
+ "Product ID": ["L51172", "M22586", "L51639", "L50250", "M20109"],
247
+ "Type": ["L", "M", "L", "L", "M"],
248
+ "Air temperature": [302.3, 300.8, 302.6, 300, 303.4], # (K)
249
+ "Machine Failure": [False, True, False, False, True]
250
+ },
251
+ schema={"Product ID": pl.String, "Type": pl.String, "Air temperature": None, "Machine Failure": None},
252
+ )
253
+ return
254
+
255
+
256
+ @app.cell(hide_code=True)
257
+ def _(mo):
258
+ mo.md(r"""Sometimes the automatically inferred schema is enough for some columns, but we might wish to override the inference of only some columns. We can specify the schema for those columns using `schema_overrides`.""")
259
+ return
260
+
261
+
262
+ @app.cell
263
+ def _(pl):
264
+ pl.DataFrame(
265
+ {
266
+ "Product ID": ["L51172", "M22586", "L51639", "L50250", "M20109"],
267
+ "Type": ["L", "M", "L", "L", "M"],
268
+ "Air temperature": [302.3, 300.8, 302.6, 300, 303.4], # (K)
269
+ "Machine Failure": [False, True, False, False, True]
270
+ },
271
+ schema_overrides={"Air temperature": pl.Float32},
272
+ )
273
+ return
274
+
275
+
276
+ @app.cell(hide_code=True)
277
+ def _(mo):
278
+ mo.md(
279
+ r"""
280
+ ### References
281
+
282
+ 1. Polars documentation ([link](https://docs.pola.rs/api/python/stable/reference/datatypes.html))
283
+ """
284
+ )
285
+ return
286
+
287
+
288
+ @app.cell(hide_code=True)
289
+ def _():
290
+ import marimo as mo
291
+ import polars as pl
292
+ return mo, pl
293
+
294
+
295
+ if __name__ == "__main__":
296
+ app.run()
probability/20_naive_bayes.py ADDED
@@ -0,0 +1,833 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # /// script
2
+ # requires-python = ">=3.10"
3
+ # dependencies = [
4
+ # "marimo",
5
+ # "matplotlib==3.10.1",
6
+ # "scipy==1.15.2",
7
+ # "numpy==2.2.4",
8
+ # "polars==1.26.0",
9
+ # "plotly==5.18.0",
10
+ # "scikit-learn==1.6.1",
11
+ # ]
12
+ # ///
13
+
14
+ import marimo
15
+
16
+ __generated_with = "0.12.0"
17
+ app = marimo.App(width="medium", app_title="Naive Bayes Classification")
18
+
19
+
20
+ @app.cell(hide_code=True)
21
+ def _(mo):
22
+ mo.md(
23
+ r"""
24
+ # Naive Bayes Classification
25
+
26
+ _This notebook is a computational companion to ["Probability for Computer Scientists"](https://chrispiech.github.io/probabilityForComputerScientists/en/part5/naive_bayes/), by Stanford professor Chris Piech._
27
+
28
+ Naive Bayes is one of those classic machine learning algorithms that seems almost too simple to work, yet it's surprisingly effective for many classification tasks. I've always found it fascinating how this algorithm applies Bayes' theorem with a strong (but knowingly incorrect) "naive" assumption that all features are independent of each other.
29
+
30
+ In this notebook, we'll dive into why this supposedly "wrong" assumption still leads to good results. We'll walk through the training process, learn how to make predictions, and see some interactive visualizations that helped me understand the concept better when I was first learning it. We'll also explore why Naive Bayes excels particularly in text classification problems like spam filtering.
31
+
32
+ If you're new to Naive Bayes, I highly recommend checking out [this excellent explanation by Mahesh Huddar](https://youtu.be/XzSlEA4ck2I?si=AASeh_KP68BAbzy5), which provides a step-by-step walkthrough with a helpful example (which we take a dive into, down below).
33
+ """
34
+ )
35
+ return
36
+
37
+
38
+ @app.cell(hide_code=True)
39
+ def _(mo):
40
+ mo.md(
41
+ r"""
42
+ ## Why "Naive"?
43
+
44
+ So why is it called "naive"? It's because the algorithm makes an assumption — it assumes all features are completely independent of each other when given the class label.
45
+
46
+ The math way of saying this is:
47
+
48
+ $$P(X_1, X_2, \ldots, X_n | Y) = P(X_1 | Y) \times P(X_2 | Y) \times \ldots \times P(X_n | Y) = \prod_{i=1}^{n} P(X_i | Y)$$
49
+
50
+ This independence assumption is almost always wrong in real data. Think about text classification — if you see the word "cloudy" in a weather report, you're much more likely to also see "rain" than you would be to see "sunshine". These words clearly depend on each other! Or in medical diagnosis, symptoms often occur together as part of syndromes.
51
+
52
+ But here's the cool part — even though we know this assumption is _technically_ wrong, the algorithm still works remarkably well in practice. By making this simplifying assumption, we:
53
+
54
+ - Make the math way easier to compute
55
+ - Need way less training data to get decent results
56
+ - Can handle thousands of features without blowing up computationally
57
+ """
58
+ )
59
+ return
60
+
61
+
62
+ @app.cell(hide_code=True)
63
+ def _(mo):
64
+ mo.md(
65
+ r"""
66
+ ## The Math Behind Naive Bayes
67
+
68
+ At its core, Naive Bayes is just an application of Bayes' theorem from our earlier probability notebooks. Let's break it down:
69
+
70
+ We have some features $\mathbf{X} = [X_1, X_2, \ldots, X_m]$ (like words in an email or symptoms of a disease) and we want to predict a class label $Y$ (like "spam/not spam" or "has disease/doesn't have disease").
71
+
72
+ What we're really trying to find is:
73
+
74
+ $$P(Y|\mathbf{X})$$
75
+
76
+ In other words, "what's the probability of a certain class given the features we observed?" Once we have these probabilities, we simply pick the class with the highest probability:
77
+
78
+ $$\hat{y} = \underset{y}{\operatorname{argmax}} \text{ } P(Y=y|\mathbf{X}=\mathbf{x})$$
79
+
80
+ Applying Bayes' theorem (from our earlier probability work), we get:
81
+
82
+ $$P(Y=y|\mathbf{X}=\mathbf{x}) = \frac{P(Y=y) \times P(\mathbf{X}=\mathbf{x}|Y=y)}{P(\mathbf{X}=\mathbf{x})}$$
83
+
84
+ Since we're comparing different possible classes for the same input features, the denominator $P(\mathbf{X}=\mathbf{x})$ is the same for all classes. So we can drop it and just compare:
85
+
86
+ $$\hat{y} = \underset{y}{\operatorname{argmax}} \text{ } P(Y=y) \times P(\mathbf{X}=\mathbf{x}|Y=y)$$
87
+
88
+ Here's where the "naive" part comes in. Calculating $P(\mathbf{X}=\mathbf{x}|Y=y)$ directly would be a computational nightmare - we'd need counts for every possible combination of feature values. Instead, we make that simplifying "naive" assumption that features are independent of each other:
89
+
90
+ $$P(\mathbf{X}=\mathbf{x}|Y=y) = \prod_{i=1}^{m} P(X_i=x_i|Y=y)$$
91
+
92
+ Which gives us our final formula:
93
+
94
+ $$\hat{y} = \underset{y}{\operatorname{argmax}} \text{ } P(Y=y) \times \prod_{i=1}^{m} P(X_i=x_i|Y=y)$$
95
+
96
+ In actual implementations, we usually use logarithms to avoid the numerical problems that come with multiplying many small probabilities (they can _underflow_ to zero):
97
+
98
+ $$\hat{y} = \underset{y}{\operatorname{argmax}} \text{ } \log P(Y=y) + \sum_{i=1}^{m} \log P(X_i=x_i|Y=y)$$
99
+
100
+ That's it! The really cool thing is that despite this massive simplification, the algorithm often gives surprisingly good results.
101
+ """
102
+ )
103
+ return
104
+
105
+
106
+ @app.cell(hide_code=True)
107
+ def _(mo):
108
+ mo.md(r"""## Example Problem""")
109
+ return
110
+
111
+
112
+ @app.cell(hide_code=True)
113
+ def _(mo):
114
+ mo.md(r"""Let's apply Naive Bayes principles to this data (Tennis Training Dataset):""")
115
+ return
116
+
117
+
118
+ @app.cell(hide_code=True)
119
+ def _(mo):
120
+ mo.md(
121
+ r"""
122
+ ## A Simple Example: Play Tennis
123
+
124
+ Let's understand Naive Bayes with a classic example: predicting whether someone will play tennis based on weather conditions. This is the same example used in Mahesh Huddar's excellent video.
125
+
126
+ Our dataset has these features:
127
+ - **Outlook**: Sunny, Overcast, Rainy
128
+ - **Temperature**: Hot, Mild, Cool
129
+ - **Humidity**: High, Normal
130
+ - **Wind**: Strong, Weak
131
+
132
+ And the target variable:
133
+ - **Play Tennis**: Yes, No
134
+
135
+ ### Example Dataset
136
+ """
137
+ )
138
+
139
+ # Create a dataset matching the image (in dict format for proper table rendering)
140
+ example_data = [
141
+ {"Day": "D1", "Outlook": "Sunny", "Temperature": "Hot", "Humidity": "High", "Wind": "Weak", "Play Tennis": "No"},
142
+ {"Day": "D2", "Outlook": "Sunny", "Temperature": "Hot", "Humidity": "High", "Wind": "Strong", "Play Tennis": "No"},
143
+ {"Day": "D3", "Outlook": "Overcast", "Temperature": "Hot", "Humidity": "High", "Wind": "Weak", "Play Tennis": "Yes"},
144
+ {"Day": "D4", "Outlook": "Rain", "Temperature": "Mild", "Humidity": "High", "Wind": "Weak", "Play Tennis": "Yes"},
145
+ {"Day": "D5", "Outlook": "Rain", "Temperature": "Cool", "Humidity": "Normal", "Wind": "Weak", "Play Tennis": "Yes"},
146
+ {"Day": "D6", "Outlook": "Rain", "Temperature": "Cool", "Humidity": "Normal", "Wind": "Strong", "Play Tennis": "No"},
147
+ {"Day": "D7", "Outlook": "Overcast", "Temperature": "Cool", "Humidity": "Normal", "Wind": "Strong", "Play Tennis": "Yes"},
148
+ {"Day": "D8", "Outlook": "Sunny", "Temperature": "Mild", "Humidity": "High", "Wind": "Weak", "Play Tennis": "No"},
149
+ {"Day": "D9", "Outlook": "Sunny", "Temperature": "Cool", "Humidity": "Normal", "Wind": "Weak", "Play Tennis": "Yes"},
150
+ {"Day": "D10", "Outlook": "Rain", "Temperature": "Mild", "Humidity": "Normal", "Wind": "Weak", "Play Tennis": "Yes"},
151
+ {"Day": "D11", "Outlook": "Sunny", "Temperature": "Mild", "Humidity": "Normal", "Wind": "Strong", "Play Tennis": "Yes"},
152
+ {"Day": "D12", "Outlook": "Overcast", "Temperature": "Mild", "Humidity": "High", "Wind": "Strong", "Play Tennis": "Yes"},
153
+ {"Day": "D13", "Outlook": "Overcast", "Temperature": "Hot", "Humidity": "Normal", "Wind": "Weak", "Play Tennis": "Yes"},
154
+ {"Day": "D14", "Outlook": "Rain", "Temperature": "Mild", "Humidity": "High", "Wind": "Strong", "Play Tennis": "No"}
155
+ ]
156
+
157
+ # Display the tennis dataset using a table
158
+ example_table = mo.ui.table(
159
+ data=example_data,
160
+ selection=None
161
+ )
162
+
163
+ mo.vstack([
164
+ mo.md("#### Tennis Training Dataset"),
165
+ example_table
166
+ ])
167
+ return example_data, example_table
168
+
169
+
170
+ @app.cell(hide_code=True)
171
+ def _(mo):
172
+ mo.md(
173
+ r"""
174
+ Let's predict whether someone will play tennis given these weather conditions:
175
+
176
+ - Outlook: Sunny
177
+ - Temperature: Cool
178
+ - Humidity: High
179
+ - Wind: Strong
180
+
181
+ Let's walk through the calculations step by step:
182
+
183
+ #### Step 1: Calculate Prior Probabilities
184
+
185
+ First, we calculate $P(Y=\text{Yes})$ and $P(Y=\text{No})$:
186
+
187
+ - $P(Y=\text{Yes}) = \frac{9}{14} = 0.64$
188
+ - $P(Y=\text{No}) = \frac{5}{14} = 0.36$
189
+
190
+ #### Step 2: Calculate Conditional Probabilities
191
+
192
+ Next, we calculate the conditional probabilities for each feature value given each class:
193
+ """
194
+ )
195
+ return
196
+
197
+
198
+ @app.cell(hide_code=True)
199
+ def _(humidity_data, mo, outlook_data, summary_table, temp_data, wind_data):
200
+ # Display tables with appropriate styling
201
+ mo.vstack([
202
+ mo.md("#### Class Distribution"),
203
+ summary_table,
204
+ mo.md("#### Conditional Probabilities"),
205
+ mo.hstack([
206
+ mo.vstack([
207
+ mo.md("**Outlook**"),
208
+ mo.ui.table(
209
+ data=outlook_data,
210
+ selection=None
211
+ )
212
+ ]),
213
+ mo.vstack([
214
+ mo.md("**Temperature**"),
215
+ mo.ui.table(
216
+ data=temp_data,
217
+ selection=None
218
+ )
219
+ ])
220
+ ]),
221
+ mo.hstack([
222
+ mo.vstack([
223
+ mo.md("**Humidity**"),
224
+ mo.ui.table(
225
+ data=humidity_data,
226
+ selection=None
227
+ )
228
+ ]),
229
+ mo.vstack([
230
+ mo.md("**Wind**"),
231
+ mo.ui.table(
232
+ data=wind_data,
233
+ selection=None
234
+ )
235
+ ])
236
+ ])
237
+ ])
238
+ return
239
+
240
+
241
+ @app.cell
242
+ def _():
243
+ # DIY
244
+ return
245
+
246
+
247
+ @app.cell(hide_code=True)
248
+ def _(mo, solution_accordion):
249
+ # Display the accordion
250
+ mo.accordion(solution_accordion)
251
+ return
252
+
253
+
254
+ @app.cell(hide_code=True)
255
+ def _(mo):
256
+ mo.md(
257
+ r"""
258
+ ### Try a Different Example
259
+
260
+ What if the conditions were different? Let's say:
261
+
262
+ - Outlook: Overcast
263
+ - Temperature: Hot
264
+ - Humidity: Normal
265
+ - Wind: Weak
266
+
267
+ Try working through this example on your own. If you get stuck, you can use the tables above and apply the same method we used in the solution.
268
+ """
269
+ )
270
+ return
271
+
272
+
273
+ @app.cell
274
+ def _():
275
+ # DIY
276
+ return
277
+
278
+
279
+ @app.cell(hide_code=True)
280
+ def _(mo):
281
+ mo.md(
282
+ r"""
283
+ ## Interactive Naive Bayes
284
+
285
+ Let's explore Naive Bayes with an interactive visualization. This will help build intuition about how the algorithm makes predictions and how the naive independence assumption affects results.
286
+ """
287
+ )
288
+ return
289
+
290
+
291
+ @app.cell(hide_code=True)
292
+ def gaussian_viz(
293
+ Ellipse,
294
+ GaussianNB,
295
+ ListedColormap,
296
+ class_sep_slider,
297
+ controls,
298
+ make_classification,
299
+ mo,
300
+ n_samples_slider,
301
+ noise_slider,
302
+ np,
303
+ pl,
304
+ plt,
305
+ regenerate_button,
306
+ train_test_split,
307
+ ):
308
+ # get values from sliders
309
+ class_sep = class_sep_slider.value
310
+ noise_val = noise_slider.value
311
+ n_samples = int(n_samples_slider.value)
312
+
313
+ # check if regenerate button was clicked
314
+ regenerate_state = regenerate_button.value
315
+
316
+ # make a dataset with current settings
317
+ X, y = make_classification(
318
+ n_samples=n_samples,
319
+ n_features=2,
320
+ n_redundant=0,
321
+ n_informative=2,
322
+ n_clusters_per_class=1,
323
+ class_sep=class_sep * (1 - noise_val), # use noise to reduce separation
324
+ random_state=42 if not regenerate_state else np.random.randint(1000)
325
+ )
326
+
327
+ # put data in a dataframe
328
+ viz_df = pl.DataFrame({
329
+ "Feature1": X[:, 0],
330
+ "Feature2": X[:, 1],
331
+ "Class": y
332
+ })
333
+
334
+ # split into train/test
335
+ X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)
336
+
337
+ # create naive bayes classifier
338
+ gnb = GaussianNB()
339
+ gnb.fit(X_train, y_train)
340
+
341
+ # setup grid for boundary visualization
342
+ h = 0.1 # step size
343
+ x_min, x_max = X[:, 0].min() - 1, X[:, 0].max() + 1
344
+ y_min, y_max = X[:, 1].min() - 1, X[:, 1].max() + 1
345
+ xx, yy = np.meshgrid(np.arange(x_min, x_max, h), np.arange(y_min, y_max, h))
346
+
347
+ # predict on grid points
348
+ grid_points = np.c_[xx.ravel(), yy.ravel()]
349
+ Z = gnb.predict(grid_points).reshape(xx.shape)
350
+
351
+ # calculate class stats
352
+ class0_mean = np.mean(X_train[y_train == 0], axis=0)
353
+ class1_mean = np.mean(X_train[y_train == 1], axis=0)
354
+ class0_var = np.var(X_train[y_train == 0], axis=0)
355
+ class1_var = np.var(X_train[y_train == 1], axis=0)
356
+
357
+ # format for display
358
+ class_stats = [
359
+ {"Class": "Class 0", "Feature1_Mean": f"{class0_mean[0]:.4f}", "Feature1_Variance": f"{class0_var[0]:.4f}",
360
+ "Feature2_Mean": f"{class0_mean[1]:.4f}", "Feature2_Variance": f"{class0_var[1]:.4f}"},
361
+ {"Class": "Class 1", "Feature1_Mean": f"{class1_mean[0]:.4f}", "Feature1_Variance": f"{class1_var[0]:.4f}",
362
+ "Feature2_Mean": f"{class1_mean[1]:.4f}", "Feature2_Variance": f"{class1_var[1]:.4f}"}
363
+ ]
364
+
365
+ # setup plot with two panels
366
+ fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 6))
367
+
368
+ # colors for our plots
369
+ cmap_light = ListedColormap(['#FFAAAA', '#AAAAFF']) # bg colors
370
+ cmap_bold = ListedColormap(['#FF0000', '#0000FF']) # point colors
371
+
372
+ # left: decision boundary
373
+ ax1.contourf(xx, yy, Z, alpha=0.3, cmap=cmap_light)
374
+ scatter1 = ax1.scatter(X_train[:, 0], X_train[:, 1], c=y_train,
375
+ cmap=cmap_bold, edgecolor='k', s=50, alpha=0.8)
376
+ scatter2 = ax1.scatter(X_test[:, 0], X_test[:, 1], c=y_test,
377
+ cmap=cmap_bold, edgecolor='k', s=25, alpha=0.5)
378
+
379
+ ax1.set_xlabel('Feature 1')
380
+ ax1.set_ylabel('Feature 2')
381
+ ax1.set_title('Gaussian Naive Bayes Decision Boundary')
382
+ ax1.legend([scatter1.legend_elements()[0][0], scatter2.legend_elements()[0][0]],
383
+ ['Training Data', 'Test Data'], loc='upper right')
384
+
385
+ # right: distribution visualization
386
+ class0_data = viz_df.filter(pl.col("Class") == 0)
387
+ class1_data = viz_df.filter(pl.col("Class") == 1)
388
+
389
+ ax2.scatter(class0_data["Feature1"], class0_data["Feature2"],
390
+ color='red', edgecolor='k', s=50, alpha=0.8, label='Class 0')
391
+ ax2.scatter(class1_data["Feature1"], class1_data["Feature2"],
392
+ color='blue', edgecolor='k', s=50, alpha=0.8, label='Class 1')
393
+
394
+ # draw ellipses function
395
+ def plot_ellipse(ax, mean, cov, color):
396
+ vals, vecs = np.linalg.eigh(cov)
397
+ order = vals.argsort()[::-1]
398
+ vals = vals[order]
399
+ vecs = vecs[:, order]
400
+
401
+ theta = np.degrees(np.arctan2(*vecs[:, 0][::-1]))
402
+ width, height = 2 * np.sqrt(5.991 * vals)
403
+
404
+ ellip = Ellipse(xy=mean, width=width, height=height, angle=theta,
405
+ edgecolor=color, fc='None', lw=2, alpha=0.7)
406
+ ax.add_patch(ellip)
407
+
408
+ # add ellipses for each class accordingly
409
+ class0_cov = np.diag(np.var(X_train[y_train == 0], axis=0))
410
+ class1_cov = np.diag(np.var(X_train[y_train == 1], axis=0))
411
+
412
+ plot_ellipse(ax2, class0_mean, class0_cov, 'red')
413
+ plot_ellipse(ax2, class1_mean, class1_cov, 'blue')
414
+
415
+ ax2.set_xlabel('Feature 1')
416
+ ax2.set_ylabel('Feature 2')
417
+ ax2.set_title('Class-Conditional Distributions (Gaussian)')
418
+ ax2.legend(loc='upper right')
419
+
420
+ plt.tight_layout()
421
+ plt.gca()
422
+
423
+ # show interactive plot
424
+ mpl_fig = mo.mpl.interactive(fig)
425
+
426
+ # show parameters info
427
+ mo.md(
428
+ r"""
429
+ ### gaussian parameters by class
430
+
431
+ each feature follows a normal distribution per class. here are the parameters:
432
+ """
433
+ )
434
+
435
+ # make stats table
436
+ stats_table = mo.ui.table(
437
+ data=class_stats,
438
+ selection="single"
439
+ )
440
+
441
+ mo.md(
442
+ r"""
443
+ ### how it works
444
+
445
+ 1. calculate mean & variance for each feature per class
446
+ 2. use gaussian pdf to get probabilities for new points
447
+ 3. apply bayes' theorem to pick most likely class
448
+
449
+ ellipses show the distributions, decision boundary is where probabilities equal
450
+ """
451
+ )
452
+
453
+ # stack everything together
454
+ mo.vstack([
455
+ controls.center(),
456
+ mpl_fig,
457
+ stats_table
458
+ ])
459
+ return (
460
+ X,
461
+ X_test,
462
+ X_train,
463
+ Z,
464
+ ax1,
465
+ ax2,
466
+ class0_cov,
467
+ class0_data,
468
+ class0_mean,
469
+ class0_var,
470
+ class1_cov,
471
+ class1_data,
472
+ class1_mean,
473
+ class1_var,
474
+ class_sep,
475
+ class_stats,
476
+ cmap_bold,
477
+ cmap_light,
478
+ fig,
479
+ gnb,
480
+ grid_points,
481
+ h,
482
+ mpl_fig,
483
+ n_samples,
484
+ noise_val,
485
+ plot_ellipse,
486
+ regenerate_state,
487
+ scatter1,
488
+ scatter2,
489
+ stats_table,
490
+ viz_df,
491
+ x_max,
492
+ x_min,
493
+ xx,
494
+ y,
495
+ y_max,
496
+ y_min,
497
+ y_test,
498
+ y_train,
499
+ yy,
500
+ )
501
+
502
+
503
+ @app.cell(hide_code=True)
504
+ def _(mo):
505
+ mo.md(
506
+ r"""
507
+ ### what's going on in this demo?
508
+
509
+ Playing with the sliders changes how our data looks and how the classifier behaves. Class separation controls how far apart the two classes are — higher values make them easier to tell apart. The noise slider adds randomness by reducing that separation, making boundaries fuzzier and classification harder. More samples just gives you more data points to work with.
510
+
511
+ The left graph shows the decision boundary — that curved line where the classifier switches from predicting one class to another. Red and blue regions show where naive bayes would classify new points. The right graph shows the actual distribution of both classes, with those ellipses representing the gaussian distributions naive bayes is using internally.
512
+
513
+ Try cranking up the noise and watch how the boundary gets messier. increase separation and see how confident the classifier becomes. This is basically what's happening inside naive bayes — it's looking at each feature's distribution per class and making the best guess based on probabilities. The table below shows the actual parameters (means and variances) the model calculates.
514
+ """
515
+ )
516
+ return
517
+
518
+
519
+ @app.cell(hide_code=True)
520
+ def _(mo):
521
+ mo.md(
522
+ r"""
523
+ ## Types of Naive Bayes Classifiers
524
+
525
+ ### Multinomial Naive Bayes
526
+ Ideal for text classification where features represent word counts or frequencies.
527
+
528
+ Mathematical form:
529
+
530
+ \[P(x_i|y) = \frac{\text{count}(x_i, y) + \alpha}{\sum_{i=1}^{|V|} \text{count}(x_i, y) + \alpha|V|}\]
531
+
532
+ where:
533
+
534
+ - \(\alpha\) is the smoothing parameter
535
+ - \(|V|\) is the size of the vocabulary
536
+ - \(\text{count}(x_i, y)\) is the count of feature \(i\) in class \(y\)
537
+
538
+ ### Bernoulli Naive Bayes
539
+ Best for binary features (0/1) — either a word appears or it doesn't.
540
+
541
+ Mathematical form:
542
+
543
+ \[P(x_i|y) = p_{iy}^{x_i}(1-p_{iy})^{(1-x_i)}\]
544
+
545
+ where:
546
+
547
+ - \(p_{iy}\) is the probability of feature \(i\) occurring in class \(y\)
548
+ - \(x_i\) is 1 if the feature is present, 0 otherwise
549
+
550
+ ### Gaussian Naive Bayes
551
+ Designed for continuous features, assuming they follow a normal distribution.
552
+
553
+ Mathematical form:
554
+
555
+ \[P(x_i|y) = \frac{1}{\sqrt{2\pi\sigma_y^2}} \exp\left(-\frac{(x_i - \mu_y)^2}{2\sigma_y^2}\right)\]
556
+
557
+ where:
558
+
559
+ - \(\mu_y\) is the mean of feature values for class \(y\)
560
+ - \(\sigma_y^2\) is the variance of feature values for class \(y\)
561
+
562
+ ### Complement Naive Bayes
563
+ Particularly effective for imbalanced datasets.
564
+
565
+ Mathematical form:
566
+
567
+ \[P(x_i|y) = \frac{\text{count}(x_i, \bar{y}) + \alpha}{\sum_{i=1}^{|V|} \text{count}(x_i, \bar{y}) + \alpha|V|}\]
568
+
569
+ where:
570
+
571
+ - \(\bar{y}\) represents all classes except \(y\)
572
+ - Other parameters are similar to Multinomial Naive Bayes
573
+ """
574
+ )
575
+ return
576
+
577
+
578
+ @app.cell(hide_code=True)
579
+ def _(mo):
580
+ mo.md(
581
+ r"""
582
+ ## 🤔 Test Your Understanding
583
+
584
+ Test your understanding of Naive Bayes with these statements:
585
+
586
+ /// details | Multiplying small probabilities in Naive Bayes can lead to numerical underflow.
587
+ ✅ **Correct!** Multiplying many small probabilities can indeed lead to numerical underflow.
588
+
589
+ That's why in practice, we often use log probabilities and add them instead of multiplying the original probabilities. This prevents numerical underflow and improves computational stability.
590
+ ///
591
+
592
+ /// details | Laplace smoothing is unnecessary if your training data covers all possible feature values.
593
+ ❌ **Incorrect.** Laplace smoothing is still beneficial even with complete feature coverage.
594
+
595
+ While Laplace smoothing is crucial for handling unseen feature values, it also helps with small sample sizes by preventing overfitting to the training data. Even with complete feature coverage, some combinations might have very few examples, leading to unreliable probability estimates.
596
+ ///
597
+
598
+ /// details | Naive Bayes performs poorly on high-dimensional data compared to other classifiers.
599
+ ❌ **Incorrect.** Naive Bayes actually excels with high-dimensional data.
600
+
601
+ Due to its simplicity and the independence assumption, Naive Bayes scales very well to high-dimensional data. It's particularly effective for text classification where each word is a dimension and there can be thousands of dimensions. Other classifiers might overfit in such high-dimensional spaces.
602
+ ///
603
+
604
+ /// details | For text classification, Multinomial Naive Bayes typically outperforms Gaussian Naive Bayes.
605
+ ✅ **Correct!** Multinomial NB is better suited for text classification than Gaussian NB.
606
+
607
+ Text data typically involves discrete counts (word frequencies) which align better with a multinomial distribution. Gaussian Naive Bayes assumes features follow a normal distribution, which doesn't match the distribution of word frequencies in text documents.
608
+ ///
609
+ """
610
+ )
611
+ return
612
+
613
+
614
+ @app.cell(hide_code=True)
615
+ def _(mo):
616
+ mo.md(
617
+ r"""
618
+ ## Summary
619
+
620
+ Throughout this notebook, we've explored Naive Bayes classification. What makes this algorithm particularly interesting is its elegant simplicity combined with surprising effectiveness. Despite making what seems like an overly simplistic assumption — that features are independent given the class — it consistently delivers reasonable performance across a wide range of applications.
621
+
622
+ The algorithm's power lies in its probabilistic foundation, built upon Bayes' theorem. During training, it simply learns probability distributions: the likelihood of seeing each class (prior probabilities) and the probability of feature values within each class (conditional probabilities). When making predictions, it combines these probabilities using the naive independence assumption, which dramatically simplifies the computation while still maintaining remarkable predictive power.
623
+
624
+ We've seen how different variants of Naive Bayes adapt to various types of data. Multinomial Naive Bayes excels at text classification by modeling word frequencies, Bernoulli Naive Bayes handles binary features elegantly, and Gaussian Naive Bayes tackles continuous data through normal distributions. Each variant maintains the core simplicity of the algorithm while adapting its probability calculations to match the data's characteristics.
625
+
626
+ Perhaps most importantly, we've learned that sometimes the most straightforward approaches can be the most practical. Naive Bayes demonstrates that a simple model, well-understood and properly applied, can often outperform more complex alternatives, especially in domains like text classification or when working with limited computational resources or training data.
627
+ """
628
+ )
629
+ return
630
+
631
+
632
+ @app.cell(hide_code=True)
633
+ def _(mo):
634
+ mo.md(r"""appendix (helper code)""")
635
+ return
636
+
637
+
638
+ @app.cell
639
+ def _():
640
+ import marimo as mo
641
+ return (mo,)
642
+
643
+
644
+ @app.cell
645
+ def init_imports():
646
+ # imports for our notebook
647
+ import numpy as np
648
+ import matplotlib.pyplot as plt
649
+ import polars as pl
650
+ from scipy import stats
651
+ from sklearn.naive_bayes import GaussianNB
652
+ from sklearn.datasets import make_classification
653
+ from sklearn.model_selection import train_test_split
654
+ from matplotlib.colors import ListedColormap
655
+ from matplotlib.patches import Ellipse
656
+
657
+ # for consistent results
658
+ np.random.seed(42)
659
+
660
+ # nicer plots
661
+ plt.style.use('seaborn-v0_8-darkgrid')
662
+ return (
663
+ Ellipse,
664
+ GaussianNB,
665
+ ListedColormap,
666
+ make_classification,
667
+ np,
668
+ pl,
669
+ plt,
670
+ stats,
671
+ train_test_split,
672
+ )
673
+
674
+
675
+ @app.cell(hide_code=True)
676
+ def _(example_data, mo):
677
+ # occurrences count in example data
678
+ yes_count = sum(1 for row in example_data if row["Play Tennis"] == "Yes")
679
+ no_count = sum(1 for row in example_data if row["Play Tennis"] == "No")
680
+ total = len(example_data)
681
+
682
+ # summary table with dict format
683
+ summary_data = [
684
+ {"Class": "Yes", "Count": f"{yes_count}", "Probability": f"{yes_count/total:.2f}"},
685
+ {"Class": "No", "Count": f"{no_count}", "Probability": f"{no_count/total:.2f}"},
686
+ {"Class": "Total", "Count": f"{total}", "Probability": "1.00"}
687
+ ]
688
+
689
+ summary_table = mo.ui.table(
690
+ data=summary_data,
691
+ selection=None
692
+ )
693
+
694
+ # tables for conditional probabilities matching the image (in dict format)
695
+ outlook_data = [
696
+ {"Outlook": "Sunny", "Y": "2/9", "N": "3/5"},
697
+ {"Outlook": "Overcast", "Y": "4/9", "N": "0"},
698
+ {"Outlook": "Rain", "Y": "3/9", "N": "2/5"}
699
+ ]
700
+
701
+ temp_data = [
702
+ {"Temperature": "Hot", "Y": "2/9", "N": "2/5"},
703
+ {"Temperature": "Mild", "Y": "4/9", "N": "2/5"},
704
+ {"Temperature": "Cool", "Y": "3/9", "N": "1/5"}
705
+ ]
706
+
707
+ humidity_data = [
708
+ {"Humidity": "High", "Y": "3/9", "N": "4/5"},
709
+ {"Humidity": "Normal", "Y": "6/9", "N": "1/5"}
710
+ ]
711
+
712
+ wind_data = [
713
+ {"Wind": "Strong", "Y": "3/9", "N": "3/5"},
714
+ {"Wind": "Weak", "Y": "6/9", "N": "2/5"}
715
+ ]
716
+ return (
717
+ humidity_data,
718
+ no_count,
719
+ outlook_data,
720
+ summary_data,
721
+ summary_table,
722
+ temp_data,
723
+ total,
724
+ wind_data,
725
+ yes_count,
726
+ )
727
+
728
+
729
+ @app.cell(hide_code=True)
730
+ def _(mo):
731
+ # accordion with solution (step-by-step)
732
+ solution_accordion = {
733
+ "step-by-step solution (click to expand)": mo.md(r"""
734
+ #### step 1: gather probabilities
735
+
736
+ from our tables:
737
+
738
+ **prior probabilities:**
739
+
740
+ - $P(Yes) = 9/14 = 0.64$
741
+ - $P(No) = 5/14 = 0.36$
742
+
743
+ **conditional probabilities:**
744
+
745
+ - $P(Outlook=Sunny|Yes) = 2/9$
746
+ - $P(Outlook=Sunny|No) = 3/5$
747
+ - $P(Temperature=Cool|Yes) = 3/9$
748
+ - $P(Temperature=Cool|No) = 1/5$
749
+ - $P(Humidity=High|Yes) = 3/9$
750
+ - $P(Humidity=High|No) = 4/5$
751
+ - $P(Wind=Strong|Yes) = 3/9$
752
+ - $P(Wind=Strong|No) = 3/5$
753
+
754
+ #### step 2: calculate for yes
755
+
756
+ $P(Yes) \times P(Sunny|Yes) \times P(Cool|Yes) \times P(High|Yes) \times P(Strong|Yes)$
757
+
758
+ $= \frac{9}{14} \times \frac{2}{9} \times \frac{3}{9} \times \frac{3}{9} \times \frac{3}{9}$
759
+
760
+ $= \frac{9}{14} \times \frac{2 \times 3 \times 3 \times 3}{9^4}$
761
+
762
+ $= \frac{9}{14} \times \frac{54}{6561}$
763
+
764
+ $= \frac{9 \times 54}{14 \times 6561}$
765
+
766
+ $= \frac{486}{91854}$
767
+
768
+ $= 0.0053$
769
+
770
+ #### step 3: calculate for no
771
+
772
+ $P(No) \times P(Sunny|No) \times P(Cool|No) \times P(High|No) \times P(Strong|No)$
773
+
774
+ $= \frac{5}{14} \times \frac{3}{5} \times \frac{1}{5} \times \frac{4}{5} \times \frac{3}{5}$
775
+
776
+ $= \frac{5}{14} \times \frac{3 \times 1 \times 4 \times 3}{5^4}$
777
+
778
+ $= \frac{5}{14} \times \frac{36}{625}$
779
+
780
+ $= \frac{5 \times 36}{14 \times 625}$
781
+
782
+ $= \frac{180}{8750}$
783
+
784
+ $= 0.0206$
785
+
786
+ #### step 4: normalize
787
+
788
+ sum of probabilities: $0.0053 + 0.0206 = 0.0259$
789
+
790
+ normalizing:
791
+
792
+ - $P(Yes|evidence) = \frac{0.0053}{0.0259} = 0.205$ (20.5%)
793
+ - $P(No|evidence) = \frac{0.0206}{0.0259} = 0.795$ (79.5%)
794
+
795
+ #### step 5: predict
796
+
797
+ since $P(No|evidence) > P(Yes|evidence)$, prediction: **No**
798
+
799
+ person would **not play tennis** under these conditions.
800
+ """)
801
+ }
802
+ return (solution_accordion,)
803
+
804
+
805
+ @app.cell(hide_code=True)
806
+ def create_gaussian_controls(mo):
807
+ # sliders for controlling viz parameters
808
+ class_sep_slider = mo.ui.slider(1.0, 3.0, value=1.5, label="Class Separation")
809
+ noise_slider = mo.ui.slider(0.1, 0.5, step=0.1, value=0.1, label="Noise (reduces class separation)")
810
+ n_samples_slider = mo.ui.slider(50, 200, value=100, step=10, label="Number of Samples")
811
+
812
+ # Create a run button to regenerate data
813
+ regenerate_button = mo.ui.run_button(label="Regenerate Data", kind="success")
814
+
815
+ # stack controls vertically
816
+ controls = mo.vstack([
817
+ mo.md("### visualization controls"),
818
+ class_sep_slider,
819
+ noise_slider,
820
+ n_samples_slider,
821
+ regenerate_button
822
+ ])
823
+ return (
824
+ class_sep_slider,
825
+ controls,
826
+ n_samples_slider,
827
+ noise_slider,
828
+ regenerate_button,
829
+ )
830
+
831
+
832
+ if __name__ == "__main__":
833
+ app.run()
probability/21_logistic_regression.py ADDED
@@ -0,0 +1,530 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # /// script
2
+ # requires-python = ">=3.10"
3
+ # dependencies = [
4
+ # "marimo",
5
+ # "matplotlib==3.10.1",
6
+ # "numpy==2.2.4",
7
+ # "drawdata==0.3.7",
8
+ # "scikit-learn==1.6.1",
9
+ # "polars==1.26.0",
10
+ # ]
11
+ # ///
12
+
13
+ import marimo
14
+
15
+ __generated_with = "0.12.5"
16
+ app = marimo.App(width="medium", app_title="Logistic Regression")
17
+
18
+
19
+ @app.cell(hide_code=True)
20
+ def _(mo):
21
+ mo.md(
22
+ r"""
23
+ # Logistic Regression
24
+
25
+ _This notebook is a computational companion to ["Probability for Computer Scientists"](https://chrispiech.github.io/probabilityForComputerScientists/en/part5/log_regression/), by Stanford professor Chris Piech._
26
+
27
+ Logistic regression learns a function approximating $P(y|x)$, and can be used to make a classifier. It makes the central assumption that $P(y|x)$ can be approximated as a sigmoid function applied to a linear combination of input features. It is particularly important to learn because logistic regression is the basic building block of artificial neural networks.
28
+ """
29
+ )
30
+ return
31
+
32
+
33
+ @app.cell(hide_code=True)
34
+ def _(mo):
35
+ mo.md(
36
+ r"""
37
+ ## The Binary Classification Problem
38
+
39
+ Imagine situations where we would like to know:
40
+
41
+ - The eligibility of getting a bank loan given the value of credit score ($x_{credit\_score}$) and monthly income ($x_{income}$)
42
+ - Identifying a tumor as benign or malignant given its size ($x_{tumor\_size}$)
43
+ - Classifying an email as promotional given the number of occurrences for some keywords like {'win', 'gift', 'discount'} ($x_{n\_win}$, $x_{n\_gift}$, $x_{n\_discount}$)
44
+ - Finding a monetary transaction as fraudulent given the time of occurrence ($x_{time\_stamp}$) and amount ($x_{amount}$)
45
+
46
+ These problems occur frequently in real life & can be dealt with machine learning. All such problems come under the umbrella of what is known as Classification. In each scenario, only one of the two possible outcomes can occur, hence these are specifically known as Binary Classification problems.
47
+
48
+ ### How Does A Machine Perform Classification?
49
+
50
+ During the inference, the goal is to have the ML model predict the class label for a given set of feature values.
51
+
52
+ Specifically, a binary classification model estimates two probabilities $p_0$ & $p_1$ for 'class-0' and 'class-1' respectively where $p_0 + p_1 = 1$.
53
+
54
+ The predicted label depends on $\max(p_0, p_1)$ i.e., it's the one which is most probable based on the given features.
55
+
56
+ In logistic regression, $p_1$ (i.e., success probability) is compared with a predefined threshold $p$ to predict the class label like below:
57
+
58
+ $$\text{predicted class} =
59
+ \begin{cases}
60
+ 1, & \text{if } p_1 \geq p \\
61
+ 0, & \text{otherwise}
62
+ \end{cases}$$
63
+
64
+ To keep the notation simple and consistent, we will denote the success probability as $p$, and failure probability as $(1-p)$ instead of $p_1$ and $p_0$ respectively.
65
+ """
66
+ )
67
+ return
68
+
69
+
70
+ @app.cell(hide_code=True)
71
+ def _(mo):
72
+ mo.md(
73
+ r"""
74
+ ## Why NOT Linear Regression?
75
+
76
+ Can't we really use linear regression to address classification? The answer is NO! The key issue is that probabilities must be between 0 and 1 and linear regression can output any real number.
77
+
78
+ If we tried using linear regression directly:
79
+ $$p = \beta_0 + \beta_1 \cdot x_{feature}$$
80
+
81
+ This creates a problem: the right side can produce any value in $\mathbb{R}$ (all real numbers), but a probability $p$ must be confined to the range $(0,1)$.
82
+
83
+ Can we convert $(\beta_0 + \beta_1 \cdot x_{tumor\_size})$ to something belonging to $(0,1)$? That may work as an estimate of a probability! The answer is YES!
84
+
85
+ We need a converter (a function), say, $g()$ that will connect $p \in (0,1)$ to $(\beta_0 + \beta_1 \cdot x_{tumor\_size}) \in \mathbb{R}$.
86
+
87
+ The solution is to use a "link function" that maps from any real number to a valid probability range. This is where the sigmoid function comes in.
88
+ """
89
+ )
90
+ return
91
+
92
+
93
+ @app.cell(hide_code=True)
94
+ def _(mo, np, plt):
95
+ # plot sigmoid to evidentiate above statements
96
+ _fig, ax = plt.subplots(figsize=(10, 6))
97
+
98
+ # x values
99
+ x = np.linspace(-10, 10, 1000)
100
+
101
+ # sigmoid formula
102
+ def sigmoid(z):
103
+ return 1 / (1 + np.exp(-z))
104
+
105
+ y = sigmoid(x)
106
+
107
+ # plot
108
+ ax.plot(x, y, 'b-', linewidth=2)
109
+
110
+ ax.axhline(y=0, color='k', linestyle='-', alpha=0.3)
111
+ ax.axhline(y=1, color='k', linestyle='-', alpha=0.3)
112
+ ax.axhline(y=0.5, color='r', linestyle='--', alpha=0.5)
113
+
114
+ # vertical line at x=0
115
+ ax.axvline(x=0, color='k', linestyle='-', alpha=0.3)
116
+
117
+ # annotations
118
+ ax.text(1, 0.85, r'$\sigma(z) = \frac{1}{1 + e^{-z}}$', fontsize=14)
119
+ ax.text(-9, 0.1, 'As z → -∞, σ(z) → 0', fontsize=12)
120
+ ax.text(3, 0.9, 'As z → ∞, σ(z) → 1', fontsize=12)
121
+ ax.text(0.5, 0.4, 'σ(0) = 0.5', fontsize=12)
122
+
123
+ # labels and title
124
+ ax.set_xlabel('z', fontsize=14)
125
+ ax.set_ylabel('σ(z)', fontsize=14)
126
+ ax.set_title('Sigmoid Function', fontsize=16)
127
+
128
+ # axis limits set
129
+ ax.set_xlim(-10, 10)
130
+ ax.set_ylim(-0.1, 1.1)
131
+
132
+ # grid
133
+ ax.grid(True, alpha=0.3)
134
+
135
+ mo.mpl.interactive(_fig)
136
+ return ax, sigmoid, x, y
137
+
138
+
139
+ @app.cell(hide_code=True)
140
+ def _(mo):
141
+ mo.md(
142
+ r"""
143
+ **Figure**: The sigmoid function maps any real number to a value between 0 and 1, making it perfect for representing probabilities.
144
+
145
+ /// note
146
+ For more information about the sigmoid function, head over to [this detailed notebook](http://marimo.app/https://github.com/marimo-team/deepml-notebooks/blob/main/problems/problem-22/notebook.py) for more insights.
147
+ ///
148
+ """
149
+ )
150
+ return
151
+
152
+
153
+ @app.cell(hide_code=True)
154
+ def _(mo):
155
+ mo.md(
156
+ r"""
157
+ ## The Core Concept (math)
158
+
159
+ Logistic regression models the probability of class 1 using the sigmoid function:
160
+
161
+ $$P(Y=1|X=x) = \sigma(z) \text{ where } z = \theta_0 + \sum_{i=1}^m \theta_i x_i$$
162
+
163
+ The sigmoid function $\sigma(z)$ transforms any real number into a probability between 0 and 1:
164
+
165
+ $$\sigma(z) = \frac{1}{1+ e^{-z}}$$
166
+
167
+ This can be written more compactly using vector notation:
168
+
169
+ $$P(Y=1|\mathbf{X}=\mathbf{x}) =\sigma(\mathbf{\theta}^T\mathbf{x}) \quad \text{ where we always set $x_0$ to be 1}$$
170
+
171
+ $$P(Y=0|\mathbf{X}=\mathbf{x}) =1-\sigma(\mathbf{\theta}^T\mathbf{x}) \quad \text{ by total law of probability}$$
172
+
173
+ Where $\theta$ represents the model parameters that need to be learned from data, and $x$ is the feature vector (with $x_0=1$ to account for the intercept term).
174
+
175
+ > **Note:** For the detailed mathematical derivation of how these parameters are learned through Maximum Likelihood Estimation (MLE) and Gradient Descent (GD), please refer to [Chris Piech's original material](https://chrispiech.github.io/probabilityForComputerScientists/en/part5/log_regression/). The mathematical details are elegant but beyond the scope of this notebook topic (which is confined to Logistic Regression).
176
+ """
177
+ )
178
+ return
179
+
180
+
181
+ @app.cell(hide_code=True)
182
+ def _(mo):
183
+ mo.md(
184
+ r"""
185
+ ### Linear Decision Boundary
186
+
187
+ A key characteristic of logistic regression is that it creates a linear decision boundary. When the model predicts, it's effectively dividing the feature space with a straight line (in 2D) or hyperplane (in higher dimensions). It is actually a straight line (of the form $y = mx + c$).
188
+
189
+ Recall the prediction rule:
190
+ $$\text{predicted class} =
191
+ \begin{cases}
192
+ 1, & \text{if } p \geq \theta_0 + \theta_1 \cdot x_{tumor\_size} \Rightarrow \log\frac{p}{1-p} \\
193
+ 0, & \text{otherwise}
194
+ \end{cases}$$
195
+
196
+ For a two-feature model, the decision boundary where $P(Y=1|X=x) = 0.5$ occurs at:
197
+ $$\theta_0 + \theta_1 x_1 + \theta_2 x_2 = 0$$
198
+
199
+ A simple logistic regression predicts the class label by identifying the regions on either side of a straight line (or hyperplane in general), hence it's a _linear_ classifier.
200
+
201
+ This linear nature makes logistic regression effective for linearly separable classes but limited when dealing with more complex patterns.
202
+ """
203
+ )
204
+ return
205
+
206
+
207
+ @app.cell(hide_code=True)
208
+ def _(mo):
209
+ mo.md("""### Visual: Linear Separability and Classification""")
210
+ return
211
+
212
+
213
+ @app.cell(hide_code=True)
214
+ def _(mo, np, plt):
215
+ # show relevant comparison to the above concepts/statements
216
+
217
+ fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))
218
+
219
+ # Linear separable data
220
+ np.random.seed(42)
221
+ X1 = np.random.randn(100, 2) - 2
222
+ X2 = np.random.randn(100, 2) + 2
223
+
224
+ ax1.scatter(X1[:, 0], X1[:, 1], color='blue', alpha=0.5)
225
+ ax1.scatter(X2[:, 0], X2[:, 1], color='red', alpha=0.5)
226
+
227
+ # Decision boundary (line)
228
+ ax1.plot([-5, 5], [5, -5], 'k--', linewidth=2)
229
+ ax1.set_xlim(-5, 5)
230
+ ax1.set_ylim(-5, 5)
231
+ ax1.set_title('Linearly Separable Classes')
232
+
233
+ # non-linear separable data
234
+ radius = 2
235
+ theta = np.linspace(0, 2*np.pi, 100)
236
+
237
+ # Outer circle points (class 1)
238
+ outer_x = 3 * np.cos(theta)
239
+ outer_y = 3 * np.sin(theta)
240
+ # Inner circle points (class 2)
241
+ inner_x = 1.5 * np.cos(theta) + np.random.randn(100) * 0.2
242
+ inner_y = 1.5 * np.sin(theta) + np.random.randn(100) * 0.2
243
+
244
+ ax2.scatter(outer_x, outer_y, color='blue', alpha=0.5)
245
+ ax2.scatter(inner_x, inner_y, color='red', alpha=0.5)
246
+
247
+ # Attempt to draw a linear boundary (which won't work well) proving the point
248
+ ax2.plot([-5, 5], [2, 2], 'k--', linewidth=2)
249
+
250
+ ax2.set_xlim(-5, 5)
251
+ ax2.set_ylim(-5, 5)
252
+ ax2.set_title('Non-Linearly Separable Classes')
253
+
254
+ fig.tight_layout()
255
+ mo.mpl.interactive(fig)
256
+ return (
257
+ X1,
258
+ X2,
259
+ ax1,
260
+ ax2,
261
+ fig,
262
+ inner_x,
263
+ inner_y,
264
+ outer_x,
265
+ outer_y,
266
+ radius,
267
+ theta,
268
+ )
269
+
270
+
271
+ @app.cell(hide_code=True)
272
+ def _(mo):
273
+ mo.md(r"""**Figure**: On the left, the classes are linearly separable as the boundary is a straight line. However, they are not linearly separable on the right, where no straight line can properly separate the two classes.""")
274
+ return
275
+
276
+
277
+ @app.cell(hide_code=True)
278
+ def _(mo):
279
+ mo.md(
280
+ r"""
281
+ Logistic regression is typically trained using MLE - finding the parameters $\theta$ that make our observed data most probable.
282
+
283
+ The optimization process generally uses GD (or its variants) to iteratively improve the parameters. The gradient has a surprisingly elegant form:
284
+
285
+ $$\frac{\partial LL(\theta)}{\partial \theta_j} = \sum_{i=1}^n \left[
286
+ y^{(i)} - \sigma(\theta^T x^{(i)})
287
+ \right] x_j^{(i)}$$
288
+
289
+ This shows that the update to each parameter depends on the prediction error (actual - predicted) multiplied by the feature value.
290
+
291
+ For those interested in the complete mathematical derivation, including log likelihood calculation and the detailed steps of GD (and relevant pseudocode followed for training), please see the [original lecture notes](https://chrispiech.github.io/probabilityForComputerScientists/en/part5/log_regression/).
292
+ """
293
+ )
294
+ return
295
+
296
+
297
+ @app.cell(hide_code=True)
298
+ def _(controls, mo, widget):
299
+ # create the layout
300
+ mo.vstack([
301
+ mo.md("## Interactive drawing demo\nDraw points of two different classes and see how logistic regression separates them. _The interactive demo was adapted and improvised from [Vincent Warmerdam's](https://github.com/koaning) code [here](https://github.com/probabl-ai/youtube-appendix/blob/main/04-drawing-data/notebook.ipynb)_."),
302
+ controls,
303
+ widget
304
+ ])
305
+ return
306
+
307
+
308
+ @app.cell(hide_code=True)
309
+ def _(LogisticRegression, mo, np, plt, run_button, widget):
310
+ warning_msg = mo.md(""" /// warning
311
+ Need more data, please draw points of at least two different colors in the scatter widget
312
+ """)
313
+
314
+ # mo.stop if button isn't clicked yet
315
+ mo.stop(
316
+ not run_button.value,
317
+ mo.md(""" /// tip
318
+ click 'Run Logistic Regression' to see the model
319
+ """)
320
+ )
321
+
322
+ # get data from widget (can also use as_pandas)
323
+ df = widget.data_as_polars
324
+
325
+ # display appropriate warning
326
+ mo.stop(
327
+ df.is_empty() or df['color'].n_unique() < 2,
328
+ warning_msg
329
+ )
330
+
331
+ # extract features and labels
332
+ X = df[['x', 'y']].to_numpy()
333
+ y_colors = df['color'].to_numpy()
334
+
335
+ # fit logistic regression model
336
+ model = LogisticRegression()
337
+ model.fit(X, y_colors)
338
+
339
+ # create grid for the viz
340
+ x_min, x_max = X[:, 0].min() - 0.5, X[:, 0].max() + 0.5
341
+ y_min, y_max = X[:, 1].min() - 0.5, X[:, 1].max() + 0.5
342
+ xx, yy = np.meshgrid(
343
+ np.linspace(x_min, x_max, 100),
344
+ np.linspace(y_min, y_max, 100)
345
+ )
346
+
347
+ # get probability predictions
348
+ Z = model.predict_proba(np.c_[xx.ravel(), yy.ravel()])[:, 1]
349
+ Z = Z.reshape(xx.shape)
350
+
351
+ # create figure
352
+ _fig, ax_fig = plt.subplots(figsize=(12, 8))
353
+
354
+ # plot decision boundary (probability contours)
355
+ contour = ax_fig.contourf(
356
+ xx, yy, Z,
357
+ levels=np.linspace(0, 1, 11),
358
+ alpha=0.7,
359
+ cmap="RdBu_r"
360
+ )
361
+
362
+ # plot decision boundary line (probability = 0.5)
363
+ ax_fig.contour(
364
+ xx, yy, Z,
365
+ levels=[0.5],
366
+ colors='k',
367
+ linewidths=2
368
+ )
369
+
370
+ # plot the data points (use same colors as in the widget)
371
+ ax_fig.scatter(X[:, 0], X[:, 1], c=y_colors, edgecolor='k', s=80)
372
+
373
+ # colorbar
374
+ plt.colorbar(contour, ax=ax_fig)
375
+
376
+ # labels and title
377
+ ax_fig.set_xlabel('x')
378
+ ax_fig.set_ylabel('y')
379
+ ax_fig.set_title('Logistic Regression')
380
+
381
+ # model params
382
+ coef = model.coef_[0]
383
+ intercept = model.intercept_[0]
384
+ equation = f"log(p/(1-p)) = {intercept:.2f} + {coef[0]:.3f}x₁ + {coef[1]:.3f}x₂"
385
+
386
+ # relevant info in regards to regression
387
+ model_info = mo.md(f"""
388
+ ### Logistic regression model
389
+
390
+ **Equation**: {equation}
391
+
392
+ **Decision boundary**: probability = 0.5
393
+
394
+ **Accuracy**: {model.score(X, y_colors):.2f}
395
+ """)
396
+
397
+ # show results vertically stacked
398
+ mo.vstack([
399
+ mo.mpl.interactive(_fig),
400
+ model_info
401
+ ])
402
+ return (
403
+ X,
404
+ Z,
405
+ ax_fig,
406
+ coef,
407
+ contour,
408
+ df,
409
+ equation,
410
+ intercept,
411
+ model,
412
+ model_info,
413
+ warning_msg,
414
+ x_max,
415
+ x_min,
416
+ xx,
417
+ y_colors,
418
+ y_max,
419
+ y_min,
420
+ yy,
421
+ )
422
+
423
+
424
+ @app.cell(hide_code=True)
425
+ def _(mo):
426
+ mo.md(
427
+ r"""
428
+ ## 🤔 Key Takeaways
429
+
430
+ Click on the statements below that you think are correct to verify your understanding:
431
+
432
+ /// details | Logistic regression tries to find parameters (θ) that minimize the error between predicted and actual values using ordinary least squares.
433
+ ❌ **Incorrect.** Logistic regression uses maximum likelihood estimation (MLE), not ordinary least squares. It finds parameters that maximize the probability of observing the training data, which is different from minimizing squared errors as in linear regression.
434
+ ///
435
+
436
+ /// details | The sigmoid function maps any real number to a value between 0 and 1, which allows logistic regression to output probabilities.
437
+ ✅ **Correct!** The sigmoid function σ(z) = 1/(1+e^(-z)) takes any real number as input and outputs a value between 0 and 1. This is perfect for representing probabilities and is a key component of logistic regression.
438
+ ///
439
+
440
+ /// details | The decision boundary in logistic regression is always a straight line, regardless of the data's complexity.
441
+ ✅ **Correct!** Standard logistic regression produces a linear decision boundary (a straight line in 2D or a hyperplane in higher dimensions). This is why it works well for linearly separable data but struggles with more complex patterns, like concentric circles (as you might've noticed from the interactive demo).
442
+ ///
443
+
444
+ /// details | The logistic regression model params are typically initialized to random values and refined through gradient descent.
445
+ ✅ **Correct!** Parameters are often initialized to zeros or small random values, then updated iteratively using gradient descent (or ascent for maximizing likelihood) until convergence.
446
+ ///
447
+
448
+ /// details | Logistic regression can naturally handle multi-class classification problems without any modifications.
449
+ ❌ **Incorrect.** Standard logistic regression is inherently a binary classifier. To handle multi-class classification, techniques like one-vs-rest or softmax regression are typically used.
450
+ ///
451
+ """
452
+ )
453
+ return
454
+
455
+
456
+ @app.cell(hide_code=True)
457
+ def _(mo):
458
+ mo.md(
459
+ r"""
460
+ ## Summary
461
+
462
+ So we've just explored logistic regression. Despite its name (seriously though, why not call it "logistic classification"?), it's actually quite elegant in how it transforms a simple linear model into a powerful decision _boundary_ maker.
463
+
464
+ The training process boils down to finding the values of θ that maximize the likelihood of seeing our training data. What's super cool is that even though the math looks _scary_ at first, the gradient has this surprisingly simple form: just the error (y - predicted) multiplied by the feature values.
465
+
466
+ Two key insights to remember:
467
+
468
+ - Logistic regression creates a _linear_ decision boundary, so it works great for linearly separable classes but struggles with more _complex_ patterns
469
+ - It directly gives you probabilities, not just classifications, which is incredibly useful when you need confidence measures
470
+ """
471
+ )
472
+ return
473
+
474
+
475
+ @app.cell(hide_code=True)
476
+ def _(mo):
477
+ mo.md(
478
+ r"""
479
+ Additional resources referred to:
480
+
481
+ - [Logistic Regression Tutorial by _Koushik Khan_](https://koushikkhan.github.io/resources/pdf/tutorials/logistic_regression_tutorial.pdf)
482
+ """
483
+ )
484
+ return
485
+
486
+
487
+ @app.cell(hide_code=True)
488
+ def _(mo):
489
+ mo.md(r"""Appendix (helper code)""")
490
+ return
491
+
492
+
493
+ @app.cell
494
+ def _():
495
+ import marimo as mo
496
+ return (mo,)
497
+
498
+
499
+ @app.cell
500
+ def init_imports():
501
+ # imports for our notebook
502
+ import numpy as np
503
+ import matplotlib.pyplot as plt
504
+ from drawdata import ScatterWidget
505
+ from sklearn.linear_model import LogisticRegression
506
+
507
+
508
+ # for consistent results
509
+ np.random.seed(42)
510
+
511
+ # nicer plots
512
+ plt.style.use('seaborn-v0_8-darkgrid')
513
+ return LogisticRegression, ScatterWidget, np, plt
514
+
515
+
516
+ @app.cell(hide_code=True)
517
+ def _(ScatterWidget, mo):
518
+ # drawing widget
519
+ widget = mo.ui.anywidget(ScatterWidget())
520
+
521
+ # run_button to run model
522
+ run_button = mo.ui.run_button(label="Run Logistic Regression", kind="success")
523
+
524
+ # stack controls
525
+ controls = mo.hstack([run_button])
526
+ return controls, run_button, widget
527
+
528
+
529
+ if __name__ == "__main__":
530
+ app.run()