File size: 1,925 Bytes
72b304e 4b8a955 72b304e 4b8a955 72b304e 4b8a955 72b304e 4b8a955 72b304e 4b8a955 72b304e 4b8a955 72b304e 4b8a955 72b304e 4b8a955 72b304e 4b8a955 72b304e 4b8a955 72b304e 4b8a955 72b304e 4b8a955 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 |
---
library_name: transformers
license: cc-by-4.0
---
This model is transfer learned on scientific image visual question answering simplified dataset, sugiv/spiqa-simplified-for-fuyu8b-transfer-learning and it is based on
adept/fuyu-8b. Most of the model layers are frozen and as I am GPU poor, this transfer learned model was trained only on a subset of simplified dataset and for two epochs only on A100, 80GB rented and $10 dollars was total spent.
``` python
model_path="sugiv/Fuyu-8b-transfer-learned-spiqa-simplified"
processor = FuyuProcessor.from_pretrained(model_path)
model = FuyuForCausalLM.from_pretrained(model_path, device_map="auto")
text_prompt = "What color is the bus?\n"
url = "https://huggingface.co./adept/fuyu-8b/resolve/main/bus.png"
image = Image.open(requests.get(url, stream=True).raw)
inputs = processor(text=text_prompt, images=image, return_tensors="pt").to("cuda:0")
# Move inputs to the same device as the model
device = next(model.parameters()).device
inputs = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in inputs.items()}
# If 'image_patches' is a list of tensors, move each tensor to the correct device
if 'image_patches' in inputs and isinstance(inputs['image_patches'], list):
inputs['image_patches'] = [patch.to(device) for patch in inputs['image_patches']]
outputs = model.generate(
**inputs,
max_new_tokens=400,
repetition_penalty=1.2,
no_repeat_ngram_size=3,
top_k=40,
top_p=0.92,
temperature=0.7,
do_sample=True
)
# Decode the output
generated_text = processor.decode(outputs[0], skip_special_tokens=True)
# Clean up the generated text
generated_text = generated_text.replace("|SPEAKER|", "").replace("|NEWLINE|", " ").strip()
if "\x04" in generated_text:
generated_text = generated_text.split("\x04")[-1].strip()
print(generated_text)
``` |