File size: 4,497 Bytes
039aebb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
import torch
from fastapi import FastAPI, Depends, status
from fastapi.responses import PlainTextResponse
from pydantic import BaseModel
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM

import time
from typing import Dict, List, Optional

import jwt
from decouple import config
from fastapi import Request, HTTPException
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials

JWT_SECRET = config("secret")
JWT_ALGORITHM = config("algorithm")

app = FastAPI()
app.ready = False

device = ("cuda" if torch.cuda.is_available() else "cpu")
tokenizer = AutoTokenizer.from_pretrained('vblagoje/bart_lfqa')
model = AutoModelForSeq2SeqLM.from_pretrained('vblagoje/bart_lfqa').to(device)
_ = model.eval()


class JWTBearer(HTTPBearer):
    def __init__(self, auto_error: bool = True):
        super(JWTBearer, self).__init__(auto_error=auto_error)

    async def __call__(self, request: Request):
        credentials: HTTPAuthorizationCredentials = await super(JWTBearer, self).__call__(request)
        if credentials:
            if not credentials.scheme == "Bearer":
                raise HTTPException(status_code=403, detail="Invalid authentication scheme.")
            if not self.verify_jwt(credentials.credentials):
                raise HTTPException(status_code=403, detail="Invalid token or expired token.")
            return credentials.credentials
        else:
            raise HTTPException(status_code=403, detail="Invalid authorization code.")

    def verify_jwt(self, jwtoken: str) -> bool:
        isTokenValid: bool = False

        try:
            payload = decodeJWT(jwtoken)
        except:
            payload = None
        if payload:
            isTokenValid = True
        return isTokenValid


def token_response(token: str):
    return {
        "access_token": token
    }


def signJWT(user_id: str) -> Dict[str, str]:
    payload = {
        "user_id": user_id,
        "expires": time.time() + 6000
    }
    token = jwt.encode(payload, JWT_SECRET, algorithm=JWT_ALGORITHM)

    return token_response(token)


def decodeJWT(token: str) -> dict:
    try:
        decoded_token = jwt.decode(token, JWT_SECRET, algorithms=[JWT_ALGORITHM])
        return decoded_token if decoded_token["expires"] >= time.time() else None
    except:
        return {}


class LFQAParameters(BaseModel):
    min_length: int = 50
    max_length: int = 250
    do_sample: bool = False
    early_stopping: bool = True
    num_beams: int = 8
    temperature: float = 1.0
    top_k: float = None
    top_p: float = None
    no_repeat_ngram_size: int = 3
    num_return_sequences: int = 1


class InferencePayload(BaseModel):
    model_input: str
    parameters: Optional[LFQAParameters] = LFQAParameters()


@app.on_event("startup")
def startup():
    app.ready = True


@app.get("/healthz")
def healthz():
    if app.ready:
        return PlainTextResponse("ok")
    return PlainTextResponse("service unavailable", status_code=status.HTTP_503_SERVICE_UNAVAILABLE)


@app.post("/generate/", dependencies=[Depends(JWTBearer())])
def generate(context: InferencePayload):

    model_input = tokenizer(context.model_input, truncation=True, padding=True, return_tensors="pt")
    param = context.parameters
    generated_answers_encoded = model.generate(input_ids=model_input["input_ids"].to(device),
                                               attention_mask=model_input["attention_mask"].to(device),
                                               min_length=param.min_length,
                                               max_length=param.max_length,
                                               do_sample=param.do_sample,
                                               early_stopping=param.early_stopping,
                                               num_beams=param.num_beams,
                                               temperature=param.temperature,
                                               top_k=param.top_k,
                                               top_p=param.top_p,
                                               no_repeat_ngram_size=param.no_repeat_ngram_size,
                                               num_return_sequences=param.num_return_sequences)
    answers = tokenizer.batch_decode(generated_answers_encoded, skip_special_tokens=True,
                                     clean_up_tokenization_spaces=True)
    results = []
    for answer in answers:
        results.append({"generated_text": answer})
    return results