aboutsummaryrefslogtreecommitdiffstats
path: root/src/augmentkv6
diff options
context:
space:
mode:
authorLibravatar Rutger Broekhoff2024-05-02 20:27:40 +0200
committerLibravatar Rutger Broekhoff2024-05-02 20:27:40 +0200
commit17a3ea880402338420699e03bcb24181e4ff3924 (patch)
treeda666ef91e0b60d20aa0b01529644c136fd1f4ab /src/augmentkv6
downloadoeuf-17a3ea880402338420699e03bcb24181e4ff3924.tar.gz
oeuf-17a3ea880402338420699e03bcb24181e4ff3924.zip
Initial commit
Based on dc4ba6a
Diffstat (limited to 'src/augmentkv6')
-rw-r--r--src/augmentkv6/.envrc2
-rw-r--r--src/augmentkv6/Makefile21
-rw-r--r--src/augmentkv6/main.cpp510
3 files changed, 533 insertions, 0 deletions
diff --git a/src/augmentkv6/.envrc b/src/augmentkv6/.envrc
new file mode 100644
index 0000000..694e74f
--- /dev/null
+++ b/src/augmentkv6/.envrc
@@ -0,0 +1,2 @@
1source_env ../../
2export DEVMODE=1
diff --git a/src/augmentkv6/Makefile b/src/augmentkv6/Makefile
new file mode 100644
index 0000000..cebb291
--- /dev/null
+++ b/src/augmentkv6/Makefile
@@ -0,0 +1,21 @@
1# Taken from:
2# Open Source Security Foundation (OpenSSF), “Compiler Options Hardening Guide
3# for C and C++,” OpenSSF Best Practices Working Group. Accessed: Dec. 01,
4# 2023. [Online]. Available:
5# https://best.openssf.org/Compiler-Hardening-Guides/Compiler-Options-Hardening-Guide-for-C-and-C++.html
6CXXFLAGS=-std=c++2b -g -fno-omit-frame-pointer $(if $(DEVMODE),-Werror,)\
7 -O2 -Wall -Wformat=2 -Wconversion -Wtrampolines -Wimplicit-fallthrough \
8 -U_FORTIFY_SOURCE -D_FORTIFY_SOURCE=3 \
9 -D_GLIBCXX_ASSERTIONS \
10 -fstrict-flex-arrays=3 \
11 -fstack-clash-protection -fstack-protector-strong
12LDFLAGS=-larrow -larrow_acero -larrow_dataset -lparquet -ltmi8 -Wl,-z,defs \
13 -Wl,-z,nodlopen -Wl,-z,noexecstack \
14 -Wl,-z,relro -Wl,-z,now
15
16augmentkv6: main.cpp
17 $(CXX) -fPIE -pie -o $@ $^ $(CXXFLAGS) $(LDFLAGS)
18
19.PHONY: clean
20clean:
21 rm augmentkv6
diff --git a/src/augmentkv6/main.cpp b/src/augmentkv6/main.cpp
new file mode 100644
index 0000000..81a54d3
--- /dev/null
+++ b/src/augmentkv6/main.cpp
@@ -0,0 +1,510 @@
1// vim:set sw=2 ts=2 sts et:
2
3#include <chrono>
4#include <cstdio>
5#include <deque>
6#include <filesystem>
7#include <format>
8#include <fstream>
9#include <iostream>
10#include <string>
11#include <string_view>
12#include <vector>
13
14#include <arrow/acero/exec_plan.h>
15#include <arrow/api.h>
16#include <arrow/compute/api.h>
17#include <arrow/dataset/api.h>
18#include <arrow/filesystem/api.h>
19#include <arrow/io/api.h>
20#include <parquet/arrow/reader.h>
21
22#include <tmi8/kv1_index.hpp>
23#include <tmi8/kv1_lexer.hpp>
24#include <tmi8/kv1_parser.hpp>
25#include <tmi8/kv1_types.hpp>
26#include <tmi8/kv6_parquet.hpp>
27
28using namespace std::string_view_literals;
29
30namespace ac = arrow::acero;
31namespace ds = arrow::dataset;
32namespace cp = arrow::compute;
33using namespace arrow;
34
35using TimingClock = std::conditional_t<
36 std::chrono::high_resolution_clock::is_steady,
37 std::chrono::high_resolution_clock,
38 std::chrono::steady_clock>;
39
40std::string readKv1() {
41 fputs("Reading KV1 from standard input\n", stderr);
42
43 char buf[4096];
44 std::string data;
45 while (!feof(stdin) && !ferror(stdin)) {
46 size_t read = fread(buf, sizeof(char), 4096, stdin);
47 data.append(buf, read);
48 }
49 if (ferror(stdin)) {
50 fputs("Error when reading from stdin\n", stderr);
51 exit(1);
52 }
53 fprintf(stderr, "Read %lu bytes\n", data.size());
54
55 return data;
56}
57
58std::vector<Kv1Token> lex() {
59 std::string data = readKv1();
60
61 auto start = TimingClock::now();
62 Kv1Lexer lexer(data);
63 lexer.lex();
64 auto end = TimingClock::now();
65
66 std::chrono::duration<double> elapsed{end - start};
67 double bytes = static_cast<double>(data.size()) / 1'000'000;
68 double speed = bytes / elapsed.count();
69
70 if (!lexer.errors.empty()) {
71 fputs("Lexer reported errors:\n", stderr);
72 for (const auto &error : lexer.errors)
73 fprintf(stderr, "- %s\n", error.c_str());
74 exit(1);
75 }
76
77 fprintf(stderr, "Got %lu tokens\n", lexer.tokens.size());
78 fprintf(stderr, "Duration: %f s\n", elapsed.count());
79 fprintf(stderr, "Speed: %f MB/s\n", speed);
80
81 return std::move(lexer.tokens);
82}
83
84bool parse(Kv1Records &into) {
85 std::vector<Kv1Token> tokens = lex();
86
87 Kv1Parser parser(tokens, into);
88 parser.parse();
89
90 bool ok = true;
91 if (!parser.gerrors.empty()) {
92 ok = false;
93 fputs("Parser reported errors:\n", stderr);
94 for (const auto &error : parser.gerrors)
95 fprintf(stderr, "- %s\n", error.c_str());
96 }
97 if (!parser.warns.empty()) {
98 fputs("Parser reported warnings:\n", stderr);
99 for (const auto &warn : parser.warns)
100 fprintf(stderr, "- %s\n", warn.c_str());
101 }
102
103 fprintf(stderr, "Parsed %lu records\n", into.size());
104
105 return ok;
106}
107
108void printParsedRecords(const Kv1Records &records) {
109 fputs("Parsed records:\n", stderr);
110 fprintf(stderr, " organizational_units: %lu\n", records.organizational_units.size());
111 fprintf(stderr, " higher_organizational_units: %lu\n", records.higher_organizational_units.size());
112 fprintf(stderr, " user_stop_points: %lu\n", records.user_stop_points.size());
113 fprintf(stderr, " user_stop_areas: %lu\n", records.user_stop_areas.size());
114 fprintf(stderr, " timing_links: %lu\n", records.timing_links.size());
115 fprintf(stderr, " links: %lu\n", records.links.size());
116 fprintf(stderr, " lines: %lu\n", records.lines.size());
117 fprintf(stderr, " destinations: %lu\n", records.destinations.size());
118 fprintf(stderr, " journey_patterns: %lu\n", records.journey_patterns.size());
119 fprintf(stderr, " concession_financer_relations: %lu\n", records.concession_financer_relations.size());
120 fprintf(stderr, " concession_areas: %lu\n", records.concession_areas.size());
121 fprintf(stderr, " financers: %lu\n", records.financers.size());
122 fprintf(stderr, " journey_pattern_timing_links: %lu\n", records.journey_pattern_timing_links.size());
123 fprintf(stderr, " points: %lu\n", records.points.size());
124 fprintf(stderr, " point_on_links: %lu\n", records.point_on_links.size());
125 fprintf(stderr, " icons: %lu\n", records.icons.size());
126 fprintf(stderr, " notices: %lu\n", records.notices.size());
127 fprintf(stderr, " notice_assignments: %lu\n", records.notice_assignments.size());
128 fprintf(stderr, " time_demand_groups: %lu\n", records.time_demand_groups.size());
129 fprintf(stderr, " time_demand_group_run_times: %lu\n", records.time_demand_group_run_times.size());
130 fprintf(stderr, " period_groups: %lu\n", records.period_groups.size());
131 fprintf(stderr, " specific_days: %lu\n", records.specific_days.size());
132 fprintf(stderr, " timetable_versions: %lu\n", records.timetable_versions.size());
133 fprintf(stderr, " public_journeys: %lu\n", records.public_journeys.size());
134 fprintf(stderr, " period_group_validities: %lu\n", records.period_group_validities.size());
135 fprintf(stderr, " exceptional_operating_days: %lu\n", records.exceptional_operating_days.size());
136 fprintf(stderr, " schedule_versions: %lu\n", records.schedule_versions.size());
137 fprintf(stderr, " public_journey_passing_times: %lu\n", records.public_journey_passing_times.size());
138 fprintf(stderr, " operating_days: %lu\n", records.operating_days.size());
139}
140
141void printIndexSize(const Kv1Index &index) {
142 fputs("Index size:\n", stderr);
143 fprintf(stderr, " organizational_units: %lu\n", index.organizational_units.size());
144 fprintf(stderr, " user_stop_points: %lu\n", index.user_stop_points.size());
145 fprintf(stderr, " user_stop_areas: %lu\n", index.user_stop_areas.size());
146 fprintf(stderr, " timing_links: %lu\n", index.timing_links.size());
147 fprintf(stderr, " links: %lu\n", index.links.size());
148 fprintf(stderr, " lines: %lu\n", index.lines.size());
149 fprintf(stderr, " destinations: %lu\n", index.destinations.size());
150 fprintf(stderr, " journey_patterns: %lu\n", index.journey_patterns.size());
151 fprintf(stderr, " concession_financer_relations: %lu\n", index.concession_financer_relations.size());
152 fprintf(stderr, " concession_areas: %lu\n", index.concession_areas.size());
153 fprintf(stderr, " financers: %lu\n", index.financers.size());
154 fprintf(stderr, " journey_pattern_timing_links: %lu\n", index.journey_pattern_timing_links.size());
155 fprintf(stderr, " points: %lu\n", index.points.size());
156 fprintf(stderr, " point_on_links: %lu\n", index.point_on_links.size());
157 fprintf(stderr, " icons: %lu\n", index.icons.size());
158 fprintf(stderr, " notices: %lu\n", index.notices.size());
159 fprintf(stderr, " time_demand_groups: %lu\n", index.time_demand_groups.size());
160 fprintf(stderr, " time_demand_group_run_times: %lu\n", index.time_demand_group_run_times.size());
161 fprintf(stderr, " period_groups: %lu\n", index.period_groups.size());
162 fprintf(stderr, " specific_days: %lu\n", index.specific_days.size());
163 fprintf(stderr, " timetable_versions: %lu\n", index.timetable_versions.size());
164 fprintf(stderr, " public_journeys: %lu\n", index.public_journeys.size());
165 fprintf(stderr, " period_group_validities: %lu\n", index.period_group_validities.size());
166 fprintf(stderr, " exceptional_operating_days: %lu\n", index.exceptional_operating_days.size());
167 fprintf(stderr, " schedule_versions: %lu\n", index.schedule_versions.size());
168 fprintf(stderr, " public_journey_passing_times: %lu\n", index.public_journey_passing_times.size());
169 fprintf(stderr, " operating_days: %lu\n", index.operating_days.size());
170}
171
172struct BasicJourneyKey {
173 std::string data_owner_code;
174 std::string line_planning_number;
175 int journey_number;
176
177 auto operator<=>(const BasicJourneyKey &) const = default;
178};
179
180size_t hash_value(const BasicJourneyKey &k) {
181 size_t seed = 0;
182
183 boost::hash_combine(seed, k.data_owner_code);
184 boost::hash_combine(seed, k.line_planning_number);
185 boost::hash_combine(seed, k.journey_number);
186
187 return seed;
188}
189
190using BasicJourneyKeySet = std::unordered_set<BasicJourneyKey, boost::hash<BasicJourneyKey>>;
191
192arrow::Result<BasicJourneyKeySet> basicJourneys(std::shared_ptr<arrow::Table> table) {
193 ac::TableSourceNodeOptions table_source_node_options(table);
194 ac::Declaration table_source("table_source", std::move(table_source_node_options));
195 auto aggregate_options = ac::AggregateNodeOptions{
196 /* .aggregates = */ {},
197 /* .keys = */ { "data_owner_code", "line_planning_number", "journey_number" },
198 };
199 ac::Declaration aggregate("aggregate", { std::move(table_source) }, std::move(aggregate_options));
200
201 std::shared_ptr<arrow::Table> result;
202 ARROW_ASSIGN_OR_RAISE(result, ac::DeclarationToTable(std::move(aggregate)));
203
204 std::shared_ptr<arrow::ChunkedArray> data_owner_codes = result->GetColumnByName("data_owner_code");
205 std::shared_ptr<arrow::ChunkedArray> line_planning_numbers = result->GetColumnByName("line_planning_number");
206 std::shared_ptr<arrow::ChunkedArray> journey_numbers = result->GetColumnByName("journey_number");
207
208 int i_data_owner_codes_chunk = 0;
209 int i_journey_numbers_chunk = 0;
210 int i_line_planning_numbers_chunk = 0;
211 int i_in_data_owner_codes_chunk = 0;
212 int i_in_journey_numbers_chunk = 0;
213 int i_in_line_planning_numbers_chunk = 0;
214
215 BasicJourneyKeySet journeys;
216
217 for (int64_t i = 0; i < result->num_rows(); i++) {
218 auto data_owner_codes_chunk = std::static_pointer_cast<arrow::StringArray>(data_owner_codes->chunk(i_data_owner_codes_chunk));
219 auto line_planning_numbers_chunk = std::static_pointer_cast<arrow::StringArray>(line_planning_numbers->chunk(i_line_planning_numbers_chunk));
220 auto journey_numbers_chunk = std::static_pointer_cast<arrow::UInt32Array>(journey_numbers->chunk(i_journey_numbers_chunk));
221
222 std::string_view data_owner_code = data_owner_codes_chunk->Value(i_in_data_owner_codes_chunk);
223 std::string_view line_planning_number = line_planning_numbers_chunk->Value(i_in_line_planning_numbers_chunk);
224 uint32_t journey_number = journey_numbers_chunk->Value(i_in_journey_numbers_chunk);
225
226 journeys.emplace(
227 std::string(data_owner_code),
228 std::string(line_planning_number),
229 journey_number
230 );
231
232 i_in_data_owner_codes_chunk++;
233 i_in_line_planning_numbers_chunk++;
234 i_in_journey_numbers_chunk++;
235 if (i_in_data_owner_codes_chunk >= data_owner_codes_chunk->length()) {
236 i_data_owner_codes_chunk++;
237 i_in_data_owner_codes_chunk = 0;
238 }
239 if (i_in_line_planning_numbers_chunk >= line_planning_numbers_chunk->length()) {
240 i_line_planning_numbers_chunk++;
241 i_in_line_planning_numbers_chunk = 0;
242 }
243 if (i_in_journey_numbers_chunk >= journey_numbers_chunk->length()) {
244 i_journey_numbers_chunk++;
245 i_in_journey_numbers_chunk = 0;
246 }
247 }
248
249 return journeys;
250}
251
252struct DistanceKey {
253 BasicJourneyKey journey;
254 std::string last_passed_user_stop_code;
255
256 auto operator<=>(const DistanceKey &) const = default;
257};
258
259size_t hash_value(const DistanceKey &k) {
260 size_t seed = 0;
261
262 boost::hash_combine(seed, k.journey);
263 boost::hash_combine(seed, k.last_passed_user_stop_code);
264
265 return seed;
266}
267
268struct DistanceTimingLink {
269 const Kv1JourneyPatternTimingLink *jopatili;
270 double distance_since_start_of_journey = 0; // at the start of the link
271};
272
273using DistanceMap = std::unordered_map<DistanceKey, double, boost::hash<DistanceKey>>;
274
275// Returns a map, where
276// DataOwnerCode + LinePlanningNumber + JourneyNumber + UserStopCode ->
277// Distance of Last User Stop
278DistanceMap makeDistanceMap(Kv1Records &records, Kv1Index &index, BasicJourneyKeySet &journeys) {
279 std::unordered_map<
280 Kv1JourneyPattern::Key,
281 std::vector<DistanceTimingLink>,
282 boost::hash<Kv1JourneyPattern::Key>> jopatili_index;
283 std::unordered_map<
284 BasicJourneyKey,
285 const Kv1PublicJourney *,
286 boost::hash<BasicJourneyKey>> journey_index;
287 for (size_t i = 0; i < records.public_journeys.size(); i++) {
288 const Kv1PublicJourney *pujo = &records.public_journeys[i];
289
290 BasicJourneyKey journey_key(
291 pujo->key.data_owner_code,
292 pujo->key.line_planning_number,
293 pujo->key.journey_number);
294
295 if (journeys.contains(journey_key)) {
296 journey_index[journey_key] = pujo;
297
298 Kv1JourneyPattern::Key jopa_key(
299 pujo->key.data_owner_code,
300 pujo->key.line_planning_number,
301 pujo->journey_pattern_code);
302 jopatili_index[jopa_key] = {};
303 }
304 }
305
306 for (size_t i = 0; i < records.journey_pattern_timing_links.size(); i++) {
307 const Kv1JourneyPatternTimingLink *jopatili = &records.journey_pattern_timing_links[i];
308 Kv1JourneyPattern::Key jopa_key(
309 jopatili->key.data_owner_code,
310 jopatili->key.line_planning_number,
311 jopatili->key.journey_pattern_code);
312 if (jopatili_index.contains(jopa_key)) {
313 jopatili_index[jopa_key].push_back(DistanceTimingLink(jopatili, 0));
314 }
315 }
316
317 for (auto &[jopa_key, timing_links] : jopatili_index) {
318 std::sort(timing_links.begin(), timing_links.end(), [](auto a, auto b) {
319 return a.jopatili->key.timing_link_order < b.jopatili->key.timing_link_order;
320 });
321
322 const std::string transport_type = index.journey_patterns[jopa_key]->p_line->transport_type;
323
324 for (size_t i = 1; i < timing_links.size(); i++) {
325 DistanceTimingLink *timing_link = &timing_links[i];
326 DistanceTimingLink *prev_timing_link = &timing_links[i - 1];
327
328 const Kv1Link::Key link_key(
329 prev_timing_link->jopatili->key.data_owner_code,
330 prev_timing_link->jopatili->user_stop_code_begin,
331 prev_timing_link->jopatili->user_stop_code_end,
332 transport_type);
333 double link_distance = index.links[link_key]->distance;
334 timing_link->distance_since_start_of_journey =
335 prev_timing_link->distance_since_start_of_journey + link_distance;
336 }
337 }
338
339 // DataOwnerCode + LinePlanningNumber + JourneyNumber + UserStopCode ->
340 // Distance of Last User Stop
341 DistanceMap distance_map;
342
343 for (const auto &journey : journeys) {
344 const Kv1PublicJourney *pujo = journey_index[journey];
345 if (pujo == nullptr) {
346 std::cerr << "Warning: No PUJO found for [" << journey.data_owner_code << "] "
347 << journey.line_planning_number << "/" << journey.journey_number << std::endl;
348 continue;
349 }
350 Kv1JourneyPattern::Key jopa_key(
351 pujo->key.data_owner_code,
352 pujo->key.line_planning_number,
353 pujo->journey_pattern_code);
354 for (const auto &timing_link : jopatili_index[jopa_key]) {
355 DistanceKey key(journey, timing_link.jopatili->user_stop_code_begin);
356 distance_map[key] = timing_link.distance_since_start_of_journey;
357 }
358 }
359
360 return distance_map;
361}
362
363arrow::Result<std::shared_ptr<arrow::Table>> augment(
364 std::shared_ptr<arrow::Table> table,
365 const DistanceMap &distance_map
366) {
367 for (int i = 0; i < table->num_columns(); i++) {
368 if (table->column(i)->num_chunks() > 1) {
369 std::stringstream ss;
370 ss << "Error: Expected column " << i
371 << " (" << table->ColumnNames()[i] << ") to have 1 chunk, got "
372 << table->column(i)->num_chunks();
373 return arrow::Status::Invalid(ss.str());
374 }
375 }
376
377 auto data_owner_codes = std::static_pointer_cast<arrow::StringArray>(table->GetColumnByName("data_owner_code")->chunk(0));
378 auto line_planning_numbers = std::static_pointer_cast<arrow::StringArray>(table->GetColumnByName("line_planning_number")->chunk(0));
379 auto journey_numbers = std::static_pointer_cast<arrow::UInt32Array>(table->GetColumnByName("journey_number")->chunk(0));
380 auto user_stop_codes = std::static_pointer_cast<arrow::StringArray>(table->GetColumnByName("user_stop_code")->chunk(0));
381 auto distance_since_last_user_stops = std::static_pointer_cast<arrow::UInt32Array>(table->GetColumnByName("distance_since_last_user_stop")->chunk(0));
382 auto timestamps = std::static_pointer_cast<arrow::TimestampArray>(table->GetColumnByName("timestamp")->chunk(0));
383
384 auto timestamps_type = table->schema()->GetFieldByName("timestamp")->type();
385 if (timestamps_type->id() != arrow::Type::TIMESTAMP)
386 return arrow::Status::Invalid("Field 'timestamp' does not have expected type TIMESTAMP");
387 if (std::static_pointer_cast<arrow::TimestampType>(timestamps_type)->unit() != arrow::TimeUnit::MILLI)
388 return arrow::Status::Invalid("Field 'timestamp' does not have unit MILLI");
389 if (!std::static_pointer_cast<arrow::TimestampType>(timestamps_type)->timezone().empty())
390 return arrow::Status::Invalid("Field 'timestamp' should have empty time zone name");
391
392 std::shared_ptr<arrow::Field> field_distance_since_start_of_journey =
393 arrow::field("distance_since_start_of_journey", arrow::uint32());
394 std::shared_ptr<arrow::Field> field_day_of_week =
395 arrow::field("timestamp_iso_day_of_week", arrow::int64());
396 std::shared_ptr<arrow::Field> field_date =
397 arrow::field("timestamp_date", arrow::date32());
398 std::shared_ptr<arrow::Field> field_local_time =
399 arrow::field("timestamp_local_time", arrow::time32(arrow::TimeUnit::SECOND));
400 arrow::UInt32Builder distance_since_start_of_journey_builder;
401 arrow::Int64Builder day_of_week_builder;
402 arrow::Date32Builder date_builder;
403 arrow::Time32Builder local_time_builder(arrow::time32(arrow::TimeUnit::SECOND), arrow::default_memory_pool());
404
405 const std::chrono::time_zone *amsterdam = std::chrono::locate_zone("Europe/Amsterdam");
406
407 for (int64_t i = 0; i < table->num_rows(); i++) {
408 DistanceKey key(
409 BasicJourneyKey(
410 std::string(data_owner_codes->Value(i)),
411 std::string(line_planning_numbers->Value(i)),
412 journey_numbers->Value(i)),
413 std::string(user_stop_codes->Value(i)));
414
415 uint32_t distance_since_last_user_stop = distance_since_last_user_stops->Value(i);
416 if (distance_map.contains(key)) {
417 uint32_t total_distance = distance_since_last_user_stop + static_cast<uint32_t>(distance_map.at(key));
418 ARROW_RETURN_NOT_OK(distance_since_start_of_journey_builder.Append(total_distance));
419 } else {
420 ARROW_RETURN_NOT_OK(distance_since_start_of_journey_builder.AppendNull());
421 }
422
423 // Welp, this has gotten a bit complicated!
424 std::chrono::sys_seconds timestamp(std::chrono::floor<std::chrono::seconds>(std::chrono::milliseconds(timestamps->Value(i))));
425 std::chrono::zoned_seconds zoned_timestamp(amsterdam, timestamp);
426 std::chrono::local_seconds local_timestamp(zoned_timestamp);
427 std::chrono::local_days local_date = std::chrono::floor<std::chrono::days>(local_timestamp);
428 std::chrono::year_month_day date(local_date);
429 std::chrono::weekday day_of_week(local_date);
430 std::chrono::hh_mm_ss<std::chrono::seconds> time(local_timestamp - local_date);
431 std::chrono::sys_days unix_date(date);
432
433 int64_t iso_day_of_week = day_of_week.iso_encoding();
434 int32_t unix_days = static_cast<int32_t>(unix_date.time_since_epoch().count());
435 int32_t secs_since_midnight = static_cast<int32_t>(std::chrono::seconds(time).count());
436
437 ARROW_RETURN_NOT_OK(day_of_week_builder.Append(iso_day_of_week));
438 ARROW_RETURN_NOT_OK(date_builder.Append(unix_days));
439 ARROW_RETURN_NOT_OK(local_time_builder.Append(secs_since_midnight));
440 }
441
442 ARROW_ASSIGN_OR_RAISE(auto distance_since_start_of_journey_col_chunk, distance_since_start_of_journey_builder.Finish());
443 ARROW_ASSIGN_OR_RAISE(auto day_of_week_col_chunk, day_of_week_builder.Finish());
444 ARROW_ASSIGN_OR_RAISE(auto date_col_chunk, date_builder.Finish());
445 ARROW_ASSIGN_OR_RAISE(auto local_time_col_chunk, local_time_builder.Finish());
446 auto distance_since_start_of_journey_col =
447 std::make_shared<arrow::ChunkedArray>(distance_since_start_of_journey_col_chunk);
448 auto day_of_week_col = std::make_shared<arrow::ChunkedArray>(day_of_week_col_chunk);
449 auto date_col = std::make_shared<arrow::ChunkedArray>(date_col_chunk);
450 auto local_time_col = std::make_shared<arrow::ChunkedArray>(local_time_col_chunk);
451
452 ARROW_ASSIGN_OR_RAISE(table, table->AddColumn(
453 table->num_columns(),
454 field_distance_since_start_of_journey,
455 distance_since_start_of_journey_col));
456 ARROW_ASSIGN_OR_RAISE(table, table->AddColumn(table->num_columns(), field_day_of_week, day_of_week_col));
457 ARROW_ASSIGN_OR_RAISE(table, table->AddColumn(table->num_columns(), field_date, date_col));
458 ARROW_ASSIGN_OR_RAISE(table, table->AddColumn(table->num_columns(), field_local_time, local_time_col));
459
460 return table;
461}
462
463arrow::Status processTables(Kv1Records &records, Kv1Index &index) {
464 std::shared_ptr<arrow::io::RandomAccessFile> input;
465 ARROW_ASSIGN_OR_RAISE(input, arrow::io::ReadableFile::Open("oeuf-input.parquet"));
466
467 std::unique_ptr<parquet::arrow::FileReader> arrow_reader;
468 ARROW_RETURN_NOT_OK(parquet::arrow::OpenFile(input, arrow::default_memory_pool(), &arrow_reader));
469
470 std::shared_ptr<arrow::Table> table;
471 ARROW_RETURN_NOT_OK(arrow_reader->ReadTable(&table));
472
473 std::cerr << "Input KV6 file has " << table->num_rows() << " rows" << std::endl;
474 ARROW_ASSIGN_OR_RAISE(BasicJourneyKeySet journeys, basicJourneys(table));
475 std::cerr << "Found " << journeys.size() << " distinct journeys" << std::endl;
476 DistanceMap distance_map = makeDistanceMap(records, index, journeys);
477 std::cerr << "Distance map has " << distance_map.size() << " keys" << std::endl;
478
479 std::cerr << "Creating augmented table" << std::endl;
480 ARROW_ASSIGN_OR_RAISE(std::shared_ptr<arrow::Table> augmented, augment(table, distance_map));
481
482 std::cerr << "Writing augmented table" << std::endl;
483 return writeArrowTableAsParquetFile(*augmented, "oeuf-augmented.parquet");
484}
485
486int main(int argc, char *argv[]) {
487 Kv1Records records;
488 if (!parse(records)) {
489 fputs("Error parsing records, exiting\n", stderr);
490 return EXIT_FAILURE;
491 }
492 printParsedRecords(records);
493 fputs("Indexing...\n", stderr);
494 Kv1Index index(&records);
495 fprintf(stderr, "Indexed %lu records\n", index.size());
496 // Only notice assignments are not indexed. If this equality is not valid,
497 // then this means that we had duplicate keys or that something else went
498 // wrong. That would really not be great.
499 assert(index.size() == records.size() - records.notice_assignments.size());
500 printIndexSize(index);
501 fputs("Linking records...\n", stderr);
502 kv1LinkRecords(index);
503 fputs("Done linking\n", stderr);
504
505 arrow::Status st = processTables(records, index);
506 if (!st.ok()) {
507 std::cerr << "Failed to process tables: " << st << std::endl;
508 return EXIT_FAILURE;
509 }
510}