Update README.md
Browse files
README.md
CHANGED
@@ -34,19 +34,30 @@ import torch
|
|
34 |
tokenizer = BertTokenizer.from_pretrained('suayptalha/medBERT-base')
|
35 |
model = BertForMaskedLM.from_pretrained('suayptalha/medBERT-base').to("cuda")
|
36 |
|
37 |
-
input_text = "
|
38 |
-
|
39 |
-
|
40 |
-
inputs = tokenizer(masked_text, return_tensors='pt').to("cuda")
|
41 |
|
42 |
outputs = model(**inputs)
|
43 |
|
44 |
-
|
|
|
|
|
|
|
|
|
|
|
45 |
|
46 |
-
|
47 |
-
|
|
|
48 |
'''
|
49 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
50 |
### **Fine-tuning the Model**
|
51 |
|
52 |
To fine-tune the **medBERT-base** model on your own medical dataset, follow these steps:
|
|
|
34 |
tokenizer = BertTokenizer.from_pretrained('suayptalha/medBERT-base')
|
35 |
model = BertForMaskedLM.from_pretrained('suayptalha/medBERT-base').to("cuda")
|
36 |
|
37 |
+
input_text = "Response to neoadjuvant chemotherapy best predicts survival [MASK] curative resection of gastric cancer."
|
38 |
+
inputs = tokenizer(input_text, return_tensors='pt').to("cuda")
|
|
|
|
|
39 |
|
40 |
outputs = model(**inputs)
|
41 |
|
42 |
+
masked_index = (inputs['input_ids'][0] == tokenizer.mask_token_id).nonzero(as_tuple=True)[0].item()
|
43 |
+
|
44 |
+
top_k = 5
|
45 |
+
logits = outputs.logits[0, masked_index]
|
46 |
+
top_k_ids = torch.topk(logits, k=top_k).indices.tolist()
|
47 |
+
top_k_tokens = tokenizer.convert_ids_to_tokens(top_k_ids)
|
48 |
|
49 |
+
print("Top 5 prediction:")
|
50 |
+
for i, token in enumerate(top_k_tokens):
|
51 |
+
print(f"{i + 1}: {token}")
|
52 |
'''
|
53 |
|
54 |
+
_Top 5 prediction:_
|
55 |
+
_1: from_
|
56 |
+
_2: of_
|
57 |
+
_3: after_
|
58 |
+
_4: by_
|
59 |
+
_5: through_
|
60 |
+
|
61 |
### **Fine-tuning the Model**
|
62 |
|
63 |
To fine-tune the **medBERT-base** model on your own medical dataset, follow these steps:
|