1use prism::operation::TermValue;
43use prism::pipeline::{
44 ConstrainedTypeShape, ConstraintRef, IntoBindingValue, PartitionProductFields,
45};
46
47#[derive(Clone, Copy, Debug)]
53pub struct OnnxCarrier<'a>(&'a [u8]);
54
55impl<'a> OnnxCarrier<'a> {
56 #[must_use]
58 pub fn new(skeleton: &'a [u8]) -> Self {
59 Self(skeleton)
60 }
61
62 #[must_use]
64 pub fn canonical_bytes(&self) -> &'a [u8] {
65 self.0
66 }
67}
68
69impl ConstrainedTypeShape for OnnxCarrier<'_> {
70 const IRI: &'static str = "https://uor.foundation/addr/OnnxValue";
71 const SITE_COUNT: usize = 1;
72 const CONSTRAINTS: &'static [ConstraintRef] = &[];
73 const CYCLE_SIZE: u64 = u64::MAX;
74}
75
76impl prism::uor_foundation::pipeline::__sdk_seal::Sealed for OnnxCarrier<'_> {}
77
78impl<'a> IntoBindingValue<'a> for OnnxCarrier<'a> {
79 fn as_binding_value<const INLINE_BYTES: usize>(&self) -> TermValue<'a, INLINE_BYTES> {
80 TermValue::borrowed(self.0)
81 }
82}
83
84impl PartitionProductFields for OnnxCarrier<'_> {
85 const FIELDS: &'static [(u32, u32)] = &[];
86 const FIELD_NAMES: &'static [&'static str] = &[];
87}
88
89#[cfg(feature = "alloc")]
94pub use alloc_impl::{canonicalize, OnnxValue};
95
96#[cfg(feature = "alloc")]
97mod alloc_impl {
98 use alloc::vec::Vec;
99
100 use prism::crypto::Sha256Hasher;
101 use prism::pipeline::{ShapeViolation, ViolationKind};
102 use prism::vocabulary::Hasher;
103
104 use crate::onnx::dtype::OnnxDataType;
105 use crate::onnx::protobuf::{read_varint, FieldValue, MessageReader};
106 use crate::onnx::shapes::bounds::{
107 ONNX_IR_VERSION_MAX, ONNX_OPSET_VERSION_MIN, ONNX_SUBGRAPH_DEPTH_MAX,
108 };
109
110 macro_rules! violation {
113 ($name:ident, $constraint:literal, $kind:expr) => {
114 const $name: ShapeViolation = ShapeViolation {
115 shape_iri: "https://uor.foundation/addr/OnnxValue",
116 constraint_iri: concat!("https://uor.foundation/addr/OnnxValue/", $constraint),
117 property_iri: concat!("https://uor.foundation/addr/OnnxValue/", $constraint),
118 expected_range: "http://www.w3.org/2001/XMLSchema#nonNegativeInteger",
119 min_count: 0,
120 max_count: 1,
121 kind: $kind,
122 };
123 };
124 }
125
126 violation!(PROTOBUF_FAILURE, "validProtobuf", ViolationKind::ValueCheck);
127 violation!(
128 UNSUPPORTED_IR,
129 "supportedIrVersion",
130 ViolationKind::ValueCheck
131 );
132 violation!(OPSET_TOO_OLD, "opsetVersionMin", ViolationKind::ValueCheck);
133 violation!(MISSING_GRAPH, "graphPresent", ViolationKind::ValueCheck);
134 violation!(
135 SUBGRAPH_DEPTH,
136 "subgraphDepthBound",
137 ViolationKind::CardinalityViolation
138 );
139 violation!(GRAPH_CYCLE, "acyclicGraph", ViolationKind::ValueCheck);
140 violation!(
141 UNKNOWN_DTYPE,
142 "knownTensorDataType",
143 ViolationKind::ValueCheck
144 );
145
146 fn from_wire(_e: crate::onnx::protobuf::WireError) -> ShapeViolation {
147 PROTOBUF_FAILURE
148 }
149
150 #[inline]
151 fn sha256(bytes: &[u8]) -> [u8; 32] {
152 Sha256Hasher::initial().fold_bytes(bytes).finalize()
153 }
154
155 const CANON_PROTO_DEPTH_MAX: usize = 32;
157
158 fn canonical_proto_digest(body: &[u8], depth: usize) -> Result<[u8; 32], ShapeViolation> {
174 #[derive(Clone, Copy)]
175 struct F {
176 number: u64,
177 wt: u8,
178 off: usize,
179 len: usize,
180 val: u64,
181 }
182 let mut fs: Vec<F> = Vec::new();
183 let mut r = MessageReader::new(body);
184 while let Some(f) = r.next_field().map_err(from_wire)? {
185 fs.push(match f.value {
186 FieldValue::Varint(v) => F {
187 number: f.number,
188 wt: 0,
189 off: 0,
190 len: 0,
191 val: v,
192 },
193 FieldValue::Fixed64(v) => F {
194 number: f.number,
195 wt: 1,
196 off: 0,
197 len: 0,
198 val: v,
199 },
200 FieldValue::Fixed32(v) => F {
201 number: f.number,
202 wt: 5,
203 off: 0,
204 len: 0,
205 val: u64::from(v),
206 },
207 FieldValue::Bytes(b) => F {
208 number: f.number,
209 wt: 2,
210 off: b.as_ptr() as usize - body.as_ptr() as usize,
211 len: b.len(),
212 val: 0,
213 },
214 });
215 }
216 fs.sort_by_key(|f| f.number);
218
219 let mut h = Sha256Hasher::initial();
220 for f in fs.iter() {
221 fold(&mut h, &f.number.to_le_bytes());
222 fold(&mut h, &[f.wt]);
223 match f.wt {
224 0 | 1 => fold(&mut h, &f.val.to_le_bytes()),
225 5 => fold(&mut h, &(f.val as u32).to_le_bytes()),
226 _ => {
227 let payload = &body[f.off..f.off + f.len];
228 let sub = if depth < CANON_PROTO_DEPTH_MAX && !payload.is_empty() {
229 canonical_proto_digest(payload, depth + 1)
230 .unwrap_or_else(|_| sha256(payload))
231 } else {
232 sha256(payload)
233 };
234 fold(&mut h, &sub);
235 }
236 }
237 }
238 Ok(h.finalize())
239 }
240
241 #[inline]
245 fn fold(h: &mut Sha256Hasher, bytes: &[u8]) {
246 let cur = core::mem::replace(h, Sha256Hasher::initial());
247 *h = cur.fold_bytes(bytes);
248 }
249
250 fn first_field(body: &[u8], field_no: u64) -> Result<Option<FieldValue<'_>>, ShapeViolation> {
254 let mut r = MessageReader::new(body);
255 while let Some(f) = r.next_field().map_err(from_wire)? {
256 if f.number == field_no {
257 return Ok(Some(f.value));
258 }
259 }
260 Ok(None)
261 }
262
263 fn first_varint(body: &[u8], field_no: u64) -> Result<Option<u64>, ShapeViolation> {
264 Ok(match first_field(body, field_no)? {
265 Some(FieldValue::Varint(v)) => Some(v),
266 _ => None,
267 })
268 }
269
270 fn first_bytes(body: &[u8], field_no: u64) -> Result<&[u8], ShapeViolation> {
271 Ok(match first_field(body, field_no)? {
272 Some(FieldValue::Bytes(b)) => b,
273 _ => &[],
274 })
275 }
276
277 fn for_each_field(
280 body: &[u8],
281 field_no: u64,
282 mut f: impl FnMut(FieldValue<'_>) -> Result<(), ShapeViolation>,
283 ) -> Result<(), ShapeViolation> {
284 let mut r = MessageReader::new(body);
285 while let Some(field) = r.next_field().map_err(from_wire)? {
286 if field.number == field_no {
287 f(field.value)?;
288 }
289 }
290 Ok(())
291 }
292
293 fn count_field(body: &[u8], field_no: u64) -> Result<usize, ShapeViolation> {
294 let mut n = 0;
295 for_each_field(body, field_no, |_| {
296 n += 1;
297 Ok(())
298 })?;
299 Ok(n)
300 }
301
302 #[derive(Clone, Copy)]
304 struct Span {
305 off: usize,
306 len: usize,
307 }
308
309 fn collect_spans(body: &[u8], field_no: u64) -> Result<Vec<Span>, ShapeViolation> {
312 let mut spans: Vec<Span> = Vec::new();
313 let mut r = MessageReader::new(body);
314 while let Some(f) = r.next_field().map_err(from_wire)? {
315 if f.number == field_no {
316 if let FieldValue::Bytes(b) = f.value {
317 spans.push(Span {
318 off: b.as_ptr() as usize - body.as_ptr() as usize,
319 len: b.len(),
320 });
321 }
322 }
323 }
324 Ok(spans)
325 }
326
327 #[derive(Clone, PartialEq, Eq)]
331 pub struct OnnxValue {
332 bytes: Vec<u8>,
333 }
334
335 impl core::fmt::Debug for OnnxValue {
336 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
337 f.debug_struct("OnnxValue")
338 .field("canonical_len", &self.bytes.len())
339 .finish_non_exhaustive()
340 }
341 }
342
343 impl OnnxValue {
344 #[must_use]
346 pub fn canonical_bytes(&self) -> &[u8] {
347 &self.bytes
348 }
349
350 pub fn parse(raw: &[u8]) -> Result<Self, ShapeViolation> {
360 let mut out: Vec<u8> = Vec::new();
361
362 let ir_version = first_varint(raw, 1)?.ok_or(UNSUPPORTED_IR)? as i64;
368 if !(1..=ONNX_IR_VERSION_MAX).contains(&ir_version) {
369 return Err(UNSUPPORTED_IR);
370 }
371 out.extend_from_slice(&ir_version.to_le_bytes());
372
373 emit_opsets(&mut out, raw)?;
375
376 let graph = first_bytes(raw, 7)?;
378 if graph.is_empty() {
379 return Err(MISSING_GRAPH);
380 }
381 emit_canonical_graph(&mut out, graph, 0)?;
382
383 emit_model_meta(&mut out, raw)?;
385
386 Ok(Self { bytes: out })
387 }
388 }
389
390 fn emit_opsets(out: &mut Vec<u8>, model: &[u8]) -> Result<(), ShapeViolation> {
394 let entries = collect_spans(model, 8)?;
395
396 let mut ok_min = false;
398 for e in &entries {
399 let body = &model[e.off..e.off + e.len];
400 let domain = first_bytes(body, 1)?;
401 let version = first_varint(body, 2)?.unwrap_or(0) as i64;
402 if domain.is_empty() && version >= ONNX_OPSET_VERSION_MIN {
403 ok_min = true;
404 }
405 }
406 if !ok_min && !entries.is_empty() {
407 return Err(OPSET_TOO_OLD);
408 }
409
410 let mut order: Vec<usize> = (0..entries.len()).collect();
411 order.sort_by(|&a, &b| {
412 let ea = &model[entries[a].off..entries[a].off + entries[a].len];
413 let eb = &model[entries[b].off..entries[b].off + entries[b].len];
414 let ka = (
415 first_bytes(ea, 1).unwrap_or(&[]),
416 first_varint(ea, 2).ok().flatten().unwrap_or(0),
417 );
418 let kb = (
419 first_bytes(eb, 1).unwrap_or(&[]),
420 first_varint(eb, 2).ok().flatten().unwrap_or(0),
421 );
422 ka.cmp(&kb)
423 });
424
425 for &idx in &order {
426 let body = &model[entries[idx].off..entries[idx].off + entries[idx].len];
427 let domain = first_bytes(body, 1)?;
428 let version = first_varint(body, 2)?.unwrap_or(0) as i64;
429 out.extend_from_slice(&sha256(domain));
430 out.extend_from_slice(&version.to_le_bytes());
431 }
432 Ok(())
433 }
434
435 fn emit_model_meta(out: &mut Vec<u8>, model: &[u8]) -> Result<(), ShapeViolation> {
437 out.extend_from_slice(&sha256(first_bytes(model, 2)?)); out.extend_from_slice(&sha256(first_bytes(model, 3)?)); out.extend_from_slice(&sha256(first_bytes(model, 4)?)); out.extend_from_slice(&(first_varint(model, 5)?.unwrap_or(0) as i64).to_le_bytes()); emit_string_string(out, model, 14) }
443
444 fn emit_string_string(
447 out: &mut Vec<u8>,
448 body: &[u8],
449 field_no: u64,
450 ) -> Result<(), ShapeViolation> {
451 let entries = collect_spans(body, field_no)?;
452 let mut order: Vec<usize> = (0..entries.len()).collect();
453 order.sort_by(|&a, &b| {
454 let ka = first_bytes(&body[entries[a].off..entries[a].off + entries[a].len], 1)
455 .unwrap_or(&[]);
456 let kb = first_bytes(&body[entries[b].off..entries[b].off + entries[b].len], 1)
457 .unwrap_or(&[]);
458 ka.cmp(kb)
459 });
460 out.extend_from_slice(&(order.len() as u32).to_le_bytes());
461 for &idx in &order {
462 let e = &body[entries[idx].off..entries[idx].off + entries[idx].len];
463 out.extend_from_slice(&sha256(first_bytes(e, 1)?));
464 out.extend_from_slice(&sha256(first_bytes(e, 2)?));
465 }
466 Ok(())
467 }
468
469 fn emit_canonical_graph(
472 out: &mut Vec<u8>,
473 graph: &[u8],
474 depth: usize,
475 ) -> Result<(), ShapeViolation> {
476 if depth > ONNX_SUBGRAPH_DEPTH_MAX {
477 return Err(SUBGRAPH_DEPTH);
478 }
479
480 out.extend_from_slice(&sha256(first_bytes(graph, 2)?)); let nodes = collect_spans(graph, 1)?;
484 let node_count = nodes.len();
485 out.extend_from_slice(&(node_count as u32).to_le_bytes());
486
487 let mut emitted: Vec<bool> = alloc::vec![false; node_count];
488 for _ in 0..node_count {
489 let mut best: Option<usize> = None;
491 for (cand, node) in nodes.iter().enumerate() {
492 if emitted[cand] {
493 continue;
494 }
495 if !node_ready(graph, &nodes, &emitted, node)? {
496 continue;
497 }
498 best = Some(match best {
499 None => cand,
500 Some(b) => {
501 if node_lex_le(graph, &nodes[cand], &nodes[b])? {
502 cand
503 } else {
504 b
505 }
506 }
507 });
508 }
509 let pick = best.ok_or(GRAPH_CYCLE)?; let body = &graph[nodes[pick].off..nodes[pick].off + nodes[pick].len];
511 emit_node(out, body, depth)?;
512 emitted[pick] = true;
513 }
514
515 emit_tensor_section(out, graph, 5)?;
517
518 emit_value_info(out, graph, 11)?;
520 emit_value_info(out, graph, 12)?;
521 emit_value_info(out, graph, 13)?;
522
523 Ok(())
524 }
525
526 fn node_ready(
529 graph: &[u8],
530 nodes: &[Span],
531 emitted: &[bool],
532 node: &Span,
533 ) -> Result<bool, ShapeViolation> {
534 let body = &graph[node.off..node.off + node.len];
535 let mut ready = true;
536 for_each_field(body, 1, |v| {
537 if let FieldValue::Bytes(name) = v {
538 if !name.is_empty() {
539 for (k, prod) in nodes.iter().enumerate() {
540 let pbody = &graph[prod.off..prod.off + prod.len];
541 let mut produces = false;
542 for_each_field(pbody, 2, |ov| {
543 if let FieldValue::Bytes(on) = ov {
544 if on == name {
545 produces = true;
546 }
547 }
548 Ok(())
549 })?;
550 if produces && !emitted[k] {
551 ready = false;
552 }
553 }
554 }
555 }
556 Ok(())
557 })?;
558 Ok(ready)
559 }
560
561 fn node_lex_le(graph: &[u8], a: &Span, b: &Span) -> Result<bool, ShapeViolation> {
563 let ba = &graph[a.off..a.off + a.len];
564 let bb = &graph[b.off..b.off + b.len];
565 let ka = (
566 first_bytes(ba, 3)?,
567 first_bytes(ba, 4)?,
568 first_bytes(ba, 7)?,
569 );
570 let kb = (
571 first_bytes(bb, 3)?,
572 first_bytes(bb, 4)?,
573 first_bytes(bb, 7)?,
574 );
575 Ok(ka <= kb)
576 }
577
578 fn emit_node(out: &mut Vec<u8>, node: &[u8], depth: usize) -> Result<(), ShapeViolation> {
582 out.extend_from_slice(&sha256(first_bytes(node, 3)?)); out.extend_from_slice(&sha256(first_bytes(node, 4)?)); out.extend_from_slice(&sha256(first_bytes(node, 7)?)); out.extend_from_slice(&sha256(first_bytes(node, 8)?)); let n_in = count_field(node, 1)?;
588 out.extend_from_slice(&(n_in as u32).to_le_bytes());
589 for_each_field(node, 1, |v| {
590 if let FieldValue::Bytes(name) = v {
591 out.extend_from_slice(&sha256(name));
592 }
593 Ok(())
594 })?;
595
596 let n_out = count_field(node, 2)?;
597 out.extend_from_slice(&(n_out as u32).to_le_bytes());
598 for_each_field(node, 2, |v| {
599 if let FieldValue::Bytes(name) = v {
600 out.extend_from_slice(&sha256(name));
601 }
602 Ok(())
603 })?;
604
605 emit_attributes(out, node, depth)
606 }
607
608 fn emit_attributes(out: &mut Vec<u8>, node: &[u8], depth: usize) -> Result<(), ShapeViolation> {
610 let attrs = collect_spans(node, 5)?;
611 let mut order: Vec<usize> = (0..attrs.len()).collect();
612 order.sort_by(|&a, &b| {
613 let na =
614 first_bytes(&node[attrs[a].off..attrs[a].off + attrs[a].len], 1).unwrap_or(&[]);
615 let nb =
616 first_bytes(&node[attrs[b].off..attrs[b].off + attrs[b].len], 1).unwrap_or(&[]);
617 na.cmp(nb)
618 });
619 out.extend_from_slice(&(order.len() as u32).to_le_bytes());
620 for &idx in &order {
621 let a = &node[attrs[idx].off..attrs[idx].off + attrs[idx].len];
622 out.extend_from_slice(&sha256(first_bytes(a, 1)?)); let atype = first_varint(a, 20)?.unwrap_or(0) as i32;
624 out.extend_from_slice(&atype.to_le_bytes());
625 emit_attribute_value(out, a, atype, depth)?;
626 }
627 Ok(())
628 }
629
630 fn emit_attribute_value(
632 out: &mut Vec<u8>,
633 a: &[u8],
634 atype: i32,
635 depth: usize,
636 ) -> Result<(), ShapeViolation> {
637 match atype {
638 1 => {
639 if let Some(FieldValue::Fixed32(bits)) = first_field(a, 2)? {
641 out.extend_from_slice(&bits.to_le_bytes());
642 }
643 }
644 2 => {
645 out.extend_from_slice(&(first_varint(a, 3)?.unwrap_or(0) as i64).to_le_bytes());
647 }
648 3 => {
649 out.extend_from_slice(&sha256(first_bytes(a, 4)?));
651 }
652 4 => {
653 emit_tensor(out, first_bytes(a, 5)?)?;
655 }
656 5 => {
657 emit_canonical_graph(out, first_bytes(a, 6)?, depth + 1)?;
659 }
660 6 => {
661 for_each_field(a, 7, |v| {
663 if let FieldValue::Bytes(p) = v {
664 out.extend_from_slice(&sha256(p));
665 } else if let FieldValue::Fixed32(b) = v {
666 out.extend_from_slice(&b.to_le_bytes());
667 }
668 Ok(())
669 })?;
670 }
671 7 => {
672 emit_packed_varints(out, a, 8)?;
674 }
675 8 => {
676 for_each_field(a, 9, |v| {
678 if let FieldValue::Bytes(s) = v {
679 out.extend_from_slice(&sha256(s));
680 }
681 Ok(())
682 })?;
683 }
684 9 => {
685 let spans = collect_spans(a, 10)?;
687 for s in &spans {
688 emit_tensor(out, &a[s.off..s.off + s.len])?;
689 }
690 }
691 10 => {
692 let spans = collect_spans(a, 11)?;
694 for s in &spans {
695 emit_canonical_graph(out, &a[s.off..s.off + s.len], depth + 1)?;
696 }
697 }
698 11 => out.extend_from_slice(&canonical_proto_digest(first_bytes(a, 22)?, 0)?), 12 => {
700 for_each_field(a, 23, |v| {
702 if let FieldValue::Bytes(s) = v {
703 out.extend_from_slice(&canonical_proto_digest(s, 0)?);
704 }
705 Ok(())
706 })?;
707 }
708 13 => out.extend_from_slice(&canonical_proto_digest(first_bytes(a, 14)?, 0)?), 14 => {
710 for_each_field(a, 15, |v| {
712 if let FieldValue::Bytes(s) = v {
713 out.extend_from_slice(&canonical_proto_digest(s, 0)?);
714 }
715 Ok(())
716 })?;
717 }
718 _ => {}
719 }
720 Ok(())
721 }
722
723 fn emit_packed_varints(
724 out: &mut Vec<u8>,
725 body: &[u8],
726 field_no: u64,
727 ) -> Result<(), ShapeViolation> {
728 for_each_field(body, field_no, |v| {
729 match v {
730 FieldValue::Bytes(p) => {
731 let mut pos = 0;
732 while pos < p.len() {
733 let (val, np) = read_varint(p, pos).map_err(from_wire)?;
734 out.extend_from_slice(&(val as i64).to_le_bytes());
735 pos = np;
736 }
737 }
738 FieldValue::Varint(val) => out.extend_from_slice(&(val as i64).to_le_bytes()),
739 _ => {}
740 }
741 Ok(())
742 })
743 }
744
745 fn emit_tensor_section(
747 out: &mut Vec<u8>,
748 graph: &[u8],
749 field_no: u64,
750 ) -> Result<(), ShapeViolation> {
751 let spans = collect_spans(graph, field_no)?;
752 let mut order: Vec<usize> = (0..spans.len()).collect();
753 order.sort_by(|&a, &b| {
754 let na =
755 first_bytes(&graph[spans[a].off..spans[a].off + spans[a].len], 8).unwrap_or(&[]);
756 let nb =
757 first_bytes(&graph[spans[b].off..spans[b].off + spans[b].len], 8).unwrap_or(&[]);
758 na.cmp(nb)
759 });
760 out.extend_from_slice(&(order.len() as u32).to_le_bytes());
761 for &idx in &order {
762 let body = &graph[spans[idx].off..spans[idx].off + spans[idx].len];
763 emit_tensor(out, body)?;
764 }
765 Ok(())
766 }
767
768 fn emit_tensor(out: &mut Vec<u8>, t: &[u8]) -> Result<(), ShapeViolation> {
775 let dtype_id = first_varint(t, 2)?.unwrap_or(0) as i32;
776 let dtype = OnnxDataType::from_i32(dtype_id).ok_or(UNKNOWN_DTYPE)?;
777
778 out.extend_from_slice(&sha256(first_bytes(t, 8)?)); out.extend_from_slice(&dtype_id.to_le_bytes());
780
781 let rank = count_dims(t)?;
783 out.extend_from_slice(&(rank as u32).to_le_bytes());
784 emit_packed_varints(out, t, 1)?;
785
786 out.extend_from_slice(&tensor_data_digest(t, dtype)?);
788 Ok(())
789 }
790
791 fn count_dims(t: &[u8]) -> Result<usize, ShapeViolation> {
792 let mut n = 0;
793 for_each_field(t, 1, |v| {
794 match v {
795 FieldValue::Bytes(p) => {
796 let mut pos = 0;
797 while pos < p.len() {
798 let (_, np) = read_varint(p, pos).map_err(from_wire)?;
799 n += 1;
800 pos = np;
801 }
802 }
803 FieldValue::Varint(_) => n += 1,
804 _ => {}
805 }
806 Ok(())
807 })?;
808 Ok(n)
809 }
810
811 fn tensor_data_digest(t: &[u8], dtype: OnnxDataType) -> Result<[u8; 32], ShapeViolation> {
814 if first_varint(t, 14)?.unwrap_or(0) == 1 {
822 let mut h = Sha256Hasher::initial();
823 fold(&mut h, b"onnx:external-data:v1");
824 let mut sub: Vec<u8> = Vec::new();
826 emit_string_string(&mut sub, t, 13)?;
827 fold(&mut h, &sub);
828 return Ok(h.finalize());
829 }
830 if let Some(FieldValue::Bytes(raw)) = first_field(t, 9)? {
832 if !raw.is_empty() {
833 return Ok(sha256(raw));
834 }
835 }
836 let mut h = Sha256Hasher::initial();
837 match dtype {
838 OnnxDataType::Float => fold_fixed_payload(t, 4, &mut h)?,
841 OnnxDataType::Double | OnnxDataType::Complex128 => fold_fixed_payload(t, 10, &mut h)?,
842 OnnxDataType::Complex64 => fold_fixed_payload(t, 4, &mut h)?,
843 OnnxDataType::Int64 => fold_typed_varints(t, 7, 8, &mut h)?,
845 OnnxDataType::Uint64 => fold_typed_varints(t, 11, 8, &mut h)?,
847 OnnxDataType::Uint32 => fold_typed_varints(t, 11, 4, &mut h)?,
848 OnnxDataType::Int32 => fold_typed_varints(t, 5, 4, &mut h)?,
851 OnnxDataType::Int16
852 | OnnxDataType::Uint16
853 | OnnxDataType::Float16
854 | OnnxDataType::Bfloat16 => fold_typed_varints(t, 5, 2, &mut h)?,
855 OnnxDataType::Int8
856 | OnnxDataType::Uint8
857 | OnnxDataType::Bool
858 | OnnxDataType::Float8E4M3Fn
859 | OnnxDataType::Float8E4M3Fnuz
860 | OnnxDataType::Float8E5M2
861 | OnnxDataType::Float8E5M2Fnuz
862 | OnnxDataType::Int4
863 | OnnxDataType::Uint4
864 | OnnxDataType::Float4E2M1 => fold_typed_varints(t, 5, 1, &mut h)?,
865 OnnxDataType::String => {
867 for_each_field(t, 6, |v| {
868 if let FieldValue::Bytes(s) = v {
869 fold(&mut h, &sha256(s));
870 }
871 Ok(())
872 })?;
873 }
874 }
875 Ok(h.finalize())
876 }
877
878 fn fold_fixed_payload(
881 body: &[u8],
882 field_no: u64,
883 h: &mut Sha256Hasher,
884 ) -> Result<(), ShapeViolation> {
885 for_each_field(body, field_no, |v| {
886 match v {
887 FieldValue::Bytes(p) => fold(h, p),
888 FieldValue::Fixed32(b) => fold(h, &b.to_le_bytes()),
889 FieldValue::Fixed64(b) => fold(h, &b.to_le_bytes()),
890 _ => {}
891 }
892 Ok(())
893 })
894 }
895
896 fn fold_typed_varints(
899 body: &[u8],
900 field_no: u64,
901 width: usize,
902 h: &mut Sha256Hasher,
903 ) -> Result<(), ShapeViolation> {
904 for_each_field(body, field_no, |v| {
905 match v {
906 FieldValue::Bytes(p) => {
907 let mut pos = 0;
908 while pos < p.len() {
909 let (val, np) = read_varint(p, pos).map_err(from_wire)?;
910 fold(h, &val.to_le_bytes()[..width]);
911 pos = np;
912 }
913 }
914 FieldValue::Varint(val) => fold(h, &val.to_le_bytes()[..width]),
915 _ => {}
916 }
917 Ok(())
918 })
919 }
920
921 fn emit_value_info(
925 out: &mut Vec<u8>,
926 graph: &[u8],
927 field_no: u64,
928 ) -> Result<(), ShapeViolation> {
929 let spans = collect_spans(graph, field_no)?;
930 let mut order: Vec<usize> = (0..spans.len()).collect();
931 order.sort_by(|&a, &b| {
932 let na =
933 first_bytes(&graph[spans[a].off..spans[a].off + spans[a].len], 1).unwrap_or(&[]);
934 let nb =
935 first_bytes(&graph[spans[b].off..spans[b].off + spans[b].len], 1).unwrap_or(&[]);
936 na.cmp(nb)
937 });
938 out.extend_from_slice(&(order.len() as u32).to_le_bytes());
939 for &idx in &order {
940 let body = &graph[spans[idx].off..spans[idx].off + spans[idx].len];
941 out.extend_from_slice(&sha256(first_bytes(body, 1)?)); out.extend_from_slice(&canonical_proto_digest(first_bytes(body, 2)?, 0)?);
943 }
945 Ok(())
946 }
947
948 pub fn canonicalize(raw: &[u8]) -> Result<Vec<u8>, ShapeViolation> {
954 Ok(OnnxValue::parse(raw)?.bytes)
955 }
956
957 #[cfg(test)]
958 mod tests {
959 use super::*;
960
961 fn put_varint(out: &mut Vec<u8>, mut v: u64) {
964 loop {
965 let mut byte = (v & 0x7f) as u8;
966 v >>= 7;
967 if v != 0 {
968 byte |= 0x80;
969 }
970 out.push(byte);
971 if v == 0 {
972 break;
973 }
974 }
975 }
976
977 fn tag(out: &mut Vec<u8>, field_no: u64, wire: u64) {
978 put_varint(out, (field_no << 3) | wire);
979 }
980
981 fn field_varint(out: &mut Vec<u8>, field_no: u64, v: u64) {
982 tag(out, field_no, 0);
983 put_varint(out, v);
984 }
985
986 fn field_bytes(out: &mut Vec<u8>, field_no: u64, b: &[u8]) {
987 tag(out, field_no, 2);
988 put_varint(out, b.len() as u64);
989 out.extend_from_slice(b);
990 }
991
992 fn minimal_onnx() -> Vec<u8> {
995 let mut opset = Vec::new();
997 field_bytes(&mut opset, 1, b""); field_varint(&mut opset, 2, 1); let mut graph = Vec::new();
1002 field_bytes(&mut graph, 2, b"g");
1003
1004 let mut model = Vec::new();
1006 field_varint(&mut model, 1, ONNX_IR_VERSION_MAX as u64); field_bytes(&mut model, 7, &graph); field_bytes(&mut model, 8, &opset); model
1010 }
1011
1012 #[test]
1013 fn parses_minimal_model() {
1014 let canon = canonicalize(&minimal_onnx()).expect("valid");
1015 assert_eq!(
1021 canon.len(),
1022 8 + 40 + (32 + 4 + 4 + 12) + (32 + 32 + 32 + 8 + 4)
1023 );
1024 }
1025
1026 #[test]
1027 fn rejects_out_of_range_ir() {
1028 let mut model = Vec::new();
1031 field_varint(&mut model, 1, (ONNX_IR_VERSION_MAX + 1) as u64);
1032 let err = OnnxValue::parse(&model).expect_err("unsupported ir");
1033 assert_eq!(err.constraint_iri, UNSUPPORTED_IR.constraint_iri);
1034 }
1035
1036 #[test]
1037 fn rejects_missing_graph() {
1038 let mut opset = Vec::new();
1039 field_bytes(&mut opset, 1, b"");
1040 field_varint(&mut opset, 2, 1);
1041 let mut model = Vec::new();
1042 field_varint(&mut model, 1, ONNX_IR_VERSION_MAX as u64);
1043 field_bytes(&mut model, 8, &opset);
1044 let err = OnnxValue::parse(&model).expect_err("no graph");
1045 assert_eq!(err.constraint_iri, MISSING_GRAPH.constraint_iri);
1046 }
1047
1048 #[test]
1049 fn deterministic() {
1050 let a = canonicalize(&minimal_onnx()).expect("valid");
1051 let b = canonicalize(&minimal_onnx()).expect("valid");
1052 assert_eq!(a, b);
1053 }
1054 }
1055}