File size: 2,275 Bytes
d7eff13
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
from collections import Counter
from itertools import chain
import math
import torch
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction


def ngrams(sequence, n):
    return [tuple(sequence[i:i+n]) for i in range(len(sequence)-n+1)]

def count_ngrams(sequence, max_n):
    counts = Counter()
    for n in range(1, max_n + 1):
        counts.update(ngrams(sequence, n))
    return counts

def self_bleu(outputs):
    smoothing_function = SmoothingFunction().method1
    scores = []
    for i in range(len(outputs)):
        references = outputs[:i] + outputs[i+1:]
        # Avoid calculating BLEU score for empty references
        if references:
            scores.append(sentence_bleu(references, outputs[i], smoothing_function=smoothing_function))
    # If all references are empty, return a default value
    if not scores:
        return 0
    return sum(scores) / len(scores)

def dist_n(outputs, n):
    all_ngrams = list(chain(*[ngrams(output, n) for output in outputs]))
    unique_ngrams = set(all_ngrams)
    return len(unique_ngrams) / len(all_ngrams) if all_ngrams else 0

def perplexity(model, tokenizer, texts):
    encodings = tokenizer(texts, return_tensors='pt', padding=True, truncation=True)
    max_length = model.config.n_positions
    stride = 512
    lls = []
    for i in range(0, encodings.input_ids.size(1), stride):
        begin_loc = max(i + stride - max_length, 0)
        end_loc = i + stride
        trg_len = end_loc - i
        input_ids = encodings.input_ids[:, begin_loc:end_loc].to(model.device)
        target_ids = input_ids.clone()
        target_ids[:, :-trg_len] = -100

        with torch.no_grad():
            outputs = model(input_ids, labels=target_ids)
            log_likelihood = outputs.loss * trg_len
        lls.append(log_likelihood)

    ppl = torch.exp(torch.stack(lls).sum() / end_loc)
    return ppl.item()

def js_divergence(p, q):
    def kl_divergence(p, q):
        return sum(p[i] * math.log(p[i] / q[i]) for i in range(len(p)) if p[i] != 0 and q[i] != 0)
    
    p_norm = [float(i)/sum(p) for i in p]
    q_norm = [float(i)/sum(q) for i in q]
    
    m = [(p_norm[i] + q_norm[i]) / 2 for i in range(len(p_norm))]
    
    return (kl_divergence(p_norm, m) + kl_divergence(q_norm, m)) / 2