EC2 Default User commited on
Commit
564cdc6
1 Parent(s): ec8533b

Add lora model and custom inference file

Browse files
handler.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import logging
3
+ import torch
4
+ from typing import List
5
+ from typing import Dict, Any
6
+ from transformers import AutoModelForCausalLM, AutoTokenizer, StoppingCriteria
7
+ import torch
8
+
9
+
10
+ class MyStoppingCriteria(StoppingCriteria):
11
+ def __init__(self, target_sequence, prompt, tokenizer):
12
+ self.target_sequence = target_sequence
13
+ self.prompt = prompt
14
+ self.tokenizer = tokenizer
15
+
16
+ def __call__(self, input_ids, scores, **kwargs):
17
+ # Get the generated text as a string
18
+ generated_text = self.tokenizer.decode(input_ids[0])
19
+ generated_text = generated_text.replace(self.prompt, '')
20
+ # Check if the target sequence appears in the generated text
21
+ if self.target_sequence in generated_text:
22
+ return True # Stop generation
23
+
24
+ return False # Continue generation
25
+
26
+ def __len__(self):
27
+ return 1
28
+
29
+ def __iter__(self):
30
+ yield self
31
+
32
+
33
+ class EndpointHandler:
34
+ def __init__(self, model_dir=""):
35
+ # load model and processor from path
36
+ self.tokenizer = AutoTokenizer.from_pretrained(model_dir)
37
+ self.model = AutoModelForCausalLM.from_pretrained(model_dir, load_in_4bit=True, device_map="auto")
38
+
39
+ self.template = {
40
+ "prompt_input": """Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n""",
41
+ "prompt_no_input": """Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Response:\n""",
42
+ "response_split": """### Response:"""
43
+ }
44
+ self.instruction = """Extract the start and end sequences for the categories 'personal information', 'work experience', 'education' and 'skills' from the following text in dictionary form"""
45
+
46
+ if torch.cuda.is_available():
47
+ self.device = "cuda"
48
+ else:
49
+ self.device = "cpu"
50
+
51
+
52
+ def __call__(self, data: Dict[str, Any]) -> Dict[str, str]:
53
+ """
54
+ Args:
55
+ data (:dict:):
56
+ The payload with the text prompt and generation parameters.
57
+ """
58
+ # process input
59
+ inputs = data.pop("inputs", data)
60
+ parameters = data.pop("parameters", None)
61
+
62
+ res = self.template["prompt_input"].format(
63
+ instruction=self.instruction, input=input
64
+ )
65
+ messages = [
66
+ {"role": "user", "content": res},
67
+ ]
68
+ input_ids = self.tokenizer.apply_chat_template(
69
+ messages, truncation=True, add_generation_prompt=True, return_tensors="pt"
70
+ ).input_ids
71
+ input_ids = input_ids.to(self.device)
72
+
73
+ # pass inputs with all kwargs in data
74
+ if parameters is not None:
75
+ outputs = self.model.generate(
76
+ input_ids=input_ids,
77
+ stopping_criteria=MyStoppingCriteria("</s>", inputs, self.tokenizer),
78
+ **parameters)
79
+ else:
80
+ outputs = self.model.generate(
81
+ input_ids=input_ids, max_new_tokens=32,
82
+ stopping_criteria=MyStoppingCriteria("</s>", inputs, self.tokenizer)
83
+ )
84
+
85
+ # postprocess the prediction
86
+ prediction = self.tokenizer.decode(outputs[0][input_ids.shape[1]:]) #, skip_special_tokens=True)
87
+ prediction = prediction.split("</s>")[0]
88
+
89
+ # TODO: add processing of the LLM output
90
+
91
+ return [{"generated_text": prediction}]
model/adapter_config.json ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "alpha_pattern": {},
3
+ "auto_mapping": null,
4
+ "base_model_name_or_path": "mistralai/Mistral-7B-Instruct-v0.2",
5
+ "bias": "none",
6
+ "fan_in_fan_out": false,
7
+ "inference_mode": true,
8
+ "init_lora_weights": true,
9
+ "layers_pattern": null,
10
+ "layers_to_transform": null,
11
+ "loftq_config": {},
12
+ "lora_alpha": 16,
13
+ "lora_dropout": 0.05,
14
+ "megatron_config": null,
15
+ "megatron_core": "megatron.core",
16
+ "modules_to_save": [
17
+ "lm_head",
18
+ "embed_tokens"
19
+ ],
20
+ "peft_type": "LORA",
21
+ "r": 8,
22
+ "rank_pattern": {},
23
+ "revision": null,
24
+ "target_modules": [
25
+ "k_proj",
26
+ "o_proj",
27
+ "v_proj",
28
+ "q_proj",
29
+ "lm_head",
30
+ "embed_tokens"
31
+ ],
32
+ "task_type": "CAUSAL_LM",
33
+ "use_rslora": false
34
+ }
model/adapter_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4e8755b8e8b0e194e4db8d9cbcb6e7f81cbc205a286f7a99de201b60345371dc
3
+ size 3173223624
model/added_tokens.json ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ {
2
+ ", 'e':": 32001,
3
+ "{'s': '": 32000
4
+ }
model/special_tokens_map.json ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": {
3
+ "content": "<s>",
4
+ "lstrip": false,
5
+ "normalized": false,
6
+ "rstrip": false,
7
+ "single_word": false
8
+ },
9
+ "eos_token": {
10
+ "content": "</s>",
11
+ "lstrip": false,
12
+ "normalized": false,
13
+ "rstrip": false,
14
+ "single_word": false
15
+ },
16
+ "pad_token": "</s>",
17
+ "unk_token": {
18
+ "content": "<unk>",
19
+ "lstrip": false,
20
+ "normalized": false,
21
+ "rstrip": false,
22
+ "single_word": false
23
+ }
24
+ }
model/tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
model/tokenizer.model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:dadfd56d766715c61d2ef780a525ab43b8e6da4de6865bda3d95fdef5e134055
3
+ size 493443
model/tokenizer_config.json ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_bos_token": true,
3
+ "add_eos_token": false,
4
+ "added_tokens_decoder": {
5
+ "0": {
6
+ "content": "<unk>",
7
+ "lstrip": false,
8
+ "normalized": false,
9
+ "rstrip": false,
10
+ "single_word": false,
11
+ "special": true
12
+ },
13
+ "1": {
14
+ "content": "<s>",
15
+ "lstrip": false,
16
+ "normalized": false,
17
+ "rstrip": false,
18
+ "single_word": false,
19
+ "special": true
20
+ },
21
+ "2": {
22
+ "content": "</s>",
23
+ "lstrip": false,
24
+ "normalized": false,
25
+ "rstrip": false,
26
+ "single_word": false,
27
+ "special": true
28
+ },
29
+ "32000": {
30
+ "content": "{'s': '",
31
+ "lstrip": false,
32
+ "normalized": true,
33
+ "rstrip": false,
34
+ "single_word": false,
35
+ "special": false
36
+ },
37
+ "32001": {
38
+ "content": ", 'e':",
39
+ "lstrip": false,
40
+ "normalized": true,
41
+ "rstrip": false,
42
+ "single_word": false,
43
+ "special": false
44
+ }
45
+ },
46
+ "additional_special_tokens": [],
47
+ "bos_token": "<s>",
48
+ "chat_template": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}",
49
+ "clean_up_tokenization_spaces": false,
50
+ "eos_token": "</s>",
51
+ "legacy": true,
52
+ "model_max_length": 1000000000000000019884624838656,
53
+ "pad_token": "</s>",
54
+ "sp_model_kwargs": {},
55
+ "spaces_between_special_tokens": false,
56
+ "tokenizer_class": "LlamaTokenizer",
57
+ "unk_token": "<unk>",
58
+ "use_default_system_prompt": false
59
+ }