Model Card for Model ID
Persian spelling correction model based on LLama3.1 instruct. This model was trained on user queries to Basalam.com market.
Model Details
Model Description
Base_prompt = """You are tasked with correcting spelling mistakes in the queries that users submitted to a Persian marketplace.
Output the corrected query in the following JSON format:
- If the input requires correction, use:
{"correction": "<corrected version of the query>"}
- If the input is correct, use:
{"correction": ""}
Here are some examples:
"query": "ندل چسبی زنانه" Your answer: {"correction": "صندل چسبی زنانه"}
"query": "بادکنک جشن تواد" Your answer: {"correction": "بادکنک جشن تولد"}
"query": "صندلی بادی" Your answer: {"correction": ""}\n"""
Uses
It should be used for spelling correction in a setting with Persian language around.
Direct Use
//output structring
def extract_json(text):
try:
correction = None
pos = 0
decoder = json.JSONDecoder()
while pos < len(text):
match = text.find('{"correction":', pos)
if match == -1:
break
try:
result, index = decoder.raw_decode(text[match:])
correction = result.get('correction')
if correction:
return correction
pos = match + index
except json.JSONDecodeError:
pos = match + 1
return correction
except Exception as e:
return text
//Load Model
BASE_MODEL_PATH = "meta-llama/Meta-Llama-3.1-8B-Instruct"
model_name_or_path = "mfsadi/Llama-3.1-8B-spelling-fa"
base_model = AutoModelForCausalLM.from_pretrained(BASE_MODEL_PATH, return_dict=True)
spelling_model = PeftModel.from_pretrained(base_model, model_name_or_path)
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_PATH)
//Inference. You need to pass "query".
prompt = f"""### Human: {spell_checking_prompt} query: {query}\n ### Assistant:"""
batch = tokenizer(str([prompt]), return_tensors='pt')
prompt_length = len(batch.get('input_ids')[0])
max_new_tokens = 50
with torch.no_grad():
output_tokens = spelling_model.generate(**batch.to(device), max_new_tokens=max_new_tokens,
repetition_penalty=1.1,
do_sample=True,
num_beams=2,
temperature=0.1,
top_k=10,
top_p=.5,
length_penalty=-1
)
output = tokenizer.decode(output_tokens[0][prompt_length:], skip_special_tokens=True)
return extract_json(output)
Model Card Contact
Majid F. Sadi