aboutsummaryrefslogtreecommitdiffstats
path: root/src/base32.zig
blob: 149e5f046fc5b97d08fe8cf3f6dfd27a24041785 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
const std = @import("std");

pub const Encoder = struct {
    const Self = @This();

    buffer: []const u8,
    index: ?usize,
    bit_off: u3,

    pub fn init(buffer: []const u8) Encoder {
        return .{
            .buffer = buffer,
            .index = 0,
            .bit_off = 0,
        };
    }

    pub fn calcSize(source_len: usize) usize {
        const source_len_bits = source_len * 8;
        return source_len_bits / 5 + (if (source_len_bits % 5 > 0) @as(usize, 1) else 0);
    }

    pub fn encode(dest: []u8, source: []const u8) []const u8 {
        const out_len = calcSize(source.len);
        std.debug.assert(dest.len >= out_len);

        var e = init(source);
        for (dest) |*b| b.* = e.next() orelse unreachable;
        return dest[0..out_len];
    }

    fn n_front_bits(self: *const Self) u3 {
        // bit_off   n_front_bits
        // 0         5
        // 1         5
        // 2         5
        // 3         5
        // 4         4
        // 5         3
        // 6         2
        // 7         1
        return if (self.bit_off <= 3) 5 else 7 - self.bit_off + 1;
    }

    fn front(self: *const Self, index: usize) u5 {
        // bit_off   bits         shl   shr   front
        // 0         0b11111000         3     0b11111
        // 1         0b01111100         2     0b11111
        // 2         0b00111110         1     0b11111
        // 3         0b00011111   0     0     0b11111
        // 4         0b00001111   1           0b11110
        // 5         0b00000111   2           0b11100
        // 6         0b00000011   3           0b11000
        // 7         0b00000001   4           0b10000
        const bitmask = @as(u8, 0b11111000) >> self.bit_off;
        const bits = self.buffer[index] & bitmask;
        if (self.bit_off >= 4) return @truncate(u5, bits << (self.bit_off - 3));
        return @truncate(u5, bits >> (3 - self.bit_off));
    }

    fn back(self: *const Self, index: usize, bits: u3) u5 {
        if (bits == 0) return 0;
        return @truncate(u5, self.buffer[index] >> (7 - bits + 1));
    }

    fn next_u5(self: *Self) ?u5 {
        const front_index = self.index orelse return null;
        const num_front_bits = self.n_front_bits();
        const front_bits = self.front(front_index);

        var back_bits: u5 = 0;
        if (self.bit_off >= 3) {
            self.bit_off -= 3;
            const new_index = front_index + 1;
            if (self.buffer.len > new_index) {
                self.index = new_index;
                back_bits = self.back(new_index, 5 - num_front_bits);
            } else {
                self.index = null;
            }
        } else {
            self.bit_off += 5;
        }

        return front_bits | back_bits;
    }

    // Returns the corresponding ASCII character for 5 bits of the input.
    fn char(unencoded: u5) u8 {
        return unencoded + (if (unencoded < 26) @as(u8, 'A') else '2' - 26);
    }

    // Returns the next byte of the encoded buffer.
    pub fn next(self: *Self) ?u8 {
        const unencoded = self.next_u5() orelse return null;
        return char(unencoded);
    }
};

// TODO(rutgerbrf): simplify the code of the decoder

pub const DecodeError = error{CorruptInputError};

pub const Decoder = struct {
    const Self = @This();

    out_off: u4 = 0,
    buf: u8 = 0,

    pub fn read(self: *Self, c: u8) DecodeError!?u8 {
        var ret: ?u8 = null;
        var decoded_c = try decodeChar(c);
        var bits_left: u3 = 5;
        while (bits_left > 0) {
            var space_avail: u4 = 8 - self.out_off;
            var write_bits: u3 = if (bits_left < space_avail) bits_left else @truncate(u3, space_avail);
            bits_left -= write_bits;
            var mask: u8 = (@as(u8, 0x01) << write_bits) - 1;
            var want: u8 = (decoded_c >> bits_left) & mask;
            self.buf |= want << @truncate(u3, space_avail - write_bits);
            self.out_off += write_bits;
            if (self.out_off == 8) {
                ret = self.buf;
                self.out_off = 0;
                self.buf = 0;
            }
        }
        return ret;
    }

    fn decodeChar(p: u8) DecodeError!u5 {
        var value: u5 = 0;
        if (p >= 'A' and p <= 'Z') {
            value = @truncate(u5, p - @as(u8, 'A'));
        } else if (p >= '2' and p <= '9') {
            // '2' -> 26
            value = @truncate(u5, p - @as(u8, '2') + 26);
        } else {
            return error.CorruptInputError;
        }
        return value;
    }
};

pub fn decodedLen(enc_len: usize) usize {
    const enc_len_bits = enc_len * 5;
    return enc_len_bits / 8;
}

pub fn decode(ps: []const u8, out: []u8) DecodeError!usize {
    var d = Decoder{};
    var i: usize = 0;
    for (ps) |p| {
        if (i >= out.len) break;
        if (try d.read(p)) |b| {
            out[i] = b;
            i += 1;
        }
    }
    if (d.out_off != 0 and i < out.len) {
        out[i] = d.buf;
        i += 1;
    }
    return i; // amount of bytes processed
}