Advanced Custom Primitives Guide

Functions With Additional Arguments

One caveat with the make_primitive functions is that the required arguments of function must be input features. Here we create a function for StringCount, a primitive which counts the number of occurrences of a string in a Text input. Since string is not a feature, it needs to be a keyword argument to string_count.

In [1]: def string_count(column, string=None):
   ...:     '''Count the number of times the value string occurs'''
   ...:     assert string is not None, "string to count needs to be defined"
   ...:     counts = [element.lower().count(string) for element in column]
   ...:     return counts

In order to have features defined using the primitive reflect what string is being counted, we define a custom generate_name function.

In [2]: def string_count_generate_name(self, base_feature_names):
   ...:   return u'STRING_COUNT(%s, "%s")' % (base_feature_names[0], self.kwargs['string'])

Now that we have the function, we create the primitive using the make_trans_primitive function.

In [3]: StringCount = make_trans_primitive(function=string_count,
   ...:                                    input_types=[Text],
   ...:                                    return_type=Numeric,
   ...:                                    cls_attributes={"generate_name": string_count_generate_name})

Passing in string="test" as a keyword argument when initializing the StringCount primitive will make “test” the value used for string when string_count is called to calculate the feature values. Now we use this primitive to define features and calculate the feature values.

In [4]: from featuretools.tests.testing_utils import make_ecommerce_entityset

In [5]: es = make_ecommerce_entityset()

In [6]: feature_matrix, features = ft.dfs(entityset=es,
   ...:                                   target_entity="sessions",
   ...:                                   agg_primitives=["sum", "mean", "std"],
   ...:                                   trans_primitives=[StringCount(string="the")])

In [7]: feature_matrix.columns
Out[7]: Index(['device_name', 'customer_id', 'device_type', 'SUM(log.value)', 'SUM(log.value_2)', 'SUM(log.value_many_nans)', 'MEAN(log.value)', 'MEAN(log.value_2)', 'MEAN(log.value_many_nans)', 'STD(log.value)', 'STD(log.value_2)', 'STD(log.value_many_nans)', 'customers.cohort', 'customers.age', 'customers.région_id', 'customers.loves_ice_cream', 'customers.cancel_reason', 'customers.engagement_level', 'SUM(log.STRING_COUNT(comments, "the"))', 'SUM(log.products.rating)', 'MEAN(log.STRING_COUNT(comments, "the"))', 'MEAN(log.products.rating)', 'STD(log.STRING_COUNT(comments, "the"))', 'STD(log.products.rating)', 'customers.SUM(log.value)', 'customers.SUM(log.value_2)', 'customers.SUM(log.value_many_nans)', 'customers.MEAN(log.value)', 'customers.MEAN(log.value_2)', 'customers.MEAN(log.value_many_nans)', 'customers.STD(log.value)', 'customers.STD(log.value_2)', 'customers.STD(log.value_many_nans)', 'customers.STRING_COUNT(favorite_quote, "the")', 'customers.cohorts.cohort_name', 'customers.régions.language'], dtype='object')

In [8]: feature_matrix[['STD(log.STRING_COUNT(comments, "the"))', 'SUM(log.STRING_COUNT(comments, "the"))', 'MEAN(log.STRING_COUNT(comments, "the"))']]
    STD(log.STRING_COUNT(comments, "the"))  SUM(log.STRING_COUNT(comments, "the"))  MEAN(log.STRING_COUNT(comments, "the"))
0                                47.124304                                     209                                    41.80
1                                36.509131                                     109                                    27.25
2                                      NaN                                      29                                    29.00
3                                49.497475                                      70                                    35.00
4                                 0.000000                                       0                                     0.00
5                                 1.414214                                       4                                     2.00