from text_generation import Client import os from dotenv import load_dotenv load_dotenv() PAPERSPACE_IP = os.getenv("PAPERSPACE_IP") client = Client(PAPERSPACE_IP) def generate_text(input_text, max_new_tokens=20, temperature=1): return client.generate(input_text, max_new_tokens=max_new_tokens, temperature=temperature).generated_text def generate_multi_text(input_text, file_path, max_new_tokens=20, temperature=1, out_path=None, earlystop = None): with open(file_path, "r") as file: rows = file.readlines() if earlystop is not None: rows = rows[:earlystop] multi_turns = [formatter(row.strip()) for row in rows] print("You are playing " + str(len(multi_turns)) + " turns.") generated_text = [] with open(out_path, "w") as file: for i, turn in enumerate(multi_turns): single_turn_resp = generate_text(input_text+turn, max_new_tokens=max_new_tokens, temperature=temperature) generated_text.append(single_turn_resp) file.write(f"Turn {i+1}: {single_turn_resp}\n") print(turn) print(single_turn_resp) print("-----------") return generated_text def read_text_file(file_path): with open(file_path, 'r') as file: return file.read() def formatter(user_prompt): return f"[User]: {user_prompt.strip()} \n [You]: \n" def main(): cwd = os.getcwd() input_text = read_text_file(os.path.join(cwd, 'utils/prompts/prompt_attitude.txt')) # user_turn = read_text_file(os.path.join(cwd, '../finetune/data/turns/conversation_nothing.txt')) max_new_tokens = 40 temperature = 0.3 multi_path = os.path.join(cwd,'inappropriate.txt') out_path = os.path.join(cwd, f'utils/user_turns/multi_turns_conversation_t{temperature}_m{max_new_tokens}_promptatt_mistral_inapp.txt') generated_text = generate_multi_text(input_text, multi_path, max_new_tokens, temperature, out_path) # print(input_text+user_turn) # generate_text_resp = generate_text(input_text+user_turn,max_new_tokens ) # print(generate_text_resp) if __name__ == "__main__": main()