-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_sales_data_processor.py
176 lines (146 loc) · 5.02 KB
/
test_sales_data_processor.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
import pytest
from sales_data_processor import normalize_product_name, process_sales_data
from pyspark.sql import SparkSession
import json
import os
import csv
from py4j.protocol import Py4JJavaError
tsv_headers = [
"product_id",
"store_id",
"product_name",
"units",
"transaction_id",
"price",
"timestamp"]
def write_tsv(data: list, output_file: str) -> None:
with open(output_file, "w", newline='\n') as f:
writer = csv.DictWriter(f, fieldnames=tsv_headers, delimiter="\t")
writer.writeheader()
for row in data:
writer.writerow(row)
class PipelineRunner:
def __init__(self, data):
self.input_file = 'test_sales_data.tsv'
self.output_file = 'test_sales_profiles.json'
self.data = data
def __enter__(self):
write_tsv(self.data, self.input_file)
return self
def read_sales_report(self):
with open(self.output_file) as f:
return json.load(f)
def run_pipeline(self, spark, store_ids):
process_sales_data(spark, store_ids, self.input_file, self.output_file)
def __exit__(self, exc_type, exc_val, exc_tb):
try:
os.remove(self.input_file)
except FileNotFoundError:
pass
try:
os.remove(self.output_file)
except FileNotFoundError:
pass
@pytest.fixture()
def spark():
return SparkSession.builder \
.appName("SalesDataProcessorTest") \
.getOrCreate()
def test_normalize_product_name():
assert normalize_product_name("coffee-large") == "coffee large"
assert normalize_product_name("coffee_large") == "coffee large"
assert normalize_product_name("coffee large") == "coffee large"
def test_filter_negative_and_0_units(spark):
"""units must be positive or 0, negative values of unit must be filtered out."""
store_ids = {1}
test_data = [
{"product_id": 0,
"store_id": 1,
"product_name": "coffee_large",
"units": -3,
"transaction_id": 1,
"price": 1.0,
"timestamp": "2021-12-01 17:48:41.569057"},
{"product_id": 1,
"store_id": 1,
"product_name": "coffee_small",
"units": 0,
"transaction_id": 2,
"price": 1.0,
"timestamp": "2021-12-01 17:48:41.569057"},
{"product_id": 2,
"store_id": 1,
"product_name": "coffee_medium",
"units": 1,
"transaction_id": 3,
"price": 1.0,
"timestamp": "2021-12-01 17:48:41.569057"}
]
with PipelineRunner(test_data) as runner:
runner.run_pipeline(spark, store_ids)
sales_profiles = runner.read_sales_report()
assert "coffee large" not in sales_profiles["1"]
assert "coffee small" not in sales_profiles["1"]
assert "coffee medium" in sales_profiles["1"]
def test_empty_sales_data_is_valid(spark):
"""If the sales data is empty, the sales profile should be empty as well."""
store_ids = {1}
test_data = []
with PipelineRunner(test_data) as runner:
runner.run_pipeline(spark, store_ids)
sales_profiles = runner.read_sales_report()
assert len(sales_profiles) == 0
def test_malformed_data_error(spark):
store_ids = {1}
test_data = [
{"product_id": 0,
"store_id": "what",
"product_name": "coffee_large",
"units": 3,
"transaction_id": 1,
"price": 1.0,
"timestamp": "2021-12-01 17:48:41.569057"}
]
with PipelineRunner(test_data) as runner:
with pytest.raises(Py4JJavaError):
runner.run_pipeline(spark, store_ids)
def test_rows_with_missing_data_are_dropped(spark, caplog):
store_ids = {1}
test_data = [
{"product_id": 0,
"store_id": 1,
"product_name": "coffee_large",
"units": None,
"transaction_id": 1,
"price": 1.0,
"timestamp": "2021-12-01 17:48:41.569057"}
]
with PipelineRunner(test_data) as runner:
runner.run_pipeline(spark, store_ids)
sales_profiles = runner.read_sales_report()
assert len(sales_profiles) == 0
assert "Number of rows with null or missing values: 1" in caplog.text
def test_process_sales_data_results_calculation(spark):
store_ids = {1}
test_data = [
{"product_id": 0,
"store_id": 1,
"product_name": "coffee_large",
"units": 3,
"transaction_id": 1,
"price": 1.0,
"timestamp": "2021-12-01 17:48:41.569057"},
{"product_id": 1,
"store_id": 1,
"product_name": "coffee_small",
"units": 1,
"transaction_id": 2,
"price": 1.0,
"timestamp": "2021-12-01 17:48:41.569057"}
]
with PipelineRunner(test_data) as runner:
runner.run_pipeline(spark, store_ids)
sales_profiles = runner.read_sales_report()
assert "1" in sales_profiles # keys should be strings, not ints
assert sales_profiles["1"]["coffee large"] == 0.75
assert sales_profiles["1"]["coffee small"] == 0.25