ga89tiy commited on
Commit
6edd88e
1 Parent(s): 1db0e44
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. LLAVA_Biovil/llava/eval/__init__.py +0 -0
  2. LLAVA_Biovil/llava/eval/eval_gpt_review.py +0 -113
  3. LLAVA_Biovil/llava/eval/eval_gpt_review_bench.py +0 -121
  4. LLAVA_Biovil/llava/eval/eval_gpt_review_visual.py +0 -118
  5. LLAVA_Biovil/llava/eval/eval_pope.py +0 -81
  6. LLAVA_Biovil/llava/eval/eval_science_qa.py +0 -114
  7. LLAVA_Biovil/llava/eval/eval_science_qa_gpt4.py +0 -104
  8. LLAVA_Biovil/llava/eval/eval_science_qa_gpt4_requery.py +0 -149
  9. LLAVA_Biovil/llava/eval/eval_textvqa.py +0 -65
  10. LLAVA_Biovil/llava/eval/generate_webpage_data_from_table.py +0 -111
  11. LLAVA_Biovil/llava/eval/m4c_evaluator.py +0 -334
  12. LLAVA_Biovil/llava/eval/model_qa.py +0 -85
  13. LLAVA_Biovil/llava/eval/model_vqa.py +0 -112
  14. LLAVA_Biovil/llava/eval/model_vqa_loader.py +0 -141
  15. LLAVA_Biovil/llava/eval/model_vqa_mmbench.py +0 -169
  16. LLAVA_Biovil/llava/eval/model_vqa_qbench.py +0 -120
  17. LLAVA_Biovil/llava/eval/model_vqa_science.py +0 -147
  18. LLAVA_Biovil/llava/eval/qa_baseline_gpt35.py +0 -74
  19. LLAVA_Biovil/llava/eval/run_llava.py +0 -155
  20. LLAVA_Biovil/llava/eval/summarize_gpt_review.py +0 -60
  21. LLAVA_Biovil/llava/eval/webpage/figures/alpaca.png +0 -0
  22. LLAVA_Biovil/llava/eval/webpage/figures/bard.jpg +0 -0
  23. LLAVA_Biovil/llava/eval/webpage/figures/chatgpt.svg +0 -1
  24. LLAVA_Biovil/llava/eval/webpage/figures/llama.jpg +0 -0
  25. LLAVA_Biovil/llava/eval/webpage/figures/swords_FILL0_wght300_GRAD0_opsz48.svg +0 -1
  26. LLAVA_Biovil/llava/eval/webpage/figures/vicuna.jpeg +0 -0
  27. LLAVA_Biovil/llava/eval/webpage/index.html +0 -162
  28. LLAVA_Biovil/llava/eval/webpage/script.js +0 -245
  29. LLAVA_Biovil/llava/eval/webpage/styles.css +0 -105
  30. LLAVA_Biovil/llava/mm_utils.py +1 -1
  31. LLAVA_Biovil/llava/model/apply_delta.py +1 -1
  32. LLAVA_Biovil/llava/model/builder.py +8 -8
  33. LLAVA_Biovil/llava/model/consolidate.py +1 -1
  34. LLAVA_Biovil/llava/model/language_model/llava_llama.py +1 -1
  35. LLAVA_Biovil/llava/model/language_model/llava_mpt.py +2 -2
  36. LLAVA_Biovil/llava/model/llava_arch.py +6 -6
  37. LLAVA_Biovil/llava/serve/__init__.py +0 -0
  38. LLAVA_Biovil/llava/serve/cli.py +0 -122
  39. LLAVA_Biovil/llava/serve/controller.py +0 -296
  40. LLAVA_Biovil/llava/serve/examples/extreme_ironing.jpg +0 -0
  41. LLAVA_Biovil/llava/serve/examples/waterview.jpg +0 -0
  42. LLAVA_Biovil/llava/serve/gradio_web_server.py +0 -470
  43. LLAVA_Biovil/llava/serve/model_worker.py +0 -310
  44. LLAVA_Biovil/llava/serve/register_worker.py +0 -26
  45. LLAVA_Biovil/llava/serve/test_message.py +0 -62
  46. LLAVA_Biovil/llava/train/__init__.py +0 -0
  47. LLAVA_Biovil/llava/train/llama_flash_attn_monkey_patch.py +0 -115
  48. LLAVA_Biovil/llava/train/llama_patch.py +0 -139
  49. LLAVA_Biovil/llava/train/llama_xformers_attn_monkey_patch.py +0 -129
  50. LLAVA_Biovil/llava/train/llava_trainer.py +0 -801
LLAVA_Biovil/llava/eval/__init__.py DELETED
File without changes
LLAVA_Biovil/llava/eval/eval_gpt_review.py DELETED
@@ -1,113 +0,0 @@
1
- import argparse
2
- import json
3
- import os
4
-
5
- import openai
6
- import tqdm
7
- import ray
8
- import time
9
-
10
- NUM_SECONDS_TO_SLEEP = 3
11
-
12
- @ray.remote(num_cpus=4)
13
- def get_eval(content: str, max_tokens: int):
14
- while True:
15
- try:
16
- response = openai.ChatCompletion.create(
17
- model='gpt-4',
18
- messages=[{
19
- 'role': 'system',
20
- 'content': 'You are a helpful and precise assistant for checking the quality of the answer.'
21
- }, {
22
- 'role': 'user',
23
- 'content': content,
24
- }],
25
- temperature=0.2, # TODO: figure out which temperature is best for evaluation
26
- max_tokens=max_tokens,
27
- )
28
- break
29
- except openai.error.RateLimitError:
30
- pass
31
- except Exception as e:
32
- print(e)
33
- time.sleep(NUM_SECONDS_TO_SLEEP)
34
-
35
- print('success!')
36
- return response['choices'][0]['message']['content']
37
-
38
-
39
- def parse_score(review):
40
- try:
41
- score_pair = review.split('\n')[0]
42
- score_pair = score_pair.replace(',', ' ')
43
- sp = score_pair.split(' ')
44
- if len(sp) == 2:
45
- return [float(sp[0]), float(sp[1])]
46
- else:
47
- print('error', review)
48
- return [-1, -1]
49
- except Exception as e:
50
- print(e)
51
- print('error', review)
52
- return [-1, -1]
53
-
54
-
55
- if __name__ == '__main__':
56
- parser = argparse.ArgumentParser(description='ChatGPT-based QA evaluation.')
57
- parser.add_argument('-q', '--question')
58
- # parser.add_argument('-a', '--answer')
59
- parser.add_argument('-a', '--answer-list', nargs='+', default=[])
60
- parser.add_argument('-r', '--rule')
61
- parser.add_argument('-o', '--output')
62
- parser.add_argument('--max-tokens', type=int, default=1024, help='maximum number of tokens produced in the output')
63
- args = parser.parse_args()
64
-
65
- ray.init()
66
-
67
- f_q = open(os.path.expanduser(args.question))
68
- f_ans1 = open(os.path.expanduser(args.answer_list[0]))
69
- f_ans2 = open(os.path.expanduser(args.answer_list[1]))
70
- rule_dict = json.load(open(os.path.expanduser(args.rule), 'r'))
71
-
72
- review_file = open(f'{args.output}', 'w')
73
-
74
- js_list = []
75
- handles = []
76
- idx = 0
77
- for ques_js, ans1_js, ans2_js in zip(f_q, f_ans1, f_ans2):
78
- # if idx == 1:
79
- # break
80
-
81
- ques = json.loads(ques_js)
82
- ans1 = json.loads(ans1_js)
83
- ans2 = json.loads(ans2_js)
84
-
85
- category = json.loads(ques_js)['category']
86
- if category in rule_dict:
87
- rule = rule_dict[category]
88
- else:
89
- rule = rule_dict['default']
90
- prompt = rule['prompt']
91
- role = rule['role']
92
- content = (f'[Question]\n{ques["text"]}\n\n'
93
- f'[{role} 1]\n{ans1["text"]}\n\n[End of {role} 1]\n\n'
94
- f'[{role} 2]\n{ans2["text"]}\n\n[End of {role} 2]\n\n'
95
- f'[System]\n{prompt}\n\n')
96
- js_list.append({
97
- 'id': idx+1,
98
- 'question_id': ques['question_id'],
99
- 'answer1_id': ans1['answer_id'],
100
- 'answer2_id': ans2['answer_id'],
101
- 'category': category})
102
- idx += 1
103
- handles.append(get_eval.remote(content, args.max_tokens))
104
- # To avoid the rate limit set by OpenAI
105
- time.sleep(NUM_SECONDS_TO_SLEEP)
106
-
107
- reviews = ray.get(handles)
108
- for idx, review in enumerate(reviews):
109
- scores = parse_score(review)
110
- js_list[idx]['content'] = review
111
- js_list[idx]['tuple'] = scores
112
- review_file.write(json.dumps(js_list[idx]) + '\n')
113
- review_file.close()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
LLAVA_Biovil/llava/eval/eval_gpt_review_bench.py DELETED
@@ -1,121 +0,0 @@
1
- import argparse
2
- import json
3
- import os
4
-
5
- import openai
6
- import time
7
-
8
- NUM_SECONDS_TO_SLEEP = 0.5
9
-
10
-
11
- def get_eval(content: str, max_tokens: int):
12
- while True:
13
- try:
14
- response = openai.ChatCompletion.create(
15
- model='gpt-4-0314',
16
- messages=[{
17
- 'role': 'system',
18
- 'content': 'You are a helpful and precise assistant for checking the quality of the answer.'
19
- }, {
20
- 'role': 'user',
21
- 'content': content,
22
- }],
23
- temperature=0.2, # TODO: figure out which temperature is best for evaluation
24
- max_tokens=max_tokens,
25
- )
26
- break
27
- except openai.error.RateLimitError:
28
- pass
29
- except Exception as e:
30
- print(e)
31
- time.sleep(NUM_SECONDS_TO_SLEEP)
32
-
33
- return response['choices'][0]['message']['content']
34
-
35
-
36
- def parse_score(review):
37
- try:
38
- score_pair = review.split('\n')[0]
39
- score_pair = score_pair.replace(',', ' ')
40
- sp = score_pair.split(' ')
41
- if len(sp) == 2:
42
- return [float(sp[0]), float(sp[1])]
43
- else:
44
- print('error', review)
45
- return [-1, -1]
46
- except Exception as e:
47
- print(e)
48
- print('error', review)
49
- return [-1, -1]
50
-
51
-
52
- if __name__ == '__main__':
53
- parser = argparse.ArgumentParser(description='ChatGPT-based QA evaluation.')
54
- parser.add_argument('-q', '--question')
55
- parser.add_argument('-c', '--context')
56
- parser.add_argument('-a', '--answer-list', nargs='+', default=[])
57
- parser.add_argument('-r', '--rule')
58
- parser.add_argument('-o', '--output')
59
- parser.add_argument('--max-tokens', type=int, default=1024, help='maximum number of tokens produced in the output')
60
- args = parser.parse_args()
61
-
62
- f_q = open(os.path.expanduser(args.question))
63
- f_ans1 = open(os.path.expanduser(args.answer_list[0]))
64
- f_ans2 = open(os.path.expanduser(args.answer_list[1]))
65
- rule_dict = json.load(open(os.path.expanduser(args.rule), 'r'))
66
-
67
- if os.path.isfile(os.path.expanduser(args.output)):
68
- cur_reviews = [json.loads(line) for line in open(os.path.expanduser(args.output))]
69
- else:
70
- cur_reviews = []
71
-
72
- review_file = open(f'{args.output}', 'a')
73
-
74
- context_list = [json.loads(line) for line in open(os.path.expanduser(args.context))]
75
- image_to_context = {context['image']: context for context in context_list}
76
-
77
- handles = []
78
- idx = 0
79
- for ques_js, ans1_js, ans2_js in zip(f_q, f_ans1, f_ans2):
80
- ques = json.loads(ques_js)
81
- ans1 = json.loads(ans1_js)
82
- ans2 = json.loads(ans2_js)
83
-
84
- inst = image_to_context[ques['image']]
85
-
86
- if isinstance(inst['caption'], list):
87
- cap_str = '\n'.join(inst['caption'])
88
- else:
89
- cap_str = inst['caption']
90
-
91
- category = 'llava_bench_' + json.loads(ques_js)['category']
92
- if category in rule_dict:
93
- rule = rule_dict[category]
94
- else:
95
- assert False, f"Visual QA category not found in rule file: {category}."
96
- prompt = rule['prompt']
97
- role = rule['role']
98
- content = (f'[Context]\n{cap_str}\n\n'
99
- f'[Question]\n{ques["text"]}\n\n'
100
- f'[{role} 1]\n{ans1["text"]}\n\n[End of {role} 1]\n\n'
101
- f'[{role} 2]\n{ans2["text"]}\n\n[End of {role} 2]\n\n'
102
- f'[System]\n{prompt}\n\n')
103
- cur_js = {
104
- 'id': idx+1,
105
- 'question_id': ques['question_id'],
106
- 'answer1_id': ans1.get('answer_id', ans1['question_id']),
107
- 'answer2_id': ans2.get('answer_id', ans2['answer_id']),
108
- 'category': category
109
- }
110
- if idx >= len(cur_reviews):
111
- review = get_eval(content, args.max_tokens)
112
- scores = parse_score(review)
113
- cur_js['content'] = review
114
- cur_js['tuple'] = scores
115
- review_file.write(json.dumps(cur_js) + '\n')
116
- review_file.flush()
117
- else:
118
- print(f'Skipping {idx} as we already have it.')
119
- idx += 1
120
- print(idx)
121
- review_file.close()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
LLAVA_Biovil/llava/eval/eval_gpt_review_visual.py DELETED
@@ -1,118 +0,0 @@
1
- import argparse
2
- import json
3
- import os
4
-
5
- import openai
6
- import time
7
-
8
- NUM_SECONDS_TO_SLEEP = 0.5
9
-
10
-
11
- def get_eval(content: str, max_tokens: int):
12
- while True:
13
- try:
14
- response = openai.ChatCompletion.create(
15
- model='gpt-4-0314',
16
- messages=[{
17
- 'role': 'system',
18
- 'content': 'You are a helpful and precise assistant for checking the quality of the answer.'
19
- }, {
20
- 'role': 'user',
21
- 'content': content,
22
- }],
23
- temperature=0.2, # TODO: figure out which temperature is best for evaluation
24
- max_tokens=max_tokens,
25
- )
26
- break
27
- except openai.error.RateLimitError:
28
- pass
29
- except Exception as e:
30
- print(e)
31
- time.sleep(NUM_SECONDS_TO_SLEEP)
32
-
33
- return response['choices'][0]['message']['content']
34
-
35
-
36
- def parse_score(review):
37
- try:
38
- score_pair = review.split('\n')[0]
39
- score_pair = score_pair.replace(',', ' ')
40
- sp = score_pair.split(' ')
41
- if len(sp) == 2:
42
- return [float(sp[0]), float(sp[1])]
43
- else:
44
- print('error', review)
45
- return [-1, -1]
46
- except Exception as e:
47
- print(e)
48
- print('error', review)
49
- return [-1, -1]
50
-
51
-
52
- if __name__ == '__main__':
53
- parser = argparse.ArgumentParser(description='ChatGPT-based QA evaluation.')
54
- parser.add_argument('-q', '--question')
55
- parser.add_argument('-c', '--context')
56
- parser.add_argument('-a', '--answer-list', nargs='+', default=[])
57
- parser.add_argument('-r', '--rule')
58
- parser.add_argument('-o', '--output')
59
- parser.add_argument('--max-tokens', type=int, default=1024, help='maximum number of tokens produced in the output')
60
- args = parser.parse_args()
61
-
62
- f_q = open(os.path.expanduser(args.question))
63
- f_ans1 = open(os.path.expanduser(args.answer_list[0]))
64
- f_ans2 = open(os.path.expanduser(args.answer_list[1]))
65
- rule_dict = json.load(open(os.path.expanduser(args.rule), 'r'))
66
-
67
- if os.path.isfile(os.path.expanduser(args.output)):
68
- cur_reviews = [json.loads(line) for line in open(os.path.expanduser(args.output))]
69
- else:
70
- cur_reviews = []
71
-
72
- review_file = open(f'{args.output}', 'a')
73
-
74
- context_list = [json.loads(line) for line in open(os.path.expanduser(args.context))]
75
- image_to_context = {context['image']: context for context in context_list}
76
-
77
- handles = []
78
- idx = 0
79
- for ques_js, ans1_js, ans2_js in zip(f_q, f_ans1, f_ans2):
80
- ques = json.loads(ques_js)
81
- ans1 = json.loads(ans1_js)
82
- ans2 = json.loads(ans2_js)
83
-
84
- inst = image_to_context[ques['image']]
85
- cap_str = '\n'.join(inst['captions'])
86
- box_str = '\n'.join([f'{instance["category"]}: {instance["bbox"]}' for instance in inst['instances']])
87
-
88
- category = json.loads(ques_js)['category']
89
- if category in rule_dict:
90
- rule = rule_dict[category]
91
- else:
92
- assert False, f"Visual QA category not found in rule file: {category}."
93
- prompt = rule['prompt']
94
- role = rule['role']
95
- content = (f'[Context]\n{cap_str}\n\n{box_str}\n\n'
96
- f'[Question]\n{ques["text"]}\n\n'
97
- f'[{role} 1]\n{ans1["text"]}\n\n[End of {role} 1]\n\n'
98
- f'[{role} 2]\n{ans2["text"]}\n\n[End of {role} 2]\n\n'
99
- f'[System]\n{prompt}\n\n')
100
- cur_js = {
101
- 'id': idx+1,
102
- 'question_id': ques['question_id'],
103
- 'answer1_id': ans1.get('answer_id', ans1['question_id']),
104
- 'answer2_id': ans2.get('answer_id', ans2['answer_id']),
105
- 'category': category
106
- }
107
- if idx >= len(cur_reviews):
108
- review = get_eval(content, args.max_tokens)
109
- scores = parse_score(review)
110
- cur_js['content'] = review
111
- cur_js['tuple'] = scores
112
- review_file.write(json.dumps(cur_js) + '\n')
113
- review_file.flush()
114
- else:
115
- print(f'Skipping {idx} as we already have it.')
116
- idx += 1
117
- print(idx)
118
- review_file.close()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
LLAVA_Biovil/llava/eval/eval_pope.py DELETED
@@ -1,81 +0,0 @@
1
- import os
2
- import json
3
- import argparse
4
-
5
- def eval_pope(answers, label_file):
6
- label_list = [json.loads(q)['label'] for q in open(label_file, 'r')]
7
-
8
- for answer in answers:
9
- text = answer['text']
10
-
11
- # Only keep the first sentence
12
- if text.find('.') != -1:
13
- text = text.split('.')[0]
14
-
15
- text = text.replace(',', '')
16
- words = text.split(' ')
17
- if 'No' in words or 'not' in words or 'no' in words:
18
- answer['text'] = 'no'
19
- else:
20
- answer['text'] = 'yes'
21
-
22
- for i in range(len(label_list)):
23
- if label_list[i] == 'no':
24
- label_list[i] = 0
25
- else:
26
- label_list[i] = 1
27
-
28
- pred_list = []
29
- for answer in answers:
30
- if answer['text'] == 'no':
31
- pred_list.append(0)
32
- else:
33
- pred_list.append(1)
34
-
35
- pos = 1
36
- neg = 0
37
- yes_ratio = pred_list.count(1) / len(pred_list)
38
-
39
- TP, TN, FP, FN = 0, 0, 0, 0
40
- for pred, label in zip(pred_list, label_list):
41
- if pred == pos and label == pos:
42
- TP += 1
43
- elif pred == pos and label == neg:
44
- FP += 1
45
- elif pred == neg and label == neg:
46
- TN += 1
47
- elif pred == neg and label == pos:
48
- FN += 1
49
-
50
- print('TP\tFP\tTN\tFN\t')
51
- print('{}\t{}\t{}\t{}'.format(TP, FP, TN, FN))
52
-
53
- precision = float(TP) / float(TP + FP)
54
- recall = float(TP) / float(TP + FN)
55
- f1 = 2*precision*recall / (precision + recall)
56
- acc = (TP + TN) / (TP + TN + FP + FN)
57
- print('Accuracy: {}'.format(acc))
58
- print('Precision: {}'.format(precision))
59
- print('Recall: {}'.format(recall))
60
- print('F1 score: {}'.format(f1))
61
- print('Yes ratio: {}'.format(yes_ratio))
62
- print('%.3f, %.3f, %.3f, %.3f, %.3f' % (f1, acc, precision, recall, yes_ratio) )
63
-
64
- if __name__ == "__main__":
65
- parser = argparse.ArgumentParser()
66
- parser.add_argument("--annotation-dir", type=str)
67
- parser.add_argument("--question-file", type=str)
68
- parser.add_argument("--result-file", type=str)
69
- args = parser.parse_args()
70
-
71
- questions = [json.loads(line) for line in open(args.question_file)]
72
- questions = {question['question_id']: question for question in questions}
73
- answers = [json.loads(q) for q in open(args.result_file)]
74
- for file in os.listdir(args.annotation_dir):
75
- assert file.startswith('coco_pope_')
76
- assert file.endswith('.json')
77
- category = file[10:-5]
78
- cur_answers = [x for x in answers if questions[x['question_id']]['category'] == category]
79
- print('Category: {}, # samples: {}'.format(category, len(cur_answers)))
80
- eval_pope(cur_answers, os.path.join(args.annotation_dir, file))
81
- print("====================================")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
LLAVA_Biovil/llava/eval/eval_science_qa.py DELETED
@@ -1,114 +0,0 @@
1
- import argparse
2
- import json
3
- import os
4
- import re
5
- import random
6
-
7
-
8
- def get_args():
9
- parser = argparse.ArgumentParser()
10
- parser.add_argument('--base-dir', type=str)
11
- parser.add_argument('--result-file', type=str)
12
- parser.add_argument('--output-file', type=str)
13
- parser.add_argument('--output-result', type=str)
14
- parser.add_argument('--split', type=str, default='test')
15
- parser.add_argument('--options', type=list, default=["A", "B", "C", "D", "E"])
16
- return parser.parse_args()
17
-
18
-
19
- def convert_caps(results):
20
- fakecaps = []
21
- for result in results:
22
- image_id = result['question_id']
23
- caption = result['text']
24
- fakecaps.append({"image_id": int(image_id), "caption": caption})
25
- return fakecaps
26
-
27
-
28
- def get_pred_idx(prediction, choices, options):
29
- """
30
- Get the index (e.g. 2) from the prediction (e.g. 'C')
31
- """
32
- if prediction in options[:len(choices)]:
33
- return options.index(prediction)
34
- else:
35
- return -1
36
- return random.choice(range(len(choices)))
37
-
38
-
39
- if __name__ == "__main__":
40
- args = get_args()
41
-
42
- base_dir = args.base_dir
43
- split_indices = json.load(open(os.path.join(base_dir, "pid_splits.json")))[args.split]
44
- problems = json.load(open(os.path.join(base_dir, "problems.json")))
45
- predictions = [json.loads(line) for line in open(args.result_file)]
46
- predictions = {pred['question_id']: pred for pred in predictions}
47
- split_problems = {idx: problems[idx] for idx in split_indices}
48
-
49
- results = {'correct': [], 'incorrect': []}
50
- sqa_results = {}
51
- sqa_results['acc'] = None
52
- sqa_results['correct'] = None
53
- sqa_results['count'] = None
54
- sqa_results['results'] = {}
55
- sqa_results['outputs'] = {}
56
-
57
- for prob_id, prob in split_problems.items():
58
- if prob_id not in predictions:
59
- pred = {'text': 'FAILED', 'prompt': 'Unknown'}
60
- pred_text = 'FAILED'
61
- else:
62
- pred = predictions[prob_id]
63
- pred_text = pred['text']
64
-
65
- if pred_text in args.options:
66
- answer = pred_text
67
- elif len(pred_text) >= 3 and pred_text[0] in args.options and pred_text[1:3] == ". ":
68
- answer = pred_text[0]
69
- else:
70
- pattern = re.compile(r'The answer is ([A-Z]).')
71
- res = pattern.findall(pred_text)
72
- if len(res) == 1:
73
- answer = res[0] # 'A', 'B', ...
74
- else:
75
- answer = "FAILED"
76
-
77
- pred_idx = get_pred_idx(answer, prob['choices'], args.options)
78
-
79
- analysis = {
80
- 'question_id': prob_id,
81
- 'parsed_ans': answer,
82
- 'ground_truth': args.options[prob['answer']],
83
- 'question': pred['prompt'],
84
- 'pred': pred_text,
85
- 'is_multimodal': '<image>' in pred['prompt'],
86
- }
87
-
88
- sqa_results['results'][prob_id] = get_pred_idx(answer, prob['choices'], args.options)
89
- sqa_results['outputs'][prob_id] = pred_text
90
-
91
- if pred_idx == prob['answer']:
92
- results['correct'].append(analysis)
93
- else:
94
- results['incorrect'].append(analysis)
95
-
96
- correct = len(results['correct'])
97
- total = len(results['correct']) + len(results['incorrect'])
98
-
99
- ###### IMG ######
100
- multimodal_correct = len([x for x in results['correct'] if x['is_multimodal']])
101
- multimodal_incorrect = len([x for x in results['incorrect'] if x['is_multimodal']])
102
- multimodal_total = multimodal_correct + multimodal_incorrect
103
- ###### IMG ######
104
-
105
- print(f'Total: {total}, Correct: {correct}, Accuracy: {correct / total * 100:.2f}%, IMG-Accuracy: {multimodal_correct / multimodal_total * 100:.2f}%')
106
-
107
- sqa_results['acc'] = correct / total * 100
108
- sqa_results['correct'] = correct
109
- sqa_results['count'] = total
110
-
111
- with open(args.output_file, 'w') as f:
112
- json.dump(results, f, indent=2)
113
- with open(args.output_result, 'w') as f:
114
- json.dump(sqa_results, f, indent=2)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
LLAVA_Biovil/llava/eval/eval_science_qa_gpt4.py DELETED
@@ -1,104 +0,0 @@
1
- import argparse
2
- import json
3
- import os
4
- import re
5
- import random
6
- from collections import defaultdict
7
-
8
-
9
- def get_args():
10
- parser = argparse.ArgumentParser()
11
- parser.add_argument('--base-dir', type=str)
12
- parser.add_argument('--gpt4-result', type=str)
13
- parser.add_argument('--our-result', type=str)
14
- parser.add_argument('--split', type=str, default='test')
15
- parser.add_argument('--options', type=list, default=["A", "B", "C", "D", "E"])
16
- return parser.parse_args()
17
-
18
-
19
- def convert_caps(results):
20
- fakecaps = []
21
- for result in results:
22
- image_id = result['question_id']
23
- caption = result['text']
24
- fakecaps.append({"image_id": int(image_id), "caption": caption})
25
- return fakecaps
26
-
27
-
28
- def get_pred_idx(prediction, choices, options):
29
- """
30
- Get the index (e.g. 2) from the prediction (e.g. 'C')
31
- """
32
- if prediction in options[:len(choices)]:
33
- return options.index(prediction)
34
- else:
35
- return random.choice(range(len(choices)))
36
-
37
-
38
- if __name__ == "__main__":
39
- args = get_args()
40
-
41
- base_dir = args.base_dir
42
- split_indices = json.load(open(os.path.join(base_dir, "pid_splits.json")))[args.split]
43
- problems = json.load(open(os.path.join(base_dir, "problems.json")))
44
- our_predictions = [json.loads(line) for line in open(args.our_result)]
45
- our_predictions = {pred['question_id']: pred for pred in our_predictions}
46
- split_problems = {idx: problems[idx] for idx in split_indices}
47
-
48
- gpt4_predictions = json.load(open(args.gpt4_result))['outputs']
49
-
50
- results = defaultdict(lambda: 0)
51
-
52
- for prob_id, prob in split_problems.items():
53
- if prob_id not in our_predictions:
54
- continue
55
- if prob_id not in gpt4_predictions:
56
- continue
57
- our_pred = our_predictions[prob_id]['text']
58
- gpt4_pred = gpt4_predictions[prob_id]
59
-
60
- pattern = re.compile(r'The answer is ([A-Z]).')
61
- our_res = pattern.findall(our_pred)
62
- if len(our_res) == 1:
63
- our_answer = our_res[0] # 'A', 'B', ...
64
- else:
65
- our_answer = "FAILED"
66
- gpt4_res = pattern.findall(gpt4_pred)
67
- if len(gpt4_res) == 1:
68
- gpt4_answer = gpt4_res[0] # 'A', 'B', ...
69
- else:
70
- gpt4_answer = "FAILED"
71
-
72
- our_pred_idx = get_pred_idx(our_answer, prob['choices'], args.options)
73
- gpt4_pred_idx = get_pred_idx(gpt4_answer, prob['choices'], args.options)
74
-
75
- if gpt4_answer == 'FAILED':
76
- results['gpt4_failed'] += 1
77
- # continue
78
- gpt4_pred_idx = our_pred_idx
79
- # if our_pred_idx != prob['answer']:
80
- # print(our_predictions[prob_id]['prompt'])
81
- # print('-----------------')
82
- # print(f'LECTURE: {prob["lecture"]}')
83
- # print(f'SOLUTION: {prob["solution"]}')
84
- # print('=====================')
85
- else:
86
- # continue
87
- pass
88
- # gpt4_pred_idx = our_pred_idx
89
-
90
- if gpt4_pred_idx == prob['answer']:
91
- results['correct'] += 1
92
- else:
93
- results['incorrect'] += 1
94
-
95
-
96
- if gpt4_pred_idx == prob['answer'] or our_pred_idx == prob['answer']:
97
- results['correct_upperbound'] += 1
98
-
99
- correct = results['correct']
100
- total = results['correct'] + results['incorrect']
101
- print(f'Total: {total}, Correct: {correct}, Accuracy: {correct / total * 100:.2f}%')
102
- print(f'Total: {total}, Correct (upper): {results["correct_upperbound"]}, Accuracy: {results["correct_upperbound"] / total * 100:.2f}%')
103
- print(f'Total: {total}, GPT-4 NO-ANS (RANDOM): {results["gpt4_failed"]}, Percentage: {results["gpt4_failed"] / total * 100:.2f}%')
104
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
LLAVA_Biovil/llava/eval/eval_science_qa_gpt4_requery.py DELETED
@@ -1,149 +0,0 @@
1
- import argparse
2
- import json
3
- import os
4
- import re
5
- import random
6
- from collections import defaultdict
7
-
8
-
9
- def get_args():
10
- parser = argparse.ArgumentParser()
11
- parser.add_argument('--base-dir', type=str)
12
- parser.add_argument('--gpt4-result', type=str)
13
- parser.add_argument('--requery-result', type=str)
14
- parser.add_argument('--our-result', type=str)
15
- parser.add_argument('--output-result', type=str)
16
- parser.add_argument('--split', type=str, default='test')
17
- parser.add_argument('--options', type=list, default=["A", "B", "C", "D", "E"])
18
- return parser.parse_args()
19
-
20
-
21
- def convert_caps(results):
22
- fakecaps = []
23
- for result in results:
24
- image_id = result['question_id']
25
- caption = result['text']
26
- fakecaps.append({"image_id": int(image_id), "caption": caption})
27
- return fakecaps
28
-
29
-
30
- def get_pred_idx(prediction, choices, options):
31
- """
32
- Get the index (e.g. 2) from the prediction (e.g. 'C')
33
- """
34
- if prediction in options[:len(choices)]:
35
- return options.index(prediction)
36
- else:
37
- return random.choice(range(len(choices)))
38
-
39
-
40
- if __name__ == "__main__":
41
- args = get_args()
42
-
43
- base_dir = args.base_dir
44
- split_indices = json.load(open(os.path.join(base_dir, "pid_splits.json")))[args.split]
45
- problems = json.load(open(os.path.join(base_dir, "problems.json")))
46
- our_predictions = [json.loads(line) for line in open(args.our_result)]
47
- our_predictions = {pred['question_id']: pred for pred in our_predictions}
48
- split_problems = {idx: problems[idx] for idx in split_indices}
49
-
50
- requery_predictions = [json.loads(line) for line in open(args.requery_result)]
51
- requery_predictions = {pred['question_id']: pred for pred in requery_predictions}
52
-
53
- gpt4_predictions = json.load(open(args.gpt4_result))['outputs']
54
-
55
- results = defaultdict(lambda: 0)
56
-
57
- sqa_results = {}
58
- sqa_results['acc'] = None
59
- sqa_results['correct'] = None
60
- sqa_results['count'] = None
61
- sqa_results['results'] = {}
62
- sqa_results['outputs'] = {}
63
-
64
- for prob_id, prob in split_problems.items():
65
- if prob_id not in our_predictions:
66
- assert False
67
- if prob_id not in gpt4_predictions:
68
- assert False
69
- our_pred = our_predictions[prob_id]['text']
70
- gpt4_pred = gpt4_predictions[prob_id]
71
- if prob_id not in requery_predictions:
72
- results['missing_requery'] += 1
73
- requery_pred = "MISSING"
74
- else:
75
- requery_pred = requery_predictions[prob_id]['text']
76
-
77
- pattern = re.compile(r'The answer is ([A-Z]).')
78
- our_res = pattern.findall(our_pred)
79
- if len(our_res) == 1:
80
- our_answer = our_res[0] # 'A', 'B', ...
81
- else:
82
- our_answer = "FAILED"
83
-
84
- requery_res = pattern.findall(requery_pred)
85
- if len(requery_res) == 1:
86
- requery_answer = requery_res[0] # 'A', 'B', ...
87
- else:
88
- requery_answer = "FAILED"
89
-
90
- gpt4_res = pattern.findall(gpt4_pred)
91
- if len(gpt4_res) == 1:
92
- gpt4_answer = gpt4_res[0] # 'A', 'B', ...
93
- else:
94
- gpt4_answer = "FAILED"
95
-
96
- our_pred_idx = get_pred_idx(our_answer, prob['choices'], args.options)
97
- gpt4_pred_idx = get_pred_idx(gpt4_answer, prob['choices'], args.options)
98
- requery_pred_idx = get_pred_idx(requery_answer, prob['choices'], args.options)
99
-
100
- results['total'] += 1
101
-
102
- if gpt4_answer == 'FAILED':
103
- results['gpt4_failed'] += 1
104
- if gpt4_pred_idx == prob['answer']:
105
- results['gpt4_correct'] += 1
106
- if our_pred_idx == prob['answer']:
107
- results['gpt4_ourvisual_correct'] += 1
108
- elif gpt4_pred_idx == prob['answer']:
109
- results['gpt4_correct'] += 1
110
- results['gpt4_ourvisual_correct'] += 1
111
-
112
- if our_pred_idx == prob['answer']:
113
- results['our_correct'] += 1
114
-
115
- if requery_answer == 'FAILED':
116
- sqa_results['results'][prob_id] = our_pred_idx
117
- if our_pred_idx == prob['answer']:
118
- results['requery_correct'] += 1
119
- else:
120
- sqa_results['results'][prob_id] = requery_pred_idx
121
- if requery_pred_idx == prob['answer']:
122
- results['requery_correct'] += 1
123
- else:
124
- print(f"""
125
- Question ({args.options[prob['answer']]}): {our_predictions[prob_id]['prompt']}
126
- Our ({our_answer}): {our_pred}
127
- GPT-4 ({gpt4_answer}): {gpt4_pred}
128
- Requery ({requery_answer}): {requery_pred}
129
- print("=====================================")
130
- """)
131
-
132
- if gpt4_pred_idx == prob['answer'] or our_pred_idx == prob['answer']:
133
- results['correct_upperbound'] += 1
134
-
135
- total = results['total']
136
- print(f'Total: {total}, Our-Correct: {results["our_correct"]}, Accuracy: {results["our_correct"] / total * 100:.2f}%')
137
- print(f'Total: {total}, GPT-4-Correct: {results["gpt4_correct"]}, Accuracy: {results["gpt4_correct"] / total * 100:.2f}%')
138
- print(f'Total: {total}, GPT-4 NO-ANS (RANDOM): {results["gpt4_failed"]}, Percentage: {results["gpt4_failed"] / total * 100:.2f}%')
139
- print(f'Total: {total}, GPT-4-OursVisual-Correct: {results["gpt4_ourvisual_correct"]}, Accuracy: {results["gpt4_ourvisual_correct"] / total * 100:.2f}%')
140
- print(f'Total: {total}, Requery-Correct: {results["requery_correct"]}, Accuracy: {results["requery_correct"] / total * 100:.2f}%')
141
- print(f'Total: {total}, Correct upper: {results["correct_upperbound"]}, Accuracy: {results["correct_upperbound"] / total * 100:.2f}%')
142
-
143
- sqa_results['acc'] = results["requery_correct"] / total * 100
144
- sqa_results['correct'] = results["requery_correct"]
145
- sqa_results['count'] = total
146
-
147
- with open(args.output_result, 'w') as f:
148
- json.dump(sqa_results, f, indent=2)
149
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
LLAVA_Biovil/llava/eval/eval_textvqa.py DELETED
@@ -1,65 +0,0 @@
1
- import os
2
- import argparse
3
- import json
4
- import re
5
-
6
- from LLAV.llava.eval.m4c_evaluator import TextVQAAccuracyEvaluator
7
-
8
-
9
- def get_args():
10
- parser = argparse.ArgumentParser()
11
- parser.add_argument('--annotation-file', type=str)
12
- parser.add_argument('--result-file', type=str)
13
- parser.add_argument('--result-dir', type=str)
14
- return parser.parse_args()
15
-
16
-
17
- def prompt_processor(prompt):
18
- if prompt.startswith('OCR tokens: '):
19
- pattern = r"Question: (.*?) Short answer:"
20
- match = re.search(pattern, prompt, re.DOTALL)
21
- question = match.group(1)
22
- elif 'Reference OCR token: ' in prompt and len(prompt.split('\n')) == 3:
23
- if prompt.startswith('Reference OCR token:'):
24
- question = prompt.split('\n')[1]
25
- else:
26
- question = prompt.split('\n')[0]
27
- elif len(prompt.split('\n')) == 2:
28
- question = prompt.split('\n')[0]
29
- else:
30
- assert False
31
-
32
- return question.lower()
33
-
34
-
35
- def eval_single(annotation_file, result_file):
36
- experiment_name = os.path.splitext(os.path.basename(result_file))[0]
37
- print(experiment_name)
38
- annotations = json.load(open(annotation_file))['data']
39
- annotations = {(annotation['image_id'], annotation['question'].lower()): annotation for annotation in annotations}
40
- results = [json.loads(line) for line in open(result_file)]
41
-
42
- pred_list = []
43
- for result in results:
44
- annotation = annotations[(result['question_id'], prompt_processor(result['prompt']))]
45
- pred_list.append({
46
- "pred_answer": result['text'],
47
- "gt_answers": annotation['answers'],
48
- })
49
-
50
- evaluator = TextVQAAccuracyEvaluator()
51
- print('Samples: {}\nAccuracy: {:.2f}%\n'.format(len(pred_list), 100. * evaluator.eval_pred_list(pred_list)))
52
-
53
-
54
- if __name__ == "__main__":
55
- args = get_args()
56
-
57
- if args.result_file is not None:
58
- eval_single(args.annotation_file, args.result_file)
59
-
60
- if args.result_dir is not None:
61
- for result_file in sorted(os.listdir(args.result_dir)):
62
- if not result_file.endswith('.jsonl'):
63
- print(f'Skipping {result_file}')
64
- continue
65
- eval_single(args.annotation_file, os.path.join(args.result_dir, result_file))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
LLAVA_Biovil/llava/eval/generate_webpage_data_from_table.py DELETED
@@ -1,111 +0,0 @@
1
- """Generate json file for webpage."""
2
- import json
3
- import os
4
- import re
5
-
6
- # models = ['llama', 'alpaca', 'gpt35', 'bard']
7
- models = ['vicuna']
8
-
9
-
10
- def read_jsonl(path: str, key: str=None):
11
- data = []
12
- with open(os.path.expanduser(path)) as f:
13
- for line in f:
14
- if not line:
15
- continue
16
- data.append(json.loads(line))
17
- if key is not None:
18
- data.sort(key=lambda x: x[key])
19
- data = {item[key]: item for item in data}
20
- return data
21
-
22
-
23
- def trim_hanging_lines(s: str, n: int) -> str:
24
- s = s.strip()
25
- for _ in range(n):
26
- s = s.split('\n', 1)[1].strip()
27
- return s
28
-
29
-
30
- if __name__ == '__main__':
31
- questions = read_jsonl('table/question.jsonl', key='question_id')
32
-
33
- # alpaca_answers = read_jsonl('table/answer/answer_alpaca-13b.jsonl', key='question_id')
34
- # bard_answers = read_jsonl('table/answer/answer_bard.jsonl', key='question_id')
35
- # gpt35_answers = read_jsonl('table/answer/answer_gpt35.jsonl', key='question_id')
36
- # llama_answers = read_jsonl('table/answer/answer_llama-13b.jsonl', key='question_id')
37
- vicuna_answers = read_jsonl('table/answer/answer_vicuna-13b.jsonl', key='question_id')
38
- ours_answers = read_jsonl('table/results/llama-13b-hf-alpaca.jsonl', key='question_id')
39
-
40
- review_vicuna = read_jsonl('table/review/review_vicuna-13b_llama-13b-hf-alpaca.jsonl', key='question_id')
41
- # review_alpaca = read_jsonl('table/review/review_alpaca-13b_vicuna-13b.jsonl', key='question_id')
42
- # review_bard = read_jsonl('table/review/review_bard_vicuna-13b.jsonl', key='question_id')
43
- # review_gpt35 = read_jsonl('table/review/review_gpt35_vicuna-13b.jsonl', key='question_id')
44
- # review_llama = read_jsonl('table/review/review_llama-13b_vicuna-13b.jsonl', key='question_id')
45
-
46
- records = []
47
- for qid in questions.keys():
48
- r = {
49
- 'id': qid,
50
- 'category': questions[qid]['category'],
51
- 'question': questions[qid]['text'],
52
- 'answers': {
53
- # 'alpaca': alpaca_answers[qid]['text'],
54
- # 'llama': llama_answers[qid]['text'],
55
- # 'bard': bard_answers[qid]['text'],
56
- # 'gpt35': gpt35_answers[qid]['text'],
57
- 'vicuna': vicuna_answers[qid]['text'],
58
- 'ours': ours_answers[qid]['text'],
59
- },
60
- 'evaluations': {
61
- # 'alpaca': review_alpaca[qid]['text'],
62
- # 'llama': review_llama[qid]['text'],
63
- # 'bard': review_bard[qid]['text'],
64
- 'vicuna': review_vicuna[qid]['content'],
65
- # 'gpt35': review_gpt35[qid]['text'],
66
- },
67
- 'scores': {
68
- 'vicuna': review_vicuna[qid]['tuple'],
69
- # 'alpaca': review_alpaca[qid]['score'],
70
- # 'llama': review_llama[qid]['score'],
71
- # 'bard': review_bard[qid]['score'],
72
- # 'gpt35': review_gpt35[qid]['score'],
73
- },
74
- }
75
-
76
- # cleanup data
77
- cleaned_evals = {}
78
- for k, v in r['evaluations'].items():
79
- v = v.strip()
80
- lines = v.split('\n')
81
- # trim the first line if it's a pair of numbers
82
- if re.match(r'\d+[, ]+\d+', lines[0]):
83
- lines = lines[1:]
84
- v = '\n'.join(lines)
85
- cleaned_evals[k] = v.replace('Assistant 1', "**Assistant 1**").replace('Assistant 2', '**Assistant 2**')
86
-
87
- r['evaluations'] = cleaned_evals
88
- records.append(r)
89
-
90
- # Reorder the records, this is optional
91
- for r in records:
92
- if r['id'] <= 20:
93
- r['id'] += 60
94
- else:
95
- r['id'] -= 20
96
- for r in records:
97
- if r['id'] <= 50:
98
- r['id'] += 10
99
- elif 50 < r['id'] <= 60:
100
- r['id'] -= 50
101
- for r in records:
102
- if r['id'] == 7:
103
- r['id'] = 1
104
- elif r['id'] < 7:
105
- r['id'] += 1
106
-
107
- records.sort(key=lambda x: x['id'])
108
-
109
- # Write to file
110
- with open('webpage/data.json', 'w') as f:
111
- json.dump({'questions': records, 'models': models}, f, indent=2)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
LLAVA_Biovil/llava/eval/m4c_evaluator.py DELETED
@@ -1,334 +0,0 @@
1
- # Copyright (c) Facebook, Inc. and its affiliates.
2
- import re
3
-
4
- from tqdm import tqdm
5
-
6
-
7
- class EvalAIAnswerProcessor:
8
- """
9
- Processes an answer similar to Eval AI
10
- copied from
11
- https://github.com/facebookresearch/mmf/blob/c46b3b3391275b4181567db80943473a89ab98ab/pythia/tasks/processors.py#L897
12
- """
13
-
14
- CONTRACTIONS = {
15
- "aint": "ain't",
16
- "arent": "aren't",
17
- "cant": "can't",
18
- "couldve": "could've",
19
- "couldnt": "couldn't",
20
- "couldn'tve": "couldn't've",
21
- "couldnt've": "couldn't've",
22
- "didnt": "didn't",
23
- "doesnt": "doesn't",
24
- "dont": "don't",
25
- "hadnt": "hadn't",
26
- "hadnt've": "hadn't've",
27
- "hadn'tve": "hadn't've",
28
- "hasnt": "hasn't",
29
- "havent": "haven't",
30
- "hed": "he'd",
31
- "hed've": "he'd've",
32
- "he'dve": "he'd've",
33
- "hes": "he's",
34
- "howd": "how'd",
35
- "howll": "how'll",
36
- "hows": "how's",
37
- "Id've": "I'd've",
38
- "I'dve": "I'd've",
39
- "Im": "I'm",
40
- "Ive": "I've",
41
- "isnt": "isn't",
42
- "itd": "it'd",
43
- "itd've": "it'd've",
44
- "it'dve": "it'd've",
45
- "itll": "it'll",
46
- "let's": "let's",
47
- "maam": "ma'am",
48
- "mightnt": "mightn't",
49
- "mightnt've": "mightn't've",
50
- "mightn'tve": "mightn't've",
51
- "mightve": "might've",
52
- "mustnt": "mustn't",
53
- "mustve": "must've",
54
- "neednt": "needn't",
55
- "notve": "not've",
56
- "oclock": "o'clock",
57
- "oughtnt": "oughtn't",
58
- "ow's'at": "'ow's'at",
59
- "'ows'at": "'ow's'at",
60
- "'ow'sat": "'ow's'at",
61
- "shant": "shan't",
62
- "shed've": "she'd've",
63
- "she'dve": "she'd've",
64
- "she's": "she's",
65
- "shouldve": "should've",
66
- "shouldnt": "shouldn't",
67
- "shouldnt've": "shouldn't've",
68
- "shouldn'tve": "shouldn't've",
69
- "somebody'd": "somebodyd",
70
- "somebodyd've": "somebody'd've",
71
- "somebody'dve": "somebody'd've",
72
- "somebodyll": "somebody'll",
73
- "somebodys": "somebody's",
74
- "someoned": "someone'd",
75
- "someoned've": "someone'd've",
76
- "someone'dve": "someone'd've",
77
- "someonell": "someone'll",
78
- "someones": "someone's",
79
- "somethingd": "something'd",
80
- "somethingd've": "something'd've",
81
- "something'dve": "something'd've",
82
- "somethingll": "something'll",
83
- "thats": "that's",
84
- "thered": "there'd",
85
- "thered've": "there'd've",
86
- "there'dve": "there'd've",
87
- "therere": "there're",
88
- "theres": "there's",
89
- "theyd": "they'd",
90
- "theyd've": "they'd've",
91
- "they'dve": "they'd've",
92
- "theyll": "they'll",
93
- "theyre": "they're",
94
- "theyve": "they've",
95
- "twas": "'twas",
96
- "wasnt": "wasn't",
97
- "wed've": "we'd've",
98
- "we'dve": "we'd've",
99
- "weve": "we've",
100
- "werent": "weren't",
101
- "whatll": "what'll",
102
- "whatre": "what're",
103
- "whats": "what's",
104
- "whatve": "what've",
105
- "whens": "when's",
106
- "whered": "where'd",
107
- "wheres": "where's",
108
- "whereve": "where've",
109
- "whod": "who'd",
110
- "whod've": "who'd've",
111
- "who'dve": "who'd've",
112
- "wholl": "who'll",
113
- "whos": "who's",
114
- "whove": "who've",
115
- "whyll": "why'll",
116
- "whyre": "why're",
117
- "whys": "why's",
118
- "wont": "won't",
119
- "wouldve": "would've",
120
- "wouldnt": "wouldn't",
121
- "wouldnt've": "wouldn't've",
122
- "wouldn'tve": "wouldn't've",
123
- "yall": "y'all",
124
- "yall'll": "y'all'll",
125
- "y'allll": "y'all'll",
126
- "yall'd've": "y'all'd've",
127
- "y'alld've": "y'all'd've",
128
- "y'all'dve": "y'all'd've",
129
- "youd": "you'd",
130
- "youd've": "you'd've",
131
- "you'dve": "you'd've",
132
- "youll": "you'll",
133
- "youre": "you're",
134
- "youve": "you've",
135
- }
136
-
137
- NUMBER_MAP = {
138
- "none": "0",
139
- "zero": "0",
140
- "one": "1",
141
- "two": "2",
142
- "three": "3",
143
- "four": "4",
144
- "five": "5",
145
- "six": "6",
146
- "seven": "7",
147
- "eight": "8",
148
- "nine": "9",
149
- "ten": "10",
150
- }
151
- ARTICLES = ["a", "an", "the"]
152
- PERIOD_STRIP = re.compile(r"(?!<=\d)(\.)(?!\d)")
153
- COMMA_STRIP = re.compile(r"(?<=\d)(\,)+(?=\d)")
154
- PUNCTUATIONS = [
155
- ";",
156
- r"/",
157
- "[",
158
- "]",
159
- '"',
160
- "{",
161
- "}",
162
- "(",
163
- ")",
164
- "=",
165
- "+",
166
- "\\",
167
- "_",
168
- "-",
169
- ">",
170
- "<",
171
- "@",
172
- "`",
173
- ",",
174
- "?",
175
- "!",
176
- ]
177
-
178
- def __init__(self, *args, **kwargs):
179
- pass
180
-
181
- def word_tokenize(self, word):
182
- word = word.lower()
183
- word = word.replace(",", "").replace("?", "").replace("'s", " 's")
184
- return word.strip()
185
-
186
- def process_punctuation(self, in_text):
187
- out_text = in_text
188
- for p in self.PUNCTUATIONS:
189
- if (p + " " in in_text or " " + p in in_text) or (
190
- re.search(self.COMMA_STRIP, in_text) is not None
191
- ):
192
- out_text = out_text.replace(p, "")
193
- else:
194
- out_text = out_text.replace(p, " ")
195
- out_text = self.PERIOD_STRIP.sub("", out_text, re.UNICODE)
196
- return out_text
197
-
198
- def process_digit_article(self, in_text):
199
- out_text = []
200
- temp_text = in_text.lower().split()
201
- for word in temp_text:
202
- word = self.NUMBER_MAP.setdefault(word, word)
203
- if word not in self.ARTICLES:
204
- out_text.append(word)
205
- else:
206
- pass
207
- for word_id, word in enumerate(out_text):
208
- if word in self.CONTRACTIONS:
209
- out_text[word_id] = self.CONTRACTIONS[word]
210
- out_text = " ".join(out_text)
211
- return out_text
212
-
213
- def __call__(self, item):
214
- item = self.word_tokenize(item)
215
- item = item.replace("\n", " ").replace("\t", " ").strip()
216
- item = self.process_punctuation(item)
217
- item = self.process_digit_article(item)
218
- return item
219
-
220
-
221
- class TextVQAAccuracyEvaluator:
222
- def __init__(self):
223
- self.answer_processor = EvalAIAnswerProcessor()
224
-
225
- def _compute_answer_scores(self, raw_answers):
226
- """
227
- compute the accuracy (soft score) of human answers
228
- """
229
- answers = [self.answer_processor(a) for a in raw_answers]
230
- assert len(answers) == 10
231
- gt_answers = list(enumerate(answers))
232
- unique_answers = set(answers)
233
- unique_answer_scores = {}
234
-
235
- for unique_answer in unique_answers:
236
- accs = []
237
- for gt_answer in gt_answers:
238
- other_answers = [item for item in gt_answers if item != gt_answer]
239
- matching_answers = [
240
- item for item in other_answers if item[1] == unique_answer
241
- ]
242
- acc = min(1, float(len(matching_answers)) / 3)
243
- accs.append(acc)
244
- unique_answer_scores[unique_answer] = sum(accs) / len(accs)
245
-
246
- return unique_answer_scores
247
-
248
- def eval_pred_list(self, pred_list):
249
- pred_scores = []
250
- for entry in tqdm(pred_list):
251
- pred_answer = self.answer_processor(entry["pred_answer"])
252
- unique_answer_scores = self._compute_answer_scores(entry["gt_answers"])
253
- score = unique_answer_scores.get(pred_answer, 0.0)
254
- pred_scores.append(score)
255
-
256
- accuracy = sum(pred_scores) / len(pred_scores)
257
- return accuracy
258
-
259
-
260
- class STVQAAccuracyEvaluator:
261
- def __init__(self):
262
- self.answer_processor = EvalAIAnswerProcessor()
263
-
264
- def eval_pred_list(self, pred_list):
265
- pred_scores = []
266
- for entry in pred_list:
267
- pred_answer = self.answer_processor(entry["pred_answer"])
268
- gts = [self.answer_processor(a) for a in entry["gt_answers"]]
269
- score = 1.0 if pred_answer in gts else 0.0
270
- pred_scores.append(score)
271
-
272
- accuracy = sum(pred_scores) / len(pred_scores)
273
- return accuracy
274
-
275
-
276
- class STVQAANLSEvaluator:
277
- def __init__(self):
278
- import editdistance # install with `pip install editdistance`
279
-
280
- self.get_edit_distance = editdistance.eval
281
-
282
- def get_anls(self, s1, s2):
283
- s1 = s1.lower().strip()
284
- s2 = s2.lower().strip()
285
- iou = 1 - self.get_edit_distance(s1, s2) / max(len(s1), len(s2))
286
- anls = iou if iou >= 0.5 else 0.0
287
- return anls
288
-
289
- def eval_pred_list(self, pred_list):
290
- pred_scores = []
291
- for entry in pred_list:
292
- anls = max(
293
- self.get_anls(entry["pred_answer"], gt) for gt in entry["gt_answers"]
294
- )
295
- pred_scores.append(anls)
296
-
297
- accuracy = sum(pred_scores) / len(pred_scores)
298
- return accuracy
299
-
300
-
301
- class TextCapsBleu4Evaluator:
302
- def __init__(self):
303
- # The following script requires Java 1.8.0 and pycocotools installed.
304
- # The pycocoevalcap can be installed with pip as
305
- # pip install git+https://github.com/ronghanghu/coco-caption.git@python23
306
- # Original pycocoevalcap code is at https://github.com/tylin/coco-caption
307
- # but has no python3 support yet.
308
- try:
309
- from pycocoevalcap.bleu.bleu import Bleu
310
- from pycocoevalcap.tokenizer.ptbtokenizer import PTBTokenizer
311
- except ModuleNotFoundError:
312
- print(
313
- "Please install pycocoevalcap module using "
314
- "pip install git+https://github.com/ronghanghu/coco-caption.git@python23" # noqa
315
- )
316
- raise
317
-
318
- self.tokenizer = PTBTokenizer()
319
- self.scorer = Bleu(4)
320
-
321
- def eval_pred_list(self, pred_list):
322
- # Create reference and hypotheses captions.
323
- gts = {}
324
- res = {}
325
- for idx, entry in enumerate(pred_list):
326
- gts[idx] = [{"caption": a} for a in entry["gt_answers"]]
327
- res[idx] = [{"caption": entry["pred_answer"]}]
328
-
329
- gts = self.tokenizer.tokenize(gts)
330
- res = self.tokenizer.tokenize(res)
331
- score, _ = self.scorer.compute_score(gts, res)
332
-
333
- bleu4 = score[3] # score is (Bleu-1, Bleu-2, Bleu-3, Bleu-4)
334
- return bleu4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
LLAVA_Biovil/llava/eval/model_qa.py DELETED
@@ -1,85 +0,0 @@
1
- import argparse
2
- from transformers import AutoTokenizer, AutoModelForCausalLM, StoppingCriteria
3
- import torch
4
- import os
5
- import json
6
- from tqdm import tqdm
7
- import shortuuid
8
-
9
- from LLAV.llava.conversation import default_conversation
10
- from LLAV.llava.utils import disable_torch_init
11
-
12
-
13
- # new stopping implementation
14
- class KeywordsStoppingCriteria(StoppingCriteria):
15
- def __init__(self, keywords, tokenizer, input_ids):
16
- self.keywords = keywords
17
- self.tokenizer = tokenizer
18
- self.start_len = None
19
- self.input_ids = input_ids
20
-
21
- def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
22
- if self.start_len is None:
23
- self.start_len = self.input_ids.shape[1]
24
- else:
25
- outputs = self.tokenizer.batch_decode(output_ids[:, self.start_len:], skip_special_tokens=True)[0]
26
- for keyword in self.keywords:
27
- if keyword in outputs:
28
- return True
29
- return False
30
-
31
-
32
- @torch.inference_mode()
33
- def eval_model(model_name, questions_file, answers_file):
34
- # Model
35
- disable_torch_init()
36
- model_name = os.path.expanduser(model_name)
37
- tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)
38
- model = AutoModelForCausalLM.from_pretrained(model_name,
39
- torch_dtype=torch.float16).cuda()
40
-
41
-
42
- ques_file = open(os.path.expanduser(questions_file), "r")
43
- ans_file = open(os.path.expanduser(answers_file), "w")
44
- for i, line in enumerate(tqdm(ques_file)):
45
- idx = json.loads(line)["question_id"]
46
- qs = json.loads(line)["text"]
47
- cat = json.loads(line)["category"]
48
- conv = default_conversation.copy()
49
- conv.append_message(conv.roles[0], qs)
50
- prompt = conv.get_prompt()
51
- inputs = tokenizer([prompt])
52
- input_ids = torch.as_tensor(inputs.input_ids).cuda()
53
- stopping_criteria = KeywordsStoppingCriteria([conv.sep], tokenizer, input_ids)
54
- output_ids = model.generate(
55
- input_ids,
56
- do_sample=True,
57
- use_cache=True,
58
- temperature=0.7,
59
- max_new_tokens=1024,
60
- stopping_criteria=[stopping_criteria])
61
- outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0]
62
- try:
63
- index = outputs.index(conv.sep, len(prompt))
64
- except ValueError:
65
- outputs += conv.sep
66
- index = outputs.index(conv.sep, len(prompt))
67
-
68
- outputs = outputs[len(prompt) + len(conv.roles[1]) + 2:index].strip()
69
- ans_id = shortuuid.uuid()
70
- ans_file.write(json.dumps({"question_id": idx,
71
- "text": outputs,
72
- "answer_id": ans_id,
73
- "model_id": model_name,
74
- "metadata": {}}) + "\n")
75
- ans_file.flush()
76
- ans_file.close()
77
-
78
- if __name__ == "__main__":
79
- parser = argparse.ArgumentParser()
80
- parser.add_argument("--model-name", type=str, default="facebook/opt-350m")
81
- parser.add_argument("--question-file", type=str, default="tables/question.jsonl")
82
- parser.add_argument("--answers-file", type=str, default="answer.jsonl")
83
- args = parser.parse_args()
84
-
85
- eval_model(args.model_name, args.question_file, args.answers_file)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
LLAVA_Biovil/llava/eval/model_vqa.py DELETED
@@ -1,112 +0,0 @@
1
- import argparse
2
- import torch
3
- import os
4
- import json
5
- from tqdm import tqdm
6
- import shortuuid
7
-
8
- from LLAV.llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
9
- from LLAV.llava.conversation import conv_templates, SeparatorStyle
10
- from LLAV.llava.model.builder import load_pretrained_model
11
- from LLAV.llava.utils import disable_torch_init
12
- from LLAV.llava.mm_utils import tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria
13
-
14
- from PIL import Image
15
- import math
16
-
17
-
18
- def split_list(lst, n):
19
- """Split a list into n (roughly) equal-sized chunks"""
20
- chunk_size = math.ceil(len(lst) / n) # integer division
21
- return [lst[i:i+chunk_size] for i in range(0, len(lst), chunk_size)]
22
-
23
-
24
- def get_chunk(lst, n, k):
25
- chunks = split_list(lst, n)
26
- return chunks[k]
27
-
28
-
29
- def eval_model(args):
30
- # Model
31
- disable_torch_init()
32
- model_path = os.path.expanduser(args.model_path)
33
- model_name = get_model_name_from_path(model_path)
34
- tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, args.model_base, model_name)
35
-
36
- questions = [json.loads(q) for q in open(os.path.expanduser(args.question_file), "r")]
37
- questions = get_chunk(questions, args.num_chunks, args.chunk_idx)
38
- answers_file = os.path.expanduser(args.answers_file)
39
- os.makedirs(os.path.dirname(answers_file), exist_ok=True)
40
- ans_file = open(answers_file, "w")
41
- for line in tqdm(questions):
42
- idx = line["question_id"]
43
- image_file = line["image"]
44
- qs = line["text"]
45
- cur_prompt = qs
46
- if model.config.mm_use_im_start_end:
47
- qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + qs
48
- else:
49
- qs = DEFAULT_IMAGE_TOKEN + '\n' + qs
50
-
51
- conv = conv_templates[args.conv_mode].copy()
52
- conv.append_message(conv.roles[0], qs)
53
- conv.append_message(conv.roles[1], None)
54
- prompt = conv.get_prompt()
55
-
56
- input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda()
57
-
58
- image = Image.open(os.path.join(args.image_folder, image_file))
59
- image_tensor = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
60
-
61
- stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
62
- keywords = [stop_str]
63
- stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
64
-
65
- with torch.inference_mode():
66
- output_ids = model.generate(
67
- input_ids,
68
- images=image_tensor.unsqueeze(0).half().cuda(),
69
- do_sample=True if args.temperature > 0 else False,
70
- temperature=args.temperature,
71
- top_p=args.top_p,
72
- num_beams=args.num_beams,
73
- # no_repeat_ngram_size=3,
74
- max_new_tokens=1024,
75
- use_cache=True)
76
-
77
- input_token_len = input_ids.shape[1]
78
- n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item()
79
- if n_diff_input_output > 0:
80
- print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids')
81
- outputs = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0]
82
- outputs = outputs.strip()
83
- if outputs.endswith(stop_str):
84
- outputs = outputs[:-len(stop_str)]
85
- outputs = outputs.strip()
86
-
87
- ans_id = shortuuid.uuid()
88
- ans_file.write(json.dumps({"question_id": idx,
89
- "prompt": cur_prompt,
90
- "text": outputs,
91
- "answer_id": ans_id,
92
- "model_id": model_name,
93
- "metadata": {}}) + "\n")
94
- ans_file.flush()
95
- ans_file.close()
96
-
97
- if __name__ == "__main__":
98
- parser = argparse.ArgumentParser()
99
- parser.add_argument("--model-path", type=str, default="facebook/opt-350m")
100
- parser.add_argument("--model-base", type=str, default=None)
101
- parser.add_argument("--image-folder", type=str, default="")
102
- parser.add_argument("--question-file", type=str, default="tables/question.jsonl")
103
- parser.add_argument("--answers-file", type=str, default="answer.jsonl")
104
- parser.add_argument("--conv-mode", type=str, default="llava_v1")
105
- parser.add_argument("--num-chunks", type=int, default=1)
106
- parser.add_argument("--chunk-idx", type=int, default=0)
107
- parser.add_argument("--temperature", type=float, default=0.2)
108
- parser.add_argument("--top_p", type=float, default=None)
109
- parser.add_argument("--num_beams", type=int, default=1)
110
- args = parser.parse_args()
111
-
112
- eval_model(args)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
LLAVA_Biovil/llava/eval/model_vqa_loader.py DELETED
@@ -1,141 +0,0 @@
1
- import argparse
2
- import torch
3
- import os
4
- import json
5
- from tqdm import tqdm
6
- import shortuuid
7
-
8
- from LLAV.llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
9
- from LLAV.llava.conversation import conv_templates
10
- from LLAV.llava.model.builder import load_pretrained_model
11
- from LLAV.llava.utils import disable_torch_init
12
- from LLAV.llava.mm_utils import tokenizer_image_token, process_images, get_model_name_from_path
13
- from torch.utils.data import Dataset, DataLoader
14
-
15
- from PIL import Image
16
- import math
17
-
18
-
19
- def split_list(lst, n):
20
- """Split a list into n (roughly) equal-sized chunks"""
21
- chunk_size = math.ceil(len(lst) / n) # integer division
22
- return [lst[i:i+chunk_size] for i in range(0, len(lst), chunk_size)]
23
-
24
-
25
- def get_chunk(lst, n, k):
26
- chunks = split_list(lst, n)
27
- return chunks[k]
28
-
29
-
30
- # Custom dataset class
31
- class CustomDataset(Dataset):
32
- def __init__(self, questions, image_folder, tokenizer, image_processor, model_config):
33
- self.questions = questions
34
- self.image_folder = image_folder
35
- self.tokenizer = tokenizer
36
- self.image_processor = image_processor
37
- self.model_config = model_config
38
-
39
- def __getitem__(self, index):
40
- line = self.questions[index]
41
- image_file = line["image"]
42
- qs = line["text"]
43
- if self.model_config.mm_use_im_start_end:
44
- qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + qs
45
- else:
46
- qs = DEFAULT_IMAGE_TOKEN + '\n' + qs
47
-
48
- conv = conv_templates[args.conv_mode].copy()
49
- conv.append_message(conv.roles[0], qs)
50
- conv.append_message(conv.roles[1], None)
51
- prompt = conv.get_prompt()
52
-
53
- image = Image.open(os.path.join(self.image_folder, image_file)).convert('RGB')
54
- image_tensor = process_images([image], self.image_processor, self.model_config)[0]
55
-
56
- input_ids = tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt')
57
-
58
- return input_ids, image_tensor
59
-
60
- def __len__(self):
61
- return len(self.questions)
62
-
63
-
64
- # DataLoader
65
- def create_data_loader(questions, image_folder, tokenizer, image_processor, model_config, batch_size=1, num_workers=4):
66
- assert batch_size == 1, "batch_size must be 1"
67
- dataset = CustomDataset(questions, image_folder, tokenizer, image_processor, model_config)
68
- data_loader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, shuffle=False)
69
- return data_loader
70
-
71
-
72
- def eval_model(args):
73
- # Model
74
- disable_torch_init()
75
- model_path = os.path.expanduser(args.model_path)
76
- model_name = get_model_name_from_path(model_path)
77
- tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, args.model_base, model_name)
78
-
79
- questions = [json.loads(q) for q in open(os.path.expanduser(args.question_file), "r")]
80
- questions = get_chunk(questions, args.num_chunks, args.chunk_idx)
81
- answers_file = os.path.expanduser(args.answers_file)
82
- os.makedirs(os.path.dirname(answers_file), exist_ok=True)
83
- ans_file = open(answers_file, "w")
84
-
85
- if 'plain' in model_name and 'finetune' not in model_name.lower() and 'mmtag' not in args.conv_mode:
86
- args.conv_mode = args.conv_mode + '_mmtag'
87
- print(f'It seems that this is a plain model, but it is not using a mmtag prompt, auto switching to {args.conv_mode}.')
88
-
89
- data_loader = create_data_loader(questions, args.image_folder, tokenizer, image_processor, model.config)
90
-
91
- for (input_ids, image_tensor), line in tqdm(zip(data_loader, questions), total=len(questions)):
92
- idx = line["question_id"]
93
- cur_prompt = line["text"]
94
-
95
- input_ids = input_ids.to(device='cuda', non_blocking=True)
96
-
97
- with torch.inference_mode():
98
- output_ids = model.generate(
99
- input_ids,
100
- images=image_tensor.to(dtype=torch.float16, device='cuda', non_blocking=True),
101
- do_sample=True if args.temperature > 0 else False,
102
- temperature=args.temperature,
103
- top_p=args.top_p,
104
- num_beams=args.num_beams,
105
- max_new_tokens=args.max_new_tokens,
106
- use_cache=True)
107
-
108
- input_token_len = input_ids.shape[1]
109
- n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item()
110
- if n_diff_input_output > 0:
111
- print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids')
112
- outputs = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0]
113
- outputs = outputs.strip()
114
-
115
- ans_id = shortuuid.uuid()
116
- ans_file.write(json.dumps({"question_id": idx,
117
- "prompt": cur_prompt,
118
- "text": outputs,
119
- "answer_id": ans_id,
120
- "model_id": model_name,
121
- "metadata": {}}) + "\n")
122
- # ans_file.flush()
123
- ans_file.close()
124
-
125
- if __name__ == "__main__":
126
- parser = argparse.ArgumentParser()
127
- parser.add_argument("--model-path", type=str, default="facebook/opt-350m")
128
- parser.add_argument("--model-base", type=str, default=None)
129
- parser.add_argument("--image-folder", type=str, default="")
130
- parser.add_argument("--question-file", type=str, default="tables/question.jsonl")
131
- parser.add_argument("--answers-file", type=str, default="answer.jsonl")
132
- parser.add_argument("--conv-mode", type=str, default="llava_v1")
133
- parser.add_argument("--num-chunks", type=int, default=1)
134
- parser.add_argument("--chunk-idx", type=int, default=0)
135
- parser.add_argument("--temperature", type=float, default=0.2)
136
- parser.add_argument("--top_p", type=float, default=None)
137
- parser.add_argument("--num_beams", type=int, default=1)
138
- parser.add_argument("--max_new_tokens", type=int, default=128)
139
- args = parser.parse_args()
140
-
141
- eval_model(args)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
LLAVA_Biovil/llava/eval/model_vqa_mmbench.py DELETED
@@ -1,169 +0,0 @@
1
- import argparse
2
- import torch
3
- import os
4
- import json
5
- import pandas as pd
6
- from tqdm import tqdm
7
- import shortuuid
8
-
9
- from LLAV.llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
10
- from LLAV.llava.conversation import conv_templates, SeparatorStyle
11
- from LLAV.llava.model.builder import load_pretrained_model
12
- from LLAV.llava.utils import disable_torch_init
13
- from LLAV.llava.mm_utils import tokenizer_image_token, process_images, load_image_from_base64, get_model_name_from_path
14
-
15
- import math
16
-
17
-
18
- all_options = ['A', 'B', 'C', 'D']
19
-
20
-
21
- def split_list(lst, n):
22
- """Split a list into n (roughly) equal-sized chunks"""
23
- chunk_size = math.ceil(len(lst) / n) # integer division
24
- return [lst[i:i+chunk_size] for i in range(0, len(lst), chunk_size)]
25
-
26
-
27
- def get_chunk(lst, n, k):
28
- chunks = split_list(lst, n)
29
- return chunks[k]
30
-
31
-
32
- def is_none(value):
33
- if value is None:
34
- return True
35
- if type(value) is float and math.isnan(value):
36
- return True
37
- if type(value) is str and value.lower() == 'nan':
38
- return True
39
- if type(value) is str and value.lower() == 'none':
40
- return True
41
- return False
42
-
43
- def get_options(row, options):
44
- parsed_options = []
45
- for option in options:
46
- option_value = row[option]
47
- if is_none(option_value):
48
- break
49
- parsed_options.append(option_value)
50
- return parsed_options
51
-
52
-
53
- def eval_model(args):
54
- # Model
55
- disable_torch_init()
56
- model_path = os.path.expanduser(args.model_path)
57
- model_name = get_model_name_from_path(model_path)
58
- tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, args.model_base, model_name)
59
-
60
- questions = pd.read_table(os.path.expanduser(args.question_file))
61
- questions = get_chunk(questions, args.num_chunks, args.chunk_idx)
62
- answers_file = os.path.expanduser(args.answers_file)
63
- os.makedirs(os.path.dirname(answers_file), exist_ok=True)
64
- ans_file = open(answers_file, "w")
65
-
66
- if 'plain' in model_name and 'finetune' not in model_name.lower() and 'mmtag' not in args.conv_mode:
67
- args.conv_mode = args.conv_mode + '_mmtag'
68
- print(f'It seems that this is a plain model, but it is not using a mmtag prompt, auto switching to {args.conv_mode}.')
69
-
70
- for index, row in tqdm(questions.iterrows(), total=len(questions)):
71
- options = get_options(row, all_options)
72
- cur_option_char = all_options[:len(options)]
73
-
74
- if args.all_rounds:
75
- num_rounds = len(options)
76
- else:
77
- num_rounds = 1
78
-
79
- for round_idx in range(num_rounds):
80
- idx = row['index']
81
- question = row['question']
82
- hint = row['hint']
83
- image = load_image_from_base64(row['image'])
84
- if not is_none(hint):
85
- question = hint + '\n' + question
86
- for option_char, option in zip(all_options[:len(options)], options):
87
- question = question + '\n' + option_char + '. ' + option
88
- qs = cur_prompt = question
89
- if model.config.mm_use_im_start_end:
90
- qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + qs
91
- else:
92
- qs = DEFAULT_IMAGE_TOKEN + '\n' + qs
93
-
94
- if args.single_pred_prompt:
95
- if args.lang == 'cn':
96
- qs = qs + '\n' + "请直接回答选项字母。"
97
- else:
98
- qs = qs + '\n' + "Answer with the option's letter from the given choices directly."
99
-
100
- conv = conv_templates[args.conv_mode].copy()
101
- conv.append_message(conv.roles[0], qs)
102
- conv.append_message(conv.roles[1], None)
103
- prompt = conv.get_prompt()
104
-
105
- input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda()
106
-
107
- image_tensor = process_images([image], image_processor, model.config)[0]
108
- # image_tensor = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
109
-
110
- stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
111
-
112
- with torch.inference_mode():
113
- output_ids = model.generate(
114
- input_ids,
115
- images=image_tensor.unsqueeze(0).half().cuda(),
116
- do_sample=True if args.temperature > 0 else False,
117
- temperature=args.temperature,
118
- top_p=args.top_p,
119
- num_beams=args.num_beams,
120
- # no_repeat_ngram_size=3,
121
- max_new_tokens=1024,
122
- use_cache=True)
123
-
124
- input_token_len = input_ids.shape[1]
125
- n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item()
126
- if n_diff_input_output > 0:
127
- print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids')
128
- outputs = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0]
129
- outputs = outputs.strip()
130
- if outputs.endswith(stop_str):
131
- outputs = outputs[:-len(stop_str)]
132
- outputs = outputs.strip()
133
-
134
- ans_id = shortuuid.uuid()
135
- ans_file.write(json.dumps({"question_id": idx,
136
- "round_id": round_idx,
137
- "prompt": cur_prompt,
138
- "text": outputs,
139
- "options": options,
140
- "option_char": cur_option_char,
141
- "answer_id": ans_id,
142
- "model_id": model_name,
143
- "metadata": {}}) + "\n")
144
- ans_file.flush()
145
-
146
- # rotate options
147
- options = options[1:] + options[:1]
148
- cur_option_char = cur_option_char[1:] + cur_option_char[:1]
149
- ans_file.close()
150
-
151
- if __name__ == "__main__":
152
- parser = argparse.ArgumentParser()
153
- parser.add_argument("--model-path", type=str, default="facebook/opt-350m")
154
- parser.add_argument("--model-base", type=str, default=None)
155
- parser.add_argument("--image-folder", type=str, default="")
156
- parser.add_argument("--question-file", type=str, default="tables/question.jsonl")
157
- parser.add_argument("--answers-file", type=str, default="answer.jsonl")
158
- parser.add_argument("--conv-mode", type=str, default="llava_v1")
159
- parser.add_argument("--num-chunks", type=int, default=1)
160
- parser.add_argument("--chunk-idx", type=int, default=0)
161
- parser.add_argument("--temperature", type=float, default=0.2)
162
- parser.add_argument("--top_p", type=float, default=None)
163
- parser.add_argument("--num_beams", type=int, default=1)
164
- parser.add_argument("--all-rounds", action="store_true")
165
- parser.add_argument("--single-pred-prompt", action="store_true")
166
- parser.add_argument("--lang", type=str, default="en")
167
- args = parser.parse_args()
168
-
169
- eval_model(args)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
LLAVA_Biovil/llava/eval/model_vqa_qbench.py DELETED
@@ -1,120 +0,0 @@
1
- import argparse
2
- import torch
3
- from tqdm import tqdm
4
- import json
5
-
6
- from LLAV.llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
7
- from LLAV.llava.conversation import conv_templates, SeparatorStyle
8
- from LLAV.llava.model.builder import load_pretrained_model
9
- from LLAV.llava.utils import disable_torch_init
10
- from LLAV.llava.mm_utils import tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria
11
-
12
- import requests
13
- from PIL import Image
14
- from io import BytesIO
15
-
16
-
17
- def load_image(image_file):
18
- if image_file.startswith('http') or image_file.startswith('https'):
19
- response = requests.get(image_file)
20
- image = Image.open(BytesIO(response.content)).convert('RGB')
21
- else:
22
- image = Image.open(image_file).convert('RGB')
23
- return image
24
-
25
-
26
- def eval_model(args):
27
- # Model
28
- disable_torch_init()
29
-
30
- model_name = get_model_name_from_path(args.model_path)
31
- tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name, True)
32
-
33
-
34
-
35
-
36
- with open(args.questions_file) as f:
37
- llvqa_data = json.load(f)
38
-
39
- for i, llddata in enumerate(tqdm(llvqa_data)):
40
- filename = llddata["img_path"]
41
- if args.lang == "en":
42
- message = llddata["question"] + "\nChoose between one of the options as follows:\n"
43
- elif args.lang == "zh":
44
- message = llddata["question"] + "\在下列选项中选择一个:\n"
45
- else:
46
- raise NotImplementedError("Q-Bench does not support languages other than English (en) and Chinese (zh) yet. Contact us (https://github.com/VQAssessment/Q-Bench/) to convert Q-Bench into more languages.")
47
- for choice, ans in zip(["A.", "B.", "C.", "D."], llddata["candidates"]):
48
- message += f"{choice} {ans}\n"
49
- qs = message
50
-
51
- if model.config.mm_use_im_start_end:
52
- qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + qs
53
- else:
54
- qs = DEFAULT_IMAGE_TOKEN + '\n' + qs
55
-
56
- if 'llama-2' in model_name.lower():
57
- conv_mode = "llava_llama_2"
58
- elif "v1" in model_name.lower():
59
- conv_mode = "llava_v1"
60
- elif "mpt" in model_name.lower():
61
- conv_mode = "mpt"
62
- else:
63
- conv_mode = "llava_v0"
64
-
65
- if args.conv_mode is not None and conv_mode != args.conv_mode:
66
- print('[WARNING] the auto inferred conversation mode is {}, while `--conv-mode` is {}, using {}'.format(conv_mode, args.conv_mode, args.conv_mode))
67
- else:
68
- args.conv_mode = conv_mode
69
-
70
- conv = conv_templates[args.conv_mode].copy()
71
- conv.append_message(conv.roles[0], qs)
72
- conv.append_message(conv.roles[1], None)
73
- prompt = conv.get_prompt()
74
-
75
- image = load_image(args.image_folder + filename)
76
- image_tensor = image_processor.preprocess(image, return_tensors='pt')['pixel_values'].half().cuda()
77
-
78
- input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda()
79
-
80
- stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
81
- keywords = [stop_str]
82
- stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
83
-
84
-
85
- with torch.inference_mode():
86
- output_ids = model.generate(
87
- input_ids,
88
- images=image_tensor,
89
- num_beams=1,
90
- do_sample=False,
91
- temperature=0,
92
- max_new_tokens=1024,
93
- use_cache=True,
94
- stopping_criteria=[stopping_criteria])
95
-
96
- input_token_len = input_ids.shape[1]
97
- n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item()
98
- if n_diff_input_output > 0:
99
- print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids')
100
- outputs = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0]
101
- outputs = outputs.strip()
102
- if outputs.endswith(stop_str):
103
- outputs = outputs[:-len(stop_str)]
104
- outputs = outputs.strip()
105
- llddata["response"] = outputs
106
- with open(args.answers_file, "a") as wf:
107
- json.dump(llddata, wf)
108
-
109
- if __name__ == "__main__":
110
- parser = argparse.ArgumentParser()
111
- parser.add_argument("--model-path", type=str, default="llava-v1.5")
112
- parser.add_argument("--model-base", type=str, default=None)
113
- parser.add_argument("--image-folder", type=str, default="./playground/data/qbench/images_llvisionqa")
114
- parser.add_argument("--questions-file", type=str, default="./playground/data/qbench/llvisionqa_dev.json")
115
- parser.add_argument("--answers-file", type=str, default="answer.jsonl")
116
- parser.add_argument("--conv-mode", type=str, default="llava_v1")
117
- parser.add_argument("--lang", type=str, default="en")
118
- args = parser.parse_args()
119
-
120
- eval_model(args)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
LLAVA_Biovil/llava/eval/model_vqa_science.py DELETED
@@ -1,147 +0,0 @@
1
- import argparse
2
- import torch
3
- import os
4
- import json
5
- from tqdm import tqdm
6
- import shortuuid
7
-
8
- from LLAV.llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
9
- from LLAV.llava.conversation import conv_templates, SeparatorStyle
10
- from LLAV.llava.model.builder import load_pretrained_model
11
- from LLAV.llava.utils import disable_torch_init
12
- from LLAV.llava.mm_utils import tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria
13
-
14
- from PIL import Image
15
- import math
16
-
17
-
18
- def split_list(lst, n):
19
- """Split a list into n (roughly) equal-sized chunks"""
20
- chunk_size = math.ceil(len(lst) / n) # integer division
21
- return [lst[i:i+chunk_size] for i in range(0, len(lst), chunk_size)]
22
-
23
-
24
- def get_chunk(lst, n, k):
25
- chunks = split_list(lst, n)
26
- return chunks[k]
27
-
28
-
29
- def eval_model(args):
30
- # Model
31
- disable_torch_init()
32
- model_path = os.path.expanduser(args.model_path)
33
- model_name = get_model_name_from_path(model_path)
34
- tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, args.model_base, model_name)
35
-
36
- questions = json.load(open(os.path.expanduser(args.question_file), "r"))
37
- questions = get_chunk(questions, args.num_chunks, args.chunk_idx)
38
- answers_file = os.path.expanduser(args.answers_file)
39
- os.makedirs(os.path.dirname(answers_file), exist_ok=True)
40
- ans_file = open(answers_file, "w")
41
- for i, line in enumerate(tqdm(questions)):
42
- idx = line["id"]
43
- question = line['conversations'][0]
44
- qs = question['value'].replace('<image>', '').strip()
45
- cur_prompt = qs
46
-
47
- if 'image' in line:
48
- image_file = line["image"]
49
- image = Image.open(os.path.join(args.image_folder, image_file))
50
- image_tensor = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
51
- images = image_tensor.unsqueeze(0).half().cuda()
52
- if getattr(model.config, 'mm_use_im_start_end', False):
53
- qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + qs
54
- else:
55
- qs = DEFAULT_IMAGE_TOKEN + '\n' + qs
56
- cur_prompt = '<image>' + '\n' + cur_prompt
57
- else:
58
- images = None
59
-
60
- if args.single_pred_prompt:
61
- qs = qs + '\n' + "Answer with the option's letter from the given choices directly."
62
- cur_prompt = cur_prompt + '\n' + "Answer with the option's letter from the given choices directly."
63
-
64
- conv = conv_templates[args.conv_mode].copy()
65
- conv.append_message(conv.roles[0], qs)
66
- conv.append_message(conv.roles[1], None)
67
- prompt = conv.get_prompt()
68
-
69
- input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda()
70
-
71
- stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
72
- keywords = [stop_str]
73
- stopping_criteria = [KeywordsStoppingCriteria(keywords, tokenizer, input_ids)] if conv.version == "v0" else None
74
-
75
- with torch.inference_mode():
76
- output_ids = model.generate(
77
- input_ids,
78
- images=images,
79
- do_sample=True if args.temperature > 0 else False,
80
- temperature=args.temperature,
81
- max_new_tokens=1024,
82
- use_cache=True,
83
- stopping_criteria=stopping_criteria,
84
- )
85
-
86
- input_token_len = input_ids.shape[1]
87
- n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item()
88
- if n_diff_input_output > 0:
89
- print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids')
90
- outputs = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0]
91
- outputs = outputs.strip()
92
- if outputs.endswith(stop_str):
93
- outputs = outputs[:-len(stop_str)]
94
- outputs = outputs.strip()
95
-
96
- # prompt for answer
97
- if args.answer_prompter:
98
- outputs_reasoning = outputs
99
- input_ids = tokenizer_image_token(prompt + outputs_reasoning + ' ###\nANSWER:', tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda()
100
-
101
- with torch.inference_mode():
102
- output_ids = model.generate(
103
- input_ids,
104
- images=images,
105
- do_sample=True if args.temperature > 0 else False,
106
- temperature=args.temperature,
107
- max_new_tokens=64,
108
- use_cache=True,
109
- stopping_criteria=[stopping_criteria])
110
-
111
- input_token_len = input_ids.shape[1]
112
- n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item()
113
- if n_diff_input_output > 0:
114
- print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids')
115
- outputs = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0]
116
- outputs = outputs.strip()
117
- if outputs.endswith(stop_str):
118
- outputs = outputs[:-len(stop_str)]
119
- outputs = outputs.strip()
120
- outputs = outputs_reasoning + '\n The answer is ' + outputs
121
-
122
- ans_id = shortuuid.uuid()
123
- ans_file.write(json.dumps({"question_id": idx,
124
- "prompt": cur_prompt,
125
- "text": outputs,
126
- "answer_id": ans_id,
127
- "model_id": model_name,
128
- "metadata": {}}) + "\n")
129
- ans_file.flush()
130
- ans_file.close()
131
-
132
- if __name__ == "__main__":
133
- parser = argparse.ArgumentParser()
134
- parser.add_argument("--model-path", type=str, default="facebook/opt-350m")
135
- parser.add_argument("--model-base", type=str, default=None)
136
- parser.add_argument("--image-folder", type=str, default="")
137
- parser.add_argument("--question-file", type=str, default="tables/question.json")
138
- parser.add_argument("--answers-file", type=str, default="answer.jsonl")
139
- parser.add_argument("--conv-mode", type=str, default="llava_v0")
140
- parser.add_argument("--num-chunks", type=int, default=1)
141
- parser.add_argument("--chunk-idx", type=int, default=0)
142
- parser.add_argument("--temperature", type=float, default=0.2)
143
- parser.add_argument("--answer-prompter", action="store_true")
144
- parser.add_argument("--single-pred-prompt", action="store_true")
145
- args = parser.parse_args()
146
-
147
- eval_model(args)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
LLAVA_Biovil/llava/eval/qa_baseline_gpt35.py DELETED
@@ -1,74 +0,0 @@
1
- """Generate answers with GPT-3.5"""
2
- # Note: you need to be using OpenAI Python v0.27.0 for the code below to work
3
- import argparse
4
- import json
5
- import os
6
- import time
7
- import concurrent.futures
8
-
9
- import openai
10
- import tqdm
11
- import shortuuid
12
-
13
- MODEL = 'gpt-3.5-turbo'
14
- MODEL_ID = 'gpt-3.5-turbo:20230327'
15
-
16
- def get_answer(question_id: int, question: str, max_tokens: int):
17
- ans = {
18
- 'answer_id': shortuuid.uuid(),
19
- 'question_id': question_id,
20
- 'model_id': MODEL_ID,
21
- }
22
- for _ in range(3):
23
- try:
24
- response = openai.ChatCompletion.create(
25
- model=MODEL,
26
- messages=[{
27
- 'role': 'system',
28
- 'content': 'You are a helpful assistant.'
29
- }, {
30
- 'role': 'user',
31
- 'content': question,
32
- }],
33
- max_tokens=max_tokens,
34
- )
35
- ans['text'] = response['choices'][0]['message']['content']
36
- return ans
37
- except Exception as e:
38
- print('[ERROR]', e)
39
- ans['text'] = '#ERROR#'
40
- time.sleep(1)
41
- return ans
42
-
43
-
44
- if __name__ == '__main__':
45
- parser = argparse.ArgumentParser(description='ChatGPT answer generation.')
46
- parser.add_argument('-q', '--question')
47
- parser.add_argument('-o', '--output')
48
- parser.add_argument('--max-tokens', type=int, default=1024, help='maximum number of tokens produced in the output')
49
- args = parser.parse_args()
50
-
51
- questions_dict = {}
52
- with open(os.path.expanduser(args.question)) as f:
53
- for line in f:
54
- if not line:
55
- continue
56
- q = json.loads(line)
57
- questions_dict[q['question_id']] = q['text']
58
-
59
- answers = []
60
-
61
- with concurrent.futures.ThreadPoolExecutor(max_workers=32) as executor:
62
- futures = []
63
- for qid, question in questions_dict.items():
64
- future = executor.submit(get_answer, qid, question, args.max_tokens)
65
- futures.append(future)
66
-
67
- for future in tqdm.tqdm(concurrent.futures.as_completed(futures), total=len(futures)):
68
- answers.append(future.result())
69
-
70
- answers.sort(key=lambda x: x['question_id'])
71
-
72
- with open(os.path.expanduser(args.output), 'w') as f:
73
- table = [json.dumps(ans) for ans in answers]
74
- f.write('\n'.join(table))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
LLAVA_Biovil/llava/eval/run_llava.py DELETED
@@ -1,155 +0,0 @@
1
- import argparse
2
- import torch
3
-
4
- from LLAV.llava.constants import (
5
- IMAGE_TOKEN_INDEX,
6
- DEFAULT_IMAGE_TOKEN,
7
- DEFAULT_IM_START_TOKEN,
8
- DEFAULT_IM_END_TOKEN,
9
- IMAGE_PLACEHOLDER,
10
- )
11
- from LLAV.llava.conversation import conv_templates, SeparatorStyle
12
- from LLAV.llava.model.builder import load_pretrained_model
13
- from LLAV.llava.utils import disable_torch_init
14
- from LLAV.llava.mm_utils import (
15
- process_images,
16
- tokenizer_image_token,
17
- get_model_name_from_path,
18
- KeywordsStoppingCriteria,
19
- )
20
-
21
- import requests
22
- from PIL import Image
23
- from io import BytesIO
24
- import re
25
-
26
-
27
- def image_parser(args):
28
- out = args.image_file.split(args.sep)
29
- return out
30
-
31
-
32
- def load_image(image_file):
33
- if image_file.startswith("http") or image_file.startswith("https"):
34
- response = requests.get(image_file)
35
- image = Image.open(BytesIO(response.content)).convert("RGB")
36
- else:
37
- image = Image.open(image_file).convert("RGB")
38
- return image
39
-
40
-
41
- def load_images(image_files):
42
- out = []
43
- for image_file in image_files:
44
- image = load_image(image_file)
45
- out.append(image)
46
- return out
47
-
48
-
49
- def eval_model(args):
50
- # Model
51
- disable_torch_init()
52
-
53
- model_name = get_model_name_from_path(args.model_path)
54
- tokenizer, model, image_processor, context_len = load_pretrained_model(
55
- args.model_path, args.model_base, model_name
56
- )
57
-
58
- qs = args.query
59
- image_token_se = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN
60
- if IMAGE_PLACEHOLDER in qs:
61
- if model.config.mm_use_im_start_end:
62
- qs = re.sub(IMAGE_PLACEHOLDER, image_token_se, qs)
63
- else:
64
- qs = re.sub(IMAGE_PLACEHOLDER, DEFAULT_IMAGE_TOKEN, qs)
65
- else:
66
- if model.config.mm_use_im_start_end:
67
- qs = image_token_se + "\n" + qs
68
- else:
69
- qs = DEFAULT_IMAGE_TOKEN + "\n" + qs
70
-
71
- if "llama-2" in model_name.lower():
72
- conv_mode = "llava_llama_2"
73
- elif "v1" in model_name.lower():
74
- conv_mode = "llava_v1"
75
- elif "mpt" in model_name.lower():
76
- conv_mode = "mpt"
77
- else:
78
- conv_mode = "llava_v0"
79
-
80
- if args.conv_mode is not None and conv_mode != args.conv_mode:
81
- print(
82
- "[WARNING] the auto inferred conversation mode is {}, while `--conv-mode` is {}, using {}".format(
83
- conv_mode, args.conv_mode, args.conv_mode
84
- )
85
- )
86
- else:
87
- args.conv_mode = conv_mode
88
-
89
- conv = conv_templates[args.conv_mode].copy()
90
- conv.append_message(conv.roles[0], qs)
91
- conv.append_message(conv.roles[1], None)
92
- prompt = conv.get_prompt()
93
-
94
- image_files = image_parser(args)
95
- images = load_images(image_files)
96
- images_tensor = process_images(
97
- images,
98
- image_processor,
99
- model.config
100
- ).to(model.device, dtype=torch.float16)
101
-
102
- input_ids = (
103
- tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt")
104
- .unsqueeze(0)
105
- .cuda()
106
- )
107
-
108
- stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
109
- keywords = [stop_str]
110
- stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
111
-
112
- with torch.inference_mode():
113
- output_ids = model.generate(
114
- input_ids,
115
- images=images_tensor,
116
- do_sample=True if args.temperature > 0 else False,
117
- temperature=args.temperature,
118
- top_p=args.top_p,
119
- num_beams=args.num_beams,
120
- max_new_tokens=args.max_new_tokens,
121
- use_cache=True,
122
- stopping_criteria=[stopping_criteria],
123
- )
124
-
125
- input_token_len = input_ids.shape[1]
126
- n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item()
127
- if n_diff_input_output > 0:
128
- print(
129
- f"[Warning] {n_diff_input_output} output_ids are not the same as the input_ids"
130
- )
131
- outputs = tokenizer.batch_decode(
132
- output_ids[:, input_token_len:], skip_special_tokens=True
133
- )[0]
134
- outputs = outputs.strip()
135
- if outputs.endswith(stop_str):
136
- outputs = outputs[: -len(stop_str)]
137
- outputs = outputs.strip()
138
- print(outputs)
139
-
140
-
141
- if __name__ == "__main__":
142
- parser = argparse.ArgumentParser()
143
- parser.add_argument("--model-path", type=str, default="facebook/opt-350m")
144
- parser.add_argument("--model-base", type=str, default=None)
145
- parser.add_argument("--image-file", type=str, required=True)
146
- parser.add_argument("--query", type=str, required=True)
147
- parser.add_argument("--conv-mode", type=str, default=None)
148
- parser.add_argument("--sep", type=str, default=",")
149
- parser.add_argument("--temperature", type=float, default=0.2)
150
- parser.add_argument("--top_p", type=float, default=None)
151
- parser.add_argument("--num_beams", type=int, default=1)
152
- parser.add_argument("--max_new_tokens", type=int, default=512)
153
- args = parser.parse_args()
154
-
155
- eval_model(args)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
LLAVA_Biovil/llava/eval/summarize_gpt_review.py DELETED
@@ -1,60 +0,0 @@
1
- import json
2
- import os
3
- from collections import defaultdict
4
-
5
- import numpy as np
6
-
7
- import argparse
8
-
9
- def parse_args():
10
- parser = argparse.ArgumentParser(description='ChatGPT-based QA evaluation.')
11
- parser.add_argument('-d', '--dir', default=None)
12
- parser.add_argument('-v', '--version', default=None)
13
- parser.add_argument('-s', '--select', nargs='*', default=None)
14
- parser.add_argument('-f', '--files', nargs='*', default=[])
15
- parser.add_argument('-i', '--ignore', nargs='*', default=[])
16
- return parser.parse_args()
17
-
18
-
19
- if __name__ == '__main__':
20
- args = parse_args()
21
-
22
- if args.ignore is not None:
23
- args.ignore = [int(x) for x in args.ignore]
24
-
25
- if len(args.files) > 0:
26
- review_files = args.files
27
- else:
28
- review_files = [x for x in os.listdir(args.dir) if x.endswith('.jsonl') and (x.startswith('gpt4_text') or x.startswith('reviews_') or x.startswith('review_') or 'review' in args.dir)]
29
-
30
- for review_file in sorted(review_files):
31
- config = os.path.basename(review_file).replace('gpt4_text_', '').replace('.jsonl', '')
32
- if args.select is not None and any(x not in config for x in args.select):
33
- continue
34
- if '0613' in config:
35
- version = '0613'
36
- else:
37
- version = '0314'
38
- if args.version is not None and args.version != version:
39
- continue
40
- scores = defaultdict(list)
41
- print(config)
42
- with open(os.path.join(args.dir, review_file) if args.dir is not None else review_file) as f:
43
- for review_str in f:
44
- review = json.loads(review_str)
45
- if review['question_id'] in args.ignore:
46
- continue
47
- if 'category' in review:
48
- scores[review['category']].append(review['tuple'])
49
- scores['all'].append(review['tuple'])
50
- else:
51
- if 'tuple' in review:
52
- scores['all'].append(review['tuple'])
53
- else:
54
- scores['all'].append(review['score'])
55
- for k, v in sorted(scores.items()):
56
- stats = np.asarray(v).mean(0).tolist()
57
- stats = [round(x, 3) for x in stats]
58
- # print(k, stats, round(stats[1]/stats[0]*100, 1))
59
- print(k, round(stats[1]/stats[0]*100, 1), round(stats[0] * 10, 1), round(stats[1] * 10, 1))
60
- print('=================================')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
LLAVA_Biovil/llava/eval/webpage/figures/alpaca.png DELETED
Binary file (96.1 kB)
 
LLAVA_Biovil/llava/eval/webpage/figures/bard.jpg DELETED
Binary file (15.3 kB)
 
LLAVA_Biovil/llava/eval/webpage/figures/chatgpt.svg DELETED
LLAVA_Biovil/llava/eval/webpage/figures/llama.jpg DELETED
Binary file (56.5 kB)
 
LLAVA_Biovil/llava/eval/webpage/figures/swords_FILL0_wght300_GRAD0_opsz48.svg DELETED
LLAVA_Biovil/llava/eval/webpage/figures/vicuna.jpeg DELETED
Binary file (54 kB)
 
LLAVA_Biovil/llava/eval/webpage/index.html DELETED
@@ -1,162 +0,0 @@
1
- <!DOCTYPE html>
2
- <html lang="en">
3
- <head>
4
- <meta charset="UTF-8">
5
- <meta name="viewport" content="width=device-width, initial-scale=1.0">
6
- <title>Who's GPT-4's favorite? Battles between State-of-the-Art Chatbots</title>
7
- <link rel="stylesheet" href="https://maxcdn.bootstrapcdn.com/bootstrap/4.5.2/css/bootstrap.min.css">
8
- <link rel="stylesheet" href="https://fonts.googleapis.com/icon?family=Material+Icons">
9
- <link rel="stylesheet" href="styles.css">
10
- </head>
11
-
12
- <body>
13
- <nav class="navbar navbar-expand-lg navbar-dark bg-dark">
14
- <a class="navbar-brand" href="#">🏔️ Vicuna Evaluation Examples</a>
15
- <button class="navbar-toggler" type="button" data-toggle="collapse" data-target="#navbarNav" aria-controls="navbarNav" aria-expanded="false" aria-label="Toggle navigation">
16
- <span class="navbar-toggler-icon"></span>
17
- </button>
18
- <div class="collapse navbar-collapse" id="navbarNav">
19
- <ul class="navbar-nav mr-auto">
20
- <li class="nav-item">
21
- <a class="nav-link" href="https://chat.lmsys.org/">Demo</a>
22
- </li>
23
- <li class="nav-item">
24
- <a class="nav-link" href="https://vicuna.lmsys.org">Blog</a>
25
- </li>
26
- <li class="nav-item">
27
- <a class="nav-link" href="https://github.com/lm-sys/FastChat">Github</a>
28
- </li>
29
- </ul>
30
- </div>
31
- </nav>
32
-
33
- <div class="container mt-5">
34
- <h2 class="text-center mb-5">Who's GPT-4's favorite? Battles between State-of-the-Art Chatbots</h2>
35
-
36
- <!-- Selection -->
37
- <div class="form-row">
38
- <div class="form-group col-md-2">
39
- <label for="category-select">Category</label>
40
- <select class="form-control" id="category-select"></select>
41
- </div>
42
- <div class="form-group col-md-8">
43
- <label for="question-select">Question</label>
44
- <select class="form-control" id="question-select"></select>
45
- </div>
46
- <div class="form-group col-md-2">
47
- <div class="col-md-2"><label>&nbsp;</label></div>
48
- <div class="btn-group" role="group" aria-label="Left and Right Controller">
49
- <button type="button" class="form-control btn btn-primary" id="prev-question"><i class="material-icons">keyboard_arrow_left</i></button>
50
- <button type="button" class="form-control btn btn-primary" id="next-question"><i class="material-icons">keyboard_arrow_right</i></button>
51
- </div>
52
- </div>
53
- </div>
54
-
55
- <!-- "Battle" -->
56
- <div class="row mb-4" style="justify-content: center;">
57
- <div class="col" style="display: flex; justify-content: center; align-items: center;">
58
- <label class="adjustable-font-size" id="other-score-label">*/10</label>
59
- </div>
60
- <div class="col">
61
- <div class="vertical-flex-layout">
62
- <img class="shadow figure-img img-fluid" src="" alt="other logo" width="150" id="other-model-figure">
63
- </div>
64
- </div>
65
- <div class="col">
66
- <div class="vertical-flex-layout">
67
- <!-- from: https://fonts.google.com/icons?icon.query=battle&selected=Material+Symbols+Outlined:swords:FILL@0;wght@300;GRAD@0;opsz@48&icon.style=Outlined -->
68
- <img class="figure-img img-fluid" src="figures/swords_FILL0_wght300_GRAD0_opsz48.svg" width="60" height="60">
69
- </div>
70
- </div>
71
- <div class="col">
72
- <div class="vertical-flex-layout">
73
- <img class="shadow figure-img img-fluid" src="figures/vicuna.jpeg" alt="vicuna logo" width="150" id="our-model-figure">
74
- </div>
75
- </div>
76
- <div class="col" style="display: flex; justify-content: center; align-items: center;">
77
- <label class="adjustable-font-size" id="our-score-label">*/10</label>
78
- </div>
79
- </div>
80
-
81
- <!-- Question Card -->
82
- <div class="card mb-4">
83
- <div class="card-body" id="selected-question"></div>
84
- </div>
85
-
86
- <!-- Answer Cards -->
87
- <div class="row">
88
- <div class="col-md-6">
89
- <div class="card mb-4 expandable-card">
90
- <div class="card-header" style="padding-bottom: 0.2rem" id="other-model-header-bg">
91
- <div class="row">
92
- <div class="col-md-5" style="align-items: center; display: flex;">
93
- <label id="other-model-header">Assistant #1</label>
94
- </div>
95
- <div class="col-md-7">
96
- <select class="form-control" id="model-select" style="height: fit-content; margin-top: -0.3rem;"></select>
97
- </div>
98
- </div>
99
- </div>
100
- <div class="card-body">
101
- <div class="card-text-container">
102
- <div class="card-text" id="other-model-answer"></div>
103
- </div>
104
- <div class="btn btn-primary expand-btn" style="display:flex;"></div>
105
- </div>
106
- </div>
107
- </div>
108
- <div class="col-md-6">
109
- <div class="card mb-4 expandable-card">
110
- <div class="card-header" id="our-model-header">
111
- Assistant #2 (Vicuna, our model)
112
- </div>
113
- <div class="card-body">
114
- <div class="card-text-container">
115
- <div class="card-text" id="our-model-answer"></div>
116
- </div>
117
- <div class="btn btn-primary expand-btn" style="display:flex;"></div>
118
- </div>
119
- </div>
120
- </div>
121
- </div>
122
-
123
- <!-- Evaluation -->
124
- <div class="card expandable-card">
125
- <div class="card-header" style="background-color: #c9c9f2;" id="evaluation-header">GPT-4 Evaluation</div>
126
- <div class="card-body">
127
- <div class="card-text-container">
128
- <div class="card-text" id="evaluation-result"></div>
129
- </div>
130
- <div class="btn btn-primary expand-btn" style="display:flex;"></div>
131
- </div>
132
- </div>
133
- </div>
134
-
135
- <div class="container-fluid bg-light py-2">
136
- <div class="text-center">
137
- <small class="text-muted">This website is co-authored with <a href="https://openai.com" target="_blank">GPT-4</a>.</small>
138
- </div>
139
- </div>
140
-
141
- <!-- Marked.js -->
142
- <script src="https://cdn.jsdelivr.net/npm/[email protected]/lib/marked.umd.min.js"></script>
143
- <!-- Bootstrap and Popper.js JavaScript dependencies -->
144
- <script src="https://code.jquery.com/jquery-3.5.1.slim.min.js"></script>
145
- <script src="https://cdn.jsdelivr.net/npm/@popperjs/[email protected]/dist/umd/popper.min.js"></script>
146
- <script src="https://maxcdn.bootstrapcdn.com/bootstrap/4.5.2/js/bootstrap.min.js"></script>
147
-
148
- <script src="script.js"></script>
149
- <script>
150
- // Fetch the JSON file
151
- fetch('data.json')
152
- .then(response => response.json())
153
- .then(json_data => {
154
- // Populate the models and questions.
155
- populateModels(json_data.models);
156
- populateQuestions(json_data.questions);
157
- displayQuestion(currentQuestionIndex);
158
- }).catch(error => console.error(error));
159
- </script>
160
- </body>
161
-
162
- </html>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
LLAVA_Biovil/llava/eval/webpage/script.js DELETED
@@ -1,245 +0,0 @@
1
- // Description: Script for the evaluation webpage.
2
-
3
- let currentQuestionIndex = 1;
4
-
5
- // Store the model name mapping for later use.
6
- modelNameMapping = {
7
- "gpt35": "ChatGPT-3.5",
8
- "gpt4": "GPT-4",
9
- "alpaca": "Alpaca-13b",
10
- "vicuna": "Vicuna-13b",
11
- "llama": "LLaMA-13b",
12
- "bard": "Bard",
13
- };
14
-
15
- modelFigureMapping = {
16
- "vicuna": "figures/vicuna.jpeg",
17
- // Image from: https://commons.wikimedia.org/wiki/File:ChatGPT_logo.svg
18
- "gpt35": "figures/chatgpt.svg",
19
- // Image from: https://www.reddit.com/r/logodesign/comments/1128aat/google_ai_bard_logo_design/
20
- "bard": "figures/bard.jpg",
21
- // Image from: https://crfm.stanford.edu/2023/03/13/alpaca.html
22
- "alpaca": "figures/alpaca.png",
23
- // Image adapted from https://commons.wikimedia.org/wiki/File:Llama_on_Machu_Picchu.jpg
24
- "llama": "figures/llama.jpg",
25
- }
26
-
27
- // Store the question data in a mapping for later use.
28
- questionMapping = {};
29
- // Store the question ids in a mapping for later use.
30
- categoryMapping = {};
31
- // Store the number of questions for later use.
32
- questionsCount = 0;
33
-
34
-
35
- function text2Markdown(text) {
36
- // Normalize the text for markdown rendering.
37
- text = text.trim().replaceAll('\n\n', '\n').replaceAll('\n', '\n\n');
38
- return marked.parse(text);
39
- }
40
-
41
- function capitalizeFirstChar(str) {
42
- if (!str || str.length === 0) {
43
- return str;
44
- }
45
- return str.charAt(0).toUpperCase() + str.slice(1);
46
- }
47
-
48
- function updateQuestionSelect(question_id) {
49
- const select = document.getElementById('question-select');
50
- // Clear the question select.
51
- select.innerHTML = '';
52
- // Populate the question select.
53
- category = questionMapping[question_id].category;
54
- categoryMapping[category].forEach(question_id => {
55
- const question = questionMapping[question_id];
56
- const option = document.createElement('option');
57
- option.value = question_id;
58
- option.textContent = 'Q' + question_id.toString() + ': ' + question.question;
59
- select.appendChild(option);
60
- });
61
- select.value = question_id;
62
- }
63
-
64
- function updateModelSelect() {
65
- const select = document.getElementById('model-select');
66
- img_path = modelFigureMapping[select.value];
67
- document.getElementById('other-model-figure').src = img_path;
68
- }
69
-
70
- function populateModels(models) {
71
- const select = document.getElementById('model-select');
72
- models.forEach(model => {
73
- const option = document.createElement('option');
74
- option.value = model;
75
- option.textContent = modelNameMapping[model];
76
- select.appendChild(option);
77
- });
78
- updateModelSelect();
79
- }
80
-
81
- function populateQuestions(questions) {
82
- const category_select = document.getElementById('category-select');
83
-
84
- questionsCount = questions.length;
85
- questions.forEach(question => {
86
- const option = document.createElement('option');
87
- // Store the question data in a mapping for later use.
88
- questionMapping[question.id] = {
89
- category: question.category,
90
- question: question.question,
91
- answers: question.answers,
92
- evaluations: question.evaluations,
93
- scores: question.scores,
94
- };
95
- // Store the question id in the category mapping.
96
- if (question.category in categoryMapping) {
97
- categoryMapping[question.category].push(question.id);
98
- } else {
99
- categoryMapping[question.category] = [question.id];
100
- const category_option = document.createElement('option');
101
- category_option.value = question.category;
102
- category_option.textContent = capitalizeFirstChar(question.category);
103
- category_select.appendChild(category_option);
104
- }
105
- });
106
- // Set the default category.
107
- updateQuestionSelect(currentQuestionIndex);
108
- }
109
-
110
- function displayQuestion(index) {
111
- const question = questionMapping[index].question;
112
- document.getElementById('selected-question').innerHTML = text2Markdown('**Question:** ' + question);
113
- displayAnswers(index);
114
- }
115
-
116
- function displayAnswers(index) {
117
- const question = questionMapping[index];
118
- const otherModel = document.getElementById('model-select').value;
119
- // render the answers with markdown
120
- document.getElementById('other-model-answer').innerHTML = text2Markdown(question.answers[otherModel]);
121
- document.getElementById('our-model-answer').innerHTML = text2Markdown(question.answers.vicuna);
122
-
123
- // Display evaluation
124
- score = question.scores[otherModel];
125
- score_text = modelNameMapping[otherModel] + " " + score[0] + "/10, Vicuna-13b " + score[1] + "/10";
126
- document.getElementById('evaluation-header').textContent = "GPT-4 Evaluation" + " (Score: " + score_text + ")";
127
- document.getElementById('evaluation-result').innerHTML = text2Markdown(question.evaluations[otherModel]);
128
-
129
- // Update model names
130
- let assistant1_title = "Assistant #1"; // (" + modelNameMapping[otherModel] + ")";
131
- let assistant2_title = "Assistant #2 (Vicuna-13b, our model)";
132
- // Update scores/labels.
133
- let assistant1_score_label = score[0].toString() + '/10';
134
- let assistant2_score_label = score[1].toString() + '/10';
135
-
136
- const colorRed ='#fa9'; // '#eb978d';
137
- // const colorGreen = '#c9f2c9';
138
- const colorBlue = '#8ef'; // '#71dbf9';
139
- const colorYellow = '#fe7'; // '#fada57';
140
- let otherModelHeaderColor = '';
141
- let ourModelHeaderColor = '';
142
- // Update the winner.
143
- if (score[0] == score[1]) {
144
- assistant1_title = '🏆 ' + assistant1_title;
145
- assistant1_score_label = '🏆 ' + assistant1_score_label;
146
- assistant2_title = '🏆 ' + assistant2_title;
147
- assistant2_score_label = '🏆 ' + assistant2_score_label;
148
- otherModelHeaderColor = colorYellow;
149
- ourModelHeaderColor = colorYellow;
150
- } else if (score[0] > score[1]) {
151
- assistant1_title = '🏆 ' + assistant1_title;
152
- assistant1_score_label = '🏆 ' + assistant1_score_label;
153
- otherModelHeaderColor = colorBlue;
154
- ourModelHeaderColor = colorRed;
155
- } else if (score[0] < score[1]) {
156
- assistant2_title = '🏆 ' + assistant2_title;
157
- assistant2_score_label = '🏆 ' + assistant2_score_label;
158
- otherModelHeaderColor = colorRed;
159
- ourModelHeaderColor = colorBlue;
160
- }
161
-
162
- document.getElementById('other-model-header-bg').style.backgroundColor = otherModelHeaderColor;
163
- document.getElementById('our-model-header').style.backgroundColor = ourModelHeaderColor;
164
-
165
- document.getElementById('other-model-header').textContent = assistant1_title;
166
- document.getElementById('our-model-header').textContent = assistant2_title;
167
-
168
- document.getElementById('other-score-label').textContent = assistant1_score_label;
169
- document.getElementById('our-score-label').textContent = assistant2_score_label;
170
-
171
- // Update expand buttons visibility for both cards after displaying answers
172
- // Reset the expanded state and update expand buttons visibility for both cards after displaying answers
173
- document.querySelectorAll('.expandable-card').forEach(card => {
174
- card.classList.remove('expanded');
175
- updateExpandButtonVisibility(card);
176
- const expandBtn = card.querySelector('.expand-btn');
177
- expandBtn.innerHTML = '<i class="material-icons" style="pointer-events: none">keyboard_arrow_down</i> Show more'; // .textContent = 'Show more';
178
- });
179
- }
180
-
181
- document.getElementById('question-select').addEventListener('change', e => {
182
- currentQuestionIndex = parseInt(e.target.value);
183
- displayQuestion(currentQuestionIndex);
184
- });
185
-
186
- document.getElementById('category-select').addEventListener('change', e => {
187
- let currentCategory = e.target.value;
188
- const questionIds = categoryMapping[currentCategory];
189
- currentQuestionIndex = questionIds[0];
190
- updateQuestionSelect(currentQuestionIndex);
191
- displayQuestion(currentQuestionIndex);
192
- });
193
-
194
- // Update expand buttons whenever the model is changed
195
- document.getElementById('model-select').addEventListener('change', () => {
196
- displayAnswers(currentQuestionIndex);
197
- document.querySelectorAll('.expandable-card').forEach(card => {
198
- updateExpandButtonVisibility(card);
199
- });
200
- updateModelSelect();
201
- });
202
-
203
- function switchQuestionAndCategory() {
204
- document.getElementById('question-select').value = currentQuestionIndex;
205
- old_category = document.getElementById('category-select').value;
206
- new_category = questionMapping[currentQuestionIndex].category;
207
- if (old_category != new_category) {
208
- document.getElementById('category-select').value = new_category;
209
- updateQuestionSelect(currentQuestionIndex);
210
- }
211
- displayQuestion(currentQuestionIndex);
212
- }
213
-
214
- document.getElementById('prev-question').addEventListener('click', () => {
215
- // Question index starts from 1.
216
- currentQuestionIndex = Math.max(1, currentQuestionIndex - 1);
217
- switchQuestionAndCategory();
218
- });
219
-
220
- document.getElementById('next-question').addEventListener('click', () => {
221
- // Question index starts from 1.
222
- currentQuestionIndex = Math.min(questionsCount, currentQuestionIndex + 1);
223
- switchQuestionAndCategory();
224
- });
225
-
226
- function updateExpandButtonVisibility(card) {
227
- const cardTextContainer = card.querySelector('.card-text-container');
228
- const expandBtn = card.querySelector('.expand-btn');
229
- if (cardTextContainer.scrollHeight > cardTextContainer.offsetHeight) {
230
- expandBtn.style.display = 'flex';
231
- } else {
232
- expandBtn.style.display = 'none';
233
- card.classList.add('expanded');
234
- }
235
- }
236
-
237
- document.querySelectorAll('.expand-btn').forEach(btn => {
238
- btn.addEventListener('click', e => {
239
- const card = e.target.closest('.expandable-card');
240
- card.classList.toggle('expanded');
241
- const more = '<i class="material-icons" style="pointer-events: none">keyboard_arrow_down</i> Show more';
242
- const less = '<i class="material-icons" style="pointer-events: none">keyboard_arrow_up</i> Show less';
243
- e.target.innerHTML = card.classList.contains('expanded') ? less : more;
244
- });
245
- });
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
LLAVA_Biovil/llava/eval/webpage/styles.css DELETED
@@ -1,105 +0,0 @@
1
- body {
2
- font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;
3
- background-color: #f8f9fa;
4
- }
5
-
6
- .navbar-dark .navbar-nav .nav-link {
7
- color: #f1cf68;
8
- font-size: 1.1rem;
9
- padding: 0.5rem 0.6rem;
10
- }
11
-
12
- .card-header {
13
- font-weight: bold;
14
- }
15
-
16
- .card {
17
- box-shadow: 0 4px 8px rgba(0, 0, 0, 0.1);
18
- transition: 0.3s;
19
- }
20
-
21
- .card:hover {
22
- box-shadow: 0 8px 16px rgba(0, 0, 0, 0.2);
23
- }
24
-
25
- button {
26
- transition: background-color 0.3s;
27
- }
28
-
29
- button:hover {
30
- background-color: #007bff;
31
- }
32
-
33
- @media (max-width: 767px) {
34
- .form-row .form-group {
35
- margin-bottom: 10px;
36
- }
37
- }
38
-
39
- /* Extra styles */
40
-
41
- .expandable-card .card-text-container {
42
- max-height: 200px;
43
- overflow-y: hidden;
44
- position: relative;
45
- }
46
-
47
- .expandable-card.expanded .card-text-container {
48
- max-height: none;
49
- }
50
-
51
- .expand-btn {
52
- position: relative;
53
- display: none;
54
- background-color: rgba(255, 255, 255, 0.8);
55
- color: #510c75;
56
- border-color: transparent;
57
- }
58
-
59
- .expand-btn:hover {
60
- background-color: rgba(200, 200, 200, 0.8);
61
- text-decoration: none;
62
- border-color: transparent;
63
- color: #510c75;
64
- }
65
-
66
- .expand-btn:focus {
67
- outline: none;
68
- text-decoration: none;
69
- }
70
-
71
- .expandable-card:not(.expanded) .card-text-container:after {
72
- content: "";
73
- position: absolute;
74
- bottom: 0;
75
- left: 0;
76
- width: 100%;
77
- height: 90px;
78
- background: linear-gradient(rgba(255, 255, 255, 0.2), rgba(255, 255, 255, 1));
79
- }
80
-
81
- .expandable-card:not(.expanded) .expand-btn {
82
- margin-top: -40px;
83
- }
84
-
85
- .card-body {
86
- padding-bottom: 5px;
87
- }
88
-
89
- .vertical-flex-layout {
90
- justify-content: center;
91
- align-items: center;
92
- height: 100%;
93
- display: flex;
94
- flex-direction: column;
95
- gap: 5px;
96
- }
97
-
98
- .figure-img {
99
- max-width: 100%;
100
- height: auto;
101
- }
102
-
103
- .adjustable-font-size {
104
- font-size: calc(0.5rem + 2vw);
105
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
LLAVA_Biovil/llava/mm_utils.py CHANGED
@@ -5,7 +5,7 @@
5
 
6
  import torch
7
  from transformers import StoppingCriteria
8
- from llava.constants import IMAGE_TOKEN_INDEX
9
 
10
 
11
  def load_image_from_base64(image):
 
5
 
6
  import torch
7
  from transformers import StoppingCriteria
8
+ from LLAVA_Biovil.llava.constants import IMAGE_TOKEN_INDEX
9
 
10
 
11
  def load_image_from_base64(image):
LLAVA_Biovil/llava/model/apply_delta.py CHANGED
@@ -7,7 +7,7 @@
7
  import torch
8
  from tqdm import tqdm
9
  from transformers import AutoTokenizer, AutoModelForCausalLM
10
- from llava import LlavaLlamaForCausalLM
11
 
12
 
13
  def apply_delta(base_model_path, target_model_path, delta_path):
 
7
  import torch
8
  from tqdm import tqdm
9
  from transformers import AutoTokenizer, AutoModelForCausalLM
10
+ from LLAVA_Biovil.llava import LlavaLlamaForCausalLM
11
 
12
 
13
  def apply_delta(base_model_path, target_model_path, delta_path):
LLAVA_Biovil/llava/model/builder.py CHANGED
@@ -20,17 +20,17 @@
20
  from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig, BitsAndBytesConfig
21
  import torch
22
 
23
- from LLAVA.biovil_t.model import ImageModel
24
- from LLAVA.biovil_t.pretrained import _download_biovil_t_image_model_weights
25
- from LLAVA.biovil_t.types import ImageEncoderType
26
- from LLAVA.llava.model.multimodal_projector.builder import build_vision_projector
27
 
28
  try:
29
- from LLAVA.llava.model import *
30
- from LLAVA.llava.constants import DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
31
  except:
32
- from llava.model import *
33
- from llava.constants import DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
34
 
35
 
36
  def load_pretrained_model(model_path, model_base, model_name, load_8bit=False, load_4bit=False, device_map="auto", device="cuda", **kwargs):
 
20
  from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig, BitsAndBytesConfig
21
  import torch
22
 
23
+ from LLAVA_Biovil.biovil_t.model import ImageModel
24
+ from LLAVA_Biovil.biovil_t.pretrained import _download_biovil_t_image_model_weights
25
+ from LLAVA_Biovil.biovil_t.types import ImageEncoderType
26
+ from LLAVA_Biovil.llava.model.multimodal_projector.builder import build_vision_projector
27
 
28
  try:
29
+ from LLAVA_Biovil.llava.model import *
30
+ from LLAVA_Biovil.llava.constants import DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
31
  except:
32
+ from LLAVA_Biovil.llava.model import *
33
+ from LLAVA_Biovil.llava.constants import DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
34
 
35
 
36
  def load_pretrained_model(model_path, model_base, model_name, load_8bit=False, load_4bit=False, device_map="auto", device="cuda", **kwargs):
LLAVA_Biovil/llava/model/consolidate.py CHANGED
@@ -6,7 +6,7 @@
6
 
7
  import torch
8
  from transformers import AutoTokenizer, AutoModelForCausalLM
9
- from LLAV.llava.model.utils import auto_upgrade
10
 
11
 
12
  def consolidate_ckpt(src_path, dst_path):
 
6
 
7
  import torch
8
  from transformers import AutoTokenizer, AutoModelForCausalLM
9
+ from LLAVA_Biovil.llava.model.utils import auto_upgrade
10
 
11
 
12
  def consolidate_ckpt(src_path, dst_path):
LLAVA_Biovil/llava/model/language_model/llava_llama.py CHANGED
@@ -25,7 +25,7 @@
25
 
26
  from transformers.modeling_outputs import CausalLMOutputWithPast
27
 
28
- from ..llava_arch import LlavaMetaModel, LlavaMetaForCausalLM
29
 
30
 
31
  class LlavaConfig(LlamaConfig):
 
25
 
26
  from transformers.modeling_outputs import CausalLMOutputWithPast
27
 
28
+ from LLAVA_Biovil.llava.llava_arch import LlavaMetaModel, LlavaMetaForCausalLM
29
 
30
 
31
  class LlavaConfig(LlamaConfig):
LLAVA_Biovil/llava/model/language_model/llava_mpt.py CHANGED
@@ -23,8 +23,8 @@
23
  from transformers import AutoConfig, AutoModelForCausalLM
24
  from transformers.modeling_outputs import CausalLMOutputWithPast
25
 
26
- from .mpt.modeling_mpt import MPTConfig, MPTForCausalLM, MPTModel
27
- from llava.model.llava_arch import LlavaMetaModel, LlavaMetaForCausalLM
28
 
29
 
30
  class LlavaMPTConfig(MPTConfig):
 
23
  from transformers import AutoConfig, AutoModelForCausalLM
24
  from transformers.modeling_outputs import CausalLMOutputWithPast
25
 
26
+ from LLAVA_Biovil.llava.model.language_model.mpt.modeling_mpt import MPTConfig, MPTForCausalLM, MPTModel
27
+ from LLAVA_Biovil.llava.model.llava_arch import LlavaMetaModel, LlavaMetaForCausalLM
28
 
29
 
30
  class LlavaMPTConfig(MPTConfig):
LLAVA_Biovil/llava/model/llava_arch.py CHANGED
@@ -15,13 +15,13 @@
15
 
16
  import torch
17
 
18
- from biovil_t.model import ImageModel
19
- from biovil_t.pretrained import _download_biovil_t_image_model_weights
20
- from biovil_t.types import ImageEncoderType
21
- from .multimodal_encoder.builder import build_vision_tower
22
- from .multimodal_projector.builder import build_vision_projector, build_image_pooler
23
 
24
- from llava.constants import IGNORE_INDEX, IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
25
 
26
 
27
 
 
15
 
16
  import torch
17
 
18
+ from LLAVA_Biovil.biovil_t.model import ImageModel
19
+ from LLAVA_Biovil.biovil_t.pretrained import _download_biovil_t_image_model_weights
20
+ from LLAVA_Biovil.biovil_t.types import ImageEncoderType
21
+ from LLAVA_Biovil.llava.multimodal_encoder.builder import build_vision_tower
22
+ from LLAVA_Biovil.llava.multimodal_projector.builder import build_vision_projector, build_image_pooler
23
 
24
+ from LLAVA_Biovil.llava.constants import IGNORE_INDEX, IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
25
 
26
 
27
 
LLAVA_Biovil/llava/serve/__init__.py DELETED
File without changes
LLAVA_Biovil/llava/serve/cli.py DELETED
@@ -1,122 +0,0 @@
1
- import argparse
2
- import torch
3
-
4
- from LLAV.llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
5
- from LLAV.llava.conversation import conv_templates, SeparatorStyle
6
- from LLAV.llava.model.builder import load_pretrained_model
7
- from LLAV.llava.utils import disable_torch_init
8
- from LLAV.llava.mm_utils import process_images, tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria
9
-
10
- import requests
11
- from PIL import Image
12
- from io import BytesIO
13
- from transformers import TextStreamer
14
-
15
-
16
- def load_image(image_file):
17
- if image_file.startswith('http://') or image_file.startswith('https://'):
18
- response = requests.get(image_file)
19
- image = Image.open(BytesIO(response.content)).convert('RGB')
20
- else:
21
- image = Image.open(image_file).convert('RGB')
22
- return image
23
-
24
-
25
- def main(args):
26
- # Model
27
- disable_torch_init()
28
-
29
- model_name = get_model_name_from_path(args.model_path)
30
- tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name, args.load_8bit, args.load_4bit, device=args.device)
31
-
32
- if 'llama-2' in model_name.lower():
33
- conv_mode = "llava_llama_2"
34
- elif "v1" in model_name.lower():
35
- conv_mode = "llava_v1"
36
- elif "mpt" in model_name.lower():
37
- conv_mode = "mpt"
38
- else:
39
- conv_mode = "llava_v0"
40
-
41
- if args.conv_mode is not None and conv_mode != args.conv_mode:
42
- print('[WARNING] the auto inferred conversation mode is {}, while `--conv-mode` is {}, using {}'.format(conv_mode, args.conv_mode, args.conv_mode))
43
- else:
44
- args.conv_mode = conv_mode
45
-
46
- conv = conv_templates[args.conv_mode].copy()
47
- if "mpt" in model_name.lower():
48
- roles = ('user', 'assistant')
49
- else:
50
- roles = conv.roles
51
-
52
- image = load_image(args.image_file)
53
- # Similar operation in model_worker.py
54
- image_tensor = process_images([image], image_processor, model.config)
55
- if type(image_tensor) is list:
56
- image_tensor = [image.to(model.device, dtype=torch.float16) for image in image_tensor]
57
- else:
58
- image_tensor = image_tensor.to(model.device, dtype=torch.float16)
59
-
60
- while True:
61
- try:
62
- inp = input(f"{roles[0]}: ")
63
- except EOFError:
64
- inp = ""
65
- if not inp:
66
- print("exit...")
67
- break
68
-
69
- print(f"{roles[1]}: ", end="")
70
-
71
- if image is not None:
72
- # first message
73
- if model.config.mm_use_im_start_end:
74
- inp = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + inp
75
- else:
76
- inp = DEFAULT_IMAGE_TOKEN + '\n' + inp
77
- conv.append_message(conv.roles[0], inp)
78
- image = None
79
- else:
80
- # later messages
81
- conv.append_message(conv.roles[0], inp)
82
- conv.append_message(conv.roles[1], None)
83
- prompt = conv.get_prompt()
84
-
85
- input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(model.device)
86
- stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
87
- keywords = [stop_str]
88
- stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
89
- streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
90
-
91
- with torch.inference_mode():
92
- output_ids = model.generate(
93
- input_ids,
94
- images=image_tensor,
95
- do_sample=True if args.temperature > 0 else False,
96
- temperature=args.temperature,
97
- max_new_tokens=args.max_new_tokens,
98
- streamer=streamer,
99
- use_cache=True,
100
- stopping_criteria=[stopping_criteria])
101
-
102
- outputs = tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip()
103
- conv.messages[-1][-1] = outputs
104
-
105
- if args.debug:
106
- print("\n", {"prompt": prompt, "outputs": outputs}, "\n")
107
-
108
-
109
- if __name__ == "__main__":
110
- parser = argparse.ArgumentParser()
111
- parser.add_argument("--model-path", type=str, default="facebook/opt-350m")
112
- parser.add_argument("--model-base", type=str, default=None)
113
- parser.add_argument("--image-file", type=str, required=True)
114
- parser.add_argument("--device", type=str, default="cuda")
115
- parser.add_argument("--conv-mode", type=str, default=None)
116
- parser.add_argument("--temperature", type=float, default=0.2)
117
- parser.add_argument("--max-new-tokens", type=int, default=512)
118
- parser.add_argument("--load-8bit", action="store_true")
119
- parser.add_argument("--load-4bit", action="store_true")
120
- parser.add_argument("--debug", action="store_true")
121
- args = parser.parse_args()
122
- main(args)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
LLAVA_Biovil/llava/serve/controller.py DELETED
@@ -1,296 +0,0 @@
1
- """
2
- A controller manages distributed workers.
3
- It sends worker addresses to clients.
4
- """
5
- import argparse
6
- import dataclasses
7
- from enum import Enum, auto
8
- import json
9
- import time
10
- from typing import List
11
- import threading
12
-
13
- from fastapi import FastAPI, Request
14
- from fastapi.responses import StreamingResponse
15
- import numpy as np
16
- import requests
17
- import uvicorn
18
-
19
- from llava.constants import CONTROLLER_HEART_BEAT_EXPIRATION
20
- from llava.utils import build_logger, server_error_msg
21
-
22
-
23
- logger = build_logger("controller", "controller.log")
24
-
25
-
26
- class DispatchMethod(Enum):
27
- LOTTERY = auto()
28
- SHORTEST_QUEUE = auto()
29
-
30
- @classmethod
31
- def from_str(cls, name):
32
- if name == "lottery":
33
- return cls.LOTTERY
34
- elif name == "shortest_queue":
35
- return cls.SHORTEST_QUEUE
36
- else:
37
- raise ValueError(f"Invalid dispatch method")
38
-
39
-
40
- @dataclasses.dataclass
41
- class WorkerInfo:
42
- model_names: List[str]
43
- speed: int
44
- queue_length: int
45
- check_heart_beat: bool
46
- last_heart_beat: str
47
-
48
-
49
- def heart_beat_controller(controller):
50
- while True:
51
- time.sleep(CONTROLLER_HEART_BEAT_EXPIRATION)
52
- controller.remove_stable_workers_by_expiration()
53
-
54
-
55
- class Controller:
56
- def __init__(self, dispatch_method: str):
57
- # Dict[str -> WorkerInfo]
58
- self.worker_info = {}
59
- self.dispatch_method = DispatchMethod.from_str(dispatch_method)
60
-
61
- self.heart_beat_thread = threading.Thread(
62
- target=heart_beat_controller, args=(self,))
63
- self.heart_beat_thread.start()
64
-
65
- logger.info("Init controller")
66
-
67
- def register_worker(self, worker_name: str, check_heart_beat: bool,
68
- worker_status: dict):
69
- if worker_name not in self.worker_info:
70
- logger.info(f"Register a new worker: {worker_name}")
71
- else:
72
- logger.info(f"Register an existing worker: {worker_name}")
73
-
74
- if not worker_status:
75
- worker_status = self.get_worker_status(worker_name)
76
- if not worker_status:
77
- return False
78
-
79
- self.worker_info[worker_name] = WorkerInfo(
80
- worker_status["model_names"], worker_status["speed"], worker_status["queue_length"],
81
- check_heart_beat, time.time())
82
-
83
- logger.info(f"Register done: {worker_name}, {worker_status}")
84
- return True
85
-
86
- def get_worker_status(self, worker_name: str):
87
- try:
88
- r = requests.post(worker_name + "/worker_get_status", timeout=5)
89
- except requests.exceptions.RequestException as e:
90
- logger.error(f"Get status fails: {worker_name}, {e}")
91
- return None
92
-
93
- if r.status_code != 200:
94
- logger.error(f"Get status fails: {worker_name}, {r}")
95
- return None
96
-
97
- return r.json()
98
-
99
- def remove_worker(self, worker_name: str):
100
- del self.worker_info[worker_name]
101
-
102
- def refresh_all_workers(self):
103
- old_info = dict(self.worker_info)
104
- self.worker_info = {}
105
-
106
- for w_name, w_info in old_info.items():
107
- if not self.register_worker(w_name, w_info.check_heart_beat, None):
108
- logger.info(f"Remove stale worker: {w_name}")
109
-
110
- def list_models(self):
111
- model_names = set()
112
-
113
- for w_name, w_info in self.worker_info.items():
114
- model_names.update(w_info.model_names)
115
-
116
- return list(model_names)
117
-
118
- def get_worker_address(self, model_name: str):
119
- if self.dispatch_method == DispatchMethod.LOTTERY:
120
- worker_names = []
121
- worker_speeds = []
122
- for w_name, w_info in self.worker_info.items():
123
- if model_name in w_info.model_names:
124
- worker_names.append(w_name)
125
- worker_speeds.append(w_info.speed)
126
- worker_speeds = np.array(worker_speeds, dtype=np.float32)
127
- norm = np.sum(worker_speeds)
128
- if norm < 1e-4:
129
- return ""
130
- worker_speeds = worker_speeds / norm
131
- if True: # Directly return address
132
- pt = np.random.choice(np.arange(len(worker_names)),
133
- p=worker_speeds)
134
- worker_name = worker_names[pt]
135
- return worker_name
136
-
137
- # Check status before returning
138
- while True:
139
- pt = np.random.choice(np.arange(len(worker_names)),
140
- p=worker_speeds)
141
- worker_name = worker_names[pt]
142
-
143
- if self.get_worker_status(worker_name):
144
- break
145
- else:
146
- self.remove_worker(worker_name)
147
- worker_speeds[pt] = 0
148
- norm = np.sum(worker_speeds)
149
- if norm < 1e-4:
150
- return ""
151
- worker_speeds = worker_speeds / norm
152
- continue
153
- return worker_name
154
- elif self.dispatch_method == DispatchMethod.SHORTEST_QUEUE:
155
- worker_names = []
156
- worker_qlen = []
157
- for w_name, w_info in self.worker_info.items():
158
- if model_name in w_info.model_names:
159
- worker_names.append(w_name)
160
- worker_qlen.append(w_info.queue_length / w_info.speed)
161
- if len(worker_names) == 0:
162
- return ""
163
- min_index = np.argmin(worker_qlen)
164
- w_name = worker_names[min_index]
165
- self.worker_info[w_name].queue_length += 1
166
- logger.info(f"names: {worker_names}, queue_lens: {worker_qlen}, ret: {w_name}")
167
- return w_name
168
- else:
169
- raise ValueError(f"Invalid dispatch method: {self.dispatch_method}")
170
-
171
- def receive_heart_beat(self, worker_name: str, queue_length: int):
172
- if worker_name not in self.worker_info:
173
- logger.info(f"Receive unknown heart beat. {worker_name}")
174
- return False
175
-
176
- self.worker_info[worker_name].queue_length = queue_length
177
- self.worker_info[worker_name].last_heart_beat = time.time()
178
- logger.info(f"Receive heart beat. {worker_name}")
179
- return True
180
-
181
- def remove_stable_workers_by_expiration(self):
182
- expire = time.time() - CONTROLLER_HEART_BEAT_EXPIRATION
183
- to_delete = []
184
- for worker_name, w_info in self.worker_info.items():
185
- if w_info.check_heart_beat and w_info.last_heart_beat < expire:
186
- to_delete.append(worker_name)
187
-
188
- for worker_name in to_delete:
189
- self.remove_worker(worker_name)
190
-
191
- def worker_api_generate_stream(self, params):
192
- worker_addr = self.get_worker_address(params["model"])
193
- if not worker_addr:
194
- logger.info(f"no worker: {params['model']}")
195
- ret = {
196
- "text": server_error_msg,
197
- "error_code": 2,
198
- }
199
- yield json.dumps(ret).encode() + b"\0"
200
-
201
- try:
202
- response = requests.post(worker_addr + "/worker_generate_stream",
203
- json=params, stream=True, timeout=5)
204
- for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
205
- if chunk:
206
- yield chunk + b"\0"
207
- except requests.exceptions.RequestException as e:
208
- logger.info(f"worker timeout: {worker_addr}")
209
- ret = {
210
- "text": server_error_msg,
211
- "error_code": 3,
212
- }
213
- yield json.dumps(ret).encode() + b"\0"
214
-
215
-
216
- # Let the controller act as a worker to achieve hierarchical
217
- # management. This can be used to connect isolated sub networks.
218
- def worker_api_get_status(self):
219
- model_names = set()
220
- speed = 0
221
- queue_length = 0
222
-
223
- for w_name in self.worker_info:
224
- worker_status = self.get_worker_status(w_name)
225
- if worker_status is not None:
226
- model_names.update(worker_status["model_names"])
227
- speed += worker_status["speed"]
228
- queue_length += worker_status["queue_length"]
229
-
230
- return {
231
- "model_names": list(model_names),
232
- "speed": speed,
233
- "queue_length": queue_length,
234
- }
235
-
236
-
237
- app = FastAPI()
238
-
239
-
240
- @app.post("/register_worker")
241
- async def register_worker(request: Request):
242
- data = await request.json()
243
- controller.register_worker(
244
- data["worker_name"], data["check_heart_beat"],
245
- data.get("worker_status", None))
246
-
247
-
248
- @app.post("/refresh_all_workers")
249
- async def refresh_all_workers():
250
- models = controller.refresh_all_workers()
251
-
252
-
253
- @app.post("/list_models")
254
- async def list_models():
255
- models = controller.list_models()
256
- return {"models": models}
257
-
258
-
259
- @app.post("/get_worker_address")
260
- async def get_worker_address(request: Request):
261
- data = await request.json()
262
- addr = controller.get_worker_address(data["model"])
263
- return {"address": addr}
264
-
265
-
266
- @app.post("/receive_heart_beat")
267
- async def receive_heart_beat(request: Request):
268
- data = await request.json()
269
- exist = controller.receive_heart_beat(
270
- data["worker_name"], data["queue_length"])
271
- return {"exist": exist}
272
-
273
-
274
- @app.post("/worker_generate_stream")
275
- async def worker_api_generate_stream(request: Request):
276
- params = await request.json()
277
- generator = controller.worker_api_generate_stream(params)
278
- return StreamingResponse(generator)
279
-
280
-
281
- @app.post("/worker_get_status")
282
- async def worker_api_get_status(request: Request):
283
- return controller.worker_api_get_status()
284
-
285
-
286
- if __name__ == "__main__":
287
- parser = argparse.ArgumentParser()
288
- parser.add_argument("--host", type=str, default="localhost")
289
- parser.add_argument("--port", type=int, default=21001)
290
- parser.add_argument("--dispatch-method", type=str, choices=[
291
- "lottery", "shortest_queue"], default="shortest_queue")
292
- args = parser.parse_args()
293
- logger.info(f"args: {args}")
294
-
295
- controller = Controller(args.dispatch_method)
296
- uvicorn.run(app, host=args.host, port=args.port, log_level="info")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
LLAVA_Biovil/llava/serve/examples/extreme_ironing.jpg DELETED
Binary file (62.6 kB)
 
LLAVA_Biovil/llava/serve/examples/waterview.jpg DELETED
Binary file (95.5 kB)
 
LLAVA_Biovil/llava/serve/gradio_web_server.py DELETED
@@ -1,470 +0,0 @@
1
- import argparse
2
- import datetime
3
- import json
4
- import os
5
- import time
6
-
7
- import gradio as gr
8
- import requests
9
-
10
- from llava.conversation import (default_conversation, conv_templates,
11
- SeparatorStyle)
12
- from llava.constants import LOGDIR
13
- from llava.utils import (build_logger, server_error_msg,
14
- violates_moderation, moderation_msg)
15
- import hashlib
16
-
17
-
18
- logger = build_logger("gradio_web_server", "gradio_web_server.log")
19
-
20
- headers = {"User-Agent": "LLaVA Client"}
21
-
22
- no_change_btn = gr.Button.update()
23
- enable_btn = gr.Button.update(interactive=True)
24
- disable_btn = gr.Button.update(interactive=False)
25
-
26
- priority = {
27
- "vicuna-13b": "aaaaaaa",
28
- "koala-13b": "aaaaaab",
29
- }
30
-
31
-
32
- def get_conv_log_filename():
33
- t = datetime.datetime.now()
34
- name = os.path.join(LOGDIR, f"{t.year}-{t.month:02d}-{t.day:02d}-conv.json")
35
- return name
36
-
37
-
38
- def get_model_list():
39
- ret = requests.post(args.controller_url + "/refresh_all_workers")
40
- assert ret.status_code == 200
41
- ret = requests.post(args.controller_url + "/list_models")
42
- models = ret.json()["models"]
43
- models.sort(key=lambda x: priority.get(x, x))
44
- logger.info(f"Models: {models}")
45
- return models
46
-
47
-
48
- get_window_url_params = """
49
- function() {
50
- const params = new URLSearchParams(window.location.search);
51
- url_params = Object.fromEntries(params);
52
- console.log(url_params);
53
- return url_params;
54
- }
55
- """
56
-
57
-
58
- def load_demo(url_params, request: gr.Request):
59
- logger.info(f"load_demo. ip: {request.client.host}. params: {url_params}")
60
-
61
- dropdown_update = gr.Dropdown.update(visible=True)
62
- if "model" in url_params:
63
- model = url_params["model"]
64
- if model in models:
65
- dropdown_update = gr.Dropdown.update(
66
- value=model, visible=True)
67
-
68
- state = default_conversation.copy()
69
- return state, dropdown_update
70
-
71
-
72
- def load_demo_refresh_model_list(request: gr.Request):
73
- logger.info(f"load_demo. ip: {request.client.host}")
74
- models = get_model_list()
75
- state = default_conversation.copy()
76
- dropdown_update = gr.Dropdown.update(
77
- choices=models,
78
- value=models[0] if len(models) > 0 else ""
79
- )
80
- return state, dropdown_update
81
-
82
-
83
- def vote_last_response(state, vote_type, model_selector, request: gr.Request):
84
- with open(get_conv_log_filename(), "a") as fout:
85
- data = {
86
- "tstamp": round(time.time(), 4),
87
- "type": vote_type,
88
- "model": model_selector,
89
- "state": state.dict(),
90
- "ip": request.client.host,
91
- }
92
- fout.write(json.dumps(data) + "\n")
93
-
94
-
95
- def upvote_last_response(state, model_selector, request: gr.Request):
96
- logger.info(f"upvote. ip: {request.client.host}")
97
- vote_last_response(state, "upvote", model_selector, request)
98
- return ("",) + (disable_btn,) * 3
99
-
100
-
101
- def downvote_last_response(state, model_selector, request: gr.Request):
102
- logger.info(f"downvote. ip: {request.client.host}")
103
- vote_last_response(state, "downvote", model_selector, request)
104
- return ("",) + (disable_btn,) * 3
105
-
106
-
107
- def flag_last_response(state, model_selector, request: gr.Request):
108
- logger.info(f"flag. ip: {request.client.host}")
109
- vote_last_response(state, "flag", model_selector, request)
110
- return ("",) + (disable_btn,) * 3
111
-
112
-
113
- def regenerate(state, image_process_mode, request: gr.Request):
114
- logger.info(f"regenerate. ip: {request.client.host}")
115
- state.messages[-1][-1] = None
116
- prev_human_msg = state.messages[-2]
117
- if type(prev_human_msg[1]) in (tuple, list):
118
- prev_human_msg[1] = (*prev_human_msg[1][:2], image_process_mode)
119
- state.skip_next = False
120
- return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5
121
-
122
-
123
- def clear_history(request: gr.Request):
124
- logger.info(f"clear_history. ip: {request.client.host}")
125
- state = default_conversation.copy()
126
- return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5
127
-
128
-
129
- def add_text(state, text, image, image_process_mode, request: gr.Request):
130
- logger.info(f"add_text. ip: {request.client.host}. len: {len(text)}")
131
- if len(text) <= 0 and image is None:
132
- state.skip_next = True
133
- return (state, state.to_gradio_chatbot(), "", None) + (no_change_btn,) * 5
134
- if args.moderate:
135
- flagged = violates_moderation(text)
136
- if flagged:
137
- state.skip_next = True
138
- return (state, state.to_gradio_chatbot(), moderation_msg, None) + (
139
- no_change_btn,) * 5
140
-
141
- text = text[:1536] # Hard cut-off
142
- if image is not None:
143
- text = text[:1200] # Hard cut-off for images
144
- if '<image>' not in text:
145
- # text = '<Image><image></Image>' + text
146
- text = text + '\n<image>'
147
- text = (text, image, image_process_mode)
148
- if len(state.get_images(return_pil=True)) > 0:
149
- state = default_conversation.copy()
150
- state.append_message(state.roles[0], text)
151
- state.append_message(state.roles[1], None)
152
- state.skip_next = False
153
- return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5
154
-
155
-
156
- def http_bot(state, model_selector, temperature, top_p, max_new_tokens, request: gr.Request):
157
- logger.info(f"http_bot. ip: {request.client.host}")
158
- start_tstamp = time.time()
159
- model_name = model_selector
160
-
161
- if state.skip_next:
162
- # This generate call is skipped due to invalid inputs
163
- yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 5
164
- return
165
-
166
- if len(state.messages) == state.offset + 2:
167
- # First round of conversation
168
- if "llava" in model_name.lower():
169
- if 'llama-2' in model_name.lower():
170
- template_name = "llava_llama_2"
171
- elif "v1" in model_name.lower():
172
- if 'mmtag' in model_name.lower():
173
- template_name = "v1_mmtag"
174
- elif 'plain' in model_name.lower() and 'finetune' not in model_name.lower():
175
- template_name = "v1_mmtag"
176
- else:
177
- template_name = "llava_v1"
178
- elif "mpt" in model_name.lower():
179
- template_name = "mpt"
180
- else:
181
- if 'mmtag' in model_name.lower():
182
- template_name = "v0_mmtag"
183
- elif 'plain' in model_name.lower() and 'finetune' not in model_name.lower():
184
- template_name = "v0_mmtag"
185
- else:
186
- template_name = "llava_v0"
187
- elif "mpt" in model_name:
188
- template_name = "mpt_text"
189
- elif "llama-2" in model_name:
190
- template_name = "llama_2"
191
- else:
192
- template_name = "vicuna_v1"
193
- new_state = conv_templates[template_name].copy()
194
- new_state.append_message(new_state.roles[0], state.messages[-2][1])
195
- new_state.append_message(new_state.roles[1], None)
196
- state = new_state
197
-
198
- # Query worker address
199
- controller_url = args.controller_url
200
- ret = requests.post(controller_url + "/get_worker_address",
201
- json={"model": model_name})
202
- worker_addr = ret.json()["address"]
203
- logger.info(f"model_name: {model_name}, worker_addr: {worker_addr}")
204
-
205
- # No available worker
206
- if worker_addr == "":
207
- state.messages[-1][-1] = server_error_msg
208
- yield (state, state.to_gradio_chatbot(), disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
209
- return
210
-
211
- # Construct prompt
212
- prompt = state.get_prompt()
213
-
214
- all_images = state.get_images(return_pil=True)
215
- all_image_hash = [hashlib.md5(image.tobytes()).hexdigest() for image in all_images]
216
- for image, hash in zip(all_images, all_image_hash):
217
- t = datetime.datetime.now()
218
- filename = os.path.join(LOGDIR, "serve_images", f"{t.year}-{t.month:02d}-{t.day:02d}", f"{hash}.jpg")
219
- if not os.path.isfile(filename):
220
- os.makedirs(os.path.dirname(filename), exist_ok=True)
221
- image.save(filename)
222
-
223
- # Make requests
224
- pload = {
225
- "model": model_name,
226
- "prompt": prompt,
227
- "temperature": float(temperature),
228
- "top_p": float(top_p),
229
- "max_new_tokens": min(int(max_new_tokens), 1536),
230
- "stop": state.sep if state.sep_style in [SeparatorStyle.SINGLE, SeparatorStyle.MPT] else state.sep2,
231
- "images": f'List of {len(state.get_images())} images: {all_image_hash}',
232
- }
233
- logger.info(f"==== request ====\n{pload}")
234
-
235
- pload['images'] = state.get_images()
236
-
237
- state.messages[-1][-1] = "▌"
238
- yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5
239
-
240
- try:
241
- # Stream output
242
- response = requests.post(worker_addr + "/worker_generate_stream",
243
- headers=headers, json=pload, stream=True, timeout=10)
244
- for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
245
- if chunk:
246
- data = json.loads(chunk.decode())
247
- if data["error_code"] == 0:
248
- output = data["text"][len(prompt):].strip()
249
- state.messages[-1][-1] = output + "▌"
250
- yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5
251
- else:
252
- output = data["text"] + f" (error_code: {data['error_code']})"
253
- state.messages[-1][-1] = output
254
- yield (state, state.to_gradio_chatbot()) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
255
- return
256
- time.sleep(0.03)
257
- except requests.exceptions.RequestException as e:
258
- state.messages[-1][-1] = server_error_msg
259
- yield (state, state.to_gradio_chatbot()) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
260
- return
261
-
262
- state.messages[-1][-1] = state.messages[-1][-1][:-1]
263
- yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5
264
-
265
- finish_tstamp = time.time()
266
- logger.info(f"{output}")
267
-
268
- with open(get_conv_log_filename(), "a") as fout:
269
- data = {
270
- "tstamp": round(finish_tstamp, 4),
271
- "type": "chat",
272
- "model": model_name,
273
- "start": round(start_tstamp, 4),
274
- "finish": round(finish_tstamp, 4),
275
- "state": state.dict(),
276
- "images": all_image_hash,
277
- "ip": request.client.host,
278
- }
279
- fout.write(json.dumps(data) + "\n")
280
-
281
- title_markdown = ("""
282
- # 🌋 LLaVA: Large Language and Vision Assistant
283
- [[Project Page](https://llava-vl.github.io)] [[Code](https://github.com/haotian-liu/LLaVA)] [[Model](https://github.com/haotian-liu/LLaVA/blob/main/docs/MODEL_ZOO.md)] | 📚 [[LLaVA](https://arxiv.org/abs/2304.08485)] [[LLaVA-v1.5](https://arxiv.org/abs/2310.03744)]
284
- """)
285
-
286
- tos_markdown = ("""
287
- ### Terms of use
288
- By using this service, users are required to agree to the following terms:
289
- The service is a research preview intended for non-commercial use only. It only provides limited safety measures and may generate offensive content. It must not be used for any illegal, harmful, violent, racist, or sexual purposes. The service may collect user dialogue data for future research.
290
- Please click the "Flag" button if you get any inappropriate answer! We will collect those to keep improving our moderator.
291
- For an optimal experience, please use desktop computers for this demo, as mobile devices may compromise its quality.
292
- """)
293
-
294
-
295
- learn_more_markdown = ("""
296
- ### License
297
- The service is a research preview intended for non-commercial use only, subject to the model [License](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) of LLaMA, [Terms of Use](https://openai.com/policies/terms-of-use) of the data generated by OpenAI, and [Privacy Practices](https://chrome.google.com/webstore/detail/sharegpt-share-your-chatg/daiacboceoaocpibfodeljbdfacokfjb) of ShareGPT. Please contact us if you find any potential violation.
298
- """)
299
-
300
- block_css = """
301
-
302
- #buttons button {
303
- min-width: min(120px,100%);
304
- }
305
-
306
- """
307
-
308
- def build_demo(embed_mode):
309
- textbox = gr.Textbox(show_label=False, placeholder="Enter text and press ENTER", container=False)
310
- with gr.Blocks(title="LLaVA", theme=gr.themes.Default(), css=block_css) as demo:
311
- state = gr.State()
312
-
313
- if not embed_mode:
314
- gr.Markdown(title_markdown)
315
-
316
- with gr.Row():
317
- with gr.Column(scale=3):
318
- with gr.Row(elem_id="model_selector_row"):
319
- model_selector = gr.Dropdown(
320
- choices=models,
321
- value=models[0] if len(models) > 0 else "",
322
- interactive=True,
323
- show_label=False,
324
- container=False)
325
-
326
- imagebox = gr.Image(type="pil")
327
- image_process_mode = gr.Radio(
328
- ["Crop", "Resize", "Pad", "Default"],
329
- value="Default",
330
- label="Preprocess for non-square image", visible=False)
331
-
332
- cur_dir = os.path.dirname(os.path.abspath(__file__))
333
- gr.Examples(examples=[
334
- [f"{cur_dir}/examples/extreme_ironing.jpg", "What is unusual about this image?"],
335
- [f"{cur_dir}/examples/waterview.jpg", "What are the things I should be cautious about when I visit here?"],
336
- ], inputs=[imagebox, textbox])
337
-
338
- with gr.Accordion("Parameters", open=False) as parameter_row:
339
- temperature = gr.Slider(minimum=0.0, maximum=1.0, value=0.2, step=0.1, interactive=True, label="Temperature",)
340
- top_p = gr.Slider(minimum=0.0, maximum=1.0, value=0.7, step=0.1, interactive=True, label="Top P",)
341
- max_output_tokens = gr.Slider(minimum=0, maximum=1024, value=512, step=64, interactive=True, label="Max output tokens",)
342
-
343
- with gr.Column(scale=8):
344
- chatbot = gr.Chatbot(elem_id="chatbot", label="LLaVA Chatbot", height=550)
345
- with gr.Row():
346
- with gr.Column(scale=8):
347
- textbox.render()
348
- with gr.Column(scale=1, min_width=50):
349
- submit_btn = gr.Button(value="Send", variant="primary")
350
- with gr.Row(elem_id="buttons") as button_row:
351
- upvote_btn = gr.Button(value="👍 Upvote", interactive=False)
352
- downvote_btn = gr.Button(value="👎 Downvote", interactive=False)
353
- flag_btn = gr.Button(value="⚠️ Flag", interactive=False)
354
- #stop_btn = gr.Button(value="⏹️ Stop Generation", interactive=False)
355
- regenerate_btn = gr.Button(value="🔄 Regenerate", interactive=False)
356
- clear_btn = gr.Button(value="🗑️ Clear", interactive=False)
357
-
358
- if not embed_mode:
359
- gr.Markdown(tos_markdown)
360
- gr.Markdown(learn_more_markdown)
361
- url_params = gr.JSON(visible=False)
362
-
363
- # Register listeners
364
- btn_list = [upvote_btn, downvote_btn, flag_btn, regenerate_btn, clear_btn]
365
- upvote_btn.click(
366
- upvote_last_response,
367
- [state, model_selector],
368
- [textbox, upvote_btn, downvote_btn, flag_btn],
369
- queue=False
370
- )
371
- downvote_btn.click(
372
- downvote_last_response,
373
- [state, model_selector],
374
- [textbox, upvote_btn, downvote_btn, flag_btn],
375
- queue=False
376
- )
377
- flag_btn.click(
378
- flag_last_response,
379
- [state, model_selector],
380
- [textbox, upvote_btn, downvote_btn, flag_btn],
381
- queue=False
382
- )
383
-
384
- regenerate_btn.click(
385
- regenerate,
386
- [state, image_process_mode],
387
- [state, chatbot, textbox, imagebox] + btn_list,
388
- queue=False
389
- ).then(
390
- http_bot,
391
- [state, model_selector, temperature, top_p, max_output_tokens],
392
- [state, chatbot] + btn_list
393
- )
394
-
395
- clear_btn.click(
396
- clear_history,
397
- None,
398
- [state, chatbot, textbox, imagebox] + btn_list,
399
- queue=False
400
- )
401
-
402
- textbox.submit(
403
- add_text,
404
- [state, textbox, imagebox, image_process_mode],
405
- [state, chatbot, textbox, imagebox] + btn_list,
406
- queue=False
407
- ).then(
408
- http_bot,
409
- [state, model_selector, temperature, top_p, max_output_tokens],
410
- [state, chatbot] + btn_list
411
- )
412
-
413
- submit_btn.click(
414
- add_text,
415
- [state, textbox, imagebox, image_process_mode],
416
- [state, chatbot, textbox, imagebox] + btn_list,
417
- queue=False
418
- ).then(
419
- http_bot,
420
- [state, model_selector, temperature, top_p, max_output_tokens],
421
- [state, chatbot] + btn_list
422
- )
423
-
424
- if args.model_list_mode == "once":
425
- demo.load(
426
- load_demo,
427
- [url_params],
428
- [state, model_selector],
429
- _js=get_window_url_params,
430
- queue=False
431
- )
432
- elif args.model_list_mode == "reload":
433
- demo.load(
434
- load_demo_refresh_model_list,
435
- None,
436
- [state, model_selector],
437
- queue=False
438
- )
439
- else:
440
- raise ValueError(f"Unknown model list mode: {args.model_list_mode}")
441
-
442
- return demo
443
-
444
-
445
- if __name__ == "__main__":
446
- parser = argparse.ArgumentParser()
447
- parser.add_argument("--host", type=str, default="0.0.0.0")
448
- parser.add_argument("--port", type=int)
449
- parser.add_argument("--controller-url", type=str, default="http://localhost:21001")
450
- parser.add_argument("--concurrency-count", type=int, default=10)
451
- parser.add_argument("--model-list-mode", type=str, default="once",
452
- choices=["once", "reload"])
453
- parser.add_argument("--share", action="store_true")
454
- parser.add_argument("--moderate", action="store_true")
455
- parser.add_argument("--embed", action="store_true")
456
- args = parser.parse_args()
457
- logger.info(f"args: {args}")
458
-
459
- models = get_model_list()
460
-
461
- logger.info(args)
462
- demo = build_demo(args.embed)
463
- demo.queue(
464
- concurrency_count=args.concurrency_count,
465
- api_open=False
466
- ).launch(
467
- server_name=args.host,
468
- server_port=args.port,
469
- share=args.share
470
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
LLAVA_Biovil/llava/serve/model_worker.py DELETED
@@ -1,310 +0,0 @@
1
- """
2
- A model worker executes the model.
3
- """
4
- import argparse
5
- import asyncio
6
- import json
7
- import time
8
- import threading
9
- import uuid
10
-
11
- from fastapi import FastAPI, Request, BackgroundTasks
12
- from fastapi.responses import StreamingResponse
13
- import requests
14
- import torch
15
- import uvicorn
16
- from functools import partial
17
-
18
- from llava.constants import WORKER_HEART_BEAT_INTERVAL
19
- from llava.utils import (build_logger, server_error_msg,
20
- pretty_print_semaphore)
21
- from llava.model.builder import load_pretrained_model
22
- from llava.mm_utils import process_images, load_image_from_base64, tokenizer_image_token, KeywordsStoppingCriteria, process_image_biovil, \
23
- load_image_from_base64_biovil
24
- from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
25
- from transformers import TextIteratorStreamer
26
- from threading import Thread
27
-
28
- from torchvision.transforms import Compose, Resize, ToTensor, CenterCrop, transforms
29
-
30
- from test import ExpandChannels
31
-
32
- GB = 1 << 30
33
-
34
- worker_id = str(uuid.uuid4())[:6]
35
- logger = build_logger("model_worker", f"model_worker_{worker_id}.log")
36
- global_counter = 0
37
-
38
- model_semaphore = None
39
-
40
-
41
- def heart_beat_worker(controller):
42
-
43
- while True:
44
- time.sleep(WORKER_HEART_BEAT_INTERVAL)
45
- controller.send_heart_beat()
46
-
47
-
48
- class ModelWorker:
49
- def __init__(self, controller_addr, worker_addr,
50
- worker_id, no_register,
51
- model_path, model_base, model_name,
52
- load_8bit, load_4bit, device, vision_tower):
53
- self.controller_addr = controller_addr
54
- self.worker_addr = worker_addr
55
- self.worker_id = worker_id
56
- if model_path.endswith("/"):
57
- model_path = model_path[:-1]
58
- if model_name is None:
59
- model_paths = model_path.split("/")
60
- if model_paths[-1].startswith('checkpoint-'):
61
- self.model_name = model_paths[-2] + "_" + model_paths[-1]
62
- else:
63
- self.model_name = model_paths[-1]
64
- else:
65
- self.model_name = model_name
66
-
67
- self.device = device
68
- logger.info(f"Loading the model {self.model_name} on worker {worker_id} ...")
69
- self.tokenizer, self.model, self.image_processor, self.context_len = load_pretrained_model(
70
- model_path, model_base, self.model_name, load_8bit, load_4bit, device=self.device)
71
- self.is_multimodal = 'llava' in self.model_name.lower()
72
-
73
- if not no_register:
74
- self.register_to_controller()
75
- self.heart_beat_thread = threading.Thread(
76
- target=heart_beat_worker, args=(self,))
77
- self.heart_beat_thread.start()
78
-
79
- self.vision_tower = vision_tower
80
- self.vis_transforms_biovil = self.create_chest_xray_transform_for_inference(512, center_crop_size=448)
81
-
82
- def create_chest_xray_transform_for_inference(self, resize: int, center_crop_size: int) -> Compose:
83
- """
84
- Defines the image transformation pipeline for Chest-Xray datasets.
85
-
86
- :param resize: The size to resize the image to. Linear resampling is used.
87
- Resizing is applied on the axis with smaller shape.
88
- :param center_crop_size: The size to center crop the image to. Square crop is applied.
89
- """
90
-
91
- transforms = [Resize(resize), CenterCrop(center_crop_size), ToTensor(), ExpandChannels()]
92
- return Compose(transforms)
93
-
94
- def register_to_controller(self):
95
- logger.info("Register to controller")
96
-
97
- url = self.controller_addr + "/register_worker"
98
- data = {
99
- "worker_name": self.worker_addr,
100
- "check_heart_beat": True,
101
- "worker_status": self.get_status()
102
- }
103
- r = requests.post(url, json=data)
104
- assert r.status_code == 200
105
-
106
- def send_heart_beat(self):
107
- logger.info(f"Send heart beat. Models: {[self.model_name]}. "
108
- f"Semaphore: {pretty_print_semaphore(model_semaphore)}. "
109
- f"global_counter: {global_counter}")
110
-
111
- url = self.controller_addr + "/receive_heart_beat"
112
-
113
- while True:
114
- try:
115
- ret = requests.post(url, json={
116
- "worker_name": self.worker_addr,
117
- "queue_length": self.get_queue_length()}, timeout=5)
118
- exist = ret.json()["exist"]
119
- break
120
- except requests.exceptions.RequestException as e:
121
- logger.error(f"heart beat error: {e}")
122
- time.sleep(5)
123
-
124
- if not exist:
125
- self.register_to_controller()
126
-
127
- def get_queue_length(self):
128
- if model_semaphore is None:
129
- return 0
130
- else:
131
- return args.limit_model_concurrency - model_semaphore._value + (len(
132
- model_semaphore._waiters) if model_semaphore._waiters is not None else 0)
133
-
134
- def get_status(self):
135
- return {
136
- "model_names": [self.model_name],
137
- "speed": 1,
138
- "queue_length": self.get_queue_length(),
139
- }
140
-
141
- @torch.inference_mode()
142
- def generate_stream(self, params):
143
- tokenizer, model, image_processor = self.tokenizer, self.model, self.image_processor
144
-
145
- prompt = params["prompt"]
146
- ori_prompt = prompt
147
- images = params.get("images", None)
148
- num_image_tokens = 0
149
- if images is not None and len(images) > 0 and self.is_multimodal:
150
- if len(images) > 0:
151
- if len(images) != prompt.count(DEFAULT_IMAGE_TOKEN):
152
- raise ValueError("Number of images does not match number of <image> tokens in prompt")
153
-
154
- if self.vision_tower == 'biovil':
155
- images = [load_image_from_base64_biovil(image) for image in images]
156
- images = process_image_biovil(images, self.vis_transforms_biovil)
157
- else:
158
- images = [load_image_from_base64(image) for image in images]
159
- images = process_images(images, image_processor, model.config)
160
-
161
- if type(images) is list:
162
- images = [image.to(self.model.device, dtype=torch.bfloat16) for image in images]
163
- else:
164
- images = images.to(self.model.device, dtype=torch.bfloat16)
165
-
166
- replace_token = DEFAULT_IMAGE_TOKEN
167
- if getattr(self.model.config, 'mm_use_im_start_end', False):
168
- replace_token = DEFAULT_IM_START_TOKEN + replace_token + DEFAULT_IM_END_TOKEN
169
- prompt = prompt.replace(DEFAULT_IMAGE_TOKEN, replace_token)
170
-
171
- num_image_tokens = prompt.count(replace_token) * 196 if self.vision_tower == 'biovil' else prompt.count(replace_token) * model.get_vision_tower().num_patches
172
- else:
173
- images = None
174
- image_args = {"images": images}
175
- else:
176
- images = None
177
- image_args = {}
178
-
179
- temperature = float(params.get("temperature", 1.0))
180
- top_p = float(params.get("top_p", 1.0))
181
- max_context_length = getattr(model.config, 'max_position_embeddings', 2048)
182
- max_new_tokens = min(int(params.get("max_new_tokens", 256)), 1024)
183
- stop_str = params.get("stop", None)
184
- do_sample = True if temperature > 0.001 else False
185
-
186
- input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(self.device)
187
- keywords = [stop_str]
188
- stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
189
- streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True, timeout=15)
190
-
191
- max_new_tokens = min(max_new_tokens, max_context_length - input_ids.shape[-1] - num_image_tokens)
192
-
193
- if max_new_tokens < 1:
194
- yield json.dumps({"text": ori_prompt + "Exceeds max token length. Please start a new conversation, thanks.", "error_code": 0}).encode() + b"\0"
195
- return
196
-
197
- thread = Thread(target=model.generate, kwargs=dict(
198
- inputs=input_ids,
199
- do_sample=do_sample,
200
- temperature=temperature,
201
- top_p=top_p,
202
- max_new_tokens=max_new_tokens,
203
- streamer=streamer,
204
- stopping_criteria=[stopping_criteria],
205
- use_cache=True,
206
- **image_args
207
- ))
208
- thread.start()
209
-
210
- generated_text = ori_prompt
211
- for new_text in streamer:
212
- generated_text += new_text
213
- if generated_text.endswith(stop_str):
214
- generated_text = generated_text[:-len(stop_str)]
215
- yield json.dumps({"text": generated_text, "error_code": 0}).encode() + b"\0"
216
-
217
- def generate_stream_gate(self, params):
218
- try:
219
- for x in self.generate_stream(params):
220
- yield x
221
- except ValueError as e:
222
- print("Caught ValueError:", e)
223
- ret = {
224
- "text": server_error_msg,
225
- "error_code": 1,
226
- }
227
- yield json.dumps(ret).encode() + b"\0"
228
- except torch.cuda.CudaError as e:
229
- print("Caught torch.cuda.CudaError:", e)
230
- ret = {
231
- "text": server_error_msg,
232
- "error_code": 1,
233
- }
234
- yield json.dumps(ret).encode() + b"\0"
235
- except Exception as e:
236
- print("Caught Unknown Error", e)
237
- ret = {
238
- "text": server_error_msg,
239
- "error_code": 1,
240
- }
241
- yield json.dumps(ret).encode() + b"\0"
242
-
243
-
244
- app = FastAPI()
245
-
246
-
247
- def release_model_semaphore(fn=None):
248
- model_semaphore.release()
249
- if fn is not None:
250
- fn()
251
-
252
-
253
- @app.post("/worker_generate_stream")
254
- async def generate_stream(request: Request):
255
- global model_semaphore, global_counter
256
- global_counter += 1
257
- params = await request.json()
258
-
259
- if model_semaphore is None:
260
- model_semaphore = asyncio.Semaphore(args.limit_model_concurrency)
261
- await model_semaphore.acquire()
262
- worker.send_heart_beat()
263
- generator = worker.generate_stream_gate(params)
264
- background_tasks = BackgroundTasks()
265
- background_tasks.add_task(partial(release_model_semaphore, fn=worker.send_heart_beat))
266
- return StreamingResponse(generator, background=background_tasks)
267
-
268
-
269
- @app.post("/worker_get_status")
270
- async def get_status(request: Request):
271
- return worker.get_status()
272
-
273
-
274
- if __name__ == "__main__":
275
- parser = argparse.ArgumentParser()
276
- parser.add_argument("--host", type=str, default="localhost")
277
- parser.add_argument("--port", type=int, default=21002)
278
- parser.add_argument("--worker-address", type=str,
279
- default="http://localhost:21002")
280
- parser.add_argument("--controller-address", type=str,
281
- default="http://localhost:21001")
282
- parser.add_argument("--model-path", type=str, default="facebook/opt-350m")
283
- parser.add_argument("--model-base", type=str, default=None)
284
- parser.add_argument("--model-name", type=str)
285
- parser.add_argument("--device", type=str, default="cuda")
286
- parser.add_argument("--multi-modal", action="store_true", help="Multimodal mode is automatically detected with model name, please make sure `llava` is included in the model path.")
287
- parser.add_argument("--limit-model-concurrency", type=int, default=5)
288
- parser.add_argument("--stream-interval", type=int, default=1)
289
- parser.add_argument("--no-register", action="store_true")
290
- parser.add_argument("--load-8bit", action="store_true")
291
- parser.add_argument("--load-4bit", action="store_true")
292
- parser.add_argument("--vision_tower", type=str, default="openai/clip-vit-large-patch14-336")
293
- args = parser.parse_args()
294
- logger.info(f"args: {args}")
295
-
296
- if args.multi_modal:
297
- logger.warning("Multimodal mode is automatically detected with model name, please make sure `llava` is included in the model path.")
298
-
299
- worker = ModelWorker(args.controller_address,
300
- args.worker_address,
301
- worker_id,
302
- args.no_register,
303
- args.model_path,
304
- args.model_base,
305
- args.model_name,
306
- args.load_8bit,
307
- args.load_4bit,
308
- args.device,
309
- args.vision_tower)
310
- uvicorn.run(app, host=args.host, port=args.port, log_level="info")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
LLAVA_Biovil/llava/serve/register_worker.py DELETED
@@ -1,26 +0,0 @@
1
- """
2
- Manually register workers.
3
-
4
- Usage:
5
- python3 -m fastchat.serve.register_worker --controller http://localhost:21001 --worker-name http://localhost:21002
6
- """
7
-
8
- import argparse
9
-
10
- import requests
11
-
12
- if __name__ == "__main__":
13
- parser = argparse.ArgumentParser()
14
- parser.add_argument("--controller-address", type=str)
15
- parser.add_argument("--worker-name", type=str)
16
- parser.add_argument("--check-heart-beat", action="store_true")
17
- args = parser.parse_args()
18
-
19
- url = args.controller_address + "/register_worker"
20
- data = {
21
- "worker_name": args.worker_name,
22
- "check_heart_beat": args.check_heart_beat,
23
- "worker_status": None,
24
- }
25
- r = requests.post(url, json=data)
26
- assert r.status_code == 200
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
LLAVA_Biovil/llava/serve/test_message.py DELETED
@@ -1,62 +0,0 @@
1
- import argparse
2
- import json
3
-
4
- import requests
5
-
6
- from LLAV.llava.conversation import default_conversation
7
-
8
-
9
- def main():
10
- if args.worker_address:
11
- worker_addr = args.worker_address
12
- else:
13
- controller_addr = args.controller_address
14
- ret = requests.post(controller_addr + "/refresh_all_workers")
15
- ret = requests.post(controller_addr + "/list_models")
16
- models = ret.json()["models"]
17
- models.sort()
18
- print(f"Models: {models}")
19
-
20
- ret = requests.post(controller_addr + "/get_worker_address",
21
- json={"model": args.model_name})
22
- worker_addr = ret.json()["address"]
23
- print(f"worker_addr: {worker_addr}")
24
-
25
- if worker_addr == "":
26
- return
27
-
28
- conv = default_conversation.copy()
29
- conv.append_message(conv.roles[0], args.message)
30
- prompt = conv.get_prompt()
31
-
32
- headers = {"User-Agent": "LLaVA Client"}
33
- pload = {
34
- "model": args.model_name,
35
- "prompt": prompt,
36
- "max_new_tokens": args.max_new_tokens,
37
- "temperature": 0.7,
38
- "stop": conv.sep,
39
- }
40
- response = requests.post(worker_addr + "/worker_generate_stream", headers=headers,
41
- json=pload, stream=True)
42
-
43
- print(prompt.replace(conv.sep, "\n"), end="")
44
- for chunk in response.iter_lines(chunk_size=8192, decode_unicode=False, delimiter=b"\0"):
45
- if chunk:
46
- data = json.loads(chunk.decode("utf-8"))
47
- output = data["text"].split(conv.sep)[-1]
48
- print(output, end="\r")
49
- print("")
50
-
51
-
52
- if __name__ == "__main__":
53
- parser = argparse.ArgumentParser()
54
- parser.add_argument("--controller-address", type=str, default="http://localhost:21001")
55
- parser.add_argument("--worker-address", type=str)
56
- parser.add_argument("--model-name", type=str, default="facebook/opt-350m")
57
- parser.add_argument("--max-new-tokens", type=int, default=32)
58
- parser.add_argument("--message", type=str, default=
59
- "Tell me a story with more than 1000 words.")
60
- args = parser.parse_args()
61
-
62
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
LLAVA_Biovil/llava/train/__init__.py DELETED
File without changes
LLAVA_Biovil/llava/train/llama_flash_attn_monkey_patch.py DELETED
@@ -1,115 +0,0 @@
1
- from typing import Optional, Tuple
2
- import warnings
3
-
4
- import torch
5
-
6
- import transformers
7
- from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv
8
-
9
- try:
10
- from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func
11
- except ImportError:
12
- from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func as flash_attn_unpadded_qkvpacked_func
13
- from flash_attn.bert_padding import unpad_input, pad_input
14
-
15
-
16
- def forward(
17
- self,
18
- hidden_states: torch.Tensor,
19
- attention_mask: Optional[torch.Tensor] = None,
20
- position_ids: Optional[torch.Tensor] = None,
21
- past_key_value: Optional[Tuple[torch.Tensor]] = None,
22
- output_attentions: bool = False,
23
- use_cache: bool = False,
24
- ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
25
- if output_attentions:
26
- warnings.warn(
27
- "Output attentions is not supported for patched `LlamaAttention`, returning `None` instead."
28
- )
29
-
30
- bsz, q_len, _ = hidden_states.size()
31
-
32
- query_states = (
33
- self.q_proj(hidden_states)
34
- .view(bsz, q_len, self.num_heads, self.head_dim)
35
- .transpose(1, 2)
36
- )
37
- key_states = (
38
- self.k_proj(hidden_states)
39
- .view(bsz, q_len, self.num_key_value_heads, self.head_dim)
40
- .transpose(1, 2)
41
- )
42
- value_states = (
43
- self.v_proj(hidden_states)
44
- .view(bsz, q_len, self.num_key_value_heads, self.head_dim)
45
- .transpose(1, 2)
46
- ) # shape: (b, num_heads, s, head_dim)
47
-
48
- kv_seq_len = key_states.shape[-2]
49
- if past_key_value is not None:
50
- kv_seq_len += past_key_value[0].shape[-2]
51
-
52
- cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
53
- query_states, key_states = apply_rotary_pos_emb(
54
- query_states, key_states, cos, sin, position_ids
55
- )
56
-
57
- if past_key_value is not None:
58
- # reuse k, v
59
- key_states = torch.cat([past_key_value[0], key_states], dim=2)
60
- value_states = torch.cat([past_key_value[1], value_states], dim=2)
61
-
62
- past_key_value = (key_states, value_states) if use_cache else None
63
-
64
- # repeat k/v heads if n_kv_heads < n_heads
65
- key_states = repeat_kv(key_states, self.num_key_value_groups)
66
- value_states = repeat_kv(value_states, self.num_key_value_groups)
67
-
68
- # Transform the data into the format required by flash attention
69
- qkv = torch.stack([query_states, key_states, value_states], dim=2)
70
- qkv = qkv.transpose(1, 3) # shape: [b, s, 3, num_heads, head_dim]
71
- key_padding_mask = attention_mask
72
-
73
- if key_padding_mask is None:
74
- qkv = qkv.reshape(-1, 3, self.num_heads, self.head_dim)
75
- cu_q_lens = torch.arange(
76
- 0, (bsz + 1) * q_len, step=q_len, dtype=torch.int32, device=qkv.device
77
- )
78
- max_s = q_len
79
- output = flash_attn_unpadded_qkvpacked_func(
80
- qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True
81
- )
82
- output = output.view(bsz, q_len, -1)
83
- else:
84
- qkv = qkv.reshape(bsz, q_len, -1)
85
- qkv, indices, cu_q_lens, max_s = unpad_input(qkv, key_padding_mask)
86
- qkv = qkv.view(-1, 3, self.num_heads, self.head_dim)
87
- output_unpad = flash_attn_unpadded_qkvpacked_func(
88
- qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True
89
- )
90
- output_unpad = output_unpad.reshape(-1, self.num_heads * self.head_dim)
91
- output = pad_input(output_unpad, indices, bsz, q_len)
92
-
93
- return self.o_proj(output), None, past_key_value
94
-
95
-
96
- # Disable the transformation of the attention mask in LlamaModel as the flash attention
97
- # requires the attention mask to be the same as the key_padding_mask
98
- def _prepare_decoder_attention_mask(
99
- self, attention_mask, input_shape, inputs_embeds, past_key_values_length
100
- ):
101
- # [bsz, seq_len]
102
- return attention_mask
103
-
104
-
105
- def replace_llama_attn_with_flash_attn():
106
- cuda_major, cuda_minor = torch.cuda.get_device_capability()
107
- if cuda_major < 8:
108
- warnings.warn(
109
- "Flash attention is only supported on A100 or H100 GPU during training due to head dim > 64 backward."
110
- "ref: https://github.com/HazyResearch/flash-attention/issues/190#issuecomment-1523359593"
111
- )
112
- transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = (
113
- _prepare_decoder_attention_mask
114
- )
115
- transformers.models.llama.modeling_llama.LlamaAttention.forward = forward
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
LLAVA_Biovil/llava/train/llama_patch.py DELETED
@@ -1,139 +0,0 @@
1
- from typing import List, Optional, Tuple
2
-
3
- import torch
4
- from torch import nn
5
- import warnings
6
- import transformers
7
- from transformers.models.llama.modeling_llama import apply_rotary_pos_emb
8
- from peft.tuners.lora import LoraLayer
9
-
10
- try:
11
- from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func
12
- from flash_attn.bert_padding import unpad_input, pad_input
13
- except Exception:
14
- raise ModuleNotFoundError(
15
- "Please install FlashAttention first, e.g., with pip install flash-attn --no-build-isolation, Learn more at https://github.com/Dao-AILab/flash-attention#installation-and-features"
16
- )
17
-
18
- try:
19
- from einops import rearrange
20
- except Exception:
21
- raise ModuleNotFoundError("Please install einops first, e.g., with pip install einops")
22
-
23
-
24
- # ADAPTED from https://github.com/allenai/open-instruct/blob/main/open_instruct/llama_flash_attn_monkey_patch.py
25
- # AND https://github.com/lm-sys/FastChat/blob/main/fastchat/train/llama_flash_attn_monkey_patch.py
26
- # AND https://github.com/LAION-AI/Open-Assistant/blob/04fa9a24b2a58c8885b8aa6a2eb02b18de6b4961/model/model_training/models/patching_llama.py
27
- # AND Sourabh https://github.com/huggingface/transformers/commit/ee81bf5aee0d65f005d157c013777e3d27d8d6bf
28
- def forward(
29
- self,
30
- hidden_states: torch.Tensor,
31
- attention_mask: Optional[torch.Tensor] = None,
32
- position_ids: Optional[torch.Tensor] = None,
33
- past_key_value: Optional[Tuple[torch.Tensor]] = None,
34
- output_attentions: bool = False,
35
- use_cache: bool = False,
36
- ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
37
- """Input shape: Batch x Time x Channel
38
-
39
- attention_mask: [bsz, q_len]
40
- """
41
- if output_attentions:
42
- warnings.warn("Output attentions is not supported for patched `LlamaAttention`, returning `None` instead.")
43
-
44
- bsz, q_len, _ = hidden_states.size()
45
-
46
- query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
47
- key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
48
- value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
49
- # [bsz, q_len, nh, hd]
50
- # [bsz, nh, q_len, hd]
51
-
52
- kv_seq_len = key_states.shape[-2]
53
- if past_key_value is not None:
54
- kv_seq_len += past_key_value[0].shape[-2]
55
- cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
56
- query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
57
-
58
- # Past Key value support
59
- if past_key_value is not None:
60
- # reuse k, v, self_attention
61
- key_states = torch.cat([past_key_value[0], key_states], dim=2)
62
- value_states = torch.cat([past_key_value[1], value_states], dim=2)
63
-
64
- past_key_value = (key_states, value_states) if use_cache else None
65
-
66
- # Flash attention codes from
67
- # https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/flash_attention.py
68
-
69
- # transform the data into the format required by flash attention
70
- qkv = torch.stack([query_states, key_states, value_states], dim=2) # [bsz, nh, 3, q_len, hd]
71
- qkv = qkv.transpose(1, 3) # [bsz, q_len, 3, nh, hd]
72
- # We have disabled _prepare_decoder_attention_mask in LlamaModel
73
- # the attention_mask should be the same as the key_padding_mask
74
- key_padding_mask = attention_mask
75
-
76
- if key_padding_mask is None:
77
- qkv = rearrange(qkv, "b s ... -> (b s) ...")
78
- max_s = q_len
79
- cu_q_lens = torch.arange(0, (bsz + 1) * q_len, step=q_len, dtype=torch.int32, device=qkv.device)
80
- output = flash_attn_varlen_qkvpacked_func(qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True)
81
- output = rearrange(output, "(b s) ... -> b s ...", b=bsz)
82
- else:
83
- nheads = qkv.shape[-2]
84
- x = rearrange(qkv, "b s three h d -> b s (three h d)")
85
- x_unpad, indices, cu_q_lens, max_s = unpad_input(x, key_padding_mask)
86
- x_unpad = rearrange(x_unpad, "nnz (three h d) -> nnz three h d", three=3, h=nheads)
87
- output_unpad = flash_attn_varlen_qkvpacked_func(
88
- x_unpad, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True
89
- )
90
- output = rearrange(
91
- pad_input(rearrange(output_unpad, "nnz h d -> nnz (h d)"), indices, bsz, q_len),
92
- "b s (h d) -> b s h d",
93
- h=nheads,
94
- )
95
- return self.o_proj(rearrange(output, "b s h d -> b s (h d)")), None, past_key_value
96
-
97
-
98
- # Disable the transformation of the attention mask in LlamaModel as the flash attention
99
- # requires the attention mask to be the same as the key_padding_mask
100
- def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
101
- # [bsz, seq_len]
102
- return attention_mask
103
-
104
-
105
- def replace_attn_with_flash_attn():
106
- cuda_major, cuda_minor = torch.cuda.get_device_capability()
107
- if cuda_major < 8:
108
- print(
109
- "Flash attention is only supported on Ampere or Hopper GPU during training due to head dim > 64 backward."
110
- "ref: https://github.com/HazyResearch/flash-attention/issues/190#issuecomment-1523359593"
111
- )
112
- transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = (
113
- _prepare_decoder_attention_mask
114
- )
115
- transformers.models.llama.modeling_llama.LlamaAttention.forward = forward
116
-
117
-
118
- def unplace_flash_attn_with_attn():
119
- import importlib
120
- import transformers
121
-
122
- print("Reloading llama model, unpatching flash attention")
123
- importlib.reload(transformers.models.llama.modeling_llama)
124
-
125
-
126
- # Adapted from https://github.com/tmm1/axolotl/blob/2eda9e02a9d15a7a3f92b41f257d9844d72fc220/src/axolotl/utils/models.py#L338
127
- def upcast_layer_for_flash_attention(model, torch_dtype):
128
- # LlamaRMSNorm layers are in fp32 after kbit_training, so we need to
129
- # convert them back to fp16/bf16 for flash-attn compatibility.
130
- for name, module in model.named_modules():
131
- if isinstance(module, LoraLayer):
132
- module.to(torch_dtype)
133
- if "norm" in name:
134
- module.to(torch_dtype)
135
- if "lm_head" in name or "embed_tokens" in name:
136
- if hasattr(module, "weight"):
137
- module.to(torch_dtype)
138
-
139
- return model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
LLAVA_Biovil/llava/train/llama_xformers_attn_monkey_patch.py DELETED
@@ -1,129 +0,0 @@
1
- """
2
- Directly copied the code from https://raw.githubusercontent.com/oobabooga/text-generation-webui/main/modules/llama_attn_hijack.py and made some adjustments
3
- """
4
-
5
- import logging
6
- import math
7
- from typing import Optional, Tuple
8
-
9
- import torch
10
- import transformers.models.llama.modeling_llama
11
- from torch import nn
12
-
13
- try:
14
- import xformers.ops
15
- except ImportError:
16
- logging.error("xformers not found! Please install it before trying to use it.")
17
-
18
-
19
- def replace_llama_attn_with_xformers_attn():
20
- transformers.models.llama.modeling_llama.LlamaAttention.forward = xformers_forward
21
-
22
-
23
- def xformers_forward(
24
- self,
25
- hidden_states: torch.Tensor,
26
- attention_mask: Optional[torch.Tensor] = None,
27
- position_ids: Optional[torch.LongTensor] = None,
28
- past_key_value: Optional[Tuple[torch.Tensor]] = None,
29
- output_attentions: bool = False,
30
- use_cache: bool = False,
31
- ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
32
- # pylint: disable=duplicate-code
33
- bsz, q_len, _ = hidden_states.size()
34
-
35
- query_states = (
36
- self.q_proj(hidden_states)
37
- .view(bsz, q_len, self.num_heads, self.head_dim)
38
- .transpose(1, 2)
39
- )
40
- key_states = (
41
- self.k_proj(hidden_states)
42
- .view(bsz, q_len, self.num_heads, self.head_dim)
43
- .transpose(1, 2)
44
- )
45
- value_states = (
46
- self.v_proj(hidden_states)
47
- .view(bsz, q_len, self.num_heads, self.head_dim)
48
- .transpose(1, 2)
49
- )
50
-
51
- kv_seq_len = key_states.shape[-2]
52
- if past_key_value is not None:
53
- kv_seq_len += past_key_value[0].shape[-2]
54
- cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
55
- (
56
- query_states,
57
- key_states,
58
- ) = transformers.models.llama.modeling_llama.apply_rotary_pos_emb(
59
- query_states, key_states, cos, sin, position_ids
60
- )
61
- # [bsz, nh, t, hd]
62
-
63
- if past_key_value is not None:
64
- # reuse k, v, self_attention
65
- key_states = torch.cat([past_key_value[0], key_states], dim=2)
66
- value_states = torch.cat([past_key_value[1], value_states], dim=2)
67
-
68
- past_key_value = (key_states, value_states) if use_cache else None
69
-
70
- # We only apply xformers optimizations if we don't need to output the whole attention matrix
71
- if not output_attentions:
72
- query_states = query_states.transpose(1, 2)
73
- key_states = key_states.transpose(1, 2)
74
- value_states = value_states.transpose(1, 2)
75
-
76
- # This is a nasty hack. We know attention_mask in transformers is either LowerTriangular or all Zeros.
77
- # We therefore check if one element in the upper triangular portion is zero. If it is, then the mask is all zeros.
78
- if attention_mask is None or attention_mask[0, 0, 0, 1] == 0:
79
- # input and output should be of form (bsz, q_len, num_heads, head_dim)
80
- attn_output = xformers.ops.memory_efficient_attention(
81
- query_states, key_states, value_states, attn_bias=None
82
- )
83
- else:
84
- # input and output should be of form (bsz, q_len, num_heads, head_dim)
85
- attn_output = xformers.ops.memory_efficient_attention(
86
- query_states,
87
- key_states,
88
- value_states,
89
- attn_bias=xformers.ops.LowerTriangularMask(),
90
- )
91
- attn_weights = None
92
- else:
93
- attn_weights = torch.matmul(
94
- query_states, key_states.transpose(2, 3)
95
- ) / math.sqrt(self.head_dim)
96
-
97
- if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
98
- raise ValueError(
99
- f"Attention weights should be of size {(bsz * self.num_heads, q_len, kv_seq_len)}, but is"
100
- f" {attn_weights.size()}"
101
- )
102
-
103
- if attention_mask is not None:
104
- if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
105
- raise ValueError(
106
- f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
107
- )
108
- attn_weights = attn_weights + attention_mask
109
- attn_weights = torch.max(
110
- attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min)
111
- )
112
-
113
- # upcast attention to fp32
114
- attn_weights = nn.functional.softmax(
115
- attn_weights, dim=-1, dtype=torch.float32
116
- ).to(query_states.dtype)
117
- attn_output = torch.matmul(attn_weights, value_states)
118
-
119
- if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
120
- raise ValueError(
121
- f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
122
- f" {attn_output.size()}"
123
- )
124
-
125
- attn_output = attn_output.transpose(1, 2)
126
-
127
- attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
128
- attn_output = self.o_proj(attn_output)
129
- return attn_output, attn_weights, past_key_value
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
LLAVA_Biovil/llava/train/llava_trainer.py DELETED
@@ -1,801 +0,0 @@
1
- import json
2
- import math
3
- import os
4
- import shutil
5
- import sys
6
- import time
7
- from distutils import dist
8
-
9
- import torch
10
- from torch import nn
11
- import numpy as np
12
-
13
- from torch.utils.data import Sampler
14
- from packaging import version
15
-
16
- from transformers import Trainer, TrainerState, is_torch_tpu_available, is_apex_available
17
- from transformers.debug_utils import DebugOption
18
- from transformers.integrations import hp_params
19
- from transformers.deepspeed import deepspeed_init, deepspeed_load_checkpoint
20
-
21
- from transformers.trainer import (
22
- is_sagemaker_mp_enabled,
23
- get_parameter_names,
24
- has_length,
25
- ALL_LAYERNORM_LAYERS,
26
- ShardedDDPOption,
27
- logger, TRAINER_STATE_NAME,
28
- )
29
- from typing import List, Optional
30
-
31
- from transformers.trainer_pt_utils import get_model_param_count
32
- from transformers.trainer_utils import HPSearchBackend, speed_metrics, TrainOutput
33
- from transformers.training_args import ParallelMode
34
- from transformers.utils import is_accelerate_available
35
-
36
- if is_accelerate_available():
37
- from accelerate import Accelerator, skip_first_batches
38
- from accelerate import __version__ as accelerate_version
39
- from accelerate.utils import DistributedDataParallelKwargs, GradientAccumulationPlugin
40
-
41
- if version.parse(accelerate_version) > version.parse("0.20.3"):
42
- from accelerate.utils import (
43
- load_fsdp_model,
44
- load_fsdp_optimizer,
45
- save_fsdp_model,
46
- save_fsdp_optimizer,
47
- )
48
-
49
- if is_torch_tpu_available(check_device=False):
50
- import torch_xla.core.xla_model as xm
51
- import torch_xla.debug.metrics as met
52
-
53
- if is_apex_available():
54
- from apex import amp
55
-
56
- # with open('/home/guests/chantal_pellegrini/RaDialog_LLaVA/data/train_token_freqs_radrestruct_balanced_50ep.json') as f:
57
- # token_frequencies = json.load(f)
58
- # token_weights = {k: 1 / v for k, v in token_frequencies.items()} # linear weighting
59
- # print("lin weighting")
60
-
61
- # token_weights = {k: 1 / (np.log(v) + 1) for k, v in token_frequencies.items()} # log weighting, seems to work better in this case
62
- # print("log weighting")
63
- token_weights = None # no weighting
64
- print("no weighting")
65
-
66
- if token_weights is not None:
67
- min_weight = min(token_weights.values())
68
- extra_token_weight = min_weight / 100 # 100 smaller than the smallest weight
69
-
70
-
71
- def maybe_zero_3(param, ignore_status=False, name=None):
72
- from deepspeed import zero
73
- from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus
74
- if hasattr(param, "ds_id"):
75
- if param.ds_status == ZeroParamStatus.NOT_AVAILABLE:
76
- if not ignore_status:
77
- print(name, 'no ignore status')
78
- with zero.GatheredParameters([param]):
79
- param = param.data.detach().cpu().clone()
80
- else:
81
- param = param.detach().cpu().clone()
82
- return param
83
-
84
-
85
- def get_mm_adapter_state_maybe_zero_3(named_params, keys_to_match):
86
- to_return = {k: t for k, t in named_params if any(key_match in k for key_match in keys_to_match)}
87
- to_return = {k: maybe_zero_3(v, ignore_status=True, name=k).cpu() for k, v in to_return.items()}
88
- return to_return
89
-
90
-
91
- def split_to_even_chunks(indices, lengths, num_chunks):
92
- """
93
- Split a list of indices into `chunks` chunks of roughly equal lengths.
94
- """
95
-
96
- if len(indices) % num_chunks != 0:
97
- return [indices[i::num_chunks] for i in range(num_chunks)]
98
-
99
- num_indices_per_chunk = len(indices) // num_chunks
100
-
101
- chunks = [[] for _ in range(num_chunks)]
102
- chunks_lengths = [0 for _ in range(num_chunks)]
103
- for index in indices:
104
- shortest_chunk = chunks_lengths.index(min(chunks_lengths))
105
- chunks[shortest_chunk].append(index)
106
- chunks_lengths[shortest_chunk] += lengths[index]
107
- if len(chunks[shortest_chunk]) == num_indices_per_chunk:
108
- chunks_lengths[shortest_chunk] = float("inf")
109
-
110
- return chunks
111
-
112
-
113
- def get_modality_length_grouped_indices(lengths, batch_size, world_size, generator=None):
114
- # We need to use torch for the random part as a distributed sampler will set the random seed for torch.
115
- assert all(l != 0 for l in lengths), "Should not have zero length."
116
- if all(l > 0 for l in lengths) or all(l < 0 for l in lengths):
117
- # all samples are in the same modality
118
- return get_length_grouped_indices(lengths, batch_size, world_size, generator=generator)
119
- mm_indices, mm_lengths = zip(*[(i, l) for i, l in enumerate(lengths) if l > 0])
120
- lang_indices, lang_lengths = zip(*[(i, -l) for i, l in enumerate(lengths) if l < 0])
121
-
122
- mm_shuffle = [mm_indices[i] for i in get_length_grouped_indices(mm_lengths, batch_size, world_size, generator=None)]
123
- lang_shuffle = [lang_indices[i] for i in get_length_grouped_indices(lang_lengths, batch_size, world_size, generator=None)]
124
- megabatch_size = world_size * batch_size
125
- mm_megabatches = [mm_shuffle[i : i + megabatch_size] for i in range(0, len(mm_shuffle), megabatch_size)]
126
- lang_megabatches = [lang_shuffle[i : i + megabatch_size] for i in range(0, len(lang_shuffle), megabatch_size)]
127
-
128
- last_mm = mm_megabatches[-1]
129
- last_lang = lang_megabatches[-1]
130
- additional_batch = last_mm + last_lang
131
- megabatches = mm_megabatches[:-1] + lang_megabatches[:-1]
132
- megabatch_indices = torch.randperm(len(megabatches), generator=generator)
133
- megabatches = [megabatches[i] for i in megabatch_indices]
134
-
135
- if len(additional_batch) > 0:
136
- megabatches.append(sorted(additional_batch))
137
-
138
- return [i for megabatch in megabatches for i in megabatch]
139
-
140
-
141
- def get_length_grouped_indices(lengths, batch_size, world_size, generator=None, merge=True):
142
- # We need to use torch for the random part as a distributed sampler will set the random seed for torch.
143
- indices = torch.randperm(len(lengths), generator=generator)
144
- megabatch_size = world_size * batch_size
145
- megabatches = [indices[i : i + megabatch_size].tolist() for i in range(0, len(lengths), megabatch_size)]
146
- megabatches = [sorted(megabatch, key=lambda i: lengths[i], reverse=True) for megabatch in megabatches]
147
- megabatches = [split_to_even_chunks(megabatch, lengths, world_size) for megabatch in megabatches]
148
-
149
- return [i for megabatch in megabatches for batch in megabatch for i in batch]
150
-
151
-
152
- class LengthGroupedSampler(Sampler):
153
- r"""
154
- Sampler that samples indices in a way that groups together features of the dataset of roughly the same length while
155
- keeping a bit of randomness.
156
- """
157
-
158
- def __init__(
159
- self,
160
- batch_size: int,
161
- world_size: int,
162
- lengths: Optional[List[int]] = None,
163
- generator=None,
164
- group_by_modality: bool = False,
165
- ):
166
- if lengths is None:
167
- raise ValueError("Lengths must be provided.")
168
-
169
- self.batch_size = batch_size
170
- self.world_size = world_size
171
- self.lengths = lengths
172
- self.generator = generator
173
- self.group_by_modality = group_by_modality
174
-
175
- def __len__(self):
176
- return len(self.lengths)
177
-
178
- def __iter__(self):
179
- if self.group_by_modality:
180
- indices = get_modality_length_grouped_indices(self.lengths, self.batch_size, self.world_size, generator=self.generator)
181
- else:
182
- indices = get_length_grouped_indices(self.lengths, self.batch_size, self.world_size, generator=self.generator)
183
- return iter(indices)
184
-
185
-
186
- class LLaVATrainer(Trainer):
187
-
188
- def compute_loss(self, model, inputs, return_outputs=False):
189
- """
190
- How the loss is computed by Trainer. By default, all models return the loss in the first element.
191
-
192
- Subclass and override for custom behavior.
193
- """
194
- outputs = model(**inputs)
195
-
196
- # Save past state if it exists
197
- # TODO: this needs to be fixed and made cleaner later.
198
- if self.args.past_index >= 0:
199
- self._past = outputs[self.args.past_index]
200
-
201
- if token_weights is not None:
202
- # check if self has attribute vocab_weight, otherwise create
203
- if not hasattr(self, 'vocab_weight'):
204
- vocab = self.tokenizer.get_vocab()
205
- self.vocab_weight = torch.ones(len(vocab)) * extra_token_weight # default weight
206
- # map them using vocab to correct indices
207
- for k, v in token_weights.items():
208
- self.vocab_weight[vocab[k]] = v
209
- self.vocab_weight = self.vocab_weight.to(self.args.device)
210
-
211
- # Shift so that tokens < n predict n
212
- shift_logits = outputs.logits[..., :-1, :].contiguous()
213
- shift_labels = outputs.modified_labels[..., 1:].contiguous()
214
- # Flatten the tokens
215
- loss_fct = nn.CrossEntropyLoss(weight=self.vocab_weight)
216
- shift_logits = shift_logits.view(-1, self.model.config.vocab_size)
217
- shift_labels = shift_labels.view(-1)
218
- # Enable model parallelism
219
- shift_labels = shift_labels.to(shift_logits.device)
220
- loss = loss_fct(shift_logits, shift_labels)
221
-
222
- return (loss, outputs) if return_outputs else loss
223
-
224
- else: #orginial compute_loss without weighting
225
- # We don't use .loss here since the model may return tuples instead of ModelOutput.
226
- loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0]
227
-
228
- return (loss, outputs) if return_outputs else loss
229
-
230
-
231
- def _inner_training_loop(
232
- self, batch_size=None, args=None, resume_from_checkpoint=None, trial=None, ignore_keys_for_eval=None
233
- ):
234
- self.accelerator.free_memory()
235
- self._train_batch_size = batch_size
236
- logger.debug(f"Currently training with a batch size of: {self._train_batch_size}")
237
- # Data loader and number of training steps
238
- train_dataloader = self.get_train_dataloader()
239
-
240
- # Setting up training control variables:
241
- # number of training epochs: num_train_epochs
242
- # number of training steps per epoch: num_update_steps_per_epoch
243
- # total number of training steps to execute: max_steps
244
- total_train_batch_size = self._train_batch_size * args.gradient_accumulation_steps * args.world_size
245
-
246
- len_dataloader = None
247
- if has_length(train_dataloader):
248
- len_dataloader = len(train_dataloader)
249
- num_update_steps_per_epoch = len_dataloader // args.gradient_accumulation_steps
250
- num_update_steps_per_epoch = max(num_update_steps_per_epoch, 1)
251
- num_examples = self.num_examples(train_dataloader)
252
- if args.max_steps > 0:
253
- max_steps = args.max_steps
254
- num_train_epochs = args.max_steps // num_update_steps_per_epoch + int(
255
- args.max_steps % num_update_steps_per_epoch > 0
256
- )
257
- # May be slightly incorrect if the last batch in the training dataloader has a smaller size but it's
258
- # the best we can do.
259
- num_train_samples = args.max_steps * total_train_batch_size
260
- else:
261
- max_steps = math.ceil(args.num_train_epochs * num_update_steps_per_epoch)
262
- num_train_epochs = math.ceil(args.num_train_epochs)
263
- num_train_samples = self.num_examples(train_dataloader) * args.num_train_epochs
264
- elif args.max_steps > 0: # Rely on max_steps when dataloader does not have a working size
265
- max_steps = args.max_steps
266
- # Setting a very large number of epochs so we go as many times as necessary over the iterator.
267
- num_train_epochs = sys.maxsize
268
- num_update_steps_per_epoch = max_steps
269
- num_examples = total_train_batch_size * args.max_steps
270
- num_train_samples = args.max_steps * total_train_batch_size
271
- else:
272
- raise ValueError(
273
- "args.max_steps must be set to a positive value if dataloader does not have a length, was"
274
- f" {args.max_steps}"
275
- )
276
-
277
- # Compute absolute values for logging, eval, and save if given as ratio
278
- if args.logging_steps and args.logging_steps < 1:
279
- args.logging_steps = math.ceil(max_steps * args.logging_steps)
280
- if args.eval_steps and args.eval_steps < 1:
281
- args.eval_steps = math.ceil(max_steps * args.eval_steps)
282
- if args.save_steps and args.save_steps < 1:
283
- args.save_steps = math.ceil(max_steps * args.save_steps)
284
-
285
- if DebugOption.UNDERFLOW_OVERFLOW in self.args.debug:
286
- if self.args.n_gpu > 1:
287
- # nn.DataParallel(model) replicates the model, creating new variables and module
288
- # references registered here no longer work on other gpus, breaking the module
289
- raise ValueError(
290
- "Currently --debug underflow_overflow is not supported under DP. Please use DDP"
291
- " (torch.distributed.launch)."
292
- )
293
- else:
294
- debug_overflow = DebugUnderflowOverflow(self.model) # noqa
295
-
296
- delay_optimizer_creation = (
297
- self.sharded_ddp is not None
298
- and self.sharded_ddp != ShardedDDPOption.SIMPLE
299
- or is_sagemaker_mp_enabled()
300
- or self.fsdp is not None
301
- )
302
-
303
- # We need to reset the scheduler, as its parameters may be different on subsequent calls
304
- if self._created_lr_scheduler:
305
- self.lr_scheduler = None
306
- self._created_lr_scheduler = False
307
-
308
- if self.is_deepspeed_enabled:
309
- self.optimizer, self.lr_scheduler = deepspeed_init(self, num_training_steps=max_steps)
310
-
311
- if not delay_optimizer_creation:
312
- self.create_optimizer_and_scheduler(num_training_steps=max_steps)
313
-
314
- self.state = TrainerState()
315
- self.state.is_hyper_param_search = trial is not None
316
-
317
- # Activate gradient checkpointing if needed
318
- if args.gradient_checkpointing:
319
- self.model.gradient_checkpointing_enable()
320
-
321
- model = self._wrap_model(self.model_wrapped)
322
-
323
- if is_sagemaker_mp_enabled() and resume_from_checkpoint is not None:
324
- self._load_from_checkpoint(resume_from_checkpoint, model)
325
-
326
- # as the model is wrapped, don't use `accelerator.prepare`
327
- # this is for unhandled cases such as
328
- # Fairscale Sharded DDP, FSDP-XLA, SageMaker MP/DP, DataParallel, IPEX
329
- use_accelerator_prepare = True if model is self.model else False
330
-
331
- if delay_optimizer_creation:
332
- self.create_optimizer_and_scheduler(num_training_steps=max_steps)
333
-
334
- # prepare using `accelerator` prepare
335
- if use_accelerator_prepare:
336
- self.model.train()
337
- if hasattr(self.lr_scheduler, "step"):
338
- if self.use_apex:
339
- model = self.accelerator.prepare(self.model)
340
- else:
341
- model, self.optimizer = self.accelerator.prepare(self.model, self.optimizer)
342
- else:
343
- # to handle cases wherein we pass "DummyScheduler" such as when it is specified in DeepSpeed config.
344
- model, self.optimizer, self.lr_scheduler = self.accelerator.prepare(
345
- self.model, self.optimizer, self.lr_scheduler
346
- )
347
-
348
- if self.is_fsdp_enabled:
349
- self.model = model
350
-
351
- # for the rest of this function `model` is the outside model, whether it was wrapped or not
352
- if model is not self.model:
353
- self.model_wrapped = model
354
-
355
- # backward compatibility
356
- if self.is_deepspeed_enabled:
357
- self.deepspeed = self.model_wrapped
358
-
359
- # deepspeed ckpt loading
360
- if resume_from_checkpoint is not None and self.is_deepspeed_enabled:
361
- print(f"DeepSpeed info: Loading model from {resume_from_checkpoint}")
362
- deepspeed_load_checkpoint(self.model_wrapped, resume_from_checkpoint)
363
- # get step from opt state
364
- # Assuming `optimizer_state_dict` is the dictionary you've loaded from the checkpoint
365
- for param_tensor, state in self.lr_scheduler.optimizer.state.items():
366
- step_tensor = state['step']
367
- step_value = step_tensor.item() # Convert tensor to a Python number
368
- print(f"Step value for a parameter tensor: {step_value}")
369
- # Since all parameters should have been updated the same number of times,
370
- # you can break after the first iteration
371
- break
372
- # step scheduler to match
373
- for _ in range(int(step_value)):
374
- self.lr_scheduler.step()
375
- # Check if saved optimizer or scheduler states exist
376
- self._load_optimizer_and_scheduler(resume_from_checkpoint)
377
-
378
- # important: at this point:
379
- # self.model is the Transformers Model
380
- # self.model_wrapped is DDP(Transformers Model), Deepspeed(Transformers Model), etc.
381
-
382
- # Train!
383
- logger.info("***** Running training *****")
384
- logger.info(f" Num examples = {num_examples:,}")
385
- logger.info(f" Num Epochs = {num_train_epochs:,}")
386
- logger.info(f" Instantaneous batch size per device = {self.args.per_device_train_batch_size:,}")
387
- if self.args.per_device_train_batch_size != self._train_batch_size:
388
- logger.info(f" Training with DataParallel so batch size has been adjusted to: {self._train_batch_size:,}")
389
- logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_train_batch_size:,}")
390
- logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
391
- logger.info(f" Total optimization steps = {max_steps:,}")
392
- logger.info(f" Number of trainable parameters = {get_model_param_count(model, trainable_only=True):,}")
393
-
394
- self.state.epoch = 0
395
- start_time = time.time()
396
- epochs_trained = 0
397
- steps_trained_in_current_epoch = 0
398
- steps_trained_progress_bar = None
399
-
400
- # Check if continuing training from a checkpoint
401
- if resume_from_checkpoint is not None and os.path.isfile(
402
- os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME)
403
- ):
404
- self.state = TrainerState.load_from_json(os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME))
405
- epochs_trained = self.state.global_step // num_update_steps_per_epoch
406
- if not args.ignore_data_skip:
407
- steps_trained_in_current_epoch = self.state.global_step % (num_update_steps_per_epoch)
408
- steps_trained_in_current_epoch *= args.gradient_accumulation_steps
409
- else:
410
- steps_trained_in_current_epoch = 0
411
-
412
- logger.info(" Continuing training from checkpoint, will skip to saved global_step")
413
- logger.info(f" Continuing training from epoch {epochs_trained}")
414
- logger.info(f" Continuing training from global step {self.state.global_step}")
415
- if not args.ignore_data_skip:
416
- logger.info(
417
- f" Will skip the first {epochs_trained} epochs then the first"
418
- f" {steps_trained_in_current_epoch} batches in the first epoch."
419
- )
420
-
421
- # Update the references
422
- self.callback_handler.model = self.model
423
- self.callback_handler.optimizer = self.optimizer
424
- self.callback_handler.lr_scheduler = self.lr_scheduler
425
- self.callback_handler.train_dataloader = train_dataloader
426
- if self.hp_name is not None and self._trial is not None:
427
- # use self._trial because the SigOpt/Optuna hpo only call `_hp_search_setup(trial)` instead of passing trial
428
- # parameter to Train when using DDP.
429
- self.state.trial_name = self.hp_name(self._trial)
430
- if trial is not None:
431
- assignments = trial.assignments if self.hp_search_backend == HPSearchBackend.SIGOPT else trial
432
- self.state.trial_params = hp_params(assignments)
433
- else:
434
- self.state.trial_params = None
435
- # This should be the same if the state has been saved but in case the training arguments changed, it's safer
436
- # to set this after the load.
437
- self.state.max_steps = max_steps
438
- self.state.num_train_epochs = num_train_epochs
439
- self.state.is_local_process_zero = self.is_local_process_zero()
440
- self.state.is_world_process_zero = self.is_world_process_zero()
441
-
442
- # tr_loss is a tensor to avoid synchronization of TPUs through .item()
443
- tr_loss = torch.tensor(0.0).to(args.device)
444
- # _total_loss_scalar is updated everytime .item() has to be called on tr_loss and stores the sum of all losses
445
- self._total_loss_scalar = 0.0
446
- self._globalstep_last_logged = self.state.global_step
447
- model.zero_grad()
448
-
449
- self.control = self.callback_handler.on_train_begin(args, self.state, self.control)
450
-
451
- # Skip the first epochs_trained epochs to get the random state of the dataloader at the right point.
452
- if not args.ignore_data_skip:
453
- for epoch in range(epochs_trained):
454
- for _ in train_dataloader:
455
- break
456
-
457
- total_batched_samples = 0
458
- for epoch in range(epochs_trained, num_train_epochs):
459
- epoch_iterator = train_dataloader
460
-
461
- # Reset the past mems state at the beginning of each epoch if necessary.
462
- if args.past_index >= 0:
463
- self._past = None
464
-
465
- steps_in_epoch = (
466
- len(epoch_iterator)
467
- if len_dataloader is not None
468
- else args.max_steps * args.gradient_accumulation_steps
469
- )
470
- self.control = self.callback_handler.on_epoch_begin(args, self.state, self.control)
471
-
472
- if epoch == epochs_trained and resume_from_checkpoint is not None and steps_trained_in_current_epoch == 0:
473
- self._load_rng_state(resume_from_checkpoint)
474
-
475
- rng_to_sync = False
476
- steps_skipped = 0
477
- if steps_trained_in_current_epoch > 0:
478
- epoch_iterator = skip_first_batches(epoch_iterator, steps_trained_in_current_epoch)
479
- steps_skipped = steps_trained_in_current_epoch
480
- steps_trained_in_current_epoch = 0
481
- rng_to_sync = True
482
-
483
- step = -1
484
- for step, inputs in enumerate(epoch_iterator):
485
- total_batched_samples += 1
486
- if rng_to_sync:
487
- self._load_rng_state(resume_from_checkpoint)
488
- rng_to_sync = False
489
-
490
- # Skip past any already trained steps if resuming training
491
- if steps_trained_in_current_epoch > 0:
492
- steps_trained_in_current_epoch -= 1
493
- if steps_trained_progress_bar is not None:
494
- steps_trained_progress_bar.update(1)
495
- if steps_trained_in_current_epoch == 0:
496
- self._load_rng_state(resume_from_checkpoint)
497
- continue
498
- elif steps_trained_progress_bar is not None:
499
- steps_trained_progress_bar.close()
500
- steps_trained_progress_bar = None
501
-
502
- if step % args.gradient_accumulation_steps == 0:
503
- self.control = self.callback_handler.on_step_begin(args, self.state, self.control)
504
-
505
- with self.accelerator.accumulate(model):
506
- tr_loss_step = self.training_step(model, inputs)
507
-
508
- if (
509
- args.logging_nan_inf_filter
510
- and not is_torch_tpu_available()
511
- and (torch.isnan(tr_loss_step) or torch.isinf(tr_loss_step))
512
- ):
513
- # if loss is nan or inf simply add the average of previous logged losses
514
- tr_loss += tr_loss / (1 + self.state.global_step - self._globalstep_last_logged)
515
- else:
516
- tr_loss += tr_loss_step
517
-
518
- self.current_flos += float(self.floating_point_ops(inputs))
519
-
520
- is_last_step_and_steps_less_than_grad_acc = (
521
- steps_in_epoch <= args.gradient_accumulation_steps and (step + 1) == steps_in_epoch
522
- )
523
-
524
- if (
525
- total_batched_samples % args.gradient_accumulation_steps == 0
526
- or
527
- # last step in epoch but step is always smaller than gradient_accumulation_steps
528
- is_last_step_and_steps_less_than_grad_acc
529
- ):
530
- # the `or` condition of `is_last_step_and_steps_less_than_grad_acc` is not covered
531
- # in accelerate. So, explicitly enable sync gradients to True in that case.
532
- if is_last_step_and_steps_less_than_grad_acc or (
533
- version.parse(accelerate_version) <= version.parse("0.20.3")
534
- ):
535
- self.accelerator.gradient_state._set_sync_gradients(True)
536
-
537
- # Gradient clipping
538
- if args.max_grad_norm is not None and args.max_grad_norm > 0:
539
- # deepspeed does its own clipping
540
-
541
- if self.do_grad_scaling:
542
- # Reduce gradients first for XLA
543
- if is_torch_tpu_available():
544
- gradients = xm._fetch_gradients(self.optimizer)
545
- xm.all_reduce("sum", gradients, scale=1.0 / xm.xrt_world_size())
546
- # AMP: gradients need unscaling
547
- self.scaler.unscale_(self.optimizer)
548
-
549
- if is_sagemaker_mp_enabled() and args.fp16:
550
- self.optimizer.clip_master_grads(args.max_grad_norm)
551
- elif hasattr(self.optimizer, "clip_grad_norm"):
552
- # Some optimizers (like the sharded optimizer) have a specific way to do gradient clipping
553
- self.optimizer.clip_grad_norm(args.max_grad_norm)
554
- elif hasattr(model, "clip_grad_norm_"):
555
- # Some models (like FullyShardedDDP) have a specific way to do gradient clipping
556
- model.clip_grad_norm_(args.max_grad_norm)
557
- elif self.use_apex:
558
- # Revert to normal clipping otherwise, handling Apex or full precision
559
- nn.utils.clip_grad_norm_(
560
- amp.master_params(self.optimizer),
561
- args.max_grad_norm,
562
- )
563
- else:
564
- self.accelerator.clip_grad_norm_(
565
- model.parameters(),
566
- args.max_grad_norm,
567
- )
568
-
569
- # Optimizer step
570
- optimizer_was_run = True
571
- if is_torch_tpu_available():
572
- if self.do_grad_scaling:
573
- self.scaler.step(self.optimizer)
574
- self.scaler.update()
575
- else:
576
- # tpu-comment: accelerate wrapped optimizers call xm.optimizer_step
577
- self.optimizer.step()
578
- elif self.do_grad_scaling:
579
- scale_before = self.scaler.get_scale()
580
- self.scaler.step(self.optimizer)
581
- self.scaler.update()
582
- scale_after = self.scaler.get_scale()
583
- optimizer_was_run = scale_before <= scale_after
584
- else:
585
- self.optimizer.step()
586
- optimizer_was_run = not self.accelerator.optimizer_step_was_skipped
587
-
588
- if optimizer_was_run:
589
- # Delay optimizer scheduling until metrics are generated
590
- if not isinstance(self.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
591
- self.lr_scheduler.step()
592
-
593
- model.zero_grad()
594
- self.state.global_step += 1
595
- self.state.epoch = epoch + (step + 1 + steps_skipped) / steps_in_epoch
596
- self.control = self.callback_handler.on_step_end(args, self.state, self.control)
597
-
598
- self._maybe_log_save_evaluate(tr_loss, model, trial, epoch, ignore_keys_for_eval)
599
- else:
600
- self.control = self.callback_handler.on_substep_end(args, self.state, self.control)
601
-
602
- if self.control.should_epoch_stop or self.control.should_training_stop:
603
- break
604
- if step < 0:
605
- logger.warning(
606
- "There seems to be not a single sample in your epoch_iterator, stopping training at step"
607
- f" {self.state.global_step}! This is expected if you're using an IterableDataset and set"
608
- f" num_steps ({max_steps}) higher than the number of available samples."
609
- )
610
- self.control.should_training_stop = True
611
-
612
- self.control = self.callback_handler.on_epoch_end(args, self.state, self.control)
613
- self._maybe_log_save_evaluate(tr_loss, model, trial, epoch, ignore_keys_for_eval)
614
-
615
- if DebugOption.TPU_METRICS_DEBUG in self.args.debug:
616
- if is_torch_tpu_available():
617
- # tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.)
618
- xm.master_print(met.metrics_report())
619
- else:
620
- logger.warning(
621
- "You enabled PyTorch/XLA debug metrics but you don't have a TPU "
622
- "configured. Check your training configuration if this is unexpected."
623
- )
624
- if self.control.should_training_stop:
625
- break
626
-
627
- if args.past_index and hasattr(self, "_past"):
628
- # Clean the state at the end of training
629
- delattr(self, "_past")
630
-
631
- logger.info("\n\nTraining completed. Do not forget to share your model on huggingface.co/models =)\n\n")
632
- if args.load_best_model_at_end and self.state.best_model_checkpoint is not None:
633
- # Wait for everyone to get here so we are sur the model has been saved by process 0.
634
- if is_torch_tpu_available():
635
- xm.rendezvous("load_best_model_at_end")
636
- elif args.parallel_mode == ParallelMode.DISTRIBUTED:
637
- dist.barrier()
638
- # elif is_sagemaker_mp_enabled():
639
- # smp.barrier()
640
-
641
- self._load_best_model()
642
-
643
- # add remaining tr_loss
644
- self._total_loss_scalar += tr_loss.item()
645
- train_loss = self._total_loss_scalar / self.state.global_step
646
-
647
- metrics = speed_metrics("train", start_time, num_samples=num_train_samples, num_steps=self.state.max_steps)
648
- self.store_flos()
649
- metrics["total_flos"] = self.state.total_flos
650
- metrics["train_loss"] = train_loss
651
-
652
- self.is_in_train = False
653
-
654
- self._memory_tracker.stop_and_update_metrics(metrics)
655
-
656
- self.log(metrics)
657
-
658
- run_dir = self._get_output_dir(trial)
659
- checkpoints_sorted = self._sorted_checkpoints(use_mtime=False, output_dir=run_dir)
660
-
661
- # Delete the last checkpoint when save_total_limit=1 if it's different from the best checkpoint and process allowed to save.
662
- if self.args.should_save and self.state.best_model_checkpoint is not None and self.args.save_total_limit == 1:
663
- for checkpoint in checkpoints_sorted:
664
- if checkpoint != self.state.best_model_checkpoint:
665
- logger.info(f"Deleting older checkpoint [{checkpoint}] due to args.save_total_limit")
666
- shutil.rmtree(checkpoint)
667
-
668
- self.control = self.callback_handler.on_train_end(args, self.state, self.control)
669
-
670
- return TrainOutput(self.state.global_step, train_loss, metrics)
671
-
672
- def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
673
- if self.train_dataset is None or not has_length(self.train_dataset):
674
- return None
675
-
676
- if self.args.group_by_modality_length:
677
- lengths = self.train_dataset.modality_lengths
678
- return LengthGroupedSampler(
679
- self.args.train_batch_size,
680
- world_size=self.args.world_size * self.args.gradient_accumulation_steps,
681
- lengths=lengths,
682
- group_by_modality=True,
683
- )
684
- else:
685
- return super()._get_train_sampler()
686
-
687
- def create_optimizer(self):
688
- """
689
- Setup the optimizer.
690
-
691
- We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the
692
- Trainer's init through `optimizers`, or subclass and override this method in a subclass.
693
- """
694
- if is_sagemaker_mp_enabled():
695
- return super().create_optimizer()
696
- if self.sharded_ddp == ShardedDDPOption.SIMPLE:
697
- return super().create_optimizer()
698
-
699
- opt_model = self.model
700
-
701
- if self.optimizer is None:
702
- decay_parameters = get_parameter_names(opt_model, ALL_LAYERNORM_LAYERS)
703
- decay_parameters = [name for name in decay_parameters if "bias" not in name]
704
- if self.args.mm_projector_lr is not None:
705
- projector_parameters = [name for name, _ in opt_model.named_parameters() if "mm_projector" in name]
706
- optimizer_grouped_parameters = [
707
- {
708
- "params": [
709
- p for n, p in opt_model.named_parameters() if (n in decay_parameters and n not in projector_parameters and p.requires_grad)
710
- ],
711
- "weight_decay": self.args.weight_decay,
712
- },
713
- {
714
- "params": [
715
- p for n, p in opt_model.named_parameters() if (n not in decay_parameters and n not in projector_parameters and p.requires_grad)
716
- ],
717
- "weight_decay": 0.0,
718
- },
719
- {
720
- "params": [
721
- p for n, p in opt_model.named_parameters() if (n in decay_parameters and n in projector_parameters and p.requires_grad)
722
- ],
723
- "weight_decay": self.args.weight_decay,
724
- "lr": self.args.mm_projector_lr,
725
- },
726
- {
727
- "params": [
728
- p for n, p in opt_model.named_parameters() if (n not in decay_parameters and n in projector_parameters and p.requires_grad)
729
- ],
730
- "weight_decay": 0.0,
731
- "lr": self.args.mm_projector_lr,
732
- },
733
- ]
734
- else:
735
- optimizer_grouped_parameters = [
736
- {
737
- "params": [
738
- p for n, p in opt_model.named_parameters() if (n in decay_parameters and p.requires_grad)
739
- ],
740
- "weight_decay": self.args.weight_decay,
741
- },
742
- {
743
- "params": [
744
- p for n, p in opt_model.named_parameters() if (n not in decay_parameters and p.requires_grad)
745
- ],
746
- "weight_decay": 0.0,
747
- },
748
- ]
749
-
750
- optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(self.args)
751
-
752
- if self.sharded_ddp == ShardedDDPOption.SIMPLE:
753
- self.optimizer = OSS(
754
- params=optimizer_grouped_parameters,
755
- optim=optimizer_cls,
756
- **optimizer_kwargs,
757
- )
758
- else:
759
- self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
760
- if optimizer_cls.__name__ == "Adam8bit":
761
- import bitsandbytes
762
-
763
- manager = bitsandbytes.optim.GlobalOptimManager.get_instance()
764
-
765
- skipped = 0
766
- for module in opt_model.modules():
767
- if isinstance(module, nn.Embedding):
768
- skipped += sum({p.data_ptr(): p.numel() for p in module.parameters()}.values())
769
- logger.info(f"skipped {module}: {skipped/2**20}M params")
770
- manager.register_module_override(module, "weight", {"optim_bits": 32})
771
- logger.debug(f"bitsandbytes: will optimize {module} in fp32")
772
- logger.info(f"skipped: {skipped/2**20}M params")
773
-
774
- return self.optimizer
775
-
776
- def _save_checkpoint(self, model, trial, metrics=None):
777
- if getattr(self.args, 'tune_mm_mlp_adapter', False):
778
- from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
779
- checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}"
780
-
781
- run_dir = self._get_output_dir(trial=trial)
782
- output_dir = os.path.join(run_dir, checkpoint_folder)
783
-
784
- # Only save Adapter
785
- keys_to_match = ['mm_projector', 'vision_resampler']
786
- if getattr(self.args, "use_im_start_end", False):
787
- keys_to_match.extend(['embed_tokens', 'embed_in'])
788
-
789
- weight_to_save = get_mm_adapter_state_maybe_zero_3(self.model.named_parameters(), keys_to_match)
790
-
791
- if self.args.local_rank == 0 or self.args.local_rank == -1:
792
- self.model.config.save_pretrained(output_dir)
793
- torch.save(weight_to_save, os.path.join(output_dir, f'mm_projector.bin'))
794
- else:
795
- super(LLaVATrainer, self)._save_checkpoint(model, trial, metrics)
796
-
797
- def _save(self, output_dir: Optional[str] = None, state_dict=None):
798
- if getattr(self.args, 'tune_mm_mlp_adapter', False):
799
- pass
800
- else:
801
- super(LLaVATrainer, self)._save(output_dir, state_dict)