ceiteach commited on
Commit
fa1167f
1 Parent(s): 481fcd8

Created custom Inference Handler

Browse files
Files changed (2) hide show
  1. handler.py +329 -0
  2. requirements.txt +8 -0
handler.py ADDED
@@ -0,0 +1,329 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from unsloth import FastLanguageModel
2
+ import torch
3
+ import json
4
+
5
+ class EndpointHandler():
6
+ def __init__(self, path="ceiteach/chart-no-pretrain-llama31-unsloth"):
7
+ max_seq_length = 4096
8
+ dtype = None
9
+ load_in_4bit = True
10
+ model_name = "ceiteach/chart-no-pretrain-llama31-unsloth"
11
+ model, tokenizer = FastLanguageModel.from_pretrained(
12
+ model_name = model_name,
13
+ max_seq_length = max_seq_length,
14
+ dtype = dtype,
15
+ load_in_4bit = load_in_4bit,
16
+ )
17
+ self.model = model
18
+ self.tokenizer = tokenizer
19
+
20
+ def __call__(self, data):
21
+ model = self.model
22
+ tokenizer = self.tokenizer
23
+ FastLanguageModel.for_inference(model)
24
+
25
+ date = "2024-08-18"
26
+ previous_metric = "CPFORD"
27
+ previous_chart_type = "line"
28
+ instruction = """
29
+ You are responding to an athlete who wants to view their data in a chart.
30
+ Today's date is 2024-08-18 (YYYY-MM-DD).
31
+
32
+ You have access to the following functions. Your response must use at least one of these functions.
33
+
34
+ Functions: "
35
+ - {"name":"draw","description":"Use this tool to add or update a dataset in the chart.","parameters":{"type":"OBJECT","description":"The parameters for the draw function.","properties":{"chartType":{"type":"STRING","description":"The type of chart to draw. It can be either 'line' or 'bar'. If unspecified use the previous chart type which will be provided in the prompt."},"metric":{"type":"STRING","description":"The metric to be charted. If unspecified use the previous metric which will be provided in the prompt."}},"required":["chartType","metric"]}}
36
+ - {"name":"updateDateRange","description":"Use this tool to update the date range of the chart.","parameters":{"type":"OBJECT","description":"The parameters for the updateDateRange function.","properties":{"startDate":{"type":"STRING","description":"The start date of the chart. If unspecified use the previous start date which will be provided in the prompt."}},"required":["startDate"]}}
37
+ - {"name":"removeDataset","description":"Use this tool to remove a dataset from the chart.","parameters":{"type":"OBJECT","description":"The parameters for the removeDataset function.","properties":{"metric":{"type":"STRING","description":"The metric to be removed. If unspecified use the previous metric which will be provided in the prompt."}},"required":["metric"]}}
38
+ - {"name":"clearChart","description":"Use this tool to clear the chart.","parameters":{"type":"OBJECT","description":"The parameters for the clearChart function.","properties":{}}}
39
+ "
40
+
41
+ "Metrics" is a map of metric names to metric ids. The format is '<metric name>': '<metric id>'.
42
+ Metrics: "
43
+ 'Accelerations': accelerationCount,
44
+ 'Decelerations': decelerationCount,
45
+ 'Total Distance': totalDistance,
46
+ 'High Speed Distance': highSpeedDistance,
47
+ 'Total Sprints': totalSprints,
48
+ 'Max Heart Rate': maxHeartRate,
49
+ 'Acceleration Distance': accelerationDistance,
50
+ 'Deceleration Distance': decelerationDistance,
51
+ 'Total Jumps': totalJumps,
52
+ 'Muscle Soreness': muscleSoreness,
53
+ 'Fatigue': fatigue,
54
+ 'Stress': stress,
55
+ 'Sleep Duration': sleepDuration,
56
+ 'Deep Sleep Duration': deepSleepDuration,
57
+ 'REM Sleep Duration': remSleepDuration,
58
+ 'Light Sleep Duration': lightSleepDuration,
59
+ 'Awake Duration': awakeDuration,
60
+ 'Minutes Played': minutesPlayed,
61
+ 'Goals': goals,
62
+ 'Assists': assists,
63
+ 'Shots': shots,
64
+ 'Anti Mullerian Hormone': AMH2,
65
+ 'Arachidic (20:0)': ARA20,
66
+ 'Behenic (22:0)': BEHE,
67
+ 'Total Carotene': CAROTENE,
68
+ 'cis-Monounsaturated Fatty Acids': CISMONO,
69
+ 'Cortisol': CORT,
70
+ '17 Hydroxyprogesterone': CP17HYD,
71
+ 'a-Linolenic (ALA) 18:3 n3': CP183AL,
72
+ 'Linoleic (LA) 18:2n6': CP186LI,
73
+ 'Ferritin (HS)': CP1FERR,
74
+ 'Omega-3 Index (HS)': CP1OMEG,
75
+ 'Testosterone (HS)': CP1TEST,
76
+ 'Vitamin A': CP1VITA,
77
+ 'Vitamin E ': CP1VITAM,
78
+ 'Eicosapentaenoic (EPA) 20:5 n3': CP203EPT,
79
+ 'Arachidonic (AA) 20:4n6': CP206ARA,
80
+ 'Docosahexaen (DHA) 22:6 n3': CP223DHC,
81
+ '25-Hydroxy Vitamin D3': CP25HDV3,
82
+ '25-Hydroxy Vitamin D': CP25HVD,
83
+ '25-Hydroxy Vitamin D2': CP25HVD2,
84
+ 'Homocysteine v2': CP2HOMO,
85
+ 'Vitamin B2 - SpectraCell': CP2VITA,
86
+ 'Vitamin B3 - SpectraCell': CP3VITA,
87
+ 'Vitamin B12 - SpectraCell': CP4VITA,
88
+ 'Vitamin B6 - SpectraCell': CP6VITA,
89
+ 'AA:EPA Ratio': CPAAEPA,
90
+ 'Active B12': CPACTIV,
91
+ 'Active Vitamin B12 (HS)': CPACTIVE,
92
+ 'Albumin': CPALBUM,
93
+ 'Albumin (SD1)': CPALBUMI,
94
+ 'Alpha-Carotene': CPALCAR,
95
+ 'Albumin/Globulin ratio': CPALGLR,
96
+ 'Alkaline Phosphatase': CPALPHO,
97
+ 'Alanine Transaminase': CPALTRA,
98
+ 'Amylase': CPAMYLA,
99
+ 'Anti Inflammatory Index': CPANTII,
100
+ 'Anti-mullerian hormone': CPANTIM,
101
+ 'Apollpoprotein B, P': CPAPOLL,
102
+ 'Activated Partial Thromboplastin Time Ratio': CPAPTTR,
103
+ 'Arginine': CPARGIN,
104
+ 'Arginine (HS)': CPARGINI,
105
+ 'Aspartate Aminotransferase': CPASAM,
106
+ 'Asparagine - SpectraCell': CPASPAR,
107
+ 'Asparagine (HS)': CPASPARA,
108
+ 'AST (SD1)': CPASTSD,
109
+ 'Atypical Lymphocyte': CPATLYMP,
110
+ 'Alpha-Tocopherol (Vit. E)': CPATVTE,
111
+ 'Avg Mins/Game': CPAVGMI,
112
+ 'Active B12 (TH)': CPB12,
113
+ 'Basophil, Absolute': CPBASOA,
114
+ 'Basophil, %': CPBASOP,
115
+ 'Basophil, % only': CPBASOPH,
116
+ 'Bioavailable Testosterone': CPBATT,
117
+ 'B Cells, Absolute': CPBCEAB,
118
+ 'B-Cell %': CPBCELL,
119
+ 'B-Cell Absolute Count US': CPBCELLA,
120
+ 'Coenzyme Q10 - SpectraCell': CPBCOEN,
121
+ 'Bermuda grass': CPBERGR,
122
+ 'Beta Globulin': CPBETAG,
123
+ 'Blood Glucose - Random': CPBGLR,
124
+ 'Serum Bicarbonate': CPBICAR,
125
+ 'Bilirubin, Indirect': CPBILIRU,
126
+ 'Biotin - SpectraCell': CPBIOTI,
127
+ 'Birch pollen': CPBIRPOL,
128
+ 'Blood Glucose - Fasting': CPBLGF,
129
+ 'Diastolic Blood Pressure': CPBPDIA,
130
+ 'Systolic Blood Pressure': CPBPSYS,
131
+ 'Branched-chain amino acids BCAA (HS)': CPBRANC,
132
+ 'Beta-Carotene': CPBTCAR,
133
+ 'Blood Urea Nitrogen/Creatinine Ratio': CPBUNCRR,
134
+ 'BUN/Urea': CPBUREA,
135
+ 'Calcium, ionized': CPCALCI,
136
+ 'Calcium - SpectraCell': CPCALCIU,
137
+ 'Calcium': CPCALCM,
138
+ 'Calcium Osmolality': CPCALCO,
139
+ 'Carnitine - SpectraCell': CPCARNI,
140
+ 'Carnitine (HS)': CPCARNIT,
141
+ 'Cat dander': CPCATDAN,
142
+ 'Calcium for PTH, Intact': CPCCALC,
143
+ 'Adjusted Calcium': CPCCALCM,
144
+ 'Copper - SpectraCell': CPCCOPP,
145
+ 'Choline - SpectraCell': CPCHOLI,
146
+ 'Chloride': CPCHOLR,
147
+ 'Chromium': CPCHROM,
148
+ 'Citrulline (HS)': CPCITRU,
149
+ 'CK (SD1)': CPCKSD1,
150
+ 'Copper': CPCOPP,
151
+ 'Copper (serum) (HS)': CPCOPPE,
152
+ 'Cortisol (HS)': CPCORTI,
153
+ 'Creatinine': CPCREAT,
154
+ 'Creatine Kinase': CPCREATK,
155
+ 'C Reactive Protein ': CPCRREPR,
156
+ 'DHEA-Sulfate Serum': CPDHEAS,
157
+ 'Direct Bilirubin': CPDIRBIL,
158
+ 'Vitamin D3 - SpectraCell': CPDVITA,
159
+ 'Vitamin E (Alpha Tocopherol)': CPEALPH,
160
+ 'Vitamin E (Gamma Tocopherol)': CPEGAMM,
161
+ 'Erythropoietin (EPO)': CPERYTH,
162
+ 'Oestradiol': CPESTR,
163
+ 'Vitamin C': CPEVITA,
164
+ 'Folic Acid Red Blood Cell': CPFACRBC,
165
+ 'Ferritin': CPFERRI,
166
+ 'Fibrinogen': CPFIBRI,
167
+ 'Follicle-Stimulating Hormone': CPFLSH,
168
+ 'Folic Acid': CPFOLAC,
169
+ 'Free Triiodothyronine': CPFRTRII,
170
+ 'Free Testosterone': CPFRTTTE,
171
+ 'Fructose Sensitivity': CPFRUCT,
172
+ 'FSH': CPFSH,
173
+ 'Free Thyroxine': CPFTHYR,
174
+ 'Gamma Globulin': CPGAMMAG,
175
+ 'EGFR Non-African American': CPGFRNAA,
176
+ 'GRA': CPGGGGG,
177
+ 'Gamma-Glutamyl Transpeptidase': CPGGLTRA,
178
+ 'Glomerular Filtration Rate': CPGLFR,
179
+ 'EGFR African American': CPGLFRAA,
180
+ 'Globulins': CPGLOBU,
181
+ 'Glutathione, Total': CPGLU,
182
+ 'Glucose': CPGLUC,
183
+ 'Glutathione - Red Cell': CPGLURC,
184
+ 'Glutamine': CPGLUTA,
185
+ 'Glutamic acid (HS)': CPHGLUT,
186
+ 'Histidine (HS)': CPHISTI,
187
+ 'Homocysteine': CPHOMOL,
188
+ 'hs-CRP': CPHSCRP,
189
+ 'HS-Omega 3': CPHSOM3,
190
+ 'Omega-6 Fatty Acids': CPHSOME,
191
+ 'Insulin': CPINSUL,
192
+ 'Iodine': CPIODIN,
193
+ 'Inorganic Phosphorus': CPIPHOSP,
194
+ 'Iron (TH)': CPIRONT,
195
+ 'Iron Saturation': CPIRSAT,
196
+ 'Isoleucine': CPISOLE,
197
+ 'Glutamine (HS)': CPKGLUT,
198
+ 'Leucine': CPLEUCI,
199
+ 'Magnesium': CPMAGM,
200
+ 'Magnesium (erythrocytes) (HS)': CPMAGNE,
201
+ 'Magnesium, RBC (BRF)': CPMAGNES,
202
+ 'Manganese': CPMANGA,
203
+ 'Manganese - SpectraCell': CPMANGAN,
204
+ 'Mean Corpuscular Haemoglobin': CPMCH,
205
+ 'Mercury': CPMERCU,
206
+ 'Methionine': CPMETHI,
207
+ 'Magnesium': CPMMAGN,
208
+ 'Neutrophil': CPNEUTR,
209
+ 'Non High-Density Lipoprotein Cholesterol': CPNHDLCH,
210
+ 'NKCA %': CPNKCA,
211
+ 'NKCA Per Cell': CPNKCAP,
212
+ 'NK Cell %': CPNKCEL,
213
+ 'NK Cell Absolute': CPNKCELL,
214
+ 'Natural Killer Cells, Absolute': CPNKNCA,
215
+ 'Nucleated RBC': CPNRBC,
216
+ 'Zinc': CPNZINC,
217
+ 'Oleic Acid': CPOLEICA,
218
+ 'Omega 6:3': CPOMEGA,
219
+ 'Omega6 : Omega3 ratio': CPOMG6_3,
220
+ 'Prolactin': CPPROLN,
221
+ 'Prostate Specific Antigen': CPPROSAG,
222
+ 'Vitamin E': CPQQVIT,
223
+ 'Vitamin B1': CPQVITA,
224
+ 'Red Blood Cell Count': CPRBCC,
225
+ 'Vitamin B5': CPRFVIT,
226
+ 'Season Games Played': CPSEASO,
227
+ 'SEGS': CPSEGS,
228
+ 'Selenium': CPSELEN,
229
+ 'Selenium (erythrocytes) (HS)': CPSELENI,
230
+ 'Serine': CPSERI,
231
+ 'Serine - SpectraCell': CPSERIN,
232
+ 'Ferritin (TH)': CPSFERR,
233
+ 'Serum Folate': CPSFOLA,
234
+ 'Serum Glucose - Fasting': CPSGLF,
235
+ 'Serum Glucose - Random': CPSGLR,
236
+ 'Sex Hormone Binding Globulin': CPSHBG,
237
+ 'Serum Inorganic Phosphate': CPSINP,
238
+ 'Serum Iron': CPSIRON,
239
+ 'Serum Lutein': CPSLUTE,
240
+ 'Sodium': CPSODI,
241
+ 'Vitamin B1': CPSPECT,
242
+ 'Selenium': CPSSELE,
243
+ 'Serum Testosterone': CPSTTTE,
244
+ 'Vitamin A': CPSVITA,
245
+ 'Vitamin D': CPSVITAM,
246
+ 'T3 Uptake': CPT3UPT,
247
+ 'Taurine': CPTAUR,
248
+ 'Taurine (HS)': CPTAURI,
249
+ 'T-Cells %': CPTCELL,
250
+ 'T-Cell Absolute': CPTCELLA,
251
+ 'T:C Ratio': CPTCRAT,
252
+ 'T:C Ratio (HS)': CPTCRATI,
253
+ 'Total Daily Cortisol': CPTDCORT,
254
+ 'Testosterone Total Female': CPTESTO,
255
+ 'Free Testosterone Female': CPTESTOS,
256
+ 'Threonine': CPTHREO,
257
+ 'Total Cholesterol': CPTOTCHO,
258
+ 'Total Testosterone': CPTOTTTE,
259
+ 'Transferrin': CPTRANSF,
260
+ 'Triglycerides': CPTRYG,
261
+ 'Tryptophan (HS)': CPTRYPT,
262
+ 'Thyroid Stimulating Hormone': CPTSHBR,
263
+ 'Testosterone': CPTTTE,
264
+ 'Total Vitamin D': CPTVD,
265
+ 'Typical Cycle Length': CPTYPIC,
266
+ 'Tryptophan': CPTYPTO,
267
+ 'Tyrosine (HS)': CPTYROS,
268
+ 'Vitamin A': CPVITAM,
269
+ 'Vitamin B1': CPVITB1,
270
+ 'Vitamin B12': CPVITB12,
271
+ 'Vitamin B2': CPVITB2,
272
+ 'Vitamin B6': CPVITB6,
273
+ 'Vitamin C': CPVITC,
274
+ 'VLDL Cholesterol Cal': CPVLDLC,
275
+ 'Vitamin E': CPVVITA,
276
+ 'White Blood Cell Count': CPWBCC,
277
+ 'Vitamin D 1,25': CPWVITA,
278
+ 'Zinc': CPZINC,
279
+ 'LDL/HDL Ratio': LDL - HDL,
280
+ 'Omega 3 Fatty Acids': O3FA,
281
+ 'Omega 3 Index': O3INDEX
282
+ "
283
+
284
+ Rules to follow: "
285
+ - Use the "draw" tool to add or update a dataset in the chart.
286
+ - Use the "updateDateRange" tool to update the date range of the chart.
287
+ - Use the "removeDataset" tool to remove a dataset from the chart.
288
+ - Use the "clearChart" tool to clear the chart.
289
+ - If the user wants you to perform multiple actions, you should return multiple function calls. For example, if the user says "Draw A and B", you should return two draw function calls.
290
+ - You can use the same tool multiple times in a response.
291
+ - The only values supported for "chartType" are "line" and "bar".
292
+ - If the user provides a metric name, you should return the metric id. If the user does not provide a metric name, you should return the previous metric id.
293
+ - Assume metric names if they are misspelled or shortened. For example, "fourth data" should be assumed to be "fort data".
294
+ - Never return a metric id that is not in "Metrics" and return the metric id exactly as it is formatted in "Metrics".
295
+ - If the user says something like "for the last 2 months", you must calculate the start date based on today's date.
296
+ - Only use the "removeDataset", "removeHighlight" or "clearChart" tools if the user has explicitly asked for it.
297
+ - "removeDataset" will remove a particular dataset. 'clearChart' will completely clear the chart. If the user mentions an athlete or metric, you should assume they want to remove a dataset and not clear the chart.
298
+ "
299
+
300
+ If a new chart type or metric is not provided, use the previous values as we can assume the user wants to continue with the same options.
301
+ - The previous chart type is line.
302
+ - The previous metric is CPFORD.
303
+ """
304
+
305
+ user_input = data.pop("inputs", data)
306
+ alpaca_prompt = f"""Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.
307
+
308
+ ### Instruction:
309
+ {instruction}
310
+
311
+ ### Input:
312
+ {user_input}
313
+
314
+ ### Response:
315
+ """
316
+ prompt_tokenized = tokenizer(alpaca_prompt, return_tensors="pt").to("cuda")
317
+ with torch.no_grad():
318
+ response = tokenizer.decode(model.generate(**prompt_tokenized, max_new_tokens=128, temperature=0.01)[0], skip_special_tokens=True)
319
+ response_content = response.split("### Response:")[1].strip()
320
+ try:
321
+ function_calls = json.loads(response_content);
322
+ return { "response": function_calls }
323
+ except json.JSONDecodeError as e:
324
+ print("Failed to parse function calls from response")
325
+ return { "response": "Unable to generate response"}
326
+
327
+
328
+
329
+
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git
2
+ xformers<0.0.27
3
+ trl<0.9.0
4
+ peft
5
+ accelerate
6
+ bitsandbytes
7
+ torch==2.3.0
8
+ json