|
from typing import Dict, List, Any |
|
from gliner import GLiNER |
|
|
|
class EndpointHandler: |
|
def __init__(self, path=""): |
|
|
|
self.model = GLiNER.from_pretrained("urchade/gliner_multi-v2.1") |
|
|
|
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: |
|
""" |
|
Args: |
|
data (Dict[str, Any]): The input data including: |
|
- "inputs": The text input from which to extract information. |
|
- "labels": The labels to predict entities for. |
|
|
|
Returns: |
|
List[Dict[str, Any]]: The extracted entities from the text, formatted as required. |
|
""" |
|
|
|
inputs = data.get("inputs", "") |
|
labels = ["party", "document title"] |
|
|
|
entities = self.model.predict_entities(inputs, labels) |
|
|
|
|
|
organized_entities = {label: {"labels": [], "scores": []} for label in labels} |
|
|
|
for entity in entities: |
|
label = entity['label'] |
|
text = entity['text'] |
|
score = entity['score'] |
|
|
|
|
|
organized_entities[label]["labels"].append(text) |
|
organized_entities[label]["scores"].append(score) |
|
|
|
|
|
doc.meta["entities"] = organized_entities |
|
|
|
return {"documents": documents} |
|
|