import gradio as gr from dataclasses import dataclass from concurrent.futures import ThreadPoolExecutor, TimeoutError from huggingface_hub import InferenceClient import os import re import subprocess import tempfile import json import datasets from datasets import load_dataset from datasets import Value, Features import random import time from typing import Tuple, Dict, Any, List from sympy import N, simplify from sympy.parsing.latex import parse_latex #from openai import OpenAI import base64 from transformers import AutoModelForSequenceClassification, AutoTokenizer from transformers import AutoTokenizer, AutoModelForPreTraining #client = OpenAI( # base_url=os.environ.get("SERVER_URL"), # api_key=os.environ.get("HF_TOKEN"), #) client = InferenceClient("mistralai/mathstral-7B-v0.1") @dataclass class Config: debug: bool = False push_to_hub: bool = False model_id: str = None revision: str = None system_prompt: str = None validation_set: str = None is_quantized: bool = False restart_on_fail: bool = False is_submission: bool = False num_samples: int = 1 num_generations: int = 1 do_sample: bool = True temperature: float = 1.0 top_p: float = 0.9 top_k: int = 50 max_new_tokens: int = 100 # Load pre-trained Wit Transformer model and tokenizer tokenizer = AutoTokenizer.from_pretrained("AnReu/math_pretrained_bert") model = AutoModelForPreTraining.from_pretrained("AnReu/math_pretrained_bert") class PythonREPL: def __init__(self, timeout=5): self.timeout = timeout def execute(self, query: str) -> Tuple[bool, str]: query = "import math\nimport numpy as np\nimport sympy as sp\n" + query query = query.strip().split("\n") if "print(" not in query[-1]: if "#" in query[-1]: query[-1] = query[-1].split("#")[0] query[-1] = "print(" + query[-1] + ")" query = "\n".join(query) with tempfile.TemporaryDirectory() as temp_dir: temp_file_path = os.path.join(temp_dir, "tmp.py") with open(temp_file_path, "w") as f: f.write(query) result = subprocess.run( ["python3", temp_file_path], capture_output=True, check=False, text=True, timeout=self.timeout, ) if result.returncode == 0: output = result.stdout return True, output.strip() else: error_msg = result.stderr.strip() msgs = error_msg.split("\n") new_msgs = [] want_next = False for m in msgs: if "Traceback" in m: new_msgs.append(m) elif m == msgs[-1]: new_msgs.append(m) elif temp_file_path in m: st = m.index('"/') + 1 if '"/' in m else 0 ed = m.index(temp_file_path) + 1 if temp_file_path in m else None clr = m[st:ed] if not ed else m[st:] m = m.replace(clr, "") new_msgs.append(m) want_next = True elif want_next: new_msgs.append(m) want_next = False error_msg = "\n".join(new_msgs) return False, error_msg.strip() def __call__(self, query: str) -> Tuple[bool, str]: with ThreadPoolExecutor() as executor: future = executor.submit(self.execute, query) try: return future.result(timeout=self.timeout) except TimeoutError: return False, f"Timed out after {self.timeout} seconds." def execute_completion( executor: PythonREPL, completion: str, return_status: bool = False, last_code_block: bool = False, ) -> str | Tuple[str, bool]: # executions = ["!" + code for code in re.findall(r"```bash(.*?)```", completion, re.DOTALL) if "!" not in code] executions = re.findall(r"```python(.*?)```", completion, re.DOTALL) if len(executions) == 0: # directly return cot result return completion, False if return_status else completion else: if last_code_block: executions = [executions[-1]] # Python execution_outputs = [] successes = [] for code in executions: success = False if "subprocess" in code: output = "subprocess is not allowed" execution_outputs.append(output) successes.append(success) continue if "venv" in code: output = "venv is not allowed" execution_outputs.append(output) successes.append(success) continue try: success, output = executor(code) except TimeoutError as e: print("time out") output = e if not success and not return_status: output = "" execution_outputs.append(output) successes.append(success) output = str(execution_outputs[-1]).strip() success = successes[-1] if return_status: return output, success else: return output def postprocess_completion( text: str, return_status: bool = False, last_code_block=False, timeout=5 ) -> str | Tuple[str, bool]: executor = PythonREPL(timeout=timeout) result = execute_completion(executor, text, return_status=return_status, last_code_block=last_code_block) del executor return result def apply_template(example: Dict[str, Any], prompt: str) -> Dict[str, Any]: return prompt.format(example["prompt"], "{}") def last_boxed_only_string(string): """ Extracts the last LaTeX boxed or framed expression from a string. Args: string (str): The input string containing LaTeX expressions. Returns: str or None: The last boxed or framed expression, if found; otherwise, None. """ idx = string.rfind("\\boxed") if idx < 0: idx = string.rfind("\\fbox") if idx < 0: return None i = idx right_brace_idx = None num_left_braces_open = 0 while i < len(string): if string[i] == "{": num_left_braces_open += 1 if string[i] == "}": num_left_braces_open -= 1 if num_left_braces_open == 0: right_brace_idx = i break i += 1 if right_brace_idx is None: retval = None else: retval = string[idx : right_brace_idx + 1] return retval def remove_boxed(s): """ Removes the LaTeX boxed command, returning the content inside the braces. Args: s (str): The string containing a LaTeX boxed expression. Returns: str or None: The content inside the boxed command, if valid; otherwise, None. """ left = "\\boxed{" try: assert s[: len(left)] == left assert s[-1] == "}" length = len(left) return s[length:-1] except Exception: return None def extract_boxed_answer(pred_str, strip_double_curly_brace=False): """ Extracts the answer from a LaTeX boxed expression within a prediction string. Args: pred_str (str): The string containing one or more LaTeX boxed expressions. strip_double_curly_brace (bool): If True, removes an additional layer of braces. Returns: str or None: The extracted answer, if any; otherwise, None. """ boxed_str = last_boxed_only_string(pred_str) if boxed_str is None: return None answer = remove_boxed(boxed_str) if answer is None: return None if strip_double_curly_brace: match = re.match("^\{(.*)\}$", answer) # noqa: W605 if match: answer = match.group(1) return answer def normalize_final_answer(final_answer: str) -> str: """ Normalizes a final answer string by removing or replacing various LaTeX and text elements. Args: final_answer (str): The answer string to normalize. Returns: str: The normalized answer string. """ match = re.search(r"(.*?)Problem:", final_answer, flags=re.S) if match: final_answer = match.group(1) # 返回匹配的第一部分,即"Problem"之前的所有文本 """Normalize a final answer to a quantitative reasoning question.""" # final_answer = final_answer.split('=')[-1] SUBSTITUTIONS = [ ("an ", ""), ("a ", ""), (".$", "$"), ("\\$", ""), (r"\ ", ""), (" ", ""), ("mbox", "text"), (",\\text{and}", ","), ("\\text{and}", ","), ("\\text{m}", "\\text{}"), ("\\le", "<"), ] REMOVED_EXPRESSIONS = [ "square", "ways", "integers", "dollars", "mph", "inches", "ft", "hours", "km", "units", "\\ldots", "sue", "points", "feet", "minutes", "digits", "cents", "degrees", "cm", "gm", "pounds", "meters", "meals", "edges", "students", "childrentickets", "multiples", "\\text{s}", "\\text{.}", "\\text{\ns}", "\\text{}^2", "\\text{}^3", "\\text{\n}", "\\text{}", r"\mathrm{th}", r"^\circ", r"^{\circ}", r"\;", r",\!", "{,}", '"', "\\dots", "\n", "\r", "\f", "\%", ] for before, after in SUBSTITUTIONS: final_answer = final_answer.replace(before, after) for expr in REMOVED_EXPRESSIONS: final_answer = final_answer.replace(expr, "") # Extract answer that is in LaTeX math, is bold, # is surrounded by a box, etc. final_answer = re.sub(r"(\\text\{)(.*?)(\})", "\\2", final_answer) final_answer = re.sub(r"(\\textbf\{)(.*?)(\})", "\\2", final_answer) final_answer = re.sub(r"(\\overline\{)(.*?)(\})", "\\2", final_answer) final_answer = re.sub(r"(\\boxed\{)(.*)(\})", "\\2", final_answer) assert "\n" not in final_answer assert "\r" not in final_answer assert "\f" not in final_answer if len(re.findall(r"finalansweris(.*)", final_answer)) > 0: final_answer = re.findall(r"finalansweris(.*)", final_answer)[-1] if len(re.findall(r"answer?is:?(.*)", final_answer)) > 0: final_answer = re.findall(r"answer?is:?(.*)", final_answer)[-1] if len(re.findall(r"oxed\{(.*?)\}", final_answer)) > 0: final_answer = re.findall(r"oxed\{(.*?)\}", final_answer)[-1] if len(re.findall(r"\$(.*?)\$", final_answer)) > 0: final_answer = re.findall(r"\$(.*?)\$", final_answer)[-1] final_answer = final_answer.strip() if "rac" in final_answer and "\\frac" not in final_answer: final_answer = final_answer.replace("rac", "\\frac") final_answer = re.sub(r"(frac)([^{])(.)", "frac{\\2}{\\3}", final_answer) final_answer = re.sub(r"(sqrt)([^{])", "sqrt{\\2}", final_answer) final_answer = final_answer.replace("$", "") if final_answer.replace(",", "").isdigit(): final_answer = final_answer.replace(",", "") return final_answer def naive_parse(answer: str) -> str: """ Extracts and returns the numeric digits from the input string, processing them in reverse order until a non-numeric character is encountered after encountering the first numeric character. Args: answer (str): The input string to parse. Returns: str: A string consisting of the numeric digits extracted from the input, in their original order. Example: >>> naive_parse("abc123def") '123' >>> naive_parse("def456ghi") '456' >>> naive_parse("no numbers here") '' """ out = [] start = False end = False for l in reversed(list(answer)): if l in "0123456789" and not end: start = True out.append(l) else: if start: end = True out = reversed(out) return "".join(out) def validate_answer_is_numeric(x: str | int | float) -> int: FLOAT_TOLERANCE = 0.2 try: x = round(float(x)) f = float(x) if abs(x - f) > FLOAT_TOLERANCE: x = -1 except Exception: x = -1 return x def filter_answers(answers: List[str]) -> List[int]: formatted_answers = [validate_answer_is_numeric(a) for a in answers] # Filter for non-negative answers formatted_answers = [a for a in formatted_answers if a >= 0] # Compute modulo formatted_answers = [a % 1_000 for a in formatted_answers] # less than 2.1 billion or cannot convert to C int (32-bit) formatted_answers = [a for a in formatted_answers if a <= 999] return formatted_answers def check_sympy_equivalence(ref_answer: str, model_answer: str) -> bool: def do_answers_match(ref_answer: str, model_answer: str) -> bool: ref_sympy = parse_latex(ref_answer) model_sympy = parse_latex(model_answer) diff = simplify(ref_sympy - model_sympy) return True if -1e-12 < N(diff) < 1e-12 or diff.is_zero else False try: result = do_answers_match(ref_answer, model_answer) return result except Exception as e: print(e) return False def check_string_match(ref_answer: str, model_answer: str) -> bool: try: return ref_answer == model_answer except Exception as e: print(e) return False def check_answer(ref_answer: str, model_answer: str) -> bool: # check if strings are the same correct = check_string_match(ref_answer, model_answer) if correct: return True # use the sympy library to check if the expressions are the same correct = check_sympy_equivalence(ref_answer, model_answer) if correct: return True return False debug = False model_id = "athstral-7B-v0.m1" revision = "main" system_prompt = "{}" validation_set = "kaggle-validation-set-medium" is_submission = True num_samples = 4 num_generations = 4 temperature = 0.8 is_quantized = False restart_on_fail = False top_p = 1.0 top_k = 0 max_new_tokens = 2048 # Papermill related variables push_to_hub = False notebook_name = "" config = Config( debug=False, push_to_hub=False, model_id=model_id, revision=revision, system_prompt=system_prompt, validation_set=validation_set, is_quantized=is_quantized, restart_on_fail=restart_on_fail, is_submission=is_submission, num_samples=num_samples, num_generations=num_generations, do_sample=True, temperature=temperature, top_p=top_p, top_k=top_k, max_new_tokens=max_new_tokens ) print(f"=== Running submission with config ===\n\n{config}") def parse_data_chunk(data_chunk): """ Parse a given data chunk string into a list of individual data entries. The function splits the input string by the delimiter "data:" and removes any leading or trailing whitespace from each resulting chunk. Empty chunks are filtered out from the final list. Parameters: data_chunk (str): The input string containing data chunks separated by "data:". Returns: list: A list of individual data entries with whitespace stripped. """ chunks = data_chunk.split("data:") stripped_chunks = map(lambda chunk: chunk.strip(), chunks) return [chunk for chunk in stripped_chunks if chunk] def generate(message, temperature): """ Generates a chat completion response by streaming data from the client chat model. This function streams the response from the client chat model and yields the content of the response chunk by chunk. If an error occurs, it yields the error message. Parameters: message (str): The input message to be sent to the chat model. temperature (float): The sampling temperature to use. Higher values mean the model will take more risks. Yields: tuple: A tuple containing the content of the response and a boolean flag indicating if an error occurred. If no error occurred, the boolean flag will be False and the content will be the response text. If an error occurred, the boolean flag will be True and the content will be the error message. """ stream = client.chat.completions.create( model="tgi", messages=message, stream=False, max_tokens=1024, stop=["```output\n"], temperature=temperature, #timeout=30, ) response = stream.response # The reason why the library method is not used here is that if an error occurs, # the returned data will not be a stream, and using the official library will result in an error. for chunk in response.iter_bytes(): chunk = chunk.decode("utf-8") data_chunks = parse_data_chunk(chunk) try: for data_chunk in data_chunks: chune_json = json.loads(data_chunk) if "error" in chune_json and chune_json["error"]: yield chune_json["error"], True break delta = chune_json["choices"][0]["delta"] content = delta["content"] if "content" in delta else "" if content != "": yield content, False except Exception as e: print(f"func: generate error occurred\nchunk:{chunk}\nerror:{e}") raise e def get_majority_text(data): from collections import Counter # Count the frequency of each answer in model_answers answer_counts = Counter(data["model_answers"]) # Find the majority response majority_response = answer_counts.most_common(1)[0][0] # Find the index of the first occurrence of the majority response majority_index = data["model_answers"].index(majority_response) # Return the corresponding text in gen_texts return data["gen_texts"][majority_index] def extract_solution(text): # Split the text at "### Solution:" parts = text.split("### Solution:", 1) if len(parts) > 1: # Return everything after "### Solution:" return parts[1].strip() else: # Return an empty string if "### Solution:" is not found return "" def process_code( example: Dict[str, Any], config: Config, restart_on_fail: bool = False, last_step: bool = False, ) -> Dict[str, Any]: gen_text = example["gen_texts"] num_python_blocks = len(re.findall(r"```python(.*?)```", gen_text, re.DOTALL)) if num_python_blocks == 0: if restart_on_fail: print("no code has ever been generated, RESTARTING") # reset the text to the original example["gen_texts"] = example["text"] else: print("no code has ever been generated, STOP") example["should_prune"] = True example["has_code"] = False return example if gen_text[-10:] != "```output\n" and ("answer is" in gen_text[-100:] or "\\boxed" in gen_text[-100:]): num_output_blocks = len(re.findall(r"```output(.*?)```", gen_text, re.DOTALL)) if num_output_blocks == 0: print("the model hallucinated the code answer") example["should_prune"] = True return example if "boxed" in gen_text[-100:]: try: answer = normalize_final_answer(extract_boxed_answer(gen_text[-100:])) except Exception: answer = "-1" else: answer = normalize_final_answer(gen_text[-100:]) example["model_answers"] = answer if not config.is_submission: example["corrects"] = check_answer(example["ground_truth"], answer) example["should_prune"] = True print("Answer is: ", answer, example["ground_truth"], example["corrects"]) return example if last_step: # no point in continuing if we are at the last step return example if gen_text[-10:] != "```output\n": # something else has gone wrong with the generation print("warning: output block not found: ", gen_text[-40:]) if restart_on_fail: example["gen_texts"] = example["text"] else: example["should_prune"] = True return example code_result, status = postprocess_completion(gen_text, return_status=True, last_code_block=True) # add the code result for the next round of generation TRUNCATION_LIMIT = 200 if len(code_result) > TRUNCATION_LIMIT: code_result = code_result[:TRUNCATION_LIMIT] + " ... (output truncated)" example["gen_texts"] = gen_text + f"{code_result}\n```" return example def solve_problem(problem, temperature, progress=gr.Progress()): """ yield token: string, stop: bool """ problem = apply_template({"prompt": problem}, prompt=config.system_prompt) print(f"Problem: {problem}") sample = { "problem": problem, # not used for the submission TODO Remove "ground_truth": "unknown", # not used for the submission TODO Remove "text": "## Solution:\n", "gen_texts": "", # used to store all the generated text "should_prune": False, "problem_index": -1, # not used for the submission TODO Remove "model_answers": "-1", "has_code": True, "corrects": False, # not used for the submission TODO Remove } for step in progress.tqdm( range(config.num_generations), desc="Generating candidates" ): # Depth of the tree (e.g. 6 steps = 5 code blocks) step_reponse = sample["gen_texts"] messages = [ {"role": "user", "content": sample["problem"]}, {"role": "assistant", "content": sample["gen_texts"]}, ] stop = False for reponse_message, error in generate(messages, temperature): if reponse_message is not None: step_reponse += reponse_message yield step_reponse, False if error: stop = True sample["gen_texts"] = step_reponse # TODO: Maybe it should just return the result of running the code sample = process_code( sample, config=config, restart_on_fail=config.restart_on_fail, last_step=(step == (config.num_generations - 1)), ) sample["gen_texts"] = sample["gen_texts"] + "\n" run_code_reponse = sample["gen_texts"].replace(step_reponse, "") for output_mseeage in run_code_reponse: if output_mseeage is not None: step_reponse += output_mseeage yield step_reponse, False if sample["should_prune"] or stop: break yield sample["gen_texts"], True features = Features({ 'id': Value('int64'), 'problem': Value('string'), 'answer': Value('string'), #'prompt': Value('string'), # Ensure this matches the actual data type of 'prompt' in your dataset #'level': Value('string') }) # Now load the dataset using the defined schema example_data = datasets.load_dataset( "AI-MO/aimo-validation-math-level-5", split="train", use_auth_token=os.environ.get("HF_DATASET_TOKEN", None), features=features # Pass the schema definition here ) with open( "/teamspace/studios/this_studio/.lightning_studio/math/app.css", "r") as f: css = f.read() latex_delimiters = [ {"left": "[", "right": "]", "display": True}, ] def get_random_problem(): example = random.choice(list(example_data)) problem = example["problem"] return problem def update_example_problem(): problem_example_text = get_random_problem() return problem_example_text, problem_example_text def clear(): problem_example_text = get_random_problem() return "", 0.1, "", problem_example_text, problem_example_text def preprocess_output(text): return text.replace(r"\(", r"\\(").replace(r"\)", r"\\)") with gr.Blocks(css=css, title="Math Olympiad Solver") as demo: btn_list = [] problem_input_ele_list = [] problem_example_text = get_random_problem() with gr.Row(elem_classes="title"): gr.HTML("Math Olympiad Solver", elem_classes="title-content") with gr.Row(elem_classes="sub-title"): gr.HTML( "
Demo of the Numina-Math-7B-TIR. Example data are drawn randomly from AMC12, year 2022-2023.
", elem_classes="sub-title-content", ) with gr.Row(elem_classes="main-area"): with gr.Column(scale=1, elem_classes="left"): with gr.Row(elem_classes="probelm-example-container"): with gr.Blocks(elem_classes="probelm-example-title"): gr.HTML("Problem example", elem_classes="probelm-example-title-content") with gr.Blocks(elem_classes="action-container"): another_btn = gr.Button( "", elem_classes="probelm-example-another", icon="./static/images/reset.png", ) copy_btn = gr.Button("Copy", elem_classes="probelm-example-copy") problem_example = gr.HTML( problem_example_text, elem_classes="probelm-example-content", ) with gr.Row(elem_classes="probelm-input-container"): inp = gr.Textbox(placeholder="Problem", label="Problem input", lines=5, visible=True) problem_markdown = gr.Markdown( visible=False, latex_delimiters=[ {"left": "[", "right": "]", "display": True}, {"left": "$", "right": "$", "display": False}, {"left": r"\(", "right": r"\)", "display": False}, ], ) inp.change(fn=lambda text: text, inputs=[inp], outputs=[problem_markdown]) problem_input_ele_list.append(inp) problem_input_ele_list.append(problem_markdown) with gr.Accordion("Advanced Options", open=False): temperature = gr.Slider(minimum=0.0, maximum=1.0, value=0.1, step=0.1, label="Temperature") with gr.Row() as btn_area: btn_clear = gr.Button("Clear", elem_classes="clear-btn") btn_run = gr.Button("Run", elem_classes="run-btn") btn_list.append(btn_clear) btn_list.append(btn_run) with gr.Column(scale=1, elem_classes="right"): gr.HTML("Solution", elem_classes="solution-title-content") out = gr.Markdown( elem_classes="solution-content", latex_delimiters=[ {"left": "[", "right": "]", "display": True}, {"left": "$", "right": "$", "display": False}, {"left": r"\(", "right": r"\)", "display": False}, ], ) problem_example_text_hidden = gr.Markdown(value=problem_example_text, visible=False) def solve_problem_wrapper(inp_text, temperature): new_running_btn = gr.Button("", elem_classes="run-btn running-btn") try: for after_tokens, stop in solve_problem(inp_text, temperature): yield preprocess_output(after_tokens), new_running_btn if stop: btn_run = gr.Button("Run", elem_classes="run-btn") yield preprocess_output(after_tokens), btn_run except Exception as e: raise e def mount_run_btn(btn): btn.click(fn=solve_problem_wrapper, inputs=[inp, temperature], outputs=[out, btn_list[1]]) btn.click(get_run_after_problem_input, None, outputs=problem_input_ele_list) def get_run_after_problem_input(): return gr.Textbox(placeholder="Problem", label="Problem input", lines=5, visible=False), gr.Markdown( visible=True, latex_delimiters=[ {"left": "[", "right": "]", "display": True}, {"left": "$", "right": "$", "display": False}, ], elem_classes="problem-input-markdown", ) def get_init_problem_input(): return gr.Textbox(placeholder="Problem", label="Problem input", lines=5, visible=True), gr.Markdown( visible=False, latex_delimiters=[ {"left": "[", "right": "]", "display": True}, {"left": "$", "right": "$", "display": False}, ], ) copy_btn.click(fn=lambda example: example, inputs=[problem_example_text_hidden], outputs=[inp]) btn_clear.click( fn=clear, inputs=[], outputs=[ inp, temperature, out, problem_example, problem_example_text_hidden, ], ) btn_clear.click(get_init_problem_input, None, outputs=problem_input_ele_list) mount_run_btn(btn_run) demo.load( update_example_problem, inputs=None, outputs=[ problem_example, problem_example_text_hidden, ], ) another_btn.click( fn=update_example_problem, inputs=[], outputs=[ problem_example, problem_example_text_hidden, ], ) if __name__ == "__main__": demo.queue(default_concurrency_limit=5).launch(share=True)