class AdvancedSummarizer: def init(self, model_name="facebook/bart-large-cnn"): self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.model = BartForConditionalGeneration.from_pretrained(model_name).to(self.device) self.tokenizer = BartTokenizer.from_pretrained(model_name)

def summarize(self, text, max_length=150, min_length=50, length_penalty=2.0, num_beams=4):
    inputs = self.tokenizer([text], max_length=1024, return_tensors="pt", truncation=True)
    inputs = inputs.to(self.device)

    summary_ids = self.model.generate(
        inputs["input_ids"],
        num_beams=num_beams,
        max_length=max_length,
        min_length=min_length,
        length_penalty=length_penalty
    )

    summary = self.tokenizer.decode(summary_ids[0], skip_special_tokens=True)
    return summary

def main_summarizer(): # Example usage summarizer = AdvancedSummarizer() text = """...""" # Your text here summary = summarizer.summarize(text) print("Summary:") print(summary)

class AdvancedTextGenerator: def init(self, model_name="gpt2-medium"): try: self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"Using device: {self.device}") self.model = GPT2LMHeadModel.from_pretrained(model_name).to(self.device) self.tokenizer = GPT2Tokenizer.from_pretrained(model_name) except Exception as e: print(f"Error initializing the model: {e}") sys.exit(1)

def generate_text(self, prompt, max_length=100, num_return_sequences=1, 
                  temperature=1.0, top_k=50, top_p=0.95, repetition_penalty=1.0):
    try:
        input_ids = self.tokenizer.encode(prompt, return_tensors="pt").to(self.device)

        output_sequences = self.model.generate(
            input_ids=input_ids,
            max_length=max_length + len(input_ids[0]),
            temperature=temperature,
            top_k=top_k,
            top_p=top_p,
            repetition_penalty=repetition_penalty,
            do_sample=True,
            num_return_sequences=num_return_sequences,
        )

        generated_sequences = []
        for generated_sequence in output_sequences:
            text = self.tokenizer.decode(generated_sequence, clean_up_tokenization_spaces=True)
            total_sequence = text[len(self.tokenizer.decode(input_ids[0], clean_up_tokenization_spaces=True)):]
            generated_sequences.append(total_sequence)

        return generated_sequences
    except Exception as e:
        return [f"Error during text generation: {e}"]

def main_generator(): parser = argparse.ArgumentParser(description="Advanced Text Generator") parser.add_argument("--prompt", type=str, help="Starting prompt for text generation") parser.add_argument("--max_length", type=int, default=100, help="Maximum length of generated text") parser.add_argument("--num_sequences", type=int, default=1, help="Number of sequences to generate") parser.add_argument("--temperature", type=float, default=1.0, help="Temperature for sampling") parser.add_argument("--top_k", type=int, default=50, help="Top-k sampling parameter") parser.add_argument("--top_p", type=float, default=0.95, help="Top-p sampling parameter") parser.add_argument("--repetition_penalty", type=float, default=1.0, help="Repetition penalty")

args = parser.parse_args()

generator = AdvancedTextGenerator()

if args.prompt:
    prompt = args.prompt
else:
    print("Please enter the prompt for text generation:")
    prompt = input().strip()

generated_texts = generator.generate_text(
    prompt, 
    max_length=args.max_length,
    num_return_sequences=args.num_sequences,
    temperature=args.temperature,
    top_k=args.top_k,
    top_p=args.top_p,
    repetition_penalty=args.repetition_penalty
)

print("\nGenerated Text(s):")
for i, text in enumerate(generated_texts, 1):
    print(f"\n--- Sequence {i} ---")
    print(text)

if name == "main": main_summarizer() # Call the summarizer main function main_generator() # Call the text generator main function

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model is not currently available via any of the supported Inference Providers.
The model cannot be deployed to the HF Inference API: The model has no pipeline_tag.

Datasets used to train shing12345/AssistGPT