import re import numpy as np import pickle def load_sentense_data(): min_sent_length = 20 All_sentenses = [] with open('reddit_short_stories.txt','r') as f: stories = f.readlines() for story in stories: L = story.split(' _____________ ')[0] L=L.strip('').strip('" ') L=re.split('; |, |\.|\*|\n|',L) All_sentenses.extend([x.lower() for x in L if len(x)>min_sent_length]) return All_sentenses def story_model(preds,res_len=5): with open("out.bin", "rb") as fp: All_sentenses = pickle.load(fp) #All_sentenses = load_sentense_data() dists = np.zeros(len(All_sentenses)) for word,prob in preds: dists += np.array([prob * (word in sent) for sent in All_sentenses]) f_res = list(zip(dists,All_sentenses)).sort(key=lambda x:x[1])[-res_len:] return f_res # list of sentenses