phi-2 / code /phi_predict.py
pharaouk's picture
Upload 15 files
503dc31
raw
history blame
599 Bytes
import pandas as pd
def predict(data, task, model, tokenizer, config, **kwargs):
if isinstance(data, pd.DataFrame):
data = data[data.columns[0]].tolist()
is_df = True
results = []
addn_args = kwargs.get("addn_args", {})
for d in data:
inputs = tokenizer(d, return_tensors="pt", return_attention_mask=False)
outputs = model.generate(**inputs, **addn_args, max_length=50)
text = tokenizer.batch_decode(outputs)[0]
results.append(text)
if is_df:
return pd.DataFrame(results,columns =['output'])
return {"output": results}