"""
Datasets group categories together. Method calls on datasets invoke the individual methods
on the individual categories. Cuts applied to datasets will act on each individual category.
"""
from __future__ import division
import pandas as pd
import numpy as np
from collections import OrderedDict
from copy import deepcopy as copy
from ..plotting import VariableDistributionPlot
from ..utils import isnotebook
from ..utils.logger import Logger
from dashi.tinytable import TinyTable
from builtins import object
from . import categories
[docs]def get_label(category):
"""
Get the label for labeling plots from a datasets plot_options dictionary.
Args:
category (HErmes.selection.categories.category): Query the category's plot_options dict, if not fall back to category.name
Returns:
string
"""
if category.plot_options:
if "label" in category.plot_options:
return category.plot_options["label"]
else:
return category.name
else:
return category.name
[docs]class Dataset(object):
"""
Holds different categories, relays calls to each
of them.
"""
def __init__(self, *args, **kwargs):
"""
Args:
*args: HErmes.selection.variables.categories.Category list
Keyword Args:
combined_categories:
"""
self.categories = []
self.combined_categories = []
# sort categories, do reweighted simulation last
# FIXME: if not, there will be problems
# FIXME: investigate!
reweighted_categories = []
self.default_plotstyles = {}
for cat in args:
self.__dict__[cat.name] = cat
if isinstance(cat,categories.ReweightedSimulation):
reweighted_categories.append(cat)
continue
self.categories.append(cat)
self.categories = self.categories + reweighted_categories
if 'combined_categories' in kwargs:
for name in list(kwargs['combined_categories'].keys()):
self.combined_categories.append(categories.CombinedCategory(name,kwargs['combined_categories'][name]))
[docs] def set_default_plotstyles(self, styledict):
"""
Define a standard for each category how
it should appear in plots
Args:
styledict (dict)
"""
self.default_plotstyles = styledict
for cat in self.categorynames:
self[cat].add_plotoptions(styledict[cat])
[docs] def add_variable(self, variable):
"""
Add a variable to this category
Args:
variable (HErmes.selection.variables.variables.Variable): A Variable instalce
"""
for cat in self.categories:
cat.add_variable(variable)
[docs] def delete_variable(self, varname):
"""
Delete a variable entirely from the dataset
Args:
varname (str): the name of the variable
Returns:
None
"""
for cat in self.categories:
cat.delete_variable(varname)
[docs] def load_vardefs(self, vardefs):
"""
Load the variable definitions from a module
Args:
vardefs (python module/dict): A module needs to contain variable definitions.
It can also be a dictionary of categoryname->module
"""
if isinstance(vardefs, dict):
for k in vardefs:
# FIXME: the way over self.__dict__ does not work
# maybe there is something more fishy...
if str(k) == "all":
for cat in self.categories:
cat.load_vardefs(vardefs[k])
for cat in self.categories:
if cat.name == k:
cat.load_vardefs(vardefs[k])
#self.__dict__[k].load_vardefs(vardefs)
else:
for cat in self.categories:
cat.load_vardefs(vardefs)
@property
def variablenames(self):
return {cat.name : cat.variablenames for cat in self.categories}
@property
def files(self):
return {cat.name : cat.files for cat in self.categories}
#@GetTiming
[docs] def read_variables(self, names=None, max_cpu_cores=categories.MAX_CORES):
"""
Read out the variable for all categories
Keyword Args:
names (str): Readout only these variables if given
max_cpu_cores (int): Maximum number of cpu cores which will be used
Returns:
None
"""
progbar = False
try:
import tqdm
n_it = len(self.categories)
loader_string = "Loading dataset"
if isnotebook():
bar = tqdm.tqdm_notebook(total=n_it, desc=loader_string, leave=True)
else:
bar = tqdm.tqdm(total=n_it, desc=loader_string, leave=True)
progbar = True
except ImportError:
pass
for cat in self.categories:
Logger.debug("Reading variables for {}".format(cat))
cat.read_variables(names=names, max_cpu_cores=max_cpu_cores)
if progbar: bar.update()
[docs] def drop_empty_variables(self):
"""
Delete variables which have no len
Returns:
None
"""
for cat in self.categories:
cat.drop_empty_variables()
[docs] def set_weightfunction(self, weightfunction=lambda x:x):
"""
Defines a function which is used for weighting
Args:
weightfunction (func or dict): if func is provided, set this to all categories
if needed, provide dict, cat.name -> func for individula setting
Returns:
None
"""
if isinstance(weightfunction, dict):
for cat in self.categories:
cat.set_weightfunction(weightfunction[cat.name])
else:
for cat in self.categories:
cat.set_weightfunction(weightfunction)
[docs] def calculate_weights(self, model=None, model_args=None):
"""
Calculate the weights for all categories
Keyword Args:
model (dict/func) : Either a dict catname -> func or a single func
If it is a single funct it will be applied to all categories
model_args (dict/list): variable names as arguments for the function
"""
if isinstance(model, dict):
if not isinstance(model_args, dict):
raise ValueError("if model is a dict, model_args has to be a dict too!")
for catname in model:
self.get_category(catname).calculate_weights(model=model[catname], model_args=model_args[catname])
else:
for cat in self.categories:
cat.calculate_weights(model=model, model_args=model_args)
# def get_weights(self, models):
# """
# Calculate the weights for all categories
#
# Args:
# models (dict or callable): A dictionary of categoryname -> model or a single clbl
# """
# if isinstance(models, dict):
# for catname in models:
# self.get_category(catname).get_weights(models[catname])
# if callable(models):
# for cat in self.categories:
# cat.get_weights(models)
[docs] def add_category(self,category):
"""
Add another category to the dataset
Args:
category (HErmes.selection.categories.Category): add this category
"""
self.categories.append(category)
def __getitem__(self, item):
"""
Shortcut for self.get_category/get_variable
Args:
item:
Returns:
HErmes.selection.variables.Variable/HErmes.selection.categories.Category
"""
try:
return self.get_category(item)
except KeyError:
pass
try:
return self.get_variable(item)
except KeyError:
pass
if ":" in item:
cat, var = item.split(":")
return self.get_category(cat).get(var)
else:
raise KeyError("{} can not be found".format(item))
[docs] def get_category(self, categoryname):
"""
Get a reference to a category.
Args:
category: A name which has to be associated to a category
Returns:
HErmes.selection.categories.Category
"""
for cat in self.categories:
if cat.name == categoryname:
return cat
raise KeyError("Can not find category {}.".format(categoryname))
[docs] def get_variable(self, varname):
"""
Get a pandas dataframe for all categories
Args:
varname (str): A name of a variable
Returns:
pandas.DataFrame: A 2d dataframe category -> variable
"""
var = dict()
for cat in self.categories:
var[cat.name] = cat.get(varname)
return pd.DataFrame.from_dict(var, orient="index")
[docs] def set_livetime(self, livetime):
"""
Define a livetime for this dataset.
Args:
livetime (float): Time interval the data was taken in. (Used for rate calculation)
Returns:
None
"""
for cat in self.categories:
if hasattr(cat, "set_livetime"):
cat.set_livetime(livetime)
@property
def weights(self):
"""
Get the weights for all categories in this dataset
"""
w = dict()
for cat in self.categories:
w[cat.name] = cat.weights
print (w)
return pd.DataFrame.from_dict(w,orient='index')
def __repr__(self):
"""
String representation
"""
rep = """ <Dataset: """
for cat in self.categories:
rep += "{} ".format(cat.name)
rep += ">"
return rep
[docs] def add_cut(self,cut):
"""
Add a cut without applying it yet
Args:
cut (HErmes.selection.variables.cut.Cut): Append this cut to the internal cutlist
"""
for cat in self.categories:
cat.add_cut(cut)
[docs] def apply_cuts(self,inplace=False):
"""
Apply them all!
"""
for cat in self.categories:
cat.apply_cuts(inplace=inplace)
[docs] def undo_cuts(self):
"""
Undo previously done cuts, but keep them so that
they can be re-applied
"""
for cat in self.categories:
cat.undo_cuts()
[docs] def delete_cuts(self):
"""
Completely purge all cuts from this
dataset
"""
for cat in self.categories:
cat.delete_cuts()
@property
def categorynames(self):
return [cat.name for cat in self.categories]
@property
def combined_categorynames(self):
return [cat.name for cat in self.combined_categories]
[docs] def get_sparsest_category(self, omit_empty_cat=True):
"""
Find out which category of the dataset has the least statistical power
Keyword Args:
omit_empty_cat (bool): if a category has no entries at all, omit
Returns:
str: category name
"""
name = self.categories[0].name
count = self.categories[0].raw_count
for cat in self.categories:
if cat.raw_count < count:
if (cat.raw_count == 0) and omit_empty_cat:
continue
count = cat.raw_count
name = cat.name
return name
[docs] def distribution(self,name,\
ratio=([],[]),
cumulative=True,
log=False,
transform=None,
color_palette='dark',
normalized = False,
styles = dict(),
style="classic",
ylabel="rate/bin [1/s]",
axis_properties=None,
ratiolabel="data/$\Sigma$ bg",
bins=None,
external_weights=None,
figure_factory=None):
"""
One shot short-cut for one of the most used
plots in eventselections
Args:
name (string): The name of the variable to plot
Keyword Args:
path (str): The path under which the plot will be saved.
ratio (list): A ratio plot of these categories will be crated
color_palette (str): A predifined color palette (from seaborn or HErmes.plotting.colors)
normalized (bool): Normalize the histogram by number of events
transform (callable): Apply this transformation before plotting
styles (dict): plot styling options
ylabel (str): general label for y-axis
ratiolabel (str): different label for the ratio part of the plot
bins (np.ndarray): binning, if None binning will be deduced from the variable definition
figure_factory (func): factory function which return a matplotlib.Figure
style (string): TODO "modern" || "classic" || "modern-cumul" || "classic-cumul"
external_weights (dict): supply external weights - this will OVERIDE ANY INTERNALLY CALCULATED WEIGHTS and use the supplied weights instead.
must be in the form { "categoryname" : weights}
axis_properties (dict): Manually define a plot layout with up to three axes.
For example, it can look like this:
{
"top": {"type": "h", # histogram
"height": 0.4, # height in percent
"index": 2}, # used internally
"center": {"type": "r", # ratio plot
"height": 0.2,
"index": 1},
"bottom": { "type": "c", # cumulative histogram
"height": 0.2,
"index": 0}
}
Returns:
HErmes.selection.variables.VariableDistributionPlot
"""
# if (not cumulative) or ratio == ([],[]):
#
# # assuming a single cumulative axis
# tmp_axis_properties = dict()
# unassigned_height = 0
#
# for key in axis_properties:
# if ("c" == axis_properties[key]["type"]) and (not cumulative):
# unassigned_height += axis_properties[key]["height"]
# continue
# if ("r" == axis_properties[key]["type"]) and (ratio == ([],[])):
# unassigned_height += axis_properties[key]["height"]
# continue
#
# tmpdict = copy(axis_properties[key])
# tmpdict["index"] = tmpdict["index"] -1 - bool(ratio == ([],[]))
# tmp_axis_properties.update({key : tmpdict})
#
# n_plots = len(tmp_axis_properties.keys())
# extra_height = unassigned_height/float(n_plots)
# for key in tmp_axis_properties:
# tmp_axis_properties[key]["height"] += extra_height
#
# else:
# tmp_axis_properties = copy(axis_properties)
if axis_properties is not None:
tmp_axis_properties = copy(axis_properties)
else:
# always have the histogram, but add
# cumulative or ratio plot
if cumulative and ratio != ([],[]):
tmp_axis_properties = {\
"top": {"type": "h", \
"height": 0.4, \
"index": 2},\
"center": {"type": "r",\
"height": 0.2,\
"index": 1},\
"bottom": {"type": "c", \
"height": 0.2,\
"index": 0}\
}
elif cumulative:
tmp_axis_properties = { \
"top": {"type": "h", \
"height": 0.6, \
"index": 1}, \
"bottom": {"type": "c", \
"height": 0.4, \
"index": 0} \
}
elif ratio != ([],[]):
tmp_axis_properties = { \
"top": {"type": "h", \
"height": 0.6, \
"index": 1}, \
"bottom": {"type": "r", \
"height": 0.4, \
"index": 0} \
}
else:
tmp_axis_properties = { \
"top": {"type": "h", \
"height": 0.95, \
"index": 0}, \
}
axes_locator = [(tmp_axis_properties[k]["index"], tmp_axis_properties[k]["type"], tmp_axis_properties[k]["height"])\
for k in tmp_axis_properties]
#print (axes_locator)
#heights = [axis_properties[k]["height"] for k in axis_properties]
cuts = self.categories[0].cuts
sparsest = self.get_sparsest_category()
# check if there are user-defined bins for that variable
if bins is None:
bins = self.get_category(sparsest).vardict[name].bins
# calculate the best possible binning
if bins is None:
bins = self.get_category(sparsest).vardict[name].calculate_fd_bins()
label = self.get_category(sparsest).vardict[name].label
plot = VariableDistributionPlot(cuts=cuts, bins=bins,\
xlabel=label,\
color_palette=color_palette)
if styles:
plot.plot_options = styles
else:
plot.plot_options = self.default_plotstyles
plotcategories = self.categories + self.combined_categories
Logger.warn("For variables with different lengths the weighting is broken. If weights, it will fail")
for cat in [x for x in plotcategories if x.plot]:
if external_weights is None:
weights = None
else:
weights = external_weights[cat.name]
plot.add_variable(cat, name, transform=transform, external_weights=weights)
Logger.debug("Adding variable data {}".format(name))
if cumulative:
Logger.debug("Adding variable data {} for cumulative plot".format(name))
plot.add_cumul(cat.name)
if len(ratio[0]) and len(ratio[1]):
tratio,tratio_err = self.calc_ratio(nominator=ratio[0],\
denominator=ratio[1])
plot.add_ratio(ratio[0],ratio[1],\
total_ratio=tratio,\
label=ratiolabel,
total_ratio_errors=tratio_err)
plot.plot(axes_locator=axes_locator,\
normalized=normalized,\
figure_factory=figure_factory,\
log=log,
ylabel=ylabel)
#plot.add_legend()
#plot.canvas.save(savepath,savename,dpi=350)
return plot
@property
def integrated_rate(self):
"""
Integrated rate for each category
Returns:
pandas.Panel: rate with error
"""
rdata,edata,index = [],[],[]
for cat in self.categories + self.combined_categories:
rate,error = cat.integrated_rate
rdata.append(rate)
index.append(cat.name)
edata.append(error)
rate = pd.Series(rdata,index)
err = pd.Series(edata,index)
return (rate,err)
#FIXME static method!
[docs] def sum_rate(self,categories=None):
"""
Sum up the integrated rates for categories
Args:
categories: categories considerred background
Returns:
tuple: rate with error
"""
if categories is None:
return 0,0
categories = [self.get_category(i) if isinstance(i, str) else i for i in categories]
rate,error = categories[0].integrated_rate
error = error**2
for cat in categories[1:]:
tmprate,tmperror = cat.integrated_rate
rate += tmprate # categories should be independent
error += tmperror**2
return (rate,np.sqrt(error))
[docs] def calc_ratio(self,nominator=None,denominator=None):
"""
Calculate a ratio of the given categories
Args:
nominator (list):
denominator (list):
Returns:
tuple
"""
nominator = [self.get_category(i) if isinstance(i, str) else i for i in nominator]
denominator = [self.get_category(i) if isinstance(i, str) else i for i in denominator]
a,a_err = self.sum_rate(categories=nominator)
b,b_err = self.sum_rate(categories=denominator)
if b == 0:
return np.nan, np.nan
sum_err = np.sqrt((a_err/ b) ** 2 + ((-a * b_err)/ (b ** 2)) ** 2)
return a/b, sum_err
def _setup_table_data(self,signal=None,background=None):
"""
Setup data for a table
If signal and background are given, also summed values
will be in the list
Keyword Args:
signal (list): category names which are considered signal
background (list): category names which are considered background
Returns
dict: table dictionary
"""
rates, errors = self.integrated_rate
sgrate, sgerrors = self.sum_rate(signal)
bgrate, bgerrors = self.sum_rate(background)
allrate, allerrors = self.sum_rate(self.categories)
tmprates = pd.Series([sgrate,bgrate,allrate],index=["signal","background","all"])
tmperrors = pd.Series([sgerrors,bgerrors,allerrors],index=["signal","background","all"])
rates = rates.append(tmprates)
errors = errors.append(tmperrors)
datacats = []
for cat in self.categories + self.combined_categories:
if isinstance(cat,categories.Data):
datacats.append(cat)
if datacats:
simcats = [cat for cat in self.categories if cat.name not in [kitty.name for kitty in datacats]]
simrate, simerror = self.sum_rate(simcats)
fudges = dict()
for cat in datacats:
rate,error = cat.integrated_rate
try:
fudges[cat.name] = (rate/simrate),(error/simerror)
except ZeroDivisionError:
fudges[cat.name] = np.NaN
rate_dict = OrderedDict()
all_fudge_dict = OrderedDict()
#for catname in sorted(self.categorynames) + sorted(self.combined_categorynames):
for cat in datacats:
label = get_label(cat)
#cfg = GetCategoryConfig(cat.name)
#label = cfg["label"]
rate_dict[label] = (rates[cat.name], errors[cat.name])
if cat.name in fudges:
all_fudge_dict[label] = fudges[cat.name]
else:
all_fudge_dict[label] = None
rate_dict["Sig."] = (rates["signal"],errors["signal"] )
rate_dict["Bg."] = (rates["background"],errors["background"])
rate_dict["Gr. Tot."] = (rates["all"],errors["all"])
all_fudge_dict["Sig."] = None
all_fudge_dict["Bg."] = None
all_fudge_dict["Gr. Tot."] = None
return rate_dict,all_fudge_dict
[docs] def tinytable(self,signal=None,\
background=None,\
layout="v",\
format="html",\
order_by=lambda x:x,
livetime=1.):
"""
Use dashi.tinytable.TinyTable to render a nice
html representation of a rate table
Args:
signal (list) : summing up signal categories to calculate total signal rate
background (list): summing up background categories to calculate total background rate
layout (str) : "v" for vertical, "h" for horizontal
format (str) : "html","latex","wiki"
Returns:
str: formatted table in desired markup
"""
def cellformatter(input):
#print input
if input is None:
return "-"
if isinstance(input[0],pd.Series):
input = (input[1][0],input[1][0])
return "{:4.2e} +- {:4.2e}".format(input[0],input[1])
#FIXME: sort the table columns
rates,fudges = self._setup_table_data(signal=signal,background=background)
events = dict()
for k in rates:
events[k] = rates[k][0] * livetime, rates[k][1] * livetime
showcats = [get_label(cat) for cat in self.categories if cat.show_in_table]
showcats += [get_label(cat) for cat in self.combined_categories if cat.show_in_table]
showcats.extend(['Sig.',"Bg.","Gr. Tot."])
orates = OrderedDict()
ofudges = OrderedDict()
oevents = OrderedDict()
for k in list(rates.keys()):
if k in showcats:
orates[k] = rates[k]
ofudges[k] = fudges[k]
oevents[k] = events[k]
#rates = {k : rates[k] for k in rates if k in showcats}
#fudges = {k : fudges[k] for k in fudges if k in showcats}
#events = {k : events[k] for k in events if k in showcats}
tt = TinyTable()
#bypass the add function ot add an ordered dict
for label,data in [('Rate (1/s)', orates),("Ratio", ofudges),("Events",oevents)]:
tt.x_labels.append(label)
tt.label_data[label] = data
#tt.add("Rate (1/s)", **rates)
#tt.add("Ratio",**fudges)
#tt.add("Events",**events)
return tt.render(layout=layout,format=format,\
format_cell=cellformatter,\
order_by=order_by)
#def cut_progression_table(self,cuts,\
# signal=None,\
# background=None,\
# layout="v",\
# format="html",\
# order_by=lambda x:x,
# livetime=1.):
# self.delete_cuts()
# self.undo_cuts()
# for cut in cuts:
# self.add_cut(cut)
# self.apply_cuts()
def __len__(self):
#FIXME: to be implemented
raise NotImplementedError