Skip to main content

uor_addr/onnx/protobuf/
varint.rs

1//! Base-128 varint decoding (protobuf v3 wire format ยง"Base 128
2//! Varints"): each byte's high bit is a continuation flag; the low 7
3//! bits are data, least-significant group first.
4
5use super::wire::WireError;
6
7/// Maximum bytes in a 64-bit varint (`ceil(64 / 7)`).
8pub const VARINT_MAX_BYTES: usize = 10;
9
10/// Read a varint starting at `buf[pos]`. Returns `(value, new_pos)`.
11///
12/// # Errors
13///
14/// [`WireError::Truncated`] if the buffer ends mid-varint;
15/// [`WireError::VarintOverflow`] if the encoding exceeds 10 bytes.
16pub fn read_varint(buf: &[u8], pos: usize) -> Result<(u64, usize), WireError> {
17    let mut result: u64 = 0;
18    let mut shift: u32 = 0;
19    let mut i = pos;
20    loop {
21        if i >= buf.len() {
22            return Err(WireError::Truncated);
23        }
24        if i - pos >= VARINT_MAX_BYTES {
25            return Err(WireError::VarintOverflow);
26        }
27        let byte = buf[i];
28        result |= u64::from(byte & 0x7F) << shift;
29        i += 1;
30        if byte & 0x80 == 0 {
31            return Ok((result, i));
32        }
33        shift += 7;
34    }
35}
36
37#[cfg(test)]
38mod tests {
39    use super::*;
40
41    #[test]
42    fn single_byte() {
43        assert_eq!(read_varint(&[0x00], 0).unwrap(), (0, 1));
44        assert_eq!(read_varint(&[0x01], 0).unwrap(), (1, 1));
45        assert_eq!(read_varint(&[0x7F], 0).unwrap(), (127, 1));
46    }
47
48    #[test]
49    fn multi_byte() {
50        // 150 = 0x96 0x01
51        assert_eq!(read_varint(&[0x96, 0x01], 0).unwrap(), (150, 2));
52        // 300 = 0xAC 0x02
53        assert_eq!(read_varint(&[0xAC, 0x02], 0).unwrap(), (300, 2));
54    }
55
56    #[test]
57    fn truncated_is_error() {
58        assert_eq!(read_varint(&[0x80], 0), Err(WireError::Truncated));
59    }
60
61    #[test]
62    fn overflow_is_error() {
63        assert_eq!(read_varint(&[0x80; 11], 0), Err(WireError::VarintOverflow));
64    }
65}