Skip to main content

uor_addr/onnx/
dtype.rs

1//! `OnnxDataType` — the ONNX `TensorProto.DataType` alphabet, a mapping
2//! from the enum IDs `1..=23` to [`prism::tensor::dtype`] shapes plus the
3//! ONNX-specific `STRING` (ID 8), which carries no numeric dtype.
4//!
5//! Authoritative source: the `TensorProto.DataType` enum in
6//! <https://github.com/onnx/onnx/blob/main/onnx/onnx.proto>.
7
8use prism::tensor::dtype::{
9    Dtype, BF16, BOOL, C128, C64, F16, F32, F4_E2M1, F64, F8_E4M3, F8_E4M3_FNUZ, F8_E5M2,
10    F8_E5M2_FNUZ, I16, I32, I4, I64, I8, U16, U32, U4, U64, U8,
11};
12
13/// An ONNX tensor element type. All variants except [`Self::String`] map
14/// 1:1 to a [`prism::tensor::dtype`] shape.
15#[derive(Debug, Clone, Copy, PartialEq, Eq)]
16pub enum OnnxDataType {
17    /// `FLOAT` (1) → [`F32`].
18    Float,
19    /// `UINT8` (2) → [`U8`].
20    Uint8,
21    /// `INT8` (3) → [`I8`].
22    Int8,
23    /// `UINT16` (4) → [`U16`].
24    Uint16,
25    /// `INT16` (5) → [`I16`].
26    Int16,
27    /// `INT32` (6) → [`I32`].
28    Int32,
29    /// `INT64` (7) → [`I64`].
30    Int64,
31    /// `STRING` (8) — ONNX-specific; not in the numeric dtype alphabet.
32    String,
33    /// `BOOL` (9) → [`BOOL`].
34    Bool,
35    /// `FLOAT16` (10) → [`F16`].
36    Float16,
37    /// `DOUBLE` (11) → [`F64`].
38    Double,
39    /// `UINT32` (12) → [`U32`].
40    Uint32,
41    /// `UINT64` (13) → [`U64`].
42    Uint64,
43    /// `COMPLEX64` (14) → [`C64`].
44    Complex64,
45    /// `COMPLEX128` (15) → [`C128`].
46    Complex128,
47    /// `BFLOAT16` (16) → [`BF16`].
48    Bfloat16,
49    /// `FLOAT8E4M3FN` (17) → [`F8_E4M3`].
50    Float8E4M3Fn,
51    /// `FLOAT8E4M3FNUZ` (18) → [`F8_E4M3_FNUZ`].
52    Float8E4M3Fnuz,
53    /// `FLOAT8E5M2` (19) → [`F8_E5M2`].
54    Float8E5M2,
55    /// `FLOAT8E5M2FNUZ` (20) → [`F8_E5M2_FNUZ`].
56    Float8E5M2Fnuz,
57    /// `UINT4` (21) → [`U4`].
58    Uint4,
59    /// `INT4` (22) → [`I4`].
60    Int4,
61    /// `FLOAT4E2M1` (23) → [`F4_E2M1`].
62    Float4E2M1,
63}
64
65impl OnnxDataType {
66    /// Map a raw `TensorProto.DataType` ID. Returns `None` for `0`
67    /// (`UNDEFINED`) and IDs outside `1..=23`.
68    #[must_use]
69    pub const fn from_i32(id: i32) -> Option<Self> {
70        Some(match id {
71            1 => Self::Float,
72            2 => Self::Uint8,
73            3 => Self::Int8,
74            4 => Self::Uint16,
75            5 => Self::Int16,
76            6 => Self::Int32,
77            7 => Self::Int64,
78            8 => Self::String,
79            9 => Self::Bool,
80            10 => Self::Float16,
81            11 => Self::Double,
82            12 => Self::Uint32,
83            13 => Self::Uint64,
84            14 => Self::Complex64,
85            15 => Self::Complex128,
86            16 => Self::Bfloat16,
87            17 => Self::Float8E4M3Fn,
88            18 => Self::Float8E4M3Fnuz,
89            19 => Self::Float8E5M2,
90            20 => Self::Float8E5M2Fnuz,
91            21 => Self::Uint4,
92            22 => Self::Int4,
93            23 => Self::Float4E2M1,
94            _ => return None,
95        })
96    }
97
98    /// Block bytes from the corresponding [`prism::tensor::dtype`] shape,
99    /// or `None` for [`Self::String`] (no fixed element width).
100    #[must_use]
101    pub const fn block_bytes(self) -> Option<usize> {
102        Some(match self {
103            Self::Float => F32::BLOCK_BYTES,
104            Self::Uint8 => U8::BLOCK_BYTES,
105            Self::Int8 => I8::BLOCK_BYTES,
106            Self::Uint16 => U16::BLOCK_BYTES,
107            Self::Int16 => I16::BLOCK_BYTES,
108            Self::Int32 => I32::BLOCK_BYTES,
109            Self::Int64 => I64::BLOCK_BYTES,
110            Self::String => return None,
111            Self::Bool => BOOL::BLOCK_BYTES,
112            Self::Float16 => F16::BLOCK_BYTES,
113            Self::Double => F64::BLOCK_BYTES,
114            Self::Uint32 => U32::BLOCK_BYTES,
115            Self::Uint64 => U64::BLOCK_BYTES,
116            Self::Complex64 => C64::BLOCK_BYTES,
117            Self::Complex128 => C128::BLOCK_BYTES,
118            Self::Bfloat16 => BF16::BLOCK_BYTES,
119            Self::Float8E4M3Fn => F8_E4M3::BLOCK_BYTES,
120            Self::Float8E4M3Fnuz => F8_E4M3_FNUZ::BLOCK_BYTES,
121            Self::Float8E5M2 => F8_E5M2::BLOCK_BYTES,
122            Self::Float8E5M2Fnuz => F8_E5M2_FNUZ::BLOCK_BYTES,
123            Self::Uint4 => U4::BLOCK_BYTES,
124            Self::Int4 => I4::BLOCK_BYTES,
125            Self::Float4E2M1 => F4_E2M1::BLOCK_BYTES,
126        })
127    }
128}
129
130#[cfg(test)]
131mod tests {
132    use super::*;
133
134    #[test]
135    fn undefined_and_out_of_range_rejected() {
136        assert!(OnnxDataType::from_i32(0).is_none());
137        assert!(OnnxDataType::from_i32(24).is_none());
138        assert!(OnnxDataType::from_i32(-1).is_none());
139    }
140
141    #[test]
142    fn full_range_maps() {
143        for id in 1..=23 {
144            assert!(OnnxDataType::from_i32(id).is_some(), "id {id} unmapped");
145        }
146    }
147
148    #[test]
149    fn string_has_no_block_width_but_others_do() {
150        assert_eq!(OnnxDataType::String.block_bytes(), None);
151        assert_eq!(OnnxDataType::Float.block_bytes(), Some(4));
152        assert_eq!(OnnxDataType::Double.block_bytes(), Some(8));
153        assert_eq!(OnnxDataType::Int4.block_bytes(), Some(1));
154    }
155}