File size: 10,235 Bytes
7c20485
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
import json
from vllm import LLM, SamplingParams
from nltk.tokenize import PunktSentenceTokenizer
import re
import torch

class LongCiteModel(LLM):

    @torch.inference_mode()
    def chat(self, tokenizer, query: str, history=None, role="user",
            max_new_tokens=None, top_p=0.7, temperature=0.95):
        if history is None:
            history = []
        inputs = tokenizer.build_chat_input(query, history=history, role=role)
        eos_token_id = [tokenizer.eos_token_id, tokenizer.get_command("<|user|>"), tokenizer.get_command("<|observation|>")]
        generation_params = SamplingParams(
            temperature=temperature,
            top_p=top_p,
            max_tokens=max_new_tokens,
            stop_token_ids=eos_token_id,
        )
        input_ids = inputs.input_ids[0].tolist()
        outputs = self.generate(sampling_params=generation_params, prompt_token_ids=[input_ids])
        response = tokenizer.decode(outputs[0].outputs[0].token_ids[:-1])
        history.append({"role": role, "content": query})
        return response, history

    def query_longcite(self, context, query, tokenizer, max_input_length=128000, max_new_tokens=1024, temperature=0.95):

        def text_split_by_punctuation(original_text, return_dict=False):
            # text = re.sub(r'([a-z])\.([A-Z])', r'\1. \2', original_text)  # separate period without space
            text = original_text
            custom_sent_tokenizer = PunktSentenceTokenizer(text)
            punctuations = r"([。;!?])"  # For Chinese support

            separated = custom_sent_tokenizer.tokenize(text)
            separated = sum([re.split(punctuations, s) for s in separated], [])
            # Put the punctuations back to the sentence
            for i in range(1, len(separated)):
                if re.match(punctuations, separated[i]):
                    separated[i-1] += separated[i]
                    separated[i] = ''

            separated = [s for s in separated if s != ""]
            if len(separated) == 1:
                separated = original_text.split('\n\n')
            separated = [s.strip() for s in separated if s.strip() != ""]
            if not return_dict:
                return separated
            else:
                pos = 0
                res = []
                for i, sent in enumerate(separated):
                    st = original_text.find(sent, pos)
                    assert st != -1, sent
                    ed = st + len(sent)
                    res.append(
                        {
                            'c_idx': i,
                            'content': sent,
                            'start_idx': st,
                            'end_idx': ed,
                        }
                    )
                    pos = ed
                return res 

        def get_prompt(context, question):
            sents = text_split_by_punctuation(context, return_dict=True)
            splited_context = ""
            for i, s in enumerate(sents):
                st, ed = s['start_idx'], s['end_idx']
                assert s['content'] == context[st:ed], s
                ed = sents[i+1]['start_idx'] if i < len(sents)-1 else len(context)
                sents[i] = {
                    'content': context[st:ed],
                    'start': st,
                    'end': ed,
                    'c_idx': s['c_idx'],
                }
                splited_context += f"<C{i}>"+context[st:ed]
            prompt = '''Please answer the user's question based on the following document. When a sentence S in your response uses information from some chunks in the document (i.e., <C{s1}>-<C_{e1}>, <C{s2}>-<C{e2}>, ...), please append these chunk numbers to S in the format "<statement>{S}<cite>[{s1}-{e1}][{s2}-{e2}]...</cite></statement>". You must answer in the same language as the user's question.\n\n[Document Start]\n%s\n[Document End]\n\n%s''' % (splited_context, question)
            return prompt, sents, splited_context
        
        def get_citations(statement, sents):
            c_texts = re.findall(r'<cite>(.*?)</cite>', statement, re.DOTALL)
            spans = sum([re.findall(r"\[([0-9]+\-[0-9]+)\]", c_text, re.DOTALL) for c_text in c_texts], [])
            statement = re.sub(r'<cite>(.*?)</cite>', '', statement, flags=re.DOTALL)
            merged_citations = []
            for i, s in enumerate(spans):
                try:
                    st, ed = [int(x) for x in s.split('-')]
                    if st > len(sents) - 1 or ed < st:
                        continue
                    st, ed = max(0, st), min(ed, len(sents)-1)
                    assert st <= ed, str(c_texts) + '\t' + str(len(sents))
                    if len(merged_citations) > 0 and st == merged_citations[-1]['end_sentence_idx'] + 1:
                        merged_citations[-1].update({
                            "end_sentence_idx": ed,
                            'end_char_idx': sents[ed]['end'],
                            'cite': ''.join([x['content'] for x in sents[merged_citations[-1]['start_sentence_idx']:ed+1]]),
                        })
                    else:
                        merged_citations.append({
                            "start_sentence_idx": st,
                            "end_sentence_idx": ed,
                            "start_char_idx":  sents[st]['start'],
                            'end_char_idx': sents[ed]['end'],
                            'cite': ''.join([x['content'] for x in sents[st:ed+1]]),
                        })
                except:
                    print(c_texts, len(sents), statement)
                    raise
            return statement, merged_citations[:3]
        
        def postprocess(answer, sents, splited_context):
            res = []
            pos = 0
            new_answer = ""
            while True:
                st = answer.find("<statement>", pos)
                if st == -1:
                    st = len(answer)
                ed = answer.find("</statement>", st)
                statement = answer[pos:st]
                if len(statement.strip()) > 5:
                    res.append({
                        "statement": statement,
                        "citation": []
                    })
                    new_answer += f"<statement>{statement}<cite></cite></statement>"
                else:
                    res.append({
                        "statement": statement,
                        "citation": None,
                    })
                    new_answer += statement
                
                if ed == -1:
                    break

                statement = answer[st+len("<statement>"):ed]
                if len(statement.strip()) > 0:
                    statement, citations = get_citations(statement, sents)
                    res.append({
                        "statement": statement,
                        "citation": citations
                    })
                    c_str = ''.join(['[{}-{}]'.format(c['start_sentence_idx'], c['end_sentence_idx']) for c in citations])
                    new_answer += f"<statement>{statement}<cite>{c_str}</cite></statement>"
                else:
                    res.append({
                        "statement": statement,
                        "citation": None,
                    })
                    new_answer += statement
                pos = ed + len("</statement>")
            return {
                "answer": new_answer.strip(),
                "statements_with_citations": [x for x in res if x['citation'] is not None],
                "splited_context": splited_context.strip(),
                "all_statements": res,
            }

        def truncate_from_middle(prompt, max_input_length=None, tokenizer=None):
            if max_input_length is None:
                return prompt
            else:
                assert tokenizer is not None
                tokenized_prompt = tokenizer.encode(prompt, add_special_tokens=False)
                if len(tokenized_prompt) > max_input_length:
                    half = int(max_input_length/2)
                    prompt = tokenizer.decode(tokenized_prompt[:half], skip_special_tokens=True)+tokenizer.decode(tokenized_prompt[-half:], skip_special_tokens=True)
                return prompt
        
        prompt, sents, splited_context = get_prompt(context, query)
        prompt = truncate_from_middle(prompt, max_input_length, tokenizer)
        output, _ = self.chat(tokenizer, prompt, history=[], max_new_tokens=max_new_tokens, temperature=temperature)
        result = postprocess(output, sents, splited_context)
        return result


if __name__ == "__main__":
    model_path = "THUDM/LongCite-glm4-9b"
    model = LongCiteModel(
        model= model_path,
        dtype=torch.bfloat16,
        trust_remote_code=True,
        tensor_parallel_size=1,
        max_model_len=131072,
        gpu_memory_utilization=1,
    )
    tokenizer = model.get_tokenizer()

    context = '''
    W. Russell Todd, 94, United States Army general (b. 1928). February 13. Tim Aymar, 59, heavy metal singer (Pharaoh) (b. 1963). Marshall \"Eddie\" Conway, 76, Black Panther Party leader (b. 1946). Roger Bonk, 78, football player (North Dakota Fighting Sioux, Winnipeg Blue Bombers) (b. 1944). Conrad Dobler, 72, football player (St. Louis Cardinals, New Orleans Saints, Buffalo Bills) (b. 1950). Brian DuBois, 55, baseball player (Detroit Tigers) (b. 1967). Robert Geddes, 99, architect, dean of the Princeton University School of Architecture (1965–1982) (b. 1923). Tom Luddy, 79, film producer (Barfly, The Secret Garden), co-founder of the Telluride Film Festival (b. 1943). David Singmaster, 84, mathematician (b. 1938).
    '''
    query = "What was Robert Geddes' profession?"
    result = model.query_longcite(context, query, tokenizer=tokenizer, max_input_length=128000, max_new_tokens=1024)

    print("Answer:")
    print(result['answer'])
    print('\n')
    print("Statement with citations:" )
    print(json.dumps(result['statements_with_citations'], indent=2, ensure_ascii=False))
    print('\n')
    print("Context (divided into sentences):")
    print(result['splited_context'])