TimeForge / llama_mesh.py
Ryukijano's picture
Create llama_mesh.py
7a18ab3 verified
raw
history blame
1.66 kB
# timeforge/llama_mesh.py
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
from threading import Thread
class LLaMAMesh:
def __init__(self, model_path="Zhengyi/LLaMA-Mesh", device="cuda"):
self.model_path = model_path
self.device = device
self.tokenizer = AutoTokenizer.from_pretrained(self.model_path)
self.model = AutoModelForCausalLM.from_pretrained(self.model_path, device_map=self.device)
self.terminators = [
self.tokenizer.eos_token_id,
self.tokenizer.convert_tokens_to_ids("<|eot_id|>"),
]
def generate_mesh(self, prompt, temperature=0.9, max_new_tokens=4096):
input_ids = self.tokenizer.apply_chat_template([{"role": "user", "content": prompt}], return_tensors="pt").to(self.model.device)
streamer = TextIteratorStreamer(self.tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
generate_kwargs = dict(
input_ids= input_ids,
streamer=streamer,
max_new_tokens=max_new_tokens,
do_sample=True,
temperature=temperature,
eos_token_id=self.terminators,
)
if temperature == 0:
generate_kwargs['do_sample'] = False
t = Thread(target=self.model.generate, kwargs=generate_kwargs)
t.start()
outputs = []
for text in streamer:
outputs.append(text)
return "".join(outputs)
if __name__ == "__main__":
llama_mesh = LLaMAMesh()
prompt = "Create a 3D model of a futuristic chair."
mesh = llama_mesh.generate_mesh(prompt)
print(mesh)