1use serde::de::{self, Deserialize};
2use serde::ser::{self, Serialize};
3
4use crate::ByteUnit;
5
6impl<'de> Deserialize<'de> for ByteUnit {
7 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
8 where D: serde::Deserializer<'de>
9 {
10 if deserializer.is_human_readable() {
11 deserializer.deserialize_any(Visitor)
13 } else {
14 deserializer.deserialize_u64(Visitor)
16 }
17 }
18}
19
20macro_rules! visit_integer_fn {
21 ($name:ident: $T:ty) => (
22 fn $name<E: de::Error>(self, v: $T) -> Result<Self::Value, E> {
23 Ok(v.into())
24 }
25 )
26}
27
28struct Visitor;
29
30impl<'de> de::Visitor<'de> for Visitor {
31 type Value = ByteUnit;
32
33 fn expecting(&self, formatter: &mut core::fmt::Formatter) -> core::fmt::Result {
34 formatter.write_str("a byte unit as an integer or string")
35 }
36
37 visit_integer_fn!(visit_i8: i8);
38 visit_integer_fn!(visit_i16: i16);
39 visit_integer_fn!(visit_i32: i32);
40 visit_integer_fn!(visit_i64: i64);
41 visit_integer_fn!(visit_i128: i128);
42
43 visit_integer_fn!(visit_u8: u8);
44 visit_integer_fn!(visit_u16: u16);
45 visit_integer_fn!(visit_u32: u32);
46 visit_integer_fn!(visit_u64: u64);
47 visit_integer_fn!(visit_u128: u128);
48
49 fn visit_str<E: de::Error>(self, v: &str) -> Result<Self::Value, E> {
50 v.parse().map_err(|_| E::invalid_value(de::Unexpected::Str(v), &"byte unit string"))
51 }
52}
53
54impl Serialize for ByteUnit {
55 fn serialize<S: ser::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
56 serializer.serialize_u64(self.as_u64())
57 }
58}
59
60#[cfg(test)]
61mod serde_tests {
62 use serde_test::{assert_de_tokens, assert_ser_tokens, Configure, Token};
63 use crate::ByteUnit;
64
65 #[test]
66 fn test_de() {
67 let half_mib = ByteUnit::Kibibyte(512).readable();
68 assert_de_tokens(&half_mib, &[Token::Str("512 kib")]);
69 assert_de_tokens(&half_mib, &[Token::Str("512 KiB")]);
70 assert_de_tokens(&half_mib, &[Token::Str("512KiB")]);
71 assert_de_tokens(&half_mib, &[Token::Str("524288")]);
72 assert_de_tokens(&half_mib, &[Token::U32(524288)]);
73 assert_de_tokens(&half_mib, &[Token::U64(524288)]);
74 assert_de_tokens(&half_mib, &[Token::I32(524288)]);
75 assert_de_tokens(&half_mib, &[Token::I64(524288)]);
76
77 let one_mib = ByteUnit::Mebibyte(1).readable();
78 assert_de_tokens(&one_mib, &[Token::Str("1 mib")]);
79 assert_de_tokens(&one_mib, &[Token::Str("1 MiB")]);
80 assert_de_tokens(&one_mib, &[Token::Str("1mib")]);
81
82 let zero = ByteUnit::Byte(0).readable();
83 assert_de_tokens(&zero, &[Token::Str("0")]);
84 assert_de_tokens(&zero, &[Token::Str("0 B")]);
85 assert_de_tokens(&zero, &[Token::U32(0)]);
86 assert_de_tokens(&zero, &[Token::U64(0)]);
87 assert_de_tokens(&zero, &[Token::I32(-34)]);
88 assert_de_tokens(&zero, &[Token::I64(-2483)]);
89 }
90
91 #[test]
92 fn test_de_compact() {
93 let half_mib = ByteUnit::Kibibyte(512).compact();
94 assert_de_tokens(&half_mib, &[Token::U32(524288)]);
95 assert_de_tokens(&half_mib, &[Token::U64(524288)]);
96 assert_de_tokens(&half_mib, &[Token::I32(524288)]);
97 assert_de_tokens(&half_mib, &[Token::I64(524288)]);
98
99 let one_mib = ByteUnit::Mebibyte(1).compact();
100 assert_de_tokens(&one_mib, &[Token::U32(1024 * 1024)]);
101
102 let zero = ByteUnit::Byte(0).compact();
103 assert_de_tokens(&zero, &[Token::U32(0)]);
104 assert_de_tokens(&zero, &[Token::U64(0)]);
105 assert_de_tokens(&zero, &[Token::I32(-34)]);
106 assert_de_tokens(&zero, &[Token::I64(-2483)]);
107 }
108
109 #[test]
110 fn test_ser_compact() {
111 let half_mib = ByteUnit::Kibibyte(512).compact();
112 assert_ser_tokens(&half_mib, &[Token::U64(512 << 10)]);
113
114 let ten_bytes = ByteUnit::Byte(10).compact();
115 assert_ser_tokens(&ten_bytes, &[Token::U64(10)]);
116
117 let zero = ByteUnit::Byte(0).compact();
118 assert_de_tokens(&zero, &[Token::U64(0)]);
119 }
120
121 #[test]
122 fn test_ser_readable() {
123 let half_mib = ByteUnit::Kibibyte(512).readable();
125 assert_ser_tokens(&half_mib, &[Token::U64(512 << 10)]);
126
127 let ten_bytes = ByteUnit::Byte(10).readable();
128 assert_ser_tokens(&ten_bytes, &[Token::U64(10)]);
129
130 let zero = ByteUnit::Byte(0).readable();
131 assert_de_tokens(&zero, &[Token::U64(0)]);
132 }
133}