File size: 7,954 Bytes
d6674c7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173

import json
from io import StringIO
from tqdm import tqdm
import csv
from sqlalchemy import create_engine
from yallmf.utils import run_with_timeout
import pandas as pd
import numpy as np
import os
import openai
 
openai.api_key = os.getenv("OPENAI_API_KEY")

INPUT_FILE = os.path.expanduser('data/dataclean_input.csv')
OUTPUT_FILE = os.path.expanduser('data/dataclean_output.csv')
# OUTPUT_FILE = os.path.expanduser('~/data/aiclean/output.csv')
# CONFIGFILE = os.path.expanduser('~/config/cookies-dataclean.json')

def get_db_engine():
    with open(CONFIGFILE) as f:
        j = json.load(f)
        dbconnstr=j['DB_CONN_STR']
    return create_engine(dbconnstr, 
        executemany_mode='batch',
        executemany_batch_page_size=1000)


def clean_data(
    input_product_names: pd.Series, 
    input_brands: pd.Series,
    input_product_categories: pd.Series,
    category_taxonomy: dict):

    output_cols = ['brand', 'product_category', 'sub_product_category', 'strain_name']
    ncols = len(output_cols)
    
    p1 = f'''
    I am going to provide a data set of marijuana products and their metadata. Using the information I provide, I want you to provide me with the following information about the products.

    - Brand (brand)
    - product category (product_category)
    - sub product category (sub_product_category)
    - strain name (strain_name)

    The following JSON shows all the acceptable Product Categories and their Sub Product Categories. Strictly adhere to the below mapping for valid product_category to sub_product_category relationships:

    {json.dumps(category_taxonomy)}
    
    Additional requirements: 

    - The input data set in CSV format, with commas as field delimiter and newline as row delimiter.
    - Do not automatically assume that the information in the data set I provide is accurate.
    - Leave the 'sub_product_category' field blank unless there's a clear and direct match with one of the categories provided in the list.If there is no explicit information to confidently assign a sub_product_category, default to leaving it blank.
    - Strain names are only applicable for the following product categories: concentrate, preroll, vape, flower
    - Look for clues in the product name to determine what brand/ product category/ sub product category/ and strain name the product should fall under. For Vape products, consider the words before 'Cartridge' or 'Cart' in the product name as potential strain names.
    - Every row of the Output CSV must have EXACTLY {ncols} columns.
    - When a field is left empty (e.g., 'sub_product_category' or 'strain_name'), simply leave it empty without placing an additional comma. Each row in the output CSV should always have only three commas separating the four fields regardless of whether some fields are empty. For instance, if 'sub_product_category' and 'strain_name' are empty, a row would look like this: "brand,product_category,,"
    - DO NOT EXPLAIN YOURSELF, ONLY RETURN A CSV WITH THESE COLUMNS: {', '.join(output_cols)}

    Input data set in CSV format:

    '''
    df = pd.DataFrame({'input__product_name':input_product_names,
                       'input__brand':input_brands,
                       'input__product_category':input_product_categories}).reset_index(drop=True)
    # remove commas from all strings
    df2 = df.copy()
    for col in df2.columns:
        df2[col] = df2[col].str.replace(',', '')
    
    # send to LLM
    p2 = df2.to_csv(index=False, quoting=csv.QUOTE_ALL)
    messages = [{'role':'system','content':'You are a helpful assistant. Return a properly-formatted CSV with the correct number of columns.'},
                 {'role':'user', 'content':p1+p2+'\n\nOutput CSV with header row:\n\n'}
                 ]
    comp = run_with_timeout(openai.ChatCompletion.create,
        model='gpt-4',
        messages=messages,
        max_tokens=2000,
        timeout=300,
        temperature=0.2
        )
    res = comp['choices'][0]['message']['content']

    # remove rows with wrong number of columns
    keeprows = []
    for i,s in enumerate(res.split('\n')):
        if i==0:
            keeprows.append(s)
            continue
        _ncols = len(s.split(','))
        if _ncols!=ncols:
            print(f'Got {_ncols} columns, skipping row {i-1} ({s})')
            df = df.drop(i-1)
        else:
            keeprows.append(s)
    df = df.reset_index(drop=True)

    resdf = pd.read_csv(StringIO('\n'.join(keeprows)))

    assert len(df)==len(resdf), 'Result CSV did not match input CSV in length'
    df = pd.concat([df.reset_index(drop=True),resdf.reset_index(drop=True)],axis=1)
    # check category/subcategory
    dropidxs=[]
    for idx, row in df.iterrows():
        drop = False
        if pd.isna(row['product_category']) and not pd.isna(row['sub_product_category']):
            drop=True
            print('product_category is null while sub_product_category is not null, dropping')
        if not pd.isna(row['product_category']):
            if row['product_category'] not in category_taxonomy.keys():
                print(f'category "{row["product_category"]}" not in taxonomy, dropping row')
                drop =True
            elif not pd.isna(row['sub_product_category']):
                if row['sub_product_category'] not in category_taxonomy[row['product_category']]:
                    print(f'subcategory "{row["sub_product_category"]}" not valid for category {row["product_category"]}, dropping row')
                    drop =True
        if drop:
            dropidxs.append(idx)
    df = df.drop(dropidxs)

    return df

def get_key(df):
    return df['input__product_name']+df['input__brand']+df['input__product_category']

def main(input_file=INPUT_FILE, output_file=OUTPUT_FILE, chunksize=30):
    category_taxonomy = {
        "Wellness": ["Mushroom Caps", "CBD Tincture/Caps/etc", "Promo/ Sample", "Capsule", "Liquid Flower", ""],
        "Concentrate": ["Diamonds", "Shatter", "Sugar", "Promo/ Sample", "Badder", "Diamonds and Sauce", "Rosin", "Cookies Dough", "Flan", "Cookie Dough", ""],
        "Preroll": ["Cubano", "Joint", "Promo/ Sample", "Blunt", "Infused Joint", "Packwoods Blunt", "Infused Blunt", "Napalm", ""],
        "Vape": ["Terp Sauce", "Gpen 0.5", "Cured Resin", "Solventless Rosin", "510", "Dry Flower Series", "Natural Terp Series", "Promo/ Sample", "Dart Pod 0.5", "Raw Garden", "Live Flower Series", "Rosin", "Disposable", ""],
        "Edible": ["Cookies", "Gummies", "Mint", "Promo/ Sample", "Beverage", "Chocolate", ""],
        "Grow Products": ["Promo/ Sample", ""],
        "Flower": ["Promo/ Sample", "Bud", ""],
        "Accessory": ["Promo/ Sample", ""]
    }

    # expects input__product_name, input__brand, input__product_category
    dfin = pd.read_csv(input_file)
    # expects same as above + output: brand, product_category, sub_product_category, strain_name
    dfout = None
    try:
        dfout = pd.read_csv(output_file)
    except FileNotFoundError:
        pass
    
    # join together and get the diff
    dfin['key'] = get_key(dfin)
    dfin=dfin.set_index('key')
    if dfout is None:
        rundf = dfin
        outlen = 0
    else:
        dfout['key'] = get_key(dfout)
        dfout=dfout.set_index('key')
        rundf = dfin.loc[~dfin.index.isin(dfout.index)]
        outlen = len(dfout)

    print(f'''Input size {len(dfin)}, Output size {outlen}, still to process {len(rundf)}, chunksize {chunksize}. Processing...''')
    for _, chunk in tqdm(rundf.groupby(np.arange(len(rundf)) // chunksize)):
        result = clean_data(chunk['input__product_name'], chunk['input__brand'], chunk['input__product_category'], category_taxonomy)
        result['key'] = get_key(result)
        result = result.set_index('key')
        if dfout is None:
            dfout = result
        else:
            dfout = pd.concat([dfout,result])
        dfout.to_csv(output_file, index=False)

if __name__=='__main__':
    main()