|
import pandas as pd |
|
|
|
def predict(data, task, model, tokenizer, config, **kwargs): |
|
if isinstance(data, pd.DataFrame): |
|
data = data[data.columns[0]].tolist() |
|
is_df = True |
|
results = [] |
|
addn_args = kwargs.get("addn_args", {}) |
|
for d in data: |
|
inputs = tokenizer(d, return_tensors="pt", return_attention_mask=False) |
|
outputs = model.generate(**inputs, **addn_args, max_length=50) |
|
text = tokenizer.batch_decode(outputs)[0] |
|
results.append(text) |
|
if is_df: |
|
return pd.DataFrame(results,columns =['output']) |
|
return {"output": results} |