|
from transformers import RagTokenizer, RagRetriever, RagTokenForGeneration |
|
from datasets import load_dataset |
|
import torch |
|
def load_rag_model(): |
|
tokenizer = RagTokenizer.from_pretrained("nklomp/rag-example") |
|
retriever = RagRetriever.from_pretrained("nklomp/rag-example", dataset=load_dataset("your_dataset")) |
|
model = RagTokenForGeneration.from_pretrained("nklomp/rag-example", retriever=retriever) |
|
return tokenizer, model |
|
|
|
def query_model(tokenizer, model, query): |
|
inputs = tokenizer(query, return_tensors="pt") |
|
with torch.no_grad(): |
|
outputs = model.generate(**inputs) |
|
return tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
|
|
|
|
tokenizer, model = load_rag_model() |
|
user_query = "I am looking for companies that can handle a large construction project." |
|
response = query_model(tokenizer, model, user_query) |
|
print(response) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|