Spaces:
Running
Running
Merge branch 'main' into fp/applicatives
Browse files- .github/workflows/typos.yaml +2 -0
- .typos.toml +22 -0
- functional_programming/README.md +4 -5
- polars/08_working_with_columns.py +605 -0
- polars/09_data_types.py +296 -0
- probability/20_naive_bayes.py +833 -0
- probability/21_logistic_regression.py +530 -0
.github/workflows/typos.yaml
CHANGED
@@ -13,3 +13,5 @@ jobs:
|
|
13 |
uses: styfle/[email protected]
|
14 |
- uses: actions/checkout@v4
|
15 |
- uses: crate-ci/[email protected]
|
|
|
|
|
|
13 |
uses: styfle/[email protected]
|
14 |
- uses: actions/checkout@v4
|
15 |
- uses: crate-ci/[email protected]
|
16 |
+
with:
|
17 |
+
config: .typos.toml
|
.typos.toml
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[default]
|
2 |
+
extend-ignore-re = [
|
3 |
+
# LaTeX math expressions
|
4 |
+
"\\\\\\[.*?\\\\\\]",
|
5 |
+
"\\\\\\(.*?\\\\\\)",
|
6 |
+
"\\$\\$.*?\\$\\$",
|
7 |
+
"\\$.*?\\$",
|
8 |
+
# LaTeX commands
|
9 |
+
"\\\\[a-zA-Z]+\\{.*?\\}",
|
10 |
+
"\\\\[a-zA-Z]+",
|
11 |
+
# LaTeX subscripts and superscripts
|
12 |
+
"_\\{.*?\\}",
|
13 |
+
"\\^\\{.*?\\}"
|
14 |
+
]
|
15 |
+
|
16 |
+
# Words to explicitly accept
|
17 |
+
[default.extend-words]
|
18 |
+
pn = "pn"
|
19 |
+
|
20 |
+
# You can also exclude specific files or directories if needed
|
21 |
+
# [files]
|
22 |
+
# extend-exclude = ["*.tex", "docs/*.md"]
|
functional_programming/README.md
CHANGED
@@ -38,7 +38,7 @@ uvx marimo edit <URL>
|
|
38 |
For example, run the `Functor` tutorial with
|
39 |
|
40 |
```bash
|
41 |
-
uvx marimo edit https://github.com/marimo-team/learn/blob/main/
|
42 |
```
|
43 |
|
44 |
You can also open notebooks in our online playground by appending `marimo.app/`
|
@@ -50,10 +50,9 @@ to a notebook's URL:
|
|
50 |
Check [here](https://github.com/marimo-team/learn/issues/51) for current series
|
51 |
structure.
|
52 |
|
53 |
-
| Notebook
|
54 |
-
|
55 |
-
| [05. Functors](https://github.com/marimo-team/learn/blob/main/
|
56 |
-
| [06. Applicatives](https://github.com/marimo-team/learn/blob/main/Functional_programming/06_applicatives.py) | Applicative programming with effects | Learn how to apply functions within a context, combining multiple effects in a pure way. Learn about the `pure` and `apply` operations that make applicatives powerful for handling multiple computations. | Applicative Functors, Pure, Apply, Effectful programming | Functors |
|
57 |
|
58 |
**Authors.**
|
59 |
|
|
|
38 |
For example, run the `Functor` tutorial with
|
39 |
|
40 |
```bash
|
41 |
+
uvx marimo edit https://github.com/marimo-team/learn/blob/main/functional_programming/05_functors.py
|
42 |
```
|
43 |
|
44 |
You can also open notebooks in our online playground by appending `marimo.app/`
|
|
|
50 |
Check [here](https://github.com/marimo-team/learn/issues/51) for current series
|
51 |
structure.
|
52 |
|
53 |
+
| Notebook | Description | References |
|
54 |
+
| ----------------------------------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ | --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
55 |
+
| [05. Category and Functors](https://github.com/marimo-team/learn/blob/main/functional_programming/05_functors.py) | Learn why `len` is a _Functor_ from `list concatenation` to `integer addition`, how to _lift_ an ordinary function into a _computation context_, and how to write an _adapter_ between two categories. | - [The Trivial Monad](http://blog.sigfpe.com/2007/04/trivial-monad.html) <br> - [Haskellwiki. Category Theory](https://en.wikibooks.org/wiki/Haskell/Category_theory) <br> - [Haskellforall. The Category Design Pattern](https://www.haskellforall.com/2012/08/the-category-design-pattern.html) <br> - [Haskellforall. The Functor Design Pattern](https://www.haskellforall.com/2012/09/the-functor-design-pattern.html) <br> - [Haskellwiki. Functor](https://wiki.haskell.org/index.php?title=Functor) <br> - [Haskellwiki. Typeclassopedia#Functor](https://wiki.haskell.org/index.php?title=Typeclassopedia#Functor) <br> - [Haskellwiki. Typeclassopedia#Category](https://wiki.haskell.org/index.php?title=Typeclassopedia#Category) |
|
|
|
56 |
|
57 |
**Authors.**
|
58 |
|
polars/08_working_with_columns.py
ADDED
@@ -0,0 +1,605 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# /// script
|
2 |
+
# requires-python = ">=3.11"
|
3 |
+
# dependencies = [
|
4 |
+
# "polars==1.18.0",
|
5 |
+
# "marimo",
|
6 |
+
# ]
|
7 |
+
# ///
|
8 |
+
|
9 |
+
import marimo
|
10 |
+
|
11 |
+
__generated_with = "0.12.0"
|
12 |
+
app = marimo.App(width="medium")
|
13 |
+
|
14 |
+
|
15 |
+
@app.cell(hide_code=True)
|
16 |
+
def _(mo):
|
17 |
+
mo.md(
|
18 |
+
r"""
|
19 |
+
# Working with Columns
|
20 |
+
|
21 |
+
Author: [Deb Debnath](https://github.com/debajyotid2)
|
22 |
+
|
23 |
+
**Note**: The following tutorial has been adapted from the Polars [documentation](https://docs.pola.rs/user-guide/expressions/expression-expansion).
|
24 |
+
"""
|
25 |
+
)
|
26 |
+
return
|
27 |
+
|
28 |
+
|
29 |
+
@app.cell(hide_code=True)
|
30 |
+
def _(mo):
|
31 |
+
mo.md(
|
32 |
+
r"""
|
33 |
+
## Expressions
|
34 |
+
|
35 |
+
Data transformations are sometimes complicated, or involve massive computations which are time-consuming. You can make a small version of the dataset with the schema you are trying to work your transformation into. But there is a better way to do it in Polars.
|
36 |
+
|
37 |
+
A Polars expression is a lazy representation of a data transformation. "Lazy" means that the transformation is not eagerly (immediately) executed.
|
38 |
+
|
39 |
+
Expressions are modular and flexible. They can be composed to build more complex expressions. For example, to calculate speed from distance and time, you can have an expression as:
|
40 |
+
"""
|
41 |
+
)
|
42 |
+
return
|
43 |
+
|
44 |
+
|
45 |
+
@app.cell
|
46 |
+
def _(pl):
|
47 |
+
speed_expr = pl.col("distance") / (pl.col("time"))
|
48 |
+
speed_expr
|
49 |
+
return (speed_expr,)
|
50 |
+
|
51 |
+
|
52 |
+
@app.cell(hide_code=True)
|
53 |
+
def _(mo):
|
54 |
+
mo.md(
|
55 |
+
r"""
|
56 |
+
## Expression expansion
|
57 |
+
|
58 |
+
Expression expansion lets you write a single expression that can expand to multiple different expressions. So rather than repeatedly defining separate expressions, you can avoid redundancy while adhering to clean code principles (Do not Repeat Yourself - [DRY](https://en.wikipedia.org/wiki/Don%27t_repeat_yourself)). Since expressions are reusable, they aid in writing concise code.
|
59 |
+
"""
|
60 |
+
)
|
61 |
+
return
|
62 |
+
|
63 |
+
|
64 |
+
@app.cell(hide_code=True)
|
65 |
+
def _(mo):
|
66 |
+
mo.md("""For the examples in this notebook, we will use a sliver of the *AI4I 2020 Predictive Maintenance Dataset*. This dataset comprises of measurements taken from sensors in industrial machinery undergoing preventive maintenance checks - basically being tested for failure conditions.""")
|
67 |
+
return
|
68 |
+
|
69 |
+
|
70 |
+
@app.cell
|
71 |
+
def _(StringIO, pl):
|
72 |
+
data_csv = """
|
73 |
+
Product ID,Type,Air temperature,Process temperature,Rotational speed,Tool wear,Machine failure,TWF,HDF,PWF,OSF,RNF
|
74 |
+
L51172,L,302.3,311.3,1614,129,0,0,1,0,0,0
|
75 |
+
M22586,M,300.8,311.9,1761,113,1,0,0,0,1,0
|
76 |
+
L51639,L,302.6,310.4,1743,191,0,1,0,0,0,1
|
77 |
+
L50250,L,300,309.1,1631,110,0,0,0,1,0,0
|
78 |
+
M20109,M,303.4,312.9,1422,63,1,0,0,0,0,0
|
79 |
+
"""
|
80 |
+
|
81 |
+
data = pl.read_csv(StringIO(data_csv))
|
82 |
+
data
|
83 |
+
return data, data_csv
|
84 |
+
|
85 |
+
|
86 |
+
@app.cell(hide_code=True)
|
87 |
+
def _(mo):
|
88 |
+
mo.md(
|
89 |
+
r"""
|
90 |
+
## Function `col`
|
91 |
+
|
92 |
+
The function `col` is used to refer to one column of a dataframe. It is one of the fundamental building blocks of expressions in Polars. `col` is also really handy in expression expansion.
|
93 |
+
"""
|
94 |
+
)
|
95 |
+
return
|
96 |
+
|
97 |
+
|
98 |
+
@app.cell(hide_code=True)
|
99 |
+
def _(mo):
|
100 |
+
mo.md(
|
101 |
+
r"""
|
102 |
+
### Explicit expansion by column name
|
103 |
+
|
104 |
+
The simplest form of expression expansion happens when you provide multiple column names to the function `col`.
|
105 |
+
|
106 |
+
Say you wish to convert all temperature values in deg. Kelvin (K) to deg. Fahrenheit (F). One way to do this would be to define individual expressions for each column as follows:
|
107 |
+
"""
|
108 |
+
)
|
109 |
+
return
|
110 |
+
|
111 |
+
|
112 |
+
@app.cell
|
113 |
+
def _(data, pl):
|
114 |
+
exprs = [
|
115 |
+
((pl.col("Air temperature") - 273.15) * 1.8 + 32).round(2),
|
116 |
+
((pl.col("Process temperature") - 273.15) * 1.8 + 32).round(2)
|
117 |
+
]
|
118 |
+
|
119 |
+
result = data.with_columns(exprs)
|
120 |
+
result
|
121 |
+
return exprs, result
|
122 |
+
|
123 |
+
|
124 |
+
@app.cell(hide_code=True)
|
125 |
+
def _(mo):
|
126 |
+
mo.md(r"""Expression expansion can reduce this verbosity when you list the column names you want the expression to expand to inside the `col` function. The result is the same as before.""")
|
127 |
+
return
|
128 |
+
|
129 |
+
|
130 |
+
@app.cell
|
131 |
+
def _(data, pl, result):
|
132 |
+
result_2 = data.with_columns(
|
133 |
+
(
|
134 |
+
(pl.col(
|
135 |
+
"Air temperature",
|
136 |
+
"Process temperature"
|
137 |
+
)
|
138 |
+
- 273.15) * 1.8 + 32
|
139 |
+
).round(2)
|
140 |
+
)
|
141 |
+
result_2.equals(result)
|
142 |
+
return (result_2,)
|
143 |
+
|
144 |
+
|
145 |
+
@app.cell(hide_code=True)
|
146 |
+
def _(mo):
|
147 |
+
mo.md(r"""In this case, the expression that does the temperature conversion is expanded to a list of two expressions. The expansion of the expression is predictable and intuitive.""")
|
148 |
+
return
|
149 |
+
|
150 |
+
|
151 |
+
@app.cell(hide_code=True)
|
152 |
+
def _(mo):
|
153 |
+
mo.md(
|
154 |
+
r"""
|
155 |
+
### Expansion by data type
|
156 |
+
|
157 |
+
Can we do better than explicitly writing the names of every columns we want transformed? Yes.
|
158 |
+
|
159 |
+
If you provide data types instead of column names, the expression is expanded to all columns that match one of the data types provided.
|
160 |
+
|
161 |
+
The example below performs the exact same computation as before:
|
162 |
+
"""
|
163 |
+
)
|
164 |
+
return
|
165 |
+
|
166 |
+
|
167 |
+
@app.cell
|
168 |
+
def _(data, pl, result):
|
169 |
+
result_3 = data.with_columns(((pl.col(pl.Float64) - 273.15) * 1.8 + 32).round(2))
|
170 |
+
result_3.equals(result)
|
171 |
+
return (result_3,)
|
172 |
+
|
173 |
+
|
174 |
+
@app.cell(hide_code=True)
|
175 |
+
def _(mo):
|
176 |
+
mo.md(
|
177 |
+
r"""
|
178 |
+
However, you should be careful to ensure that the transformation is only applied to the columns you want. For ensuring this it is important to know the schema of the data beforehand.
|
179 |
+
|
180 |
+
`col` accepts multiple data types in case the columns you need have more than one data type.
|
181 |
+
"""
|
182 |
+
)
|
183 |
+
return
|
184 |
+
|
185 |
+
|
186 |
+
@app.cell
|
187 |
+
def _(data, pl, result):
|
188 |
+
result_4 = data.with_columns(
|
189 |
+
(
|
190 |
+
(pl.col(
|
191 |
+
pl.Float32,
|
192 |
+
pl.Float64,
|
193 |
+
)
|
194 |
+
- 273.15) * 1.8 + 32
|
195 |
+
).round(2)
|
196 |
+
)
|
197 |
+
result.equals(result_4)
|
198 |
+
return (result_4,)
|
199 |
+
|
200 |
+
|
201 |
+
@app.cell(hide_code=True)
|
202 |
+
def _(mo):
|
203 |
+
mo.md(
|
204 |
+
r"""
|
205 |
+
### Expansion by pattern matching
|
206 |
+
|
207 |
+
`col` also accepts regular expressions for selecting columns by pattern matching. Regular expressions start and end with ^ and $, respectively.
|
208 |
+
"""
|
209 |
+
)
|
210 |
+
return
|
211 |
+
|
212 |
+
|
213 |
+
@app.cell
|
214 |
+
def _(data, pl):
|
215 |
+
data.select(pl.col("^.*temperature$"))
|
216 |
+
return
|
217 |
+
|
218 |
+
|
219 |
+
@app.cell(hide_code=True)
|
220 |
+
def _(mo):
|
221 |
+
mo.md(r"""Regular expressions can be combined with exact column names.""")
|
222 |
+
return
|
223 |
+
|
224 |
+
|
225 |
+
@app.cell
|
226 |
+
def _(data, pl):
|
227 |
+
data.select(pl.col("^.*temperature$", "Tool wear"))
|
228 |
+
return
|
229 |
+
|
230 |
+
|
231 |
+
@app.cell(hide_code=True)
|
232 |
+
def _(mo):
|
233 |
+
mo.md(r"""**Note**: You _cannot_ mix strings (exact names, regular expressions) and data types in a `col` function.""")
|
234 |
+
return
|
235 |
+
|
236 |
+
|
237 |
+
@app.cell
|
238 |
+
def _(data, pl):
|
239 |
+
try:
|
240 |
+
data.select(pl.col("Air temperature", pl.Float64))
|
241 |
+
except TypeError as err:
|
242 |
+
print("TypeError:", err)
|
243 |
+
return
|
244 |
+
|
245 |
+
|
246 |
+
@app.cell(hide_code=True)
|
247 |
+
def _(mo):
|
248 |
+
mo.md(
|
249 |
+
r"""
|
250 |
+
## Selecting all columns
|
251 |
+
|
252 |
+
To select all columns, you can use the `all` function.
|
253 |
+
"""
|
254 |
+
)
|
255 |
+
return
|
256 |
+
|
257 |
+
|
258 |
+
@app.cell
|
259 |
+
def _(data, pl):
|
260 |
+
result_6 = data.select(pl.all())
|
261 |
+
result_6.equals(data)
|
262 |
+
return (result_6,)
|
263 |
+
|
264 |
+
|
265 |
+
@app.cell(hide_code=True)
|
266 |
+
def _(mo):
|
267 |
+
mo.md(
|
268 |
+
r"""
|
269 |
+
## Excluding columns
|
270 |
+
|
271 |
+
There are scenarios where we might want to exclude specific columns from the ones selected by building expressions, e.g. by the `col` or `all` functions. For this purpose, we use the function `exclude`, which accepts exactly the same types of arguments as `col`:
|
272 |
+
"""
|
273 |
+
)
|
274 |
+
return
|
275 |
+
|
276 |
+
|
277 |
+
@app.cell
|
278 |
+
def _(data, pl):
|
279 |
+
data.select(pl.all().exclude("^.*F$"))
|
280 |
+
return
|
281 |
+
|
282 |
+
|
283 |
+
@app.cell(hide_code=True)
|
284 |
+
def _(mo):
|
285 |
+
mo.md(r"""`exclude` can also be used after the function `col`:""")
|
286 |
+
return
|
287 |
+
|
288 |
+
|
289 |
+
@app.cell
|
290 |
+
def _(data, pl):
|
291 |
+
data.select(pl.col(pl.Int64).exclude("^.*F$"))
|
292 |
+
return
|
293 |
+
|
294 |
+
|
295 |
+
@app.cell(hide_code=True)
|
296 |
+
def _(mo):
|
297 |
+
mo.md(
|
298 |
+
r"""
|
299 |
+
## Column renaming
|
300 |
+
|
301 |
+
When applying a transformation with an expression to a column, the data in the column gets overwritten with the transformed data. However, this might not be the intended outcome in all situations - ideally you would want to store transformed data in a new column. Applying multiple transformations to the same column at the same time without renaming leads to errors.
|
302 |
+
"""
|
303 |
+
)
|
304 |
+
return
|
305 |
+
|
306 |
+
|
307 |
+
@app.cell
|
308 |
+
def _(data, pl):
|
309 |
+
from polars.exceptions import DuplicateError
|
310 |
+
|
311 |
+
try:
|
312 |
+
data.select(
|
313 |
+
(pl.col("Air temperature") - 273.15) * 1.8 + 32, # This would be named "Air temperature"...
|
314 |
+
pl.col("Air temperature") - 273.15, # And so would this.
|
315 |
+
)
|
316 |
+
except DuplicateError as err:
|
317 |
+
print("DuplicateError:", err)
|
318 |
+
return (DuplicateError,)
|
319 |
+
|
320 |
+
|
321 |
+
@app.cell(hide_code=True)
|
322 |
+
def _(mo):
|
323 |
+
mo.md(
|
324 |
+
r"""
|
325 |
+
### Renaming a single column with `alias`
|
326 |
+
|
327 |
+
The function `alias` lets you rename a single column:
|
328 |
+
"""
|
329 |
+
)
|
330 |
+
return
|
331 |
+
|
332 |
+
|
333 |
+
@app.cell
|
334 |
+
def _(data, pl):
|
335 |
+
data.select(
|
336 |
+
((pl.col("Air temperature") - 273.15) * 1.8 + 32).round(2).alias("Air temperature [F]"),
|
337 |
+
(pl.col("Air temperature") - 273.15).round(2).alias("Air temperature [C]")
|
338 |
+
)
|
339 |
+
return
|
340 |
+
|
341 |
+
|
342 |
+
@app.cell(hide_code=True)
|
343 |
+
def _(mo):
|
344 |
+
mo.md(
|
345 |
+
r"""
|
346 |
+
### Prefixing and suffixing column names
|
347 |
+
|
348 |
+
As `alias` renames a single column at a time, it cannot be used during expression expansion. If it is sufficient add a static prefix or a static suffix to the existing names, you can use the functions `name.prefix` and `name.suffix` with `col`:
|
349 |
+
"""
|
350 |
+
)
|
351 |
+
return
|
352 |
+
|
353 |
+
|
354 |
+
@app.cell
|
355 |
+
def _(data, pl):
|
356 |
+
data.select(
|
357 |
+
((pl.col("Air temperature") - 273.15) * 1.8 + 32).round(2).name.prefix("deg F "),
|
358 |
+
(pl.col("Process temperature") - 273.15).round(2).name.suffix(" C"),
|
359 |
+
)
|
360 |
+
return
|
361 |
+
|
362 |
+
|
363 |
+
@app.cell(hide_code=True)
|
364 |
+
def _(mo):
|
365 |
+
mo.md(
|
366 |
+
r"""
|
367 |
+
### Dynamic name replacement
|
368 |
+
|
369 |
+
If a static prefix/suffix is not enough, use `name.map`. `name.map` requires a function that transforms column names to the desired. The transformation should lead to unique names to avoid `DuplicateError`.
|
370 |
+
"""
|
371 |
+
)
|
372 |
+
return
|
373 |
+
|
374 |
+
|
375 |
+
@app.cell
|
376 |
+
def _(data, pl):
|
377 |
+
# There is also `.name.to_lowercase`, so this usage of `.map` is moot.
|
378 |
+
data.select(pl.col("^.*F$").name.map(str.lower))
|
379 |
+
return
|
380 |
+
|
381 |
+
|
382 |
+
@app.cell(hide_code=True)
|
383 |
+
def _(mo):
|
384 |
+
mo.md(
|
385 |
+
r"""
|
386 |
+
## Programmatically generating expressions
|
387 |
+
|
388 |
+
For this example, we will first create four additional columns with the rolling mean temperatures of the two temperature columns. Such transformations are sometimes used to create additional features for machine learning models or data analysis.
|
389 |
+
"""
|
390 |
+
)
|
391 |
+
return
|
392 |
+
|
393 |
+
|
394 |
+
@app.cell
|
395 |
+
def _(data, pl):
|
396 |
+
ext_temp_data = data.with_columns(
|
397 |
+
pl.col("^.*temperature$").rolling_mean(window_size=2).round(2).name.prefix("Rolling mean ")
|
398 |
+
).select(pl.col("^.*temperature*$"))
|
399 |
+
ext_temp_data
|
400 |
+
return (ext_temp_data,)
|
401 |
+
|
402 |
+
|
403 |
+
@app.cell(hide_code=True)
|
404 |
+
def _(mo):
|
405 |
+
mo.md(r"""Now, suppose we want to calculate the difference between the rolling mean and actual temperatures. We cannot use expression expansion here as we want differences between specific columns.""")
|
406 |
+
return
|
407 |
+
|
408 |
+
|
409 |
+
@app.cell(hide_code=True)
|
410 |
+
def _(mo):
|
411 |
+
mo.md(r"""At first, you may think about using a `for` loop:""")
|
412 |
+
return
|
413 |
+
|
414 |
+
|
415 |
+
@app.cell
|
416 |
+
def _(ext_temp_data, pl):
|
417 |
+
_result = ext_temp_data
|
418 |
+
for col_name in ["Air", "Process"]:
|
419 |
+
_result = _result.with_columns(
|
420 |
+
(abs(pl.col(f"Rolling mean {col_name} temperature") - pl.col(f"{col_name} temperature")))
|
421 |
+
.round(2).alias(f"Delta {col_name} temperature")
|
422 |
+
)
|
423 |
+
_result
|
424 |
+
return (col_name,)
|
425 |
+
|
426 |
+
|
427 |
+
@app.cell(hide_code=True)
|
428 |
+
def _(mo):
|
429 |
+
mo.md(r"""Using a `for` loop is functional, but not scalable, as each expression needs to be defined in an iteration and executed serially. Instead we can use a generator in Python to programmatically create all expressions at once. In conjunction with the `with_columns` context, we can take advantage of parallel execution of computations and query optimization from Polars.""")
|
430 |
+
return
|
431 |
+
|
432 |
+
|
433 |
+
@app.cell
|
434 |
+
def _(ext_temp_data, pl):
|
435 |
+
def delta_expressions(colnames: list[str]) -> pl.Expr:
|
436 |
+
for col_name in colnames:
|
437 |
+
yield (abs(pl.col(f"Rolling mean {col_name} temperature") - pl.col(f"{col_name} temperature"))
|
438 |
+
.round(2).alias(f"Delta {col_name} temperature"))
|
439 |
+
|
440 |
+
|
441 |
+
ext_temp_data.with_columns(delta_expressions(["Air", "Process"]))
|
442 |
+
return (delta_expressions,)
|
443 |
+
|
444 |
+
|
445 |
+
@app.cell(hide_code=True)
|
446 |
+
def _(mo):
|
447 |
+
mo.md(
|
448 |
+
r"""
|
449 |
+
## More flexible column selections
|
450 |
+
|
451 |
+
For more flexible column selections, you can use column selectors from `selectors`. Column selectors allow for more expressiveness in the way you specify selections. For example, column selectors can perform the familiar set operations of union, intersection, difference, etc. We can use the union operation with the functions `string` and `ends_with` to select all string columns and the columns whose names end with "`_high`":
|
452 |
+
"""
|
453 |
+
)
|
454 |
+
return
|
455 |
+
|
456 |
+
|
457 |
+
@app.cell
|
458 |
+
def _(data):
|
459 |
+
import polars.selectors as cs
|
460 |
+
|
461 |
+
data.select(cs.string() | cs.ends_with("F"))
|
462 |
+
return (cs,)
|
463 |
+
|
464 |
+
|
465 |
+
@app.cell(hide_code=True)
|
466 |
+
def _(mo):
|
467 |
+
mo.md(r"""Likewise, you can pick columns based on the category of the type of data, offering more flexibility than the `col` function. As an example, `cs.numeric` selects numeric data types (including `pl.Float32`, `pl.Float64`, `pl.Int32`, etc.) or `cs.temporal` for all dates, times and similar data types.""")
|
468 |
+
return
|
469 |
+
|
470 |
+
|
471 |
+
@app.cell(hide_code=True)
|
472 |
+
def _(mo):
|
473 |
+
mo.md(
|
474 |
+
r"""
|
475 |
+
### Combining selectors with set operations
|
476 |
+
|
477 |
+
Multiple selectors can be combined using set operations and the usual Python operators:
|
478 |
+
|
479 |
+
|
480 |
+
| Operator | Operation |
|
481 |
+
|:--------:|:--------------------:|
|
482 |
+
| `A | B` | Union |
|
483 |
+
| `A & B` | Intersection |
|
484 |
+
| `A - B` | Difference |
|
485 |
+
| `A ^ B` | Symmetric difference |
|
486 |
+
| `~A` | Complement |
|
487 |
+
|
488 |
+
For example, to select all failure indicator variables excluding the failure variables due to wear, we can perform a set difference between the column selectors.
|
489 |
+
"""
|
490 |
+
)
|
491 |
+
return
|
492 |
+
|
493 |
+
|
494 |
+
@app.cell
|
495 |
+
def _(cs, data):
|
496 |
+
data.select(cs.contains("F") - cs.contains("W"))
|
497 |
+
return
|
498 |
+
|
499 |
+
|
500 |
+
@app.cell(hide_code=True)
|
501 |
+
def _(mo):
|
502 |
+
mo.md(
|
503 |
+
r"""
|
504 |
+
### Resolving operator ambiguity
|
505 |
+
|
506 |
+
Expression functions can be chained on top of selectors:
|
507 |
+
"""
|
508 |
+
)
|
509 |
+
return
|
510 |
+
|
511 |
+
|
512 |
+
@app.cell
|
513 |
+
def _(cs, data, pl):
|
514 |
+
ext_failure_data = data.select(cs.contains("F")).cast(pl.Boolean)
|
515 |
+
ext_failure_data
|
516 |
+
return (ext_failure_data,)
|
517 |
+
|
518 |
+
|
519 |
+
@app.cell(hide_code=True)
|
520 |
+
def _(mo):
|
521 |
+
mo.md(
|
522 |
+
r"""
|
523 |
+
However, operators that perform set operations on column selectors operate on both selectors and on expressions. For example, the operator `~` on a selector represents the set operation “complement” and on an expression represents the Boolean operation of negation.
|
524 |
+
|
525 |
+
For instance, if you want to negate the Boolean values in the columns “HDF”, “OSF”, and “RNF”, at first you would think about using the `~` operator with the column selector to choose all failure variables containing "W". Because of the operator ambiguity here, the columns that are not of interest are selected here.
|
526 |
+
"""
|
527 |
+
)
|
528 |
+
return
|
529 |
+
|
530 |
+
|
531 |
+
@app.cell
|
532 |
+
def _(cs, ext_failure_data):
|
533 |
+
ext_failure_data.select((~cs.ends_with("WF")).name.prefix("No"))
|
534 |
+
return
|
535 |
+
|
536 |
+
|
537 |
+
@app.cell(hide_code=True)
|
538 |
+
def _(mo):
|
539 |
+
mo.md(r"""To resolve the operator ambiguity, we use `as_expr`:""")
|
540 |
+
return
|
541 |
+
|
542 |
+
|
543 |
+
@app.cell
|
544 |
+
def _(cs, ext_failure_data):
|
545 |
+
ext_failure_data.select((~cs.ends_with("WF").as_expr()).name.prefix("No"))
|
546 |
+
return
|
547 |
+
|
548 |
+
|
549 |
+
@app.cell(hide_code=True)
|
550 |
+
def _(mo):
|
551 |
+
mo.md(
|
552 |
+
r"""
|
553 |
+
### Debugging selectors
|
554 |
+
|
555 |
+
The function `cs.is_selector` helps check whether a complex chain of selectors and operators ultimately results in a selector. For example, to resolve any ambiguity with the selector in the last example, we can do:
|
556 |
+
"""
|
557 |
+
)
|
558 |
+
return
|
559 |
+
|
560 |
+
|
561 |
+
@app.cell
|
562 |
+
def _(cs):
|
563 |
+
cs.is_selector(~cs.ends_with("WF").as_expr())
|
564 |
+
return
|
565 |
+
|
566 |
+
|
567 |
+
@app.cell(hide_code=True)
|
568 |
+
def _(mo):
|
569 |
+
mo.md(r"""Additionally we can use `expand_selector` to see what columns a selector expands into. Note that for this function we need to provide additional context in the form of the dataframe.""")
|
570 |
+
return
|
571 |
+
|
572 |
+
|
573 |
+
@app.cell
|
574 |
+
def _(cs, ext_failure_data):
|
575 |
+
cs.expand_selector(
|
576 |
+
ext_failure_data,
|
577 |
+
cs.ends_with("WF"),
|
578 |
+
)
|
579 |
+
return
|
580 |
+
|
581 |
+
|
582 |
+
@app.cell(hide_code=True)
|
583 |
+
def _(mo):
|
584 |
+
mo.md(
|
585 |
+
r"""
|
586 |
+
### References
|
587 |
+
|
588 |
+
1. AI4I 2020 Predictive Maintenance Dataset [Dataset]. (2020). UCI Machine Learning Repository. ([link](https://doi.org/10.24432/C5HS5C)).
|
589 |
+
2. Polars documentation ([link](https://docs.pola.rs/user-guide/expressions/expression-expansion/#more-flexible-column-selections))
|
590 |
+
"""
|
591 |
+
)
|
592 |
+
return
|
593 |
+
|
594 |
+
|
595 |
+
@app.cell(hide_code=True)
|
596 |
+
def _():
|
597 |
+
import csv
|
598 |
+
import marimo as mo
|
599 |
+
import polars as pl
|
600 |
+
from io import StringIO
|
601 |
+
return StringIO, csv, mo, pl
|
602 |
+
|
603 |
+
|
604 |
+
if __name__ == "__main__":
|
605 |
+
app.run()
|
polars/09_data_types.py
ADDED
@@ -0,0 +1,296 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# /// script
|
2 |
+
# requires-python = ">=3.11"
|
3 |
+
# dependencies = [
|
4 |
+
# "polars==1.18.0",
|
5 |
+
# "marimo",
|
6 |
+
# ]
|
7 |
+
# ///
|
8 |
+
|
9 |
+
import marimo
|
10 |
+
|
11 |
+
__generated_with = "0.12.0"
|
12 |
+
app = marimo.App(width="medium")
|
13 |
+
|
14 |
+
|
15 |
+
@app.cell(hide_code=True)
|
16 |
+
def _(mo):
|
17 |
+
mo.md(
|
18 |
+
r"""
|
19 |
+
# Data Types
|
20 |
+
|
21 |
+
Author: [Deb Debnath](https://github.com/debajyotid2)
|
22 |
+
|
23 |
+
**Note**: The following tutorial has been adapted from the Polars [documentation](https://docs.pola.rs/user-guide/concepts/data-types-and-structures/).
|
24 |
+
"""
|
25 |
+
)
|
26 |
+
return
|
27 |
+
|
28 |
+
|
29 |
+
@app.cell(hide_code=True)
|
30 |
+
def _(mo):
|
31 |
+
mo.md(
|
32 |
+
r"""
|
33 |
+
Polars supports a variety of data types that fall broadly under the following categories:
|
34 |
+
|
35 |
+
- Numeric data types: integers and floating point numbers.
|
36 |
+
- Nested data types: lists, structs, and arrays.
|
37 |
+
- Temporal: dates, datetimes, times, and time deltas.
|
38 |
+
- Miscellaneous: strings, binary data, Booleans, categoricals, enums, and objects.
|
39 |
+
|
40 |
+
All types support missing values represented by `null` which is different from `NaN` used in floating point data types. The numeric datatypes in Polars loosely follow the type system of the Rust language, since its core functionalities are built in Rust.
|
41 |
+
|
42 |
+
[Here](https://docs.pola.rs/api/python/stable/reference/datatypes.html) is a full list of all data types Polars supports.
|
43 |
+
"""
|
44 |
+
)
|
45 |
+
return
|
46 |
+
|
47 |
+
|
48 |
+
@app.cell(hide_code=True)
|
49 |
+
def _(mo):
|
50 |
+
mo.md(
|
51 |
+
r"""
|
52 |
+
## Series
|
53 |
+
|
54 |
+
A series is a 1-dimensional data structure that can hold only one data type.
|
55 |
+
"""
|
56 |
+
)
|
57 |
+
return
|
58 |
+
|
59 |
+
|
60 |
+
@app.cell
|
61 |
+
def _(pl):
|
62 |
+
s = pl.Series("emojis", ["😀", "🤣", "🥶", "💀", "🤖"])
|
63 |
+
s
|
64 |
+
return (s,)
|
65 |
+
|
66 |
+
|
67 |
+
@app.cell(hide_code=True)
|
68 |
+
def _(mo):
|
69 |
+
mo.md(r"""Unless specified, Polars infers the datatype from the supplied values.""")
|
70 |
+
return
|
71 |
+
|
72 |
+
|
73 |
+
@app.cell
|
74 |
+
def _(pl):
|
75 |
+
s1 = pl.Series("friends", ["Евгений", "अभिषेक", "秀良", "Federico", "Bob"])
|
76 |
+
s2 = pl.Series("uints", [0x00, 0x01, 0x10, 0x11], dtype=pl.UInt8)
|
77 |
+
s1.dtype, s2.dtype
|
78 |
+
return s1, s2
|
79 |
+
|
80 |
+
|
81 |
+
@app.cell(hide_code=True)
|
82 |
+
def _(mo):
|
83 |
+
mo.md(
|
84 |
+
r"""
|
85 |
+
## Dataframe
|
86 |
+
|
87 |
+
A dataframe is a 2-dimensional data structure that contains uniquely named series and can hold multiple data types. Dataframes are more commonly used for data manipulation using the functionality of Polars.
|
88 |
+
|
89 |
+
The snippet below shows how to create a dataframe from a dictionary of lists:
|
90 |
+
"""
|
91 |
+
)
|
92 |
+
return
|
93 |
+
|
94 |
+
|
95 |
+
@app.cell
|
96 |
+
def _(pl):
|
97 |
+
data = pl.DataFrame(
|
98 |
+
{
|
99 |
+
"Product ID": ["L51172", "M22586", "L51639", "L50250", "M20109"],
|
100 |
+
"Type": ["L", "M", "L", "L", "M"],
|
101 |
+
"Air temperature": [302.3, 300.8, 302.6, 300, 303.4], # (K)
|
102 |
+
"Machine Failure": [False, True, False, False, True]
|
103 |
+
}
|
104 |
+
)
|
105 |
+
data
|
106 |
+
return (data,)
|
107 |
+
|
108 |
+
|
109 |
+
@app.cell(hide_code=True)
|
110 |
+
def _(mo):
|
111 |
+
mo.md(
|
112 |
+
r"""
|
113 |
+
### Inspecting a dataframe
|
114 |
+
|
115 |
+
Polars has various functions to explore the data in a dataframe. We will use the dataframe `data` defined above in our examples. Alongside we can also see a view of the dataframe rendered by `marimo` as the cells are executed.
|
116 |
+
|
117 |
+
///note
|
118 |
+
We can also use `marimo`'s built in data-inspection elements/features such as [`mo.ui.dataframe`](https://docs.marimo.io/api/inputs/dataframe/#marimo.ui.dataframe) & [`mo.ui.data_explorer`](https://docs.marimo.io/api/inputs/data_explorer/). For more check out our Polars tutorials at [`marimo learn`](https://marimo-team.github.io/learn/)!
|
119 |
+
"""
|
120 |
+
)
|
121 |
+
return
|
122 |
+
|
123 |
+
|
124 |
+
@app.cell(hide_code=True)
|
125 |
+
def _(mo):
|
126 |
+
mo.md(
|
127 |
+
"""
|
128 |
+
#### Head
|
129 |
+
|
130 |
+
The function `head` shows the first rows of a dataframe. Unless specified, it shows the first 5 rows.
|
131 |
+
"""
|
132 |
+
)
|
133 |
+
return
|
134 |
+
|
135 |
+
|
136 |
+
@app.cell
|
137 |
+
def _(data):
|
138 |
+
data.head(3)
|
139 |
+
return
|
140 |
+
|
141 |
+
|
142 |
+
@app.cell(hide_code=True)
|
143 |
+
def _(mo):
|
144 |
+
mo.md(
|
145 |
+
r"""
|
146 |
+
#### Glimpse
|
147 |
+
|
148 |
+
The function `glimpse` is an alternative to `head` to view the first few columns, but displays each line of the output corresponding to a single column. That way, it makes inspecting wider dataframes easier.
|
149 |
+
"""
|
150 |
+
)
|
151 |
+
return
|
152 |
+
|
153 |
+
|
154 |
+
@app.cell
|
155 |
+
def _(data):
|
156 |
+
print(data.glimpse(return_as_string=True))
|
157 |
+
return
|
158 |
+
|
159 |
+
|
160 |
+
@app.cell(hide_code=True)
|
161 |
+
def _(mo):
|
162 |
+
mo.md(
|
163 |
+
r"""
|
164 |
+
#### Tail
|
165 |
+
|
166 |
+
The `tail` function, just like its name suggests, shows the last rows of a dataframe. Unless the number of rows is specified, it will show the last 5 rows.
|
167 |
+
"""
|
168 |
+
)
|
169 |
+
return
|
170 |
+
|
171 |
+
|
172 |
+
@app.cell
|
173 |
+
def _(data):
|
174 |
+
data.tail(3)
|
175 |
+
return
|
176 |
+
|
177 |
+
|
178 |
+
@app.cell(hide_code=True)
|
179 |
+
def _(mo):
|
180 |
+
mo.md(
|
181 |
+
r"""
|
182 |
+
#### Sample
|
183 |
+
|
184 |
+
`sample` can be used to show a specified number of randomly selected rows from the dataframe. Unless the number of rows is specified, it will show a single row. `sample` does not preserve order of the rows.
|
185 |
+
"""
|
186 |
+
)
|
187 |
+
return
|
188 |
+
|
189 |
+
|
190 |
+
@app.cell
|
191 |
+
def _(data):
|
192 |
+
import random
|
193 |
+
|
194 |
+
random.seed(42) # For reproducibility.
|
195 |
+
|
196 |
+
data.sample(3)
|
197 |
+
return (random,)
|
198 |
+
|
199 |
+
|
200 |
+
@app.cell(hide_code=True)
|
201 |
+
def _(mo):
|
202 |
+
mo.md(
|
203 |
+
r"""
|
204 |
+
#### Describe
|
205 |
+
|
206 |
+
The function `describe` describes the summary statistics for all columns of a dataframe.
|
207 |
+
"""
|
208 |
+
)
|
209 |
+
return
|
210 |
+
|
211 |
+
|
212 |
+
@app.cell
|
213 |
+
def _(data):
|
214 |
+
data.describe()
|
215 |
+
return
|
216 |
+
|
217 |
+
|
218 |
+
@app.cell(hide_code=True)
|
219 |
+
def _(mo):
|
220 |
+
mo.md(
|
221 |
+
r"""
|
222 |
+
## Schema
|
223 |
+
|
224 |
+
A schema is a mapping showing the datatype corresponding to every column of a dataframe. The schema of a dataframe can be viewed using the attribute `schema`.
|
225 |
+
"""
|
226 |
+
)
|
227 |
+
return
|
228 |
+
|
229 |
+
|
230 |
+
@app.cell
|
231 |
+
def _(data):
|
232 |
+
data.schema
|
233 |
+
return
|
234 |
+
|
235 |
+
|
236 |
+
@app.cell(hide_code=True)
|
237 |
+
def _(mo):
|
238 |
+
mo.md(r"""Since a schema is a mapping, it can be specified in the form of a Python dictionary. Then this dictionary can be used to specify the schema of a dataframe on definition. If not specified or the entry is `None`, Polars infers the datatype from the contents of the column. Note that if the schema is not specified, it will be inferred automatically by default.""")
|
239 |
+
return
|
240 |
+
|
241 |
+
|
242 |
+
@app.cell
|
243 |
+
def _(pl):
|
244 |
+
pl.DataFrame(
|
245 |
+
{
|
246 |
+
"Product ID": ["L51172", "M22586", "L51639", "L50250", "M20109"],
|
247 |
+
"Type": ["L", "M", "L", "L", "M"],
|
248 |
+
"Air temperature": [302.3, 300.8, 302.6, 300, 303.4], # (K)
|
249 |
+
"Machine Failure": [False, True, False, False, True]
|
250 |
+
},
|
251 |
+
schema={"Product ID": pl.String, "Type": pl.String, "Air temperature": None, "Machine Failure": None},
|
252 |
+
)
|
253 |
+
return
|
254 |
+
|
255 |
+
|
256 |
+
@app.cell(hide_code=True)
|
257 |
+
def _(mo):
|
258 |
+
mo.md(r"""Sometimes the automatically inferred schema is enough for some columns, but we might wish to override the inference of only some columns. We can specify the schema for those columns using `schema_overrides`.""")
|
259 |
+
return
|
260 |
+
|
261 |
+
|
262 |
+
@app.cell
|
263 |
+
def _(pl):
|
264 |
+
pl.DataFrame(
|
265 |
+
{
|
266 |
+
"Product ID": ["L51172", "M22586", "L51639", "L50250", "M20109"],
|
267 |
+
"Type": ["L", "M", "L", "L", "M"],
|
268 |
+
"Air temperature": [302.3, 300.8, 302.6, 300, 303.4], # (K)
|
269 |
+
"Machine Failure": [False, True, False, False, True]
|
270 |
+
},
|
271 |
+
schema_overrides={"Air temperature": pl.Float32},
|
272 |
+
)
|
273 |
+
return
|
274 |
+
|
275 |
+
|
276 |
+
@app.cell(hide_code=True)
|
277 |
+
def _(mo):
|
278 |
+
mo.md(
|
279 |
+
r"""
|
280 |
+
### References
|
281 |
+
|
282 |
+
1. Polars documentation ([link](https://docs.pola.rs/api/python/stable/reference/datatypes.html))
|
283 |
+
"""
|
284 |
+
)
|
285 |
+
return
|
286 |
+
|
287 |
+
|
288 |
+
@app.cell(hide_code=True)
|
289 |
+
def _():
|
290 |
+
import marimo as mo
|
291 |
+
import polars as pl
|
292 |
+
return mo, pl
|
293 |
+
|
294 |
+
|
295 |
+
if __name__ == "__main__":
|
296 |
+
app.run()
|
probability/20_naive_bayes.py
ADDED
@@ -0,0 +1,833 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# /// script
|
2 |
+
# requires-python = ">=3.10"
|
3 |
+
# dependencies = [
|
4 |
+
# "marimo",
|
5 |
+
# "matplotlib==3.10.1",
|
6 |
+
# "scipy==1.15.2",
|
7 |
+
# "numpy==2.2.4",
|
8 |
+
# "polars==1.26.0",
|
9 |
+
# "plotly==5.18.0",
|
10 |
+
# "scikit-learn==1.6.1",
|
11 |
+
# ]
|
12 |
+
# ///
|
13 |
+
|
14 |
+
import marimo
|
15 |
+
|
16 |
+
__generated_with = "0.12.0"
|
17 |
+
app = marimo.App(width="medium", app_title="Naive Bayes Classification")
|
18 |
+
|
19 |
+
|
20 |
+
@app.cell(hide_code=True)
|
21 |
+
def _(mo):
|
22 |
+
mo.md(
|
23 |
+
r"""
|
24 |
+
# Naive Bayes Classification
|
25 |
+
|
26 |
+
_This notebook is a computational companion to ["Probability for Computer Scientists"](https://chrispiech.github.io/probabilityForComputerScientists/en/part5/naive_bayes/), by Stanford professor Chris Piech._
|
27 |
+
|
28 |
+
Naive Bayes is one of those classic machine learning algorithms that seems almost too simple to work, yet it's surprisingly effective for many classification tasks. I've always found it fascinating how this algorithm applies Bayes' theorem with a strong (but knowingly incorrect) "naive" assumption that all features are independent of each other.
|
29 |
+
|
30 |
+
In this notebook, we'll dive into why this supposedly "wrong" assumption still leads to good results. We'll walk through the training process, learn how to make predictions, and see some interactive visualizations that helped me understand the concept better when I was first learning it. We'll also explore why Naive Bayes excels particularly in text classification problems like spam filtering.
|
31 |
+
|
32 |
+
If you're new to Naive Bayes, I highly recommend checking out [this excellent explanation by Mahesh Huddar](https://youtu.be/XzSlEA4ck2I?si=AASeh_KP68BAbzy5), which provides a step-by-step walkthrough with a helpful example (which we take a dive into, down below).
|
33 |
+
"""
|
34 |
+
)
|
35 |
+
return
|
36 |
+
|
37 |
+
|
38 |
+
@app.cell(hide_code=True)
|
39 |
+
def _(mo):
|
40 |
+
mo.md(
|
41 |
+
r"""
|
42 |
+
## Why "Naive"?
|
43 |
+
|
44 |
+
So why is it called "naive"? It's because the algorithm makes an assumption — it assumes all features are completely independent of each other when given the class label.
|
45 |
+
|
46 |
+
The math way of saying this is:
|
47 |
+
|
48 |
+
$$P(X_1, X_2, \ldots, X_n | Y) = P(X_1 | Y) \times P(X_2 | Y) \times \ldots \times P(X_n | Y) = \prod_{i=1}^{n} P(X_i | Y)$$
|
49 |
+
|
50 |
+
This independence assumption is almost always wrong in real data. Think about text classification — if you see the word "cloudy" in a weather report, you're much more likely to also see "rain" than you would be to see "sunshine". These words clearly depend on each other! Or in medical diagnosis, symptoms often occur together as part of syndromes.
|
51 |
+
|
52 |
+
But here's the cool part — even though we know this assumption is _technically_ wrong, the algorithm still works remarkably well in practice. By making this simplifying assumption, we:
|
53 |
+
|
54 |
+
- Make the math way easier to compute
|
55 |
+
- Need way less training data to get decent results
|
56 |
+
- Can handle thousands of features without blowing up computationally
|
57 |
+
"""
|
58 |
+
)
|
59 |
+
return
|
60 |
+
|
61 |
+
|
62 |
+
@app.cell(hide_code=True)
|
63 |
+
def _(mo):
|
64 |
+
mo.md(
|
65 |
+
r"""
|
66 |
+
## The Math Behind Naive Bayes
|
67 |
+
|
68 |
+
At its core, Naive Bayes is just an application of Bayes' theorem from our earlier probability notebooks. Let's break it down:
|
69 |
+
|
70 |
+
We have some features $\mathbf{X} = [X_1, X_2, \ldots, X_m]$ (like words in an email or symptoms of a disease) and we want to predict a class label $Y$ (like "spam/not spam" or "has disease/doesn't have disease").
|
71 |
+
|
72 |
+
What we're really trying to find is:
|
73 |
+
|
74 |
+
$$P(Y|\mathbf{X})$$
|
75 |
+
|
76 |
+
In other words, "what's the probability of a certain class given the features we observed?" Once we have these probabilities, we simply pick the class with the highest probability:
|
77 |
+
|
78 |
+
$$\hat{y} = \underset{y}{\operatorname{argmax}} \text{ } P(Y=y|\mathbf{X}=\mathbf{x})$$
|
79 |
+
|
80 |
+
Applying Bayes' theorem (from our earlier probability work), we get:
|
81 |
+
|
82 |
+
$$P(Y=y|\mathbf{X}=\mathbf{x}) = \frac{P(Y=y) \times P(\mathbf{X}=\mathbf{x}|Y=y)}{P(\mathbf{X}=\mathbf{x})}$$
|
83 |
+
|
84 |
+
Since we're comparing different possible classes for the same input features, the denominator $P(\mathbf{X}=\mathbf{x})$ is the same for all classes. So we can drop it and just compare:
|
85 |
+
|
86 |
+
$$\hat{y} = \underset{y}{\operatorname{argmax}} \text{ } P(Y=y) \times P(\mathbf{X}=\mathbf{x}|Y=y)$$
|
87 |
+
|
88 |
+
Here's where the "naive" part comes in. Calculating $P(\mathbf{X}=\mathbf{x}|Y=y)$ directly would be a computational nightmare - we'd need counts for every possible combination of feature values. Instead, we make that simplifying "naive" assumption that features are independent of each other:
|
89 |
+
|
90 |
+
$$P(\mathbf{X}=\mathbf{x}|Y=y) = \prod_{i=1}^{m} P(X_i=x_i|Y=y)$$
|
91 |
+
|
92 |
+
Which gives us our final formula:
|
93 |
+
|
94 |
+
$$\hat{y} = \underset{y}{\operatorname{argmax}} \text{ } P(Y=y) \times \prod_{i=1}^{m} P(X_i=x_i|Y=y)$$
|
95 |
+
|
96 |
+
In actual implementations, we usually use logarithms to avoid the numerical problems that come with multiplying many small probabilities (they can _underflow_ to zero):
|
97 |
+
|
98 |
+
$$\hat{y} = \underset{y}{\operatorname{argmax}} \text{ } \log P(Y=y) + \sum_{i=1}^{m} \log P(X_i=x_i|Y=y)$$
|
99 |
+
|
100 |
+
That's it! The really cool thing is that despite this massive simplification, the algorithm often gives surprisingly good results.
|
101 |
+
"""
|
102 |
+
)
|
103 |
+
return
|
104 |
+
|
105 |
+
|
106 |
+
@app.cell(hide_code=True)
|
107 |
+
def _(mo):
|
108 |
+
mo.md(r"""## Example Problem""")
|
109 |
+
return
|
110 |
+
|
111 |
+
|
112 |
+
@app.cell(hide_code=True)
|
113 |
+
def _(mo):
|
114 |
+
mo.md(r"""Let's apply Naive Bayes principles to this data (Tennis Training Dataset):""")
|
115 |
+
return
|
116 |
+
|
117 |
+
|
118 |
+
@app.cell(hide_code=True)
|
119 |
+
def _(mo):
|
120 |
+
mo.md(
|
121 |
+
r"""
|
122 |
+
## A Simple Example: Play Tennis
|
123 |
+
|
124 |
+
Let's understand Naive Bayes with a classic example: predicting whether someone will play tennis based on weather conditions. This is the same example used in Mahesh Huddar's excellent video.
|
125 |
+
|
126 |
+
Our dataset has these features:
|
127 |
+
- **Outlook**: Sunny, Overcast, Rainy
|
128 |
+
- **Temperature**: Hot, Mild, Cool
|
129 |
+
- **Humidity**: High, Normal
|
130 |
+
- **Wind**: Strong, Weak
|
131 |
+
|
132 |
+
And the target variable:
|
133 |
+
- **Play Tennis**: Yes, No
|
134 |
+
|
135 |
+
### Example Dataset
|
136 |
+
"""
|
137 |
+
)
|
138 |
+
|
139 |
+
# Create a dataset matching the image (in dict format for proper table rendering)
|
140 |
+
example_data = [
|
141 |
+
{"Day": "D1", "Outlook": "Sunny", "Temperature": "Hot", "Humidity": "High", "Wind": "Weak", "Play Tennis": "No"},
|
142 |
+
{"Day": "D2", "Outlook": "Sunny", "Temperature": "Hot", "Humidity": "High", "Wind": "Strong", "Play Tennis": "No"},
|
143 |
+
{"Day": "D3", "Outlook": "Overcast", "Temperature": "Hot", "Humidity": "High", "Wind": "Weak", "Play Tennis": "Yes"},
|
144 |
+
{"Day": "D4", "Outlook": "Rain", "Temperature": "Mild", "Humidity": "High", "Wind": "Weak", "Play Tennis": "Yes"},
|
145 |
+
{"Day": "D5", "Outlook": "Rain", "Temperature": "Cool", "Humidity": "Normal", "Wind": "Weak", "Play Tennis": "Yes"},
|
146 |
+
{"Day": "D6", "Outlook": "Rain", "Temperature": "Cool", "Humidity": "Normal", "Wind": "Strong", "Play Tennis": "No"},
|
147 |
+
{"Day": "D7", "Outlook": "Overcast", "Temperature": "Cool", "Humidity": "Normal", "Wind": "Strong", "Play Tennis": "Yes"},
|
148 |
+
{"Day": "D8", "Outlook": "Sunny", "Temperature": "Mild", "Humidity": "High", "Wind": "Weak", "Play Tennis": "No"},
|
149 |
+
{"Day": "D9", "Outlook": "Sunny", "Temperature": "Cool", "Humidity": "Normal", "Wind": "Weak", "Play Tennis": "Yes"},
|
150 |
+
{"Day": "D10", "Outlook": "Rain", "Temperature": "Mild", "Humidity": "Normal", "Wind": "Weak", "Play Tennis": "Yes"},
|
151 |
+
{"Day": "D11", "Outlook": "Sunny", "Temperature": "Mild", "Humidity": "Normal", "Wind": "Strong", "Play Tennis": "Yes"},
|
152 |
+
{"Day": "D12", "Outlook": "Overcast", "Temperature": "Mild", "Humidity": "High", "Wind": "Strong", "Play Tennis": "Yes"},
|
153 |
+
{"Day": "D13", "Outlook": "Overcast", "Temperature": "Hot", "Humidity": "Normal", "Wind": "Weak", "Play Tennis": "Yes"},
|
154 |
+
{"Day": "D14", "Outlook": "Rain", "Temperature": "Mild", "Humidity": "High", "Wind": "Strong", "Play Tennis": "No"}
|
155 |
+
]
|
156 |
+
|
157 |
+
# Display the tennis dataset using a table
|
158 |
+
example_table = mo.ui.table(
|
159 |
+
data=example_data,
|
160 |
+
selection=None
|
161 |
+
)
|
162 |
+
|
163 |
+
mo.vstack([
|
164 |
+
mo.md("#### Tennis Training Dataset"),
|
165 |
+
example_table
|
166 |
+
])
|
167 |
+
return example_data, example_table
|
168 |
+
|
169 |
+
|
170 |
+
@app.cell(hide_code=True)
|
171 |
+
def _(mo):
|
172 |
+
mo.md(
|
173 |
+
r"""
|
174 |
+
Let's predict whether someone will play tennis given these weather conditions:
|
175 |
+
|
176 |
+
- Outlook: Sunny
|
177 |
+
- Temperature: Cool
|
178 |
+
- Humidity: High
|
179 |
+
- Wind: Strong
|
180 |
+
|
181 |
+
Let's walk through the calculations step by step:
|
182 |
+
|
183 |
+
#### Step 1: Calculate Prior Probabilities
|
184 |
+
|
185 |
+
First, we calculate $P(Y=\text{Yes})$ and $P(Y=\text{No})$:
|
186 |
+
|
187 |
+
- $P(Y=\text{Yes}) = \frac{9}{14} = 0.64$
|
188 |
+
- $P(Y=\text{No}) = \frac{5}{14} = 0.36$
|
189 |
+
|
190 |
+
#### Step 2: Calculate Conditional Probabilities
|
191 |
+
|
192 |
+
Next, we calculate the conditional probabilities for each feature value given each class:
|
193 |
+
"""
|
194 |
+
)
|
195 |
+
return
|
196 |
+
|
197 |
+
|
198 |
+
@app.cell(hide_code=True)
|
199 |
+
def _(humidity_data, mo, outlook_data, summary_table, temp_data, wind_data):
|
200 |
+
# Display tables with appropriate styling
|
201 |
+
mo.vstack([
|
202 |
+
mo.md("#### Class Distribution"),
|
203 |
+
summary_table,
|
204 |
+
mo.md("#### Conditional Probabilities"),
|
205 |
+
mo.hstack([
|
206 |
+
mo.vstack([
|
207 |
+
mo.md("**Outlook**"),
|
208 |
+
mo.ui.table(
|
209 |
+
data=outlook_data,
|
210 |
+
selection=None
|
211 |
+
)
|
212 |
+
]),
|
213 |
+
mo.vstack([
|
214 |
+
mo.md("**Temperature**"),
|
215 |
+
mo.ui.table(
|
216 |
+
data=temp_data,
|
217 |
+
selection=None
|
218 |
+
)
|
219 |
+
])
|
220 |
+
]),
|
221 |
+
mo.hstack([
|
222 |
+
mo.vstack([
|
223 |
+
mo.md("**Humidity**"),
|
224 |
+
mo.ui.table(
|
225 |
+
data=humidity_data,
|
226 |
+
selection=None
|
227 |
+
)
|
228 |
+
]),
|
229 |
+
mo.vstack([
|
230 |
+
mo.md("**Wind**"),
|
231 |
+
mo.ui.table(
|
232 |
+
data=wind_data,
|
233 |
+
selection=None
|
234 |
+
)
|
235 |
+
])
|
236 |
+
])
|
237 |
+
])
|
238 |
+
return
|
239 |
+
|
240 |
+
|
241 |
+
@app.cell
|
242 |
+
def _():
|
243 |
+
# DIY
|
244 |
+
return
|
245 |
+
|
246 |
+
|
247 |
+
@app.cell(hide_code=True)
|
248 |
+
def _(mo, solution_accordion):
|
249 |
+
# Display the accordion
|
250 |
+
mo.accordion(solution_accordion)
|
251 |
+
return
|
252 |
+
|
253 |
+
|
254 |
+
@app.cell(hide_code=True)
|
255 |
+
def _(mo):
|
256 |
+
mo.md(
|
257 |
+
r"""
|
258 |
+
### Try a Different Example
|
259 |
+
|
260 |
+
What if the conditions were different? Let's say:
|
261 |
+
|
262 |
+
- Outlook: Overcast
|
263 |
+
- Temperature: Hot
|
264 |
+
- Humidity: Normal
|
265 |
+
- Wind: Weak
|
266 |
+
|
267 |
+
Try working through this example on your own. If you get stuck, you can use the tables above and apply the same method we used in the solution.
|
268 |
+
"""
|
269 |
+
)
|
270 |
+
return
|
271 |
+
|
272 |
+
|
273 |
+
@app.cell
|
274 |
+
def _():
|
275 |
+
# DIY
|
276 |
+
return
|
277 |
+
|
278 |
+
|
279 |
+
@app.cell(hide_code=True)
|
280 |
+
def _(mo):
|
281 |
+
mo.md(
|
282 |
+
r"""
|
283 |
+
## Interactive Naive Bayes
|
284 |
+
|
285 |
+
Let's explore Naive Bayes with an interactive visualization. This will help build intuition about how the algorithm makes predictions and how the naive independence assumption affects results.
|
286 |
+
"""
|
287 |
+
)
|
288 |
+
return
|
289 |
+
|
290 |
+
|
291 |
+
@app.cell(hide_code=True)
|
292 |
+
def gaussian_viz(
|
293 |
+
Ellipse,
|
294 |
+
GaussianNB,
|
295 |
+
ListedColormap,
|
296 |
+
class_sep_slider,
|
297 |
+
controls,
|
298 |
+
make_classification,
|
299 |
+
mo,
|
300 |
+
n_samples_slider,
|
301 |
+
noise_slider,
|
302 |
+
np,
|
303 |
+
pl,
|
304 |
+
plt,
|
305 |
+
regenerate_button,
|
306 |
+
train_test_split,
|
307 |
+
):
|
308 |
+
# get values from sliders
|
309 |
+
class_sep = class_sep_slider.value
|
310 |
+
noise_val = noise_slider.value
|
311 |
+
n_samples = int(n_samples_slider.value)
|
312 |
+
|
313 |
+
# check if regenerate button was clicked
|
314 |
+
regenerate_state = regenerate_button.value
|
315 |
+
|
316 |
+
# make a dataset with current settings
|
317 |
+
X, y = make_classification(
|
318 |
+
n_samples=n_samples,
|
319 |
+
n_features=2,
|
320 |
+
n_redundant=0,
|
321 |
+
n_informative=2,
|
322 |
+
n_clusters_per_class=1,
|
323 |
+
class_sep=class_sep * (1 - noise_val), # use noise to reduce separation
|
324 |
+
random_state=42 if not regenerate_state else np.random.randint(1000)
|
325 |
+
)
|
326 |
+
|
327 |
+
# put data in a dataframe
|
328 |
+
viz_df = pl.DataFrame({
|
329 |
+
"Feature1": X[:, 0],
|
330 |
+
"Feature2": X[:, 1],
|
331 |
+
"Class": y
|
332 |
+
})
|
333 |
+
|
334 |
+
# split into train/test
|
335 |
+
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)
|
336 |
+
|
337 |
+
# create naive bayes classifier
|
338 |
+
gnb = GaussianNB()
|
339 |
+
gnb.fit(X_train, y_train)
|
340 |
+
|
341 |
+
# setup grid for boundary visualization
|
342 |
+
h = 0.1 # step size
|
343 |
+
x_min, x_max = X[:, 0].min() - 1, X[:, 0].max() + 1
|
344 |
+
y_min, y_max = X[:, 1].min() - 1, X[:, 1].max() + 1
|
345 |
+
xx, yy = np.meshgrid(np.arange(x_min, x_max, h), np.arange(y_min, y_max, h))
|
346 |
+
|
347 |
+
# predict on grid points
|
348 |
+
grid_points = np.c_[xx.ravel(), yy.ravel()]
|
349 |
+
Z = gnb.predict(grid_points).reshape(xx.shape)
|
350 |
+
|
351 |
+
# calculate class stats
|
352 |
+
class0_mean = np.mean(X_train[y_train == 0], axis=0)
|
353 |
+
class1_mean = np.mean(X_train[y_train == 1], axis=0)
|
354 |
+
class0_var = np.var(X_train[y_train == 0], axis=0)
|
355 |
+
class1_var = np.var(X_train[y_train == 1], axis=0)
|
356 |
+
|
357 |
+
# format for display
|
358 |
+
class_stats = [
|
359 |
+
{"Class": "Class 0", "Feature1_Mean": f"{class0_mean[0]:.4f}", "Feature1_Variance": f"{class0_var[0]:.4f}",
|
360 |
+
"Feature2_Mean": f"{class0_mean[1]:.4f}", "Feature2_Variance": f"{class0_var[1]:.4f}"},
|
361 |
+
{"Class": "Class 1", "Feature1_Mean": f"{class1_mean[0]:.4f}", "Feature1_Variance": f"{class1_var[0]:.4f}",
|
362 |
+
"Feature2_Mean": f"{class1_mean[1]:.4f}", "Feature2_Variance": f"{class1_var[1]:.4f}"}
|
363 |
+
]
|
364 |
+
|
365 |
+
# setup plot with two panels
|
366 |
+
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 6))
|
367 |
+
|
368 |
+
# colors for our plots
|
369 |
+
cmap_light = ListedColormap(['#FFAAAA', '#AAAAFF']) # bg colors
|
370 |
+
cmap_bold = ListedColormap(['#FF0000', '#0000FF']) # point colors
|
371 |
+
|
372 |
+
# left: decision boundary
|
373 |
+
ax1.contourf(xx, yy, Z, alpha=0.3, cmap=cmap_light)
|
374 |
+
scatter1 = ax1.scatter(X_train[:, 0], X_train[:, 1], c=y_train,
|
375 |
+
cmap=cmap_bold, edgecolor='k', s=50, alpha=0.8)
|
376 |
+
scatter2 = ax1.scatter(X_test[:, 0], X_test[:, 1], c=y_test,
|
377 |
+
cmap=cmap_bold, edgecolor='k', s=25, alpha=0.5)
|
378 |
+
|
379 |
+
ax1.set_xlabel('Feature 1')
|
380 |
+
ax1.set_ylabel('Feature 2')
|
381 |
+
ax1.set_title('Gaussian Naive Bayes Decision Boundary')
|
382 |
+
ax1.legend([scatter1.legend_elements()[0][0], scatter2.legend_elements()[0][0]],
|
383 |
+
['Training Data', 'Test Data'], loc='upper right')
|
384 |
+
|
385 |
+
# right: distribution visualization
|
386 |
+
class0_data = viz_df.filter(pl.col("Class") == 0)
|
387 |
+
class1_data = viz_df.filter(pl.col("Class") == 1)
|
388 |
+
|
389 |
+
ax2.scatter(class0_data["Feature1"], class0_data["Feature2"],
|
390 |
+
color='red', edgecolor='k', s=50, alpha=0.8, label='Class 0')
|
391 |
+
ax2.scatter(class1_data["Feature1"], class1_data["Feature2"],
|
392 |
+
color='blue', edgecolor='k', s=50, alpha=0.8, label='Class 1')
|
393 |
+
|
394 |
+
# draw ellipses function
|
395 |
+
def plot_ellipse(ax, mean, cov, color):
|
396 |
+
vals, vecs = np.linalg.eigh(cov)
|
397 |
+
order = vals.argsort()[::-1]
|
398 |
+
vals = vals[order]
|
399 |
+
vecs = vecs[:, order]
|
400 |
+
|
401 |
+
theta = np.degrees(np.arctan2(*vecs[:, 0][::-1]))
|
402 |
+
width, height = 2 * np.sqrt(5.991 * vals)
|
403 |
+
|
404 |
+
ellip = Ellipse(xy=mean, width=width, height=height, angle=theta,
|
405 |
+
edgecolor=color, fc='None', lw=2, alpha=0.7)
|
406 |
+
ax.add_patch(ellip)
|
407 |
+
|
408 |
+
# add ellipses for each class accordingly
|
409 |
+
class0_cov = np.diag(np.var(X_train[y_train == 0], axis=0))
|
410 |
+
class1_cov = np.diag(np.var(X_train[y_train == 1], axis=0))
|
411 |
+
|
412 |
+
plot_ellipse(ax2, class0_mean, class0_cov, 'red')
|
413 |
+
plot_ellipse(ax2, class1_mean, class1_cov, 'blue')
|
414 |
+
|
415 |
+
ax2.set_xlabel('Feature 1')
|
416 |
+
ax2.set_ylabel('Feature 2')
|
417 |
+
ax2.set_title('Class-Conditional Distributions (Gaussian)')
|
418 |
+
ax2.legend(loc='upper right')
|
419 |
+
|
420 |
+
plt.tight_layout()
|
421 |
+
plt.gca()
|
422 |
+
|
423 |
+
# show interactive plot
|
424 |
+
mpl_fig = mo.mpl.interactive(fig)
|
425 |
+
|
426 |
+
# show parameters info
|
427 |
+
mo.md(
|
428 |
+
r"""
|
429 |
+
### gaussian parameters by class
|
430 |
+
|
431 |
+
each feature follows a normal distribution per class. here are the parameters:
|
432 |
+
"""
|
433 |
+
)
|
434 |
+
|
435 |
+
# make stats table
|
436 |
+
stats_table = mo.ui.table(
|
437 |
+
data=class_stats,
|
438 |
+
selection="single"
|
439 |
+
)
|
440 |
+
|
441 |
+
mo.md(
|
442 |
+
r"""
|
443 |
+
### how it works
|
444 |
+
|
445 |
+
1. calculate mean & variance for each feature per class
|
446 |
+
2. use gaussian pdf to get probabilities for new points
|
447 |
+
3. apply bayes' theorem to pick most likely class
|
448 |
+
|
449 |
+
ellipses show the distributions, decision boundary is where probabilities equal
|
450 |
+
"""
|
451 |
+
)
|
452 |
+
|
453 |
+
# stack everything together
|
454 |
+
mo.vstack([
|
455 |
+
controls.center(),
|
456 |
+
mpl_fig,
|
457 |
+
stats_table
|
458 |
+
])
|
459 |
+
return (
|
460 |
+
X,
|
461 |
+
X_test,
|
462 |
+
X_train,
|
463 |
+
Z,
|
464 |
+
ax1,
|
465 |
+
ax2,
|
466 |
+
class0_cov,
|
467 |
+
class0_data,
|
468 |
+
class0_mean,
|
469 |
+
class0_var,
|
470 |
+
class1_cov,
|
471 |
+
class1_data,
|
472 |
+
class1_mean,
|
473 |
+
class1_var,
|
474 |
+
class_sep,
|
475 |
+
class_stats,
|
476 |
+
cmap_bold,
|
477 |
+
cmap_light,
|
478 |
+
fig,
|
479 |
+
gnb,
|
480 |
+
grid_points,
|
481 |
+
h,
|
482 |
+
mpl_fig,
|
483 |
+
n_samples,
|
484 |
+
noise_val,
|
485 |
+
plot_ellipse,
|
486 |
+
regenerate_state,
|
487 |
+
scatter1,
|
488 |
+
scatter2,
|
489 |
+
stats_table,
|
490 |
+
viz_df,
|
491 |
+
x_max,
|
492 |
+
x_min,
|
493 |
+
xx,
|
494 |
+
y,
|
495 |
+
y_max,
|
496 |
+
y_min,
|
497 |
+
y_test,
|
498 |
+
y_train,
|
499 |
+
yy,
|
500 |
+
)
|
501 |
+
|
502 |
+
|
503 |
+
@app.cell(hide_code=True)
|
504 |
+
def _(mo):
|
505 |
+
mo.md(
|
506 |
+
r"""
|
507 |
+
### what's going on in this demo?
|
508 |
+
|
509 |
+
Playing with the sliders changes how our data looks and how the classifier behaves. Class separation controls how far apart the two classes are — higher values make them easier to tell apart. The noise slider adds randomness by reducing that separation, making boundaries fuzzier and classification harder. More samples just gives you more data points to work with.
|
510 |
+
|
511 |
+
The left graph shows the decision boundary — that curved line where the classifier switches from predicting one class to another. Red and blue regions show where naive bayes would classify new points. The right graph shows the actual distribution of both classes, with those ellipses representing the gaussian distributions naive bayes is using internally.
|
512 |
+
|
513 |
+
Try cranking up the noise and watch how the boundary gets messier. increase separation and see how confident the classifier becomes. This is basically what's happening inside naive bayes — it's looking at each feature's distribution per class and making the best guess based on probabilities. The table below shows the actual parameters (means and variances) the model calculates.
|
514 |
+
"""
|
515 |
+
)
|
516 |
+
return
|
517 |
+
|
518 |
+
|
519 |
+
@app.cell(hide_code=True)
|
520 |
+
def _(mo):
|
521 |
+
mo.md(
|
522 |
+
r"""
|
523 |
+
## Types of Naive Bayes Classifiers
|
524 |
+
|
525 |
+
### Multinomial Naive Bayes
|
526 |
+
Ideal for text classification where features represent word counts or frequencies.
|
527 |
+
|
528 |
+
Mathematical form:
|
529 |
+
|
530 |
+
\[P(x_i|y) = \frac{\text{count}(x_i, y) + \alpha}{\sum_{i=1}^{|V|} \text{count}(x_i, y) + \alpha|V|}\]
|
531 |
+
|
532 |
+
where:
|
533 |
+
|
534 |
+
- \(\alpha\) is the smoothing parameter
|
535 |
+
- \(|V|\) is the size of the vocabulary
|
536 |
+
- \(\text{count}(x_i, y)\) is the count of feature \(i\) in class \(y\)
|
537 |
+
|
538 |
+
### Bernoulli Naive Bayes
|
539 |
+
Best for binary features (0/1) — either a word appears or it doesn't.
|
540 |
+
|
541 |
+
Mathematical form:
|
542 |
+
|
543 |
+
\[P(x_i|y) = p_{iy}^{x_i}(1-p_{iy})^{(1-x_i)}\]
|
544 |
+
|
545 |
+
where:
|
546 |
+
|
547 |
+
- \(p_{iy}\) is the probability of feature \(i\) occurring in class \(y\)
|
548 |
+
- \(x_i\) is 1 if the feature is present, 0 otherwise
|
549 |
+
|
550 |
+
### Gaussian Naive Bayes
|
551 |
+
Designed for continuous features, assuming they follow a normal distribution.
|
552 |
+
|
553 |
+
Mathematical form:
|
554 |
+
|
555 |
+
\[P(x_i|y) = \frac{1}{\sqrt{2\pi\sigma_y^2}} \exp\left(-\frac{(x_i - \mu_y)^2}{2\sigma_y^2}\right)\]
|
556 |
+
|
557 |
+
where:
|
558 |
+
|
559 |
+
- \(\mu_y\) is the mean of feature values for class \(y\)
|
560 |
+
- \(\sigma_y^2\) is the variance of feature values for class \(y\)
|
561 |
+
|
562 |
+
### Complement Naive Bayes
|
563 |
+
Particularly effective for imbalanced datasets.
|
564 |
+
|
565 |
+
Mathematical form:
|
566 |
+
|
567 |
+
\[P(x_i|y) = \frac{\text{count}(x_i, \bar{y}) + \alpha}{\sum_{i=1}^{|V|} \text{count}(x_i, \bar{y}) + \alpha|V|}\]
|
568 |
+
|
569 |
+
where:
|
570 |
+
|
571 |
+
- \(\bar{y}\) represents all classes except \(y\)
|
572 |
+
- Other parameters are similar to Multinomial Naive Bayes
|
573 |
+
"""
|
574 |
+
)
|
575 |
+
return
|
576 |
+
|
577 |
+
|
578 |
+
@app.cell(hide_code=True)
|
579 |
+
def _(mo):
|
580 |
+
mo.md(
|
581 |
+
r"""
|
582 |
+
## 🤔 Test Your Understanding
|
583 |
+
|
584 |
+
Test your understanding of Naive Bayes with these statements:
|
585 |
+
|
586 |
+
/// details | Multiplying small probabilities in Naive Bayes can lead to numerical underflow.
|
587 |
+
✅ **Correct!** Multiplying many small probabilities can indeed lead to numerical underflow.
|
588 |
+
|
589 |
+
That's why in practice, we often use log probabilities and add them instead of multiplying the original probabilities. This prevents numerical underflow and improves computational stability.
|
590 |
+
///
|
591 |
+
|
592 |
+
/// details | Laplace smoothing is unnecessary if your training data covers all possible feature values.
|
593 |
+
❌ **Incorrect.** Laplace smoothing is still beneficial even with complete feature coverage.
|
594 |
+
|
595 |
+
While Laplace smoothing is crucial for handling unseen feature values, it also helps with small sample sizes by preventing overfitting to the training data. Even with complete feature coverage, some combinations might have very few examples, leading to unreliable probability estimates.
|
596 |
+
///
|
597 |
+
|
598 |
+
/// details | Naive Bayes performs poorly on high-dimensional data compared to other classifiers.
|
599 |
+
❌ **Incorrect.** Naive Bayes actually excels with high-dimensional data.
|
600 |
+
|
601 |
+
Due to its simplicity and the independence assumption, Naive Bayes scales very well to high-dimensional data. It's particularly effective for text classification where each word is a dimension and there can be thousands of dimensions. Other classifiers might overfit in such high-dimensional spaces.
|
602 |
+
///
|
603 |
+
|
604 |
+
/// details | For text classification, Multinomial Naive Bayes typically outperforms Gaussian Naive Bayes.
|
605 |
+
✅ **Correct!** Multinomial NB is better suited for text classification than Gaussian NB.
|
606 |
+
|
607 |
+
Text data typically involves discrete counts (word frequencies) which align better with a multinomial distribution. Gaussian Naive Bayes assumes features follow a normal distribution, which doesn't match the distribution of word frequencies in text documents.
|
608 |
+
///
|
609 |
+
"""
|
610 |
+
)
|
611 |
+
return
|
612 |
+
|
613 |
+
|
614 |
+
@app.cell(hide_code=True)
|
615 |
+
def _(mo):
|
616 |
+
mo.md(
|
617 |
+
r"""
|
618 |
+
## Summary
|
619 |
+
|
620 |
+
Throughout this notebook, we've explored Naive Bayes classification. What makes this algorithm particularly interesting is its elegant simplicity combined with surprising effectiveness. Despite making what seems like an overly simplistic assumption — that features are independent given the class — it consistently delivers reasonable performance across a wide range of applications.
|
621 |
+
|
622 |
+
The algorithm's power lies in its probabilistic foundation, built upon Bayes' theorem. During training, it simply learns probability distributions: the likelihood of seeing each class (prior probabilities) and the probability of feature values within each class (conditional probabilities). When making predictions, it combines these probabilities using the naive independence assumption, which dramatically simplifies the computation while still maintaining remarkable predictive power.
|
623 |
+
|
624 |
+
We've seen how different variants of Naive Bayes adapt to various types of data. Multinomial Naive Bayes excels at text classification by modeling word frequencies, Bernoulli Naive Bayes handles binary features elegantly, and Gaussian Naive Bayes tackles continuous data through normal distributions. Each variant maintains the core simplicity of the algorithm while adapting its probability calculations to match the data's characteristics.
|
625 |
+
|
626 |
+
Perhaps most importantly, we've learned that sometimes the most straightforward approaches can be the most practical. Naive Bayes demonstrates that a simple model, well-understood and properly applied, can often outperform more complex alternatives, especially in domains like text classification or when working with limited computational resources or training data.
|
627 |
+
"""
|
628 |
+
)
|
629 |
+
return
|
630 |
+
|
631 |
+
|
632 |
+
@app.cell(hide_code=True)
|
633 |
+
def _(mo):
|
634 |
+
mo.md(r"""appendix (helper code)""")
|
635 |
+
return
|
636 |
+
|
637 |
+
|
638 |
+
@app.cell
|
639 |
+
def _():
|
640 |
+
import marimo as mo
|
641 |
+
return (mo,)
|
642 |
+
|
643 |
+
|
644 |
+
@app.cell
|
645 |
+
def init_imports():
|
646 |
+
# imports for our notebook
|
647 |
+
import numpy as np
|
648 |
+
import matplotlib.pyplot as plt
|
649 |
+
import polars as pl
|
650 |
+
from scipy import stats
|
651 |
+
from sklearn.naive_bayes import GaussianNB
|
652 |
+
from sklearn.datasets import make_classification
|
653 |
+
from sklearn.model_selection import train_test_split
|
654 |
+
from matplotlib.colors import ListedColormap
|
655 |
+
from matplotlib.patches import Ellipse
|
656 |
+
|
657 |
+
# for consistent results
|
658 |
+
np.random.seed(42)
|
659 |
+
|
660 |
+
# nicer plots
|
661 |
+
plt.style.use('seaborn-v0_8-darkgrid')
|
662 |
+
return (
|
663 |
+
Ellipse,
|
664 |
+
GaussianNB,
|
665 |
+
ListedColormap,
|
666 |
+
make_classification,
|
667 |
+
np,
|
668 |
+
pl,
|
669 |
+
plt,
|
670 |
+
stats,
|
671 |
+
train_test_split,
|
672 |
+
)
|
673 |
+
|
674 |
+
|
675 |
+
@app.cell(hide_code=True)
|
676 |
+
def _(example_data, mo):
|
677 |
+
# occurrences count in example data
|
678 |
+
yes_count = sum(1 for row in example_data if row["Play Tennis"] == "Yes")
|
679 |
+
no_count = sum(1 for row in example_data if row["Play Tennis"] == "No")
|
680 |
+
total = len(example_data)
|
681 |
+
|
682 |
+
# summary table with dict format
|
683 |
+
summary_data = [
|
684 |
+
{"Class": "Yes", "Count": f"{yes_count}", "Probability": f"{yes_count/total:.2f}"},
|
685 |
+
{"Class": "No", "Count": f"{no_count}", "Probability": f"{no_count/total:.2f}"},
|
686 |
+
{"Class": "Total", "Count": f"{total}", "Probability": "1.00"}
|
687 |
+
]
|
688 |
+
|
689 |
+
summary_table = mo.ui.table(
|
690 |
+
data=summary_data,
|
691 |
+
selection=None
|
692 |
+
)
|
693 |
+
|
694 |
+
# tables for conditional probabilities matching the image (in dict format)
|
695 |
+
outlook_data = [
|
696 |
+
{"Outlook": "Sunny", "Y": "2/9", "N": "3/5"},
|
697 |
+
{"Outlook": "Overcast", "Y": "4/9", "N": "0"},
|
698 |
+
{"Outlook": "Rain", "Y": "3/9", "N": "2/5"}
|
699 |
+
]
|
700 |
+
|
701 |
+
temp_data = [
|
702 |
+
{"Temperature": "Hot", "Y": "2/9", "N": "2/5"},
|
703 |
+
{"Temperature": "Mild", "Y": "4/9", "N": "2/5"},
|
704 |
+
{"Temperature": "Cool", "Y": "3/9", "N": "1/5"}
|
705 |
+
]
|
706 |
+
|
707 |
+
humidity_data = [
|
708 |
+
{"Humidity": "High", "Y": "3/9", "N": "4/5"},
|
709 |
+
{"Humidity": "Normal", "Y": "6/9", "N": "1/5"}
|
710 |
+
]
|
711 |
+
|
712 |
+
wind_data = [
|
713 |
+
{"Wind": "Strong", "Y": "3/9", "N": "3/5"},
|
714 |
+
{"Wind": "Weak", "Y": "6/9", "N": "2/5"}
|
715 |
+
]
|
716 |
+
return (
|
717 |
+
humidity_data,
|
718 |
+
no_count,
|
719 |
+
outlook_data,
|
720 |
+
summary_data,
|
721 |
+
summary_table,
|
722 |
+
temp_data,
|
723 |
+
total,
|
724 |
+
wind_data,
|
725 |
+
yes_count,
|
726 |
+
)
|
727 |
+
|
728 |
+
|
729 |
+
@app.cell(hide_code=True)
|
730 |
+
def _(mo):
|
731 |
+
# accordion with solution (step-by-step)
|
732 |
+
solution_accordion = {
|
733 |
+
"step-by-step solution (click to expand)": mo.md(r"""
|
734 |
+
#### step 1: gather probabilities
|
735 |
+
|
736 |
+
from our tables:
|
737 |
+
|
738 |
+
**prior probabilities:**
|
739 |
+
|
740 |
+
- $P(Yes) = 9/14 = 0.64$
|
741 |
+
- $P(No) = 5/14 = 0.36$
|
742 |
+
|
743 |
+
**conditional probabilities:**
|
744 |
+
|
745 |
+
- $P(Outlook=Sunny|Yes) = 2/9$
|
746 |
+
- $P(Outlook=Sunny|No) = 3/5$
|
747 |
+
- $P(Temperature=Cool|Yes) = 3/9$
|
748 |
+
- $P(Temperature=Cool|No) = 1/5$
|
749 |
+
- $P(Humidity=High|Yes) = 3/9$
|
750 |
+
- $P(Humidity=High|No) = 4/5$
|
751 |
+
- $P(Wind=Strong|Yes) = 3/9$
|
752 |
+
- $P(Wind=Strong|No) = 3/5$
|
753 |
+
|
754 |
+
#### step 2: calculate for yes
|
755 |
+
|
756 |
+
$P(Yes) \times P(Sunny|Yes) \times P(Cool|Yes) \times P(High|Yes) \times P(Strong|Yes)$
|
757 |
+
|
758 |
+
$= \frac{9}{14} \times \frac{2}{9} \times \frac{3}{9} \times \frac{3}{9} \times \frac{3}{9}$
|
759 |
+
|
760 |
+
$= \frac{9}{14} \times \frac{2 \times 3 \times 3 \times 3}{9^4}$
|
761 |
+
|
762 |
+
$= \frac{9}{14} \times \frac{54}{6561}$
|
763 |
+
|
764 |
+
$= \frac{9 \times 54}{14 \times 6561}$
|
765 |
+
|
766 |
+
$= \frac{486}{91854}$
|
767 |
+
|
768 |
+
$= 0.0053$
|
769 |
+
|
770 |
+
#### step 3: calculate for no
|
771 |
+
|
772 |
+
$P(No) \times P(Sunny|No) \times P(Cool|No) \times P(High|No) \times P(Strong|No)$
|
773 |
+
|
774 |
+
$= \frac{5}{14} \times \frac{3}{5} \times \frac{1}{5} \times \frac{4}{5} \times \frac{3}{5}$
|
775 |
+
|
776 |
+
$= \frac{5}{14} \times \frac{3 \times 1 \times 4 \times 3}{5^4}$
|
777 |
+
|
778 |
+
$= \frac{5}{14} \times \frac{36}{625}$
|
779 |
+
|
780 |
+
$= \frac{5 \times 36}{14 \times 625}$
|
781 |
+
|
782 |
+
$= \frac{180}{8750}$
|
783 |
+
|
784 |
+
$= 0.0206$
|
785 |
+
|
786 |
+
#### step 4: normalize
|
787 |
+
|
788 |
+
sum of probabilities: $0.0053 + 0.0206 = 0.0259$
|
789 |
+
|
790 |
+
normalizing:
|
791 |
+
|
792 |
+
- $P(Yes|evidence) = \frac{0.0053}{0.0259} = 0.205$ (20.5%)
|
793 |
+
- $P(No|evidence) = \frac{0.0206}{0.0259} = 0.795$ (79.5%)
|
794 |
+
|
795 |
+
#### step 5: predict
|
796 |
+
|
797 |
+
since $P(No|evidence) > P(Yes|evidence)$, prediction: **No**
|
798 |
+
|
799 |
+
person would **not play tennis** under these conditions.
|
800 |
+
""")
|
801 |
+
}
|
802 |
+
return (solution_accordion,)
|
803 |
+
|
804 |
+
|
805 |
+
@app.cell(hide_code=True)
|
806 |
+
def create_gaussian_controls(mo):
|
807 |
+
# sliders for controlling viz parameters
|
808 |
+
class_sep_slider = mo.ui.slider(1.0, 3.0, value=1.5, label="Class Separation")
|
809 |
+
noise_slider = mo.ui.slider(0.1, 0.5, step=0.1, value=0.1, label="Noise (reduces class separation)")
|
810 |
+
n_samples_slider = mo.ui.slider(50, 200, value=100, step=10, label="Number of Samples")
|
811 |
+
|
812 |
+
# Create a run button to regenerate data
|
813 |
+
regenerate_button = mo.ui.run_button(label="Regenerate Data", kind="success")
|
814 |
+
|
815 |
+
# stack controls vertically
|
816 |
+
controls = mo.vstack([
|
817 |
+
mo.md("### visualization controls"),
|
818 |
+
class_sep_slider,
|
819 |
+
noise_slider,
|
820 |
+
n_samples_slider,
|
821 |
+
regenerate_button
|
822 |
+
])
|
823 |
+
return (
|
824 |
+
class_sep_slider,
|
825 |
+
controls,
|
826 |
+
n_samples_slider,
|
827 |
+
noise_slider,
|
828 |
+
regenerate_button,
|
829 |
+
)
|
830 |
+
|
831 |
+
|
832 |
+
if __name__ == "__main__":
|
833 |
+
app.run()
|
probability/21_logistic_regression.py
ADDED
@@ -0,0 +1,530 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# /// script
|
2 |
+
# requires-python = ">=3.10"
|
3 |
+
# dependencies = [
|
4 |
+
# "marimo",
|
5 |
+
# "matplotlib==3.10.1",
|
6 |
+
# "numpy==2.2.4",
|
7 |
+
# "drawdata==0.3.7",
|
8 |
+
# "scikit-learn==1.6.1",
|
9 |
+
# "polars==1.26.0",
|
10 |
+
# ]
|
11 |
+
# ///
|
12 |
+
|
13 |
+
import marimo
|
14 |
+
|
15 |
+
__generated_with = "0.12.5"
|
16 |
+
app = marimo.App(width="medium", app_title="Logistic Regression")
|
17 |
+
|
18 |
+
|
19 |
+
@app.cell(hide_code=True)
|
20 |
+
def _(mo):
|
21 |
+
mo.md(
|
22 |
+
r"""
|
23 |
+
# Logistic Regression
|
24 |
+
|
25 |
+
_This notebook is a computational companion to ["Probability for Computer Scientists"](https://chrispiech.github.io/probabilityForComputerScientists/en/part5/log_regression/), by Stanford professor Chris Piech._
|
26 |
+
|
27 |
+
Logistic regression learns a function approximating $P(y|x)$, and can be used to make a classifier. It makes the central assumption that $P(y|x)$ can be approximated as a sigmoid function applied to a linear combination of input features. It is particularly important to learn because logistic regression is the basic building block of artificial neural networks.
|
28 |
+
"""
|
29 |
+
)
|
30 |
+
return
|
31 |
+
|
32 |
+
|
33 |
+
@app.cell(hide_code=True)
|
34 |
+
def _(mo):
|
35 |
+
mo.md(
|
36 |
+
r"""
|
37 |
+
## The Binary Classification Problem
|
38 |
+
|
39 |
+
Imagine situations where we would like to know:
|
40 |
+
|
41 |
+
- The eligibility of getting a bank loan given the value of credit score ($x_{credit\_score}$) and monthly income ($x_{income}$)
|
42 |
+
- Identifying a tumor as benign or malignant given its size ($x_{tumor\_size}$)
|
43 |
+
- Classifying an email as promotional given the number of occurrences for some keywords like {'win', 'gift', 'discount'} ($x_{n\_win}$, $x_{n\_gift}$, $x_{n\_discount}$)
|
44 |
+
- Finding a monetary transaction as fraudulent given the time of occurrence ($x_{time\_stamp}$) and amount ($x_{amount}$)
|
45 |
+
|
46 |
+
These problems occur frequently in real life & can be dealt with machine learning. All such problems come under the umbrella of what is known as Classification. In each scenario, only one of the two possible outcomes can occur, hence these are specifically known as Binary Classification problems.
|
47 |
+
|
48 |
+
### How Does A Machine Perform Classification?
|
49 |
+
|
50 |
+
During the inference, the goal is to have the ML model predict the class label for a given set of feature values.
|
51 |
+
|
52 |
+
Specifically, a binary classification model estimates two probabilities $p_0$ & $p_1$ for 'class-0' and 'class-1' respectively where $p_0 + p_1 = 1$.
|
53 |
+
|
54 |
+
The predicted label depends on $\max(p_0, p_1)$ i.e., it's the one which is most probable based on the given features.
|
55 |
+
|
56 |
+
In logistic regression, $p_1$ (i.e., success probability) is compared with a predefined threshold $p$ to predict the class label like below:
|
57 |
+
|
58 |
+
$$\text{predicted class} =
|
59 |
+
\begin{cases}
|
60 |
+
1, & \text{if } p_1 \geq p \\
|
61 |
+
0, & \text{otherwise}
|
62 |
+
\end{cases}$$
|
63 |
+
|
64 |
+
To keep the notation simple and consistent, we will denote the success probability as $p$, and failure probability as $(1-p)$ instead of $p_1$ and $p_0$ respectively.
|
65 |
+
"""
|
66 |
+
)
|
67 |
+
return
|
68 |
+
|
69 |
+
|
70 |
+
@app.cell(hide_code=True)
|
71 |
+
def _(mo):
|
72 |
+
mo.md(
|
73 |
+
r"""
|
74 |
+
## Why NOT Linear Regression?
|
75 |
+
|
76 |
+
Can't we really use linear regression to address classification? The answer is NO! The key issue is that probabilities must be between 0 and 1 and linear regression can output any real number.
|
77 |
+
|
78 |
+
If we tried using linear regression directly:
|
79 |
+
$$p = \beta_0 + \beta_1 \cdot x_{feature}$$
|
80 |
+
|
81 |
+
This creates a problem: the right side can produce any value in $\mathbb{R}$ (all real numbers), but a probability $p$ must be confined to the range $(0,1)$.
|
82 |
+
|
83 |
+
Can we convert $(\beta_0 + \beta_1 \cdot x_{tumor\_size})$ to something belonging to $(0,1)$? That may work as an estimate of a probability! The answer is YES!
|
84 |
+
|
85 |
+
We need a converter (a function), say, $g()$ that will connect $p \in (0,1)$ to $(\beta_0 + \beta_1 \cdot x_{tumor\_size}) \in \mathbb{R}$.
|
86 |
+
|
87 |
+
The solution is to use a "link function" that maps from any real number to a valid probability range. This is where the sigmoid function comes in.
|
88 |
+
"""
|
89 |
+
)
|
90 |
+
return
|
91 |
+
|
92 |
+
|
93 |
+
@app.cell(hide_code=True)
|
94 |
+
def _(mo, np, plt):
|
95 |
+
# plot sigmoid to evidentiate above statements
|
96 |
+
_fig, ax = plt.subplots(figsize=(10, 6))
|
97 |
+
|
98 |
+
# x values
|
99 |
+
x = np.linspace(-10, 10, 1000)
|
100 |
+
|
101 |
+
# sigmoid formula
|
102 |
+
def sigmoid(z):
|
103 |
+
return 1 / (1 + np.exp(-z))
|
104 |
+
|
105 |
+
y = sigmoid(x)
|
106 |
+
|
107 |
+
# plot
|
108 |
+
ax.plot(x, y, 'b-', linewidth=2)
|
109 |
+
|
110 |
+
ax.axhline(y=0, color='k', linestyle='-', alpha=0.3)
|
111 |
+
ax.axhline(y=1, color='k', linestyle='-', alpha=0.3)
|
112 |
+
ax.axhline(y=0.5, color='r', linestyle='--', alpha=0.5)
|
113 |
+
|
114 |
+
# vertical line at x=0
|
115 |
+
ax.axvline(x=0, color='k', linestyle='-', alpha=0.3)
|
116 |
+
|
117 |
+
# annotations
|
118 |
+
ax.text(1, 0.85, r'$\sigma(z) = \frac{1}{1 + e^{-z}}$', fontsize=14)
|
119 |
+
ax.text(-9, 0.1, 'As z → -∞, σ(z) → 0', fontsize=12)
|
120 |
+
ax.text(3, 0.9, 'As z → ∞, σ(z) → 1', fontsize=12)
|
121 |
+
ax.text(0.5, 0.4, 'σ(0) = 0.5', fontsize=12)
|
122 |
+
|
123 |
+
# labels and title
|
124 |
+
ax.set_xlabel('z', fontsize=14)
|
125 |
+
ax.set_ylabel('σ(z)', fontsize=14)
|
126 |
+
ax.set_title('Sigmoid Function', fontsize=16)
|
127 |
+
|
128 |
+
# axis limits set
|
129 |
+
ax.set_xlim(-10, 10)
|
130 |
+
ax.set_ylim(-0.1, 1.1)
|
131 |
+
|
132 |
+
# grid
|
133 |
+
ax.grid(True, alpha=0.3)
|
134 |
+
|
135 |
+
mo.mpl.interactive(_fig)
|
136 |
+
return ax, sigmoid, x, y
|
137 |
+
|
138 |
+
|
139 |
+
@app.cell(hide_code=True)
|
140 |
+
def _(mo):
|
141 |
+
mo.md(
|
142 |
+
r"""
|
143 |
+
**Figure**: The sigmoid function maps any real number to a value between 0 and 1, making it perfect for representing probabilities.
|
144 |
+
|
145 |
+
/// note
|
146 |
+
For more information about the sigmoid function, head over to [this detailed notebook](http://marimo.app/https://github.com/marimo-team/deepml-notebooks/blob/main/problems/problem-22/notebook.py) for more insights.
|
147 |
+
///
|
148 |
+
"""
|
149 |
+
)
|
150 |
+
return
|
151 |
+
|
152 |
+
|
153 |
+
@app.cell(hide_code=True)
|
154 |
+
def _(mo):
|
155 |
+
mo.md(
|
156 |
+
r"""
|
157 |
+
## The Core Concept (math)
|
158 |
+
|
159 |
+
Logistic regression models the probability of class 1 using the sigmoid function:
|
160 |
+
|
161 |
+
$$P(Y=1|X=x) = \sigma(z) \text{ where } z = \theta_0 + \sum_{i=1}^m \theta_i x_i$$
|
162 |
+
|
163 |
+
The sigmoid function $\sigma(z)$ transforms any real number into a probability between 0 and 1:
|
164 |
+
|
165 |
+
$$\sigma(z) = \frac{1}{1+ e^{-z}}$$
|
166 |
+
|
167 |
+
This can be written more compactly using vector notation:
|
168 |
+
|
169 |
+
$$P(Y=1|\mathbf{X}=\mathbf{x}) =\sigma(\mathbf{\theta}^T\mathbf{x}) \quad \text{ where we always set $x_0$ to be 1}$$
|
170 |
+
|
171 |
+
$$P(Y=0|\mathbf{X}=\mathbf{x}) =1-\sigma(\mathbf{\theta}^T\mathbf{x}) \quad \text{ by total law of probability}$$
|
172 |
+
|
173 |
+
Where $\theta$ represents the model parameters that need to be learned from data, and $x$ is the feature vector (with $x_0=1$ to account for the intercept term).
|
174 |
+
|
175 |
+
> **Note:** For the detailed mathematical derivation of how these parameters are learned through Maximum Likelihood Estimation (MLE) and Gradient Descent (GD), please refer to [Chris Piech's original material](https://chrispiech.github.io/probabilityForComputerScientists/en/part5/log_regression/). The mathematical details are elegant but beyond the scope of this notebook topic (which is confined to Logistic Regression).
|
176 |
+
"""
|
177 |
+
)
|
178 |
+
return
|
179 |
+
|
180 |
+
|
181 |
+
@app.cell(hide_code=True)
|
182 |
+
def _(mo):
|
183 |
+
mo.md(
|
184 |
+
r"""
|
185 |
+
### Linear Decision Boundary
|
186 |
+
|
187 |
+
A key characteristic of logistic regression is that it creates a linear decision boundary. When the model predicts, it's effectively dividing the feature space with a straight line (in 2D) or hyperplane (in higher dimensions). It is actually a straight line (of the form $y = mx + c$).
|
188 |
+
|
189 |
+
Recall the prediction rule:
|
190 |
+
$$\text{predicted class} =
|
191 |
+
\begin{cases}
|
192 |
+
1, & \text{if } p \geq \theta_0 + \theta_1 \cdot x_{tumor\_size} \Rightarrow \log\frac{p}{1-p} \\
|
193 |
+
0, & \text{otherwise}
|
194 |
+
\end{cases}$$
|
195 |
+
|
196 |
+
For a two-feature model, the decision boundary where $P(Y=1|X=x) = 0.5$ occurs at:
|
197 |
+
$$\theta_0 + \theta_1 x_1 + \theta_2 x_2 = 0$$
|
198 |
+
|
199 |
+
A simple logistic regression predicts the class label by identifying the regions on either side of a straight line (or hyperplane in general), hence it's a _linear_ classifier.
|
200 |
+
|
201 |
+
This linear nature makes logistic regression effective for linearly separable classes but limited when dealing with more complex patterns.
|
202 |
+
"""
|
203 |
+
)
|
204 |
+
return
|
205 |
+
|
206 |
+
|
207 |
+
@app.cell(hide_code=True)
|
208 |
+
def _(mo):
|
209 |
+
mo.md("""### Visual: Linear Separability and Classification""")
|
210 |
+
return
|
211 |
+
|
212 |
+
|
213 |
+
@app.cell(hide_code=True)
|
214 |
+
def _(mo, np, plt):
|
215 |
+
# show relevant comparison to the above concepts/statements
|
216 |
+
|
217 |
+
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))
|
218 |
+
|
219 |
+
# Linear separable data
|
220 |
+
np.random.seed(42)
|
221 |
+
X1 = np.random.randn(100, 2) - 2
|
222 |
+
X2 = np.random.randn(100, 2) + 2
|
223 |
+
|
224 |
+
ax1.scatter(X1[:, 0], X1[:, 1], color='blue', alpha=0.5)
|
225 |
+
ax1.scatter(X2[:, 0], X2[:, 1], color='red', alpha=0.5)
|
226 |
+
|
227 |
+
# Decision boundary (line)
|
228 |
+
ax1.plot([-5, 5], [5, -5], 'k--', linewidth=2)
|
229 |
+
ax1.set_xlim(-5, 5)
|
230 |
+
ax1.set_ylim(-5, 5)
|
231 |
+
ax1.set_title('Linearly Separable Classes')
|
232 |
+
|
233 |
+
# non-linear separable data
|
234 |
+
radius = 2
|
235 |
+
theta = np.linspace(0, 2*np.pi, 100)
|
236 |
+
|
237 |
+
# Outer circle points (class 1)
|
238 |
+
outer_x = 3 * np.cos(theta)
|
239 |
+
outer_y = 3 * np.sin(theta)
|
240 |
+
# Inner circle points (class 2)
|
241 |
+
inner_x = 1.5 * np.cos(theta) + np.random.randn(100) * 0.2
|
242 |
+
inner_y = 1.5 * np.sin(theta) + np.random.randn(100) * 0.2
|
243 |
+
|
244 |
+
ax2.scatter(outer_x, outer_y, color='blue', alpha=0.5)
|
245 |
+
ax2.scatter(inner_x, inner_y, color='red', alpha=0.5)
|
246 |
+
|
247 |
+
# Attempt to draw a linear boundary (which won't work well) proving the point
|
248 |
+
ax2.plot([-5, 5], [2, 2], 'k--', linewidth=2)
|
249 |
+
|
250 |
+
ax2.set_xlim(-5, 5)
|
251 |
+
ax2.set_ylim(-5, 5)
|
252 |
+
ax2.set_title('Non-Linearly Separable Classes')
|
253 |
+
|
254 |
+
fig.tight_layout()
|
255 |
+
mo.mpl.interactive(fig)
|
256 |
+
return (
|
257 |
+
X1,
|
258 |
+
X2,
|
259 |
+
ax1,
|
260 |
+
ax2,
|
261 |
+
fig,
|
262 |
+
inner_x,
|
263 |
+
inner_y,
|
264 |
+
outer_x,
|
265 |
+
outer_y,
|
266 |
+
radius,
|
267 |
+
theta,
|
268 |
+
)
|
269 |
+
|
270 |
+
|
271 |
+
@app.cell(hide_code=True)
|
272 |
+
def _(mo):
|
273 |
+
mo.md(r"""**Figure**: On the left, the classes are linearly separable as the boundary is a straight line. However, they are not linearly separable on the right, where no straight line can properly separate the two classes.""")
|
274 |
+
return
|
275 |
+
|
276 |
+
|
277 |
+
@app.cell(hide_code=True)
|
278 |
+
def _(mo):
|
279 |
+
mo.md(
|
280 |
+
r"""
|
281 |
+
Logistic regression is typically trained using MLE - finding the parameters $\theta$ that make our observed data most probable.
|
282 |
+
|
283 |
+
The optimization process generally uses GD (or its variants) to iteratively improve the parameters. The gradient has a surprisingly elegant form:
|
284 |
+
|
285 |
+
$$\frac{\partial LL(\theta)}{\partial \theta_j} = \sum_{i=1}^n \left[
|
286 |
+
y^{(i)} - \sigma(\theta^T x^{(i)})
|
287 |
+
\right] x_j^{(i)}$$
|
288 |
+
|
289 |
+
This shows that the update to each parameter depends on the prediction error (actual - predicted) multiplied by the feature value.
|
290 |
+
|
291 |
+
For those interested in the complete mathematical derivation, including log likelihood calculation and the detailed steps of GD (and relevant pseudocode followed for training), please see the [original lecture notes](https://chrispiech.github.io/probabilityForComputerScientists/en/part5/log_regression/).
|
292 |
+
"""
|
293 |
+
)
|
294 |
+
return
|
295 |
+
|
296 |
+
|
297 |
+
@app.cell(hide_code=True)
|
298 |
+
def _(controls, mo, widget):
|
299 |
+
# create the layout
|
300 |
+
mo.vstack([
|
301 |
+
mo.md("## Interactive drawing demo\nDraw points of two different classes and see how logistic regression separates them. _The interactive demo was adapted and improvised from [Vincent Warmerdam's](https://github.com/koaning) code [here](https://github.com/probabl-ai/youtube-appendix/blob/main/04-drawing-data/notebook.ipynb)_."),
|
302 |
+
controls,
|
303 |
+
widget
|
304 |
+
])
|
305 |
+
return
|
306 |
+
|
307 |
+
|
308 |
+
@app.cell(hide_code=True)
|
309 |
+
def _(LogisticRegression, mo, np, plt, run_button, widget):
|
310 |
+
warning_msg = mo.md(""" /// warning
|
311 |
+
Need more data, please draw points of at least two different colors in the scatter widget
|
312 |
+
""")
|
313 |
+
|
314 |
+
# mo.stop if button isn't clicked yet
|
315 |
+
mo.stop(
|
316 |
+
not run_button.value,
|
317 |
+
mo.md(""" /// tip
|
318 |
+
click 'Run Logistic Regression' to see the model
|
319 |
+
""")
|
320 |
+
)
|
321 |
+
|
322 |
+
# get data from widget (can also use as_pandas)
|
323 |
+
df = widget.data_as_polars
|
324 |
+
|
325 |
+
# display appropriate warning
|
326 |
+
mo.stop(
|
327 |
+
df.is_empty() or df['color'].n_unique() < 2,
|
328 |
+
warning_msg
|
329 |
+
)
|
330 |
+
|
331 |
+
# extract features and labels
|
332 |
+
X = df[['x', 'y']].to_numpy()
|
333 |
+
y_colors = df['color'].to_numpy()
|
334 |
+
|
335 |
+
# fit logistic regression model
|
336 |
+
model = LogisticRegression()
|
337 |
+
model.fit(X, y_colors)
|
338 |
+
|
339 |
+
# create grid for the viz
|
340 |
+
x_min, x_max = X[:, 0].min() - 0.5, X[:, 0].max() + 0.5
|
341 |
+
y_min, y_max = X[:, 1].min() - 0.5, X[:, 1].max() + 0.5
|
342 |
+
xx, yy = np.meshgrid(
|
343 |
+
np.linspace(x_min, x_max, 100),
|
344 |
+
np.linspace(y_min, y_max, 100)
|
345 |
+
)
|
346 |
+
|
347 |
+
# get probability predictions
|
348 |
+
Z = model.predict_proba(np.c_[xx.ravel(), yy.ravel()])[:, 1]
|
349 |
+
Z = Z.reshape(xx.shape)
|
350 |
+
|
351 |
+
# create figure
|
352 |
+
_fig, ax_fig = plt.subplots(figsize=(12, 8))
|
353 |
+
|
354 |
+
# plot decision boundary (probability contours)
|
355 |
+
contour = ax_fig.contourf(
|
356 |
+
xx, yy, Z,
|
357 |
+
levels=np.linspace(0, 1, 11),
|
358 |
+
alpha=0.7,
|
359 |
+
cmap="RdBu_r"
|
360 |
+
)
|
361 |
+
|
362 |
+
# plot decision boundary line (probability = 0.5)
|
363 |
+
ax_fig.contour(
|
364 |
+
xx, yy, Z,
|
365 |
+
levels=[0.5],
|
366 |
+
colors='k',
|
367 |
+
linewidths=2
|
368 |
+
)
|
369 |
+
|
370 |
+
# plot the data points (use same colors as in the widget)
|
371 |
+
ax_fig.scatter(X[:, 0], X[:, 1], c=y_colors, edgecolor='k', s=80)
|
372 |
+
|
373 |
+
# colorbar
|
374 |
+
plt.colorbar(contour, ax=ax_fig)
|
375 |
+
|
376 |
+
# labels and title
|
377 |
+
ax_fig.set_xlabel('x')
|
378 |
+
ax_fig.set_ylabel('y')
|
379 |
+
ax_fig.set_title('Logistic Regression')
|
380 |
+
|
381 |
+
# model params
|
382 |
+
coef = model.coef_[0]
|
383 |
+
intercept = model.intercept_[0]
|
384 |
+
equation = f"log(p/(1-p)) = {intercept:.2f} + {coef[0]:.3f}x₁ + {coef[1]:.3f}x₂"
|
385 |
+
|
386 |
+
# relevant info in regards to regression
|
387 |
+
model_info = mo.md(f"""
|
388 |
+
### Logistic regression model
|
389 |
+
|
390 |
+
**Equation**: {equation}
|
391 |
+
|
392 |
+
**Decision boundary**: probability = 0.5
|
393 |
+
|
394 |
+
**Accuracy**: {model.score(X, y_colors):.2f}
|
395 |
+
""")
|
396 |
+
|
397 |
+
# show results vertically stacked
|
398 |
+
mo.vstack([
|
399 |
+
mo.mpl.interactive(_fig),
|
400 |
+
model_info
|
401 |
+
])
|
402 |
+
return (
|
403 |
+
X,
|
404 |
+
Z,
|
405 |
+
ax_fig,
|
406 |
+
coef,
|
407 |
+
contour,
|
408 |
+
df,
|
409 |
+
equation,
|
410 |
+
intercept,
|
411 |
+
model,
|
412 |
+
model_info,
|
413 |
+
warning_msg,
|
414 |
+
x_max,
|
415 |
+
x_min,
|
416 |
+
xx,
|
417 |
+
y_colors,
|
418 |
+
y_max,
|
419 |
+
y_min,
|
420 |
+
yy,
|
421 |
+
)
|
422 |
+
|
423 |
+
|
424 |
+
@app.cell(hide_code=True)
|
425 |
+
def _(mo):
|
426 |
+
mo.md(
|
427 |
+
r"""
|
428 |
+
## 🤔 Key Takeaways
|
429 |
+
|
430 |
+
Click on the statements below that you think are correct to verify your understanding:
|
431 |
+
|
432 |
+
/// details | Logistic regression tries to find parameters (θ) that minimize the error between predicted and actual values using ordinary least squares.
|
433 |
+
❌ **Incorrect.** Logistic regression uses maximum likelihood estimation (MLE), not ordinary least squares. It finds parameters that maximize the probability of observing the training data, which is different from minimizing squared errors as in linear regression.
|
434 |
+
///
|
435 |
+
|
436 |
+
/// details | The sigmoid function maps any real number to a value between 0 and 1, which allows logistic regression to output probabilities.
|
437 |
+
✅ **Correct!** The sigmoid function σ(z) = 1/(1+e^(-z)) takes any real number as input and outputs a value between 0 and 1. This is perfect for representing probabilities and is a key component of logistic regression.
|
438 |
+
///
|
439 |
+
|
440 |
+
/// details | The decision boundary in logistic regression is always a straight line, regardless of the data's complexity.
|
441 |
+
✅ **Correct!** Standard logistic regression produces a linear decision boundary (a straight line in 2D or a hyperplane in higher dimensions). This is why it works well for linearly separable data but struggles with more complex patterns, like concentric circles (as you might've noticed from the interactive demo).
|
442 |
+
///
|
443 |
+
|
444 |
+
/// details | The logistic regression model params are typically initialized to random values and refined through gradient descent.
|
445 |
+
✅ **Correct!** Parameters are often initialized to zeros or small random values, then updated iteratively using gradient descent (or ascent for maximizing likelihood) until convergence.
|
446 |
+
///
|
447 |
+
|
448 |
+
/// details | Logistic regression can naturally handle multi-class classification problems without any modifications.
|
449 |
+
❌ **Incorrect.** Standard logistic regression is inherently a binary classifier. To handle multi-class classification, techniques like one-vs-rest or softmax regression are typically used.
|
450 |
+
///
|
451 |
+
"""
|
452 |
+
)
|
453 |
+
return
|
454 |
+
|
455 |
+
|
456 |
+
@app.cell(hide_code=True)
|
457 |
+
def _(mo):
|
458 |
+
mo.md(
|
459 |
+
r"""
|
460 |
+
## Summary
|
461 |
+
|
462 |
+
So we've just explored logistic regression. Despite its name (seriously though, why not call it "logistic classification"?), it's actually quite elegant in how it transforms a simple linear model into a powerful decision _boundary_ maker.
|
463 |
+
|
464 |
+
The training process boils down to finding the values of θ that maximize the likelihood of seeing our training data. What's super cool is that even though the math looks _scary_ at first, the gradient has this surprisingly simple form: just the error (y - predicted) multiplied by the feature values.
|
465 |
+
|
466 |
+
Two key insights to remember:
|
467 |
+
|
468 |
+
- Logistic regression creates a _linear_ decision boundary, so it works great for linearly separable classes but struggles with more _complex_ patterns
|
469 |
+
- It directly gives you probabilities, not just classifications, which is incredibly useful when you need confidence measures
|
470 |
+
"""
|
471 |
+
)
|
472 |
+
return
|
473 |
+
|
474 |
+
|
475 |
+
@app.cell(hide_code=True)
|
476 |
+
def _(mo):
|
477 |
+
mo.md(
|
478 |
+
r"""
|
479 |
+
Additional resources referred to:
|
480 |
+
|
481 |
+
- [Logistic Regression Tutorial by _Koushik Khan_](https://koushikkhan.github.io/resources/pdf/tutorials/logistic_regression_tutorial.pdf)
|
482 |
+
"""
|
483 |
+
)
|
484 |
+
return
|
485 |
+
|
486 |
+
|
487 |
+
@app.cell(hide_code=True)
|
488 |
+
def _(mo):
|
489 |
+
mo.md(r"""Appendix (helper code)""")
|
490 |
+
return
|
491 |
+
|
492 |
+
|
493 |
+
@app.cell
|
494 |
+
def _():
|
495 |
+
import marimo as mo
|
496 |
+
return (mo,)
|
497 |
+
|
498 |
+
|
499 |
+
@app.cell
|
500 |
+
def init_imports():
|
501 |
+
# imports for our notebook
|
502 |
+
import numpy as np
|
503 |
+
import matplotlib.pyplot as plt
|
504 |
+
from drawdata import ScatterWidget
|
505 |
+
from sklearn.linear_model import LogisticRegression
|
506 |
+
|
507 |
+
|
508 |
+
# for consistent results
|
509 |
+
np.random.seed(42)
|
510 |
+
|
511 |
+
# nicer plots
|
512 |
+
plt.style.use('seaborn-v0_8-darkgrid')
|
513 |
+
return LogisticRegression, ScatterWidget, np, plt
|
514 |
+
|
515 |
+
|
516 |
+
@app.cell(hide_code=True)
|
517 |
+
def _(ScatterWidget, mo):
|
518 |
+
# drawing widget
|
519 |
+
widget = mo.ui.anywidget(ScatterWidget())
|
520 |
+
|
521 |
+
# run_button to run model
|
522 |
+
run_button = mo.ui.run_button(label="Run Logistic Regression", kind="success")
|
523 |
+
|
524 |
+
# stack controls
|
525 |
+
controls = mo.hstack([run_button])
|
526 |
+
return controls, run_button, widget
|
527 |
+
|
528 |
+
|
529 |
+
if __name__ == "__main__":
|
530 |
+
app.run()
|