serde_yaml/value/
tagged.rs

1use crate::value::de::{MapDeserializer, MapRefDeserializer, SeqDeserializer, SeqRefDeserializer};
2use crate::value::Value;
3use crate::Error;
4use serde::de::value::{BorrowedStrDeserializer, StrDeserializer};
5use serde::de::{
6    Deserialize, DeserializeSeed, Deserializer, EnumAccess, Error as _, VariantAccess, Visitor,
7};
8use serde::forward_to_deserialize_any;
9use serde::ser::{Serialize, SerializeMap, Serializer};
10use std::cmp::Ordering;
11use std::fmt::{self, Debug, Display};
12use std::hash::{Hash, Hasher};
13use std::mem;
14
15/// A representation of YAML's `!Tag` syntax, used for enums.
16///
17/// Refer to the example code on [`TaggedValue`] for an example of deserializing
18/// tagged values.
19#[derive(Clone)]
20pub struct Tag {
21    pub(crate) string: String,
22}
23
24/// A `Tag` + `Value` representing a tagged YAML scalar, sequence, or mapping.
25///
26/// ```
27/// use serde_yaml::value::TaggedValue;
28/// use std::collections::BTreeMap;
29///
30/// let yaml = "
31///     scalar: !Thing x
32///     sequence_flow: !Thing [first]
33///     sequence_block: !Thing
34///       - first
35///     mapping_flow: !Thing {k: v}
36///     mapping_block: !Thing
37///       k: v
38/// ";
39///
40/// let data: BTreeMap<String, TaggedValue> = serde_yaml::from_str(yaml).unwrap();
41/// assert!(data["scalar"].tag == "Thing");
42/// assert!(data["sequence_flow"].tag == "Thing");
43/// assert!(data["sequence_block"].tag == "Thing");
44/// assert!(data["mapping_flow"].tag == "Thing");
45/// assert!(data["mapping_block"].tag == "Thing");
46///
47/// // The leading '!' in tags are not significant. The following is also true.
48/// assert!(data["scalar"].tag == "!Thing");
49/// ```
50#[derive(Clone, PartialEq, PartialOrd, Hash, Debug)]
51pub struct TaggedValue {
52    #[allow(missing_docs)]
53    pub tag: Tag,
54    #[allow(missing_docs)]
55    pub value: Value,
56}
57
58impl Tag {
59    /// Create tag.
60    ///
61    /// The leading '!' is not significant. It may be provided, but does not
62    /// have to be. The following are equivalent:
63    ///
64    /// ```
65    /// use serde_yaml::value::Tag;
66    ///
67    /// assert_eq!(Tag::new("!Thing"), Tag::new("Thing"));
68    ///
69    /// let tag = Tag::new("Thing");
70    /// assert!(tag == "Thing");
71    /// assert!(tag == "!Thing");
72    /// assert!(tag.to_string() == "!Thing");
73    ///
74    /// let tag = Tag::new("!Thing");
75    /// assert!(tag == "Thing");
76    /// assert!(tag == "!Thing");
77    /// assert!(tag.to_string() == "!Thing");
78    /// ```
79    ///
80    /// Such a tag would serialize to `!Thing` in YAML regardless of whether a
81    /// '!' was included in the call to `Tag::new`.
82    ///
83    /// # Panics
84    ///
85    /// Panics if `string.is_empty()`. There is no syntax in YAML for an empty
86    /// tag.
87    pub fn new(string: impl Into<String>) -> Self {
88        let tag: String = string.into();
89        assert!(!tag.is_empty(), "empty YAML tag is not allowed");
90        Tag { string: tag }
91    }
92}
93
94impl Value {
95    pub(crate) fn untag(self) -> Self {
96        let mut cur = self;
97        while let Value::Tagged(tagged) = cur {
98            cur = tagged.value;
99        }
100        cur
101    }
102
103    pub(crate) fn untag_ref(&self) -> &Self {
104        let mut cur = self;
105        while let Value::Tagged(tagged) = cur {
106            cur = &tagged.value;
107        }
108        cur
109    }
110
111    pub(crate) fn untag_mut(&mut self) -> &mut Self {
112        let mut cur = self;
113        while let Value::Tagged(tagged) = cur {
114            cur = &mut tagged.value;
115        }
116        cur
117    }
118}
119
120pub(crate) fn nobang(maybe_banged: &str) -> &str {
121    match maybe_banged.strip_prefix('!') {
122        Some("") | None => maybe_banged,
123        Some(unbanged) => unbanged,
124    }
125}
126
127impl Eq for Tag {}
128
129impl PartialEq for Tag {
130    fn eq(&self, other: &Tag) -> bool {
131        PartialEq::eq(nobang(&self.string), nobang(&other.string))
132    }
133}
134
135impl<T> PartialEq<T> for Tag
136where
137    T: ?Sized + AsRef<str>,
138{
139    fn eq(&self, other: &T) -> bool {
140        PartialEq::eq(nobang(&self.string), nobang(other.as_ref()))
141    }
142}
143
144impl Ord for Tag {
145    fn cmp(&self, other: &Self) -> Ordering {
146        Ord::cmp(nobang(&self.string), nobang(&other.string))
147    }
148}
149
150impl PartialOrd for Tag {
151    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
152        Some(self.cmp(other))
153    }
154}
155
156impl Hash for Tag {
157    fn hash<H: Hasher>(&self, hasher: &mut H) {
158        nobang(&self.string).hash(hasher);
159    }
160}
161
162impl Display for Tag {
163    fn fmt(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
164        write!(formatter, "!{}", nobang(&self.string))
165    }
166}
167
168impl Debug for Tag {
169    fn fmt(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
170        Display::fmt(self, formatter)
171    }
172}
173
174impl Serialize for TaggedValue {
175    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
176    where
177        S: Serializer,
178    {
179        struct SerializeTag<'a>(&'a Tag);
180
181        impl<'a> Serialize for SerializeTag<'a> {
182            fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
183            where
184                S: Serializer,
185            {
186                serializer.collect_str(self.0)
187            }
188        }
189
190        let mut map = serializer.serialize_map(Some(1))?;
191        map.serialize_entry(&SerializeTag(&self.tag), &self.value)?;
192        map.end()
193    }
194}
195
196impl<'de> Deserialize<'de> for TaggedValue {
197    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
198    where
199        D: Deserializer<'de>,
200    {
201        struct TaggedValueVisitor;
202
203        impl<'de> Visitor<'de> for TaggedValueVisitor {
204            type Value = TaggedValue;
205
206            fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
207                formatter.write_str("a YAML value with a !Tag")
208            }
209
210            fn visit_enum<A>(self, data: A) -> Result<Self::Value, A::Error>
211            where
212                A: EnumAccess<'de>,
213            {
214                let (tag, contents) = data.variant_seed(TagStringVisitor)?;
215                let value = contents.newtype_variant()?;
216                Ok(TaggedValue { tag, value })
217            }
218        }
219
220        deserializer.deserialize_any(TaggedValueVisitor)
221    }
222}
223
224impl<'de> Deserializer<'de> for TaggedValue {
225    type Error = Error;
226
227    fn deserialize_any<V>(self, visitor: V) -> Result<V::Value, Error>
228    where
229        V: Visitor<'de>,
230    {
231        visitor.visit_enum(self)
232    }
233
234    fn deserialize_ignored_any<V>(self, visitor: V) -> Result<V::Value, Error>
235    where
236        V: Visitor<'de>,
237    {
238        drop(self);
239        visitor.visit_unit()
240    }
241
242    forward_to_deserialize_any! {
243        bool i8 i16 i32 i64 u8 u16 u32 u64 f32 f64 char str string bytes
244        byte_buf option unit unit_struct newtype_struct seq tuple tuple_struct
245        map struct enum identifier
246    }
247}
248
249impl<'de> EnumAccess<'de> for TaggedValue {
250    type Error = Error;
251    type Variant = Value;
252
253    fn variant_seed<V>(self, seed: V) -> Result<(V::Value, Self::Variant), Error>
254    where
255        V: DeserializeSeed<'de>,
256    {
257        let tag = StrDeserializer::<Error>::new(nobang(&self.tag.string));
258        let value = seed.deserialize(tag)?;
259        Ok((value, self.value))
260    }
261}
262
263impl<'de> VariantAccess<'de> for Value {
264    type Error = Error;
265
266    fn unit_variant(self) -> Result<(), Error> {
267        Deserialize::deserialize(self)
268    }
269
270    fn newtype_variant_seed<T>(self, seed: T) -> Result<T::Value, Error>
271    where
272        T: DeserializeSeed<'de>,
273    {
274        seed.deserialize(self)
275    }
276
277    fn tuple_variant<V>(self, _len: usize, visitor: V) -> Result<V::Value, Error>
278    where
279        V: Visitor<'de>,
280    {
281        if let Value::Sequence(v) = self {
282            Deserializer::deserialize_any(SeqDeserializer::new(v), visitor)
283        } else {
284            Err(Error::invalid_type(self.unexpected(), &"tuple variant"))
285        }
286    }
287
288    fn struct_variant<V>(
289        self,
290        _fields: &'static [&'static str],
291        visitor: V,
292    ) -> Result<V::Value, Error>
293    where
294        V: Visitor<'de>,
295    {
296        if let Value::Mapping(v) = self {
297            Deserializer::deserialize_any(MapDeserializer::new(v), visitor)
298        } else {
299            Err(Error::invalid_type(self.unexpected(), &"struct variant"))
300        }
301    }
302}
303
304impl<'de> Deserializer<'de> for &'de TaggedValue {
305    type Error = Error;
306
307    fn deserialize_any<V>(self, visitor: V) -> Result<V::Value, Error>
308    where
309        V: Visitor<'de>,
310    {
311        visitor.visit_enum(self)
312    }
313
314    fn deserialize_ignored_any<V>(self, visitor: V) -> Result<V::Value, Error>
315    where
316        V: Visitor<'de>,
317    {
318        visitor.visit_unit()
319    }
320
321    forward_to_deserialize_any! {
322        bool i8 i16 i32 i64 u8 u16 u32 u64 f32 f64 char str string bytes
323        byte_buf option unit unit_struct newtype_struct seq tuple tuple_struct
324        map struct enum identifier
325    }
326}
327
328impl<'de> EnumAccess<'de> for &'de TaggedValue {
329    type Error = Error;
330    type Variant = &'de Value;
331
332    fn variant_seed<V>(self, seed: V) -> Result<(V::Value, Self::Variant), Error>
333    where
334        V: DeserializeSeed<'de>,
335    {
336        let tag = BorrowedStrDeserializer::<Error>::new(nobang(&self.tag.string));
337        let value = seed.deserialize(tag)?;
338        Ok((value, &self.value))
339    }
340}
341
342impl<'de> VariantAccess<'de> for &'de Value {
343    type Error = Error;
344
345    fn unit_variant(self) -> Result<(), Error> {
346        Deserialize::deserialize(self)
347    }
348
349    fn newtype_variant_seed<T>(self, seed: T) -> Result<T::Value, Error>
350    where
351        T: DeserializeSeed<'de>,
352    {
353        seed.deserialize(self)
354    }
355
356    fn tuple_variant<V>(self, _len: usize, visitor: V) -> Result<V::Value, Error>
357    where
358        V: Visitor<'de>,
359    {
360        if let Value::Sequence(v) = self {
361            Deserializer::deserialize_any(SeqRefDeserializer::new(v), visitor)
362        } else {
363            Err(Error::invalid_type(self.unexpected(), &"tuple variant"))
364        }
365    }
366
367    fn struct_variant<V>(
368        self,
369        _fields: &'static [&'static str],
370        visitor: V,
371    ) -> Result<V::Value, Error>
372    where
373        V: Visitor<'de>,
374    {
375        if let Value::Mapping(v) = self {
376            Deserializer::deserialize_any(MapRefDeserializer::new(v), visitor)
377        } else {
378            Err(Error::invalid_type(self.unexpected(), &"struct variant"))
379        }
380    }
381}
382
383pub(crate) struct TagStringVisitor;
384
385impl<'de> Visitor<'de> for TagStringVisitor {
386    type Value = Tag;
387
388    fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
389        formatter.write_str("a YAML tag string")
390    }
391
392    fn visit_str<E>(self, string: &str) -> Result<Self::Value, E>
393    where
394        E: serde::de::Error,
395    {
396        self.visit_string(string.to_owned())
397    }
398
399    fn visit_string<E>(self, string: String) -> Result<Self::Value, E>
400    where
401        E: serde::de::Error,
402    {
403        if string.is_empty() {
404            return Err(E::custom("empty YAML tag is not allowed"));
405        }
406        Ok(Tag::new(string))
407    }
408}
409
410impl<'de> DeserializeSeed<'de> for TagStringVisitor {
411    type Value = Tag;
412
413    fn deserialize<D>(self, deserializer: D) -> Result<Self::Value, D::Error>
414    where
415        D: Deserializer<'de>,
416    {
417        deserializer.deserialize_string(self)
418    }
419}
420
421pub(crate) enum MaybeTag<T> {
422    Tag(String),
423    NotTag(T),
424}
425
426pub(crate) fn check_for_tag<T>(value: &T) -> MaybeTag<String>
427where
428    T: ?Sized + Display,
429{
430    enum CheckForTag {
431        Empty,
432        Bang,
433        Tag(String),
434        NotTag(String),
435    }
436
437    impl fmt::Write for CheckForTag {
438        fn write_str(&mut self, s: &str) -> fmt::Result {
439            if s.is_empty() {
440                return Ok(());
441            }
442            match self {
443                CheckForTag::Empty => {
444                    if s == "!" {
445                        *self = CheckForTag::Bang;
446                    } else {
447                        *self = CheckForTag::NotTag(s.to_owned());
448                    }
449                }
450                CheckForTag::Bang => {
451                    *self = CheckForTag::Tag(s.to_owned());
452                }
453                CheckForTag::Tag(string) => {
454                    let mut string = mem::take(string);
455                    string.push_str(s);
456                    *self = CheckForTag::NotTag(string);
457                }
458                CheckForTag::NotTag(string) => {
459                    string.push_str(s);
460                }
461            }
462            Ok(())
463        }
464    }
465
466    let mut check_for_tag = CheckForTag::Empty;
467    fmt::write(&mut check_for_tag, format_args!("{}", value)).unwrap();
468    match check_for_tag {
469        CheckForTag::Empty => MaybeTag::NotTag(String::new()),
470        CheckForTag::Bang => MaybeTag::NotTag("!".to_owned()),
471        CheckForTag::Tag(string) => MaybeTag::Tag(string),
472        CheckForTag::NotTag(string) => MaybeTag::NotTag(string),
473    }
474}