Raniahossam33 commited on
Commit
82b369d
1 Parent(s): 23eab40

Upload 5 files

Browse files
Files changed (5) hide show
  1. app.py +19 -0
  2. counting_model.pt +3 -0
  3. fish_feeding.py +164 -0
  4. length_model.pt +3 -0
  5. requirements.txt +7 -0
app.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import numpy as np
3
+ from fish_feeding import FishFeeding
4
+
5
+ model = FishFeeding()
6
+ model.load_models()
7
+
8
+ def fish_feeding(images):
9
+ for i, img in enumerate(images):
10
+ images[i] = np.array(img, dtype=np.uint8)
11
+
12
+ total_feed, times = model.final_fish_feed(images)
13
+ return {"total_feed": total_feed, "times": times}
14
+
15
+ inputs = gr.Image(type='numpy', label="Upload fish images")
16
+ outputs = gr.JSON(label="Fish Feeding Results")
17
+
18
+ app = gr.Interface(fish_feeding, inputs=inputs, outputs=outputs, title="Fish Feeding Predictor")
19
+ app.launch()
counting_model.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c9dc64af07f0ed80fe8f9c9b9d286353136a649accf794c924af6b8832ae07a7
3
+ size 6238297
fish_feeding.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ from PIL import Image
4
+ from transformers import pipeline
5
+ from ultralytics import YOLO
6
+
7
+ class FishFeeding:
8
+
9
+ def __init__(self, focal_length: float = 27.4) -> None:
10
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
11
+ self.collected_lengths = []
12
+ self.focal_length = focal_length
13
+ self.final_weight = None
14
+ self.length_model_name = "length_model.pt"
15
+ self.depth_model_name = "vinvino02/glpn-nyu"
16
+ self.counting_model_name = "counting_model.pt"
17
+
18
+ def load_models(self) -> None:
19
+ self.fish_keypoints_model = YOLO(self.length_model_name)
20
+ self.depth_model = pipeline(task="depth-estimation", model=self.depth_model_name, device=self.device)
21
+ self.fish_detection_model = YOLO(self.counting_model_name)
22
+
23
+ def predict_fish_length(self, frame):
24
+ image_obj = Image.fromarray(frame)
25
+ image_obj = image_obj.resize((640, 640)) # Adjust size as per requirement
26
+ depth = self.depth_model(image_obj)
27
+ depth = depth["predicted_depth"]
28
+ depth = np.array(depth).squeeze()
29
+
30
+ results = self.fish_detection_model(frame)[0]
31
+ if (results.keypoints == None):
32
+ raise ValueError("No fish detected in the image")
33
+ keypoints = results.keypoints.xyn[0].detach().cpu().numpy()
34
+ head = keypoints[0]
35
+ back = keypoints[1]
36
+ belly = keypoints[2]
37
+ tail = keypoints[3]
38
+
39
+ depth_w, depth_h = depth.shape[:2]
40
+
41
+ head_x = int(head[0] * depth_w)
42
+ head_y = int(head[1] * depth_h)
43
+ tail_x = int(tail[0] * depth_w)
44
+ tail_y = int(tail[1] * depth_h)
45
+
46
+ back_x = int(back[0] * depth_w)
47
+ back_y = int(back[1] * depth_h)
48
+ belly_x = int(belly[0] * depth_w)
49
+ belly_y = int(belly[1] * depth_h)
50
+
51
+ head_depth = depth[head_y, head_x]
52
+ tail_depth = depth[tail_y, tail_x]
53
+
54
+ fish_length = (
55
+ np.sqrt(
56
+ (head_x * head_depth - tail_x * tail_depth) ** 2
57
+ + (head_y * head_depth - tail_y * tail_depth) ** 2
58
+ )
59
+ / self.focal_length
60
+ )
61
+ # girth = (
62
+ # np.sqrt(
63
+ # (back_x * head_depth - belly_x * tail_depth) ** 2
64
+ # + (back_y * head_depth - belly_y * tail_depth) ** 2
65
+ # )
66
+ # / self.focal_length
67
+ # )
68
+ return fish_length
69
+
70
+ # def videocapture(self):
71
+ # cap = cv2.VideoCapture(self.video_path)
72
+ # assert cap.isOpened(), "Error reading video file"
73
+ # while True:
74
+ # ret, frame = cap.read()
75
+ # if not ret:
76
+ # break
77
+ # output = self.predict_fish_length(frame)
78
+ # self.collected_lengths.append(output)
79
+ # cap.release()
80
+ # return self.collected_lengths
81
+
82
+ def get_average_weight(self):
83
+ if not self.collected_lengths:
84
+ return 0
85
+ length_average = sum(self.collected_lengths) / len(self.collected_lengths)
86
+ final_weight = 0.014 * length_average ** 3.02
87
+ return final_weight
88
+
89
+ def fish_counting(self, images):
90
+ counting_output = 0
91
+ for im0 in images:
92
+ tracks = self.fish_detection_model(im0)
93
+ counting_output = max(counting_output, len(tracks))
94
+
95
+ return counting_output
96
+
97
+ def final_fish_feed(self, images: list):
98
+ for image in images:
99
+ try:
100
+ output = self.predict_fish_length(image)
101
+ except ValueError:
102
+ continue
103
+ self.collected_lengths.append(output)
104
+
105
+ average_weight = self.get_average_weight()
106
+ if 0 <= average_weight <= 50:
107
+ feed, times = 3.3, 2
108
+ elif 50 < average_weight <= 100:
109
+ feed, times = 4.8, 2
110
+ elif 100 < average_weight <= 250:
111
+ feed, times = 5.8, 2
112
+ elif 250 < average_weight <= 500:
113
+ feed, times = 8.4, 2
114
+ elif 500 < average_weight <= 750:
115
+ feed, times = 9.4, 1
116
+ elif 750 < average_weight <= 1000:
117
+ feed, times = 10.5, 1
118
+ elif 1000 < average_weight <= 1500:
119
+ feed, times = 11.0, 1
120
+ else:
121
+ feed, times = 12.0, 1
122
+
123
+ fish_count = self.fish_counting(images)
124
+ total_feed = feed * fish_count
125
+ return total_feed, times
126
+
127
+
128
+ # if __name__ == "__main__":
129
+ # to_collect = 6
130
+ # collected = []
131
+ # video_path = "object_counting.mp4"
132
+ # cap = cv2.VideoCapture(video_path)
133
+
134
+ # fish_feeding = FishFeeding()
135
+ # fish_feeding.load_models()
136
+
137
+ # d = {"images": []}
138
+
139
+ # while True:
140
+ # ret, frame = cap.read()
141
+ # if not ret:
142
+ # break
143
+
144
+ # if len(collected) == to_collect:
145
+ # total_feed, times = fish_feeding.final_fish_feed(collected)
146
+ # print(f"Total feed: {total_feed}, Feed times: {times}")
147
+ # collected = []
148
+
149
+ # break
150
+
151
+ # collected.append(frame)
152
+ # d["images"].append(frame.tolist())
153
+
154
+ # if cv2.waitKey(1) & 0xFF == ord("q"):
155
+ # break
156
+
157
+ # cap.release()
158
+ # cv2.destroyAllWindows()
159
+
160
+ # # save d to json file
161
+ # import json
162
+ # with open("data.json", "w") as f:
163
+ # json.dump(d, f)
164
+
length_model.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:208dbde300963cfe83f5f4a9fdfd7a91e640d0410d01ce4e70c96d440cbc03d1
3
+ size 6403287
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ pillow
2
+ ultralytics
3
+ transformers
4
+ fastapi
5
+ dill==0.3.8
6
+ gradio
7
+ torch