Story_make / story.py
Yasin Yousif
Add application file
9cf758a
raw
history blame
912 Bytes
import re
import numpy as np
import pickle
def load_sentense_data():
min_sent_length = 20
All_sentenses = []
with open('reddit_short_stories.txt','r') as f:
stories = f.readlines()
for story in stories:
L = story.split('<nl> _____________ <nl>')[0]
L=L.strip('<sos>').strip('" <eos>')
L=re.split('; |, |\.|\*|\n|<nl>',L)
All_sentenses.extend([x.lower() for x in L if len(x)>min_sent_length])
return All_sentenses
def story_model(preds,res_len=5):
with open("out.bin", "rb") as fp:
All_sentenses = pickle.load(fp)
#All_sentenses = load_sentense_data()
dists = np.zeros(len(All_sentenses))
for word,prob in preds:
dists += np.array([prob * (word in sent) for sent in All_sentenses])
f_res = list(zip(dists,All_sentenses)).sort(key=lambda x:x[1])[-res_len:]
return f_res # list of sentenses