Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for [u8; N] #28

Merged
merged 2 commits into from
Dec 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ edition = "2018"
keywords = ["serde", "serialization", "no_std", "bytes"]
license = "MIT OR Apache-2.0"
repository = "https://github.com/serde-rs/bytes"
rust-version = "1.31"
rust-version = "1.53"
sgued marked this conversation as resolved.
Show resolved Hide resolved

[features]
default = ["std"]
Expand Down
225 changes: 225 additions & 0 deletions src/bytearray.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,225 @@
use crate::Bytes;
use core::borrow::{Borrow, BorrowMut};
use core::cmp::Ordering;
use core::convert::TryInto;
use core::fmt::{self, Debug};
use core::hash::{Hash, Hasher};
use core::ops::{Deref, DerefMut};

use serde::de::{Deserialize, Deserializer, Error, SeqAccess, Visitor};
use serde::ser::{Serialize, Serializer};

/// Wrapper around `[u8; N]` to serialize and deserialize efficiently.
///
/// ```
/// use std::collections::HashMap;
/// use std::io;
///
/// use serde_bytes::ByteArray;
///
/// fn deserialize_bytearrays() -> bincode::Result<()> {
/// let example_data = [
/// 2, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 3, 0, 0, 0, 0, 0, 0, 0, 116,
/// 119, 111, 1, 0, 0, 0, 3, 0, 0, 0, 0, 0, 0, 0, 111, 110, 101
/// ];
///
/// let map: HashMap<u32, ByteArray<3>> = bincode::deserialize(&example_data[..])?;
///
/// println!("{:?}", map);
///
/// Ok(())
/// }
/// #
/// # fn main() {
/// # deserialize_bytearrays().unwrap();
/// # }
/// ```
#[derive(Clone, Eq, Ord)]
pub struct ByteArray<const N: usize> {
bytes: [u8; N],
}

impl<const N: usize> ByteArray<N> {
/// Transform an [array](https://doc.rust-lang.org/stable/std/primitive.array.html) to the equivalent `ByteArray`
pub fn new(bytes: [u8; N]) -> Self {
Self { bytes }
}

/// Wrap existing bytes into a `ByteArray`
pub fn from<T: Into<[u8; N]>>(bytes: T) -> Self {
Self {
bytes: bytes.into(),
}
}

/// Unwraps the byte array underlying this `ByteArray`
pub fn into_array(self) -> [u8; N] {
self.bytes
}
}

impl<const N: usize> Debug for ByteArray<N> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
Debug::fmt(&self.bytes, f)
}
}

impl<const N: usize> AsRef<[u8; N]> for ByteArray<N> {
fn as_ref(&self) -> &[u8; N] {
&self.bytes
}
}
impl<const N: usize> AsMut<[u8; N]> for ByteArray<N> {
fn as_mut(&mut self) -> &mut [u8; N] {
&mut self.bytes
}
}

impl<const N: usize> Borrow<[u8; N]> for ByteArray<N> {
fn borrow(&self) -> &[u8; N] {
&self.bytes
}
}
impl<const N: usize> BorrowMut<[u8; N]> for ByteArray<N> {
fn borrow_mut(&mut self) -> &mut [u8; N] {
&mut self.bytes
}
}

impl<const N: usize> Deref for ByteArray<N> {
type Target = [u8; N];

fn deref(&self) -> &Self::Target {
&self.bytes
}
}

impl<const N: usize> DerefMut for ByteArray<N> {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.bytes
}
}

impl<const N: usize> Borrow<Bytes> for ByteArray<N> {
fn borrow(&self) -> &Bytes {
Bytes::new(&self.bytes)
}
}

impl<const N: usize> BorrowMut<Bytes> for ByteArray<N> {
fn borrow_mut(&mut self) -> &mut Bytes {
unsafe { &mut *(&mut self.bytes as &mut [u8] as *mut [u8] as *mut Bytes) }
}
}

impl<Rhs, const N: usize> PartialEq<Rhs> for ByteArray<N>
where
Rhs: ?Sized + Borrow<[u8; N]>,
{
fn eq(&self, other: &Rhs) -> bool {
self.as_ref().eq(other.borrow())
}
}

impl<Rhs, const N: usize> PartialOrd<Rhs> for ByteArray<N>
where
Rhs: ?Sized + Borrow<[u8; N]>,
{
fn partial_cmp(&self, other: &Rhs) -> Option<Ordering> {
self.as_ref().partial_cmp(other.borrow())
}
}

impl<const N: usize> Hash for ByteArray<N> {
fn hash<H: Hasher>(&self, state: &mut H) {
self.bytes.hash(state);
}
}

impl<const N: usize> IntoIterator for ByteArray<N> {
type Item = u8;
type IntoIter = <[u8; N] as IntoIterator>::IntoIter;

fn into_iter(self) -> Self::IntoIter {
IntoIterator::into_iter(self.bytes)
}
}

impl<'a, const N: usize> IntoIterator for &'a ByteArray<N> {
type Item = &'a u8;
type IntoIter = <&'a [u8; N] as IntoIterator>::IntoIter;

fn into_iter(self) -> Self::IntoIter {
self.bytes.iter()
}
}

impl<'a, const N: usize> IntoIterator for &'a mut ByteArray<N> {
type Item = &'a mut u8;
type IntoIter = <&'a mut [u8; N] as IntoIterator>::IntoIter;

fn into_iter(self) -> Self::IntoIter {
self.bytes.iter_mut()
}
}

impl<const N: usize> Serialize for ByteArray<N> {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
serializer.serialize_bytes(&self.bytes)
}
}

struct ByteArrayVisitor<const N: usize>;

impl<'de, const N: usize> Visitor<'de> for ByteArrayVisitor<N> {
type Value = ByteArray<N>;

fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
write!(formatter, "a byte array of length {}", N)
}

fn visit_seq<V>(self, mut seq: V) -> Result<ByteArray<N>, V::Error>
where
V: SeqAccess<'de>,
{
let mut bytes = [0; N];

for (idx, byte) in bytes.iter_mut().enumerate() {
*byte = seq
.next_element()?
.ok_or_else(|| V::Error::invalid_length(idx, &self))?;
}

Ok(ByteArray::from(bytes))
}

fn visit_bytes<E>(self, v: &[u8]) -> Result<ByteArray<N>, E>
where
E: Error,
{
Ok(ByteArray {
bytes: v
.try_into()
.map_err(|_| E::invalid_length(v.len(), &self))?,
})
}

fn visit_str<E>(self, v: &str) -> Result<ByteArray<N>, E>
where
E: Error,
{
self.visit_bytes(v.as_bytes())
}
}

impl<'de, const N: usize> Deserialize<'de> for ByteArray<N> {
fn deserialize<D>(deserializer: D) -> Result<ByteArray<N>, D::Error>
where
D: Deserializer<'de>,
{
deserializer.deserialize_bytes(ByteArrayVisitor::<N>)
}
}
22 changes: 21 additions & 1 deletion src/de.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::Bytes;
use crate::{ByteArray, Bytes};
use core::fmt;
use core::marker::PhantomData;
use serde::de::{Error, Visitor};
Expand Down Expand Up @@ -63,6 +63,26 @@ impl<'de: 'a, 'a> Deserialize<'de> for &'a Bytes {
}
}

impl<'de, const N: usize> Deserialize<'de> for [u8; N] {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
let arr: ByteArray<N> = serde::Deserialize::deserialize(deserializer)?;
Ok(*arr)
}
}

impl<'de, const N: usize> Deserialize<'de> for ByteArray<N> {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
// Via the serde::Deserialize impl for ByteArray
serde::Deserialize::deserialize(deserializer)
}
}

#[cfg(any(feature = "std", feature = "alloc"))]
impl<'de> Deserialize<'de> for ByteBuf {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
Expand Down
11 changes: 11 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@
//!
//! #[serde(with = "serde_bytes")]
//! byte_buf: Vec<u8>,
//!
//! #[serde(with = "serde_bytes")]
//! byte_array: [u8; 314],
//! }
//! ```

Expand All @@ -36,6 +39,7 @@
clippy::needless_doctest_main
)]

mod bytearray;
mod bytes;
mod de;
mod ser;
Expand All @@ -51,6 +55,7 @@ use serde::Deserializer;

use serde::Serializer;

pub use crate::bytearray::ByteArray;
pub use crate::bytes::Bytes;
pub use crate::de::Deserialize;
pub use crate::ser::Serialize;
Expand All @@ -76,6 +81,9 @@ pub use crate::bytebuf::ByteBuf;
///
/// #[serde(with = "serde_bytes")]
/// byte_buf: Vec<u8>,
///
/// #[serde(with = "serde_bytes")]
/// byte_array: [u8; 314],
/// }
/// ```
pub fn serialize<T, S>(bytes: &T, serializer: S) -> Result<S::Ok, S::Error>
Expand All @@ -101,6 +109,9 @@ where
/// struct Packet {
/// #[serde(with = "serde_bytes")]
/// payload: Vec<u8>,
///
/// #[serde(with = "serde_bytes")]
/// byte_array: [u8; 314],
/// }
/// ```
#[cfg(any(feature = "std", feature = "alloc"))]
Expand Down
20 changes: 19 additions & 1 deletion src/ser.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::Bytes;
use crate::{ByteArray, Bytes};
use serde::Serializer;

#[cfg(any(feature = "std", feature = "alloc"))]
Expand Down Expand Up @@ -51,6 +51,24 @@ impl Serialize for Bytes {
}
}

impl<const N: usize> Serialize for [u8; N] {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
serializer.serialize_bytes(self)
}
}

impl<const N: usize> Serialize for ByteArray<N> {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
serializer.serialize_bytes(&**self)
}
}

#[cfg(any(feature = "std", feature = "alloc"))]
impl Serialize for ByteBuf {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
Expand Down
Loading