from featuretools.primitives import TransformPrimitive
from featuretools.utils.gen_utils import import_or_raise
from woodwork.column_schema import ColumnSchema
from woodwork.logical_types import Double, NaturalLanguage
[docs]class UniversalSentenceEncoder(TransformPrimitive):
"""Transforms a sentence or short paragraph to a vector using [tfhub
model](https://tfhub.dev/google/universal-sentence-encoder/2)
Args:
None
Examples:
>>> sentences = ["I like to eat pizza", "The roller coaster was built in 1885.", ""]
>>> # universal_sentence_encoder = UniversalSentenceEncoder() # normal syntax
>>> output = universal_sentence_encoder(sentences) # defined in test file
>>> len(output)
512
>>> len(output[0])
3
>>> values = output[:3, 0]
>>> [round(x, 4) for x in values]
[0.0178, 0.0616, -0.0089]
"""
name = "universal_sentence_encoder"
input_types = [ColumnSchema(logical_type=NaturalLanguage)]
return_type = ColumnSchema(logical_type=Double, semantic_tags={"numeric"})
[docs] def __init__(self):
message = "In order to use the UniversalSentenceEncoder primitive install 'nlp_primitives[complete]'"
self.tf = import_or_raise("tensorflow", message)
hub = import_or_raise("tensorflow_hub", message)
self.tf.compat.v1.disable_eager_execution()
self.module_url = "https://tfhub.dev/google/universal-sentence-encoder/2"
self.embed = hub.Module(self.module_url)
self.number_output_features = 512
self.n = 512
def get_function(self):
def universal_sentence_encoder(col):
with self.tf.compat.v1.Session() as session:
session.run(
[
self.tf.compat.v1.global_variables_initializer(),
self.tf.compat.v1.tables_initializer(),
]
)
embeddings = session.run(self.embed(col.tolist()))
return embeddings.transpose()
return universal_sentence_encoder