suayptalha commited on
Commit
0a771f1
·
verified ·
1 Parent(s): b1157c4

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +18 -7
README.md CHANGED
@@ -34,19 +34,30 @@ import torch
34
  tokenizer = BertTokenizer.from_pretrained('suayptalha/medBERT-base')
35
  model = BertForMaskedLM.from_pretrained('suayptalha/medBERT-base').to("cuda")
36
 
37
- input_text = "The patient was diagnosed with gastric cancer after a thorough examination."
38
- masked_text = input_text.replace("gastric cancer", tokenizer.mask_token)
39
-
40
- inputs = tokenizer(masked_text, return_tensors='pt').to("cuda")
41
 
42
  outputs = model(**inputs)
43
 
44
- predicted_token_id = torch.argmax(outputs.logits, dim=-1)
 
 
 
 
 
45
 
46
- predicted_token = tokenizer.decode(predicted_token_id[0, inputs['input_ids'].shape[1] - 1])
47
- print(predicted_token)
 
48
  '''
49
 
 
 
 
 
 
 
 
50
  ### **Fine-tuning the Model**
51
 
52
  To fine-tune the **medBERT-base** model on your own medical dataset, follow these steps:
 
34
  tokenizer = BertTokenizer.from_pretrained('suayptalha/medBERT-base')
35
  model = BertForMaskedLM.from_pretrained('suayptalha/medBERT-base').to("cuda")
36
 
37
+ input_text = "Response to neoadjuvant chemotherapy best predicts survival [MASK] curative resection of gastric cancer."
38
+ inputs = tokenizer(input_text, return_tensors='pt').to("cuda")
 
 
39
 
40
  outputs = model(**inputs)
41
 
42
+ masked_index = (inputs['input_ids'][0] == tokenizer.mask_token_id).nonzero(as_tuple=True)[0].item()
43
+
44
+ top_k = 5
45
+ logits = outputs.logits[0, masked_index]
46
+ top_k_ids = torch.topk(logits, k=top_k).indices.tolist()
47
+ top_k_tokens = tokenizer.convert_ids_to_tokens(top_k_ids)
48
 
49
+ print("Top 5 prediction:")
50
+ for i, token in enumerate(top_k_tokens):
51
+ print(f"{i + 1}: {token}")
52
  '''
53
 
54
+ _Top 5 prediction:_
55
+ _1: from_
56
+ _2: of_
57
+ _3: after_
58
+ _4: by_
59
+ _5: through_
60
+
61
  ### **Fine-tuning the Model**
62
 
63
  To fine-tune the **medBERT-base** model on your own medical dataset, follow these steps: