import html
from featuretools.feature_base.feature_base import (
AggregationFeature,
DirectFeature,
FeatureOutputSlice,
IdentityFeature,
TransformFeature
)
from featuretools.feature_base.feature_descriptions import describe_feature
from featuretools.utils.plot_utils import (
check_graphviz,
get_graphviz_format,
save_graph
)
TARGET_COLOR = '#D9EAD3'
TABLE_TEMPLATE = '''<
<TABLE BORDER="0" CELLBORDER="1" CELLSPACING="0" CELLPADDING="10">
<TR>
<TD colspan="1" bgcolor="#A9A9A9"><B>{dataframe_name}</B></TD>
</TR>{table_cols}
</TABLE>>'''
COL_TEMPLATE = '''<TR><TD ALIGN="LEFT" port="{}">{}</TD></TR>'''
TARGET_TEMPLATE = '''
<TR>
<TD ALIGN="LEFT" port="{}" BGCOLOR="{target_color}">{}</TD>
</TR>'''.format('{}', '{}', target_color=TARGET_COLOR)
[docs]def graph_feature(feature, to_file=None, description=False, **kwargs):
'''Generates a feature lineage graph for the given feature
Args:
feature (FeatureBase) : Feature to generate lineage graph for
to_file (str, optional) : Path to where the plot should be saved.
If set to None (as by default), the plot will not be saved.
description (bool or str, optional): The feature description to use as a caption
for the graph. If False, no description is added. Set to True
to use an auto-generated description. Defaults to False.
kwargs (keywords): Additional keyword arguments to pass as keyword arguments
to the ft.describe_feature function.
Returns:
graphviz.Digraph : Graph object that can directly be displayed in Jupyter notebooks.
'''
graphviz = check_graphviz()
format_ = get_graphviz_format(graphviz=graphviz,
to_file=to_file)
# Initialize a new directed graph
graph = graphviz.Digraph(feature.get_name(), format=format_,
graph_attr={'rankdir': 'LR'})
dataframes = {}
edges = ([], [])
primitives = []
groupbys = []
_, max_depth = get_feature_data(feature, dataframes, groupbys, edges, primitives, layer=0)
dataframes[feature.dataframe_name]['targets'].add(feature.get_name())
for df_name in dataframes:
dataframe_name = '\u2605 {} (target)'.format(df_name) if df_name == feature.dataframe_name else df_name
dataframe_table = get_dataframe_table(dataframe_name, dataframes[df_name])
graph.attr('node', shape='plaintext')
graph.node(df_name, dataframe_table)
graph.attr('node', shape='diamond')
num_primitives = len(primitives)
for prim_name, prim_label, layer, prim_type in primitives:
step_num = max_depth - layer
if num_primitives == 1:
type_str = '<FONT POINT-SIZE="12"><B>{}</B><BR></BR></FONT>'.format(prim_type) if prim_type else ''
prim_label = '<{}{}>'.format(type_str, prim_label)
else:
step = 'Step {}'.format(step_num)
type_str = ' ' + prim_type if prim_type else ''
prim_label = '<<FONT POINT-SIZE="12"><B>{}:</B>{}<BR></BR></FONT>{}>'.format(step, type_str, prim_label)
# sink first layer transform primitive if multiple primitives
if step_num == 1 and prim_type == 'Transform' and num_primitives > 1:
with graph.subgraph() as init_transform:
init_transform.attr(rank='min')
init_transform.node(name=prim_name, label=prim_label)
else:
graph.node(name=prim_name, label=prim_label)
graph.attr('node', shape='box')
for groupby_name, groupby_label in groupbys:
graph.node(name=groupby_name, label=groupby_label)
graph.attr('edge', style='solid', dir='forward')
for edge in edges[1]:
graph.edge(*edge)
graph.attr('edge', style='dotted', arrowhead='none', dir='forward')
for edge in edges[0]:
graph.edge(*edge)
if description is True:
graph.attr(label=describe_feature(feature, **kwargs))
elif description is not False:
graph.attr(label=description)
if to_file:
save_graph(graph, to_file, format_)
return graph
def get_feature_data(feat, dataframes, groupbys, edges, primitives, layer=0):
# 1) add feature to dataframes tables:
feat_name = feat.get_name()
if feat.dataframe_name not in dataframes:
add_dataframe(feat.dataframe, dataframes)
dataframe_dict = dataframes[feat.dataframe_name]
# if we've already explored this feat, continue
feat_node = "{}:{}".format(feat.dataframe_name, feat_name)
if feat_name in dataframe_dict['columns'] or feat_name in dataframe_dict['feats']:
return feat_node, layer
if isinstance(feat, IdentityFeature):
dataframe_dict['columns'].add(feat_name)
else:
dataframe_dict['feats'].add(feat_name)
base_node = feat_node
# 2) if multi-output, convert feature to generic base
if isinstance(feat, FeatureOutputSlice):
feat = feat.base_feature
feat_name = feat.get_name()
# 3) add primitive node
if feat.primitive.name or isinstance(feat, DirectFeature):
prim_name = feat.primitive.name if feat.primitive.name else 'join'
prim_type = ''
if isinstance(feat, AggregationFeature):
prim_type = 'Aggregation'
elif isinstance(feat, TransformFeature):
prim_type = 'Transform'
primitive_node = "{}_{}_{}".format(layer, feat_name, prim_name)
primitives.append((primitive_node, prim_name.upper(), layer, prim_type))
edges[1].append([primitive_node, base_node])
base_node = primitive_node
# 4) add groupby/join edges and nodes
dependencies = [(dep.hash(), dep) for dep in feat.get_dependencies()]
for is_forward, r in feat.relationship_path:
if is_forward:
if r.child_dataframe.ww.name not in dataframes:
add_dataframe(r.child_dataframe, dataframes)
dataframes[r.child_dataframe.ww.name]['columns'].add(r._child_column_name)
child_node = '{}:{}'.format(r.child_dataframe.ww.name, r._child_column_name)
edges[0].append([base_node, child_node])
else:
if r.child_dataframe.ww.name not in dataframes:
add_dataframe(r.child_dataframe, dataframes)
dataframes[r.child_dataframe.ww.name]['columns'].add(r._child_column_name)
child_node = '{}:{}'.format(r.child_dataframe.ww.name, r._child_column_name)
child_name = child_node.replace(':', '--')
groupby_node = "{}_groupby_{}".format(feat_name, child_name)
groupby_name = 'group by\n{}'.format(r._child_column_name)
groupbys.append((groupby_node, groupby_name))
edges[0].append([child_node, groupby_node])
edges[1].append([groupby_node, base_node])
base_node = groupby_node
if hasattr(feat, 'groupby'):
groupby = feat.groupby
_ = get_feature_data(groupby, dataframes, groupbys, edges, primitives, layer + 1)
dependencies.remove((groupby.hash(), groupby))
groupby_name = groupby.get_name()
if isinstance(groupby, IdentityFeature):
dataframes[groupby.dataframe_name]['columns'].add(groupby_name)
else:
dataframes[groupby.dataframe_name]['feats'].add(groupby_name)
child_node = '{}:{}'.format(groupby.dataframe_name, groupby_name)
child_name = child_node.replace(':', '--')
groupby_node = "{}_groupby_{}".format(feat_name, child_name)
groupby_name = 'group by\n{}'.format(groupby_name)
groupbys.append((groupby_node, groupby_name))
edges[0].append([child_node, groupby_node])
edges[1].append([groupby_node, base_node])
base_node = groupby_node
# 5) recurse over dependents
max_depth = layer
for _, f in dependencies:
dependent_node, depth = get_feature_data(f, dataframes, groupbys, edges, primitives, layer + 1)
edges[1].append([dependent_node, base_node])
max_depth = max(depth, max_depth)
return feat_node, max_depth
def add_dataframe(dataframe, dataframe_dict):
dataframe_dict[dataframe.ww.name] = {
'index': dataframe.ww.index,
'targets': set(),
'columns': set(),
'feats': set()
}
def get_dataframe_table(dataframe_name, dataframe_dict):
'''
given a dict of columns and feats, construct the html table for it
'''
index = dataframe_dict['index']
targets = dataframe_dict['targets']
columns = dataframe_dict['columns'].difference(targets)
feats = dataframe_dict['feats'].difference(targets)
# If the index is used, make sure it's the first element in the table
clean_index = html.escape(index)
if index in columns:
rows = [COL_TEMPLATE.format(clean_index, clean_index + " (index)")]
columns.discard(index)
elif index in targets:
rows = [TARGET_TEMPLATE.format(clean_index, clean_index + " (index)")]
targets.discard(index)
else:
rows = []
for col in list(columns) + list(feats) + list(targets):
template = COL_TEMPLATE
if col in targets:
template = TARGET_TEMPLATE
col = html.escape(col)
rows.append(template.format(col, col))
table = TABLE_TEMPLATE.format(dataframe_name=dataframe_name,
table_cols="\n".join(rows))
return table