Spaces:
Paused
Paused
initial commit
Browse files- .gitignore +1 -0
- README.md +2 -1
- app.py +155 -12
- requirements.txt +7 -1
- styles.css +45 -0
- vid2persona/gen/gemini.py +61 -0
- vid2persona/gen/local_openllm.py +42 -0
- vid2persona/gen/tgi_openllm.py +25 -0
- vid2persona/gen/utils.py +37 -0
- vid2persona/init.py +31 -0
- vid2persona/pipeline/llm.py +75 -0
- vid2persona/pipeline/vlm.py +15 -0
- vid2persona/prompts/llm.toml +21 -0
- vid2persona/prompts/vlm.toml +19 -0
- vid2persona/utils.py +7 -0
.gitignore
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
__pycache__
|
README.md
CHANGED
@@ -1,6 +1,6 @@
|
|
1 |
---
|
2 |
title: Vid2persona
|
3 |
-
emoji:
|
4 |
colorFrom: green
|
5 |
colorTo: indigo
|
6 |
sdk: gradio
|
@@ -8,6 +8,7 @@ sdk_version: 4.21.0
|
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
license: mit
|
|
|
11 |
---
|
12 |
|
13 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
1 |
---
|
2 |
title: Vid2persona
|
3 |
+
emoji: 🎥🤾
|
4 |
colorFrom: green
|
5 |
colorTo: indigo
|
6 |
sdk: gradio
|
|
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
license: mit
|
11 |
+
short_description: Let's talk to person from video clip
|
12 |
---
|
13 |
|
14 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
app.py
CHANGED
@@ -1,18 +1,161 @@
|
|
1 |
-
import os
|
2 |
import gradio as gr
|
3 |
-
import google.auth
|
4 |
-
import vertexai
|
5 |
|
6 |
-
|
7 |
-
|
8 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
|
10 |
-
|
11 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
|
13 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
14 |
|
15 |
-
|
16 |
-
|
|
|
|
|
17 |
|
18 |
-
demo.launch()
|
|
|
|
|
1 |
import gradio as gr
|
|
|
|
|
2 |
|
3 |
+
from vid2persona import init
|
4 |
+
from vid2persona.pipeline import vlm
|
5 |
+
from vid2persona.pipeline import llm
|
6 |
+
|
7 |
+
init.auth_gcp()
|
8 |
+
init.get_env_vars()
|
9 |
+
prompt_tpl_path = "vid2persona/prompts"
|
10 |
+
|
11 |
+
async def extract_traits(video_path):
|
12 |
+
traits = await vlm.get_traits(
|
13 |
+
init.gcp_project_id,
|
14 |
+
init.gcp_project_location,
|
15 |
+
video_path,
|
16 |
+
prompt_tpl_path
|
17 |
+
)
|
18 |
+
if 'characters' in traits:
|
19 |
+
traits = traits['characters'][0]
|
20 |
+
|
21 |
+
return [
|
22 |
+
traits, [],
|
23 |
+
gr.Textbox("", interactive=True),
|
24 |
+
gr.Button(interactive=True),
|
25 |
+
gr.Button(interactive=True),
|
26 |
+
gr.Button(interactive=True)
|
27 |
+
]
|
28 |
+
|
29 |
+
async def conversation(
|
30 |
+
message: str, messages: list, traits: dict,
|
31 |
+
model_id: str, max_input_token_length: int,
|
32 |
+
max_new_tokens: int, temperature: float,
|
33 |
+
top_p: float, top_k: float, repetition_penalty: float,
|
34 |
+
):
|
35 |
+
messages = messages + [[message, ""]]
|
36 |
+
yield [messages, message, gr.Button(interactive=False), gr.Button(interactive=False)]
|
37 |
+
|
38 |
+
async for partial_response in llm.chat(
|
39 |
+
message, messages, traits,
|
40 |
+
prompt_tpl_path, model_id,
|
41 |
+
max_input_token_length, max_new_tokens,
|
42 |
+
temperature, top_p, top_k,
|
43 |
+
repetition_penalty, hf_token=init.hf_access_token
|
44 |
+
):
|
45 |
+
last_message = messages[-1]
|
46 |
+
last_message[1] = last_message[1] + partial_response
|
47 |
+
messages[-1] = last_message
|
48 |
+
yield [messages, "", gr.Button(interactive=False), gr.Button(interactive=False)]
|
49 |
+
|
50 |
+
yield [messages, "", gr.Button(interactive=True), gr.Button(interactive=True)]
|
51 |
+
|
52 |
+
async def regen_conversation(
|
53 |
+
messages: list, traits: dict,
|
54 |
+
model_id: str, max_input_token_length: int,
|
55 |
+
max_new_tokens: int, temperature: float,
|
56 |
+
top_p: float, top_k: float, repetition_penalty: float,
|
57 |
+
):
|
58 |
+
if len(messages) > 0:
|
59 |
+
message = messages[-1][0]
|
60 |
+
messages = messages[:-1]
|
61 |
+
messages = messages + [[message, ""]]
|
62 |
+
yield [messages, "", gr.Button(interactive=False), gr.Button(interactive=False)]
|
63 |
+
|
64 |
+
async for partial_response in llm.chat(
|
65 |
+
message, messages, traits,
|
66 |
+
prompt_tpl_path, model_id,
|
67 |
+
max_input_token_length, max_new_tokens,
|
68 |
+
temperature, top_p, top_k,
|
69 |
+
repetition_penalty, hf_token=init.hf_access_token
|
70 |
+
):
|
71 |
+
last_message = messages[-1]
|
72 |
+
last_message[1] = last_message[1] + partial_response
|
73 |
+
messages[-1] = last_message
|
74 |
+
yield [messages, "", gr.Button(interactive=False), gr.Button(interactive=False)]
|
75 |
+
|
76 |
+
yield [messages, "", gr.Button(interactive=True), gr.Button(interactive=True)]
|
77 |
+
|
78 |
+
with gr.Blocks(css="styles.css", theme=gr.themes.Soft()) as demo:
|
79 |
+
gr.Markdown("Vid2Persona", elem_classes=["md-center", "h1-font"])
|
80 |
+
gr.Markdown("This project breathes life into video characters by using AI to describe their personality and then chat with you as them.")
|
81 |
+
|
82 |
+
with gr.Column(elem_classes=["group"]):
|
83 |
+
with gr.Row():
|
84 |
+
video = gr.Video(label="upload short video clip")
|
85 |
+
traits = gr.Json(label="extracted traits")
|
86 |
+
|
87 |
+
with gr.Row():
|
88 |
+
trait_gen = gr.Button("generate traits")
|
89 |
+
|
90 |
+
with gr.Column(elem_classes=["group"]):
|
91 |
+
chatbot = gr.Chatbot([], label="chatbot", elem_id="chatbot", elem_classes=["chatbot-no-label"])
|
92 |
+
with gr.Row():
|
93 |
+
clear = gr.Button("clear conversation", interactive=False)
|
94 |
+
regen = gr.Button("regenerate the last", interactive=False)
|
95 |
+
stop = gr.Button("stop", interactive=False)
|
96 |
+
user_input = gr.Textbox(placeholder="ask anything", interactive=False, elem_classes=["textbox-no-label", "textbox-no-top-bottom-borders"])
|
97 |
+
|
98 |
+
with gr.Accordion("parameters' control pane", open=False):
|
99 |
+
model_id = gr.Dropdown(choices=init.ALLOWED_LLM_FOR_HF_PRO_ACCOUNTS, value="HuggingFaceH4/zephyr-7b-beta", label="Model ID")
|
100 |
+
|
101 |
+
with gr.Row():
|
102 |
+
max_input_token_length = gr.Slider(minimum=1024, maximum=4096, value=4096, label="max-input-tokens")
|
103 |
+
max_new_tokens = gr.Slider(minimum=128, maximum=2048, value=256, label="max-new-tokens")
|
104 |
+
|
105 |
+
with gr.Row():
|
106 |
+
temperature = gr.Slider(minimum=0, maximum=2, step=0.1, value=0.6, label="temperature")
|
107 |
+
top_p = gr.Slider(minimum=0, maximum=2, step=0.1, value=0.9, label="top-p")
|
108 |
+
top_k = gr.Slider(minimum=0, maximum=2, step=0.1, value=50, label="top-k")
|
109 |
+
repetition_penalty = gr.Slider(minimum=0, maximum=2, step=0.1, value=1.2, label="repetition-penalty")
|
110 |
|
111 |
+
with gr.Row():
|
112 |
+
gr.Markdown(
|
113 |
+
"[![GitHub Repo](https://img.shields.io/badge/GitHub%20Repo-gray?style=for-the-badge&logo=github&link=https://github.com/deep-diver/Vid2Persona)](https://github.com/deep-diver/Vid2Persona) "
|
114 |
+
"[![Chansung](https://img.shields.io/badge/Chansung-blue?style=for-the-badge&logo=twitter&link=https://twitter.com/algo_diver)](https://twitter.com/algo_diver) "
|
115 |
+
"[![Sayak](https://img.shields.io/badge/Sayak-blue?style=for-the-badge&logo=twitter&link=https://twitter.com/RisingSayak)](https://twitter.com/RisingSayak )",
|
116 |
+
elem_id="bottom-md"
|
117 |
+
)
|
118 |
+
|
119 |
+
trait_gen.click(
|
120 |
+
extract_traits,
|
121 |
+
[video],
|
122 |
+
[traits, chatbot, user_input, clear, regen, stop]
|
123 |
+
)
|
124 |
+
|
125 |
+
conv = user_input.submit(
|
126 |
+
conversation,
|
127 |
+
[
|
128 |
+
user_input, chatbot, traits,
|
129 |
+
model_id, max_input_token_length,
|
130 |
+
max_new_tokens, temperature,
|
131 |
+
top_p, top_k, repetition_penalty,
|
132 |
+
],
|
133 |
+
[chatbot, user_input, clear, regen]
|
134 |
+
)
|
135 |
+
|
136 |
+
clear.click(
|
137 |
+
lambda: [
|
138 |
+
gr.Chatbot([]),
|
139 |
+
gr.Button(interactive=False),
|
140 |
+
gr.Button(interactive=False),
|
141 |
+
],
|
142 |
+
None, [chatbot, clear, regen]
|
143 |
+
)
|
144 |
|
145 |
+
conv_regen = regen.click(
|
146 |
+
regen_conversation,
|
147 |
+
[
|
148 |
+
chatbot, traits,
|
149 |
+
model_id, max_input_token_length,
|
150 |
+
max_new_tokens, temperature,
|
151 |
+
top_p, top_k, repetition_penalty,
|
152 |
+
],
|
153 |
+
[chatbot, user_input, clear, regen]
|
154 |
+
)
|
155 |
|
156 |
+
stop.click(
|
157 |
+
None, None, None,
|
158 |
+
cancels=[conv, conv_regen]
|
159 |
+
)
|
160 |
|
161 |
+
demo.launch()
|
requirements.txt
CHANGED
@@ -1,2 +1,8 @@
|
|
|
|
1 |
google-auth
|
2 |
-
google-cloud-aiplatform
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
toml
|
2 |
google-auth
|
3 |
+
google-cloud-aiplatform
|
4 |
+
transformers
|
5 |
+
accelerate
|
6 |
+
bitsandbytes
|
7 |
+
openai
|
8 |
+
gradio
|
styles.css
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
.textbox-no-label > label > span {
|
2 |
+
display: none;
|
3 |
+
}
|
4 |
+
|
5 |
+
.textbox-no-top-bottom-borders > label > textarea {
|
6 |
+
border: none !important;
|
7 |
+
}
|
8 |
+
|
9 |
+
.chatbot-no-label > div > label {
|
10 |
+
display: none;
|
11 |
+
}
|
12 |
+
|
13 |
+
.md-center {
|
14 |
+
text-align: center;
|
15 |
+
display: block;
|
16 |
+
}
|
17 |
+
|
18 |
+
.h1-font > span {
|
19 |
+
font-size: xx-large;
|
20 |
+
font-weight: bold;
|
21 |
+
}
|
22 |
+
|
23 |
+
.json-holder {
|
24 |
+
overflow: scroll;
|
25 |
+
height: 500px;
|
26 |
+
}
|
27 |
+
|
28 |
+
.group {
|
29 |
+
padding-top: 10px;
|
30 |
+
padding-left: 10px;
|
31 |
+
padding-right: 10px;
|
32 |
+
padding-bottom: 10px;
|
33 |
+
border: 2px dashed gray;
|
34 |
+
border-radius: 20px;
|
35 |
+
box-shadow: 5px 3px 10px 1px rgba(0, 0, 0, 0.4) !important;
|
36 |
+
}
|
37 |
+
|
38 |
+
#bottom-md a {
|
39 |
+
float: left;
|
40 |
+
margin-right: 10px;
|
41 |
+
}
|
42 |
+
|
43 |
+
#chatbot {
|
44 |
+
height: 600px !important;
|
45 |
+
}
|
vid2persona/gen/gemini.py
ADDED
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Union, Iterable
|
2 |
+
|
3 |
+
import vertexai
|
4 |
+
from vertexai.generative_models import (
|
5 |
+
GenerativeModel, Part,
|
6 |
+
GenerationResponse, GenerationConfig
|
7 |
+
)
|
8 |
+
|
9 |
+
from .utils import parse_first_json_snippet
|
10 |
+
|
11 |
+
def _default_gen_config():
|
12 |
+
return GenerationConfig(
|
13 |
+
max_output_tokens=2048,
|
14 |
+
temperature=0.4,
|
15 |
+
top_p=1,
|
16 |
+
top_k=32
|
17 |
+
)
|
18 |
+
|
19 |
+
def init_vertexai(project_id: str, location: str) -> None:
|
20 |
+
vertexai.init(project=project_id, location=location)
|
21 |
+
|
22 |
+
async def _ask_about_video(
|
23 |
+
prompt: str="What is in the video?",
|
24 |
+
gen_config: dict=_default_gen_config(),
|
25 |
+
model_name: str="gemini-1.0-pro-vision",
|
26 |
+
gcs: str=None,
|
27 |
+
base64_content: bytes=None
|
28 |
+
) -> Union[GenerationResponse, Iterable[GenerationResponse]]:
|
29 |
+
if gcs is None and base64_content is None:
|
30 |
+
raise ValueError("Either a GCS bucket path or base64_encoded string of the video must be provided")
|
31 |
+
|
32 |
+
if gcs is not None and base64_content is not None:
|
33 |
+
raise ValueError("Only one of gcs or base64_encoded must be provided")
|
34 |
+
|
35 |
+
if gcs is not None:
|
36 |
+
video = Part.from_uri(gcs, mime_type="video/mp4")
|
37 |
+
else:
|
38 |
+
video = Part.from_data(data=base64_content, mime_type="video/mp4")
|
39 |
+
|
40 |
+
model = GenerativeModel(model_name)
|
41 |
+
return await model.generate_content_async(
|
42 |
+
[video, prompt],
|
43 |
+
generation_config=gen_config
|
44 |
+
)
|
45 |
+
|
46 |
+
async def ask_about_video(prompt: str, video_clip: bytes, retry_num: int=10):
|
47 |
+
json_content = None
|
48 |
+
cur_retry = 0
|
49 |
+
|
50 |
+
while json_content is None and cur_retry < retry_num:
|
51 |
+
try:
|
52 |
+
resps = await _ask_about_video(
|
53 |
+
prompt=prompt, base64_content=video_clip
|
54 |
+
)
|
55 |
+
|
56 |
+
json_content = parse_first_json_snippet(resps.text)
|
57 |
+
except Exception as e:
|
58 |
+
cur_retry = cur_retry + 1
|
59 |
+
print(f"......retry {e}")
|
60 |
+
|
61 |
+
return json_content
|
vid2persona/gen/local_openllm.py
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from threading import Thread
|
3 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
4 |
+
from transformers import TextIteratorStreamer
|
5 |
+
|
6 |
+
model = None
|
7 |
+
tokenizer = None
|
8 |
+
|
9 |
+
def send_message(
|
10 |
+
messages: list,
|
11 |
+
model_id: str,
|
12 |
+
max_input_token_length: int,
|
13 |
+
parameters: dict
|
14 |
+
):
|
15 |
+
global tokenizer
|
16 |
+
global model
|
17 |
+
|
18 |
+
if tokenizer is None:
|
19 |
+
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
20 |
+
tokenizer.use_default_system_prompt = False
|
21 |
+
if model is None:
|
22 |
+
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16, device_map="auto")
|
23 |
+
|
24 |
+
input_ids = tokenizer.apply_chat_template(messages, return_tensors="pt")
|
25 |
+
if input_ids.shape[1] > max_input_token_length:
|
26 |
+
input_ids = input_ids[:, -max_input_token_length:]
|
27 |
+
print(f"Trimmed input from conversation as it was longer than {max_input_token_length} tokens.")
|
28 |
+
input_ids = input_ids.to(model.device)
|
29 |
+
|
30 |
+
streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
|
31 |
+
generate_kwargs = dict(
|
32 |
+
{"input_ids": input_ids},
|
33 |
+
streamer=streamer,
|
34 |
+
do_sample=True,
|
35 |
+
num_beams=1,
|
36 |
+
**parameters
|
37 |
+
)
|
38 |
+
t = Thread(target=model.generate, kwargs=generate_kwargs)
|
39 |
+
t.start()
|
40 |
+
|
41 |
+
for text in streamer:
|
42 |
+
yield text.replace("<|assistant|>", "")
|
vid2persona/gen/tgi_openllm.py
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from openai import AsyncOpenAI
|
2 |
+
|
3 |
+
async def send_messages(
|
4 |
+
messages: list,
|
5 |
+
model_id: str,
|
6 |
+
hf_token: str,
|
7 |
+
parameters: dict
|
8 |
+
):
|
9 |
+
parameters.pop('repetition_penalty')
|
10 |
+
parameters['max_tokens'] = parameters.pop('max_new_tokens')
|
11 |
+
parameters['logprobs'] = True
|
12 |
+
parameters['top_logprobs'] = parameters.pop('top_k')
|
13 |
+
# parameters['presence_penalty'] = parameters.pop('repetition_penalty')
|
14 |
+
|
15 |
+
client = AsyncOpenAI(
|
16 |
+
base_url=f"https://api-inference.huggingface.co/models/{model_id}/v1",
|
17 |
+
api_key=hf_token,
|
18 |
+
)
|
19 |
+
|
20 |
+
responses = await client.chat.completions.create(
|
21 |
+
model="tgi", messages=messages, stream=True, **parameters
|
22 |
+
)
|
23 |
+
|
24 |
+
async for response in responses:
|
25 |
+
yield response.choices[0].delta.content
|
vid2persona/gen/utils.py
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
|
3 |
+
def find_json_snippet(raw_snippet):
|
4 |
+
json_parsed_string = None
|
5 |
+
|
6 |
+
json_start_index = raw_snippet.find('{')
|
7 |
+
json_end_index = raw_snippet.rfind('}')
|
8 |
+
|
9 |
+
if json_start_index >= 0 and json_end_index >= 0:
|
10 |
+
json_snippet = raw_snippet[json_start_index:json_end_index+1]
|
11 |
+
try:
|
12 |
+
json_parsed_string = json.loads(json_snippet, strict=False)
|
13 |
+
except:
|
14 |
+
raise ValueError('......failed to parse string into JSON format')
|
15 |
+
else:
|
16 |
+
raise ValueError('......No JSON code snippet found in string.')
|
17 |
+
|
18 |
+
return json_parsed_string
|
19 |
+
|
20 |
+
def parse_first_json_snippet(snippet):
|
21 |
+
json_parsed_string = None
|
22 |
+
|
23 |
+
if isinstance(snippet, list):
|
24 |
+
for snippet_piece in snippet:
|
25 |
+
try:
|
26 |
+
json_parsed_string = find_json_snippet(snippet_piece)
|
27 |
+
return json_parsed_string
|
28 |
+
except:
|
29 |
+
pass
|
30 |
+
else:
|
31 |
+
try:
|
32 |
+
json_parsed_string = find_json_snippet(snippet)
|
33 |
+
except Exception as e:
|
34 |
+
print(e)
|
35 |
+
raise ValueError()
|
36 |
+
|
37 |
+
return json_parsed_string
|
vid2persona/init.py
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import google.auth
|
3 |
+
|
4 |
+
# https://huggingface.co/blog/inference-pro
|
5 |
+
ALLOWED_LLM_FOR_HF_PRO_ACCOUNTS = [
|
6 |
+
"mistralai/Mixtral-8x7B-Instruct-v0.1",
|
7 |
+
"NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO",
|
8 |
+
"mistralai/Mistral-7B-Instruct-v0.2",
|
9 |
+
"mistralai/Mistral-7B-Instruct-v0.1",
|
10 |
+
"HuggingFaceH4/zephyr-7b-beta",
|
11 |
+
"meta-llama/Llama-2-7b-chat-hf",
|
12 |
+
"meta-llama/Llama-2-13b-chat-hf",
|
13 |
+
"meta-llama/Llama-2-70b-chat-hf",
|
14 |
+
"openchat/openchat-3.5-0106"
|
15 |
+
]
|
16 |
+
|
17 |
+
def auth_gcp():
|
18 |
+
gcp_credentials = os.getenv("GCP_CREDENTIALS")
|
19 |
+
with open("gcp-credentials.json", "w") as f:
|
20 |
+
f.write(gcp_credentials)
|
21 |
+
|
22 |
+
os.environ['GOOGLE_APPLICATION_CREDENTIALS'] = './gcp-credentials.json'
|
23 |
+
google.auth.default()
|
24 |
+
|
25 |
+
def get_env_vars():
|
26 |
+
global gcp_project_id, gcp_project_location
|
27 |
+
global hf_access_token
|
28 |
+
|
29 |
+
gcp_project_id = os.getenv("GCP_PROJECT_ID")
|
30 |
+
gcp_project_location = os.getenv("GCP_PROJECT_LOCATION")
|
31 |
+
hf_access_token = os.getenv("HF_TOKEN", None)
|
vid2persona/pipeline/llm.py
ADDED
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import toml
|
2 |
+
from string import Template
|
3 |
+
from transformers import AutoTokenizer
|
4 |
+
|
5 |
+
from vid2persona.gen import tgi_openllm
|
6 |
+
from vid2persona.gen import local_openllm
|
7 |
+
|
8 |
+
tokenizer = None
|
9 |
+
|
10 |
+
def _get_system_prompt(
|
11 |
+
personality_json_dict: dict,
|
12 |
+
prompt_tpl_path: str
|
13 |
+
) -> str:
|
14 |
+
"""Assumes a single character is passed."""
|
15 |
+
prompt_tpl_path = f"{prompt_tpl_path}/llm.toml"
|
16 |
+
system_prompt = Template(toml.load(prompt_tpl_path)['conversation']['system'])
|
17 |
+
|
18 |
+
name = personality_json_dict["name"]
|
19 |
+
physcial_description = personality_json_dict["physicalDescription"]
|
20 |
+
personality_traits = [str(trait) for trait in personality_json_dict["personalityTraits"]]
|
21 |
+
likes = [str(like) for like in personality_json_dict["likes"]]
|
22 |
+
dislikes = [str(dislike) for dislike in personality_json_dict["dislikes"]]
|
23 |
+
background = [str(info) for info in personality_json_dict["background"]]
|
24 |
+
goals = [str(goal) for goal in personality_json_dict["goals"]]
|
25 |
+
relationships = [str(relationship) for relationship in personality_json_dict["relationships"]]
|
26 |
+
|
27 |
+
system_prompt = system_prompt.substitute(
|
28 |
+
name=name,
|
29 |
+
physcial_description=physcial_description,
|
30 |
+
personality_traits=', '.join(personality_traits),
|
31 |
+
likes=', '.join(likes),
|
32 |
+
background=', '.join(background),
|
33 |
+
goals=', '.join(goals),
|
34 |
+
relationships=', '.join(relationships)
|
35 |
+
)
|
36 |
+
|
37 |
+
return system_prompt
|
38 |
+
|
39 |
+
async def chat(
|
40 |
+
message: str,
|
41 |
+
chat_history: list[tuple[str, str]],
|
42 |
+
personality_json_dict: dict,
|
43 |
+
prompt_tpl_path: str,
|
44 |
+
|
45 |
+
model_id: str,
|
46 |
+
max_input_token_length: int,
|
47 |
+
max_new_tokens: int,
|
48 |
+
temperature: float,
|
49 |
+
top_p: float,
|
50 |
+
top_k: int,
|
51 |
+
repetition_penalty: float,
|
52 |
+
|
53 |
+
hf_token: str,
|
54 |
+
):
|
55 |
+
messages = []
|
56 |
+
system_prompt = _get_system_prompt(personality_json_dict, prompt_tpl_path)
|
57 |
+
messages.append({"role": "system", "content": system_prompt})
|
58 |
+
for user, assistant in chat_history:
|
59 |
+
messages.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
|
60 |
+
messages.append({"role": "user", "content": message})
|
61 |
+
|
62 |
+
parameters = {
|
63 |
+
"max_new_tokens": max_new_tokens,
|
64 |
+
"temperature": temperature,
|
65 |
+
"top_p": top_p,
|
66 |
+
"top_k": top_k,
|
67 |
+
"repetition_penalty": repetition_penalty
|
68 |
+
}
|
69 |
+
|
70 |
+
if hf_token is None:
|
71 |
+
for response in local_openllm.send_message(messages, model_id, max_input_token_length, parameters):
|
72 |
+
yield response
|
73 |
+
else:
|
74 |
+
async for response in tgi_openllm.send_messages(messages, model_id, hf_token, parameters):
|
75 |
+
yield response
|
vid2persona/pipeline/vlm.py
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import toml
|
2 |
+
from vid2persona.gen.gemini import init_vertexai, ask_about_video
|
3 |
+
from vid2persona.utils import get_base64_content
|
4 |
+
|
5 |
+
async def get_traits(
|
6 |
+
gcp_project_id: str, gcp_project_location: str,
|
7 |
+
video_clip_path: str, prompt_tpl_path: str,
|
8 |
+
):
|
9 |
+
prompt_tpl_path = f"{prompt_tpl_path}/vlm.toml"
|
10 |
+
prompt = toml.load(prompt_tpl_path)['extraction']['traits']
|
11 |
+
init_vertexai(gcp_project_id, gcp_project_location)
|
12 |
+
video_clip = get_base64_content(video_clip_path)
|
13 |
+
|
14 |
+
response = await ask_about_video(prompt=prompt, video_clip=video_clip)
|
15 |
+
return response
|
vid2persona/prompts/llm.toml
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[conversation]
|
2 |
+
system = """
|
3 |
+
You are acting as the character detailed below. The details of the character contain different traits, starting from its inherent personality traits to its background.
|
4 |
+
|
5 |
+
* Name: $name
|
6 |
+
* Physical description: $physcial_description
|
7 |
+
* Personality traits: $personality_traits
|
8 |
+
* Likes: $likes
|
9 |
+
* Background: $background
|
10 |
+
* Goals: $goals
|
11 |
+
* Relationships: $relationships
|
12 |
+
|
13 |
+
While generating your responses, you must consider the information above.
|
14 |
+
"""
|
15 |
+
|
16 |
+
examples = [
|
17 |
+
["Hello there! How are you doing?"],
|
18 |
+
["Recite me a short poem."],
|
19 |
+
["Explain the plot of Cinderella in a sentence."],
|
20 |
+
["Write a 100-word article on 'Benefits of Open-Source in AI research'"],
|
21 |
+
]
|
vid2persona/prompts/vlm.toml
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[extraction]
|
2 |
+
traits = """
|
3 |
+
Carefully analyze the provided video clip to identify and extract detailed information about the main character(s) featured. Pay attention to visual elements, spoken dialogue, character interactions, and any narrative cues that reveal aspects of the character's personality, physical appearance, behaviors, and background.
|
4 |
+
|
5 |
+
Your task is to construct a rich, imaginative character profile based on your observations, and where explicit information is not available, you are encouraged to use your creativity to fill in the gaps. The goal is to create a vivid, believable character profile that can be used to simulate conversation with a language model as if it were the character itself.
|
6 |
+
|
7 |
+
Format the extracted data as a structured JSON object containing the following fields for each main character:
|
8 |
+
|
9 |
+
name(text): The character's name as mentioned or inferred in the video. If not provided, create a suitable name that matches the character's traits and context.
|
10 |
+
physicalDescription(text): Describe the character's appearance, including hair color, eye color, height, and distinctive features. Use imaginative details if necessary to provide a complete picture.
|
11 |
+
personalityTraits(list): List descriptive adjectives or phrases that capture the character's personality, based on their actions and dialogue. Invent traits as needed to ensure a well-rounded personality.
|
12 |
+
likes(list): Specify things, activities, or concepts the character enjoys or values, deduced or imagined from their behavior and interactions.
|
13 |
+
dislikes(list): Note what the character appears to dislike or avoid, filling in creatively where direct evidence is not available.
|
14 |
+
background(list): Provide background information such as occupation, family ties, or significant life events, inferring where possible or inventing details to add depth to the character's story.
|
15 |
+
goals(list): Describe the character's apparent motivations and objectives, whether explicitly stated or implied. Where not directly observable, construct plausible goals that align with the character's portrayed or inferred traits.
|
16 |
+
relationships(list): Detail the character's relationships with other characters, including the nature of each relationship and the names of other characters involved. Use creative license to elaborate on these relationships if the video provides limited information.
|
17 |
+
|
18 |
+
Ensure the JSON object is well-structured and comprehensive, ready for integration with a language model to facilitate engaging conversations as if with the character itself. For multiple main characters, provide a distinct profile for each within the same JSON object.
|
19 |
+
"""
|
vid2persona/utils.py
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import base64
|
2 |
+
|
3 |
+
def get_base64_content(file_path, decode=True):
|
4 |
+
with open(file_path, 'rb') as f:
|
5 |
+
data = f.read()
|
6 |
+
|
7 |
+
return base64.b64decode(base64.b64encode(data)) if decode else base64.b64encode(data)
|