uor_addr/onnx/protobuf/
wire.rs1use super::tag::{Tag, WireType};
6use super::varint::read_varint;
7
8#[derive(Debug, Clone, Copy, PartialEq, Eq)]
10pub enum WireError {
11 Truncated,
13 VarintOverflow,
15 UnknownWireType(u8),
18 ZeroFieldNumber,
20 LengthOutOfRange,
23}
24
25#[derive(Debug, Clone, Copy, PartialEq, Eq)]
27pub enum FieldValue<'a> {
28 Varint(u64),
30 Fixed64(u64),
32 Fixed32(u32),
34 Bytes(&'a [u8]),
36}
37
38#[derive(Debug, Clone, Copy, PartialEq, Eq)]
40pub struct Field<'a> {
41 pub number: u64,
43 pub value: FieldValue<'a>,
45}
46
47pub struct MessageReader<'a> {
49 buf: &'a [u8],
50 pos: usize,
51}
52
53impl<'a> MessageReader<'a> {
54 #[must_use]
56 pub fn new(buf: &'a [u8]) -> Self {
57 Self { buf, pos: 0 }
58 }
59
60 pub fn next_field(&mut self) -> Result<Option<Field<'a>>, WireError> {
66 if self.pos >= self.buf.len() {
67 return Ok(None);
68 }
69 let (tag_v, p) = read_varint(self.buf, self.pos)?;
70 self.pos = p;
71 let tag = Tag::from_varint(tag_v)?;
72 let value = match tag.wire_type {
73 WireType::Varint => {
74 let (v, p) = read_varint(self.buf, self.pos)?;
75 self.pos = p;
76 FieldValue::Varint(v)
77 }
78 WireType::Fixed64 => {
79 let b = self.take(8)?;
80 FieldValue::Fixed64(u64::from_le_bytes([
81 b[0], b[1], b[2], b[3], b[4], b[5], b[6], b[7],
82 ]))
83 }
84 WireType::Fixed32 => {
85 let b = self.take(4)?;
86 FieldValue::Fixed32(u32::from_le_bytes([b[0], b[1], b[2], b[3]]))
87 }
88 WireType::LengthDelimited => {
89 let (len, p) = read_varint(self.buf, self.pos)?;
90 self.pos = p;
91 let len = usize::try_from(len).map_err(|_| WireError::LengthOutOfRange)?;
92 FieldValue::Bytes(self.take(len)?)
93 }
94 };
95 Ok(Some(Field {
96 number: tag.field_number,
97 value,
98 }))
99 }
100
101 fn take(&mut self, n: usize) -> Result<&'a [u8], WireError> {
102 let end = self.pos.checked_add(n).ok_or(WireError::LengthOutOfRange)?;
103 if end > self.buf.len() {
104 return Err(WireError::Truncated);
105 }
106 let s = &self.buf[self.pos..end];
107 self.pos = end;
108 Ok(s)
109 }
110}
111
112#[cfg(test)]
113mod tests {
114 use super::*;
115
116 #[test]
117 fn reads_mixed_fields() {
118 let buf = [0x08, 0x0D, 0x12, 0x02, b'h', b'i'];
120 let mut r = MessageReader::new(&buf);
121 let f1 = r.next_field().unwrap().unwrap();
122 assert_eq!(f1.number, 1);
123 assert_eq!(f1.value, FieldValue::Varint(13));
124 let f2 = r.next_field().unwrap().unwrap();
125 assert_eq!(f2.number, 2);
126 assert_eq!(f2.value, FieldValue::Bytes(b"hi"));
127 assert_eq!(r.next_field().unwrap(), None);
128 }
129
130 #[test]
131 fn truncated_length_delimited() {
132 let buf = [0x12, 0x05, b'h', b'i']; let mut r = MessageReader::new(&buf);
134 assert_eq!(r.next_field(), Err(WireError::Truncated));
135 }
136}