cnmoro commited on
Commit
d3ffb56
·
verified ·
1 Parent(s): 64167fd

Create README.md

Browse files
Files changed (1) hide show
  1. README.md +116 -0
README.md ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: mit
3
+ language:
4
+ - en
5
+ - pt
6
+ base_model:
7
+ - cnmoro/tangled-llama-33m-32k-instruct-v0.1-fix
8
+ ---
9
+
10
+ ```python
11
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
12
+ import torch
13
+
14
+ model_id = "cnmoro/BertMini-Reranker-EnPt"
15
+ model = AutoModelForSequenceClassification.from_pretrained(
16
+ model_id,
17
+ num_labels=2
18
+ )
19
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
20
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
21
+ model.to(device)
22
+
23
+ template = "Query: {query}\nSentence: {document}"
24
+
25
+ def rank(query, documents, normalize_scores=True):
26
+ texts = [template.format(query=query, document=document) for document in documents]
27
+
28
+ inputs = tokenizer(
29
+ texts,
30
+ add_special_tokens=True,
31
+ max_length=512,
32
+ truncation=True,
33
+ padding=True,
34
+ return_tensors="pt",
35
+ )
36
+
37
+ input_ids = inputs["input_ids"].to(device)
38
+ attention_mask = inputs["attention_mask"].to(device)
39
+
40
+ model.eval()
41
+ with torch.no_grad():
42
+ outputs = model(input_ids, attention_mask=attention_mask)
43
+ logits = outputs.logits
44
+ probabilities = torch.softmax(logits, dim=1)
45
+
46
+ # Get the predicted classes and confidence scores
47
+ predicted_classes = torch.argmax(probabilities, dim=1).tolist()
48
+ confidences = probabilities.max(dim=1).values.tolist()
49
+
50
+ # Construct the results
51
+ results = [
52
+ {"prediction": pred, "confidence": conf}
53
+ for pred, conf in zip(predicted_classes, confidences)
54
+ ]
55
+
56
+ final_results = []
57
+ for document, result in zip(documents, results):
58
+ # If the prediction is 0, then get the score as 1 - confidence
59
+ if result['prediction'] == 0:
60
+ result['confidence'] = 1 - result['confidence']
61
+ final_results.append((document, result['confidence']))
62
+
63
+ # Sort by the confidence score, descending
64
+ sorted_results = sorted(final_results, key=lambda x: x[1], reverse=True)
65
+
66
+ if normalize_scores:
67
+ total_score = sum([result[1] for result in sorted_results])
68
+ sorted_results = [(result[0], result[1] / total_score) for result in sorted_results]
69
+
70
+ return sorted_results
71
+
72
+ # Sample - 1
73
+ query = "O que é o Pantanal?"
74
+ documents = [
75
+ "É um dos ecossistemas mais ricos em biodiversidade do mundo, abrigando uma grande variedade de espécies animais e vegetais.",
76
+ "Sua beleza natural, com rios e lagos interligados, atrai turistas de todo o mundo.",
77
+ "O Pantanal sofre com impactos ambientais, como a exploração mineral e o desmatamento.",
78
+ "O Pantanal é uma extensa planície alagável localizada na América do Sul, principalmente no Brasil, mas também em partes da Bolívia e Paraguai.",
79
+ "É um local com importância histórica e cultural para as populações locais.",
80
+ "O Pantanal é um importante habitat para diversas espécies de animais, inclusive aves migratórias."
81
+ ]
82
+ rank(query, documents)
83
+ # [('O Pantanal é um importante habitat para diversas espécies de animais, inclusive aves migratórias.',
84
+ # 0.39881916829816605),
85
+ # ('O Pantanal é uma extensa planície alagável localizada na América do Sul, principalmente no Brasil, mas também em partes da Bolívia e Paraguai.',
86
+ # 0.37527216160662785),
87
+ # ('O Pantanal sofre com impactos ambientais, como a exploração mineral e o desmatamento.',
88
+ # 0.1491597234932686),
89
+ # ('É um local com importância histórica e cultural para as populações locais.',
90
+ # 0.03648153259324298),
91
+ # ('Sua beleza natural, com rios e lagos interligados, atrai turistas de todo o mundo.',
92
+ # 0.020711667666201344),
93
+ # ('É um dos ecossistemas mais ricos em biodiversidade do mundo, abrigando uma grande variedade de espécies animais e vegetais.',
94
+ # 0.019555746342493203)]
95
+
96
+ # Sample - 2
97
+ query = "What is the speed of light?"
98
+ documents = [
99
+ "Isaac Newton's laws of motion and gravity laid the groundwork for classical mechanics.",
100
+ "The theory of relativity, proposed by Albert Einstein, has revolutionized our understanding of space, time, and gravity.",
101
+ "The Earth orbits the Sun at an average distance of about 93 million miles, taking roughly 365.25 days to complete one revolution.",
102
+ "The speed of light in a vacuum is approximately 299,792 kilometers per second (km/s), or about 186,282 miles per second.",
103
+ "Light can be described as both a wave and a particle, a concept known as wave-particle duality."
104
+ ]
105
+ rank(query, documents)
106
+ # [('The speed of light in a vacuum is approximately 299,792 kilometers per second (km/s), or about 186,282 miles per second.',
107
+ # 0.23310426435074763),
108
+ # ('The Earth orbits the Sun at an average distance of about 93 million miles, taking roughly 365.25 days to complete one revolution.',
109
+ # 0.22329693015953184),
110
+ # ("Isaac Newton's laws of motion and gravity laid the groundwork for classical mechanics.",
111
+ # 0.20374707681001922),
112
+ # ('The theory of relativity, proposed by Albert Einstein, has revolutionized our understanding of space, time, and gravity.',
113
+ # 0.20284618746671068),
114
+ # ('Light can be described as both a wave and a particle, a concept known as wave-particle duality.',
115
+ # 0.13700554121299063)]
116
+ ```