Spaces:
Sleeping
Sleeping
import sys | |
from pathlib import Path | |
import string | |
import random | |
import torch | |
import numpy as np | |
import pickle | |
import gradio as gr | |
import pandas as pd | |
from scipy.special import softmax | |
import numpy as np | |
import seaborn as sns | |
import matplotlib.pyplot as plt | |
import hydra | |
from omegaconf import open_dict, DictConfig | |
import matplotlib.pyplot as plt | |
import matplotlib | |
from matplotlib.patches import Patch | |
sns.set() | |
sns.set_style("darkgrid") | |
from utils.data import * | |
from utils.metrics import * | |
def user_interface(Ufile, Pfile, Sfile=None, job_meta_file=None, user_meta_file=None, user_groups=None): | |
recdata = Data(Ufile, Pfile, Sfile, job_meta_file, user_meta_file, user_groups) | |
def calculate_user_item_metrics(res, S, U, k=10): | |
# get rec | |
m, n = res.shape | |
if not torch.is_tensor(res): | |
res = torch.from_numpy(res) | |
if not torch.is_tensor(U): | |
U = torch.from_numpy(U) | |
_, rec = torch.topk(res, k, dim=1) | |
rec_onehot = slow_onehot(rec, res) | |
# rec_onehot = F.one_hot(rec, num_classes=n).sum(1).float() | |
try: | |
rec_per_job = rec_onehot.sum(axis=0).numpy() | |
except: | |
rec_per_job = rec_onehot.sum(axis=0).cpu().numpy() | |
rec = rec.cpu() | |
S = S.cpu() | |
# envy | |
envy = expected_envy_torch_vec(U, rec_onehot, k=1).numpy() | |
# competitors for each rec job | |
competitors = get_competitors(rec_per_job, rec) | |
# rank | |
better_competitors = get_num_better_competitors(rec, S) | |
# scores per job for later zoom in scores | |
scores = get_scores_per_job(rec, S) | |
return {'rec': rec, 'envy': envy, 'competitors': competitors, 'ranks': better_competitors, 'scores_job': scores} | |
def plot_user_envy(user=0, k=2): | |
plt.close('all') | |
user = int(user) | |
if k in recdata.lookup_dict: | |
ret_dict = recdata.lookup_dict[k] | |
else: | |
ret_dict = calculate_user_item_metrics(recdata.P_sub, recdata.S_sub, recdata.U_sub, k=k) | |
recdata.lookup_dict[k] = ret_dict | |
# user's recommended jobs | |
users_rec = ret_dict['rec'][user].numpy() | |
# Plot | |
fig, ax1 = plt.subplots(figsize=(10, 5)) | |
# fig.tight_layout() | |
fig.subplots_adjust(bottom=0.2) | |
envy = ret_dict['envy'].sum(-1) | |
envy_user = envy[user] | |
# plot envy histogram | |
n, bins, patches = ax1.hist(envy, bins=30, color='grey', alpha=0.5) | |
ax1.set_yscale('symlog') | |
sns.kdeplot(envy, color='grey', bw_adjust=0.3, cut=0, ax=ax1) | |
# mark this user's envy | |
# index of the bin that contains this user's envy | |
idx = np.digitize(envy_user, bins) | |
# print(envy_user, idx) | |
patches[idx-1].set_fc('r') | |
ax1.legend(handles=[Patch(facecolor='r', edgecolor='r', alpha=0.5, | |
label='Your envy level')], fontsize=15) | |
ax1.set_xlabel('Envy', fontsize=18) | |
ax1.set_ylabel('Number of users (log scale)', fontsize=18) | |
return fig | |
def plot_user_scores(user=0, k=2): | |
user = int(user) | |
if k in recdata.lookup_dict: | |
ret_dict = recdata.lookup_dict[k] | |
else: | |
ret_dict = calculate_user_item_metrics(recdata.P_sub, recdata.S_sub, recdata.U_sub, k=k) | |
recdata.lookup_dict[k] = ret_dict | |
users_rec = ret_dict['rec'][user].numpy() | |
scores = ret_dict['scores_job'] | |
# scores = [softmax(np.array(scores[jb])*0.5) for jb in users_rec] | |
scores = [scores[jb] for jb in users_rec] | |
rank_xs = [list(range(1, len(s)+1)) for s in scores] | |
my_ranks = [1+int(i) for i in ret_dict['ranks'][user]] | |
# my scores are the scores of the recommended jobs with rank | |
# my_scores = [scores[i][j] for i, j in enumerate(my_ranks)] | |
my_scores = [recdata.S_sub[user, job_id].item() for job_id in users_rec] | |
# my_scores_log = np.log(np.array(my_scores).astype(float)) | |
ys = np.arange(len(users_rec)) | |
# user's recommended jobs | |
if (user, k) in recdata.user_temp_data: | |
df = recdata.user_temp_data[(user, k)] | |
else: | |
df = pd.DataFrame({'x': rank_xs, 's': scores, 'y': ys}) | |
df = df.explode(list('xs')) | |
recdata.user_temp_data[(user, k)] = df | |
# df['log_scores'] = np.log(df['s'].values.astype(float)) | |
fig, ax = plt.subplots(figsize=(10, 5)) | |
# fig.tight_layout() | |
fig.subplots_adjust(bottom=0.3) | |
def sub_cmap(cmap, vmin, vmax): | |
return lambda v: cmap(vmin + (vmax - vmin) * v) | |
# palette=matplotlib.cm.get_cmap('Greens').reversed() | |
# palette = sub_cmap(palette,0.2, 0.8) | |
sns.scatterplot(data=df, x="y", y="s", ax=ax, alpha=0.6, | |
legend=False, s=100, hue='y', palette="summer") #monotone color palette | |
sns.scatterplot(y=my_scores, x=range(k), ax=ax, | |
alpha=0.8, s=200, ec='r', fc='none', label='Your rank') | |
# add ranking of this user's score for each job | |
# find score gaps | |
gaps = np.diff(np.sort(scores[0])).mean() | |
for i, (y, x) in enumerate(zip(my_scores, range(k))): | |
ax.text(x-0.3, y+gaps, my_ranks[i], color='r', fontsize=15) | |
# add notation for 'rank' | |
# ax.text(-0.8, 1.12, 'Your rank', color='r', fontsize=12) | |
ax.set_xticks(range(k)) | |
# shorten the job title | |
titles = [recdata.job_metadata[jb] for jb in users_rec] | |
titles = [t[:15] + '...' if len(t) > 15 else t for t in titles] | |
ax.set_xticklabels(titles, rotation=25, ha='right', fontsize=15) | |
ax.set_xlabel('') | |
ax.set_xlim(-1, k) | |
# ax.grid(False) | |
ax.set_ylabel('Score', fontsize=18) | |
# ax.set_ylim(-0.09, 1.2) | |
ax.legend(fontsize=15) | |
return fig | |
# demo = gr.Blocks(gr.themes.Base.from_hub('finlaymacklon/smooth_slate')) | |
demo = gr.Blocks(gr.themes.Soft()) | |
with demo: | |
def submit0(user, k): | |
fig = plot_user_envy(user, k) | |
return { | |
hist_plot: gr.update(value=fig, visible=True), | |
} | |
def submit2(user, k): | |
bar = plot_user_scores(user, k) | |
return { | |
bar_plot2: gr.update(value=bar, visible=True) | |
} | |
def submit(user): | |
new_job_num = random.randint(1,6) | |
# if new_job_num == 0, do nothing but clear the plots | |
if new_job_num > 0: | |
print(f'adding {new_job_num} new jobs') | |
recdata.update(new_user_num=0, new_job_num=new_job_num) | |
recdata.tweak_P(user) | |
return { | |
hist_plot: gr.update(visible=False), | |
bar_plot2: gr.update(visible=False) | |
} | |
# def submit_login(user): | |
# return { | |
# k: gr.update(visible=True), | |
# btn: gr.update(visible=True), | |
# btn0: gr.update(visible=True), | |
# btn2: gr.update(visible=True), | |
# pswd: gr.update(visible=False), | |
# lgbtn: gr.update(visible=False), | |
# } | |
# layout | |
gr.Markdown("## Job Recommendation Inferiority and Envy Monitor Demo") | |
with gr.Row(): | |
with gr.Column(scale=1): | |
user = gr.Textbox(label='User ID',default='0', placeholder='Enter a random integer user ID') | |
# with gr.Column(scale=1): | |
# pswd = gr.Textbox(label='Password',default='********') | |
# with gr.Column(scale=1): | |
# lgbtn = gr.Button("Login") | |
# with gr.Row(): | |
with gr.Column(scale=1): | |
k = gr.Slider(minimum=1, maximum=20, | |
default=4, step=1, label='Number of Jobs', visible=True) | |
with gr.Column(scale=1): | |
btn = gr.Button("Refresh to see new jobs", visible=True) | |
with gr.Tab('Envy'): | |
btn0 = gr.Button("User envy distribution", visible=True) | |
hist_plot = gr.Plot(visible=False) | |
with gr.Tab('Inferiority'): | |
with gr.Row(): | |
# btn1 = gr.Button("User ranks for the recommended jobs") | |
btn2 = gr.Button("User scores/ranks for the recommended jobs", visible=True) | |
# bar_plot = gr.Plot() | |
bar_plot2 = gr.Plot(visible=False) | |
# lgbtn.click(submit_login, inputs=[user], outputs=[k, btn, btn0, btn2, pswd, lgbtn]) | |
btn.click(submit, inputs=[user], outputs=[hist_plot, bar_plot2]) | |
btn0.click(submit0, inputs=[user, k], outputs=[hist_plot]) | |
# btn1.click(submit1, inputs=[user, k], outputs=[bar_plot]) | |
btn2.click(submit2, inputs=[user, k], outputs=[bar_plot2]) | |
return demo | |
def developer_interface(Ufile, Pfile, Sfile=None, job_meta_file=None, user_meta_file=None, user_groups=None): | |
recdata = Data(Ufile, Pfile, Sfile, job_meta_file, user_meta_file, user_groups, sub_sample_size=500) | |
def calculate_all_metrics(k, S_sub, U_sub, P_sub): | |
print('calculating all metrics') | |
if k in recdata.lookup_dict: | |
print('Found in lookup dict') | |
return recdata.lookup_dict[k] | |
else: | |
if not torch.is_tensor(P_sub): | |
P_sub = torch.from_numpy(P_sub) | |
envy, inferiority, utility = eiu_cut_off2( | |
(S_sub, U_sub), P_sub, k=k, agg=False) | |
envy = envy.sum(-1) | |
inferiority = inferiority.sum(-1) | |
_, rec = torch.topk(P_sub, k=k, dim=1) | |
rec_onehot = slow_onehot(rec, P_sub) | |
try: | |
rec_per_job = rec_onehot.sum(axis=0).numpy() | |
except: | |
rec_per_job = rec_onehot.sum(axis=0).cpu().numpy() | |
rec = rec.cpu() | |
metrics_at_k = {'rec': rec, 'envy': envy, 'inferiority': inferiority, 'utility': utility, | |
'rec_per_job': rec_per_job} | |
print('Finished calculating all metrics') | |
return metrics_at_k | |
def plot_user_box(metrics_dict): | |
print('plotting user box') | |
plt.close('all') | |
envy = metrics_dict['envy'].numpy() | |
inferiority = metrics_dict['inferiority'].numpy() | |
fig, (ax1, ax2) = plt.subplots(ncols=2, constrained_layout = True) | |
# fig.tight_layout() | |
ax1.boxplot(envy) | |
ax1.set_ylabel('Envy', fontsize=18) | |
# ax1.set_title('Envy', fontsize=18) | |
ax1.set_xticks([]) | |
ax2.boxplot(inferiority) | |
ax2.yaxis.set_label_position("right") | |
ax2.yaxis.tick_right() | |
ax2.set_ylabel('Inferiority', fontsize=18) | |
# ax2.set_title('Inferiority', fontsize=18) | |
ax2.set_xticks([]) | |
return fig | |
def plot_scatter(k, group=None): | |
print('plotting scatter') | |
plt.close('all') | |
if group == 'None': | |
group = None | |
if k in recdata.lookup_dict: | |
metrics_dict = recdata.lookup_dict[k] | |
else: | |
metrics_dict = calculate_all_metrics(k, recdata.S_sub, recdata.U_sub, recdata.P_sub) | |
recdata.lookup_dict[k] = metrics_dict | |
data = {'log(envy+1)': np.log(metrics_dict['envy']+1), | |
'inferiority': metrics_dict['inferiority']} | |
data = pd.DataFrame(data) | |
data = data.join(recdata.user_metadata) | |
fig, ax = plt.subplots(constrained_layout = True) | |
sns.scatterplot(data=data, x='log(envy+1)', y='inferiority', hue=group, ax=ax) | |
ax.set_xlabel('Log(envy+1)', fontsize=18) | |
ax.set_ylabel('Inferiority', fontsize=18) | |
ax.legend(fontsize=15) | |
return fig | |
def lorenz_curve(X, ax, label): | |
# ref: https://zhiyzuo.github.io/Plot-Lorenz/ | |
X.sort() | |
X_lorenz = X.cumsum() / X.sum() | |
X_lorenz = np.insert(X_lorenz, 0, 0) | |
X_lorenz[0], X_lorenz[-1] | |
ax.plot(np.arange(X_lorenz.size) / (X_lorenz.size - 1), X_lorenz, label=label) | |
## line plot of equality | |
ax.plot([0, 1], [0, 1], linestyle='dashed', color='k', label='Line of Equality') | |
ax.legend(fontsize=15) | |
ax.set_xlabel('Percentage of jobs', fontsize=18) | |
ax.set_ylabel('Percentage of job exposure', fontsize=18) | |
return ax | |
def plot_item(rec_per_job): | |
print('plotting item') | |
plt.close('all') | |
fig, (ax1, ax2) = plt.subplots(nrows=2, figsize=(10, 10)) | |
fig.tight_layout(pad=5.0) | |
labels, counts = np.unique(rec_per_job, return_counts=True) | |
ax1.bar(labels, counts, align='center') | |
ax1.set_xlabel('Number of times a job is recommended', fontsize=18) | |
ax1.set_ylabel('Number of jobs', fontsize=18) | |
ax1.set_title('Distribution of job exposure', fontsize=18) | |
ax2 = lorenz_curve(rec_per_job, ax2,'Lorenz Curve') | |
# ax2.set_title('Lorenz Curve', fontsize=18) | |
return fig | |
# build the interface | |
demo = gr.Blocks(gr.themes.Soft()) | |
with demo: | |
# callbacks | |
def submit_u(): | |
# generate two random integers including 0 representing user num and job num | |
user_num = np.random.randint(0, 5) | |
job_num = np.random.randint(0, 5) | |
if user_num > 0 or job_num > 0: | |
recdata.update(user_num, job_num) | |
return{ | |
info: gr.update(value='New {} users and {} jobs'.format(user_num, job_num),visible=True), | |
} | |
def submit1(k): | |
metrics_dict = calculate_all_metrics(k, recdata.S_sub, recdata.U_sub, recdata.P_sub) | |
return { | |
user_box_plot: plot_user_box(metrics_dict), | |
scatter_plot: plot_scatter(k), | |
btn2: gr.update(visible=True) | |
} | |
def submit2(): | |
return { | |
radio: gr.update(visible=True) | |
} | |
def submit3(k): | |
metrics_dict = calculate_all_metrics(k, recdata.S_sub, recdata.U_sub, recdata.P_sub) | |
return { | |
item_plots: plot_item(metrics_dict['rec_per_job']) | |
} | |
# layout | |
gr.Markdown("## Envy & Inferiority Monitor for Developers Demo") | |
# 1. accept k | |
with gr.Row(): | |
with gr.Column(scale=1): | |
k = gr.inputs.Slider(minimum=1, maximum=min(30,len( | |
recdata.P[0])), default=1, step=1, label='Number of Jobs') | |
with gr.Column(scale=1): | |
btn = gr.Button('Refresh') | |
with gr.Column(scale=1): | |
info = gr.Textbox('', label='Updated info', visible=False) | |
btn.click(submit_u, inputs=[], outputs=[info]) | |
with gr.Tab('User'): | |
plt.close('all') | |
btn1 = gr.Button('Visualize user-side fairness') | |
user_box_plot = gr.Plot() | |
scatter_plot = gr.Plot() | |
btn2 = gr.Button('Visualize intra-group fairness', visible=False) | |
radio = gr.Radio(choices=user_groups, value=user_groups[0] if len(user_groups) > 0 else "", | |
interactive=True, label="User group", visible=False) | |
btn1.click(submit1, inputs=[k], outputs=[ | |
user_box_plot, scatter_plot, btn2]) | |
btn2.click(submit2, inputs=[], outputs=[radio]) | |
radio.change(fn=plot_scatter, inputs=[ | |
k, radio], outputs=[scatter_plot]) | |
with gr.Tab('Item'): | |
plt.close('all') | |
btn3 = gr.Button('Visualize item-side fairness') | |
item_plots = gr.Plot() | |
btn3.click(submit3, inputs=[k], outputs=[item_plots]) | |
return demo | |
def main(config: DictConfig): | |
print(config) | |
Ufile = config.Ufile | |
Sfile = config.Sfile | |
Pfile = config.Pfile | |
user_meta_file = config.user_meta_file | |
job_meta_file = config.job_meta_file | |
user_groups = ['None'] + \ | |
list(config.user_groups) if config.user_groups else ['None'] | |
server_name = config.server_name | |
role = config.role | |
if role == 'user': | |
demo = user_interface(Ufile, Pfile, Sfile, | |
job_meta_file, user_meta_file, user_groups) | |
elif role == 'developer': | |
demo = developer_interface( | |
Ufile, Pfile, Sfile, job_meta_file, user_meta_file, user_groups) | |
# demo.launch(server_name=server_name, server_port=config.server_port) | |
demo.launch() | |
if __name__ == "__main__": | |
main() | |