Source code for featuretools.feature_base.feature_visualizer

import html

from featuretools.feature_base.feature_base import (
    AggregationFeature,
    DirectFeature,
    FeatureOutputSlice,
    IdentityFeature,
    TransformFeature,
)
from featuretools.feature_base.feature_descriptions import describe_feature
from featuretools.utils.plot_utils import (
    check_graphviz,
    get_graphviz_format,
    save_graph,
)

TARGET_COLOR = "#D9EAD3"
TABLE_TEMPLATE = """<
<TABLE BORDER="0" CELLBORDER="1" CELLSPACING="0" CELLPADDING="10">
    <TR>
        <TD colspan="1" bgcolor="#A9A9A9"><B>{dataframe_name}</B></TD>
    </TR>{table_cols}
</TABLE>>"""
COL_TEMPLATE = """<TR><TD ALIGN="LEFT" port="{}">{}</TD></TR>"""
TARGET_TEMPLATE = """
    <TR>
        <TD ALIGN="LEFT" port="{}" BGCOLOR="{target_color}">{}</TD>
    </TR>""".format(
    "{}", "{}", target_color=TARGET_COLOR
)


[docs]def graph_feature(feature, to_file=None, description=False, **kwargs): """Generates a feature lineage graph for the given feature Args: feature (FeatureBase) : Feature to generate lineage graph for to_file (str, optional) : Path to where the plot should be saved. If set to None (as by default), the plot will not be saved. description (bool or str, optional): The feature description to use as a caption for the graph. If False, no description is added. Set to True to use an auto-generated description. Defaults to False. kwargs (keywords): Additional keyword arguments to pass as keyword arguments to the ft.describe_feature function. Returns: graphviz.Digraph : Graph object that can directly be displayed in Jupyter notebooks. """ graphviz = check_graphviz() format_ = get_graphviz_format(graphviz=graphviz, to_file=to_file) # Initialize a new directed graph graph = graphviz.Digraph( feature.get_name(), format=format_, graph_attr={"rankdir": "LR"} ) dataframes = {} edges = ([], []) primitives = [] groupbys = [] _, max_depth = get_feature_data( feature, dataframes, groupbys, edges, primitives, layer=0 ) dataframes[feature.dataframe_name]["targets"].add(feature.get_name()) for df_name in dataframes: dataframe_name = ( "\u2605 {} (target)".format(df_name) if df_name == feature.dataframe_name else df_name ) dataframe_table = get_dataframe_table(dataframe_name, dataframes[df_name]) graph.attr("node", shape="plaintext") graph.node(df_name, dataframe_table) graph.attr("node", shape="diamond") num_primitives = len(primitives) for prim_name, prim_label, layer, prim_type in primitives: step_num = max_depth - layer if num_primitives == 1: type_str = ( '<FONT POINT-SIZE="12"><B>{}</B><BR></BR></FONT>'.format(prim_type) if prim_type else "" ) prim_label = "<{}{}>".format(type_str, prim_label) else: step = "Step {}".format(step_num) type_str = " " + prim_type if prim_type else "" prim_label = ( '<<FONT POINT-SIZE="12"><B>{}:</B>{}<BR></BR></FONT>{}>'.format( step, type_str, prim_label ) ) # sink first layer transform primitive if multiple primitives if step_num == 1 and prim_type == "Transform" and num_primitives > 1: with graph.subgraph() as init_transform: init_transform.attr(rank="min") init_transform.node(name=prim_name, label=prim_label) else: graph.node(name=prim_name, label=prim_label) graph.attr("node", shape="box") for groupby_name, groupby_label in groupbys: graph.node(name=groupby_name, label=groupby_label) graph.attr("edge", style="solid", dir="forward") for edge in edges[1]: graph.edge(*edge) graph.attr("edge", style="dotted", arrowhead="none", dir="forward") for edge in edges[0]: graph.edge(*edge) if description is True: graph.attr(label=describe_feature(feature, **kwargs)) elif description is not False: graph.attr(label=description) if to_file: save_graph(graph, to_file, format_) return graph
def get_feature_data(feat, dataframes, groupbys, edges, primitives, layer=0): # 1) add feature to dataframes tables: feat_name = feat.get_name() if feat.dataframe_name not in dataframes: add_dataframe(feat.dataframe, dataframes) dataframe_dict = dataframes[feat.dataframe_name] # if we've already explored this feat, continue feat_node = "{}:{}".format(feat.dataframe_name, feat_name) if feat_name in dataframe_dict["columns"] or feat_name in dataframe_dict["feats"]: return feat_node, layer if isinstance(feat, IdentityFeature): dataframe_dict["columns"].add(feat_name) else: dataframe_dict["feats"].add(feat_name) base_node = feat_node # 2) if multi-output, convert feature to generic base if isinstance(feat, FeatureOutputSlice): feat = feat.base_feature feat_name = feat.get_name() # 3) add primitive node if feat.primitive.name or isinstance(feat, DirectFeature): prim_name = feat.primitive.name if feat.primitive.name else "join" prim_type = "" if isinstance(feat, AggregationFeature): prim_type = "Aggregation" elif isinstance(feat, TransformFeature): prim_type = "Transform" primitive_node = "{}_{}_{}".format(layer, feat_name, prim_name) primitives.append((primitive_node, prim_name.upper(), layer, prim_type)) edges[1].append([primitive_node, base_node]) base_node = primitive_node # 4) add groupby/join edges and nodes dependencies = [(dep.hash(), dep) for dep in feat.get_dependencies()] for is_forward, r in feat.relationship_path: if is_forward: if r.child_dataframe.ww.name not in dataframes: add_dataframe(r.child_dataframe, dataframes) dataframes[r.child_dataframe.ww.name]["columns"].add(r._child_column_name) child_node = "{}:{}".format(r.child_dataframe.ww.name, r._child_column_name) edges[0].append([base_node, child_node]) else: if r.child_dataframe.ww.name not in dataframes: add_dataframe(r.child_dataframe, dataframes) dataframes[r.child_dataframe.ww.name]["columns"].add(r._child_column_name) child_node = "{}:{}".format(r.child_dataframe.ww.name, r._child_column_name) child_name = child_node.replace(":", "--") groupby_node = "{}_groupby_{}".format(feat_name, child_name) groupby_name = "group by\n{}".format(r._child_column_name) groupbys.append((groupby_node, groupby_name)) edges[0].append([child_node, groupby_node]) edges[1].append([groupby_node, base_node]) base_node = groupby_node if hasattr(feat, "groupby"): groupby = feat.groupby _ = get_feature_data( groupby, dataframes, groupbys, edges, primitives, layer + 1 ) dependencies.remove((groupby.hash(), groupby)) groupby_name = groupby.get_name() if isinstance(groupby, IdentityFeature): dataframes[groupby.dataframe_name]["columns"].add(groupby_name) else: dataframes[groupby.dataframe_name]["feats"].add(groupby_name) child_node = "{}:{}".format(groupby.dataframe_name, groupby_name) child_name = child_node.replace(":", "--") groupby_node = "{}_groupby_{}".format(feat_name, child_name) groupby_name = "group by\n{}".format(groupby_name) groupbys.append((groupby_node, groupby_name)) edges[0].append([child_node, groupby_node]) edges[1].append([groupby_node, base_node]) base_node = groupby_node # 5) recurse over dependents max_depth = layer for _, f in dependencies: dependent_node, depth = get_feature_data( f, dataframes, groupbys, edges, primitives, layer + 1 ) edges[1].append([dependent_node, base_node]) max_depth = max(depth, max_depth) return feat_node, max_depth def add_dataframe(dataframe, dataframe_dict): dataframe_dict[dataframe.ww.name] = { "index": dataframe.ww.index, "targets": set(), "columns": set(), "feats": set(), } def get_dataframe_table(dataframe_name, dataframe_dict): """ given a dict of columns and feats, construct the html table for it """ index = dataframe_dict["index"] targets = dataframe_dict["targets"] columns = dataframe_dict["columns"].difference(targets) feats = dataframe_dict["feats"].difference(targets) # If the index is used, make sure it's the first element in the table clean_index = html.escape(index) if index in columns: rows = [COL_TEMPLATE.format(clean_index, clean_index + " (index)")] columns.discard(index) elif index in targets: rows = [TARGET_TEMPLATE.format(clean_index, clean_index + " (index)")] targets.discard(index) else: rows = [] for col in list(columns) + list(feats) + list(targets): template = COL_TEMPLATE if col in targets: template = TARGET_TEMPLATE col = html.escape(col) rows.append(template.format(col, col)) table = TABLE_TEMPLATE.format( dataframe_name=dataframe_name, table_cols="\n".join(rows) ) return table