liamcripwell
commited on
Commit
•
94dce2c
1
Parent(s):
3b501c3
Update README.md
Browse files
README.md
CHANGED
@@ -91,4 +91,76 @@ template = """{
|
|
91 |
prediction = predict_NuExtract(model, tokenizer, [text], template)[0]
|
92 |
print(prediction)
|
93 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
94 |
```
|
|
|
91 |
prediction = predict_NuExtract(model, tokenizer, [text], template)[0]
|
92 |
print(prediction)
|
93 |
|
94 |
+
```
|
95 |
+
|
96 |
+
Sliding window prompting:
|
97 |
+
|
98 |
+
```python
|
99 |
+
import json
|
100 |
+
|
101 |
+
MAX_INPUT_SIZE = 20_000
|
102 |
+
MAX_NEW_TOKENS = 6000
|
103 |
+
|
104 |
+
def clean_json_text(text):
|
105 |
+
text = text.strip()
|
106 |
+
text = text.replace("\#", "#").replace("\&", "&")
|
107 |
+
return text
|
108 |
+
|
109 |
+
def predict_chunk(text, template, current, model, tokenizer):
|
110 |
+
current = clean_json_text(current)
|
111 |
+
|
112 |
+
input_llm = f"<|input|>\n### Template:\n{template}\n### Current:\n{current}\n### Text:\n{text}\n\n<|output|>" + "{"
|
113 |
+
input_ids = tokenizer(input_llm, return_tensors="pt", truncation=True, max_length=MAX_INPUT_SIZE).to("cuda")
|
114 |
+
output = tokenizer.decode(model.generate(**input_ids, max_new_tokens=MAX_NEW_TOKENS)[0], skip_special_tokens=True)
|
115 |
+
|
116 |
+
return clean_json_text(output.split("<|output|>")[1])
|
117 |
+
|
118 |
+
def split_document(document, window_size, overlap):
|
119 |
+
tokens = tokenizer.tokenize(document)
|
120 |
+
print(f"\tLength of document: {len(tokens)} tokens")
|
121 |
+
|
122 |
+
chunks = []
|
123 |
+
if len(tokens) > window_size:
|
124 |
+
for i in range(0, len(tokens), window_size-overlap):
|
125 |
+
print(f"\t{i} to {i + len(tokens[i:i + window_size])}")
|
126 |
+
chunk = tokenizer.convert_tokens_to_string(tokens[i:i + window_size])
|
127 |
+
chunks.append(chunk)
|
128 |
+
|
129 |
+
if i + len(tokens[i:i + window_size]) >= len(tokens):
|
130 |
+
break
|
131 |
+
else:
|
132 |
+
chunks.append(document)
|
133 |
+
print(f"\tSplit into {len(chunks)} chunks")
|
134 |
+
|
135 |
+
return chunks
|
136 |
+
|
137 |
+
def handle_broken_output(pred, prev):
|
138 |
+
try:
|
139 |
+
if all([(v in ["", []]) for v in json.loads(pred).values()]):
|
140 |
+
# if empty json, return previous
|
141 |
+
pred = prev
|
142 |
+
except:
|
143 |
+
# if broken json, return previous
|
144 |
+
pred = prev
|
145 |
+
|
146 |
+
return pred
|
147 |
+
|
148 |
+
def sliding_window_prediction(text, template, model, tokenizer, window_size=4000, overlap=128):
|
149 |
+
# split text into chunks of n tokens
|
150 |
+
tokens = tokenizer.tokenize(text)
|
151 |
+
chunks = split_document(text, window_size, overlap)
|
152 |
+
|
153 |
+
# iterate over text chunks
|
154 |
+
prev = template
|
155 |
+
for i, chunk in enumerate(chunks):
|
156 |
+
print(f"Processing chunk {i}...")
|
157 |
+
pred = predict_chunk(chunk, template, prev, model, tokenizer)
|
158 |
+
|
159 |
+
# handle broken output
|
160 |
+
pred = handle_broken_output(pred, prev)
|
161 |
+
|
162 |
+
# iterate
|
163 |
+
prev = pred
|
164 |
+
|
165 |
+
return pred
|
166 |
```
|