Spaces:
Runtime error
Runtime error
""" | |
Ways to transform interfaces to produce new interfaces | |
""" | |
import asyncio | |
import warnings | |
import gradio | |
from gradio.documentation import document, set_documentation_group | |
set_documentation_group("mix_interface") | |
class Parallel(gradio.Interface): | |
""" | |
Creates a new Interface consisting of multiple Interfaces in parallel (comparing their outputs). | |
The Interfaces to put in Parallel must share the same input components (but can have different output components). | |
Demos: interface_parallel, interface_parallel_load | |
Guides: advanced_interface_features | |
""" | |
def __init__(self, *interfaces: gradio.Interface, **options): | |
""" | |
Parameters: | |
interfaces: any number of Interface objects that are to be compared in parallel | |
options: additional kwargs that are passed into the new Interface object to customize it | |
Returns: | |
an Interface object comparing the given models | |
""" | |
outputs = [] | |
for interface in interfaces: | |
if not (isinstance(interface, gradio.Interface)): | |
warnings.warn( | |
"Parallel requires all inputs to be of type Interface. " | |
"May not work as expected." | |
) | |
outputs.extend(interface.output_components) | |
async def parallel_fn(*args): | |
return_values_with_durations = await asyncio.gather( | |
*[interface.call_function(0, list(args)) for interface in interfaces] | |
) | |
return_values = [rv["prediction"] for rv in return_values_with_durations] | |
combined_list = [] | |
for interface, return_value in zip(interfaces, return_values): | |
if len(interface.output_components) == 1: | |
combined_list.append(return_value) | |
else: | |
combined_list.extend(return_value) | |
if len(outputs) == 1: | |
return combined_list[0] | |
return combined_list | |
parallel_fn.__name__ = " | ".join([io.__name__ for io in interfaces]) | |
kwargs = { | |
"fn": parallel_fn, | |
"inputs": interfaces[0].input_components, | |
"outputs": outputs, | |
} | |
kwargs.update(options) | |
super().__init__(**kwargs) | |
class Series(gradio.Interface): | |
""" | |
Creates a new Interface from multiple Interfaces in series (the output of one is fed as the input to the next, | |
and so the input and output components must agree between the interfaces). | |
Demos: interface_series, interface_series_load | |
Guides: advanced_interface_features | |
""" | |
def __init__(self, *interfaces: gradio.Interface, **options): | |
""" | |
Parameters: | |
interfaces: any number of Interface objects that are to be connected in series | |
options: additional kwargs that are passed into the new Interface object to customize it | |
Returns: | |
an Interface object connecting the given models | |
""" | |
async def connected_fn(*data): | |
for idx, interface in enumerate(interfaces): | |
# skip preprocessing for first interface since the Series interface will include it | |
if idx > 0 and not (interface.api_mode): | |
data = [ | |
input_component.preprocess(data[i]) | |
for i, input_component in enumerate(interface.input_components) | |
] | |
# run all of predictions sequentially | |
data = (await interface.call_function(0, list(data)))["prediction"] | |
if len(interface.output_components) == 1: | |
data = [data] | |
# skip postprocessing for final interface since the Series interface will include it | |
if idx < len(interfaces) - 1 and not (interface.api_mode): | |
data = [ | |
output_component.postprocess(data[i]) | |
for i, output_component in enumerate( | |
interface.output_components | |
) | |
] | |
if len(interface.output_components) == 1: # type: ignore | |
return data[0] | |
return data | |
for interface in interfaces: | |
if not (isinstance(interface, gradio.Interface)): | |
warnings.warn( | |
"Series requires all inputs to be of type Interface. May " | |
"not work as expected." | |
) | |
connected_fn.__name__ = " => ".join([io.__name__ for io in interfaces]) | |
kwargs = { | |
"fn": connected_fn, | |
"inputs": interfaces[0].input_components, | |
"outputs": interfaces[-1].output_components, | |
"_api_mode": interfaces[0].api_mode, # TODO: set api_mode per-interface | |
} | |
kwargs.update(options) | |
super().__init__(**kwargs) | |