"""This module includes helper functions for writing unit tests for FactoryTX
components.
"""
import contextlib
import io
from typing import Dict, Generator, List, Tuple
import dateutil.parser
import pandas as pd
import testfixtures
from factorytx.base import Transform, Transmit
from factorytx.dataflow import InputStreamId
from factorytx.config import get_component_class, TRANSFORM_INFO
from factorytx.validation import ValidationMessage, has_errors, Level
_MINUS_INFINITY = float('-inf')
def _normalize_record(record_dict: dict) -> List[Tuple]:
return sorted((k, v if v == v else _MINUS_INFINITY) for k, v in record_dict.items())
[docs]def normalize_frame(df: pd.DataFrame) -> List[List[tuple]]:
"""Converts a DataFrame to a normalized form for comparisons."""
return sorted(_normalize_record(d) for d in df.to_dict(orient='records'))
[docs]def dataframe_to_csv(df: pd.DataFrame, remove_newlines: bool = False) -> str:
"""Converts a DataFrame to a CSV for assertions. Columns are written in
lexicographic order.
"""
df = df.sort_index(axis=1, inplace=False)
buf = io.StringIO()
df.to_csv(buf, index=False)
return_val = buf.getvalue()
if remove_newlines:
return return_val.replace('\r', '').replace('\n', '')
return return_val
[docs]def csv_string_to_dataframe(multiline_string: str) -> io.StringIO:
"""Converts the CSV string buffer into a Pandas DataFrame."""
df = pd.read_csv(io.StringIO(multiline_string), header=0, index_col=None)
return df
def csv_file_to_dataframe(filename: str) -> pd.DataFrame:
"""Converts a CSV file into a Pandas DataFrame."""
df = pd.read_csv(filename, header=0, index_col=None)
return df
class MemoryTransmit(Transmit):
"""The memory transmit stores data frames in its `received` attribute."""
@staticmethod
def clean(config: dict, root_config: dict) -> List[ValidationMessage]:
return []
def __init__(self, config: dict, root_config: dict) -> None:
super().__init__(config, root_config)
self.processed: List[Tuple[InputStreamId, pd.DataFrame]] = []
@property
def received(self) -> List[pd.DataFrame]:
return [df for (_, df) in self.processed]
def clear(self) -> None:
self.processed = []
def process(self, input_stream_id: InputStreamId, input_frame: pd.DataFrame) -> None:
self.processed.append((input_stream_id, input_frame))
def purge(self, streams: List[InputStreamId]) -> None:
pass
@staticmethod
def create(filter_stream: List[str]) -> 'MemoryTransmit':
return MemoryTransmit({
'transmit_name': 'my fake transmit',
'filter_stream': filter_stream,
}, {})
@contextlib.contextmanager
def freeze_time(module_name: str,
timestamp: str) -> Generator[None, None, None]:
"""Freeze time by mocking datetime, so tests can be executed in a specific
time context.
Example:
>>> with freeze_time('my_module', "1985-10-26 01:21:00.000"):
>>> datetime.now()
datetime(1985, 10, 26, 1, 21, 0)
"""
d = dateutil.parser.parse(timestamp)
fake_datetime = testfixtures.test_datetime(d.year, d.month, d.day, d.hour, d.minute,
d.second, d.microsecond, d.tzinfo, delta=0)
fake_time = testfixtures.test_time(d.year, d.month, d.day, d.hour, d.minute,
d.second, d.microsecond, delta=0)
with testfixtures.Replacer() as replacer:
replacer.replace(f'{module_name}.datetime', fake_datetime)
replacer.replace(f'time.time', fake_time)
yield
def get_validation_errors(validation_results: List[ValidationMessage]) -> List[ValidationMessage]:
validation_errors = [r for r in validation_results if r.level == Level.ERROR]
return validation_errors