amaye15 commited on
Commit
7589a7b
1 Parent(s): 6fda347

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +15 -22
README.md CHANGED
@@ -20,30 +20,23 @@ DaViT (Dual-Attention Vision Transformer) is designed to handle image classifica
20
  Here is an example of how to use the DaViT model for image classification:
21
 
22
  ```python
23
- import torch
24
- from transformers import AutoModel, AutoConfig
25
- # Load the configuration and model
26
- config = AutoConfig.from_pretrained("your-username/DaViT")
27
- model = AutoModel.from_pretrained("your-username/DaViT")
28
- # Generate a random sample input tensor with shape (batch_size, channels, height, width)
29
- batch_size = 2
30
- channels = 3
31
- height = 224
32
- width = 224
33
- sample_input = torch.randn(batch_size, channels, height, width)
34
- # Pass the sample input through the model
35
- output = model(sample_input)
36
- # Print the output shape
37
- print(f"Output shape: {output.shape}")
38
- ```
39
 
40
- ## Files
 
 
41
 
42
- - `configuration_davit.py`: Contains the `DaViTConfig` class.
43
- - `modeling_davit.py`: Contains the `DaViTModel` class.
44
- - `test_davit_model.py`: Script to test the model.
45
- - `config.json`: Configuration file for the model.
46
- - `model.safetensors`: Pretrained weights of the DaViT model.
47
 
48
  ## Credits
49
 
 
20
  Here is an example of how to use the DaViT model for image classification:
21
 
22
  ```python
23
+ # Load model directly
24
+ from transformers import AutoModel, AutoProcessor
25
+ from PIL import Image
26
+ import requests
27
+
28
+ model = AutoModel.from_pretrained("amaye15/DaViT-Florence-2-large-ft", trust_remote_code=True, cache_dir=os.getcwd())
29
+ processor = AutoProcessor.from_pretrained("amaye15/DaViT-Florence-2-large-ft", trust_remote_code=True, cache_dir=os.getcwd())
30
+
 
 
 
 
 
 
 
 
31
 
32
+ prompt = "<OCR>"
33
+ url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/car.jpg?download=true"
34
+ image = image = Image.open(requests.get(url, stream=True).raw)
35
 
36
+ inputs = processor(text=prompt, images=image, return_tensors="pt")
37
+
38
+ model(inputs["pixel_values"])
39
+ ```
 
40
 
41
  ## Credits
42