mjwong commited on
Commit
cb8c696
1 Parent(s): 9320d3c

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +28 -0
README.md CHANGED
@@ -39,6 +39,8 @@ Liang Wang, Nan Yang, Xiaolong Huang, Binxing Jiao, Linjun Yang, Daxin Jiang, Ra
39
 
40
  ## How to use the model
41
 
 
 
42
  The model can be loaded with the `zero-shot-classification` pipeline like so:
43
 
44
  ```python
@@ -62,6 +64,32 @@ candidate_labels = ["politics", "economy", "entertainment", "environment"]
62
  classifier(sequence_to_classify, candidate_labels, multi_label=True)
63
  ```
64
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
  ### Eval results
66
  The model was evaluated using the XNLI test sets on 15 languages: English (en), Arabic (ar), Bulgarian (bg), German (de), Greek (el), Spanish (es), French (fr), Hindi (hi), Russian (ru), Swahili (sw), Thai (th), Turkish (tr), Urdu (ur), Vietnam (vi) and Chinese (zh). The metric used is accuracy.
67
 
 
39
 
40
  ## How to use the model
41
 
42
+ ### With the zero-shot classification pipeline
43
+
44
  The model can be loaded with the `zero-shot-classification` pipeline like so:
45
 
46
  ```python
 
64
  classifier(sequence_to_classify, candidate_labels, multi_label=True)
65
  ```
66
 
67
+ ### With manual PyTorch
68
+
69
+ The model can also be applied on NLI tasks like so:
70
+
71
+ ```python
72
+ import torch
73
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
74
+
75
+ # device = "cuda:0" or "cpu"
76
+ device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
77
+
78
+ model_name = "mjwong/multilingual-e5-large-xnli"
79
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
80
+ model = AutoModelForSequenceClassification.from_pretrained(model_name)
81
+
82
+ premise = "But I thought you'd sworn off coffee."
83
+ hypothesis = "I thought that you vowed to drink more coffee."
84
+
85
+ input = tokenizer(premise, hypothesis, truncation=True, return_tensors="pt")
86
+ output = model(input["input_ids"].to(device))
87
+ prediction = torch.softmax(output["logits"][0], -1).tolist()
88
+ label_names = ["entailment", "neutral", "contradiction"]
89
+ prediction = {name: round(float(pred) * 100, 2) for pred, name in zip(prediction, label_names)}
90
+ print(prediction)
91
+ ```
92
+
93
  ### Eval results
94
  The model was evaluated using the XNLI test sets on 15 languages: English (en), Arabic (ar), Bulgarian (bg), German (de), Greek (el), Spanish (es), French (fr), Hindi (hi), Russian (ru), Swahili (sw), Thai (th), Turkish (tr), Urdu (ur), Vietnam (vi) and Chinese (zh). The metric used is accuracy.
95