Skip to main content

uor_addr/onnx/
value.rs

1//! ONNX typed input (IR ≤ v13) (ADR-023 amended by ADR-060).
2//!
3//! Protobuf v3 admits many byte-representations of the same logical
4//! message; this realization defines a canonical form — a **flat
5//! skeleton** — that collapses that freedom. Two ONNX models that decode
6//! to the same logical content (regardless of protobuf field order, node
7//! ordering among valid topological orderings, or whether tensor data is
8//! stored in `raw_data` or the typed-data fields) canonicalize to
9//! byte-identical skeletons and therefore to the same κ-label.
10//!
11//! ```text
12//! LE_i64(ir_version)
13//! ── opset imports, sorted by (domain, version) ──
14//!   for op: sha256(domain) || LE_i64(version)
15//! ── graph (recursive) ──
16//!   sha256(graph_name)
17//!   nodes in Kahn-topological order (lex (name, op_type, domain) tie-break):
18//!     sha256(name) || sha256(op_type) || sha256(domain) || sha256(overload)
19//!       || LE_u32(n_in)  || (sha256(input_name)  × n_in)
20//!       || LE_u32(n_out) || (sha256(output_name) × n_out)
21//!       || attributes, sorted by name (GRAPH/GRAPHS recurse inline)
22//!   initializers (#5), sorted by name, each a canonical TensorProto record
23//!   graph input (#11) / output (#12) / value_info (#13), sorted by name
24//! ── model metadata ──
25//!   sha256(producer_name) || sha256(producer_version) || sha256(domain)
26//!     || LE_i64(model_version) || metadata_props sorted by key
27//! ```
28//!
29//! Under ADR-060 the **full skeleton** flows through the pipeline as a
30//! [`TermValue::Borrowed`] carrier and ψ₉ folds it through the σ-axis —
31//! there is no two-level commitment, no carrier ceiling, and no node /
32//! attribute / initializer / IO count cap. Variable-length leaves (tensor
33//! data bytes, strings, opaque sub-message payloads) are still replaced by
34//! their `sha256(...)` digest so the skeleton stays bounded by structure
35//! size, not data size, while still binding every weight byte into the
36//! κ-label.
37//!
38//! [`OnnxValue`] (the owned parsed value, `alloc`-gated) holds the
39//! skeleton; [`OnnxCarrier`] is the borrowed model-input handle the
40//! pipeline binds.
41
42use prism::operation::TermValue;
43use prism::pipeline::{
44    ConstrainedTypeShape, ConstraintRef, IntoBindingValue, PartitionProductFields,
45};
46
47// ─── OnnxCarrier — the borrowed model-input handle (no_alloc) ───────────
48
49/// Borrowed canonical-skeleton input handle (ADR-060 borrowed carrier). A
50/// thin, `Copy` borrow of the skeleton bytes produced by [`canonicalize`];
51/// `as_binding_value` returns the `Borrowed` carrier zero-copy.
52#[derive(Clone, Copy, Debug)]
53pub struct OnnxCarrier<'a>(&'a [u8]);
54
55impl<'a> OnnxCarrier<'a> {
56    /// Wrap a canonical-skeleton byte slice as a model input handle.
57    #[must_use]
58    pub fn new(skeleton: &'a [u8]) -> Self {
59        Self(skeleton)
60    }
61
62    /// Borrow the canonical-skeleton bytes.
63    #[must_use]
64    pub fn canonical_bytes(&self) -> &'a [u8] {
65        self.0
66    }
67}
68
69impl ConstrainedTypeShape for OnnxCarrier<'_> {
70    const IRI: &'static str = "https://uor.foundation/addr/OnnxValue";
71    const SITE_COUNT: usize = 1;
72    const CONSTRAINTS: &'static [ConstraintRef] = &[];
73    const CYCLE_SIZE: u64 = u64::MAX;
74}
75
76impl prism::uor_foundation::pipeline::__sdk_seal::Sealed for OnnxCarrier<'_> {}
77
78impl<'a> IntoBindingValue<'a> for OnnxCarrier<'a> {
79    fn as_binding_value<const INLINE_BYTES: usize>(&self) -> TermValue<'a, INLINE_BYTES> {
80        TermValue::borrowed(self.0)
81    }
82}
83
84impl PartitionProductFields for OnnxCarrier<'_> {
85    const FIELDS: &'static [(u32, u32)] = &[];
86    const FIELD_NAMES: &'static [&'static str] = &[];
87}
88
89// ═════════════════════════════════════════════════════════════════════
90// alloc-gated parser + owned value
91// ═════════════════════════════════════════════════════════════════════
92
93#[cfg(feature = "alloc")]
94pub use alloc_impl::{canonicalize, OnnxValue};
95
96#[cfg(feature = "alloc")]
97mod alloc_impl {
98    use alloc::vec::Vec;
99
100    use prism::crypto::Sha256Hasher;
101    use prism::pipeline::{ShapeViolation, ViolationKind};
102    use prism::vocabulary::Hasher;
103
104    use crate::onnx::dtype::OnnxDataType;
105    use crate::onnx::protobuf::{read_varint, FieldValue, MessageReader};
106    use crate::onnx::shapes::bounds::{
107        ONNX_IR_VERSION_MAX, ONNX_OPSET_VERSION_MIN, ONNX_SUBGRAPH_DEPTH_MAX,
108    };
109
110    // ─── ShapeViolation IRIs ─────────────────────────────────────────────
111
112    macro_rules! violation {
113        ($name:ident, $constraint:literal, $kind:expr) => {
114            const $name: ShapeViolation = ShapeViolation {
115                shape_iri: "https://uor.foundation/addr/OnnxValue",
116                constraint_iri: concat!("https://uor.foundation/addr/OnnxValue/", $constraint),
117                property_iri: concat!("https://uor.foundation/addr/OnnxValue/", $constraint),
118                expected_range: "http://www.w3.org/2001/XMLSchema#nonNegativeInteger",
119                min_count: 0,
120                max_count: 1,
121                kind: $kind,
122            };
123        };
124    }
125
126    violation!(PROTOBUF_FAILURE, "validProtobuf", ViolationKind::ValueCheck);
127    violation!(
128        UNSUPPORTED_IR,
129        "supportedIrVersion",
130        ViolationKind::ValueCheck
131    );
132    violation!(OPSET_TOO_OLD, "opsetVersionMin", ViolationKind::ValueCheck);
133    violation!(MISSING_GRAPH, "graphPresent", ViolationKind::ValueCheck);
134    violation!(
135        SUBGRAPH_DEPTH,
136        "subgraphDepthBound",
137        ViolationKind::CardinalityViolation
138    );
139    violation!(GRAPH_CYCLE, "acyclicGraph", ViolationKind::ValueCheck);
140    violation!(
141        UNKNOWN_DTYPE,
142        "knownTensorDataType",
143        ViolationKind::ValueCheck
144    );
145
146    fn from_wire(_e: crate::onnx::protobuf::WireError) -> ShapeViolation {
147        PROTOBUF_FAILURE
148    }
149
150    #[inline]
151    fn sha256(bytes: &[u8]) -> [u8; 32] {
152        Sha256Hasher::initial().fold_bytes(bytes).finalize()
153    }
154
155    /// Recursion ceiling for the opaque-message field-order canonicalizer.
156    const CANON_PROTO_DEPTH_MAX: usize = 32;
157
158    /// Field-order-canonical digest of an opaque protobuf message — folds
159    /// its fields in ascending field-number order (stable within a number,
160    /// so repeated-field order is preserved), recursing into
161    /// length-delimited fields. This applies canonicalization rule 1
162    /// (field-number ordering) to sub-messages the realization otherwise
163    /// treats opaquely (`TypeProto`, `SparseTensorProto`), so two
164    /// serializations of the same logical value canonicalize identically.
165    /// Returns a 32-byte **leaf digest** (an opaque sub-message is a leaf
166    /// — its digest is appended inline, never expanded into the skeleton).
167    ///
168    /// A length-delimited field that is genuinely a string / bytes leaf
169    /// (e.g. `dim_param`) generally fails to re-parse as a well-formed
170    /// message; that case falls back to a digest of the raw payload. The
171    /// transform is deterministic either way: identical bytes always take
172    /// the same path.
173    fn canonical_proto_digest(body: &[u8], depth: usize) -> Result<[u8; 32], ShapeViolation> {
174        #[derive(Clone, Copy)]
175        struct F {
176            number: u64,
177            wt: u8,
178            off: usize,
179            len: usize,
180            val: u64,
181        }
182        let mut fs: Vec<F> = Vec::new();
183        let mut r = MessageReader::new(body);
184        while let Some(f) = r.next_field().map_err(from_wire)? {
185            fs.push(match f.value {
186                FieldValue::Varint(v) => F {
187                    number: f.number,
188                    wt: 0,
189                    off: 0,
190                    len: 0,
191                    val: v,
192                },
193                FieldValue::Fixed64(v) => F {
194                    number: f.number,
195                    wt: 1,
196                    off: 0,
197                    len: 0,
198                    val: v,
199                },
200                FieldValue::Fixed32(v) => F {
201                    number: f.number,
202                    wt: 5,
203                    off: 0,
204                    len: 0,
205                    val: u64::from(v),
206                },
207                FieldValue::Bytes(b) => F {
208                    number: f.number,
209                    wt: 2,
210                    off: b.as_ptr() as usize - body.as_ptr() as usize,
211                    len: b.len(),
212                    val: 0,
213                },
214            });
215        }
216        // Stable sort by field number (preserves repeated-field order).
217        fs.sort_by_key(|f| f.number);
218
219        let mut h = Sha256Hasher::initial();
220        for f in fs.iter() {
221            fold(&mut h, &f.number.to_le_bytes());
222            fold(&mut h, &[f.wt]);
223            match f.wt {
224                0 | 1 => fold(&mut h, &f.val.to_le_bytes()),
225                5 => fold(&mut h, &(f.val as u32).to_le_bytes()),
226                _ => {
227                    let payload = &body[f.off..f.off + f.len];
228                    let sub = if depth < CANON_PROTO_DEPTH_MAX && !payload.is_empty() {
229                        canonical_proto_digest(payload, depth + 1)
230                            .unwrap_or_else(|_| sha256(payload))
231                    } else {
232                        sha256(payload)
233                    };
234                    fold(&mut h, &sub);
235                }
236            }
237        }
238        Ok(h.finalize())
239    }
240
241    /// Fold `bytes` into the running hasher behind a mutable reference (the
242    /// `Hasher::fold_bytes` consume-and-return API is awkward inside
243    /// `FnMut` closures; this wraps the take-replace dance once).
244    #[inline]
245    fn fold(h: &mut Sha256Hasher, bytes: &[u8]) {
246        let cur = core::mem::replace(h, Sha256Hasher::initial());
247        *h = cur.fold_bytes(bytes);
248    }
249
250    // ─── Protobuf field accessors over a message body ──────────────────
251
252    /// First occurrence of `field_no` in `body`, or `None`.
253    fn first_field(body: &[u8], field_no: u64) -> Result<Option<FieldValue<'_>>, ShapeViolation> {
254        let mut r = MessageReader::new(body);
255        while let Some(f) = r.next_field().map_err(from_wire)? {
256            if f.number == field_no {
257                return Ok(Some(f.value));
258            }
259        }
260        Ok(None)
261    }
262
263    fn first_varint(body: &[u8], field_no: u64) -> Result<Option<u64>, ShapeViolation> {
264        Ok(match first_field(body, field_no)? {
265            Some(FieldValue::Varint(v)) => Some(v),
266            _ => None,
267        })
268    }
269
270    fn first_bytes(body: &[u8], field_no: u64) -> Result<&[u8], ShapeViolation> {
271        Ok(match first_field(body, field_no)? {
272            Some(FieldValue::Bytes(b)) => b,
273            _ => &[],
274        })
275    }
276
277    /// Invoke `f` for every occurrence of `field_no` (the repeated-field
278    /// iterator). Stops and propagates the first error `f` returns.
279    fn for_each_field(
280        body: &[u8],
281        field_no: u64,
282        mut f: impl FnMut(FieldValue<'_>) -> Result<(), ShapeViolation>,
283    ) -> Result<(), ShapeViolation> {
284        let mut r = MessageReader::new(body);
285        while let Some(field) = r.next_field().map_err(from_wire)? {
286            if field.number == field_no {
287                f(field.value)?;
288            }
289        }
290        Ok(())
291    }
292
293    fn count_field(body: &[u8], field_no: u64) -> Result<usize, ShapeViolation> {
294        let mut n = 0;
295        for_each_field(body, field_no, |_| {
296            n += 1;
297            Ok(())
298        })?;
299        Ok(n)
300    }
301
302    /// A `(offset, len)` span into a parent buffer.
303    #[derive(Clone, Copy)]
304    struct Span {
305        off: usize,
306        len: usize,
307    }
308
309    /// Collect every occurrence of `field_no` (length-delimited) in `body`
310    /// as a `(offset, len)` span.
311    fn collect_spans(body: &[u8], field_no: u64) -> Result<Vec<Span>, ShapeViolation> {
312        let mut spans: Vec<Span> = Vec::new();
313        let mut r = MessageReader::new(body);
314        while let Some(f) = r.next_field().map_err(from_wire)? {
315            if f.number == field_no {
316                if let FieldValue::Bytes(b) = f.value {
317                    spans.push(Span {
318                        off: b.as_ptr() as usize - body.as_ptr() as usize,
319                        len: b.len(),
320                    });
321                }
322            }
323        }
324        Ok(spans)
325    }
326
327    /// A parsed, canonicalized ONNX `ModelProto`. The stored bytes are the
328    /// flat canonical skeleton (see [module docs](super)). **`alloc`-gated**
329    /// — the pipeline binds the borrowed [`OnnxCarrier`](super::OnnxCarrier).
330    #[derive(Clone, PartialEq, Eq)]
331    pub struct OnnxValue {
332        bytes: Vec<u8>,
333    }
334
335    impl core::fmt::Debug for OnnxValue {
336        fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
337            f.debug_struct("OnnxValue")
338                .field("canonical_len", &self.bytes.len())
339                .finish_non_exhaustive()
340        }
341    }
342
343    impl OnnxValue {
344        /// Borrow the canonical-skeleton bytes.
345        #[must_use]
346        pub fn canonical_bytes(&self) -> &[u8] {
347            &self.bytes
348        }
349
350        /// Parse an ONNX `ModelProto` wire buffer into a canonicalized
351        /// skeleton.
352        ///
353        /// # Errors
354        ///
355        /// A [`ShapeViolation`] whose `constraint_iri` names the violated
356        /// invariant (protobuf decode failure, unsupported IR version,
357        /// opset below the minimum, missing graph, a subgraph cycle, an
358        /// over-deep subgraph nesting, or an unknown tensor data type).
359        pub fn parse(raw: &[u8]) -> Result<Self, ShapeViolation> {
360            let mut out: Vec<u8> = Vec::new();
361
362            // ── ir_version (ModelProto #1) ──
363            // Accept any known IR revision (1..=ONNX_IR_VERSION_MAX); the
364            // canonical skeleton is IR-version-agnostic and binds the
365            // `ir_version` value, so distinct revisions canonicalize
366            // distinctly. Reject absent / 0 / a future unknown revision.
367            let ir_version = first_varint(raw, 1)?.ok_or(UNSUPPORTED_IR)? as i64;
368            if !(1..=ONNX_IR_VERSION_MAX).contains(&ir_version) {
369                return Err(UNSUPPORTED_IR);
370            }
371            out.extend_from_slice(&ir_version.to_le_bytes());
372
373            // ── opset imports (ModelProto #8, repeated OperatorSetIdProto) ──
374            emit_opsets(&mut out, raw)?;
375
376            // ── graph (ModelProto #7) ──
377            let graph = first_bytes(raw, 7)?;
378            if graph.is_empty() {
379                return Err(MISSING_GRAPH);
380            }
381            emit_canonical_graph(&mut out, graph, 0)?;
382
383            // ── model metadata ──
384            emit_model_meta(&mut out, raw)?;
385
386            Ok(Self { bytes: out })
387        }
388    }
389
390    /// Emit opset imports sorted by `(domain, version)`. Enforces at least
391    /// one default-domain (`""`) import at or above
392    /// [`ONNX_OPSET_VERSION_MIN`].
393    fn emit_opsets(out: &mut Vec<u8>, model: &[u8]) -> Result<(), ShapeViolation> {
394        let entries = collect_spans(model, 8)?;
395
396        // Default-domain minimum-version check.
397        let mut ok_min = false;
398        for e in &entries {
399            let body = &model[e.off..e.off + e.len];
400            let domain = first_bytes(body, 1)?;
401            let version = first_varint(body, 2)?.unwrap_or(0) as i64;
402            if domain.is_empty() && version >= ONNX_OPSET_VERSION_MIN {
403                ok_min = true;
404            }
405        }
406        if !ok_min && !entries.is_empty() {
407            return Err(OPSET_TOO_OLD);
408        }
409
410        let mut order: Vec<usize> = (0..entries.len()).collect();
411        order.sort_by(|&a, &b| {
412            let ea = &model[entries[a].off..entries[a].off + entries[a].len];
413            let eb = &model[entries[b].off..entries[b].off + entries[b].len];
414            let ka = (
415                first_bytes(ea, 1).unwrap_or(&[]),
416                first_varint(ea, 2).ok().flatten().unwrap_or(0),
417            );
418            let kb = (
419                first_bytes(eb, 1).unwrap_or(&[]),
420                first_varint(eb, 2).ok().flatten().unwrap_or(0),
421            );
422            ka.cmp(&kb)
423        });
424
425        for &idx in &order {
426            let body = &model[entries[idx].off..entries[idx].off + entries[idx].len];
427            let domain = first_bytes(body, 1)?;
428            let version = first_varint(body, 2)?.unwrap_or(0) as i64;
429            out.extend_from_slice(&sha256(domain));
430            out.extend_from_slice(&version.to_le_bytes());
431        }
432        Ok(())
433    }
434
435    /// Emit producer / domain / model_version + sorted `metadata_props`.
436    fn emit_model_meta(out: &mut Vec<u8>, model: &[u8]) -> Result<(), ShapeViolation> {
437        out.extend_from_slice(&sha256(first_bytes(model, 2)?)); // producer_name
438        out.extend_from_slice(&sha256(first_bytes(model, 3)?)); // producer_version
439        out.extend_from_slice(&sha256(first_bytes(model, 4)?)); // domain
440        out.extend_from_slice(&(first_varint(model, 5)?.unwrap_or(0) as i64).to_le_bytes()); // model_version
441        emit_string_string(out, model, 14) // metadata_props
442    }
443
444    /// Emit a repeated `StringStringEntryProto` map (`field_no`), sorted by
445    /// key, inline: for each entry `sha256(key) || sha256(value)`.
446    fn emit_string_string(
447        out: &mut Vec<u8>,
448        body: &[u8],
449        field_no: u64,
450    ) -> Result<(), ShapeViolation> {
451        let entries = collect_spans(body, field_no)?;
452        let mut order: Vec<usize> = (0..entries.len()).collect();
453        order.sort_by(|&a, &b| {
454            let ka = first_bytes(&body[entries[a].off..entries[a].off + entries[a].len], 1)
455                .unwrap_or(&[]);
456            let kb = first_bytes(&body[entries[b].off..entries[b].off + entries[b].len], 1)
457                .unwrap_or(&[]);
458            ka.cmp(kb)
459        });
460        out.extend_from_slice(&(order.len() as u32).to_le_bytes());
461        for &idx in &order {
462            let e = &body[entries[idx].off..entries[idx].off + entries[idx].len];
463            out.extend_from_slice(&sha256(first_bytes(e, 1)?));
464            out.extend_from_slice(&sha256(first_bytes(e, 2)?));
465        }
466        Ok(())
467    }
468
469    /// Emit a `GraphProto` body inline, recursing into subgraphs bounded by
470    /// [`ONNX_SUBGRAPH_DEPTH_MAX`].
471    fn emit_canonical_graph(
472        out: &mut Vec<u8>,
473        graph: &[u8],
474        depth: usize,
475    ) -> Result<(), ShapeViolation> {
476        if depth > ONNX_SUBGRAPH_DEPTH_MAX {
477            return Err(SUBGRAPH_DEPTH);
478        }
479
480        out.extend_from_slice(&sha256(first_bytes(graph, 2)?)); // graph name
481
482        // ── Nodes in Kahn-topological order (lex tie-break) ──
483        let nodes = collect_spans(graph, 1)?;
484        let node_count = nodes.len();
485        out.extend_from_slice(&(node_count as u32).to_le_bytes());
486
487        let mut emitted: Vec<bool> = alloc::vec![false; node_count];
488        for _ in 0..node_count {
489            // Find the lex-min ready (all producers emitted), unemitted node.
490            let mut best: Option<usize> = None;
491            for (cand, node) in nodes.iter().enumerate() {
492                if emitted[cand] {
493                    continue;
494                }
495                if !node_ready(graph, &nodes, &emitted, node)? {
496                    continue;
497                }
498                best = Some(match best {
499                    None => cand,
500                    Some(b) => {
501                        if node_lex_le(graph, &nodes[cand], &nodes[b])? {
502                            cand
503                        } else {
504                            b
505                        }
506                    }
507                });
508            }
509            let pick = best.ok_or(GRAPH_CYCLE)?; // no ready node ⇒ cycle
510            let body = &graph[nodes[pick].off..nodes[pick].off + nodes[pick].len];
511            emit_node(out, body, depth)?;
512            emitted[pick] = true;
513        }
514
515        // ── Initializers (#5), sorted by name, with tensor-data digests ──
516        emit_tensor_section(out, graph, 5)?;
517
518        // ── Graph IO: inputs (#11), outputs (#12), value_info (#13) ──
519        emit_value_info(out, graph, 11)?;
520        emit_value_info(out, graph, 12)?;
521        emit_value_info(out, graph, 13)?;
522
523        Ok(())
524    }
525
526    /// A node is ready when every input name that is *produced by another
527    /// node in this graph* has had its producer emitted.
528    fn node_ready(
529        graph: &[u8],
530        nodes: &[Span],
531        emitted: &[bool],
532        node: &Span,
533    ) -> Result<bool, ShapeViolation> {
534        let body = &graph[node.off..node.off + node.len];
535        let mut ready = true;
536        for_each_field(body, 1, |v| {
537            if let FieldValue::Bytes(name) = v {
538                if !name.is_empty() {
539                    for (k, prod) in nodes.iter().enumerate() {
540                        let pbody = &graph[prod.off..prod.off + prod.len];
541                        let mut produces = false;
542                        for_each_field(pbody, 2, |ov| {
543                            if let FieldValue::Bytes(on) = ov {
544                                if on == name {
545                                    produces = true;
546                                }
547                            }
548                            Ok(())
549                        })?;
550                        if produces && !emitted[k] {
551                            ready = false;
552                        }
553                    }
554                }
555            }
556            Ok(())
557        })?;
558        Ok(ready)
559    }
560
561    /// Lexicographic order on `(name, op_type, domain)`.
562    fn node_lex_le(graph: &[u8], a: &Span, b: &Span) -> Result<bool, ShapeViolation> {
563        let ba = &graph[a.off..a.off + a.len];
564        let bb = &graph[b.off..b.off + b.len];
565        let ka = (
566            first_bytes(ba, 3)?,
567            first_bytes(ba, 4)?,
568            first_bytes(ba, 7)?,
569        );
570        let kb = (
571            first_bytes(bb, 3)?,
572            first_bytes(bb, 4)?,
573            first_bytes(bb, 7)?,
574        );
575        Ok(ka <= kb)
576    }
577
578    /// Emit a `NodeProto` inline: identity fields, positional inputs /
579    /// outputs, then the name-sorted attributes (which recurse into
580    /// subgraphs inline).
581    fn emit_node(out: &mut Vec<u8>, node: &[u8], depth: usize) -> Result<(), ShapeViolation> {
582        out.extend_from_slice(&sha256(first_bytes(node, 3)?)); // name
583        out.extend_from_slice(&sha256(first_bytes(node, 4)?)); // op_type
584        out.extend_from_slice(&sha256(first_bytes(node, 7)?)); // domain
585        out.extend_from_slice(&sha256(first_bytes(node, 8)?)); // overload (IR v10+)
586
587        let n_in = count_field(node, 1)?;
588        out.extend_from_slice(&(n_in as u32).to_le_bytes());
589        for_each_field(node, 1, |v| {
590            if let FieldValue::Bytes(name) = v {
591                out.extend_from_slice(&sha256(name));
592            }
593            Ok(())
594        })?;
595
596        let n_out = count_field(node, 2)?;
597        out.extend_from_slice(&(n_out as u32).to_le_bytes());
598        for_each_field(node, 2, |v| {
599            if let FieldValue::Bytes(name) = v {
600                out.extend_from_slice(&sha256(name));
601            }
602            Ok(())
603        })?;
604
605        emit_attributes(out, node, depth)
606    }
607
608    /// Emit a node's `attribute` field (#5), sorted by name, inline.
609    fn emit_attributes(out: &mut Vec<u8>, node: &[u8], depth: usize) -> Result<(), ShapeViolation> {
610        let attrs = collect_spans(node, 5)?;
611        let mut order: Vec<usize> = (0..attrs.len()).collect();
612        order.sort_by(|&a, &b| {
613            let na =
614                first_bytes(&node[attrs[a].off..attrs[a].off + attrs[a].len], 1).unwrap_or(&[]);
615            let nb =
616                first_bytes(&node[attrs[b].off..attrs[b].off + attrs[b].len], 1).unwrap_or(&[]);
617            na.cmp(nb)
618        });
619        out.extend_from_slice(&(order.len() as u32).to_le_bytes());
620        for &idx in &order {
621            let a = &node[attrs[idx].off..attrs[idx].off + attrs[idx].len];
622            out.extend_from_slice(&sha256(first_bytes(a, 1)?)); // name
623            let atype = first_varint(a, 20)?.unwrap_or(0) as i32;
624            out.extend_from_slice(&atype.to_le_bytes());
625            emit_attribute_value(out, a, atype, depth)?;
626        }
627        Ok(())
628    }
629
630    /// Emit an attribute's value inline, dispatched on its `AttributeType`.
631    fn emit_attribute_value(
632        out: &mut Vec<u8>,
633        a: &[u8],
634        atype: i32,
635        depth: usize,
636    ) -> Result<(), ShapeViolation> {
637        match atype {
638            1 => {
639                // FLOAT (#2, fixed32)
640                if let Some(FieldValue::Fixed32(bits)) = first_field(a, 2)? {
641                    out.extend_from_slice(&bits.to_le_bytes());
642                }
643            }
644            2 => {
645                // INT (#3, varint)
646                out.extend_from_slice(&(first_varint(a, 3)?.unwrap_or(0) as i64).to_le_bytes());
647            }
648            3 => {
649                // STRING (#4, bytes)
650                out.extend_from_slice(&sha256(first_bytes(a, 4)?));
651            }
652            4 => {
653                // TENSOR (#5)
654                emit_tensor(out, first_bytes(a, 5)?)?;
655            }
656            5 => {
657                // GRAPH (#6) — recurse inline
658                emit_canonical_graph(out, first_bytes(a, 6)?, depth + 1)?;
659            }
660            6 => {
661                // FLOATS (#7, packed fixed32)
662                for_each_field(a, 7, |v| {
663                    if let FieldValue::Bytes(p) = v {
664                        out.extend_from_slice(&sha256(p));
665                    } else if let FieldValue::Fixed32(b) = v {
666                        out.extend_from_slice(&b.to_le_bytes());
667                    }
668                    Ok(())
669                })?;
670            }
671            7 => {
672                // INTS (#8, packed varint)
673                emit_packed_varints(out, a, 8)?;
674            }
675            8 => {
676                // STRINGS (#9, repeated bytes)
677                for_each_field(a, 9, |v| {
678                    if let FieldValue::Bytes(s) = v {
679                        out.extend_from_slice(&sha256(s));
680                    }
681                    Ok(())
682                })?;
683            }
684            9 => {
685                // TENSORS (#10)
686                let spans = collect_spans(a, 10)?;
687                for s in &spans {
688                    emit_tensor(out, &a[s.off..s.off + s.len])?;
689                }
690            }
691            10 => {
692                // GRAPHS (#11) — recurse inline
693                let spans = collect_spans(a, 11)?;
694                for s in &spans {
695                    emit_canonical_graph(out, &a[s.off..s.off + s.len], depth + 1)?;
696                }
697            }
698            11 => out.extend_from_slice(&canonical_proto_digest(first_bytes(a, 22)?, 0)?), // SPARSE_TENSOR
699            12 => {
700                // SPARSE_TENSORS (#23)
701                for_each_field(a, 23, |v| {
702                    if let FieldValue::Bytes(s) = v {
703                        out.extend_from_slice(&canonical_proto_digest(s, 0)?);
704                    }
705                    Ok(())
706                })?;
707            }
708            13 => out.extend_from_slice(&canonical_proto_digest(first_bytes(a, 14)?, 0)?), // TYPE_PROTO
709            14 => {
710                // TYPE_PROTOS (#15)
711                for_each_field(a, 15, |v| {
712                    if let FieldValue::Bytes(s) = v {
713                        out.extend_from_slice(&canonical_proto_digest(s, 0)?);
714                    }
715                    Ok(())
716                })?;
717            }
718            _ => {}
719        }
720        Ok(())
721    }
722
723    fn emit_packed_varints(
724        out: &mut Vec<u8>,
725        body: &[u8],
726        field_no: u64,
727    ) -> Result<(), ShapeViolation> {
728        for_each_field(body, field_no, |v| {
729            match v {
730                FieldValue::Bytes(p) => {
731                    let mut pos = 0;
732                    while pos < p.len() {
733                        let (val, np) = read_varint(p, pos).map_err(from_wire)?;
734                        out.extend_from_slice(&(val as i64).to_le_bytes());
735                        pos = np;
736                    }
737                }
738                FieldValue::Varint(val) => out.extend_from_slice(&(val as i64).to_le_bytes()),
739                _ => {}
740            }
741            Ok(())
742        })
743    }
744
745    /// Emit a name-sorted section of repeated `TensorProto` (initializers).
746    fn emit_tensor_section(
747        out: &mut Vec<u8>,
748        graph: &[u8],
749        field_no: u64,
750    ) -> Result<(), ShapeViolation> {
751        let spans = collect_spans(graph, field_no)?;
752        let mut order: Vec<usize> = (0..spans.len()).collect();
753        order.sort_by(|&a, &b| {
754            let na =
755                first_bytes(&graph[spans[a].off..spans[a].off + spans[a].len], 8).unwrap_or(&[]);
756            let nb =
757                first_bytes(&graph[spans[b].off..spans[b].off + spans[b].len], 8).unwrap_or(&[]);
758            na.cmp(nb)
759        });
760        out.extend_from_slice(&(order.len() as u32).to_le_bytes());
761        for &idx in &order {
762            let body = &graph[spans[idx].off..spans[idx].off + spans[idx].len];
763            emit_tensor(out, body)?;
764        }
765        Ok(())
766    }
767
768    /// Emit a canonical `TensorProto` record inline: `sha256(name) ||
769    /// LE_i32(dtype) || LE_u32(rank) || (LE_i64 dim …) || tensor_data_digest`,
770    /// where `tensor_data_digest` is a 32-byte leaf digest streaming
771    /// `raw_data` if present, else the typed-data field re-encoded to the
772    /// canonical little-endian `raw_data` layout (so the two storage forms
773    /// canonicalize identically).
774    fn emit_tensor(out: &mut Vec<u8>, t: &[u8]) -> Result<(), ShapeViolation> {
775        let dtype_id = first_varint(t, 2)?.unwrap_or(0) as i32;
776        let dtype = OnnxDataType::from_i32(dtype_id).ok_or(UNKNOWN_DTYPE)?;
777
778        out.extend_from_slice(&sha256(first_bytes(t, 8)?)); // name
779        out.extend_from_slice(&dtype_id.to_le_bytes());
780
781        // dims (#1, repeated int64; packed or unpacked).
782        let rank = count_dims(t)?;
783        out.extend_from_slice(&(rank as u32).to_le_bytes());
784        emit_packed_varints(out, t, 1)?;
785
786        // data digest (a leaf — appended inline as 32 bytes).
787        out.extend_from_slice(&tensor_data_digest(t, dtype)?);
788        Ok(())
789    }
790
791    fn count_dims(t: &[u8]) -> Result<usize, ShapeViolation> {
792        let mut n = 0;
793        for_each_field(t, 1, |v| {
794            match v {
795                FieldValue::Bytes(p) => {
796                    let mut pos = 0;
797                    while pos < p.len() {
798                        let (_, np) = read_varint(p, pos).map_err(from_wire)?;
799                        n += 1;
800                        pos = np;
801                    }
802                }
803                FieldValue::Varint(_) => n += 1,
804                _ => {}
805            }
806            Ok(())
807        })?;
808        Ok(n)
809    }
810
811    /// Stream the tensor's data through SHA-256 in canonical `raw_data`
812    /// layout, returning the 32-byte leaf digest.
813    fn tensor_data_digest(t: &[u8], dtype: OnnxDataType) -> Result<[u8; 32], ShapeViolation> {
814        // External data (`data_location` #14 == EXTERNAL = 1): the core
815        // cannot open the referenced sibling file, so the κ-label binds the
816        // external *reference* (`external_data` #13 — location / offset /
817        // length / checksum, sorted by key) rather than the dereferenced
818        // bytes. A domain tag keeps external digests disjoint from inline
819        // ones. Hosts requiring inline≡external equivalence dereference
820        // before calling.
821        if first_varint(t, 14)?.unwrap_or(0) == 1 {
822            let mut h = Sha256Hasher::initial();
823            fold(&mut h, b"onnx:external-data:v1");
824            // metadata-style sorted digest of external_data (#13).
825            let mut sub: Vec<u8> = Vec::new();
826            emit_string_string(&mut sub, t, 13)?;
827            fold(&mut h, &sub);
828            return Ok(h.finalize());
829        }
830        // raw_data (#9) takes precedence and is already canonical.
831        if let Some(FieldValue::Bytes(raw)) = first_field(t, 9)? {
832            if !raw.is_empty() {
833                return Ok(sha256(raw));
834            }
835        }
836        let mut h = Sha256Hasher::initial();
837        match dtype {
838            // float_data (#4) / double_data (#10): packed fixed-width — the
839            // packed payload IS the canonical raw layout.
840            OnnxDataType::Float => fold_fixed_payload(t, 4, &mut h)?,
841            OnnxDataType::Double | OnnxDataType::Complex128 => fold_fixed_payload(t, 10, &mut h)?,
842            OnnxDataType::Complex64 => fold_fixed_payload(t, 4, &mut h)?,
843            // int64_data (#7): re-encode each varint to 8-byte LE.
844            OnnxDataType::Int64 => fold_typed_varints(t, 7, 8, &mut h)?,
845            // uint64_data (#11): UINT64 → 8-byte LE; UINT32 → 4-byte LE.
846            OnnxDataType::Uint64 => fold_typed_varints(t, 11, 8, &mut h)?,
847            OnnxDataType::Uint32 => fold_typed_varints(t, 11, 4, &mut h)?,
848            // int32_data (#5) carries INT32/INT16/INT8/UINT16/UINT8/BOOL and
849            // the bit-packed small floats — re-encode to the dtype's width.
850            OnnxDataType::Int32 => fold_typed_varints(t, 5, 4, &mut h)?,
851            OnnxDataType::Int16
852            | OnnxDataType::Uint16
853            | OnnxDataType::Float16
854            | OnnxDataType::Bfloat16 => fold_typed_varints(t, 5, 2, &mut h)?,
855            OnnxDataType::Int8
856            | OnnxDataType::Uint8
857            | OnnxDataType::Bool
858            | OnnxDataType::Float8E4M3Fn
859            | OnnxDataType::Float8E4M3Fnuz
860            | OnnxDataType::Float8E5M2
861            | OnnxDataType::Float8E5M2Fnuz
862            | OnnxDataType::Int4
863            | OnnxDataType::Uint4
864            | OnnxDataType::Float4E2M1 => fold_typed_varints(t, 5, 1, &mut h)?,
865            // string_data (#6): fold each element's digest.
866            OnnxDataType::String => {
867                for_each_field(t, 6, |v| {
868                    if let FieldValue::Bytes(s) = v {
869                        fold(&mut h, &sha256(s));
870                    }
871                    Ok(())
872                })?;
873            }
874        }
875        Ok(h.finalize())
876    }
877
878    /// Fold the (already-canonical) packed payload of a fixed-width
879    /// repeated field directly.
880    fn fold_fixed_payload(
881        body: &[u8],
882        field_no: u64,
883        h: &mut Sha256Hasher,
884    ) -> Result<(), ShapeViolation> {
885        for_each_field(body, field_no, |v| {
886            match v {
887                FieldValue::Bytes(p) => fold(h, p),
888                FieldValue::Fixed32(b) => fold(h, &b.to_le_bytes()),
889                FieldValue::Fixed64(b) => fold(h, &b.to_le_bytes()),
890                _ => {}
891            }
892            Ok(())
893        })
894    }
895
896    /// Re-encode each varint of a packed/unpacked repeated field to `width`
897    /// little-endian bytes (the canonical `raw_data` element layout).
898    fn fold_typed_varints(
899        body: &[u8],
900        field_no: u64,
901        width: usize,
902        h: &mut Sha256Hasher,
903    ) -> Result<(), ShapeViolation> {
904        for_each_field(body, field_no, |v| {
905            match v {
906                FieldValue::Bytes(p) => {
907                    let mut pos = 0;
908                    while pos < p.len() {
909                        let (val, np) = read_varint(p, pos).map_err(from_wire)?;
910                        fold(h, &val.to_le_bytes()[..width]);
911                        pos = np;
912                    }
913                }
914                FieldValue::Varint(val) => fold(h, &val.to_le_bytes()[..width]),
915                _ => {}
916            }
917            Ok(())
918        })
919    }
920
921    /// Emit a name-sorted section of repeated `ValueInfoProto` (graph
922    /// input / output / value_info). Binds the name plus a field-order-
923    /// canonical leaf digest of the `TypeProto`.
924    fn emit_value_info(
925        out: &mut Vec<u8>,
926        graph: &[u8],
927        field_no: u64,
928    ) -> Result<(), ShapeViolation> {
929        let spans = collect_spans(graph, field_no)?;
930        let mut order: Vec<usize> = (0..spans.len()).collect();
931        order.sort_by(|&a, &b| {
932            let na =
933                first_bytes(&graph[spans[a].off..spans[a].off + spans[a].len], 1).unwrap_or(&[]);
934            let nb =
935                first_bytes(&graph[spans[b].off..spans[b].off + spans[b].len], 1).unwrap_or(&[]);
936            na.cmp(nb)
937        });
938        out.extend_from_slice(&(order.len() as u32).to_le_bytes());
939        for &idx in &order {
940            let body = &graph[spans[idx].off..spans[idx].off + spans[idx].len];
941            out.extend_from_slice(&sha256(first_bytes(body, 1)?)); // name
942            out.extend_from_slice(&canonical_proto_digest(first_bytes(body, 2)?, 0)?);
943            // type (TypeProto)
944        }
945        Ok(())
946    }
947
948    /// Canonical skeleton as an owned `Vec<u8>`.
949    ///
950    /// # Errors
951    ///
952    /// Surfaces the [`ShapeViolation`] [`OnnxValue::parse`] would raise.
953    pub fn canonicalize(raw: &[u8]) -> Result<Vec<u8>, ShapeViolation> {
954        Ok(OnnxValue::parse(raw)?.bytes)
955    }
956
957    #[cfg(test)]
958    mod tests {
959        use super::*;
960
961        // ── Minimal protobuf encoders for building test `ModelProto`s ──
962
963        fn put_varint(out: &mut Vec<u8>, mut v: u64) {
964            loop {
965                let mut byte = (v & 0x7f) as u8;
966                v >>= 7;
967                if v != 0 {
968                    byte |= 0x80;
969                }
970                out.push(byte);
971                if v == 0 {
972                    break;
973                }
974            }
975        }
976
977        fn tag(out: &mut Vec<u8>, field_no: u64, wire: u64) {
978            put_varint(out, (field_no << 3) | wire);
979        }
980
981        fn field_varint(out: &mut Vec<u8>, field_no: u64, v: u64) {
982            tag(out, field_no, 0);
983            put_varint(out, v);
984        }
985
986        fn field_bytes(out: &mut Vec<u8>, field_no: u64, b: &[u8]) {
987            tag(out, field_no, 2);
988            put_varint(out, b.len() as u64);
989            out.extend_from_slice(b);
990        }
991
992        /// Smallest valid ONNX `ModelProto`: ir_version=13, one
993        /// default-domain opset import (version 1), and a non-empty graph.
994        fn minimal_onnx() -> Vec<u8> {
995            // OperatorSetIdProto { domain = "", version = 1 }
996            let mut opset = Vec::new();
997            field_bytes(&mut opset, 1, b""); // domain
998            field_varint(&mut opset, 2, 1); // version
999
1000            // GraphProto { name = "g" }
1001            let mut graph = Vec::new();
1002            field_bytes(&mut graph, 2, b"g");
1003
1004            // ModelProto
1005            let mut model = Vec::new();
1006            field_varint(&mut model, 1, ONNX_IR_VERSION_MAX as u64); // ir_version
1007            field_bytes(&mut model, 7, &graph); // graph
1008            field_bytes(&mut model, 8, &opset); // opset_import
1009            model
1010        }
1011
1012        #[test]
1013        fn parses_minimal_model() {
1014            let canon = canonicalize(&minimal_onnx()).expect("valid");
1015            // ir_version(8) + opset(domain digest 32 + version 8)
1016            //   + graph: name(32) + node_count(4) + init_count(4)
1017            //     + 3× IO counts(4 each)
1018            //   + meta: producer(32) + producer_ver(32) + domain(32)
1019            //     + model_ver(8) + metadata_props count(4)
1020            assert_eq!(
1021                canon.len(),
1022                8 + 40 + (32 + 4 + 4 + 12) + (32 + 32 + 32 + 8 + 4)
1023            );
1024        }
1025
1026        #[test]
1027        fn rejects_out_of_range_ir() {
1028            // IR 7 is in range (1..=13) → accepted (would reach MISSING_GRAPH);
1029            // 14 is a future/unknown revision → rejected at the IR gate.
1030            let mut model = Vec::new();
1031            field_varint(&mut model, 1, (ONNX_IR_VERSION_MAX + 1) as u64);
1032            let err = OnnxValue::parse(&model).expect_err("unsupported ir");
1033            assert_eq!(err.constraint_iri, UNSUPPORTED_IR.constraint_iri);
1034        }
1035
1036        #[test]
1037        fn rejects_missing_graph() {
1038            let mut opset = Vec::new();
1039            field_bytes(&mut opset, 1, b"");
1040            field_varint(&mut opset, 2, 1);
1041            let mut model = Vec::new();
1042            field_varint(&mut model, 1, ONNX_IR_VERSION_MAX as u64);
1043            field_bytes(&mut model, 8, &opset);
1044            let err = OnnxValue::parse(&model).expect_err("no graph");
1045            assert_eq!(err.constraint_iri, MISSING_GRAPH.constraint_iri);
1046        }
1047
1048        #[test]
1049        fn deterministic() {
1050            let a = canonicalize(&minimal_onnx()).expect("valid");
1051            let b = canonicalize(&minimal_onnx()).expect("valid");
1052            assert_eq!(a, b);
1053        }
1054    }
1055}