diff options
author | Rutger Broekhoff | 2024-05-02 20:27:40 +0200 |
---|---|---|
committer | Rutger Broekhoff | 2024-05-02 20:27:40 +0200 |
commit | 17a3ea880402338420699e03bcb24181e4ff3924 (patch) | |
tree | da666ef91e0b60d20aa0b01529644c136fd1f4ab /src/augmentkv6 | |
download | oeuf-17a3ea880402338420699e03bcb24181e4ff3924.tar.gz oeuf-17a3ea880402338420699e03bcb24181e4ff3924.zip |
Initial commit
Based on dc4ba6a
Diffstat (limited to 'src/augmentkv6')
-rw-r--r-- | src/augmentkv6/.envrc | 2 | ||||
-rw-r--r-- | src/augmentkv6/Makefile | 21 | ||||
-rw-r--r-- | src/augmentkv6/main.cpp | 510 |
3 files changed, 533 insertions, 0 deletions
diff --git a/src/augmentkv6/.envrc b/src/augmentkv6/.envrc new file mode 100644 index 0000000..694e74f --- /dev/null +++ b/src/augmentkv6/.envrc | |||
@@ -0,0 +1,2 @@ | |||
1 | source_env ../../ | ||
2 | export DEVMODE=1 | ||
diff --git a/src/augmentkv6/Makefile b/src/augmentkv6/Makefile new file mode 100644 index 0000000..cebb291 --- /dev/null +++ b/src/augmentkv6/Makefile | |||
@@ -0,0 +1,21 @@ | |||
1 | # Taken from: | ||
2 | # Open Source Security Foundation (OpenSSF), “Compiler Options Hardening Guide | ||
3 | # for C and C++,” OpenSSF Best Practices Working Group. Accessed: Dec. 01, | ||
4 | # 2023. [Online]. Available: | ||
5 | # https://best.openssf.org/Compiler-Hardening-Guides/Compiler-Options-Hardening-Guide-for-C-and-C++.html | ||
6 | CXXFLAGS=-std=c++2b -g -fno-omit-frame-pointer $(if $(DEVMODE),-Werror,)\ | ||
7 | -O2 -Wall -Wformat=2 -Wconversion -Wtrampolines -Wimplicit-fallthrough \ | ||
8 | -U_FORTIFY_SOURCE -D_FORTIFY_SOURCE=3 \ | ||
9 | -D_GLIBCXX_ASSERTIONS \ | ||
10 | -fstrict-flex-arrays=3 \ | ||
11 | -fstack-clash-protection -fstack-protector-strong | ||
12 | LDFLAGS=-larrow -larrow_acero -larrow_dataset -lparquet -ltmi8 -Wl,-z,defs \ | ||
13 | -Wl,-z,nodlopen -Wl,-z,noexecstack \ | ||
14 | -Wl,-z,relro -Wl,-z,now | ||
15 | |||
16 | augmentkv6: main.cpp | ||
17 | $(CXX) -fPIE -pie -o $@ $^ $(CXXFLAGS) $(LDFLAGS) | ||
18 | |||
19 | .PHONY: clean | ||
20 | clean: | ||
21 | rm augmentkv6 | ||
diff --git a/src/augmentkv6/main.cpp b/src/augmentkv6/main.cpp new file mode 100644 index 0000000..81a54d3 --- /dev/null +++ b/src/augmentkv6/main.cpp | |||
@@ -0,0 +1,510 @@ | |||
1 | // vim:set sw=2 ts=2 sts et: | ||
2 | |||
3 | #include <chrono> | ||
4 | #include <cstdio> | ||
5 | #include <deque> | ||
6 | #include <filesystem> | ||
7 | #include <format> | ||
8 | #include <fstream> | ||
9 | #include <iostream> | ||
10 | #include <string> | ||
11 | #include <string_view> | ||
12 | #include <vector> | ||
13 | |||
14 | #include <arrow/acero/exec_plan.h> | ||
15 | #include <arrow/api.h> | ||
16 | #include <arrow/compute/api.h> | ||
17 | #include <arrow/dataset/api.h> | ||
18 | #include <arrow/filesystem/api.h> | ||
19 | #include <arrow/io/api.h> | ||
20 | #include <parquet/arrow/reader.h> | ||
21 | |||
22 | #include <tmi8/kv1_index.hpp> | ||
23 | #include <tmi8/kv1_lexer.hpp> | ||
24 | #include <tmi8/kv1_parser.hpp> | ||
25 | #include <tmi8/kv1_types.hpp> | ||
26 | #include <tmi8/kv6_parquet.hpp> | ||
27 | |||
28 | using namespace std::string_view_literals; | ||
29 | |||
30 | namespace ac = arrow::acero; | ||
31 | namespace ds = arrow::dataset; | ||
32 | namespace cp = arrow::compute; | ||
33 | using namespace arrow; | ||
34 | |||
35 | using TimingClock = std::conditional_t< | ||
36 | std::chrono::high_resolution_clock::is_steady, | ||
37 | std::chrono::high_resolution_clock, | ||
38 | std::chrono::steady_clock>; | ||
39 | |||
40 | std::string readKv1() { | ||
41 | fputs("Reading KV1 from standard input\n", stderr); | ||
42 | |||
43 | char buf[4096]; | ||
44 | std::string data; | ||
45 | while (!feof(stdin) && !ferror(stdin)) { | ||
46 | size_t read = fread(buf, sizeof(char), 4096, stdin); | ||
47 | data.append(buf, read); | ||
48 | } | ||
49 | if (ferror(stdin)) { | ||
50 | fputs("Error when reading from stdin\n", stderr); | ||
51 | exit(1); | ||
52 | } | ||
53 | fprintf(stderr, "Read %lu bytes\n", data.size()); | ||
54 | |||
55 | return data; | ||
56 | } | ||
57 | |||
58 | std::vector<Kv1Token> lex() { | ||
59 | std::string data = readKv1(); | ||
60 | |||
61 | auto start = TimingClock::now(); | ||
62 | Kv1Lexer lexer(data); | ||
63 | lexer.lex(); | ||
64 | auto end = TimingClock::now(); | ||
65 | |||
66 | std::chrono::duration<double> elapsed{end - start}; | ||
67 | double bytes = static_cast<double>(data.size()) / 1'000'000; | ||
68 | double speed = bytes / elapsed.count(); | ||
69 | |||
70 | if (!lexer.errors.empty()) { | ||
71 | fputs("Lexer reported errors:\n", stderr); | ||
72 | for (const auto &error : lexer.errors) | ||
73 | fprintf(stderr, "- %s\n", error.c_str()); | ||
74 | exit(1); | ||
75 | } | ||
76 | |||
77 | fprintf(stderr, "Got %lu tokens\n", lexer.tokens.size()); | ||
78 | fprintf(stderr, "Duration: %f s\n", elapsed.count()); | ||
79 | fprintf(stderr, "Speed: %f MB/s\n", speed); | ||
80 | |||
81 | return std::move(lexer.tokens); | ||
82 | } | ||
83 | |||
84 | bool parse(Kv1Records &into) { | ||
85 | std::vector<Kv1Token> tokens = lex(); | ||
86 | |||
87 | Kv1Parser parser(tokens, into); | ||
88 | parser.parse(); | ||
89 | |||
90 | bool ok = true; | ||
91 | if (!parser.gerrors.empty()) { | ||
92 | ok = false; | ||
93 | fputs("Parser reported errors:\n", stderr); | ||
94 | for (const auto &error : parser.gerrors) | ||
95 | fprintf(stderr, "- %s\n", error.c_str()); | ||
96 | } | ||
97 | if (!parser.warns.empty()) { | ||
98 | fputs("Parser reported warnings:\n", stderr); | ||
99 | for (const auto &warn : parser.warns) | ||
100 | fprintf(stderr, "- %s\n", warn.c_str()); | ||
101 | } | ||
102 | |||
103 | fprintf(stderr, "Parsed %lu records\n", into.size()); | ||
104 | |||
105 | return ok; | ||
106 | } | ||
107 | |||
108 | void printParsedRecords(const Kv1Records &records) { | ||
109 | fputs("Parsed records:\n", stderr); | ||
110 | fprintf(stderr, " organizational_units: %lu\n", records.organizational_units.size()); | ||
111 | fprintf(stderr, " higher_organizational_units: %lu\n", records.higher_organizational_units.size()); | ||
112 | fprintf(stderr, " user_stop_points: %lu\n", records.user_stop_points.size()); | ||
113 | fprintf(stderr, " user_stop_areas: %lu\n", records.user_stop_areas.size()); | ||
114 | fprintf(stderr, " timing_links: %lu\n", records.timing_links.size()); | ||
115 | fprintf(stderr, " links: %lu\n", records.links.size()); | ||
116 | fprintf(stderr, " lines: %lu\n", records.lines.size()); | ||
117 | fprintf(stderr, " destinations: %lu\n", records.destinations.size()); | ||
118 | fprintf(stderr, " journey_patterns: %lu\n", records.journey_patterns.size()); | ||
119 | fprintf(stderr, " concession_financer_relations: %lu\n", records.concession_financer_relations.size()); | ||
120 | fprintf(stderr, " concession_areas: %lu\n", records.concession_areas.size()); | ||
121 | fprintf(stderr, " financers: %lu\n", records.financers.size()); | ||
122 | fprintf(stderr, " journey_pattern_timing_links: %lu\n", records.journey_pattern_timing_links.size()); | ||
123 | fprintf(stderr, " points: %lu\n", records.points.size()); | ||
124 | fprintf(stderr, " point_on_links: %lu\n", records.point_on_links.size()); | ||
125 | fprintf(stderr, " icons: %lu\n", records.icons.size()); | ||
126 | fprintf(stderr, " notices: %lu\n", records.notices.size()); | ||
127 | fprintf(stderr, " notice_assignments: %lu\n", records.notice_assignments.size()); | ||
128 | fprintf(stderr, " time_demand_groups: %lu\n", records.time_demand_groups.size()); | ||
129 | fprintf(stderr, " time_demand_group_run_times: %lu\n", records.time_demand_group_run_times.size()); | ||
130 | fprintf(stderr, " period_groups: %lu\n", records.period_groups.size()); | ||
131 | fprintf(stderr, " specific_days: %lu\n", records.specific_days.size()); | ||
132 | fprintf(stderr, " timetable_versions: %lu\n", records.timetable_versions.size()); | ||
133 | fprintf(stderr, " public_journeys: %lu\n", records.public_journeys.size()); | ||
134 | fprintf(stderr, " period_group_validities: %lu\n", records.period_group_validities.size()); | ||
135 | fprintf(stderr, " exceptional_operating_days: %lu\n", records.exceptional_operating_days.size()); | ||
136 | fprintf(stderr, " schedule_versions: %lu\n", records.schedule_versions.size()); | ||
137 | fprintf(stderr, " public_journey_passing_times: %lu\n", records.public_journey_passing_times.size()); | ||
138 | fprintf(stderr, " operating_days: %lu\n", records.operating_days.size()); | ||
139 | } | ||
140 | |||
141 | void printIndexSize(const Kv1Index &index) { | ||
142 | fputs("Index size:\n", stderr); | ||
143 | fprintf(stderr, " organizational_units: %lu\n", index.organizational_units.size()); | ||
144 | fprintf(stderr, " user_stop_points: %lu\n", index.user_stop_points.size()); | ||
145 | fprintf(stderr, " user_stop_areas: %lu\n", index.user_stop_areas.size()); | ||
146 | fprintf(stderr, " timing_links: %lu\n", index.timing_links.size()); | ||
147 | fprintf(stderr, " links: %lu\n", index.links.size()); | ||
148 | fprintf(stderr, " lines: %lu\n", index.lines.size()); | ||
149 | fprintf(stderr, " destinations: %lu\n", index.destinations.size()); | ||
150 | fprintf(stderr, " journey_patterns: %lu\n", index.journey_patterns.size()); | ||
151 | fprintf(stderr, " concession_financer_relations: %lu\n", index.concession_financer_relations.size()); | ||
152 | fprintf(stderr, " concession_areas: %lu\n", index.concession_areas.size()); | ||
153 | fprintf(stderr, " financers: %lu\n", index.financers.size()); | ||
154 | fprintf(stderr, " journey_pattern_timing_links: %lu\n", index.journey_pattern_timing_links.size()); | ||
155 | fprintf(stderr, " points: %lu\n", index.points.size()); | ||
156 | fprintf(stderr, " point_on_links: %lu\n", index.point_on_links.size()); | ||
157 | fprintf(stderr, " icons: %lu\n", index.icons.size()); | ||
158 | fprintf(stderr, " notices: %lu\n", index.notices.size()); | ||
159 | fprintf(stderr, " time_demand_groups: %lu\n", index.time_demand_groups.size()); | ||
160 | fprintf(stderr, " time_demand_group_run_times: %lu\n", index.time_demand_group_run_times.size()); | ||
161 | fprintf(stderr, " period_groups: %lu\n", index.period_groups.size()); | ||
162 | fprintf(stderr, " specific_days: %lu\n", index.specific_days.size()); | ||
163 | fprintf(stderr, " timetable_versions: %lu\n", index.timetable_versions.size()); | ||
164 | fprintf(stderr, " public_journeys: %lu\n", index.public_journeys.size()); | ||
165 | fprintf(stderr, " period_group_validities: %lu\n", index.period_group_validities.size()); | ||
166 | fprintf(stderr, " exceptional_operating_days: %lu\n", index.exceptional_operating_days.size()); | ||
167 | fprintf(stderr, " schedule_versions: %lu\n", index.schedule_versions.size()); | ||
168 | fprintf(stderr, " public_journey_passing_times: %lu\n", index.public_journey_passing_times.size()); | ||
169 | fprintf(stderr, " operating_days: %lu\n", index.operating_days.size()); | ||
170 | } | ||
171 | |||
172 | struct BasicJourneyKey { | ||
173 | std::string data_owner_code; | ||
174 | std::string line_planning_number; | ||
175 | int journey_number; | ||
176 | |||
177 | auto operator<=>(const BasicJourneyKey &) const = default; | ||
178 | }; | ||
179 | |||
180 | size_t hash_value(const BasicJourneyKey &k) { | ||
181 | size_t seed = 0; | ||
182 | |||
183 | boost::hash_combine(seed, k.data_owner_code); | ||
184 | boost::hash_combine(seed, k.line_planning_number); | ||
185 | boost::hash_combine(seed, k.journey_number); | ||
186 | |||
187 | return seed; | ||
188 | } | ||
189 | |||
190 | using BasicJourneyKeySet = std::unordered_set<BasicJourneyKey, boost::hash<BasicJourneyKey>>; | ||
191 | |||
192 | arrow::Result<BasicJourneyKeySet> basicJourneys(std::shared_ptr<arrow::Table> table) { | ||
193 | ac::TableSourceNodeOptions table_source_node_options(table); | ||
194 | ac::Declaration table_source("table_source", std::move(table_source_node_options)); | ||
195 | auto aggregate_options = ac::AggregateNodeOptions{ | ||
196 | /* .aggregates = */ {}, | ||
197 | /* .keys = */ { "data_owner_code", "line_planning_number", "journey_number" }, | ||
198 | }; | ||
199 | ac::Declaration aggregate("aggregate", { std::move(table_source) }, std::move(aggregate_options)); | ||
200 | |||
201 | std::shared_ptr<arrow::Table> result; | ||
202 | ARROW_ASSIGN_OR_RAISE(result, ac::DeclarationToTable(std::move(aggregate))); | ||
203 | |||
204 | std::shared_ptr<arrow::ChunkedArray> data_owner_codes = result->GetColumnByName("data_owner_code"); | ||
205 | std::shared_ptr<arrow::ChunkedArray> line_planning_numbers = result->GetColumnByName("line_planning_number"); | ||
206 | std::shared_ptr<arrow::ChunkedArray> journey_numbers = result->GetColumnByName("journey_number"); | ||
207 | |||
208 | int i_data_owner_codes_chunk = 0; | ||
209 | int i_journey_numbers_chunk = 0; | ||
210 | int i_line_planning_numbers_chunk = 0; | ||
211 | int i_in_data_owner_codes_chunk = 0; | ||
212 | int i_in_journey_numbers_chunk = 0; | ||
213 | int i_in_line_planning_numbers_chunk = 0; | ||
214 | |||
215 | BasicJourneyKeySet journeys; | ||
216 | |||
217 | for (int64_t i = 0; i < result->num_rows(); i++) { | ||
218 | auto data_owner_codes_chunk = std::static_pointer_cast<arrow::StringArray>(data_owner_codes->chunk(i_data_owner_codes_chunk)); | ||
219 | auto line_planning_numbers_chunk = std::static_pointer_cast<arrow::StringArray>(line_planning_numbers->chunk(i_line_planning_numbers_chunk)); | ||
220 | auto journey_numbers_chunk = std::static_pointer_cast<arrow::UInt32Array>(journey_numbers->chunk(i_journey_numbers_chunk)); | ||
221 | |||
222 | std::string_view data_owner_code = data_owner_codes_chunk->Value(i_in_data_owner_codes_chunk); | ||
223 | std::string_view line_planning_number = line_planning_numbers_chunk->Value(i_in_line_planning_numbers_chunk); | ||
224 | uint32_t journey_number = journey_numbers_chunk->Value(i_in_journey_numbers_chunk); | ||
225 | |||
226 | journeys.emplace( | ||
227 | std::string(data_owner_code), | ||
228 | std::string(line_planning_number), | ||
229 | journey_number | ||
230 | ); | ||
231 | |||
232 | i_in_data_owner_codes_chunk++; | ||
233 | i_in_line_planning_numbers_chunk++; | ||
234 | i_in_journey_numbers_chunk++; | ||
235 | if (i_in_data_owner_codes_chunk >= data_owner_codes_chunk->length()) { | ||
236 | i_data_owner_codes_chunk++; | ||
237 | i_in_data_owner_codes_chunk = 0; | ||
238 | } | ||
239 | if (i_in_line_planning_numbers_chunk >= line_planning_numbers_chunk->length()) { | ||
240 | i_line_planning_numbers_chunk++; | ||
241 | i_in_line_planning_numbers_chunk = 0; | ||
242 | } | ||
243 | if (i_in_journey_numbers_chunk >= journey_numbers_chunk->length()) { | ||
244 | i_journey_numbers_chunk++; | ||
245 | i_in_journey_numbers_chunk = 0; | ||
246 | } | ||
247 | } | ||
248 | |||
249 | return journeys; | ||
250 | } | ||
251 | |||
252 | struct DistanceKey { | ||
253 | BasicJourneyKey journey; | ||
254 | std::string last_passed_user_stop_code; | ||
255 | |||
256 | auto operator<=>(const DistanceKey &) const = default; | ||
257 | }; | ||
258 | |||
259 | size_t hash_value(const DistanceKey &k) { | ||
260 | size_t seed = 0; | ||
261 | |||
262 | boost::hash_combine(seed, k.journey); | ||
263 | boost::hash_combine(seed, k.last_passed_user_stop_code); | ||
264 | |||
265 | return seed; | ||
266 | } | ||
267 | |||
268 | struct DistanceTimingLink { | ||
269 | const Kv1JourneyPatternTimingLink *jopatili; | ||
270 | double distance_since_start_of_journey = 0; // at the start of the link | ||
271 | }; | ||
272 | |||
273 | using DistanceMap = std::unordered_map<DistanceKey, double, boost::hash<DistanceKey>>; | ||
274 | |||
275 | // Returns a map, where | ||
276 | // DataOwnerCode + LinePlanningNumber + JourneyNumber + UserStopCode -> | ||
277 | // Distance of Last User Stop | ||
278 | DistanceMap makeDistanceMap(Kv1Records &records, Kv1Index &index, BasicJourneyKeySet &journeys) { | ||
279 | std::unordered_map< | ||
280 | Kv1JourneyPattern::Key, | ||
281 | std::vector<DistanceTimingLink>, | ||
282 | boost::hash<Kv1JourneyPattern::Key>> jopatili_index; | ||
283 | std::unordered_map< | ||
284 | BasicJourneyKey, | ||
285 | const Kv1PublicJourney *, | ||
286 | boost::hash<BasicJourneyKey>> journey_index; | ||
287 | for (size_t i = 0; i < records.public_journeys.size(); i++) { | ||
288 | const Kv1PublicJourney *pujo = &records.public_journeys[i]; | ||
289 | |||
290 | BasicJourneyKey journey_key( | ||
291 | pujo->key.data_owner_code, | ||
292 | pujo->key.line_planning_number, | ||
293 | pujo->key.journey_number); | ||
294 | |||
295 | if (journeys.contains(journey_key)) { | ||
296 | journey_index[journey_key] = pujo; | ||
297 | |||
298 | Kv1JourneyPattern::Key jopa_key( | ||
299 | pujo->key.data_owner_code, | ||
300 | pujo->key.line_planning_number, | ||
301 | pujo->journey_pattern_code); | ||
302 | jopatili_index[jopa_key] = {}; | ||
303 | } | ||
304 | } | ||
305 | |||
306 | for (size_t i = 0; i < records.journey_pattern_timing_links.size(); i++) { | ||
307 | const Kv1JourneyPatternTimingLink *jopatili = &records.journey_pattern_timing_links[i]; | ||
308 | Kv1JourneyPattern::Key jopa_key( | ||
309 | jopatili->key.data_owner_code, | ||
310 | jopatili->key.line_planning_number, | ||
311 | jopatili->key.journey_pattern_code); | ||
312 | if (jopatili_index.contains(jopa_key)) { | ||
313 | jopatili_index[jopa_key].push_back(DistanceTimingLink(jopatili, 0)); | ||
314 | } | ||
315 | } | ||
316 | |||
317 | for (auto &[jopa_key, timing_links] : jopatili_index) { | ||
318 | std::sort(timing_links.begin(), timing_links.end(), [](auto a, auto b) { | ||
319 | return a.jopatili->key.timing_link_order < b.jopatili->key.timing_link_order; | ||
320 | }); | ||
321 | |||
322 | const std::string transport_type = index.journey_patterns[jopa_key]->p_line->transport_type; | ||
323 | |||
324 | for (size_t i = 1; i < timing_links.size(); i++) { | ||
325 | DistanceTimingLink *timing_link = &timing_links[i]; | ||
326 | DistanceTimingLink *prev_timing_link = &timing_links[i - 1]; | ||
327 | |||
328 | const Kv1Link::Key link_key( | ||
329 | prev_timing_link->jopatili->key.data_owner_code, | ||
330 | prev_timing_link->jopatili->user_stop_code_begin, | ||
331 | prev_timing_link->jopatili->user_stop_code_end, | ||
332 | transport_type); | ||
333 | double link_distance = index.links[link_key]->distance; | ||
334 | timing_link->distance_since_start_of_journey = | ||
335 | prev_timing_link->distance_since_start_of_journey + link_distance; | ||
336 | } | ||
337 | } | ||
338 | |||
339 | // DataOwnerCode + LinePlanningNumber + JourneyNumber + UserStopCode -> | ||
340 | // Distance of Last User Stop | ||
341 | DistanceMap distance_map; | ||
342 | |||
343 | for (const auto &journey : journeys) { | ||
344 | const Kv1PublicJourney *pujo = journey_index[journey]; | ||
345 | if (pujo == nullptr) { | ||
346 | std::cerr << "Warning: No PUJO found for [" << journey.data_owner_code << "] " | ||
347 | << journey.line_planning_number << "/" << journey.journey_number << std::endl; | ||
348 | continue; | ||
349 | } | ||
350 | Kv1JourneyPattern::Key jopa_key( | ||
351 | pujo->key.data_owner_code, | ||
352 | pujo->key.line_planning_number, | ||
353 | pujo->journey_pattern_code); | ||
354 | for (const auto &timing_link : jopatili_index[jopa_key]) { | ||
355 | DistanceKey key(journey, timing_link.jopatili->user_stop_code_begin); | ||
356 | distance_map[key] = timing_link.distance_since_start_of_journey; | ||
357 | } | ||
358 | } | ||
359 | |||
360 | return distance_map; | ||
361 | } | ||
362 | |||
363 | arrow::Result<std::shared_ptr<arrow::Table>> augment( | ||
364 | std::shared_ptr<arrow::Table> table, | ||
365 | const DistanceMap &distance_map | ||
366 | ) { | ||
367 | for (int i = 0; i < table->num_columns(); i++) { | ||
368 | if (table->column(i)->num_chunks() > 1) { | ||
369 | std::stringstream ss; | ||
370 | ss << "Error: Expected column " << i | ||
371 | << " (" << table->ColumnNames()[i] << ") to have 1 chunk, got " | ||
372 | << table->column(i)->num_chunks(); | ||
373 | return arrow::Status::Invalid(ss.str()); | ||
374 | } | ||
375 | } | ||
376 | |||
377 | auto data_owner_codes = std::static_pointer_cast<arrow::StringArray>(table->GetColumnByName("data_owner_code")->chunk(0)); | ||
378 | auto line_planning_numbers = std::static_pointer_cast<arrow::StringArray>(table->GetColumnByName("line_planning_number")->chunk(0)); | ||
379 | auto journey_numbers = std::static_pointer_cast<arrow::UInt32Array>(table->GetColumnByName("journey_number")->chunk(0)); | ||
380 | auto user_stop_codes = std::static_pointer_cast<arrow::StringArray>(table->GetColumnByName("user_stop_code")->chunk(0)); | ||
381 | auto distance_since_last_user_stops = std::static_pointer_cast<arrow::UInt32Array>(table->GetColumnByName("distance_since_last_user_stop")->chunk(0)); | ||
382 | auto timestamps = std::static_pointer_cast<arrow::TimestampArray>(table->GetColumnByName("timestamp")->chunk(0)); | ||
383 | |||
384 | auto timestamps_type = table->schema()->GetFieldByName("timestamp")->type(); | ||
385 | if (timestamps_type->id() != arrow::Type::TIMESTAMP) | ||
386 | return arrow::Status::Invalid("Field 'timestamp' does not have expected type TIMESTAMP"); | ||
387 | if (std::static_pointer_cast<arrow::TimestampType>(timestamps_type)->unit() != arrow::TimeUnit::MILLI) | ||
388 | return arrow::Status::Invalid("Field 'timestamp' does not have unit MILLI"); | ||
389 | if (!std::static_pointer_cast<arrow::TimestampType>(timestamps_type)->timezone().empty()) | ||
390 | return arrow::Status::Invalid("Field 'timestamp' should have empty time zone name"); | ||
391 | |||
392 | std::shared_ptr<arrow::Field> field_distance_since_start_of_journey = | ||
393 | arrow::field("distance_since_start_of_journey", arrow::uint32()); | ||
394 | std::shared_ptr<arrow::Field> field_day_of_week = | ||
395 | arrow::field("timestamp_iso_day_of_week", arrow::int64()); | ||
396 | std::shared_ptr<arrow::Field> field_date = | ||
397 | arrow::field("timestamp_date", arrow::date32()); | ||
398 | std::shared_ptr<arrow::Field> field_local_time = | ||
399 | arrow::field("timestamp_local_time", arrow::time32(arrow::TimeUnit::SECOND)); | ||
400 | arrow::UInt32Builder distance_since_start_of_journey_builder; | ||
401 | arrow::Int64Builder day_of_week_builder; | ||
402 | arrow::Date32Builder date_builder; | ||
403 | arrow::Time32Builder local_time_builder(arrow::time32(arrow::TimeUnit::SECOND), arrow::default_memory_pool()); | ||
404 | |||
405 | const std::chrono::time_zone *amsterdam = std::chrono::locate_zone("Europe/Amsterdam"); | ||
406 | |||
407 | for (int64_t i = 0; i < table->num_rows(); i++) { | ||
408 | DistanceKey key( | ||
409 | BasicJourneyKey( | ||
410 | std::string(data_owner_codes->Value(i)), | ||
411 | std::string(line_planning_numbers->Value(i)), | ||
412 | journey_numbers->Value(i)), | ||
413 | std::string(user_stop_codes->Value(i))); | ||
414 | |||
415 | uint32_t distance_since_last_user_stop = distance_since_last_user_stops->Value(i); | ||
416 | if (distance_map.contains(key)) { | ||
417 | uint32_t total_distance = distance_since_last_user_stop + static_cast<uint32_t>(distance_map.at(key)); | ||
418 | ARROW_RETURN_NOT_OK(distance_since_start_of_journey_builder.Append(total_distance)); | ||
419 | } else { | ||
420 | ARROW_RETURN_NOT_OK(distance_since_start_of_journey_builder.AppendNull()); | ||
421 | } | ||
422 | |||
423 | // Welp, this has gotten a bit complicated! | ||
424 | std::chrono::sys_seconds timestamp(std::chrono::floor<std::chrono::seconds>(std::chrono::milliseconds(timestamps->Value(i)))); | ||
425 | std::chrono::zoned_seconds zoned_timestamp(amsterdam, timestamp); | ||
426 | std::chrono::local_seconds local_timestamp(zoned_timestamp); | ||
427 | std::chrono::local_days local_date = std::chrono::floor<std::chrono::days>(local_timestamp); | ||
428 | std::chrono::year_month_day date(local_date); | ||
429 | std::chrono::weekday day_of_week(local_date); | ||
430 | std::chrono::hh_mm_ss<std::chrono::seconds> time(local_timestamp - local_date); | ||
431 | std::chrono::sys_days unix_date(date); | ||
432 | |||
433 | int64_t iso_day_of_week = day_of_week.iso_encoding(); | ||
434 | int32_t unix_days = static_cast<int32_t>(unix_date.time_since_epoch().count()); | ||
435 | int32_t secs_since_midnight = static_cast<int32_t>(std::chrono::seconds(time).count()); | ||
436 | |||
437 | ARROW_RETURN_NOT_OK(day_of_week_builder.Append(iso_day_of_week)); | ||
438 | ARROW_RETURN_NOT_OK(date_builder.Append(unix_days)); | ||
439 | ARROW_RETURN_NOT_OK(local_time_builder.Append(secs_since_midnight)); | ||
440 | } | ||
441 | |||
442 | ARROW_ASSIGN_OR_RAISE(auto distance_since_start_of_journey_col_chunk, distance_since_start_of_journey_builder.Finish()); | ||
443 | ARROW_ASSIGN_OR_RAISE(auto day_of_week_col_chunk, day_of_week_builder.Finish()); | ||
444 | ARROW_ASSIGN_OR_RAISE(auto date_col_chunk, date_builder.Finish()); | ||
445 | ARROW_ASSIGN_OR_RAISE(auto local_time_col_chunk, local_time_builder.Finish()); | ||
446 | auto distance_since_start_of_journey_col = | ||
447 | std::make_shared<arrow::ChunkedArray>(distance_since_start_of_journey_col_chunk); | ||
448 | auto day_of_week_col = std::make_shared<arrow::ChunkedArray>(day_of_week_col_chunk); | ||
449 | auto date_col = std::make_shared<arrow::ChunkedArray>(date_col_chunk); | ||
450 | auto local_time_col = std::make_shared<arrow::ChunkedArray>(local_time_col_chunk); | ||
451 | |||
452 | ARROW_ASSIGN_OR_RAISE(table, table->AddColumn( | ||
453 | table->num_columns(), | ||
454 | field_distance_since_start_of_journey, | ||
455 | distance_since_start_of_journey_col)); | ||
456 | ARROW_ASSIGN_OR_RAISE(table, table->AddColumn(table->num_columns(), field_day_of_week, day_of_week_col)); | ||
457 | ARROW_ASSIGN_OR_RAISE(table, table->AddColumn(table->num_columns(), field_date, date_col)); | ||
458 | ARROW_ASSIGN_OR_RAISE(table, table->AddColumn(table->num_columns(), field_local_time, local_time_col)); | ||
459 | |||
460 | return table; | ||
461 | } | ||
462 | |||
463 | arrow::Status processTables(Kv1Records &records, Kv1Index &index) { | ||
464 | std::shared_ptr<arrow::io::RandomAccessFile> input; | ||
465 | ARROW_ASSIGN_OR_RAISE(input, arrow::io::ReadableFile::Open("oeuf-input.parquet")); | ||
466 | |||
467 | std::unique_ptr<parquet::arrow::FileReader> arrow_reader; | ||
468 | ARROW_RETURN_NOT_OK(parquet::arrow::OpenFile(input, arrow::default_memory_pool(), &arrow_reader)); | ||
469 | |||
470 | std::shared_ptr<arrow::Table> table; | ||
471 | ARROW_RETURN_NOT_OK(arrow_reader->ReadTable(&table)); | ||
472 | |||
473 | std::cerr << "Input KV6 file has " << table->num_rows() << " rows" << std::endl; | ||
474 | ARROW_ASSIGN_OR_RAISE(BasicJourneyKeySet journeys, basicJourneys(table)); | ||
475 | std::cerr << "Found " << journeys.size() << " distinct journeys" << std::endl; | ||
476 | DistanceMap distance_map = makeDistanceMap(records, index, journeys); | ||
477 | std::cerr << "Distance map has " << distance_map.size() << " keys" << std::endl; | ||
478 | |||
479 | std::cerr << "Creating augmented table" << std::endl; | ||
480 | ARROW_ASSIGN_OR_RAISE(std::shared_ptr<arrow::Table> augmented, augment(table, distance_map)); | ||
481 | |||
482 | std::cerr << "Writing augmented table" << std::endl; | ||
483 | return writeArrowTableAsParquetFile(*augmented, "oeuf-augmented.parquet"); | ||
484 | } | ||
485 | |||
486 | int main(int argc, char *argv[]) { | ||
487 | Kv1Records records; | ||
488 | if (!parse(records)) { | ||
489 | fputs("Error parsing records, exiting\n", stderr); | ||
490 | return EXIT_FAILURE; | ||
491 | } | ||
492 | printParsedRecords(records); | ||
493 | fputs("Indexing...\n", stderr); | ||
494 | Kv1Index index(&records); | ||
495 | fprintf(stderr, "Indexed %lu records\n", index.size()); | ||
496 | // Only notice assignments are not indexed. If this equality is not valid, | ||
497 | // then this means that we had duplicate keys or that something else went | ||
498 | // wrong. That would really not be great. | ||
499 | assert(index.size() == records.size() - records.notice_assignments.size()); | ||
500 | printIndexSize(index); | ||
501 | fputs("Linking records...\n", stderr); | ||
502 | kv1LinkRecords(index); | ||
503 | fputs("Done linking\n", stderr); | ||
504 | |||
505 | arrow::Status st = processTables(records, index); | ||
506 | if (!st.ok()) { | ||
507 | std::cerr << "Failed to process tables: " << st << std::endl; | ||
508 | return EXIT_FAILURE; | ||
509 | } | ||
510 | } | ||