Spaces:
Runtime error
Runtime error
import sys | |
from typing import Any, Dict, List, Optional, Generic, TypeVar, cast | |
from types import TracebackType | |
if sys.version_info >= (3, 8): | |
from importlib.metadata import entry_points | |
else: | |
from importlib_metadata import entry_points | |
from toolz import curry | |
PluginType = TypeVar("PluginType") | |
class NoSuchEntryPoint(Exception): | |
def __init__(self, group, name): | |
self.group = group | |
self.name = name | |
def __str__(self): | |
return f"No {self.name!r} entry point found in group {self.group!r}" | |
class PluginEnabler: | |
"""Context manager for enabling plugins | |
This object lets you use enable() as a context manager to | |
temporarily enable a given plugin:: | |
with plugins.enable('name'): | |
do_something() # 'name' plugin temporarily enabled | |
# plugins back to original state | |
""" | |
def __init__(self, registry: "PluginRegistry", name: str, **options): | |
self.registry = registry # type: PluginRegistry | |
self.name = name # type: str | |
self.options = options # type: Dict[str, Any] | |
self.original_state = registry._get_state() # type: Dict[str, Any] | |
self.registry._enable(name, **options) | |
def __enter__(self) -> "PluginEnabler": | |
return self | |
def __exit__(self, typ: type, value: Exception, traceback: TracebackType) -> None: | |
self.registry._set_state(self.original_state) | |
def __repr__(self) -> str: | |
return "{}.enable({!r})".format(self.registry.__class__.__name__, self.name) | |
class PluginRegistry(Generic[PluginType]): | |
"""A registry for plugins. | |
This is a plugin registry that allows plugins to be loaded/registered | |
in two ways: | |
1. Through an explicit call to ``.register(name, value)``. | |
2. By looking for other Python packages that are installed and provide | |
a setuptools entry point group. | |
When you create an instance of this class, provide the name of the | |
entry point group to use:: | |
reg = PluginRegister('my_entrypoint_group') | |
""" | |
# this is a mapping of name to error message to allow custom error messages | |
# in case an entrypoint is not found | |
entrypoint_err_messages = {} # type: Dict[str, str] | |
# global settings is a key-value mapping of settings that are stored globally | |
# in the registry rather than passed to the plugins | |
_global_settings = {} # type: Dict[str, Any] | |
def __init__(self, entry_point_group: str = "", plugin_type: type = object): | |
"""Create a PluginRegistry for a named entry point group. | |
Parameters | |
========== | |
entry_point_group: str | |
The name of the entry point group. | |
plugin_type: object | |
A type that will optionally be used for runtime type checking of | |
loaded plugins using isinstance. | |
""" | |
self.entry_point_group = entry_point_group # type: str | |
self.plugin_type = plugin_type # type: Optional[type] | |
self._active = None # type: Optional[PluginType] | |
self._active_name = "" # type: str | |
self._plugins = {} # type: Dict[str, PluginType] | |
self._options = {} # type: Dict[str, Any] | |
self._global_settings = self.__class__._global_settings.copy() # type: dict | |
def register(self, name: str, value: Optional[PluginType]) -> Optional[PluginType]: | |
"""Register a plugin by name and value. | |
This method is used for explicit registration of a plugin and shouldn't be | |
used to manage entry point managed plugins, which are auto-loaded. | |
Parameters | |
========== | |
name: str | |
The name of the plugin. | |
value: PluginType or None | |
The actual plugin object to register or None to unregister that plugin. | |
Returns | |
======= | |
plugin: PluginType or None | |
The plugin that was registered or unregistered. | |
""" | |
if value is None: | |
return self._plugins.pop(name, None) | |
else: | |
assert isinstance(value, self.plugin_type) # type: ignore[arg-type] # Should ideally be fixed by better annotating plugin_type | |
self._plugins[name] = value | |
return value | |
def names(self) -> List[str]: | |
"""List the names of the registered and entry points plugins.""" | |
exts = list(self._plugins.keys()) | |
e_points = importlib_metadata_get(self.entry_point_group) | |
more_exts = [ep.name for ep in e_points] | |
exts.extend(more_exts) | |
return sorted(set(exts)) | |
def _get_state(self) -> Dict[str, Any]: | |
"""Return a dictionary representing the current state of the registry""" | |
return { | |
"_active": self._active, | |
"_active_name": self._active_name, | |
"_plugins": self._plugins.copy(), | |
"_options": self._options.copy(), | |
"_global_settings": self._global_settings.copy(), | |
} | |
def _set_state(self, state: Dict[str, Any]) -> None: | |
"""Reset the state of the registry""" | |
assert set(state.keys()) == { | |
"_active", | |
"_active_name", | |
"_plugins", | |
"_options", | |
"_global_settings", | |
} | |
for key, val in state.items(): | |
setattr(self, key, val) | |
def _enable(self, name: str, **options) -> None: | |
if name not in self._plugins: | |
try: | |
(ep,) = [ | |
ep | |
for ep in importlib_metadata_get(self.entry_point_group) | |
if ep.name == name | |
] | |
except ValueError as err: | |
if name in self.entrypoint_err_messages: | |
raise ValueError(self.entrypoint_err_messages[name]) from err | |
else: | |
raise NoSuchEntryPoint(self.entry_point_group, name) from err | |
value = cast(PluginType, ep.load()) | |
self.register(name, value) | |
self._active_name = name | |
self._active = self._plugins[name] | |
for key in set(options.keys()) & set(self._global_settings.keys()): | |
self._global_settings[key] = options.pop(key) | |
self._options = options | |
def enable(self, name: Optional[str] = None, **options) -> PluginEnabler: | |
"""Enable a plugin by name. | |
This can be either called directly, or used as a context manager. | |
Parameters | |
---------- | |
name : string (optional) | |
The name of the plugin to enable. If not specified, then use the | |
current active name. | |
**options : | |
Any additional parameters will be passed to the plugin as keyword | |
arguments | |
Returns | |
------- | |
PluginEnabler: | |
An object that allows enable() to be used as a context manager | |
""" | |
if name is None: | |
name = self.active | |
return PluginEnabler(self, name, **options) | |
def active(self) -> str: | |
"""Return the name of the currently active plugin""" | |
return self._active_name | |
def options(self) -> Dict[str, Any]: | |
"""Return the current options dictionary""" | |
return self._options | |
def get(self) -> Optional[PluginType]: | |
"""Return the currently active plugin.""" | |
if self._options: | |
return curry(self._active, **self._options) | |
else: | |
return self._active | |
def __repr__(self) -> str: | |
return "{}(active={!r}, registered={!r})" "".format( | |
self.__class__.__name__, self._active_name, list(self.names()) | |
) | |
def importlib_metadata_get(group): | |
ep = entry_points() | |
# 'select' was introduced in Python 3.10 and 'get' got deprecated | |
# We don't check for Python version here as by checking with hasattr we | |
# also get compatibility with the importlib_metadata package which had a different | |
# deprecation cycle for 'get' | |
if hasattr(ep, "select"): | |
return ep.select(group=group) | |
else: | |
return ep.get(group, []) | |