1use prism::tensor::dtype::{
9 Dtype, BF16, BOOL, C128, C64, F16, F32, F4_E2M1, F64, F8_E4M3, F8_E4M3_FNUZ, F8_E5M2,
10 F8_E5M2_FNUZ, I16, I32, I4, I64, I8, U16, U32, U4, U64, U8,
11};
12
13#[derive(Debug, Clone, Copy, PartialEq, Eq)]
16pub enum OnnxDataType {
17 Float,
19 Uint8,
21 Int8,
23 Uint16,
25 Int16,
27 Int32,
29 Int64,
31 String,
33 Bool,
35 Float16,
37 Double,
39 Uint32,
41 Uint64,
43 Complex64,
45 Complex128,
47 Bfloat16,
49 Float8E4M3Fn,
51 Float8E4M3Fnuz,
53 Float8E5M2,
55 Float8E5M2Fnuz,
57 Uint4,
59 Int4,
61 Float4E2M1,
63}
64
65impl OnnxDataType {
66 #[must_use]
69 pub const fn from_i32(id: i32) -> Option<Self> {
70 Some(match id {
71 1 => Self::Float,
72 2 => Self::Uint8,
73 3 => Self::Int8,
74 4 => Self::Uint16,
75 5 => Self::Int16,
76 6 => Self::Int32,
77 7 => Self::Int64,
78 8 => Self::String,
79 9 => Self::Bool,
80 10 => Self::Float16,
81 11 => Self::Double,
82 12 => Self::Uint32,
83 13 => Self::Uint64,
84 14 => Self::Complex64,
85 15 => Self::Complex128,
86 16 => Self::Bfloat16,
87 17 => Self::Float8E4M3Fn,
88 18 => Self::Float8E4M3Fnuz,
89 19 => Self::Float8E5M2,
90 20 => Self::Float8E5M2Fnuz,
91 21 => Self::Uint4,
92 22 => Self::Int4,
93 23 => Self::Float4E2M1,
94 _ => return None,
95 })
96 }
97
98 #[must_use]
101 pub const fn block_bytes(self) -> Option<usize> {
102 Some(match self {
103 Self::Float => F32::BLOCK_BYTES,
104 Self::Uint8 => U8::BLOCK_BYTES,
105 Self::Int8 => I8::BLOCK_BYTES,
106 Self::Uint16 => U16::BLOCK_BYTES,
107 Self::Int16 => I16::BLOCK_BYTES,
108 Self::Int32 => I32::BLOCK_BYTES,
109 Self::Int64 => I64::BLOCK_BYTES,
110 Self::String => return None,
111 Self::Bool => BOOL::BLOCK_BYTES,
112 Self::Float16 => F16::BLOCK_BYTES,
113 Self::Double => F64::BLOCK_BYTES,
114 Self::Uint32 => U32::BLOCK_BYTES,
115 Self::Uint64 => U64::BLOCK_BYTES,
116 Self::Complex64 => C64::BLOCK_BYTES,
117 Self::Complex128 => C128::BLOCK_BYTES,
118 Self::Bfloat16 => BF16::BLOCK_BYTES,
119 Self::Float8E4M3Fn => F8_E4M3::BLOCK_BYTES,
120 Self::Float8E4M3Fnuz => F8_E4M3_FNUZ::BLOCK_BYTES,
121 Self::Float8E5M2 => F8_E5M2::BLOCK_BYTES,
122 Self::Float8E5M2Fnuz => F8_E5M2_FNUZ::BLOCK_BYTES,
123 Self::Uint4 => U4::BLOCK_BYTES,
124 Self::Int4 => I4::BLOCK_BYTES,
125 Self::Float4E2M1 => F4_E2M1::BLOCK_BYTES,
126 })
127 }
128}
129
130#[cfg(test)]
131mod tests {
132 use super::*;
133
134 #[test]
135 fn undefined_and_out_of_range_rejected() {
136 assert!(OnnxDataType::from_i32(0).is_none());
137 assert!(OnnxDataType::from_i32(24).is_none());
138 assert!(OnnxDataType::from_i32(-1).is_none());
139 }
140
141 #[test]
142 fn full_range_maps() {
143 for id in 1..=23 {
144 assert!(OnnxDataType::from_i32(id).is_some(), "id {id} unmapped");
145 }
146 }
147
148 #[test]
149 fn string_has_no_block_width_but_others_do() {
150 assert_eq!(OnnxDataType::String.block_bytes(), None);
151 assert_eq!(OnnxDataType::Float.block_bytes(), Some(4));
152 assert_eq!(OnnxDataType::Double.block_bytes(), Some(8));
153 assert_eq!(OnnxDataType::Int4.block_bytes(), Some(1));
154 }
155}