Your Name commited on
Commit
23faa2e
·
1 Parent(s): c7606da

Initial commit

Browse files
Files changed (19) hide show
  1. .gitattributes +2 -0
  2. README.md +2 -1
  3. app.py +144 -0
  4. inference_pb2.py +30 -0
  5. inference_pb2.pyi +29 -0
  6. inference_pb2_grpc.py +101 -0
  7. input/0.png +3 -0
  8. input/1.png +3 -0
  9. input/10.jpg +3 -0
  10. input/11.jpg +3 -0
  11. input/2.png +3 -0
  12. input/3.jpg +3 -0
  13. input/4.jpg +3 -0
  14. input/5.jpg +3 -0
  15. input/6.png +3 -0
  16. input/7.png +3 -0
  17. input/8.png +3 -0
  18. input/9.jpg +3 -0
  19. requirements.txt +6 -0
.gitattributes CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.png filter=lfs diff=lfs merge=lfs -text
37
+ *.jpg filter=lfs diff=lfs merge=lfs -text
README.md CHANGED
@@ -1,12 +1,13 @@
1
  ---
2
  title: HairFastGAN
3
- emoji:
4
  colorFrom: pink
5
  colorTo: blue
6
  sdk: gradio
7
  sdk_version: 4.31.5
8
  app_file: app.py
9
  pinned: false
 
10
  ---
11
 
12
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
  title: HairFastGAN
3
+ emoji: 💈
4
  colorFrom: pink
5
  colorTo: blue
6
  sdk: gradio
7
  sdk_version: 4.31.5
8
  app_file: app.py
9
  pinned: false
10
+ license: mit
11
  ---
12
 
13
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from io import BytesIO
3
+
4
+ import gradio as gr
5
+ import grpc
6
+ from PIL import Image
7
+ from cachetools import LRUCache
8
+ import hashlib
9
+
10
+ from inference_pb2 import HairSwapRequest, HairSwapResponse
11
+ from inference_pb2_grpc import HairSwapServiceStub
12
+ from utils.shape_predictor import align_face
13
+
14
+
15
+ def get_bytes(img):
16
+ if img is None:
17
+ return img
18
+
19
+ buffered = BytesIO()
20
+ img.save(buffered, format="JPEG")
21
+ return buffered.getvalue()
22
+
23
+
24
+ def bytes_to_image(image: bytes) -> Image.Image:
25
+ image = Image.open(BytesIO(image))
26
+ return image
27
+
28
+
29
+ def center_crop(img):
30
+ width, height = img.size
31
+ side = min(width, height)
32
+
33
+ left = (width - side) / 2
34
+ top = (height - side) / 2
35
+ right = (width + side) / 2
36
+ bottom = (height + side) / 2
37
+
38
+ img = img.crop((left, top, right, bottom))
39
+ return img
40
+
41
+
42
+ def resize(name):
43
+ def resize_inner(img, align):
44
+ global align_cache
45
+
46
+ if name in align:
47
+ img_hash = hashlib.md5(get_bytes(img)).hexdigest()
48
+
49
+ if img_hash not in align_cache:
50
+ img = align_face(img, return_tensors=False)[0]
51
+ align_cache[img_hash] = img
52
+ else:
53
+ img = align_cache[img_hash]
54
+
55
+ elif img.size != (1024, 1024):
56
+ img = center_crop(img)
57
+ img = img.resize((1024, 1024), Image.Resampling.LANCZOS)
58
+
59
+ return img
60
+
61
+ return resize_inner
62
+
63
+
64
+ def swap_hair(face, shape, color, blending, poisson_iters, poisson_erosion, progress=gr.Progress(track_tqdm=True)):
65
+ if not face or not shape and not color:
66
+ raise ValueError("Need to upload a face and at least a shape or color")
67
+
68
+ face_bytes, shape_bytes, color_bytes = map(lambda item: get_bytes(item), (face, shape, color))
69
+
70
+ if shape_bytes is None:
71
+ shape_bytes = b'face'
72
+ if color_bytes is None:
73
+ color_bytes = b'shape'
74
+
75
+ with grpc.insecure_channel(os.environ['SERVER']) as channel:
76
+ stub = HairSwapServiceStub(channel)
77
+
78
+ output: HairSwapResponse = stub.swap(
79
+ HairSwapRequest(face=face_bytes, shape=shape_bytes, color=color_bytes, blending=blending,
80
+ poisson_iters=poisson_iters, poisson_erosion=poisson_erosion, use_cache=True)
81
+ )
82
+
83
+ output = bytes_to_image(output.image)
84
+ return output
85
+
86
+
87
+ def get_demo():
88
+ with gr.Blocks() as demo:
89
+ gr.Markdown("## HairFastGan")
90
+ gr.Markdown(
91
+ '<div style="display: flex; align-items: center; gap: 10px;">'
92
+ '<span>Official HairFastGAN Gradio demo:</span>'
93
+ '<a href="https://arxiv.org/abs/2404.01094"><img src="https://img.shields.io/badge/arXiv-2404.01094-b31b1b.svg" height=22.5></a>'
94
+ '<a href="https://github.com/AIRI-Institute/HairFastGAN"><img src="https://img.shields.io/badge/github-%23121011.svg?style=for-the-badge&logo=github&logoColor=white" height=22.5></a>'
95
+ '<a href="https://huggingface.co/AIRI-Institute/HairFastGAN"><img src="https://huggingface.co/datasets/huggingface/badges/resolve/main/model-on-hf-md.svg" height=22.5></a>'
96
+ '<a href="https://colab.research.google.com/#fileId=https://huggingface.co/AIRI-Institute/HairFastGAN/blob/main/notebooks/HairFast_inference.ipynb"><img src="https://colab.research.google.com/assets/colab-badge.svg" height=22.5></a>'
97
+ '</div>'
98
+ )
99
+ with gr.Row():
100
+ with gr.Column():
101
+ source = gr.Image(label="Photo that you want to replace the hair", type="pil")
102
+ with gr.Row():
103
+ shape = gr.Image(label="Reference hair you want to get (optional)", type="pil")
104
+ color = gr.Image(label="Reference color hair you want to get (optional)", type="pil")
105
+ with gr.Accordion("Advanced Options", open=False):
106
+ blending = gr.Radio(["Article", "Alternative_v1", "Alternative_v2"], value='Article',
107
+ label="Blending version", info="Selects a model for hair color transfer.")
108
+ poisson_iters = gr.Slider(0, 2500, value=0, step=1, label="Poisson iters",
109
+ info="The power of blending with the original image, helps to recover more details. Not included in the article, disabled by default.")
110
+ poisson_erosion = gr.Slider(1, 100, value=15, step=1, label="Poisson erosion",
111
+ info="Smooths out the blending area.")
112
+ align = gr.CheckboxGroup(["Face", "Shape", "Color"], value=["Face", "Shape", "Color"],
113
+ label="Image cropping [recommended]", info="Selects which images to crop by face")
114
+ btn = gr.Button("Get the haircut")
115
+ with gr.Column():
116
+ output = gr.Image(label="Your result")
117
+
118
+ gr.Examples(examples=[["input/0.png", "input/1.png", "input/2.png"], ["input/6.png", "input/7.png", None],
119
+ ["input/10.jpg", None, "input/11.jpg"]],
120
+ inputs=[source, shape, color], outputs=output)
121
+
122
+ source.upload(fn=resize('Face'), inputs=[source, align], outputs=source)
123
+ shape.upload(fn=resize('Shape'), inputs=[shape, align], outputs=shape)
124
+ color.upload(fn=resize('Color'), inputs=[color, align], outputs=color)
125
+
126
+ btn.click(fn=swap_hair, inputs=[source, shape, color, blending, poisson_iters, poisson_erosion], outputs=output)
127
+
128
+ gr.Markdown('''To cite the paper by the authors
129
+ ```
130
+ @article{nikolaev2024hairfastgan,
131
+ title={HairFastGAN: Realistic and Robust Hair Transfer with a Fast Encoder-Based Approach},
132
+ author={Nikolaev, Maxim and Kuznetsov, Mikhail and Vetrov, Dmitry and Alanov, Aibek},
133
+ journal={arXiv preprint arXiv:2404.01094},
134
+ year={2024}
135
+ }
136
+ ```
137
+ ''')
138
+ return demo
139
+
140
+
141
+ if __name__ == '__main__':
142
+ align_cache = LRUCache(maxsize=10)
143
+ demo = get_demo()
144
+ demo.launch(server_name="0.0.0.0", server_port=7860)
inference_pb2.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Generated by the protocol buffer compiler. DO NOT EDIT!
3
+ # source: inference.proto
4
+ # Protobuf Python Version: 5.26.1
5
+ """Generated protocol buffer code."""
6
+ from google.protobuf import descriptor as _descriptor
7
+ from google.protobuf import descriptor_pool as _descriptor_pool
8
+ from google.protobuf import symbol_database as _symbol_database
9
+ from google.protobuf.internal import builder as _builder
10
+ # @@protoc_insertion_point(imports)
11
+
12
+ _sym_db = _symbol_database.Default()
13
+
14
+
15
+
16
+
17
+ DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x0finference.proto\x12\tinference\"\x92\x01\n\x0fHairSwapRequest\x12\x0c\n\x04\x66\x61\x63\x65\x18\x01 \x01(\x0c\x12\r\n\x05shape\x18\x02 \x01(\x0c\x12\r\n\x05\x63olor\x18\x03 \x01(\x0c\x12\x10\n\x08\x62lending\x18\x04 \x01(\t\x12\x15\n\rpoisson_iters\x18\x05 \x01(\x05\x12\x17\n\x0fpoisson_erosion\x18\x06 \x01(\x05\x12\x11\n\tuse_cache\x18\x07 \x01(\x08\"!\n\x10HairSwapResponse\x12\r\n\x05image\x18\x01 \x01(\x0c\x32R\n\x0fHairSwapService\x12?\n\x04swap\x12\x1a.inference.HairSwapRequest\x1a\x1b.inference.HairSwapResponseb\x06proto3')
18
+
19
+ _globals = globals()
20
+ _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
21
+ _builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'inference_pb2', _globals)
22
+ if not _descriptor._USE_C_DESCRIPTORS:
23
+ DESCRIPTOR._loaded_options = None
24
+ _globals['_HAIRSWAPREQUEST']._serialized_start=31
25
+ _globals['_HAIRSWAPREQUEST']._serialized_end=177
26
+ _globals['_HAIRSWAPRESPONSE']._serialized_start=179
27
+ _globals['_HAIRSWAPRESPONSE']._serialized_end=212
28
+ _globals['_HAIRSWAPSERVICE']._serialized_start=214
29
+ _globals['_HAIRSWAPSERVICE']._serialized_end=296
30
+ # @@protoc_insertion_point(module_scope)
inference_pb2.pyi ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from google.protobuf import descriptor as _descriptor
2
+ from google.protobuf import message as _message
3
+ from typing import ClassVar as _ClassVar, Optional as _Optional
4
+
5
+ DESCRIPTOR: _descriptor.FileDescriptor
6
+
7
+ class HairSwapRequest(_message.Message):
8
+ __slots__ = ("face", "shape", "color", "blending", "poisson_iters", "poisson_erosion", "use_cache")
9
+ FACE_FIELD_NUMBER: _ClassVar[int]
10
+ SHAPE_FIELD_NUMBER: _ClassVar[int]
11
+ COLOR_FIELD_NUMBER: _ClassVar[int]
12
+ BLENDING_FIELD_NUMBER: _ClassVar[int]
13
+ POISSON_ITERS_FIELD_NUMBER: _ClassVar[int]
14
+ POISSON_EROSION_FIELD_NUMBER: _ClassVar[int]
15
+ USE_CACHE_FIELD_NUMBER: _ClassVar[int]
16
+ face: bytes
17
+ shape: bytes
18
+ color: bytes
19
+ blending: str
20
+ poisson_iters: int
21
+ poisson_erosion: int
22
+ use_cache: bool
23
+ def __init__(self, face: _Optional[bytes] = ..., shape: _Optional[bytes] = ..., color: _Optional[bytes] = ..., blending: _Optional[str] = ..., poisson_iters: _Optional[int] = ..., poisson_erosion: _Optional[int] = ..., use_cache: bool = ...) -> None: ...
24
+
25
+ class HairSwapResponse(_message.Message):
26
+ __slots__ = ("image",)
27
+ IMAGE_FIELD_NUMBER: _ClassVar[int]
28
+ image: bytes
29
+ def __init__(self, image: _Optional[bytes] = ...) -> None: ...
inference_pb2_grpc.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT!
2
+ """Client and server classes corresponding to protobuf-defined services."""
3
+ import grpc
4
+ import warnings
5
+
6
+ import inference_pb2 as inference__pb2
7
+
8
+ GRPC_GENERATED_VERSION = '1.63.0'
9
+ GRPC_VERSION = grpc.__version__
10
+ EXPECTED_ERROR_RELEASE = '1.65.0'
11
+ SCHEDULED_RELEASE_DATE = 'June 25, 2024'
12
+ _version_not_supported = False
13
+
14
+ try:
15
+ from grpc._utilities import first_version_is_lower
16
+ _version_not_supported = first_version_is_lower(GRPC_VERSION, GRPC_GENERATED_VERSION)
17
+ except ImportError:
18
+ _version_not_supported = True
19
+
20
+ if _version_not_supported:
21
+ warnings.warn(
22
+ f'The grpc package installed is at version {GRPC_VERSION},'
23
+ + f' but the generated code in inference_pb2_grpc.py depends on'
24
+ + f' grpcio>={GRPC_GENERATED_VERSION}.'
25
+ + f' Please upgrade your grpc module to grpcio>={GRPC_GENERATED_VERSION}'
26
+ + f' or downgrade your generated code using grpcio-tools<={GRPC_VERSION}.'
27
+ + f' This warning will become an error in {EXPECTED_ERROR_RELEASE},'
28
+ + f' scheduled for release on {SCHEDULED_RELEASE_DATE}.',
29
+ RuntimeWarning
30
+ )
31
+
32
+
33
+ class HairSwapServiceStub(object):
34
+ """Missing associated documentation comment in .proto file."""
35
+
36
+ def __init__(self, channel):
37
+ """Constructor.
38
+
39
+ Args:
40
+ channel: A grpc.Channel.
41
+ """
42
+ self.swap = channel.unary_unary(
43
+ '/inference.HairSwapService/swap',
44
+ request_serializer=inference__pb2.HairSwapRequest.SerializeToString,
45
+ response_deserializer=inference__pb2.HairSwapResponse.FromString,
46
+ _registered_method=True)
47
+
48
+
49
+ class HairSwapServiceServicer(object):
50
+ """Missing associated documentation comment in .proto file."""
51
+
52
+ def swap(self, request, context):
53
+ """Missing associated documentation comment in .proto file."""
54
+ context.set_code(grpc.StatusCode.UNIMPLEMENTED)
55
+ context.set_details('Method not implemented!')
56
+ raise NotImplementedError('Method not implemented!')
57
+
58
+
59
+ def add_HairSwapServiceServicer_to_server(servicer, server):
60
+ rpc_method_handlers = {
61
+ 'swap': grpc.unary_unary_rpc_method_handler(
62
+ servicer.swap,
63
+ request_deserializer=inference__pb2.HairSwapRequest.FromString,
64
+ response_serializer=inference__pb2.HairSwapResponse.SerializeToString,
65
+ ),
66
+ }
67
+ generic_handler = grpc.method_handlers_generic_handler(
68
+ 'inference.HairSwapService', rpc_method_handlers)
69
+ server.add_generic_rpc_handlers((generic_handler,))
70
+
71
+
72
+ # This class is part of an EXPERIMENTAL API.
73
+ class HairSwapService(object):
74
+ """Missing associated documentation comment in .proto file."""
75
+
76
+ @staticmethod
77
+ def swap(request,
78
+ target,
79
+ options=(),
80
+ channel_credentials=None,
81
+ call_credentials=None,
82
+ insecure=False,
83
+ compression=None,
84
+ wait_for_ready=None,
85
+ timeout=None,
86
+ metadata=None):
87
+ return grpc.experimental.unary_unary(
88
+ request,
89
+ target,
90
+ '/inference.HairSwapService/swap',
91
+ inference__pb2.HairSwapRequest.SerializeToString,
92
+ inference__pb2.HairSwapResponse.FromString,
93
+ options,
94
+ channel_credentials,
95
+ insecure,
96
+ call_credentials,
97
+ compression,
98
+ wait_for_ready,
99
+ timeout,
100
+ metadata,
101
+ _registered_method=True)
input/0.png ADDED

Git LFS Details

  • SHA256: 2250590e65e153c785683218c0e2da0c21ae104ac2190857988e05474ca89986
  • Pointer size: 132 Bytes
  • Size of remote file: 1.66 MB
input/1.png ADDED

Git LFS Details

  • SHA256: 5f67d4e98519ee4c1b0dad362bacd95dc7f0c090b1c45ebfcee74a85c660e372
  • Pointer size: 132 Bytes
  • Size of remote file: 1.27 MB
input/10.jpg ADDED

Git LFS Details

  • SHA256: c28b0dd07e3e19d4b0bade8f496df8e822a51c054533597d8059cb9cdfa5d123
  • Pointer size: 130 Bytes
  • Size of remote file: 82.3 kB
input/11.jpg ADDED

Git LFS Details

  • SHA256: 1cc97a3c93cf58517e407677fe5fe226a073a76b015aadb57709efccf7b9e192
  • Pointer size: 130 Bytes
  • Size of remote file: 76.8 kB
input/2.png ADDED

Git LFS Details

  • SHA256: 97e0975d499216762645987955e80ee8b764c062d8b47da8f432bf11004a6443
  • Pointer size: 132 Bytes
  • Size of remote file: 1.39 MB
input/3.jpg ADDED

Git LFS Details

  • SHA256: 825f14bc144e8ee11270387e1684472ef70767a2fe428443860d5f817158f8a5
  • Pointer size: 131 Bytes
  • Size of remote file: 129 kB
input/4.jpg ADDED

Git LFS Details

  • SHA256: 02e957c36c1276ec55431a6a8d9614dcdbd98c49d565233e7bc926ff28025050
  • Pointer size: 130 Bytes
  • Size of remote file: 76.9 kB
input/5.jpg ADDED

Git LFS Details

  • SHA256: 5853a6efc8aedf5ef3f9fcb977c829bdaf3a5bbd2c46981a6854ed6aeec709bf
  • Pointer size: 130 Bytes
  • Size of remote file: 87.2 kB
input/6.png ADDED

Git LFS Details

  • SHA256: 91bc9e71396e0e364f66b44d5c1d58d1e5036e53f019a5badffed99d04d7e413
  • Pointer size: 132 Bytes
  • Size of remote file: 1.46 MB
input/7.png ADDED

Git LFS Details

  • SHA256: 5b126e1e7858c7d73dd1a0d2b24ca1f178d93040583b4474b1458cf70482ecc5
  • Pointer size: 132 Bytes
  • Size of remote file: 1.41 MB
input/8.png ADDED

Git LFS Details

  • SHA256: 4b36754e56b501fa74e92a89fd218f4c8571975a722b0138e2897fb1a46f9790
  • Pointer size: 132 Bytes
  • Size of remote file: 1.34 MB
input/9.jpg ADDED

Git LFS Details

  • SHA256: 6b376b2d07ca9775856f2c1ce1007bf1b7948d6f70a9738f85730a04fc252467
  • Pointer size: 131 Bytes
  • Size of remote file: 162 kB
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ pillow==10.0.0
2
+ face_alignment==1.3.4
3
+ addict==2.4.0
4
+ git+https://github.com/openai/CLIP.git
5
+ gdown==3.12.2
6
+ dlib==19.24.1