dragonjump commited on
Commit
0410591
·
1 Parent(s): 35f3879
Files changed (1) hide show
  1. main.py +29 -8
main.py CHANGED
@@ -4,6 +4,8 @@ from transformers import (
4
  AutoModelForCausalLM,
5
  AutoTokenizer,
6
  )
 
 
7
  from qwen_vl_utils import process_vision_info
8
  import torch
9
  import logging
@@ -13,22 +15,41 @@ logging.basicConfig(level=logging.INFO)
13
  app = FastAPI()
14
 
15
  # Qwen2.5-VL Model Setup
16
- qwen_checkpoint = "Qwen/Qwen2.5-VL-7B-Instruct"
17
- min_pixels = 256 * 28 * 28
18
- max_pixels = 1280 * 28 * 28
 
 
 
 
 
 
 
 
 
 
 
 
 
19
 
 
 
 
 
20
  processor = AutoProcessor.from_pretrained(
21
- qwen_checkpoint,
22
  min_pixels=min_pixels,
23
- max_pixels=max_pixels,
24
  )
25
-
26
- qwen_model = AutoModelForCausalLM.from_pretrained(
27
- qwen_checkpoint,
28
  torch_dtype=torch.bfloat16,
29
  device_map="auto",
 
30
  )
31
 
 
 
32
  # LLaMA Model Setup
33
  llama_model_name = "Orenguteng/Llama-3.1-8B-Lexi-Uncensored-V2"
34
  llama_tokenizer = AutoTokenizer.from_pretrained(llama_model_name)
 
4
  AutoModelForCausalLM,
5
  AutoTokenizer,
6
  )
7
+ from transformers import Qwen2_5_VLForConditionalGeneration
8
+
9
  from qwen_vl_utils import process_vision_info
10
  import torch
11
  import logging
 
15
  app = FastAPI()
16
 
17
  # Qwen2.5-VL Model Setup
18
+ # qwen_checkpoint = "Qwen/Qwen2.5-VL-7B-Instruct"
19
+ # min_pixels = 256 * 28 * 28
20
+ # max_pixels = 1280 * 28 * 28
21
+
22
+ # processor = AutoProcessor.from_pretrained(
23
+ # qwen_checkpoint,
24
+ # min_pixels=min_pixels,
25
+ # max_pixels=max_pixels,
26
+ # )
27
+
28
+ # qwen_model = AutoModelForCausalLM.from_pretrained(
29
+ # qwen_checkpoint,
30
+ # torch_dtype=torch.bfloat16,
31
+ # device_map="auto",
32
+ # )
33
+
34
 
35
+
36
+ checkpoint = "Qwen/Qwen2.5-VL-3B-Instruct"
37
+ min_pixels = 256*28*28
38
+ max_pixels = 1280*28*28
39
  processor = AutoProcessor.from_pretrained(
40
+ checkpoint,
41
  min_pixels=min_pixels,
42
+ max_pixels=max_pixels
43
  )
44
+ qwen_model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
45
+ checkpoint,
 
46
  torch_dtype=torch.bfloat16,
47
  device_map="auto",
48
+ # attn_implementation="flash_attention_2",
49
  )
50
 
51
+
52
+
53
  # LLaMA Model Setup
54
  llama_model_name = "Orenguteng/Llama-3.1-8B-Lexi-Uncensored-V2"
55
  llama_tokenizer = AutoTokenizer.from_pretrained(llama_model_name)