juierror commited on
Commit
69da55a
1 Parent(s): ba82509

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +26 -0
README.md CHANGED
@@ -1,3 +1,29 @@
1
  ---
2
  license: mit
3
  ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
  license: mit
3
  ---
4
+
5
+ # How to use
6
+ ```python
7
+ from typing import List
8
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
9
+
10
+ tokenizer = AutoTokenizer.from_pretrained("juierror/text-to-sql-with-table-schema")
11
+ model = AutoModelForSeq2SeqLM.from_pretrained("juierror/text-to-sql-with-table-schema")
12
+
13
+ def prepare_input(question: str, table: List[str]):
14
+ table_prefix = "table:"
15
+ question_prefix = "question:"
16
+ join_table = ",".join(table)
17
+ inputs = f"{question_prefix} {question} {table_prefix} {join_table}"
18
+ input_ids = tokenizer(inputs, max_length=700, return_tensors="pt").input_ids
19
+ return input_ids
20
+
21
+ def inference(question: str, table: List[str]) -> str:
22
+ input_data = prepare_input(question=question, table=table)
23
+ input_data = input_data.to(model.device)
24
+ outputs = model.generate(inputs=input_data, num_beams=10, top_k=10, max_length=700)
25
+ result = tokenizer.decode(token_ids=outputs[0], skip_special_tokens=True)
26
+ return result
27
+
28
+ print(inference(question="get people name with age equal 25", table=["id", "name", "age"]))
29
+ ```