G-Rost commited on
Commit
908bd46
1 Parent(s): 54deee2

Upload roop_face_analyser (1).py

Browse files
Files changed (1) hide show
  1. roop/roop_face_analyser (1).py +55 -0
roop/roop_face_analyser (1).py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import threading
2
+ from typing import Any, Optional, List
3
+ import insightface
4
+ import numpy
5
+ import spaces
6
+
7
+ import roop.globals
8
+ from roop.typing import Frame, Face
9
+
10
+ FACE_ANALYSER = None
11
+ THREAD_LOCK = threading.Lock()
12
+
13
+ @spaces.GPU
14
+ def get_face_analyser() -> Any:
15
+ global FACE_ANALYSER
16
+
17
+ with THREAD_LOCK:
18
+ if FACE_ANALYSER is None:
19
+ FACE_ANALYSER = insightface.app.FaceAnalysis(name='buffalo_l', providers=roop.globals.execution_providers)
20
+ FACE_ANALYSER.prepare(ctx_id=0)
21
+ return FACE_ANALYSER
22
+
23
+
24
+ def clear_face_analyser() -> Any:
25
+ global FACE_ANALYSER
26
+
27
+ FACE_ANALYSER = None
28
+
29
+
30
+ def get_one_face(frame: Frame, position: int = 0) -> Optional[Face]:
31
+ many_faces = get_many_faces(frame)
32
+ if many_faces:
33
+ try:
34
+ return many_faces[position]
35
+ except IndexError:
36
+ return many_faces[-1]
37
+ return None
38
+
39
+
40
+ def get_many_faces(frame: Frame) -> Optional[List[Face]]:
41
+ try:
42
+ return get_face_analyser().get(frame)
43
+ except ValueError:
44
+ return None
45
+
46
+
47
+ def find_similar_face(frame: Frame, reference_face: Face) -> Optional[Face]:
48
+ many_faces = get_many_faces(frame)
49
+ if many_faces:
50
+ for face in many_faces:
51
+ if hasattr(face, 'normed_embedding') and hasattr(reference_face, 'normed_embedding'):
52
+ distance = numpy.sum(numpy.square(face.normed_embedding - reference_face.normed_embedding))
53
+ if distance < roop.globals.similar_face_distance:
54
+ return face
55
+ return None