|
import gradio as gr |
|
from transformers import pipeline |
|
from flask import Flask, request, jsonify |
|
import os |
|
from huggingface_hub import login |
|
|
|
|
|
HF_TOKEN = os.getenv("HUGGINGFACE_TOKEN") |
|
|
|
if HF_TOKEN is None: |
|
raise ValueError("Hugging Face token is not set in environment variables") |
|
|
|
|
|
try: |
|
login(token=HF_TOKEN, add_to_git_credential=True) |
|
except ValueError as e: |
|
print(f"Error during login: {e}") |
|
raise |
|
|
|
|
|
model_id = "rish13/polymers2" |
|
model = pipeline('text-generation', model=model_id) |
|
|
|
|
|
def generate_text(prompt): |
|
return model(prompt)[0]['generated_text'] |
|
|
|
gradio_interface = gr.Interface(fn=generate_text, inputs="text", outputs="text") |
|
|
|
|
|
app = Flask(__name__) |
|
|
|
@app.route('/search', methods=['POST']) |
|
def predict_endpoint(): |
|
data = request.json |
|
prompt = data.get('prompt', '') |
|
generated_text = generate_text(prompt) |
|
return jsonify({"result": generated_text}) |
|
|
|
@app.route('/') |
|
def home(): |
|
return "Welcome to the text generation API. Use the /search endpoint to generate text." |
|
|
|
if __name__ == "__main__": |
|
|
|
from threading import Thread |
|
gradio_thread = Thread(target=lambda: gradio_interface.launch(share=False, inbrowser=True)) |
|
gradio_thread.start() |
|
|
|
|
|
app.run(host='0.0.0.0', port=5000) |
|
|