Source code for featuretools.primitives.standard.aggregation.count_above_mean
import numpy as np
from woodwork.column_schema import ColumnSchema
from woodwork.logical_types import IntegerNullable
from featuretools.primitives.base.aggregation_primitive_base import AggregationPrimitive
[docs]class CountAboveMean(AggregationPrimitive):
    """Calculates the number of values that are above the mean.
    Args:
        skipna (bool): Determines if to use NA/null values. Defaults to
            True to skip NA/null.
    Examples:
        >>> count_above_mean = CountAboveMean()
        >>> count_above_mean([1, 2, 3, 4, 5])
        2
        The way NaNs are treated can be controlled.
        >>> count_above_mean_skipna = CountAboveMean(skipna=False)
        >>> count_above_mean_skipna([1, 2, 3, 4, 5, None])
        nan
    """
    name = "count_above_mean"
    input_types = [ColumnSchema(semantic_tags={"numeric"})]
    return_type = ColumnSchema(logical_type=IntegerNullable, semantic_tags={"numeric"})
    stack_on_self = False
[docs]    def __init__(self, skipna=True):
        self.skipna = skipna 
    def get_function(self):
        def count_above_mean(x):
            mean = x.mean(skipna=self.skipna)
            if np.isnan(mean):
                return np.nan
            return len(x[x > mean])
        return count_above_mean