|
|
|
import torch |
|
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer |
|
from threading import Thread |
|
|
|
class LLaMAMesh: |
|
def __init__(self, model_path="Zhengyi/LLaMA-Mesh", device="cuda"): |
|
self.model_path = model_path |
|
self.device = device |
|
self.tokenizer = AutoTokenizer.from_pretrained(self.model_path) |
|
self.model = AutoModelForCausalLM.from_pretrained(self.model_path, device_map=self.device) |
|
self.terminators = [ |
|
self.tokenizer.eos_token_id, |
|
self.tokenizer.convert_tokens_to_ids("<|eot_id|>"), |
|
] |
|
|
|
|
|
def generate_mesh(self, prompt, temperature=0.9, max_new_tokens=4096): |
|
input_ids = self.tokenizer.apply_chat_template([{"role": "user", "content": prompt}], return_tensors="pt").to(self.model.device) |
|
streamer = TextIteratorStreamer(self.tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True) |
|
generate_kwargs = dict( |
|
input_ids= input_ids, |
|
streamer=streamer, |
|
max_new_tokens=max_new_tokens, |
|
do_sample=True, |
|
temperature=temperature, |
|
eos_token_id=self.terminators, |
|
) |
|
if temperature == 0: |
|
generate_kwargs['do_sample'] = False |
|
|
|
t = Thread(target=self.model.generate, kwargs=generate_kwargs) |
|
t.start() |
|
|
|
outputs = [] |
|
for text in streamer: |
|
outputs.append(text) |
|
return "".join(outputs) |
|
|
|
if __name__ == "__main__": |
|
llama_mesh = LLaMAMesh() |
|
prompt = "Create a 3D model of a futuristic chair." |
|
mesh = llama_mesh.generate_mesh(prompt) |
|
print(mesh) |