Update modeling_codeshell.py
Browse files- modeling_codeshell.py +121 -5
modeling_codeshell.py
CHANGED
@@ -32,14 +32,17 @@
|
|
32 |
"""PyTorch CodeShell model."""
|
33 |
import os
|
34 |
import math
|
35 |
-
from typing import List, Optional, Tuple, Union
|
|
|
|
|
|
|
36 |
|
37 |
import torch
|
38 |
import torch.utils.checkpoint
|
39 |
from torch import nn
|
40 |
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
41 |
|
42 |
-
from transformers import PreTrainedModel, PretrainedConfig
|
43 |
from transformers.generation.utils import GenerationConfig
|
44 |
|
45 |
from transformers.activations import ACT2FN
|
@@ -54,7 +57,6 @@ from transformers.utils import (
|
|
54 |
)
|
55 |
from .configuration_codeshell import CodeShellConfig
|
56 |
|
57 |
-
|
58 |
# Fused kernels
|
59 |
# Use separate functions for each case because conditionals prevent kernel fusion.
|
60 |
# TODO: Could have better fused kernels depending on scaling, dropout and head mask.
|
@@ -743,6 +745,62 @@ class CodeShellModel(CodeShellPreTrainedModel):
|
|
743 |
hidden_states=all_hidden_states,
|
744 |
attentions=all_self_attentions,
|
745 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
746 |
|
747 |
|
748 |
@add_start_docstrings(
|
@@ -886,6 +944,65 @@ class CodeShellForCausalLM(CodeShellPreTrainedModel):
|
|
886 |
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
|
887 |
)
|
888 |
return reordered_past
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
889 |
|
890 |
class CodeShell4bitForCausalLM(CodeShellForCausalLM):
|
891 |
def __init__(self, config):
|
@@ -966,5 +1083,4 @@ class CodeShell4bitForCausalLM(CodeShellForCausalLM):
|
|
966 |
if device_map is not None:
|
967 |
model = model.to(torch.device(device_map))
|
968 |
|
969 |
-
return model
|
970 |
-
|
|
|
32 |
"""PyTorch CodeShell model."""
|
33 |
import os
|
34 |
import math
|
35 |
+
from typing import List, Optional, Tuple, Union, Callable
|
36 |
+
from threading import Thread
|
37 |
+
from queue import Queue
|
38 |
+
|
39 |
|
40 |
import torch
|
41 |
import torch.utils.checkpoint
|
42 |
from torch import nn
|
43 |
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
44 |
|
45 |
+
from transformers import LogitsProcessorList, StoppingCriteriaList, StoppingCriteria, PreTrainedModel, PretrainedConfig
|
46 |
from transformers.generation.utils import GenerationConfig
|
47 |
|
48 |
from transformers.activations import ACT2FN
|
|
|
57 |
)
|
58 |
from .configuration_codeshell import CodeShellConfig
|
59 |
|
|
|
60 |
# Fused kernels
|
61 |
# Use separate functions for each case because conditionals prevent kernel fusion.
|
62 |
# TODO: Could have better fused kernels depending on scaling, dropout and head mask.
|
|
|
745 |
hidden_states=all_hidden_states,
|
746 |
attentions=all_self_attentions,
|
747 |
)
|
748 |
+
|
749 |
+
class EndOfFunctionCriteria(StoppingCriteria):
|
750 |
+
"""Custom `StoppingCriteria` which checks if all generated functions in the batch are completed."""
|
751 |
+
def __init__(self, input_lengths, eof_strings, tokenizer):
|
752 |
+
self.input_lengths = input_lengths
|
753 |
+
self.eof_strings = eof_strings
|
754 |
+
self.tokenizer = tokenizer
|
755 |
+
|
756 |
+
def __call__(self, input_ids, scores, **kwargs):
|
757 |
+
"""Returns true if all generated sequences contain any of the end-of-function strings."""
|
758 |
+
decoded_generations = []
|
759 |
+
for _input_ids, input_length in zip(input_ids, self.input_lengths):
|
760 |
+
decoded_generations.append(self.tokenizer.decode(_input_ids[input_length:]))
|
761 |
+
done = []
|
762 |
+
for decoded_generation in decoded_generations:
|
763 |
+
done.append(
|
764 |
+
any(
|
765 |
+
[
|
766 |
+
stop_string in decoded_generation
|
767 |
+
for stop_string in self.eof_strings
|
768 |
+
]
|
769 |
+
)
|
770 |
+
)
|
771 |
+
return all(done)
|
772 |
+
|
773 |
+
class TextIterStreamer:
|
774 |
+
def __init__(self, tokenizer, skip_prompt=False, skip_special_tokens=False):
|
775 |
+
self.tokenizer = tokenizer
|
776 |
+
self.skip_prompt = skip_prompt
|
777 |
+
self.skip_special_tokens = skip_special_tokens
|
778 |
+
self.tokens = []
|
779 |
+
self.text_queue = Queue()
|
780 |
+
self.next_tokens_are_prompt = True
|
781 |
+
|
782 |
+
def put(self, value):
|
783 |
+
if self.skip_prompt and self.next_tokens_are_prompt:
|
784 |
+
self.next_tokens_are_prompt = False
|
785 |
+
else:
|
786 |
+
if len(value.shape) > 1:
|
787 |
+
value = value[0]
|
788 |
+
self.tokens.extend(value.tolist())
|
789 |
+
self.text_queue.put(
|
790 |
+
self.tokenizer.decode(self.tokens, skip_special_tokens=self.skip_special_tokens))
|
791 |
+
|
792 |
+
def end(self):
|
793 |
+
self.text_queue.put(None)
|
794 |
+
|
795 |
+
def __iter__(self):
|
796 |
+
return self
|
797 |
+
|
798 |
+
def __next__(self):
|
799 |
+
value = self.text_queue.get()
|
800 |
+
if value is None:
|
801 |
+
raise StopIteration()
|
802 |
+
else:
|
803 |
+
return value
|
804 |
|
805 |
|
806 |
@add_start_docstrings(
|
|
|
944 |
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
|
945 |
)
|
946 |
return reordered_past
|
947 |
+
|
948 |
+
|
949 |
+
def build_chat_input(self, query, history, tokenizer, max_new_tokens=None):
|
950 |
+
user_name = "\n## human:"
|
951 |
+
ai_name = "\n## assistant: "
|
952 |
+
stop = '|<end>|'
|
953 |
+
|
954 |
+
prompt = ''
|
955 |
+
for q, r in history:
|
956 |
+
prompt += f"{user_name}{q}{stop}"
|
957 |
+
prompt += f"{ai_name}{r}{stop}"
|
958 |
+
prompt += f"{user_name}{query}{stop}"
|
959 |
+
prompt += ai_name.rstrip()
|
960 |
+
|
961 |
+
max_new_tokens = max_new_tokens or self.generation_config.max_new_tokens
|
962 |
+
max_input_tokens = self.config.n_positions - max_new_tokens
|
963 |
+
|
964 |
+
input_tokens = tokenizer.encode(prompt)
|
965 |
+
input_tokens = input_tokens[-max_input_tokens:] # truncate left
|
966 |
+
return torch.LongTensor([input_tokens]).to(self.device)
|
967 |
+
|
968 |
+
def chat(self, query, history, tokenizer, stream=False,
|
969 |
+
generation_config: Optional[GenerationConfig]=None):
|
970 |
+
generation_config = generation_config or self.generation_config
|
971 |
+
input_ids = self.build_chat_input(query, history, tokenizer, generation_config.max_new_tokens)
|
972 |
+
stopping_criteria = StoppingCriteriaList(
|
973 |
+
[EndOfFunctionCriteria([len(input_ids[0])], ['|<end>|', '<|endoftext|>'], tokenizer)]
|
974 |
+
)
|
975 |
+
|
976 |
+
if stream:
|
977 |
+
streamer = TextIterStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
|
978 |
+
Thread(target=self.generate, kwargs=dict(
|
979 |
+
inputs=input_ids, streamer=streamer,
|
980 |
+
stopping_criteria = stopping_criteria,
|
981 |
+
generation_config=generation_config,
|
982 |
+
)).start()
|
983 |
+
return streamer
|
984 |
+
else:
|
985 |
+
outputs = self.generate(input_ids, generation_config=generation_config, stopping_criteria = stopping_criteria)
|
986 |
+
response = tokenizer.decode(outputs[0][len(input_ids[0]):], skip_special_tokens=True)
|
987 |
+
return response
|
988 |
+
|
989 |
+
def generate_stream(self, prompt, tokenizer, generation_config=None, **kwargs):
|
990 |
+
generation_config = generation_config or self.generation_config
|
991 |
+
max_input_tokens = self.config.n_positions - self.generation_config.max_new_tokens
|
992 |
+
|
993 |
+
input_ids = tokenizer.encode(prompt)
|
994 |
+
input_ids = input_ids[-max_input_tokens:] # truncate left
|
995 |
+
|
996 |
+
stopping_criteria = StoppingCriteriaList(
|
997 |
+
[EndOfFunctionCriteria([len(input_ids[0])], ['|<end>|', '<|endoftext|>'], tokenizer)]
|
998 |
+
)
|
999 |
+
|
1000 |
+
streamer = TextIterStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
|
1001 |
+
Thread(target=self.generate, kwargs=dict(
|
1002 |
+
inputs=input_ids, stopping_criteria=stopping_criteria, **kwargs
|
1003 |
+
)).start()
|
1004 |
+
return streamer
|
1005 |
+
|
1006 |
|
1007 |
class CodeShell4bitForCausalLM(CodeShellForCausalLM):
|
1008 |
def __init__(self, config):
|
|
|
1083 |
if device_map is not None:
|
1084 |
model = model.to(torch.device(device_map))
|
1085 |
|
1086 |
+
return model
|
|