Edit model card

Model Card for Model ID

Persian spelling correction model based on LLama3.1 instruct. This model was trained on user queries to Basalam.com market.

Model Details

Model Description

Base_prompt = """You are tasked with correcting spelling mistakes in the queries that users submitted to a Persian marketplace.

Output the corrected query in the following JSON format:
- If the input requires correction, use:
  {"correction": "<corrected version of the query>"}
- If the input is correct, use:
  {"correction": ""}
Here are some examples:
"query": "ندل چسبی زنانه" Your answer: {"correction": "صندل چسبی زنانه"}
"query": "بادکنک جشن تواد"  Your answer: {"correction": "بادکنک جشن تولد"}
"query": "صندلی بادی"  Your answer: {"correction": ""}\n"""

Uses

It should be used for spelling correction in a setting with Persian language around.

Direct Use

//output structring
def extract_json(text):
    try: 
        correction = None
        pos = 0
        decoder = json.JSONDecoder()
        while pos < len(text):
            match = text.find('{"correction":', pos)
            if match == -1:
                break
            try:
                result, index = decoder.raw_decode(text[match:])
                correction = result.get('correction')
                if correction:
                    return correction
                pos = match + index
            except json.JSONDecodeError:
                pos = match + 1
        return correction
    except Exception as e:
        return text

        
//Load Model
BASE_MODEL_PATH = "meta-llama/Meta-Llama-3.1-8B-Instruct"
model_name_or_path = "mfsadi/Llama-3.1-8B-spelling-fa"
base_model = AutoModelForCausalLM.from_pretrained(BASE_MODEL_PATH, return_dict=True)
spelling_model = PeftModel.from_pretrained(base_model, model_name_or_path)
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_PATH)


//Inference. You need to pass "query".
prompt = f"""### Human: {spell_checking_prompt} query: {query}\n ### Assistant:"""
batch = tokenizer(str([prompt]), return_tensors='pt')
prompt_length = len(batch.get('input_ids')[0])
max_new_tokens = 50
with torch.no_grad():
    output_tokens = spelling_model.generate(**batch.to(device), max_new_tokens=max_new_tokens,
                                            repetition_penalty=1.1,
                                            do_sample=True,
                                            num_beams=2,
                                            temperature=0.1,
                                            top_k=10,
                                            top_p=.5,
                                            length_penalty=-1
                                            )
    output = tokenizer.decode(output_tokens[0][prompt_length:], skip_special_tokens=True)
    return extract_json(output)

Model Card Contact

Majid F. Sadi

[email protected]

https://www.linkedin.com/in/mfsadi/

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference API
Unable to determine this model’s pipeline type. Check the docs .