Spaces:
Sleeping
Sleeping
Raniahossam33
commited on
Commit
•
82b369d
1
Parent(s):
23eab40
Upload 5 files
Browse files- app.py +19 -0
- counting_model.pt +3 -0
- fish_feeding.py +164 -0
- length_model.pt +3 -0
- requirements.txt +7 -0
app.py
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import numpy as np
|
3 |
+
from fish_feeding import FishFeeding
|
4 |
+
|
5 |
+
model = FishFeeding()
|
6 |
+
model.load_models()
|
7 |
+
|
8 |
+
def fish_feeding(images):
|
9 |
+
for i, img in enumerate(images):
|
10 |
+
images[i] = np.array(img, dtype=np.uint8)
|
11 |
+
|
12 |
+
total_feed, times = model.final_fish_feed(images)
|
13 |
+
return {"total_feed": total_feed, "times": times}
|
14 |
+
|
15 |
+
inputs = gr.Image(type='numpy', label="Upload fish images")
|
16 |
+
outputs = gr.JSON(label="Fish Feeding Results")
|
17 |
+
|
18 |
+
app = gr.Interface(fish_feeding, inputs=inputs, outputs=outputs, title="Fish Feeding Predictor")
|
19 |
+
app.launch()
|
counting_model.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:c9dc64af07f0ed80fe8f9c9b9d286353136a649accf794c924af6b8832ae07a7
|
3 |
+
size 6238297
|
fish_feeding.py
ADDED
@@ -0,0 +1,164 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
from PIL import Image
|
4 |
+
from transformers import pipeline
|
5 |
+
from ultralytics import YOLO
|
6 |
+
|
7 |
+
class FishFeeding:
|
8 |
+
|
9 |
+
def __init__(self, focal_length: float = 27.4) -> None:
|
10 |
+
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
11 |
+
self.collected_lengths = []
|
12 |
+
self.focal_length = focal_length
|
13 |
+
self.final_weight = None
|
14 |
+
self.length_model_name = "length_model.pt"
|
15 |
+
self.depth_model_name = "vinvino02/glpn-nyu"
|
16 |
+
self.counting_model_name = "counting_model.pt"
|
17 |
+
|
18 |
+
def load_models(self) -> None:
|
19 |
+
self.fish_keypoints_model = YOLO(self.length_model_name)
|
20 |
+
self.depth_model = pipeline(task="depth-estimation", model=self.depth_model_name, device=self.device)
|
21 |
+
self.fish_detection_model = YOLO(self.counting_model_name)
|
22 |
+
|
23 |
+
def predict_fish_length(self, frame):
|
24 |
+
image_obj = Image.fromarray(frame)
|
25 |
+
image_obj = image_obj.resize((640, 640)) # Adjust size as per requirement
|
26 |
+
depth = self.depth_model(image_obj)
|
27 |
+
depth = depth["predicted_depth"]
|
28 |
+
depth = np.array(depth).squeeze()
|
29 |
+
|
30 |
+
results = self.fish_detection_model(frame)[0]
|
31 |
+
if (results.keypoints == None):
|
32 |
+
raise ValueError("No fish detected in the image")
|
33 |
+
keypoints = results.keypoints.xyn[0].detach().cpu().numpy()
|
34 |
+
head = keypoints[0]
|
35 |
+
back = keypoints[1]
|
36 |
+
belly = keypoints[2]
|
37 |
+
tail = keypoints[3]
|
38 |
+
|
39 |
+
depth_w, depth_h = depth.shape[:2]
|
40 |
+
|
41 |
+
head_x = int(head[0] * depth_w)
|
42 |
+
head_y = int(head[1] * depth_h)
|
43 |
+
tail_x = int(tail[0] * depth_w)
|
44 |
+
tail_y = int(tail[1] * depth_h)
|
45 |
+
|
46 |
+
back_x = int(back[0] * depth_w)
|
47 |
+
back_y = int(back[1] * depth_h)
|
48 |
+
belly_x = int(belly[0] * depth_w)
|
49 |
+
belly_y = int(belly[1] * depth_h)
|
50 |
+
|
51 |
+
head_depth = depth[head_y, head_x]
|
52 |
+
tail_depth = depth[tail_y, tail_x]
|
53 |
+
|
54 |
+
fish_length = (
|
55 |
+
np.sqrt(
|
56 |
+
(head_x * head_depth - tail_x * tail_depth) ** 2
|
57 |
+
+ (head_y * head_depth - tail_y * tail_depth) ** 2
|
58 |
+
)
|
59 |
+
/ self.focal_length
|
60 |
+
)
|
61 |
+
# girth = (
|
62 |
+
# np.sqrt(
|
63 |
+
# (back_x * head_depth - belly_x * tail_depth) ** 2
|
64 |
+
# + (back_y * head_depth - belly_y * tail_depth) ** 2
|
65 |
+
# )
|
66 |
+
# / self.focal_length
|
67 |
+
# )
|
68 |
+
return fish_length
|
69 |
+
|
70 |
+
# def videocapture(self):
|
71 |
+
# cap = cv2.VideoCapture(self.video_path)
|
72 |
+
# assert cap.isOpened(), "Error reading video file"
|
73 |
+
# while True:
|
74 |
+
# ret, frame = cap.read()
|
75 |
+
# if not ret:
|
76 |
+
# break
|
77 |
+
# output = self.predict_fish_length(frame)
|
78 |
+
# self.collected_lengths.append(output)
|
79 |
+
# cap.release()
|
80 |
+
# return self.collected_lengths
|
81 |
+
|
82 |
+
def get_average_weight(self):
|
83 |
+
if not self.collected_lengths:
|
84 |
+
return 0
|
85 |
+
length_average = sum(self.collected_lengths) / len(self.collected_lengths)
|
86 |
+
final_weight = 0.014 * length_average ** 3.02
|
87 |
+
return final_weight
|
88 |
+
|
89 |
+
def fish_counting(self, images):
|
90 |
+
counting_output = 0
|
91 |
+
for im0 in images:
|
92 |
+
tracks = self.fish_detection_model(im0)
|
93 |
+
counting_output = max(counting_output, len(tracks))
|
94 |
+
|
95 |
+
return counting_output
|
96 |
+
|
97 |
+
def final_fish_feed(self, images: list):
|
98 |
+
for image in images:
|
99 |
+
try:
|
100 |
+
output = self.predict_fish_length(image)
|
101 |
+
except ValueError:
|
102 |
+
continue
|
103 |
+
self.collected_lengths.append(output)
|
104 |
+
|
105 |
+
average_weight = self.get_average_weight()
|
106 |
+
if 0 <= average_weight <= 50:
|
107 |
+
feed, times = 3.3, 2
|
108 |
+
elif 50 < average_weight <= 100:
|
109 |
+
feed, times = 4.8, 2
|
110 |
+
elif 100 < average_weight <= 250:
|
111 |
+
feed, times = 5.8, 2
|
112 |
+
elif 250 < average_weight <= 500:
|
113 |
+
feed, times = 8.4, 2
|
114 |
+
elif 500 < average_weight <= 750:
|
115 |
+
feed, times = 9.4, 1
|
116 |
+
elif 750 < average_weight <= 1000:
|
117 |
+
feed, times = 10.5, 1
|
118 |
+
elif 1000 < average_weight <= 1500:
|
119 |
+
feed, times = 11.0, 1
|
120 |
+
else:
|
121 |
+
feed, times = 12.0, 1
|
122 |
+
|
123 |
+
fish_count = self.fish_counting(images)
|
124 |
+
total_feed = feed * fish_count
|
125 |
+
return total_feed, times
|
126 |
+
|
127 |
+
|
128 |
+
# if __name__ == "__main__":
|
129 |
+
# to_collect = 6
|
130 |
+
# collected = []
|
131 |
+
# video_path = "object_counting.mp4"
|
132 |
+
# cap = cv2.VideoCapture(video_path)
|
133 |
+
|
134 |
+
# fish_feeding = FishFeeding()
|
135 |
+
# fish_feeding.load_models()
|
136 |
+
|
137 |
+
# d = {"images": []}
|
138 |
+
|
139 |
+
# while True:
|
140 |
+
# ret, frame = cap.read()
|
141 |
+
# if not ret:
|
142 |
+
# break
|
143 |
+
|
144 |
+
# if len(collected) == to_collect:
|
145 |
+
# total_feed, times = fish_feeding.final_fish_feed(collected)
|
146 |
+
# print(f"Total feed: {total_feed}, Feed times: {times}")
|
147 |
+
# collected = []
|
148 |
+
|
149 |
+
# break
|
150 |
+
|
151 |
+
# collected.append(frame)
|
152 |
+
# d["images"].append(frame.tolist())
|
153 |
+
|
154 |
+
# if cv2.waitKey(1) & 0xFF == ord("q"):
|
155 |
+
# break
|
156 |
+
|
157 |
+
# cap.release()
|
158 |
+
# cv2.destroyAllWindows()
|
159 |
+
|
160 |
+
# # save d to json file
|
161 |
+
# import json
|
162 |
+
# with open("data.json", "w") as f:
|
163 |
+
# json.dump(d, f)
|
164 |
+
|
length_model.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:208dbde300963cfe83f5f4a9fdfd7a91e640d0410d01ce4e70c96d440cbc03d1
|
3 |
+
size 6403287
|
requirements.txt
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
pillow
|
2 |
+
ultralytics
|
3 |
+
transformers
|
4 |
+
fastapi
|
5 |
+
dill==0.3.8
|
6 |
+
gradio
|
7 |
+
torch
|