ceiteach's picture
Created custom Inference Handler
fa1167f verified
raw
history blame
12.4 kB
from unsloth import FastLanguageModel
import torch
import json
class EndpointHandler():
def __init__(self, path="ceiteach/chart-no-pretrain-llama31-unsloth"):
max_seq_length = 4096
dtype = None
load_in_4bit = True
model_name = "ceiteach/chart-no-pretrain-llama31-unsloth"
model, tokenizer = FastLanguageModel.from_pretrained(
model_name = model_name,
max_seq_length = max_seq_length,
dtype = dtype,
load_in_4bit = load_in_4bit,
)
self.model = model
self.tokenizer = tokenizer
def __call__(self, data):
model = self.model
tokenizer = self.tokenizer
FastLanguageModel.for_inference(model)
date = "2024-08-18"
previous_metric = "CPFORD"
previous_chart_type = "line"
instruction = """
You are responding to an athlete who wants to view their data in a chart.
Today's date is 2024-08-18 (YYYY-MM-DD).
You have access to the following functions. Your response must use at least one of these functions.
Functions: "
- {"name":"draw","description":"Use this tool to add or update a dataset in the chart.","parameters":{"type":"OBJECT","description":"The parameters for the draw function.","properties":{"chartType":{"type":"STRING","description":"The type of chart to draw. It can be either 'line' or 'bar'. If unspecified use the previous chart type which will be provided in the prompt."},"metric":{"type":"STRING","description":"The metric to be charted. If unspecified use the previous metric which will be provided in the prompt."}},"required":["chartType","metric"]}}
- {"name":"updateDateRange","description":"Use this tool to update the date range of the chart.","parameters":{"type":"OBJECT","description":"The parameters for the updateDateRange function.","properties":{"startDate":{"type":"STRING","description":"The start date of the chart. If unspecified use the previous start date which will be provided in the prompt."}},"required":["startDate"]}}
- {"name":"removeDataset","description":"Use this tool to remove a dataset from the chart.","parameters":{"type":"OBJECT","description":"The parameters for the removeDataset function.","properties":{"metric":{"type":"STRING","description":"The metric to be removed. If unspecified use the previous metric which will be provided in the prompt."}},"required":["metric"]}}
- {"name":"clearChart","description":"Use this tool to clear the chart.","parameters":{"type":"OBJECT","description":"The parameters for the clearChart function.","properties":{}}}
"
"Metrics" is a map of metric names to metric ids. The format is '<metric name>': '<metric id>'.
Metrics: "
'Accelerations': accelerationCount,
'Decelerations': decelerationCount,
'Total Distance': totalDistance,
'High Speed Distance': highSpeedDistance,
'Total Sprints': totalSprints,
'Max Heart Rate': maxHeartRate,
'Acceleration Distance': accelerationDistance,
'Deceleration Distance': decelerationDistance,
'Total Jumps': totalJumps,
'Muscle Soreness': muscleSoreness,
'Fatigue': fatigue,
'Stress': stress,
'Sleep Duration': sleepDuration,
'Deep Sleep Duration': deepSleepDuration,
'REM Sleep Duration': remSleepDuration,
'Light Sleep Duration': lightSleepDuration,
'Awake Duration': awakeDuration,
'Minutes Played': minutesPlayed,
'Goals': goals,
'Assists': assists,
'Shots': shots,
'Anti Mullerian Hormone': AMH2,
'Arachidic (20:0)': ARA20,
'Behenic (22:0)': BEHE,
'Total Carotene': CAROTENE,
'cis-Monounsaturated Fatty Acids': CISMONO,
'Cortisol': CORT,
'17 Hydroxyprogesterone': CP17HYD,
'a-Linolenic (ALA) 18:3 n3': CP183AL,
'Linoleic (LA) 18:2n6': CP186LI,
'Ferritin (HS)': CP1FERR,
'Omega-3 Index (HS)': CP1OMEG,
'Testosterone (HS)': CP1TEST,
'Vitamin A': CP1VITA,
'Vitamin E ': CP1VITAM,
'Eicosapentaenoic (EPA) 20:5 n3': CP203EPT,
'Arachidonic (AA) 20:4n6': CP206ARA,
'Docosahexaen (DHA) 22:6 n3': CP223DHC,
'25-Hydroxy Vitamin D3': CP25HDV3,
'25-Hydroxy Vitamin D': CP25HVD,
'25-Hydroxy Vitamin D2': CP25HVD2,
'Homocysteine v2': CP2HOMO,
'Vitamin B2 - SpectraCell': CP2VITA,
'Vitamin B3 - SpectraCell': CP3VITA,
'Vitamin B12 - SpectraCell': CP4VITA,
'Vitamin B6 - SpectraCell': CP6VITA,
'AA:EPA Ratio': CPAAEPA,
'Active B12': CPACTIV,
'Active Vitamin B12 (HS)': CPACTIVE,
'Albumin': CPALBUM,
'Albumin (SD1)': CPALBUMI,
'Alpha-Carotene': CPALCAR,
'Albumin/Globulin ratio': CPALGLR,
'Alkaline Phosphatase': CPALPHO,
'Alanine Transaminase': CPALTRA,
'Amylase': CPAMYLA,
'Anti Inflammatory Index': CPANTII,
'Anti-mullerian hormone': CPANTIM,
'Apollpoprotein B, P': CPAPOLL,
'Activated Partial Thromboplastin Time Ratio': CPAPTTR,
'Arginine': CPARGIN,
'Arginine (HS)': CPARGINI,
'Aspartate Aminotransferase': CPASAM,
'Asparagine - SpectraCell': CPASPAR,
'Asparagine (HS)': CPASPARA,
'AST (SD1)': CPASTSD,
'Atypical Lymphocyte': CPATLYMP,
'Alpha-Tocopherol (Vit. E)': CPATVTE,
'Avg Mins/Game': CPAVGMI,
'Active B12 (TH)': CPB12,
'Basophil, Absolute': CPBASOA,
'Basophil, %': CPBASOP,
'Basophil, % only': CPBASOPH,
'Bioavailable Testosterone': CPBATT,
'B Cells, Absolute': CPBCEAB,
'B-Cell %': CPBCELL,
'B-Cell Absolute Count US': CPBCELLA,
'Coenzyme Q10 - SpectraCell': CPBCOEN,
'Bermuda grass': CPBERGR,
'Beta Globulin': CPBETAG,
'Blood Glucose - Random': CPBGLR,
'Serum Bicarbonate': CPBICAR,
'Bilirubin, Indirect': CPBILIRU,
'Biotin - SpectraCell': CPBIOTI,
'Birch pollen': CPBIRPOL,
'Blood Glucose - Fasting': CPBLGF,
'Diastolic Blood Pressure': CPBPDIA,
'Systolic Blood Pressure': CPBPSYS,
'Branched-chain amino acids BCAA (HS)': CPBRANC,
'Beta-Carotene': CPBTCAR,
'Blood Urea Nitrogen/Creatinine Ratio': CPBUNCRR,
'BUN/Urea': CPBUREA,
'Calcium, ionized': CPCALCI,
'Calcium - SpectraCell': CPCALCIU,
'Calcium': CPCALCM,
'Calcium Osmolality': CPCALCO,
'Carnitine - SpectraCell': CPCARNI,
'Carnitine (HS)': CPCARNIT,
'Cat dander': CPCATDAN,
'Calcium for PTH, Intact': CPCCALC,
'Adjusted Calcium': CPCCALCM,
'Copper - SpectraCell': CPCCOPP,
'Choline - SpectraCell': CPCHOLI,
'Chloride': CPCHOLR,
'Chromium': CPCHROM,
'Citrulline (HS)': CPCITRU,
'CK (SD1)': CPCKSD1,
'Copper': CPCOPP,
'Copper (serum) (HS)': CPCOPPE,
'Cortisol (HS)': CPCORTI,
'Creatinine': CPCREAT,
'Creatine Kinase': CPCREATK,
'C Reactive Protein ': CPCRREPR,
'DHEA-Sulfate Serum': CPDHEAS,
'Direct Bilirubin': CPDIRBIL,
'Vitamin D3 - SpectraCell': CPDVITA,
'Vitamin E (Alpha Tocopherol)': CPEALPH,
'Vitamin E (Gamma Tocopherol)': CPEGAMM,
'Erythropoietin (EPO)': CPERYTH,
'Oestradiol': CPESTR,
'Vitamin C': CPEVITA,
'Folic Acid Red Blood Cell': CPFACRBC,
'Ferritin': CPFERRI,
'Fibrinogen': CPFIBRI,
'Follicle-Stimulating Hormone': CPFLSH,
'Folic Acid': CPFOLAC,
'Free Triiodothyronine': CPFRTRII,
'Free Testosterone': CPFRTTTE,
'Fructose Sensitivity': CPFRUCT,
'FSH': CPFSH,
'Free Thyroxine': CPFTHYR,
'Gamma Globulin': CPGAMMAG,
'EGFR Non-African American': CPGFRNAA,
'GRA': CPGGGGG,
'Gamma-Glutamyl Transpeptidase': CPGGLTRA,
'Glomerular Filtration Rate': CPGLFR,
'EGFR African American': CPGLFRAA,
'Globulins': CPGLOBU,
'Glutathione, Total': CPGLU,
'Glucose': CPGLUC,
'Glutathione - Red Cell': CPGLURC,
'Glutamine': CPGLUTA,
'Glutamic acid (HS)': CPHGLUT,
'Histidine (HS)': CPHISTI,
'Homocysteine': CPHOMOL,
'hs-CRP': CPHSCRP,
'HS-Omega 3': CPHSOM3,
'Omega-6 Fatty Acids': CPHSOME,
'Insulin': CPINSUL,
'Iodine': CPIODIN,
'Inorganic Phosphorus': CPIPHOSP,
'Iron (TH)': CPIRONT,
'Iron Saturation': CPIRSAT,
'Isoleucine': CPISOLE,
'Glutamine (HS)': CPKGLUT,
'Leucine': CPLEUCI,
'Magnesium': CPMAGM,
'Magnesium (erythrocytes) (HS)': CPMAGNE,
'Magnesium, RBC (BRF)': CPMAGNES,
'Manganese': CPMANGA,
'Manganese - SpectraCell': CPMANGAN,
'Mean Corpuscular Haemoglobin': CPMCH,
'Mercury': CPMERCU,
'Methionine': CPMETHI,
'Magnesium': CPMMAGN,
'Neutrophil': CPNEUTR,
'Non High-Density Lipoprotein Cholesterol': CPNHDLCH,
'NKCA %': CPNKCA,
'NKCA Per Cell': CPNKCAP,
'NK Cell %': CPNKCEL,
'NK Cell Absolute': CPNKCELL,
'Natural Killer Cells, Absolute': CPNKNCA,
'Nucleated RBC': CPNRBC,
'Zinc': CPNZINC,
'Oleic Acid': CPOLEICA,
'Omega 6:3': CPOMEGA,
'Omega6 : Omega3 ratio': CPOMG6_3,
'Prolactin': CPPROLN,
'Prostate Specific Antigen': CPPROSAG,
'Vitamin E': CPQQVIT,
'Vitamin B1': CPQVITA,
'Red Blood Cell Count': CPRBCC,
'Vitamin B5': CPRFVIT,
'Season Games Played': CPSEASO,
'SEGS': CPSEGS,
'Selenium': CPSELEN,
'Selenium (erythrocytes) (HS)': CPSELENI,
'Serine': CPSERI,
'Serine - SpectraCell': CPSERIN,
'Ferritin (TH)': CPSFERR,
'Serum Folate': CPSFOLA,
'Serum Glucose - Fasting': CPSGLF,
'Serum Glucose - Random': CPSGLR,
'Sex Hormone Binding Globulin': CPSHBG,
'Serum Inorganic Phosphate': CPSINP,
'Serum Iron': CPSIRON,
'Serum Lutein': CPSLUTE,
'Sodium': CPSODI,
'Vitamin B1': CPSPECT,
'Selenium': CPSSELE,
'Serum Testosterone': CPSTTTE,
'Vitamin A': CPSVITA,
'Vitamin D': CPSVITAM,
'T3 Uptake': CPT3UPT,
'Taurine': CPTAUR,
'Taurine (HS)': CPTAURI,
'T-Cells %': CPTCELL,
'T-Cell Absolute': CPTCELLA,
'T:C Ratio': CPTCRAT,
'T:C Ratio (HS)': CPTCRATI,
'Total Daily Cortisol': CPTDCORT,
'Testosterone Total Female': CPTESTO,
'Free Testosterone Female': CPTESTOS,
'Threonine': CPTHREO,
'Total Cholesterol': CPTOTCHO,
'Total Testosterone': CPTOTTTE,
'Transferrin': CPTRANSF,
'Triglycerides': CPTRYG,
'Tryptophan (HS)': CPTRYPT,
'Thyroid Stimulating Hormone': CPTSHBR,
'Testosterone': CPTTTE,
'Total Vitamin D': CPTVD,
'Typical Cycle Length': CPTYPIC,
'Tryptophan': CPTYPTO,
'Tyrosine (HS)': CPTYROS,
'Vitamin A': CPVITAM,
'Vitamin B1': CPVITB1,
'Vitamin B12': CPVITB12,
'Vitamin B2': CPVITB2,
'Vitamin B6': CPVITB6,
'Vitamin C': CPVITC,
'VLDL Cholesterol Cal': CPVLDLC,
'Vitamin E': CPVVITA,
'White Blood Cell Count': CPWBCC,
'Vitamin D 1,25': CPWVITA,
'Zinc': CPZINC,
'LDL/HDL Ratio': LDL - HDL,
'Omega 3 Fatty Acids': O3FA,
'Omega 3 Index': O3INDEX
"
Rules to follow: "
- Use the "draw" tool to add or update a dataset in the chart.
- Use the "updateDateRange" tool to update the date range of the chart.
- Use the "removeDataset" tool to remove a dataset from the chart.
- Use the "clearChart" tool to clear the chart.
- If the user wants you to perform multiple actions, you should return multiple function calls. For example, if the user says "Draw A and B", you should return two draw function calls.
- You can use the same tool multiple times in a response.
- The only values supported for "chartType" are "line" and "bar".
- If the user provides a metric name, you should return the metric id. If the user does not provide a metric name, you should return the previous metric id.
- Assume metric names if they are misspelled or shortened. For example, "fourth data" should be assumed to be "fort data".
- Never return a metric id that is not in "Metrics" and return the metric id exactly as it is formatted in "Metrics".
- If the user says something like "for the last 2 months", you must calculate the start date based on today's date.
- Only use the "removeDataset", "removeHighlight" or "clearChart" tools if the user has explicitly asked for it.
- "removeDataset" will remove a particular dataset. 'clearChart' will completely clear the chart. If the user mentions an athlete or metric, you should assume they want to remove a dataset and not clear the chart.
"
If a new chart type or metric is not provided, use the previous values as we can assume the user wants to continue with the same options.
- The previous chart type is line.
- The previous metric is CPFORD.
"""
user_input = data.pop("inputs", data)
alpaca_prompt = f"""Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.
### Instruction:
{instruction}
### Input:
{user_input}
### Response:
"""
prompt_tokenized = tokenizer(alpaca_prompt, return_tensors="pt").to("cuda")
with torch.no_grad():
response = tokenizer.decode(model.generate(**prompt_tokenized, max_new_tokens=128, temperature=0.01)[0], skip_special_tokens=True)
response_content = response.split("### Response:")[1].strip()
try:
function_calls = json.loads(response_content);
return { "response": function_calls }
except json.JSONDecodeError as e:
print("Failed to parse function calls from response")
return { "response": "Unable to generate response"}