check_wholesaler_investment(dataloader)

Validates shape, columns and data for investments across parquet, json and hdf5 for every wholesaler

Source code in wt_ml/dataset/data_validator/checks/check_wholesaler_investment.py
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
def check_wholesaler_investment(dataloader: DataLoader) -> DataStatus:
    """Validates shape, columns and data for investments across parquet, json and hdf5 for every wholesaler"""
    data = dataloader.wholesaler_investment
    statuses: list[StatusType] = []
    messages: list[str] = []
    hdf5_wholesalers = data.hdf5.wholesaler.unique()

    for wh in hdf5_wholesalers:
        wh_parquet = data.parquet.loc[data.parquet["wholesaler"] == wh]
        wh_json = data.json.loc[data.json["wholesaler"] == wh]
        wh_hdf5 = data.hdf5.loc[data.hdf5["wholesaler"] == wh]
        hdf5_columns = wh_hdf5.columns
        wh_json_date = wh_json.loc[wh_json["date"].isin(dataloader.date_idx), hdf5_columns]
        wh_json_date = wh_json_date.sort_values(by=["date", "brand_code", "product_code"])
        wh_hdf5_date = wh_hdf5.loc[wh_hdf5["date"].isin(dataloader.date_idx)]
        wh_hdf5_date = wh_hdf5_date.sort_values(by=["date", "brand_code", "product_code"])
        if wh_parquet.shape == wh_json.loc[wh_json.select_dtypes(float).sum(1) != 0].shape:
            statuses.append(StatusType.PASS)
        else:
            statuses.append(StatusType.FAIL)
            messages.append(
                f"Shapes mismatch between parquet {wh_parquet.shape} and "
                f"json {wh_json.loc[wh_json.select_dtypes(float).sum(1) != 0].shape} in wholesaler Investment data for"
                f" wholesaler : {wh}"
            )
        if all(wh_hdf5.columns.isin(wh_json.columns)):
            statuses.append(StatusType.PASS)
        else:
            statuses.append(StatusType.FAIL)
            messages.append(
                f"Following columns mismatch between hdf5 and json: {wh_hdf5.columns.difference(wh_json.columns)} for"
                f"wholesaler : {wh}"
            )
        if wh_hdf5_date.select_dtypes([int, float]).shape == wh_json_date.select_dtypes([int, float]).shape:
            if np.allclose(
                wh_hdf5_date.select_dtypes([int, float]).values,
                wh_json_date.select_dtypes([int, float]).values,
                atol=ABS_TOLERANCE,
            ):
                statuses.append(StatusType.PASS)
            else:
                statuses.append(StatusType.FAIL)
                messages.append(f"Data mismatch for hdf5 and json in wholesaler Investment data for wholesaler : {wh}")
        else:
            statuses.append(StatusType.FAIL)
            messages.append(
                f"Shape mismatch for hdf5 {wh_hdf5_date.select_dtypes([int, float]).shape} and "
                f"json {wh_json_date.select_dtypes([int, float]).shape} in "
                f"wholesaler Investment data for wholesaler : {wh}"
            )
    return DataStatus(
        status=StatusType.PASS if all(status == StatusType.PASS for status in statuses) else StatusType.FAIL,
        message="\n".join(messages),
    )