"""This module contains various helper functions that can be used when creating
custom FactoryTX components.
"""
from typing import (
Any, Callable, Dict, Iterable, Iterator, List, Set, Tuple, TypeVar, Union,
Type, Generator
)
from base64 import urlsafe_b64encode
from collections import defaultdict
import contextlib
from datetime import datetime, date
import enum
import io
import os
import re
import tempfile
import time
import uuid
from pandas import DataFrame
import pandas as pd
import pytz
import simplejson
from factorytx import const
from factorytx import markers
# Disabled until mypy adds support for recursive types.
#Json = Union[Dict[str, 'Json'], List['Json'], str, int, float, bool]
JsonDict = Dict[str, Any]
def _json_default(o: Any) -> Any:
"""Converts custom types to JSON."""
if isinstance(o, datetime):
return o.strftime('%Y-%m-%d %H:%M:%S.%f')
elif isinstance(o, date):
return o.strftime('%Y-%m-%d')
elif isinstance(o, enum.Enum):
return o.value
else:
raise TypeError(repr(o) + " is not JSON serializable.")
ENCODER = simplejson.JSONEncoder(default=_json_default, use_decimal=True, ignore_nan=True)
DECODER = simplejson.JSONDecoder()
[docs]def jsonencode(obj: Any) -> str:
"""Converts a tree of objects to JSON. Dict key order is preserved.
>>> jsonencode({"a": 1, "c": 2, "b": date(2018, 1, 1)})
'{"a": 1, "c": 2, "b": "2018-01-01"}'
"""
return ENCODER.encode(obj)
[docs]def jsondecode(data: bytes) -> Any:
"""Converts JSON to a tree of objects. Object key order is preserved.
>>> jsondecode('{"a": 1, "c": 2, "b": "2018-01-01"}')
{'a': 1, 'c': 2, 'b': '2018-01-01'}
"""
return DECODER.decode(data) # type: ignore # typeshed thinks data must be a string.
[docs]def make_guid() -> bytes:
"""Returns a short globally-unique ASCII string."""
random_bytes = uuid.uuid4().bytes
guid = urlsafe_b64encode(random_bytes).strip(b'=')
return guid
[docs]def overwrite_atomic(path: str, data: bytes, mode: int = 0o600) -> None:
"""Atomically replaces the file at `path` with a new one containing `data`.
If power is lost during the write or the system crashes then the file will
contain either the old contents, or the new contents. It will never contain
corrupt or truncated data.
"""
dirname = os.path.dirname(path)
tmp = tempfile.NamedTemporaryFile(delete=False, dir=dirname, suffix='.tmp')
with contextlib.closing(tmp):
fp: io.BytesIO = tmp.file # type: ignore # tempfile stubs are incomplete.
filename: str = tmp.name
fp.write(data)
fp.flush()
os.fsync(fp.fileno())
os.chmod(filename, mode)
os.rename(filename, path)
[docs]def ensure_dir_exists(dirname: str, exist_ok: bool = True) -> None:
"""Creates the directory `dirname` if it doesn't already exist."""
try:
os.makedirs(dirname, exist_ok=exist_ok)
except FileExistsError:
pass
_PATH_SEPARATORS = os.sep
if os.altsep:
_PATH_SEPARATORS += os.altsep
[docs]def safe_path_join(base, *parts):
# type: (str, *str) -> str
"""Constructs a path by joining one or more path components. Unlike
`os.path.join`, a leading slash in a trailing component does not cause
previous components to be discarded.
>>> os.path.join('/a', 'b', '/c')
'/c'
>>> safe_path_join('/a', 'b', '/c')
'/a/b/c'
"""
relative_parts = [part.lstrip(_PATH_SEPARATORS) for part in parts]
non_empty_parts = [part for part in relative_parts if part]
path = os.sep.join([base] + non_empty_parts)
return path
# TODO: Move this into dataflow.py?
[docs]def add_columns_if_absent(frame: DataFrame, values: Dict[str, Any]) -> None:
"""Adds columns to a dataframe if they are not already present.
>>> df = DataFrame({"a": [1, 2, 3], "d": [3, 5, 8]})
>>> add_columns_if_absent(df, {"a": 0, "b": 1, "c": 2})
>>> df
a d b c
0 1 3 1 2
1 2 5 1 2
2 3 8 1 2
"""
columns = set(frame.columns)
for key, value in values.items():
if key not in columns:
frame[key] = value
# TODO: Move this into paths.py?
[docs]def get_component_path(component_name: str) -> str:
"""Returns the filesystem location that holds data for the named component."""
path = safe_path_join(const.COMPONENT_DATA_PATH, component_name)
return path
[docs]def group_frame_by_stream(frame: DataFrame) -> List[Tuple[Tuple[str, str], DataFrame]]:
"""Splits a DataFrame into list of per-stream frames.
Returns data in the format `[((asset, stream type), frame), ...]`.
"""
# A groupby on a 100 column DataFrame is 10x slower than finding the
# set of unique assets and stream types. Few frames will have more than one
# stream, so we can speed things up by avoiding redundant groupbys.
if frame.shape[0] == 1:
ix = frame.index[0]
asset = frame.at[ix, const.ASSET_COLNAME]
stream_type = frame.at[ix, const.STREAM_TYPE_COLNAME]
return [((asset, stream_type), frame)]
assets = frame[const.ASSET_COLNAME].unique()
stream_types = frame[const.STREAM_TYPE_COLNAME].unique()
if assets.size == 1 and stream_types.size == 1:
grouped = [((assets[0], stream_types[0]), frame)]
else:
grouped = list(frame.groupby(by=[const.ASSET_COLNAME, const.STREAM_TYPE_COLNAME]))
return grouped
A = TypeVar('A')
B = TypeVar('B')
[docs]def grouped(values: Iterable[A], key: Callable[[A], B]) -> Iterable[Tuple[B, Iterable[A]]]:
"""Returns an iterable of values grouped by the specific key. Within each
group the ordering if the original argument is preserved. This function is
more analogous to the GROUP BY operator in SQL than `itertools.groupby`.
>>> grouped([1.0, 1.3, 1.9, 0.9, 2.1, 1.1], key=round)
[(1, [1.0, 1.3, 0.9, 1.1]), (2, [1.9, 2.1])]
"""
groups: Dict[B, List[A]] = defaultdict(list)
for v in values:
groups[key(v)].append(v)
return list(groups.items())
[docs]def sanitize_column_names(frame: DataFrame) -> DataFrame:
"""Sanitizes column names into valid mongo field names. Columns names
that can't be sanitized without conflicting with existing names will be
dropped.
"""
columns: Set[str] = set(frame.columns)
drops: List[str] = [] # [column name, ...]
renames: Dict[str, str] = {} # {from -> to, ...} mapping
for colname in list(frame.columns):
sanitized = colname
if sanitized.startswith('$'):
sanitized = '_' + sanitized[1:]
if '.' in sanitized:
sanitized = sanitized.replace('.', '_')
if sanitized != colname:
if sanitized in columns:
msg = (f'Dropped the column named "{colname}" after sanitizing '
f'it since it would conflict with "{sanitized}".')
markers.warning(f'columns.{colname}.sanitize', msg)
columns.remove(colname)
drops.append(colname)
else:
renames[colname] = sanitized
columns.remove(colname)
columns.add(sanitized)
frame = frame.rename(columns=renames, copy=False)
frame.drop(drops, axis=1, inplace=True)
return frame
UNIX_EPOCH_BASE = pytz.utc.localize(datetime(1970, 1, 1))
[docs]def to_unix_epoch(dt: datetime) -> float:
"""Converts a timezone-aware datetime to a unix epoch value.
>>> tz = pytz.timezone('America/Los_Angeles')
>>> dt = tz.localize(datetime(2018, 10, 19, 10, 38, 2, 208000))
>>> to_unix_epoch(dt)
1539970682.208
"""
offset = (dt - UNIX_EPOCH_BASE).total_seconds()
return offset
NOTHING = object()
[docs]def lookahead(iterable: Iterable[A], default: Union[A, B] = None) -> Iterator[Tuple[A, Union[A, B]]]:
"""Yields a pair `(elt, next_elt)` for every `elt` in `iterable`.
If `elt` is the last element then `next_elt` is `default`; otherwise it
is the next element in the iterable.
#>>> list(lookahead([1, 2, 3]))
[(1, 2), (2, 3), (3, None)]
#>>> list(lookahead([1, 2], default=0))
[(1, 2), (2, 0)]
#>>> list(lookahead([]))
[]
"""
iterator = iter(iterable)
value = next_value = next(iterator, NOTHING)
while next_value is not NOTHING:
next_value = next(iterator, NOTHING)
yield value, next_value if next_value is not NOTHING else default # type: ignore
value = next_value
_INVALID_KAFKA_TOPIC_CHARS = re.compile('[^a-zA-Z0-9._-]+')
_TOPIC_DELIMITER = "."
[docs]def kafka_topic_with_suffix(tenant: str, suffix: str) -> str:
"""Produces generated topic string for Kafka
:param tenant: initial namespace for topic,
:param suffix: a string that will serve as the topic suffix
:return: topic name
"""
return _TOPIC_DELIMITER.join([tenant, _INVALID_KAFKA_TOPIC_CHARS.sub(_TOPIC_DELIMITER, suffix)])
def create_kafka_topic(row: Dict, topic_prefix: str, topic_suffix_field: Union[str, None], default_topic: str) -> str:
if not topic_suffix_field:
return default_topic
row_value = row.get(topic_suffix_field)
if not row_value or pd.isna(row_value):
return default_topic
return kafka_topic_with_suffix(
topic_prefix,
str(row_value)
)
def get_device_uuid() -> str:
return os.environ['DEVICE_UUID']
[docs]def class_name(typ: Type) -> str:
"""
Given a Class/Type, determine the module reference for that type
"""
cls_nm = '.'.join(
i
for i in (typ.__dict__.get('__module__'), typ.__name__)
if i is not None
)
return cls_nm
def chunk_iter(iterable: Union[list, str], chunk_len: int) -> Generator[Any, None, None]:
for i in range(0, len(iterable), chunk_len):
yield iterable[i:i + chunk_len]
class Timer:
"""Records the duration of a block of code in the field `duration`."""
def __enter__(self) -> 'Timer':
self.start = time.monotonic()
return self
def __exit__(self, *args: Any) -> Any:
self.end = time.monotonic()
self.duration = self.end - self.start