VibhuJawa commited on
Commit
10349c9
1 Parent(s): 74cc0f0

Add Readme.md to model card

Browse files
Files changed (1) hide show
  1. README.md +42 -3
README.md CHANGED
@@ -3,7 +3,46 @@ tags:
3
  - pytorch_model_hub_mixin
4
  - model_hub_mixin
5
  ---
 
6
 
7
- This model has been pushed to the Hub using ****:
8
- - Repo: [More Information Needed]
9
- - Docs: [More Information Needed]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  - pytorch_model_hub_mixin
4
  - model_hub_mixin
5
  ---
6
+ # nvidia/domain-classifier
7
 
8
+ This repository contains the code for the domain classifier model.
9
+
10
+ # How to use in transformers
11
+ To use the Domain classifier, use the following code:
12
+
13
+ ```python3
14
+
15
+ import torch
16
+ from torch import nn
17
+ from transformers import AutoModel, AutoTokenizer, AutoConfig
18
+ from huggingface_hub import PyTorchModelHubMixin
19
+
20
+ class CustomModel(nn.Module, PyTorchModelHubMixin):
21
+ def __init__(self, config):
22
+ super(CustomModel, self).__init__()
23
+ self.model = AutoModel.from_pretrained(config['base_model'])
24
+ self.dropout = nn.Dropout(config['fc_dropout'])
25
+ self.fc = nn.Linear(self.model.config.hidden_size, len(config['id2label']))
26
+
27
+ def forward(self, input_ids, attention_mask):
28
+ features = self.model(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state
29
+ dropped = self.dropout(features)
30
+ outputs = self.fc(dropped)
31
+ return torch.softmax(outputs[:, 0, :], dim=1)
32
+
33
+ # Setup configuration and model
34
+ config = AutoConfig.from_pretrained("nvidia/domain-classifier")
35
+ tokenizer = AutoTokenizer.from_pretrained("nvidia/domain-classifier")
36
+ model = CustomModel.from_pretrained("nvidia/domain-classifier")
37
+
38
+ # Prepare and process inputs
39
+ text_samples = ["Sports is a popular domain", "Politics is a popular domain"]
40
+ inputs = tokenizer(text_samples, return_tensors="pt", padding="longest", truncation=True)
41
+ outputs = model(inputs['input_ids'], inputs['attention_mask'])
42
+
43
+ # Predict and display results
44
+ predicted_classes = torch.argmax(outputs, dim=1)
45
+ predicted_domains = [config.id2label[class_idx.item()] for class_idx in predicted_classes.cpu().numpy()]
46
+ print(predicted_domains)
47
+ # ['Sports', 'News']
48
+ ```