File size: 6,269 Bytes
d59aeff
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
import importlib
import inspect
import re
from typing import Any, Callable, Type, Union, get_type_hints

from pydantic import BaseModel, parse_raw_as
from pydantic.tools import parse_obj_as


def name_to_title(name: str) -> str:
    """Converts a camelCase or snake_case name to title case."""
    # If camelCase -> convert to snake case
    name = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", name)
    name = re.sub("([a-z0-9])([A-Z])", r"\1_\2", name).lower()
    # Convert to title case
    return name.replace("_", " ").strip().title()


def is_compatible_type(type: Type) -> bool:
    """Returns `True` if the type is opyrator-compatible."""
    try:
        if issubclass(type, BaseModel):
            return True
    except Exception:
        pass

    try:
        # valid list type
        if type.__origin__ is list and issubclass(type.__args__[0], BaseModel):
            return True
    except Exception:
        pass

    return False


def get_input_type(func: Callable) -> Type:
    """Returns the input type of a given function (callable).

    Args:
        func: The function for which to get the input type.

    Raises:
        ValueError: If the function does not have a valid input type annotation.
    """
    type_hints = get_type_hints(func)

    if "input" not in type_hints:
        raise ValueError(
            "The callable MUST have a parameter with the name `input` with typing annotation. "
            "For example: `def my_opyrator(input: InputModel) -> OutputModel:`."
        )

    input_type = type_hints["input"]

    if not is_compatible_type(input_type):
        raise ValueError(
            "The `input` parameter MUST be a subclass of the Pydantic BaseModel or a list of Pydantic models."
        )

    # TODO: return warning if more than one input parameters

    return input_type


def get_output_type(func: Callable) -> Type:
    """Returns the output type of a given function (callable).

    Args:
        func: The function for which to get the output type.

    Raises:
        ValueError: If the function does not have a valid output type annotation.
    """
    type_hints = get_type_hints(func)
    if "return" not in type_hints:
        raise ValueError(
            "The return type of the callable MUST be annotated with type hints."
            "For example: `def my_opyrator(input: InputModel) -> OutputModel:`."
        )

    output_type = type_hints["return"]

    if not is_compatible_type(output_type):
        raise ValueError(
            "The return value MUST be a subclass of the Pydantic BaseModel or a list of Pydantic models."
        )

    return output_type


def get_callable(import_string: str) -> Callable:
    """Import a callable from an string."""
    callable_seperator = ":"
    if callable_seperator not in import_string:
        # Use dot as seperator
        callable_seperator = "."

    if callable_seperator not in import_string:
        raise ValueError("The callable path MUST specify the function. ")

    mod_name, callable_name = import_string.rsplit(callable_seperator, 1)
    mod = importlib.import_module(mod_name)
    return getattr(mod, callable_name)


class Opyrator:
    def __init__(self, func: Union[Callable, str]) -> None:
        if isinstance(func, str):
            # Try to load the function from a string notion
            self.function = get_callable(func)
        else:
            self.function = func

        self._action = "Execute"
        self._input_type = None
        self._output_type = None

        if not callable(self.function):
            raise ValueError("The provided function parameters is not a callable.")

        if inspect.isclass(self.function):
            raise ValueError(
                "The provided callable is an uninitialized Class. This is not allowed."
            )

        if inspect.isfunction(self.function):
            # The provided callable is a function
            self._input_type = get_input_type(self.function)
            self._output_type = get_output_type(self.function)

            try:
                # Get name
                self._name = name_to_title(self.function.__name__)
            except Exception:
                pass

            try:
                # Get description from function
                doc_string = inspect.getdoc(self.function)
                if doc_string:
                    self._action = doc_string
            except Exception:
                pass
        elif hasattr(self.function, "__call__"):
            # The provided callable is a function
            self._input_type = get_input_type(self.function.__call__)  # type: ignore
            self._output_type = get_output_type(self.function.__call__)  # type: ignore

            try:
                # Get name
                self._name = name_to_title(type(self.function).__name__)
            except Exception:
                pass

            try:
                # Get action from
                doc_string = inspect.getdoc(self.function.__call__)  # type: ignore
                if doc_string:
                    self._action = doc_string

                if (
                    not self._action
                    or self._action == "Call"
                ):
                    # Get docstring from class instead of __call__ function
                    doc_string = inspect.getdoc(self.function)
                    if doc_string:
                        self._action = doc_string
            except Exception:
                pass
        else:
            raise ValueError("Unknown callable type.")

    @property
    def name(self) -> str:
        return self._name

    @property
    def action(self) -> str:
        return self._action

    @property
    def input_type(self) -> Any:
        return self._input_type

    @property
    def output_type(self) -> Any:
        return self._output_type

    def __call__(self, input: Any, **kwargs: Any) -> Any:

        input_obj = input

        if isinstance(input, str):
            # Allow json input
            input_obj = parse_raw_as(self.input_type, input)

        if isinstance(input, dict):
            # Allow dict input
            input_obj = parse_obj_as(self.input_type, input)

        return self.function(input_obj, **kwargs)