File size: 18,736 Bytes
c3279e7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
import sys
import torch
import torch.nn.functional as F
import numpy as np
from collections import defaultdict

np.set_printoptions(precision=4)
from scipy.stats import rankdata


"""Information Retrieval metrics
Useful Resources:
http://www.cs.utexas.edu/~mooney/ir-course/slides/Evaluation.ppt
http://www.nii.ac.jp/TechReports/05-014E.pdf
http://www.stanford.edu/class/cs276/handouts/EvaluationNew-handout-6-per.pdf
http://hal.archives-ouvertes.fr/docs/00/72/67/60/PDF/07-busa-fekete.pdf
Learning to Rank for Information Retrieval (Tie-Yan Liu)
"""


def mean_reciprocal_rank(rs):
    """Score is reciprocal of the rank of the first relevant item
    First element is 'rank 1'.  Relevance is binary (nonzero is relevant).
    Example from http://en.wikipedia.org/wiki/Mean_reciprocal_rank
    >>> rs = [[0, 0, 1], [0, 1, 0], [1, 0, 0]]
    >>> mean_reciprocal_rank(rs)
    0.61111111111111105
    >>> rs = np.array([[0, 0, 0], [0, 1, 0], [1, 0, 0]])
    >>> mean_reciprocal_rank(rs)
    0.5
    >>> rs = [[0, 0, 0, 1], [1, 0, 0], [1, 0, 0]]
    >>> mean_reciprocal_rank(rs)
    0.75
    Args:
        rs: Iterator of relevance scores (list or numpy) in rank order
            (first element is the first item)
    Returns:
        Mean reciprocal rank
    """
    rs = (np.asarray(r).nonzero()[0] for r in rs)
    return np.mean([1. / (r[0] + 1) if r.size else 0. for r in rs])


def r_precision(r):
    """Score is precision after all relevant documents have been retrieved
    Relevance is binary (nonzero is relevant).
    >>> r = [0, 0, 1]
    >>> r_precision(r)
    0.33333333333333331
    >>> r = [0, 1, 0]
    >>> r_precision(r)
    0.5
    >>> r = [1, 0, 0]
    >>> r_precision(r)
    1.0
    Args:
        r: Relevance scores (list or numpy) in rank order
            (first element is the first item)
    Returns:
        R Precision
    """
    r = np.asarray(r) != 0
    z = r.nonzero()[0]
    if not z.size:
        return 0.
    return np.mean(r[:z[-1] + 1])


def precision_at_k(r, k):
    """Score is precision @ k
    Relevance is binary (nonzero is relevant).
    >>> r = [0, 0, 1]
    >>> precision_at_k(r, 1)
    0.0
    >>> precision_at_k(r, 2)
    0.0
    >>> precision_at_k(r, 3)
    0.33333333333333331
    >>> precision_at_k(r, 4)
    Traceback (most recent call last):
        File "<stdin>", line 1, in ?
    ValueError: Relevance score length < k
    Args:
        r: Relevance scores (list or numpy) in rank order
            (first element is the first item)
    Returns:
        Precision @ k
    Raises:
        ValueError: len(r) must be >= k
    """
    assert k >= 1
    r = np.asarray(r)[:k] != 0
    if r.size != k:
        raise ValueError('Relevance score length < k')
    return np.mean(r)


def average_precision(r):
    """Score is average precision (area under PR curve)
    Relevance is binary (nonzero is relevant).
    >>> r = [1, 1, 0, 1, 0, 1, 0, 0, 0, 1]
    >>> delta_r = 1. / sum(r)
    >>> sum([sum(r[:x + 1]) / (x + 1.) * delta_r for x, y in enumerate(r) if y])
    0.7833333333333333
    >>> average_precision(r)
    0.78333333333333333
    Args:
        r: Relevance scores (list or numpy) in rank order
            (first element is the first item)
    Returns:
        Average precision
    """
    r = np.asarray(r) != 0
    out = [precision_at_k(r, k + 1) for k in range(r.size) if r[k]]
    if not out:
        return 0.
    return np.mean(out)


def mean_average_precision(rs):
    """Score is mean average precision
    Relevance is binary (nonzero is relevant).
    >>> rs = [[1, 1, 0, 1, 0, 1, 0, 0, 0, 1]]
    >>> mean_average_precision(rs)
    0.78333333333333333
    >>> rs = [[1, 1, 0, 1, 0, 1, 0, 0, 0, 1], [0]]
    >>> mean_average_precision(rs)
    0.39166666666666666
    Args:
        rs: Iterator of relevance scores (list or numpy) in rank order
            (first element is the first item)
    Returns:
        Mean average precision
    """
    return np.mean([average_precision(r) for r in rs])


def dcg_at_k(r, k, method=0):
    """Score is discounted cumulative gain (dcg)
    Relevance is positive real values.  Can use binary
    as the previous methods.
    Example from
    http://www.stanford.edu/class/cs276/handouts/EvaluationNew-handout-6-per.pdf
    >>> r = [3, 2, 3, 0, 0, 1, 2, 2, 3, 0]
    >>> dcg_at_k(r, 1)
    3.0
    >>> dcg_at_k(r, 1, method=1)
    3.0
    >>> dcg_at_k(r, 2)
    5.0
    >>> dcg_at_k(r, 2, method=1)
    4.2618595071429155
    >>> dcg_at_k(r, 10)
    9.6051177391888114
    >>> dcg_at_k(r, 11)
    9.6051177391888114
    Args:
        r: Relevance scores (list or numpy) in rank order
            (first element is the first item)
        k: Number of results to consider
        method: If 0 then weights are [1.0, 1.0, 0.6309, 0.5, 0.4307, ...]
                If 1 then weights are [1.0, 0.6309, 0.5, 0.4307, ...]
    Returns:
        Discounted cumulative gain
    """
    r = np.asfarray(r)[:k]
    if r.size:
        if method == 0:
            return r[0] + np.sum(r[1:] / np.log2(np.arange(2, r.size + 1)))
        elif method == 1:
            return np.sum(r / np.log2(np.arange(2, r.size + 2)))
        else:
            raise ValueError('method must be 0 or 1.')
    return 0.


def ndcg_at_k(r, k, method=0):
    """Score is normalized discounted cumulative gain (ndcg)
    Relevance is positive real values.  Can use binary
    as the previous methods.
    Example from
    http://www.stanford.edu/class/cs276/handouts/EvaluationNew-handout-6-per.pdf
    >>> r = [3, 2, 3, 0, 0, 1, 2, 2, 3, 0]
    >>> ndcg_at_k(r, 1)
    1.0
    >>> r = [2, 1, 2, 0]
    >>> ndcg_at_k(r, 4)
    0.9203032077642922
    >>> ndcg_at_k(r, 4, method=1)
    0.96519546960144276
    >>> ndcg_at_k([0], 1)
    0.0
    >>> ndcg_at_k([1], 2)
    1.0
    Args:
        r: Relevance scores (list or numpy) in rank order
            (first element is the first item)
        k: Number of results to consider
        method: If 0 then weights are [1.0, 1.0, 0.6309, 0.5, 0.4307, ...]
                If 1 then weights are [1.0, 0.6309, 0.5, 0.4307, ...]
    Returns:
        Normalized discounted cumulative gain
    """
    dcg_max = dcg_at_k(sorted(r, reverse=True), k, method)
    if not dcg_max:
        return 0.
    return dcg_at_k(r, k, method) / dcg_max


"""
Wealth inequality
"""


def gini(arr):
    ## Gini = \frac{2\sum_i^n i\times y_i}{n\sum_i^n y_i} - \frac{n+1}{n}
    sorted_arr = arr.copy()
    sorted_arr.sort()
    n = arr.size
    coef_ = 2. / n
    const_ = (n + 1.) / n
    weighted_sum = sum([(i + 1) * yi for i, yi in enumerate(sorted_arr)])
    return coef_ * weighted_sum / (sorted_arr.sum()) - const_


"""
Expected envy and inferiority under probabilistic recommendation as weighted sampling with replacement
"""


def expected_utility_u(Ru, ps, k):
    return Ru @ ps * k


def expected_utility(R, Pi, k):
    U = (R * Pi * k).sum(axis=1)
    # if not agg:
    return U


def expected_envy_u_v(Ru, pus, pvs, k):
    return Ru @ (pvs - pus) * k


def prob_in(ps, k):
    return 1 - (1 - ps) ** k


def prob_in_approx(ps, k):
    return k * ps


def expected_inferiority_u_v(Ru, Rv, pus, pvs, k, compensate=False, approx=False):
    differ = Rv - Ru
    if not compensate:
        differ = np.clip(differ, a_min=0, a_max=None)
    if not approx:
        return differ @ (prob_in(pus, k) * prob_in(pvs, k))
    else:
        return differ @ (prob_in_approx(pus, k) * prob_in_approx(pvs, k))


def expected_envy(R, Pi, k):
    """
    Measure expected envy for k-sized recommendation according to rec strategy Pi with respect to relevancy scores R
    :param R: m x n real-valued matrix
    :param Pi: m x n Markov matrix
    :return: E: m x n envy matrix where Euv = envy from u to v if not agg, sum of E if agg
    """
    assert np.all(np.isclose(Pi.sum(axis=1), 1.)) or np.array_equal(Pi,
                                                                    Pi.astype(bool))  # binary matrix for discrete rec
    m, n = len(R), len(R[0])
    E = np.zeros((m, m))
    for u in range(m):
        for v in range(m):
            if v == u:
                continue
            E[u, v] = expected_envy_u_v(R[u], Pi[u], Pi[v], k=k)
    E = np.clip(E, a_min=0., a_max=None)
    # if not agg:
    return E


def expected_inferiority(R, Pi, k, compensate=True, approx=False):
    """
    Measure expected inferiority for k-sized recommendation according to rec strategy Pi with respect to relevancy scores R
    :param R:
    :param Pi:
    :param k:
    :param agg:
    :return: I: m x n
    """
    assert np.all(np.isclose(Pi.sum(axis=1), 1.)) or np.array_equal(Pi,
                                                                    Pi.astype(bool))  # binary matrix for discrete rec
    m, n = len(R), len(R[0])
    I = np.zeros((m, m))
    for u in range(m):
        for v in range(m):
            if v == u:
                continue
            I[u, v] = expected_inferiority_u_v(R[u], R[v], Pi[u], Pi[v], k=k, approx=approx, compensate=compensate)

    I = np.clip(I, a_min=0., a_max=None)
    # if not agg:
    return I


def expected_envy_torch(R, Pi, k):
    m, n = len(R), len(R[0])
    E = torch.zeros(m, m)
    for u in range(m):
        for v in range(m):
            if v == u:
                continue
            E[u, v] = expected_envy_u_v(R[u], Pi[u], Pi[v], k=k)
    E = torch.clamp(E, min=0.)
    return E


def expected_envy_torch_vec(R, P, k):
    res = R @ P.transpose(0, 1)
    envy_mat = (res - torch.diagonal(res, 0).reshape(-1, 1))
    return k * (torch.clamp(envy_mat, min=0.))


def expected_inferiority_torch(R, Pi, k, compensate=False, approx=False):
    m, n = R.shape
    I = torch.zeros((m, m))
    for u in range(m):
        for v in range(m):
            if v == u:
                continue
            if not approx:
                joint_prob = prob_in(Pi[v], k) * prob_in(Pi[u], k)
            else:
                joint_prob = prob_in_approx(Pi[v], k) * prob_in_approx(Pi[u], k)

            if not compensate:
                I[u, v] = torch.clamp(R[v] - R[u], min=0., max=None) @ joint_prob
            else:
                I[u, v] = (R[v] - R[u]) @ joint_prob

    return torch.clamp(I, min=0.)


def expected_inferiority_torch_vec(R, P, k, compensate=False, approx=False):
    m, n = R.shape
    I = torch.zeros((m, m))
    P_pow_k = 1 - (1 - P).pow(k) if not approx else P * k
    for i in range(m):
        first_term = torch.clamp(R - R[i], min=0.) if not compensate else R - R[i]
        I[i] = (first_term * (P_pow_k[i] * P_pow_k)).sum(1)
    return I


def slow_onehot(idx, P):
    m = P.shape[0]
    res = torch.zeros_like(P)
    for i in range(m):
        res[i, idx[i]] = 1.
    return res


def eiu_cut_off(R, Pi, k, agg=True):
    """
    Evaluate envy, inferiority, utility based on top-k cut-off recommendation
    :param R:
    :param Pi:
    :return: envy, inferiority, utility
    """
    # print('Start evaluation!')
    m, n = R.shape
    # _, rec = torch.topk(Pi, k, dim=1)
    # rec_onehot = F.one_hot(rec, num_classes=n).sum(1).float()
    rec_onehot = slow_onehot(torch.topk(Pi, k, dim=1)[1], Pi)
    envy = expected_envy_torch_vec(R, rec_onehot, k=1)
    inferiority = expected_inferiority_torch_vec(R, rec_onehot, k=1, compensate=False, approx=False)
    utility = expected_utility(R, rec_onehot, k=1)
    if agg:
        envy = envy.sum(-1).mean()
        inferiority = inferiority.sum(-1).mean()
        utility = utility.mean()
    return envy, inferiority, utility


def eiu_cut_off2(R, Pi, k, agg=True):
    """
    Evaluate envy, inferiority, utility based on top-k cut-off recommendation
    :param R:
    :param Pi:
    :return: envy, inferiority, utility
    """
    # print('Start evaluation!')
    S, U = R
    if not isinstance(S, torch.Tensor):
        S = torch.tensor(S)
    if not isinstance(U, torch.Tensor):
        U = torch.tensor(U)
    if not isinstance(Pi, torch.Tensor):
        Pi = torch.tensor(Pi)
    m, n = U.shape
    # _, rec = torch.topk(Pi, k, dim=1)
    # rec_onehot = F.one_hot(rec, num_classes=n).sum(1).float()
    rec_onehot = slow_onehot(torch.topk(Pi, k, dim=1)[1], Pi)
    envy = expected_envy_torch_vec(U, rec_onehot, k=1)
    inferiority = expected_inferiority_torch_vec(S, rec_onehot, k=1, compensate=False, approx=False)
    utility = expected_utility(U, rec_onehot, k=1)
    if agg:
        envy = envy.sum(-1).mean()
        inferiority = inferiority.sum(-1).mean()
        utility = utility.mean()
    return envy, inferiority, utility


"""
Global congestion metrics
"""


def get_competitors(rec_per_job, rec):
    m = rec.shape[0]
    competitors = []
    for i in range(m):
        if len(rec[i]) == 1:
            competitors.append([rec_per_job[rec[i]]])
        else:
            competitors.append(rec_per_job[rec[i]])
    return np.array(competitors)


def get_better_competitor_scores(rec, R):
    m, n = R.shape
    _, k = rec.shape
    user_ids_per_job = defaultdict(list)
    for i, r in enumerate(rec):
        for j in r:
            user_ids_per_job[j.item()].append(i)

    mean_competitor_scores_per_job = np.zeros((m, k))
    for i in range(m):
        my_rec_jobs = rec[i].numpy()

        my_mean_competitors = np.zeros(k)
        for j_, j in enumerate(my_rec_jobs):
            my_score = R[i, j]
            all_ids = user_ids_per_job[j].copy()
            all_ids.remove(i)
            other_scores = R[all_ids, j]
            if not all_ids:
                other_scores = np.zeros(1)  # TODO if no competition, then it is the negative of my own score
            my_mean_competitors[j_] = other_scores.mean() - my_score
            # my_mean_competitors[my_mean_competitors < 0] = 0. # TODO only keep the better competitors
        mean_competitor_scores_per_job[i] = my_mean_competitors
    return mean_competitor_scores_per_job


def get_num_better_competitors(rec, R):
    m, n = R.shape
    _, k = rec.shape
    user_ids_per_job = defaultdict(list)
    for i, r in enumerate(rec):
        for j in r:
            user_ids_per_job[j.item()].append(i)

    num_better_competitors = np.zeros((m, k))
    for i in range(m):
        my_rec_jobs = rec[i].numpy()

        better_competitors = np.zeros(k)
        for j_, j in enumerate(my_rec_jobs):
            my_score = R[i, j]
            all_ids = user_ids_per_job[j].copy()
            all_ids.remove(i)
            other_scores = R[all_ids, j]
            better_competitors[j_] = ((other_scores - my_score) > 0).sum()
        num_better_competitors[i] = better_competitors
    return num_better_competitors


def get_scores_ids_per_job(rec, R):
    scores_per_job = defaultdict(list)
    ids_per_job = defaultdict(list)

    for i in range(len(rec)):
        u = rec[i]
        for jb in u:
            jb = jb.item()
            ids_per_job[jb].append(i)
            scores_per_job[jb].append(R[i, jb].item())
    return scores_per_job, ids_per_job


def get_rank(a, method='ordinal', axis=None, descending=False):
    if descending:
        a = np.array(a) * -1
    return stats.rankdata(a, method=method, axis=axis)


def get_ranks_per_job(scores_rec):
    ranks_per_job = defaultdict(list)
    for jb in scores_rec:
        ranks_per_job[jb] = get_rank(scores_rec[jb], descending=True)
    return ranks_per_job


def get_ranks_per_user(ranks_per_job, ids_per_job):
    for k, v in ranks_per_job.items():
        ranks_per_job[k] = [i - 1 for i in v]
    ranks_per_user = defaultdict(list)
    for k, v in ids_per_job.items():
        rks = ranks_per_job[k]
        for i, u in enumerate(v):
            ranks_per_user[u].append(rks[i])
    return ranks_per_user


def calculate_global_metrics(res, R, k=10):
    # get rec
    m, n = res.shape
    if not torch.is_tensor(res):
        res = torch.from_numpy(res)
    _, rec = torch.topk(res, k, dim=1)
    rec_onehot = slow_onehot(rec, res)
    # rec_onehot = F.one_hot(rec, num_classes=n).sum(1).float()
    try:
        rec_per_job = rec_onehot.sum(axis=0).numpy()
    except:
        rec_per_job = rec_onehot.sum(axis=0).cpu().numpy()
        rec = rec.cpu()
        R = R.cpu()
    opt_competitors = get_competitors(rec_per_job, rec)

    # mean competitors per person
    mean_competitors = opt_competitors.mean()

    # mean better competitors per person
    mean_better_competitors = get_num_better_competitors(rec, R).mean()

    # mean competitor scores - my score
    mean_diff_scores = get_better_competitor_scores(rec, R)
    mean_diff_scores[mean_diff_scores < 0] = 0.
    mean_diff_scores = mean_diff_scores.mean()

    # mean rank
    # scores_opt, ids_opt = get_scores_ids_per_job(rec, R)
    # ranks_opt = get_ranks_per_job(scores_opt)
    # ranks_per_user_opt = get_ranks_per_user(ranks_opt, ids_opt)
    # mean_rank = np.array(list(ranks_per_user_opt.values())).mean()

    # gini
    gini_index = gini(rec_per_job)

    return {'mean_competitors': mean_competitors, 'mean_better_competitors': mean_better_competitors, \
            'mean_scores_diff': mean_diff_scores, 'mean_rank': mean_better_competitors, 'gini_index': gini_index}


def calculate_global_metrics2(res, R, k=10):
    # get rec
    S, U = R
    m, n = res.shape
    if not torch.is_tensor(res):
        res = torch.from_numpy(res)
    _, rec = torch.topk(res, k, dim=1)
    rec_onehot = slow_onehot(rec, res)
    # rec_onehot = F.one_hot(rec, num_classes=n).sum(1).float()
    try:
        rec_per_job = rec_onehot.sum(axis=0).numpy()
    except:
        rec_per_job = rec_onehot.sum(axis=0).cpu().numpy()
        rec = rec.cpu()
        S = S.cpu()
        U = U.cpu()
    opt_competitors = get_competitors(rec_per_job, rec)

    # mean competitors per person
    mean_competitors = opt_competitors.mean()

    # mean better competitors per person
    mean_better_competitors = get_num_better_competitors(rec, S).mean()

    # mean competitor scores - my score
    mean_diff_scores = get_better_competitor_scores(rec, S)
    mean_diff_scores[mean_diff_scores < 0] = 0.
    mean_diff_scores = mean_diff_scores.mean()

    # mean rank
    scores_opt, ids_opt = get_scores_ids_per_job(rec, S)
    ranks_opt = get_ranks_per_job(scores_opt)
    ranks_per_user_opt = get_ranks_per_user(ranks_opt, ids_opt)
    mean_rank = np.array(list(ranks_per_user_opt.values())).mean()

    # gini
    gini_index = gini(rec_per_job)

    return {'mean_competitors': mean_competitors, 'mean_better_competitors': mean_better_competitors, \
            'mean_scores_diff': mean_diff_scores, 'mean_rank': mean_rank, 'gini_index': gini_index}

def get_scores_per_job(rec, S):
    scores_per_job = defaultdict(list)
    for i in range(len(rec)):
        u = rec[i]
        for jb in u:
            jb = jb.item()
            scores_per_job[jb].append(S[i, jb].item())
    return scores_per_job