Skip to main content

uor_addr/onnx/protobuf/
wire.rs

1//! Field-level protobuf reader: iterates the fields of a message,
2//! yielding borrowed [`Field`] views. No allocation; the reader holds a
3//! `&[u8]` and a cursor.
4
5use super::tag::{Tag, WireType};
6use super::varint::read_varint;
7
8/// Wire-format decode errors.
9#[derive(Debug, Clone, Copy, PartialEq, Eq)]
10pub enum WireError {
11    /// Buffer ended mid-field.
12    Truncated,
13    /// A varint exceeded 10 bytes.
14    VarintOverflow,
15    /// Wire type bits `3`/`4` (deprecated groups) or anything else
16    /// unrecognized.
17    UnknownWireType(u8),
18    /// Field number `0` is illegal.
19    ZeroFieldNumber,
20    /// A length-delimited field declared a length running past the
21    /// buffer.
22    LengthOutOfRange,
23}
24
25/// A decoded field value (borrowed for length-delimited payloads).
26#[derive(Debug, Clone, Copy, PartialEq, Eq)]
27pub enum FieldValue<'a> {
28    /// Wire type 0.
29    Varint(u64),
30    /// Wire type 1 (little-endian 64-bit).
31    Fixed64(u64),
32    /// Wire type 5 (little-endian 32-bit).
33    Fixed32(u32),
34    /// Wire type 2 — string / bytes / embedded message / packed scalars.
35    Bytes(&'a [u8]),
36}
37
38/// A single decoded field: its number plus value.
39#[derive(Debug, Clone, Copy, PartialEq, Eq)]
40pub struct Field<'a> {
41    /// Field number.
42    pub number: u64,
43    /// Field value.
44    pub value: FieldValue<'a>,
45}
46
47/// Iterates the top-level fields of a protobuf message.
48pub struct MessageReader<'a> {
49    buf: &'a [u8],
50    pos: usize,
51}
52
53impl<'a> MessageReader<'a> {
54    /// Wrap a message body.
55    #[must_use]
56    pub fn new(buf: &'a [u8]) -> Self {
57        Self { buf, pos: 0 }
58    }
59
60    /// Read the next field, or `None` at end of message.
61    ///
62    /// # Errors
63    ///
64    /// Any [`WireError`] from a malformed encoding.
65    pub fn next_field(&mut self) -> Result<Option<Field<'a>>, WireError> {
66        if self.pos >= self.buf.len() {
67            return Ok(None);
68        }
69        let (tag_v, p) = read_varint(self.buf, self.pos)?;
70        self.pos = p;
71        let tag = Tag::from_varint(tag_v)?;
72        let value = match tag.wire_type {
73            WireType::Varint => {
74                let (v, p) = read_varint(self.buf, self.pos)?;
75                self.pos = p;
76                FieldValue::Varint(v)
77            }
78            WireType::Fixed64 => {
79                let b = self.take(8)?;
80                FieldValue::Fixed64(u64::from_le_bytes([
81                    b[0], b[1], b[2], b[3], b[4], b[5], b[6], b[7],
82                ]))
83            }
84            WireType::Fixed32 => {
85                let b = self.take(4)?;
86                FieldValue::Fixed32(u32::from_le_bytes([b[0], b[1], b[2], b[3]]))
87            }
88            WireType::LengthDelimited => {
89                let (len, p) = read_varint(self.buf, self.pos)?;
90                self.pos = p;
91                let len = usize::try_from(len).map_err(|_| WireError::LengthOutOfRange)?;
92                FieldValue::Bytes(self.take(len)?)
93            }
94        };
95        Ok(Some(Field {
96            number: tag.field_number,
97            value,
98        }))
99    }
100
101    fn take(&mut self, n: usize) -> Result<&'a [u8], WireError> {
102        let end = self.pos.checked_add(n).ok_or(WireError::LengthOutOfRange)?;
103        if end > self.buf.len() {
104            return Err(WireError::Truncated);
105        }
106        let s = &self.buf[self.pos..end];
107        self.pos = end;
108        Ok(s)
109    }
110}
111
112#[cfg(test)]
113mod tests {
114    use super::*;
115
116    #[test]
117    fn reads_mixed_fields() {
118        // field 1 varint = 13 ; field 2 length-delimited = "hi"
119        let buf = [0x08, 0x0D, 0x12, 0x02, b'h', b'i'];
120        let mut r = MessageReader::new(&buf);
121        let f1 = r.next_field().unwrap().unwrap();
122        assert_eq!(f1.number, 1);
123        assert_eq!(f1.value, FieldValue::Varint(13));
124        let f2 = r.next_field().unwrap().unwrap();
125        assert_eq!(f2.number, 2);
126        assert_eq!(f2.value, FieldValue::Bytes(b"hi"));
127        assert_eq!(r.next_field().unwrap(), None);
128    }
129
130    #[test]
131    fn truncated_length_delimited() {
132        let buf = [0x12, 0x05, b'h', b'i']; // declares 5 bytes, only 2
133        let mut r = MessageReader::new(&buf);
134        assert_eq!(r.next_field(), Err(WireError::Truncated));
135    }
136}