NOTICE

The upcoming release of Featuretools 1.0.0 contains several breaking changes. Users are encouraged to test this version prior to release by installing from GitHub:

pip install https://github.com/alteryx/featuretools/archive/woodwork-integration.zip

For details on migrating to the new version, refer to Transitioning to Featuretools Version 1.0. Please report any issues in the Featuretools GitHub repo or by messaging in Alteryx Open Source Slack.


Source code for featuretools.primitives.standard.transform_primitive

import warnings

import numpy as np
import pandas as pd
from woodwork.column_schema import ColumnSchema
from woodwork.logical_types import (
    URL,
    AgeFractional,
    Boolean,
    BooleanNullable,
    Categorical,
    Datetime,
    EmailAddress,
    LatLong,
    NaturalLanguage,
    Ordinal
)

from featuretools.primitives.base.transform_primitive_base import (
    TransformPrimitive
)
from featuretools.utils import convert_time_units
from featuretools.utils.common_tld_utils import COMMON_TLDS
from featuretools.utils.gen_utils import Library


class IsNull(TransformPrimitive):
    """Determines if a value is null.

    Examples:
        >>> is_null = IsNull()
        >>> is_null([1, None, 3]).tolist()
        [False, True, False]
    """
    name = "is_null"
    input_types = [ColumnSchema()]
    return_type = ColumnSchema(logical_type=Boolean)
    compatibility = [Library.PANDAS, Library.DASK, Library.KOALAS]
    description_template = "whether {} is null"

    def get_function(self):
        def isnull(array):
            return array.isnull()
        return isnull


[docs]class Absolute(TransformPrimitive): """Computes the absolute value of a number. Examples: >>> absolute = Absolute() >>> absolute([3.0, -5.0, -2.4]).tolist() [3.0, 5.0, 2.4] """ name = "absolute" input_types = [ColumnSchema(semantic_tags={'numeric'})] return_type = ColumnSchema(semantic_tags={'numeric'}) compatibility = [Library.PANDAS, Library.DASK, Library.KOALAS] description_template = "the absolute value of {}" def get_function(self): return np.absolute
[docs]class TimeSincePrevious(TransformPrimitive): """Compute the time since the previous entry in a list. Args: unit (str): Defines the unit of time to count from. Defaults to Seconds. Acceptable values: years, months, days, hours, minutes, seconds, milliseconds, nanoseconds Description: Given a list of datetimes, compute the time in seconds elapsed since the previous item in the list. The result for the first item in the list will always be `NaN`. Examples: >>> from datetime import datetime >>> time_since_previous = TimeSincePrevious() >>> dates = [datetime(2019, 3, 1, 0, 0, 0), ... datetime(2019, 3, 1, 0, 2, 0), ... datetime(2019, 3, 1, 0, 3, 0), ... datetime(2019, 3, 1, 0, 2, 30), ... datetime(2019, 3, 1, 0, 10, 0)] >>> time_since_previous(dates).tolist() [nan, 120.0, 60.0, -30.0, 450.0] """ name = "time_since_previous" input_types = [ColumnSchema(logical_type=Datetime, semantic_tags={'time_index'})] return_type = ColumnSchema(semantic_tags={'numeric'}) description_template = "the time since the previous instance of {}"
[docs] def __init__(self, unit="seconds"): self.unit = unit.lower()
def get_function(self): def pd_diff(values): return convert_time_units(values.diff().apply(lambda x: x.total_seconds()), self.unit) return pd_diff
[docs]class Day(TransformPrimitive): """Determines the day of the month from a datetime. Examples: >>> from datetime import datetime >>> dates = [datetime(2019, 3, 1), ... datetime(2019, 3, 3), ... datetime(2019, 3, 31)] >>> day = Day() >>> day(dates).tolist() [1, 3, 31] """ name = "day" input_types = [ColumnSchema(logical_type=Datetime)] return_type = ColumnSchema(logical_type=Ordinal(order=list(range(1, 32))), semantic_tags={'category'}) compatibility = [Library.PANDAS, Library.DASK, Library.KOALAS] description_template = "the day of the month of {}" def get_function(self): def day(vals): return vals.dt.day return day
[docs]class Hour(TransformPrimitive): """Determines the hour value of a datetime. Examples: >>> from datetime import datetime >>> dates = [datetime(2019, 3, 1), ... datetime(2019, 3, 3, 11, 10, 50), ... datetime(2019, 3, 31, 19, 45, 15)] >>> hour = Hour() >>> hour(dates).tolist() [0, 11, 19] """ name = "hour" input_types = [ColumnSchema(logical_type=Datetime)] return_type = ColumnSchema(logical_type=Ordinal(order=list(range(24))), semantic_tags={'category'}) compatibility = [Library.PANDAS, Library.DASK, Library.KOALAS] description_template = 'the hour value of {}' def get_function(self): def hour(vals): return vals.dt.hour return hour
[docs]class Second(TransformPrimitive): """Determines the seconds value of a datetime. Examples: >>> from datetime import datetime >>> dates = [datetime(2019, 3, 1), ... datetime(2019, 3, 3, 11, 10, 50), ... datetime(2019, 3, 31, 19, 45, 15)] >>> second = Second() >>> second(dates).tolist() [0, 50, 15] """ name = "second" input_types = [ColumnSchema(logical_type=Datetime)] return_type = ColumnSchema(logical_type=Ordinal(order=list(range(60))), semantic_tags={'category'}) compatibility = [Library.PANDAS, Library.DASK, Library.KOALAS] description_template = "the seconds value of {}" def get_function(self): def second(vals): return vals.dt.second return second
[docs]class Minute(TransformPrimitive): """Determines the minutes value of a datetime. Examples: >>> from datetime import datetime >>> dates = [datetime(2019, 3, 1), ... datetime(2019, 3, 3, 11, 10, 50), ... datetime(2019, 3, 31, 19, 45, 15)] >>> minute = Minute() >>> minute(dates).tolist() [0, 10, 45] """ name = "minute" input_types = [ColumnSchema(logical_type=Datetime)] return_type = ColumnSchema(logical_type=Ordinal(order=list(range(60))), semantic_tags={'category'}) compatibility = [Library.PANDAS, Library.DASK, Library.KOALAS] description_template = "the minutes value of {}" def get_function(self): def minute(vals): return vals.dt.minute return minute
[docs]class Week(TransformPrimitive): """Determines the week of the year from a datetime. Description: Returns the week of the year from a datetime value. The first week of the year starts on January 1, and week numbers increment each Monday. Examples: >>> from datetime import datetime >>> dates = [datetime(2019, 1, 3), ... datetime(2019, 6, 17, 11, 10, 50), ... datetime(2019, 11, 30, 19, 45, 15)] >>> week = Week() >>> week(dates).tolist() [1, 25, 48] """ name = "week" input_types = [ColumnSchema(logical_type=Datetime)] return_type = ColumnSchema(logical_type=Ordinal(order=list(range(1, 54))), semantic_tags={'category'}) compatibility = [Library.PANDAS, Library.DASK, Library.KOALAS] description_template = "the week of the year of {}" def get_function(self): def week(vals): warnings.filterwarnings("ignore", message=("Series.dt.weekofyear and Series.dt.week " "have been deprecated."), module="featuretools" ) return vals.dt.week return week
[docs]class Month(TransformPrimitive): """Determines the month value of a datetime. Examples: >>> from datetime import datetime >>> dates = [datetime(2019, 3, 1), ... datetime(2019, 6, 17, 11, 10, 50), ... datetime(2019, 11, 30, 19, 45, 15)] >>> month = Month() >>> month(dates).tolist() [3, 6, 11] """ name = "month" input_types = [ColumnSchema(logical_type=Datetime)] return_type = ColumnSchema(logical_type=Ordinal(order=list(range(1, 13))), semantic_tags={'category'}) compatibility = [Library.PANDAS, Library.DASK, Library.KOALAS] description_template = "the month of {}" def get_function(self): def month(vals): return vals.dt.month return month
[docs]class Year(TransformPrimitive): """Determines the year value of a datetime. Examples: >>> from datetime import datetime >>> dates = [datetime(2019, 3, 1), ... datetime(2048, 6, 17, 11, 10, 50), ... datetime(1950, 11, 30, 19, 45, 15)] >>> year = Year() >>> year(dates).tolist() [2019, 2048, 1950] """ name = "year" input_types = [ColumnSchema(logical_type=Datetime)] return_type = ColumnSchema(logical_type=Ordinal(order=list(range(1, 3000))), semantic_tags={'category'}) compatibility = [Library.PANDAS, Library.DASK, Library.KOALAS] description_template = "the year of {}" def get_function(self): def year(vals): return vals.dt.year return year
[docs]class IsWeekend(TransformPrimitive): """Determines if a date falls on a weekend. Examples: >>> from datetime import datetime >>> dates = [datetime(2019, 3, 1), ... datetime(2019, 6, 17, 11, 10, 50), ... datetime(2019, 11, 30, 19, 45, 15)] >>> is_weekend = IsWeekend() >>> is_weekend(dates).tolist() [False, False, True] """ name = "is_weekend" input_types = [ColumnSchema(logical_type=Datetime)] return_type = ColumnSchema(logical_type=BooleanNullable) compatibility = [Library.PANDAS, Library.DASK, Library.KOALAS] description_template = "whether {} occurred on a weekend" def get_function(self): def is_weekend(vals): return vals.dt.weekday > 4 return is_weekend
[docs]class Weekday(TransformPrimitive): """Determines the day of the week from a datetime. Description: Returns the day of the week from a datetime value. Weeks start on Monday (day 0) and run through Sunday (day 6). Examples: >>> from datetime import datetime >>> dates = [datetime(2019, 3, 1), ... datetime(2019, 6, 17, 11, 10, 50), ... datetime(2019, 11, 30, 19, 45, 15)] >>> weekday = Weekday() >>> weekday(dates).tolist() [4, 0, 5] """ name = "weekday" input_types = [ColumnSchema(logical_type=Datetime)] return_type = ColumnSchema(logical_type=Ordinal(order=list(range(7))), semantic_tags={'category'}) compatibility = [Library.PANDAS, Library.DASK, Library.KOALAS] description_template = "the day of the week of {}" def get_function(self): def weekday(vals): return vals.dt.weekday return weekday
[docs]class NumCharacters(TransformPrimitive): """Calculates the number of characters in a string. Examples: >>> num_characters = NumCharacters() >>> num_characters(['This is a string', ... 'second item', ... 'final1']).tolist() [16, 11, 6] """ name = 'num_characters' input_types = [ColumnSchema(logical_type=NaturalLanguage)] return_type = ColumnSchema(semantic_tags={'numeric'}) compatibility = [Library.PANDAS, Library.DASK, Library.KOALAS] description_template = "the number of characters in {}" def get_function(self): def character_counter(array): return array.fillna('').str.len() return character_counter
[docs]class NumWords(TransformPrimitive): """Determines the number of words in a string by counting the spaces. Examples: >>> num_words = NumWords() >>> num_words(['This is a string', ... 'Two words', ... 'no-spaces', ... 'Also works with sentences. Second sentence!']).tolist() [4, 2, 1, 6] """ name = 'num_words' input_types = [ColumnSchema(logical_type=NaturalLanguage)] return_type = ColumnSchema(semantic_tags={'numeric'}) compatibility = [Library.PANDAS, Library.DASK, Library.KOALAS] description_template = "the number of words in {}" def get_function(self): def word_counter(array): return array.fillna('').str.count(' ') + 1 return word_counter
[docs]class TimeSince(TransformPrimitive): """Calculates time from a value to a specified cutoff datetime. Args: unit (str): Defines the unit of time to count from. Defaults to Seconds. Acceptable values: years, months, days, hours, minutes, seconds, milliseconds, nanoseconds Examples: >>> from datetime import datetime >>> time_since = TimeSince() >>> times = [datetime(2019, 3, 1, 0, 0, 0, 1), ... datetime(2019, 3, 1, 0, 0, 1, 0), ... datetime(2019, 3, 1, 0, 2, 0, 0)] >>> cutoff_time = datetime(2019, 3, 1, 0, 0, 0, 0) >>> values = time_since(times, time=cutoff_time) >>> list(map(int, values)) [0, -1, -120] Change output to nanoseconds >>> from datetime import datetime >>> time_since_nano = TimeSince(unit='nanoseconds') >>> times = [datetime(2019, 3, 1, 0, 0, 0, 1), ... datetime(2019, 3, 1, 0, 0, 1, 0), ... datetime(2019, 3, 1, 0, 2, 0, 0)] >>> cutoff_time = datetime(2019, 3, 1, 0, 0, 0, 0) >>> values = time_since_nano(times, time=cutoff_time) >>> list(map(lambda x: int(round(x)), values)) [-1000, -1000000000, -120000000000] """ name = 'time_since' input_types = [ColumnSchema(logical_type=Datetime)] return_type = ColumnSchema(semantic_tags={'numeric'}) uses_calc_time = True compatibility = [Library.PANDAS, Library.DASK] description_template = "the time from {} to the cutoff time"
[docs] def __init__(self, unit="seconds"): self.unit = unit.lower()
def get_function(self): def pd_time_since(array, time): return convert_time_units((time - array).dt.total_seconds(), self.unit) return pd_time_since
[docs]class IsIn(TransformPrimitive): """Determines whether a value is present in a provided list. Examples: >>> items = ['string', 10.3, False] >>> is_in = IsIn(list_of_outputs=items) >>> is_in(['string', 10.5, False]).tolist() [True, False, True] """ name = "isin" input_types = [ColumnSchema()] return_type = ColumnSchema(logical_type=Boolean) compatibility = [Library.PANDAS, Library.DASK, Library.KOALAS]
[docs] def __init__(self, list_of_outputs=None): self.list_of_outputs = list_of_outputs if not list_of_outputs: stringified_output_list = '[]' else: stringified_output_list = ', '.join([str(x) for x in list_of_outputs]) self.description_template = "whether {{}} is in {}".format(stringified_output_list)
def get_function(self): def pd_is_in(array): return array.isin(self.list_of_outputs or []) return pd_is_in def generate_name(self, base_feature_names): return u"%s.isin(%s)" % (base_feature_names[0], str(self.list_of_outputs))
[docs]class Diff(TransformPrimitive): """Compute the difference between the value in a list and the previous value in that list. Description: Given a list of values, compute the difference from the previous item in the list. The result for the first element of the list will always be `NaN`. If the values are datetimes, the output will be a timedelta. Examples: >>> diff = Diff() >>> values = [1, 10, 3, 4, 15] >>> diff(values).tolist() [nan, 9.0, -7.0, 1.0, 11.0] """ name = "diff" input_types = [ColumnSchema(semantic_tags={'numeric'})] return_type = ColumnSchema(semantic_tags={'numeric'}) uses_full_dataframe = True description_template = "the difference from the previous value of {}" def get_function(self): def pd_diff(values): return values.diff() return pd_diff
class Negate(TransformPrimitive): """Negates a numeric value. Examples: >>> negate = Negate() >>> negate([1.0, 23.2, -7.0]).tolist() [-1.0, -23.2, 7.0] """ name = "negate" input_types = [ColumnSchema(semantic_tags={'numeric'})] return_type = ColumnSchema(semantic_tags={'numeric'}) compatibility = [Library.PANDAS, Library.DASK, Library.KOALAS] description_template = "the negation of {}" def get_function(self): def negate(vals): return vals * -1 return negate def generate_name(self, base_feature_names): return "-(%s)" % (base_feature_names[0])
[docs]class Not(TransformPrimitive): """Negates a boolean value. Examples: >>> not_func = Not() >>> not_func([True, True, False]).tolist() [False, False, True] """ name = "not" input_types = [[ColumnSchema(logical_type=Boolean)], [ColumnSchema(logical_type=BooleanNullable)]] return_type = ColumnSchema(logical_type=BooleanNullable) compatibility = [Library.PANDAS, Library.DASK, Library.KOALAS] description_template = "the negation of {}" def generate_name(self, base_feature_names): return u"NOT({})".format(base_feature_names[0]) def get_function(self): return np.logical_not
[docs]class Percentile(TransformPrimitive): """Determines the percentile rank for each value in a list. Examples: >>> percentile = Percentile() >>> percentile([10, 15, 1, 20]).tolist() [0.5, 0.75, 0.25, 1.0] Nan values are ignored when determining rank >>> percentile([10, 15, 1, None, 20]).tolist() [0.5, 0.75, 0.25, nan, 1.0] """ name = 'percentile' uses_full_dataframe = True input_types = [ColumnSchema(semantic_tags={'numeric'})] return_type = ColumnSchema(semantic_tags={'numeric'}) description_template = "the percentile rank of {}" def get_function(self): return lambda array: array.rank(pct=True)
[docs]class Latitude(TransformPrimitive): """Returns the first tuple value in a list of LatLong tuples. For use with the LatLong logical type. Examples: >>> latitude = Latitude() >>> latitude([(42.4, -71.1), ... (40.0, -122.4), ... (41.2, -96.75)]).tolist() [42.4, 40.0, 41.2] """ name = 'latitude' input_types = [ColumnSchema(logical_type=LatLong)] return_type = ColumnSchema(semantic_tags={'numeric'}) description_template = "the latitude of {}" def get_function(self): def latitude(latlong): return latlong.map(lambda x: x[0] if isinstance(x, tuple) else np.nan) return latitude
[docs]class Longitude(TransformPrimitive): """Returns the second tuple value in a list of LatLong tuples. For use with the LatLong logical type. Examples: >>> longitude = Longitude() >>> longitude([(42.4, -71.1), ... (40.0, -122.4), ... (41.2, -96.75)]).tolist() [-71.1, -122.4, -96.75] """ name = 'longitude' input_types = [ColumnSchema(logical_type=LatLong)] return_type = ColumnSchema(semantic_tags={'numeric'}) description_template = "the longitude of {}" def get_function(self): def longitude(latlong): return latlong.map(lambda x: x[1] if isinstance(x, tuple) else np.nan) return longitude
[docs]class Haversine(TransformPrimitive): """Calculates the approximate haversine distance between two LatLong columns. Args: unit (str): Determines the unit value to output. Could be `miles` or `kilometers`. Default is `miles`. Examples: >>> haversine = Haversine() >>> distances = haversine([(42.4, -71.1), (40.0, -122.4)], ... [(40.0, -122.4), (41.2, -96.75)]) >>> np.round(distances, 3).tolist() [2631.231, 1343.289] Output units can be specified >>> haversine_km = Haversine(unit='kilometers') >>> distances_km = haversine_km([(42.4, -71.1), (40.0, -122.4)], ... [(40.0, -122.4), (41.2, -96.75)]) >>> np.round(distances_km, 3).tolist() [4234.555, 2161.814] """ name = 'haversine' input_types = [ColumnSchema(logical_type=LatLong), ColumnSchema(logical_type=LatLong)] return_type = ColumnSchema(semantic_tags={'numeric'}) commutative = True
[docs] def __init__(self, unit='miles'): valid_units = ['miles', 'kilometers'] if unit not in valid_units: error_message = 'Invalid unit %s provided. Must be one of %s' % (unit, valid_units) raise ValueError(error_message) self.unit = unit self.description_template = "the haversine distance in {} between {{}} and {{}}".format(self.unit)
def get_function(self): def haversine(latlong1, latlong2): lat_1s = np.array([x[0] if isinstance(x, tuple) else np.nan for x in latlong1]) lon_1s = np.array([x[1] if isinstance(x, tuple) else np.nan for x in latlong1]) lat_2s = np.array([x[0] if isinstance(x, tuple) else np.nan for x in latlong2]) lon_2s = np.array([x[1] if isinstance(x, tuple) else np.nan for x in latlong2]) lon1, lat1, lon2, lat2 = map( np.radians, [lon_1s, lat_1s, lon_2s, lat_2s]) dlon = lon2 - lon1 dlat = lat2 - lat1 a = np.sin(dlat / 2.0) ** 2 + np.cos(lat1) * \ np.cos(lat2) * np.sin(dlon / 2.0)**2 radius_earth = 3958.7613 if self.unit == 'kilometers': radius_earth = 6371.0088 distance = radius_earth * 2 * np.arcsin(np.sqrt(a)) return distance return haversine def generate_name(self, base_feature_names): name = u"{}(".format(self.name.upper()) name += u", ".join(base_feature_names) if self.unit != 'miles': name += u", unit={}".format(self.unit) name += u")" return name
class Age(TransformPrimitive): """Calculates the age in years as a floating point number given a date of birth. Description: Age in years is computed by calculating the number of days between the date of birth and the reference time and dividing the result by 365. Examples: Determine the age of three people as of Jan 1, 2019 >>> import pandas as pd >>> reference_date = pd.to_datetime("01-01-2019") >>> age = Age() >>> input_ages = [pd.to_datetime("01-01-2000"), ... pd.to_datetime("05-30-1983"), ... pd.to_datetime("10-17-1997")] >>> age(input_ages, time=reference_date).tolist() [19.013698630136986, 35.61643835616438, 21.221917808219178] """ name = "age" input_types = [ColumnSchema(logical_type=Datetime, semantic_tags={'date_of_birth'})] return_type = ColumnSchema(logical_type=AgeFractional, semantic_tags={'numeric'}) uses_calc_time = True compatibility = [Library.PANDAS, Library.DASK] description_template = "the age from {}" def get_function(self): def age(x, time=None): return (time - x).dt.days / 365 return age class URLToDomain(TransformPrimitive): """Determines the domain of a url. Description: Calculates the label to identify the network domain of a URL. Supports urls with or without protocol as well as international country domains. Examples: >>> url_to_domain = URLToDomain() >>> urls = ['https://play.google.com', ... 'http://www.google.co.in', ... 'www.facebook.com'] >>> url_to_domain(urls).tolist() ['play.google.com', 'google.co.in', 'facebook.com'] """ name = "url_to_domain" input_types = [ColumnSchema(logical_type=URL)] return_type = ColumnSchema(logical_type=Categorical, semantic_tags={'category'}) def get_function(self): def url_to_domain(x): p = r'^(?:https?:\/\/)?(?:[^@\/\n]+@)?(?:www\.)?([^:\/?\n]+)' return x.str.extract(p, expand=False) return url_to_domain class URLToProtocol(TransformPrimitive): """Determines the protocol (http or https) of a url. Description: Extract the protocol of a url using regex. It will be either https or http. Returns nan if the url doesn't contain a protocol. Examples: >>> url_to_protocol = URLToProtocol() >>> urls = ['https://play.google.com', ... 'http://www.google.co.in', ... 'www.facebook.com'] >>> url_to_protocol(urls).to_list() ['https', 'http', nan] """ name = "url_to_protocol" input_types = [ColumnSchema(logical_type=URL)] return_type = ColumnSchema(logical_type=Categorical, semantic_tags={'category'}) def get_function(self): def url_to_protocol(x): p = r'^(https|http)(?:\:)' return x.str.extract(p, expand=False) return url_to_protocol class URLToTLD(TransformPrimitive): """Determines the top level domain of a url. Description: Extract the top level domain of a url, using regex, and a list of common top level domains. Returns nan if the url is invalid or null. Common top level domains were pulled from this list: https://www.hayksaakian.com/most-popular-tlds/ Examples: >>> url_to_tld = URLToTLD() >>> urls = ['https://www.google.com', 'http://www.google.co.in', ... 'www.facebook.com'] >>> url_to_tld(urls).to_list() ['com', 'in', 'com'] """ name = "url_to_tld" input_types = [ColumnSchema(logical_type=URL)] return_type = ColumnSchema(logical_type=Categorical, semantic_tags={'category'}) def get_function(self): self.tlds_pattern = r'(?:\.({}))'.format('|'.join(COMMON_TLDS)) def url_to_domain(x): p = r'^(?:https?:\/\/)?(?:[^@\/\n]+@)?(?:www\.)?([^:\/?\n]+)' return x.str.extract(p, expand=False) def url_to_tld(x): domains = url_to_domain(x) df = domains.str.extractall(self.tlds_pattern) matches = df.groupby(level=0).last()[0] return matches.reindex(x.index) return url_to_tld class IsFreeEmailDomain(TransformPrimitive): """Determines if an email address is from a free email domain. Description: EmailAddress input should be a string. Will return Nan if an invalid email address is provided, or if the input is not a string. The list of free email domains used in this primitive was obtained from https://github.com/willwhite/freemail/blob/master/data/free.txt. Examples: >>> is_free_email_domain = IsFreeEmailDomain() >>> is_free_email_domain(['[email protected]', '[email protected]']).tolist() [True, False] """ name = "is_free_email_domain" input_types = [ColumnSchema(logical_type=EmailAddress)] return_type = ColumnSchema(logical_type=BooleanNullable) filename = "free_email_provider_domains.txt" def get_function(self): file_path = self.get_filepath(self.filename) free_domains = pd.read_csv(file_path, header=None, names=['domain']) free_domains['domain'] = free_domains.domain.str.strip() def is_free_email_domain(emails): # if the input is empty return an empty Series if len(emails) == 0: return pd.Series([]) emails_df = pd.DataFrame({'email': emails}) # if all emails are NaN expand won't propogate NaNs and will fail on indexing if emails_df['email'].isnull().all(): emails_df['domain'] = np.nan else: # .str.strip() and .str.split() return NaN for NaN values and propogate NaNs into new columns emails_df['domain'] = emails_df['email'].str.strip().str.split('@', expand=True)[1] emails_df['is_free'] = emails_df['domain'].isin(free_domains['domain']) # if there are any NaN domain values, change the series type to allow for # both bools and NaN values and set is_free to NaN for the NaN domains if emails_df['domain'].isnull().values.any(): emails_df['is_free'] = emails_df['is_free'].astype(np.object) emails_df.loc[emails_df['domain'].isnull(), 'is_free'] = np.nan return emails_df.is_free.values return is_free_email_domain class EmailAddressToDomain(TransformPrimitive): """Determines the domain of an email Description: EmailAddress input should be a string. Will return Nan if an invalid email address is provided, or if the input is not a string. Examples: >>> email_address_to_domain = EmailAddressToDomain() >>> email_address_to_domain(['[email protected]', '[email protected]']).tolist() ['gmail.com', 'featuretools.com'] """ name = "email_address_to_domain" input_types = [ColumnSchema(logical_type=EmailAddress)] return_type = ColumnSchema(logical_type=Categorical, semantic_tags={'category'}) def get_function(self): def email_address_to_domain(emails): # if the input is empty return an empty Series if len(emails) == 0: return pd.Series([]) emails_df = pd.DataFrame({'email': emails}) # if all emails are NaN expand won't propogate NaNs and will fail on indexing if emails_df['email'].isnull().all(): emails_df['domain'] = np.nan emails_df['domain'] = emails_df['domain'].astype(object) else: # .str.strip() and .str.split() return NaN for NaN values and propogate NaNs into new columns emails_df['domain'] = emails_df['email'].str.strip().str.split('@', expand=True)[1] return emails_df.domain.values return email_address_to_domain