Spaces:
Runtime error
Runtime error
Add range in heatmap
Browse files- pages/clustering.py +53 -30
- recommender.py +1 -1
- recommender_system.py +0 -1
pages/clustering.py
CHANGED
@@ -7,7 +7,7 @@ import altair as alt
|
|
7 |
from sklearn.mixture import GaussianMixture
|
8 |
import plotly.express as px
|
9 |
import itertools
|
10 |
-
from typing import Dict, List
|
11 |
|
12 |
|
13 |
SIDEBAR_DESCRIPTION = """
|
@@ -15,10 +15,10 @@ SIDEBAR_DESCRIPTION = """
|
|
15 |
|
16 |
To cluster a client, we adopt the RFM metrics. They stand for:
|
17 |
|
18 |
-
- R = recency, that is the number of days since the last purchase
|
19 |
in the store
|
20 |
- F = frequency, that is the number of times a customer has ordered something
|
21 |
-
- M = monetary value, that is how much a customer has spent buying
|
22 |
from your business.
|
23 |
|
24 |
Given these 3 metrics, we can cluster the customers and find a suitable
|
@@ -28,8 +28,8 @@ we're using right now has about 5000 distinct customers, we identify
|
|
28 |
|
29 |
## How we compute the clusters
|
30 |
|
31 |
-
We resort to a GaussianMixture algorithm. We can think of GaussianMixture
|
32 |
-
as generalized k-means clustering that incorporates information about
|
33 |
the covariance structure of the data as well as the centers of the clusters.
|
34 |
""".lstrip()
|
35 |
|
@@ -46,7 +46,7 @@ There 3 available clusters for this metric:
|
|
46 |
""".lstrip()
|
47 |
|
48 |
RECENCY_CLUSTERS_EXPLAIN = """
|
49 |
-
The **recency** refers to how recently a customer has bought;
|
50 |
|
51 |
There 3 available clusters for this metric:
|
52 |
|
@@ -58,7 +58,7 @@ There 3 available clusters for this metric:
|
|
58 |
""".lstrip()
|
59 |
|
60 |
MONETARY_CLUSTERS_EXPLAIN = """
|
61 |
-
The **revenue** refers to how much a customer has spent buying
|
62 |
from your business.
|
63 |
|
64 |
There 3 available clusters for this metric:
|
@@ -115,7 +115,7 @@ def cluster_clients(df: pd.DataFrame):
|
|
115 |
|
116 |
|
117 |
def _order_cluster(cluster_model: GaussianMixture, clusters, order="ascending"):
|
118 |
-
"""Orders the cluster by order
|
119 |
centroids = cluster_model.means_.sum(axis=1)
|
120 |
|
121 |
if order.lower() == "descending":
|
@@ -191,7 +191,10 @@ def explain_cluster(cluster_info):
|
|
191 |
" and values"
|
192 |
)
|
193 |
for cluster, info in cluster_info.items():
|
194 |
-
|
|
|
|
|
|
|
195 |
|
196 |
|
197 |
def categorize_user(recency_cluster, frequency_cluster, monetary_cluster):
|
@@ -231,7 +234,9 @@ def categorize_user(recency_cluster, frequency_cluster, monetary_cluster):
|
|
231 |
st.write(f"The customer can be described as: **{description}**")
|
232 |
|
233 |
|
234 |
-
def plot_rfm_distribution(
|
|
|
|
|
235 |
"""Plots 3 histograms for the RFM metrics."""
|
236 |
|
237 |
for x, to_reverse in zip(("Revenue", "Frequency", "Recency"), (False, False, True)):
|
@@ -241,20 +246,21 @@ def plot_rfm_distribution(df_rfm: pd.DataFrame, cluster_info: Dict[str, List[int
|
|
241 |
log_y=True,
|
242 |
title=f"{x} metric",
|
243 |
)
|
244 |
-
# Get the max value in the cluster info. The
|
245 |
-
# values
|
246 |
-
values
|
|
|
247 |
print(values)
|
248 |
# Add vertical bar on each cluster end. But skip the last cluster.
|
249 |
-
loop_range =
|
250 |
if to_reverse:
|
251 |
-
#
|
252 |
-
loop_range =
|
253 |
-
for n_cluster
|
254 |
print(x)
|
255 |
-
print(values[
|
256 |
fig.add_vline(
|
257 |
-
x=values[
|
258 |
annotation_text=f"End of cluster {n_cluster+1}",
|
259 |
line_dash="dot",
|
260 |
annotation=dict(textangle=90, font_color="red"),
|
@@ -267,13 +273,20 @@ def plot_rfm_distribution(df_rfm: pd.DataFrame, cluster_info: Dict[str, List[int
|
|
267 |
st.plotly_chart(fig)
|
268 |
|
269 |
|
270 |
-
def display_dataframe_heatmap(df_rfm: pd.DataFrame):
|
271 |
"""Displays an heatmap of how many clients lay in the clusters.
|
272 |
|
273 |
This method uses some black magic coming from the dataframe
|
274 |
styling guide.
|
275 |
"""
|
276 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
277 |
# Create a dataframe with the count of clients for each group
|
278 |
# of cluster.
|
279 |
|
@@ -291,6 +304,13 @@ def display_dataframe_heatmap(df_rfm: pd.DataFrame):
|
|
291 |
["Revenue_cluster", "Frequency_cluster", "Recency_cluster"]
|
292 |
)
|
293 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
294 |
# Use the count column as values, then index with the clusters.
|
295 |
count = count.pivot(
|
296 |
index=["Revenue_cluster", "Frequency_cluster"],
|
@@ -301,15 +321,15 @@ def display_dataframe_heatmap(df_rfm: pd.DataFrame):
|
|
301 |
# Style manipulation
|
302 |
cell_hover = {
|
303 |
"selector": "td",
|
304 |
-
"props": "font-size:1.
|
305 |
}
|
306 |
index_names = {
|
307 |
"selector": ".index_name",
|
308 |
-
"props": "font-style: italic; color: Black; font-weight:normal;font-size:1.
|
309 |
}
|
310 |
headers = {
|
311 |
"selector": "th:not(.index_name)",
|
312 |
-
"props": "background-color: White; color: black; font-size:1.
|
313 |
}
|
314 |
|
315 |
# Finally, display
|
@@ -336,7 +356,7 @@ def main():
|
|
336 |
"# Dataset "
|
337 |
"\nThis is the processed dataset with information about the clients, such as"
|
338 |
" the RFM values and the clusters they belong to."
|
339 |
-
|
340 |
st.dataframe(df_rfm.style.format(formatter={"Revenue": "{:.2f}"}))
|
341 |
|
342 |
cluster_info_dict = defaultdict(list)
|
@@ -351,15 +371,14 @@ def main():
|
|
351 |
)
|
352 |
min_cluster = cluster_info["min"].astype(int)
|
353 |
max_cluster = cluster_info["max"].astype(int)
|
354 |
-
|
355 |
-
cluster_info_dict[cluster].extend(min_max_interlieved)
|
356 |
st.dataframe(cluster_info)
|
357 |
|
358 |
st.markdown("## RFM metric distribution")
|
359 |
|
360 |
plot_rfm_distribution(df_rfm, cluster_info_dict)
|
361 |
|
362 |
-
display_dataframe_heatmap(df_rfm)
|
363 |
|
364 |
st.markdown("## Interactive exploration")
|
365 |
|
@@ -369,9 +388,13 @@ def main():
|
|
369 |
)
|
370 |
|
371 |
client_to_select = (
|
372 |
-
df_rfm.groupby(["Recency_cluster", "Frequency_cluster", "Revenue_cluster"])[
|
373 |
-
|
374 |
-
|
|
|
|
|
|
|
|
|
375 |
)
|
376 |
|
377 |
# Let the user select the user to investigate
|
|
|
7 |
from sklearn.mixture import GaussianMixture
|
8 |
import plotly.express as px
|
9 |
import itertools
|
10 |
+
from typing import Dict, List, Tuple
|
11 |
|
12 |
|
13 |
SIDEBAR_DESCRIPTION = """
|
|
|
15 |
|
16 |
To cluster a client, we adopt the RFM metrics. They stand for:
|
17 |
|
18 |
+
- R = recency, that is the number of days since the last purchase
|
19 |
in the store
|
20 |
- F = frequency, that is the number of times a customer has ordered something
|
21 |
+
- M = monetary value, that is how much a customer has spent buying
|
22 |
from your business.
|
23 |
|
24 |
Given these 3 metrics, we can cluster the customers and find a suitable
|
|
|
28 |
|
29 |
## How we compute the clusters
|
30 |
|
31 |
+
We resort to a GaussianMixture algorithm. We can think of GaussianMixture
|
32 |
+
as generalized k-means clustering that incorporates information about
|
33 |
the covariance structure of the data as well as the centers of the clusters.
|
34 |
""".lstrip()
|
35 |
|
|
|
46 |
""".lstrip()
|
47 |
|
48 |
RECENCY_CLUSTERS_EXPLAIN = """
|
49 |
+
The **recency** refers to how recently a customer has bought;
|
50 |
|
51 |
There 3 available clusters for this metric:
|
52 |
|
|
|
58 |
""".lstrip()
|
59 |
|
60 |
MONETARY_CLUSTERS_EXPLAIN = """
|
61 |
+
The **revenue** refers to how much a customer has spent buying
|
62 |
from your business.
|
63 |
|
64 |
There 3 available clusters for this metric:
|
|
|
115 |
|
116 |
|
117 |
def _order_cluster(cluster_model: GaussianMixture, clusters, order="ascending"):
|
118 |
+
"""Orders the cluster by `order`."""
|
119 |
centroids = cluster_model.means_.sum(axis=1)
|
120 |
|
121 |
if order.lower() == "descending":
|
|
|
191 |
" and values"
|
192 |
)
|
193 |
for cluster, info in cluster_info.items():
|
194 |
+
# Transform the (mins, maxs) tuple into
|
195 |
+
# [min_1, max_1, min_2, max_2, ...] list.
|
196 |
+
min_max_interleaved = list(itertools.chain(*zip(info[0], info[1])))
|
197 |
+
st.write(EXPLANATION_DICT[cluster].format(*min_max_interleaved))
|
198 |
|
199 |
|
200 |
def categorize_user(recency_cluster, frequency_cluster, monetary_cluster):
|
|
|
234 |
st.write(f"The customer can be described as: **{description}**")
|
235 |
|
236 |
|
237 |
+
def plot_rfm_distribution(
|
238 |
+
df_rfm: pd.DataFrame, cluster_info: Dict[str, Tuple[List[int], List[int]]]
|
239 |
+
):
|
240 |
"""Plots 3 histograms for the RFM metrics."""
|
241 |
|
242 |
for x, to_reverse in zip(("Revenue", "Frequency", "Recency"), (False, False, True)):
|
|
|
246 |
log_y=True,
|
247 |
title=f"{x} metric",
|
248 |
)
|
249 |
+
# Get the max value in the cluster info. The cluster_info_dict is a
|
250 |
+
# tuple with first element the min values of the cluster, and second
|
251 |
+
# element the max values of the cluster.
|
252 |
+
values = cluster_info[f"{x}_cluster"][1] # get max values
|
253 |
print(values)
|
254 |
# Add vertical bar on each cluster end. But skip the last cluster.
|
255 |
+
loop_range = range(len(values) - 1)
|
256 |
if to_reverse:
|
257 |
+
# Skip the last element
|
258 |
+
loop_range = range(len(values) - 1, 0, -1)
|
259 |
+
for n_cluster in loop_range:
|
260 |
print(x)
|
261 |
+
print(values[n_cluster])
|
262 |
fig.add_vline(
|
263 |
+
x=values[n_cluster],
|
264 |
annotation_text=f"End of cluster {n_cluster+1}",
|
265 |
line_dash="dot",
|
266 |
annotation=dict(textangle=90, font_color="red"),
|
|
|
273 |
st.plotly_chart(fig)
|
274 |
|
275 |
|
276 |
+
def display_dataframe_heatmap(df_rfm: pd.DataFrame, cluster_info_dict):
|
277 |
"""Displays an heatmap of how many clients lay in the clusters.
|
278 |
|
279 |
This method uses some black magic coming from the dataframe
|
280 |
styling guide.
|
281 |
"""
|
282 |
|
283 |
+
def style_with_limits(x, column, cluster_limit_dict):
|
284 |
+
"""Simple function to transform the cluster number into
|
285 |
+
a cluster + range string."""
|
286 |
+
min_v = cluster_limit_dict[column][0][x - 1]
|
287 |
+
max_v = cluster_limit_dict[column][1][x - 1]
|
288 |
+
return f"{x}: [{int(min_v)}, {int(max_v)}]"
|
289 |
+
|
290 |
# Create a dataframe with the count of clients for each group
|
291 |
# of cluster.
|
292 |
|
|
|
304 |
["Revenue_cluster", "Frequency_cluster", "Recency_cluster"]
|
305 |
)
|
306 |
|
307 |
+
# Add limits to the cells. In this way, we can better display
|
308 |
+
# the heatmap.
|
309 |
+
for cluster in ["Revenue_cluster", "Frequency_cluster", "Recency_cluster"]:
|
310 |
+
count[cluster] = count[cluster].apply(
|
311 |
+
lambda x: style_with_limits(x, cluster, cluster_info_dict)
|
312 |
+
)
|
313 |
+
|
314 |
# Use the count column as values, then index with the clusters.
|
315 |
count = count.pivot(
|
316 |
index=["Revenue_cluster", "Frequency_cluster"],
|
|
|
321 |
# Style manipulation
|
322 |
cell_hover = {
|
323 |
"selector": "td",
|
324 |
+
"props": "font-size:1.2em",
|
325 |
}
|
326 |
index_names = {
|
327 |
"selector": ".index_name",
|
328 |
+
"props": "font-style: italic; color: Black; font-weight:normal;font-size:1.2em;",
|
329 |
}
|
330 |
headers = {
|
331 |
"selector": "th:not(.index_name)",
|
332 |
+
"props": "background-color: White; color: black; font-size:1.2em",
|
333 |
}
|
334 |
|
335 |
# Finally, display
|
|
|
356 |
"# Dataset "
|
357 |
"\nThis is the processed dataset with information about the clients, such as"
|
358 |
" the RFM values and the clusters they belong to."
|
359 |
+
)
|
360 |
st.dataframe(df_rfm.style.format(formatter={"Revenue": "{:.2f}"}))
|
361 |
|
362 |
cluster_info_dict = defaultdict(list)
|
|
|
371 |
)
|
372 |
min_cluster = cluster_info["min"].astype(int)
|
373 |
max_cluster = cluster_info["max"].astype(int)
|
374 |
+
cluster_info_dict[cluster] = (min_cluster, max_cluster)
|
|
|
375 |
st.dataframe(cluster_info)
|
376 |
|
377 |
st.markdown("## RFM metric distribution")
|
378 |
|
379 |
plot_rfm_distribution(df_rfm, cluster_info_dict)
|
380 |
|
381 |
+
display_dataframe_heatmap(df_rfm, cluster_info_dict)
|
382 |
|
383 |
st.markdown("## Interactive exploration")
|
384 |
|
|
|
388 |
)
|
389 |
|
390 |
client_to_select = (
|
391 |
+
df_rfm.groupby(["Recency_cluster", "Frequency_cluster", "Revenue_cluster"])[
|
392 |
+
"CustomerID"
|
393 |
+
]
|
394 |
+
.first()
|
395 |
+
.values
|
396 |
+
if filter_by_cluster
|
397 |
+
else df["CustomerID"].unique()
|
398 |
)
|
399 |
|
400 |
# Let the user select the user to investigate
|
recommender.py
CHANGED
@@ -82,7 +82,7 @@ class Recommender:
|
|
82 |
def recommend_products(
|
83 |
self,
|
84 |
user_id,
|
85 |
-
items_to_recommend
|
86 |
):
|
87 |
"""Finds the recommended items for the user.
|
88 |
|
|
|
82 |
def recommend_products(
|
83 |
self,
|
84 |
user_id,
|
85 |
+
items_to_recommend=5,
|
86 |
):
|
87 |
"""Finds the recommended items for the user.
|
88 |
|
recommender_system.py
CHANGED
@@ -242,7 +242,6 @@ def display_recommendation_plots(
|
|
242 |
items_other_description = _extract_description(df, bought_by_similar_users)
|
243 |
suggestion_description = _extract_description(df, suggestions)
|
244 |
|
245 |
-
|
246 |
# Plot the scatterplot
|
247 |
|
248 |
fig = go.Figure()
|
|
|
242 |
items_other_description = _extract_description(df, bought_by_similar_users)
|
243 |
suggestion_description = _extract_description(df, suggestions)
|
244 |
|
|
|
245 |
# Plot the scatterplot
|
246 |
|
247 |
fig = go.Figure()
|