Advanced Custom Primitives Guide¶
Functions With Additional Arguments¶
One caveat with the make_primitive functions is that the required arguments of function
must be input features. Here we create a function for StringCount
, a primitive which counts the number of occurrences of a string in a Text
input. Since string
is not a feature, it needs to be a keyword argument to string_count
.
In [1]: def string_count(column, string=None):
...: '''Count the number of times the value string occurs'''
...: assert string is not None, "string to count needs to be defined"
...: counts = [element.lower().count(string) for element in column]
...: return counts
...:
In order to have features defined using the primitive reflect what string is being counted, we define a custom generate_name
function.
In [2]: def string_count_generate_name(self, base_feature_names):
...: return u'STRING_COUNT(%s, "%s")' % (base_feature_names[0], self.kwargs['string'])
...:
Now that we have the function, we create the primitive using the make_trans_primitive
function.
In [3]: StringCount = make_trans_primitive(function=string_count,
...: input_types=[Text],
...: return_type=Numeric,
...: cls_attributes={"generate_name": string_count_generate_name})
...:
Passing in string="test"
as a keyword argument when initializing the StringCount primitive will make “test” the value used for string when string_count
is called to calculate the feature values. Now we use this primitive to define features and calculate the feature values.
In [4]: from featuretools.tests.testing_utils import make_ecommerce_entityset
In [5]: es = make_ecommerce_entityset()
In [6]: feature_matrix, features = ft.dfs(entityset=es,
...: target_entity="sessions",
...: agg_primitives=["sum", "mean", "std"],
...: trans_primitives=[StringCount(string="the")])
...:
In [7]: feature_matrix.columns
Out[7]: Index(['device_name', 'customer_id', 'device_type', 'SUM(log.value_many_nans)', 'SUM(log.value_2)', 'SUM(log.value)', 'MEAN(log.value_many_nans)', 'MEAN(log.value_2)', 'MEAN(log.value)', 'STD(log.value_many_nans)', 'STD(log.value_2)', 'STD(log.value)', 'customers.cohort', 'customers.age', 'customers.région_id', 'customers.loves_ice_cream', 'customers.cancel_reason', 'customers.engagement_level', 'SUM(log.STRING_COUNT(comments, "the"))', 'SUM(log.products.rating)', 'MEAN(log.STRING_COUNT(comments, "the"))', 'MEAN(log.products.rating)', 'STD(log.STRING_COUNT(comments, "the"))', 'STD(log.products.rating)', 'customers.SUM(log.value)', 'customers.SUM(log.value_many_nans)', 'customers.SUM(log.value_2)', 'customers.MEAN(log.value)', 'customers.MEAN(log.value_many_nans)', 'customers.MEAN(log.value_2)', 'customers.STD(log.value)', 'customers.STD(log.value_many_nans)', 'customers.STD(log.value_2)', 'customers.STRING_COUNT(favorite_quote, "the")', 'customers.cohorts.cohort_name', 'customers.régions.language'], dtype='object')
In [8]: feature_matrix[['STD(log.STRING_COUNT(comments, "the"))', 'SUM(log.STRING_COUNT(comments, "the"))', 'MEAN(log.STRING_COUNT(comments, "the"))']]