nevi1's picture
Upload 244 files
73f4c20
raw
history blame
2.82 kB
# stuff specifically for the sklearn logic
from typing import Mapping
from functools import partial, reduce
import operator
from itertools import product
import argparse
###############################################################################
# A grid search convenience class
###############################################################################
class ParameterGrid:
"""logic YOINKED from sklearn <3
def worth just using the lib itself, or something fancier in future for
efficient sampling etc. It's implemented as an iterator interface but thats
probs not necessary"""
def __init__(self, params):
# we may want to product a few sets of parameters
# independently of eachother, so expects a List[Mapping]
if isinstance(params, Mapping):
self.params = [params]
else:
self.params = params
# removed all checking code soooo make sure your
# param dict is already nice and conforming
def __iter__(self):
"""Iterate over the points in the grid.
Returns
-------
params : iterator over dict of str to any
Yields dictionaries mapping each estimator parameter to one of its
allowed values.
"""
for p in self.params:
# Always sort the keys of a dictionary, for reproducibility
items = sorted(p.items())
if not items:
yield {}
else:
keys, values = zip(*items)
for v in product(*values):
params = dict(zip(keys, v))
yield params
def __len__(self):
"""Number of points on the grid."""
# Product function that can handle iterables (np.product can't).
product = partial(reduce, operator.mul)
return sum(
product(len(v) for v in p.values()) if p else 1 for p in self.params
)
###############################################################################
# little "oneliner" reduce thingy that turns your shallow dict into
# the list [k1, v1, k2, v2, k3, v3 ...]
# and optionally "k1 v1 k2 v2 k3 v3"
def flatten_dict(dict, to_string=False, sep=" "):
flat_dict = reduce(operator.iconcat,dict.items() , [])
if to_string:
try:
return sep.join([str(elm) for elm in flat_dict])
except:
raise ValueError(f'Error converting dict={flat_dict} to whitespace joined string')
else:
return flat_dict
def str2bool(v):
if isinstance(v, bool):
return v
if v.lower() in ('yes', 'true', 't', 'y', '1'):
return True
elif v.lower() in ('no', 'false', 'f', 'n', '0'):
return False
else:
raise argparse.ArgumentTypeError('Boolean value expected.')