etrotta commited on
Commit
192b61d
·
1 Parent(s): cbef791

Add data exploration viz draft

Browse files
Files changed (2) hide show
  1. .gitignore +4 -1
  2. polars/05_reactive_plots.py +419 -0
.gitignore CHANGED
@@ -168,4 +168,7 @@ cython_debug/
168
  #.idea/
169
 
170
  # PyPI configuration file
171
- .pypirc
 
 
 
 
168
  #.idea/
169
 
170
  # PyPI configuration file
171
+ .pypirc
172
+
173
+ # Marimo specific
174
+ __marimo__
polars/05_reactive_plots.py ADDED
@@ -0,0 +1,419 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # /// script
2
+ # requires-python = ">=3.10"
3
+ # dependencies = [
4
+ # "marimo",
5
+ # "numpy==2.2.3",
6
+ # "plotly[express]==6.0.0",
7
+ # "polars==1.23.0",
8
+ # "statsmodels==0.14.4",
9
+ # ]
10
+ # ///
11
+
12
+ import marimo
13
+
14
+ __generated_with = "0.11.13"
15
+ app = marimo.App(width="medium")
16
+
17
+
18
+ @app.cell
19
+ def _():
20
+ import marimo as mo
21
+ import polars as pl
22
+ import plotly.express as px
23
+ return mo, pl, px
24
+
25
+
26
+ @app.cell(hide_code=True)
27
+ def _(mo):
28
+ mo.md(
29
+ """
30
+ For this tutorial, we will be using the a [Spotify Tracks dataset](https://huggingface.co/datasets/maharshipandya/spotify-tracks-dataset).
31
+
32
+ Note that it does not contains data about ***all*** tracks, you can try using a larger dataset such as [bigdata-pw/Spotify](https://huggingface.co/datasets/bigdata-pw/Spotify), but I'm sticking with the smaller one to keep the notebook size managable for most users.
33
+
34
+ You should always take a look at the data you are working on before actually doing any operations on it - for data coming from sources such as HuggingFace or Kaggle you may want to look in their websites, then filter or do some transformations before downloading.
35
+
36
+ Let's say that looking at it in the Data Viewer, we decided we do not want the Unnamed column (which appears to be the row index), nor do we care about the original ID, and we only want non-explicit tracks.
37
+ """
38
+ )
39
+ return
40
+
41
+
42
+ @app.cell
43
+ def _(pl):
44
+ repo_id, branch, file_path = (
45
+ "maharshipandya/spotify-tracks-dataset",
46
+ "~parquet",
47
+ "default/train/0000.parquet",
48
+ )
49
+ URL = f"hf://datasets/{repo_id}@{branch}/{file_path}"
50
+ lz = pl.scan_parquet(URL)
51
+ df = (
52
+ lz
53
+ # Filter data we consider relevant (somewhat arbitrary in this example)
54
+ .filter(pl.col("explicit") == False)
55
+ .drop("Unnamed: 0", "track_id", "explicit")
56
+ .with_columns(
57
+ # Some random transformations for example,
58
+ # Transform a String column with few unique values into Categorical to occupy less memory
59
+ pl.col("track_genre").cast(pl.Categorical()),
60
+ # Convert the duration from miliseconds to seconds (int)
61
+ pl.col("duration_ms").floordiv(1_000).alias("duration_seconds"),
62
+ # Convert the popularity from an integer 0 ~ 100 to a percentage 0 ~ 1.0
63
+ pl.col("popularity").truediv(100),
64
+ )
65
+ # lastly, download and collect into memory
66
+ .collect()
67
+ )
68
+ df
69
+ return URL, branch, df, file_path, lz, repo_id
70
+
71
+
72
+ @app.cell(hide_code=True)
73
+ def _(mo):
74
+ mo.md(
75
+ r"""
76
+ We may want to start by investigating any values that seem weird, to verify if there could be issues in the data, in bugs in our pipelines, or if our understanding of it is wrong.
77
+
78
+ For example, the "min" value for the duration column is zero, and the max is over an hour. Why is that?
79
+ """
80
+ )
81
+ return
82
+
83
+
84
+ @app.cell(disabled=True)
85
+ def _(df, pl):
86
+ # We *could* just filter some of the rows and look at them as a table, for example...
87
+ pl.concat([df.sort("duration_ms").head(5), df.sort("duration_ms", descending=True).head(5)])
88
+ # But creating a visualisation for this helps paint the full picture of how the data is distributed, rather than focusing *only* on some outiers
89
+ return
90
+
91
+
92
+ @app.cell
93
+ def _(df, mo, pl, px):
94
+ # Let's visualize it and get a feel for which region makes sense to focus on for our analysis
95
+ duration_counts = df.group_by("duration_seconds").len("count")
96
+ fig = px.bar(duration_counts, x="duration_seconds", y="count")
97
+ fig.update_layout(selectdirection="h")
98
+ fig.add_selection(x0=120, y0=0, x1=360, y1=duration_counts.select(pl.col("count").max()).item())
99
+ plot = mo.ui.plotly(fig)
100
+ plot
101
+ return duration_counts, fig, plot
102
+
103
+
104
+ @app.cell(hide_code=True)
105
+ def _(mo):
106
+ mo.md(
107
+ """
108
+ The previous cell set a default, but you can and should try moving it around a bit.
109
+
110
+ Note how there are a few outliers with extremely little duration (less than 2 minutes) and a few with extremely long duration (more than 6 minutes)
111
+
112
+ We will focus on those within that middle ground from around 120 seconds to 360 seconds, but you can play around with it a bit and see how the results change if you move the Selection region. Perhaps you can even find some Classical songs?
113
+ """
114
+ )
115
+ return
116
+
117
+
118
+ @app.cell
119
+ def _(pl, plot):
120
+ # We can see our selection and use it as a filter:
121
+ pl.DataFrame(plot.value)
122
+ return
123
+
124
+
125
+ @app.cell
126
+ def _(df, pl, plot):
127
+ min_dur, max_dur = (
128
+ min(row["duration_seconds"] for row in plot.value),
129
+ max(row["duration_seconds"] for row in plot.value),
130
+ )
131
+
132
+ # Calculate how many we are keeping vs throwing away with the filter
133
+ duration_in_range = pl.col("duration_seconds").is_between(min_dur, max_dur)
134
+ print(
135
+ f"Filtering to keep rows between {min_dur}s and {max_dur}s duration - Throwing away {df.select(1 - duration_in_range.mean()).item():.2%} of the rows"
136
+ )
137
+
138
+ # Actually filter
139
+ filtered_duration = df.filter(duration_in_range)
140
+ filtered_duration
141
+ return duration_in_range, filtered_duration, max_dur, min_dur
142
+
143
+
144
+ @app.cell(hide_code=True)
145
+ def _(mo):
146
+ mo.md(
147
+ r"""
148
+ Now that our data is clean, let's start making some more analises over it. Some example questions:
149
+
150
+ - Which tracks or artists are the most popular? (Both globally as well as for each genre)
151
+ - Which genres are the most popular? The loudest?
152
+ - What are some common combinations of different artists?
153
+ - What can we infer anything based on the track's title or artist name?
154
+ - How popular is some specific song you like?
155
+ - How much does the mode and key affect other attributes?
156
+ - Can you classify a song's genre based on its attributes?
157
+
158
+ For brevity, we will not explore all of them - feel free to try some of the others yourself, or go more in deep in the explored ones.
159
+ """
160
+ )
161
+ return
162
+
163
+
164
+ @app.cell
165
+ def _(filter_genre, filtered_duration, mo, pl):
166
+ # Now, if you saw the Dataset description or looked closely at the Artists column you may notice there are some rows with multiple artists separated by ;;. We will have to separate each of these.
167
+ most_popular_artists = (
168
+ filtered_duration.lazy()
169
+ .with_columns(pl.col("artists").str.split(";"))
170
+ # Spoiler for the next cell! Remember that in marimo you can do things 'out of order'
171
+ .filter(True if filter_genre.value is None else pl.col("track_genre").eq(filter_genre.value))
172
+ .explode("artists")
173
+ .group_by("artists")
174
+ .agg(
175
+ # Now, how we aggregate it is also a question.
176
+ # Do we take the sum of each of their songs popularity?
177
+ # Do we just take their most popular song?
178
+ # Do we take an average of their songs popularity?
179
+ # We'll proceed with the average of their top 10 most popular songs for now,
180
+ # but that is something you may want to modify and experiment with.
181
+ pl.col("popularity").top_k(10).mean(),
182
+ # Let's also take some of their most popular albums songs for reference:
183
+ pl.col("track_name").sort_by("popularity").unique(maintain_order=True).top_k(5),
184
+ pl.col("album_name").sort_by("popularity").unique(maintain_order=True).top_k(5),
185
+ pl.col("track_genre").top_k_by("popularity", k=1).alias("Most popular genre"),
186
+ # And for good measure, see how many total tracks they have
187
+ pl.col("track_name").n_unique().alias("tracks_count")
188
+ )
189
+ .collect()
190
+ )
191
+ mo.md("Let's start with the Most popular artists")
192
+ return (most_popular_artists,)
193
+
194
+
195
+ @app.cell
196
+ def _(most_popular_artists, pl):
197
+ # Just adjust the formatting for displaying columns that include multiple values in the same line
198
+ most_popular_artists.with_columns(pl.col(pl.List(pl.String())).list.join("\n")).sort("popularity", descending=True)
199
+ return
200
+
201
+
202
+ @app.cell
203
+ def _(filtered_duration, mo):
204
+ # Recognize any of your favourite songs? Me neither. Let's try adding a filter by genre
205
+ filter_genre = mo.ui.dropdown(options=filtered_duration["track_genre"].unique().sort().to_list(), allow_select_none=True, value=None, searchable=True, label="Filter by Track Genre:")
206
+ filter_genre
207
+ return (filter_genre,)
208
+
209
+
210
+ @app.cell(hide_code=True)
211
+ def _(mo):
212
+ mo.md(
213
+ r"""
214
+ So far so good - but there's been a distinct lack of visualations, so let's fix that.
215
+
216
+ Let's start simple, just some metrics for each genre:
217
+ """
218
+ )
219
+ return
220
+
221
+
222
+ @app.cell
223
+ def _(filtered_duration, pl, px):
224
+ fig_dur_per_genre = px.scatter(
225
+ filtered_duration.group_by("track_genre").agg(
226
+ pl.col("duration_seconds", "popularity").mean().round(2),
227
+ ).sort("track_genre", descending=True),
228
+ hover_name="track_genre",
229
+ y="duration_seconds",
230
+ x="popularity",
231
+ )
232
+ fig_dur_per_genre
233
+ return (fig_dur_per_genre,)
234
+
235
+
236
+ @app.cell(hide_code=True)
237
+ def _(mo):
238
+ mo.md(
239
+ r"""
240
+ Now, why don't we play a bit with morimo's UI elements?
241
+
242
+ We will use Dropdowns to allow for the user to select any column to use for the visualisation, and throw in some extras
243
+
244
+ - A slider for the transparency to help understand dense clusters
245
+ - Add a Trendline to the scatterplot (requires statsmodels)
246
+ - Filter by some specific Genre
247
+ """
248
+ )
249
+ return
250
+
251
+
252
+ @app.cell
253
+ def _(filtered_duration, mo):
254
+ # Let's start by making some comparisons, scatter plots are a nice way to get a feel for how dependent a variable is on another
255
+ options = [
256
+ "duration_seconds",
257
+ "popularity",
258
+ "danceability",
259
+ "energy",
260
+ "key",
261
+ "loudness",
262
+ "mode",
263
+ "speechiness",
264
+ "acousticness",
265
+ "instrumentalness",
266
+ "liveness",
267
+ "valence",
268
+ "tempo",
269
+ ]
270
+ x_axis = mo.ui.dropdown(options, value="energy", label="X")
271
+ y_axis = mo.ui.dropdown(options, value="danceability", label="Y")
272
+ color = mo.ui.dropdown(options, value="loudness", allow_select_none=True, searchable=True, label="Color column")
273
+ alpha = mo.ui.slider(start=0.01, stop=1.0, step=0.01, value=0.1, label="Alpha", show_value=True)
274
+ include_trendline = mo.ui.checkbox(label="Trendline")
275
+ # We *could* reuse the same filter_genre as above, but it would cause marimo to rerun both the table and the graph whenever we change either
276
+ filter_genre2 = mo.ui.dropdown(options=filtered_duration["track_genre"].unique().sort().to_list(), allow_select_none=True, value=None, searchable=True, label="Filter by Track Genre:")
277
+ x_axis, y_axis, color, alpha, include_trendline, filter_genre2
278
+ return (
279
+ alpha,
280
+ color,
281
+ filter_genre2,
282
+ include_trendline,
283
+ options,
284
+ x_axis,
285
+ y_axis,
286
+ )
287
+
288
+
289
+ @app.cell
290
+ def _(
291
+ alpha,
292
+ color,
293
+ filter_genre2,
294
+ filtered_duration,
295
+ include_trendline,
296
+ mo,
297
+ pl,
298
+ px,
299
+ x_axis,
300
+ y_axis,
301
+ ):
302
+ fig2 = px.scatter(
303
+ filtered_duration.filter((pl.col("track_genre") == filter_genre2.value) if filter_genre2.value is not None else True),
304
+ x=x_axis.value,
305
+ y=y_axis.value,
306
+ color=color.value,
307
+ opacity=alpha.value,
308
+ trendline="lowess" if include_trendline.value else None,
309
+ )
310
+ chart2 = mo.ui.plotly(fig2)
311
+ chart2
312
+ return chart2, fig2
313
+
314
+
315
+ @app.cell(hide_code=True)
316
+ def _(mo):
317
+ mo.md(
318
+ r"""
319
+ As we have seen before, we can also use the plot as an input to select a region and look at it in more detail.
320
+
321
+ Try selecting a region then performing some explorations of your own with the data inside of it.
322
+ """
323
+ )
324
+ return
325
+
326
+
327
+ @app.cell
328
+ def _(chart2, filtered_duration, mo, pl):
329
+ # Let's look at which sort of songs were included in that region
330
+ if len(chart2.value) == 0:
331
+ out = mo.md("No data found in selection")
332
+ else:
333
+ active_columns = list(chart2.value[0].keys())
334
+ column_order = ["track_name", *active_columns, "album_name", "artists"]
335
+ out = filtered_duration.join(pl.DataFrame(chart2.value).unique(), on=active_columns).select(pl.col(column_order), pl.exclude(*column_order))
336
+ out
337
+ return active_columns, column_order, out
338
+
339
+
340
+ @app.cell(hide_code=True)
341
+ def _():
342
+ # Appendix : Some other examples
343
+ return
344
+
345
+
346
+ @app.cell
347
+ def _(mo):
348
+ # Components to filter for some specific song
349
+ filter_artist = mo.ui.text(label="Artist: ")
350
+ filter_track = mo.ui.text(label="Track name: ")
351
+ return filter_artist, filter_track
352
+
353
+
354
+ @app.cell(disabled=True)
355
+ def _(filtered_duration, mo, pl):
356
+ # Note that we cannot use dropdown due to the sheer number of elements being enormous:
357
+ all_artists = filtered_duration.select(pl.col("artists").str.split(';').explode().unique().sort())['artists'].to_list()
358
+ all_tracks = filtered_duration['track_name'].unique().sort().to_list()
359
+ filter_artist = mo.ui.dropdown(all_artists, value=None, searchable=True)
360
+ filter_track = mo.ui.dropdown(all_tracks, value=None, searchable=True)
361
+ # So we just provide freeform text boxes and filter ourselfves later
362
+ return all_artists, all_tracks, filter_artist, filter_track
363
+
364
+
365
+ @app.cell
366
+ def _(filter_artist, filter_track, filtered_duration, mo, pl):
367
+ def score_match_text(col: pl.Expr, string: str | None) -> pl.Expr:
368
+ if not string:
369
+ return pl.lit(0)
370
+ col = col.str.to_lowercase()
371
+ string = string.casefold()
372
+ return (
373
+ # For a more professional use case, you might want to look into string distance functions
374
+ # in the polars-dspolars-ds package or other polars plugins
375
+ - col.str.len_chars().cast(pl.Int32())
376
+ + pl.when(col.str.contains(string)).then(50).otherwise(0)
377
+ + pl.when(col.str.starts_with(string)).then(50).otherwise(0)
378
+ )
379
+
380
+ filtered_artist_track = filtered_duration.select(
381
+ pl.col("artists"),
382
+ pl.col("track_name"),
383
+ (score_match_text(pl.col("track_name"), filter_track.value)
384
+ + pl.col('artists').str.split(';').list.eval(score_match_text(pl.element(), filter_artist.value)).list.sum()).alias("match_score"),
385
+ pl.col("album_name"),
386
+ pl.col("track_genre"),
387
+ pl.col("popularity"),
388
+ pl.col("duration_seconds"),
389
+ ).filter(pl.col("match_score") > 0).sort("match_score", descending=True)
390
+
391
+ mo.md("Filter a track based on its name or artist"), filter_artist, filter_track, filtered_artist_track
392
+ return filtered_artist_track, score_match_text
393
+
394
+
395
+ @app.cell
396
+ def _(filter_genre2, filtered_duration, mo, pl):
397
+ # Artists combinations
398
+ artist_combinations = (
399
+ filtered_duration
400
+ .lazy()
401
+ .filter((pl.col("track_genre") == filter_genre2.value) if filter_genre2.value is not None else True)
402
+ .with_columns(pl.col("artists").str.split(';'))
403
+ .with_columns(pl.col("artists").alias("other_artist"))
404
+ .explode("artists")
405
+ .explode("other_artist")
406
+ # Filter to:
407
+ # 1) Remove an artist with themselves
408
+ # 2) Remove duplicate combinations, otherwise we would have once row for (A, B) and one for (B, A)
409
+ .filter(pl.col("artists") > pl.col("other_artist"))
410
+ .group_by("artists", "other_artist")
411
+ .len("count")
412
+ .collect()
413
+ )
414
+ mo.md("Check which artists collaborate with others most often (reuses the last genre filter)"), filter_genre2, artist_combinations.sort("count", descending=True)
415
+ return (artist_combinations,)
416
+
417
+
418
+ if __name__ == "__main__":
419
+ app.run()