I was reading a recent blog post about reducing memory consumption in librsvg. In the blog post, the author explains that he took an array of unions where the variants had vastly different size, and created a much more compact representation. The article is really interesting, and the implementation reduce the memory footprint of librsvg by a significant margin in some corner cases.
In Rust, an enum
is more or less a tagged union. If you create an array of
tagged union, and the only thing you need to do is to iterate on it, this means
that you are going to waste a lot of space. First of all you can remove all the
padding inside the tagged union itself. Furthermore, if the type of the active
variant isn’t the biggest variant, you can save some additional bytes.
The implementation of the author of the article is working, it doesn’t use any unsafe – which is a good thing – but I felt that it was quite under-engineered. The implementation worked only for a variable number of f64
(double
in C), instead of any types possible for the variant.
Coming from a C and C++ background, I know that with some clever use of memcpy
I would be able to remove all padding in the unions. So I tried write a demo in C, then re-wrote it in Rust. librsvg
is a rust library (or “crate” as rustaceans call them), so I wanted to see if that trick was possible in Rust. Rust is a language that I started being really interested more or less a year ago, but I never took the time to start any project in it. I learn a few things along the way, so I though it would be a good idea to share this experiment.
Version 1: in C using tagged union
The goal is to take an array of unions, make it somehow more compact, and them being able to retrieve the unions in the same order.
The first thing we need is an union, with fields of different sizes to demonstrate the whole transformation.
// An union with 3 variants.
typedef union {
char array[7]; // 7 bytes
float f; // 4 bytes
struct {char x; char y;} pair; // 2 bytes
} Union;
C is a statically typed language, so it need to allocate (on the stack) enough
space for any variant in an Union
. This is why the size of Union
is the size
of the biggest variant, even if the active variant doesn’t need that much space.
The processor can access the data in memory only if it is correctly aligned. To make it more efficient, the compiler add some padding to correctly align the memory. This is usually useful, but in our case, where memory is more important than speed, this is not what we want.
The size of the union is the size of the biggest variant, which is 7 bytes, plus some architecture-dependent padding, for a total of 7 + 1 = 8 bytes on my machine.
In rust, an enum
is a tagged union, so let’s create one manually in C from our
union.
// A tagged union
typedef struct {
unsigned char discriminant; // 1 byte
// architecture-dependent padding
Union data; // 8 bytes
} TaggedUnion;
The discriminant can be encoded with a single byte since there is only 3 variants. Unfortunately there is 3 bytes of padding between the discriminant and the data. The total size is the size of the discriminant + the padding + the size of the data, for a total of 1 + 3 + 8 = 12 bytes
In the article, the author space optimized an array of tagged union. It was a dynamically allocated array, but for simplicity, we are going to use a statically allocated array.
// An array of two tagged unions
// Its size is 2× the size of one tagged union, so 32 bytes
TaggedUnion input [2] = {
// Create a first union, that uses the variant "f"
{
.discriminant = 1, // 01 in hexadecimal
.data.f = 3.0e-18, // 66 5c 5d 22 in hexadecimal
},
// Create a second variant that uses the variant "pair"
{
.discriminant = 2,
.data.pair = {
.x = 'x', // 0x78 in hexadecimal
.y = 'y', // 0x79 in hexadecimal
},
},
};
If we display the bytes in the array, we will get
01
00
00
00
66
5c
5d
22
00
00
00
00
02
00
00
00
78
79
00
00
00
00
00
00
00
00
00
00
You can see the discriminant (01
) of the first variant, then some padding (the
zeroes), the float (66 5c 5d 22
), some more padding, the discriminant of the second
variant (02
), some padding, the pairs of characters (78
and 79
) and the last
bytes of padding. As you can guess, this isn’t really space efficient. It uses
16 bytes per tagged union, for a total of 32 bytes, even if we only have 4 + 2 bytes
of data and two times 1 byte for the discriminant. It’s 24 bytes lost out of 32!
When you saw what I did, you may have immediately thing that it is possible to
remove the padding with #pragma pack(1)
or __attribute__((packed))
. This is
true for the padding between the discriminant and the data, but not for the
extra space inside the union used by the active variant (since the size of the
different variants isn’t the same).
The serialization method I choose is to remove all padding, and store both the
discriminant and the data in a single contiguous buffer. Just by knowing the
discriminant, you can know how much bytes you need to read next, and what is the
type of the data you are reading. If it’s 0
, then we have the variant array
,
which is 7 bytes long. If it’s a 1
, it will be the variant f
, a float
which is 4 bytes long. And finally if it’s 2
, it will be the variant
pair
, a pair of character, for a total of 2 bytes. By storing the index
first, when reading, we will be able to know what kind of data is next.
After serialization, since all padding will be removed, we want to get:
01
66
5c
5d
22
02
78
79
To do so, we are going to copy each useful bytes into a buffer. Once again, for simplicity reason, we are going to statically allocate it (and I just created a “big enough” buffer, not an “exactly as big as needed” one).
unsigned char buffer[20] = {0}; // 20 bytes should be enough
Technically it is not even needed to zero initialize it, but it made debugging easier.
Next we will need to do the serialization. For that we need to read the discriminant, store it in the buffer, then read exactly as many bytes as needed (depending on the discriminant) and copy those bytes in the destination buffer.
// Get the size of a variant
size_t size_of_variant(unsigned char discriminant) {
switch (discriminant) {
case 0: return sizeof(char[7]);
case 1: return sizeof(float);
case 2: return sizeof(struct {char x; char y;});
}
}
// Serialize a tagged union into a buffer of bytes
// The buffer must be big enough
// Returns the number of bytes used in the buffer.
size_t serialize(TaggedUnion *tagged_union, unsigned char *buffer) {
// Add the variant type in the buffer
buffer[0] = tagged_union->discriminant;
// Add the variant data in the buffer
size_t nb_bytes = size_of_variant(tagged_union->discriminant);
(void) memcpy(&buffer[1], &tagged_union->data, nb_bytes);
return nb_bytes + 1; // NB: +1 for the discriminant
}
In C, arrays decays as pointers, and a pointer can be implicitly promoted as an
array. buffer[0]
will access to the first byte of the array. As such,
&buffer[0]
is the address of the first byte in the array, which is a synonym to
buffer
.
void* memcpy(void* dest, const void* src, std::size_t count);
memcpy
copies count
bytes, of the data pointed by scr
into dest
. It
returns a pointer to dest
, but I ignored the return value (this is why I
casted it to void
). This kind of transformation is really unsafe, because we
are copying bytes into a struct, and reinterpreting those bytes as another type,
but we can do it because we know the exact layout of both the source and the
destination.
If you are familiar with C, all of this should be relatively straightforward.
To be able to debug what we are doing, we will need a way to display the content of the buffer.
void display(unsigned char *buffer, size_t bytes) {
printf("consumed bytes: %d, buffer content: ", bytes);
for (int i = 0; i < bytes; i++) {
printf("%02x ", buffer[i]);
}
printf("\n");
}
Let’s try this!
unsigned char buffer [20] = {0};
size_t bytes_written = 0;
for (int i = 0; i < sizeof(input)/sizeof(input[0]); i++) {
// Copy the first union in it
bytes_written += serialize(&input[i], &buffer[bytes_written]);
display(buffer, bytes_written);
// The above lines outputs:
// consumed bytes: 5, buffer content: 01 66 5c 5d 22
// Then:
// consumed bytes: 8, buffer content: 01 66 5c 5d 22 02 78 79
}
Perfect! The size of the serialized output is only 8 bytes, instead of the initial 32. Now we need to be able to do the reverse and deserialize it.
When reading back the buffer, the first thing to do is to read the first byte. This is the determinant. The determinant will give us how many bytes we need to read next, and how to interpret them, as we saw earlier.
// Copy an union from the buffer into an output tagged union.
TaggedUnion deserialize(unsigned char *buffer, size_t *index) {
TaggedUnion tagged_union;
// Read the discriminant
unsigned char discriminant = buffer[*index];
tagged_union.discriminant = discriminant;
*index += 1;
// Read the data
size_t bytes_to_read = size_of_variant(discriminant);
memcpy(&tagged_union.data, &buffer[*index], bytes_to_read);
*index += bytes_to_read;
return tagged_union;
}
Let’s try it!
// Unpack the unions
TaggedUnion output[2];
size_t bytes_read = 0;
for (int i = 0; bytes_read <= bytes_written; ++i) {
output[i] = deserialize(buffer, &bytes_read);
}
assert(input[0].discriminant == output[0].discriminant);
assert(input[0].data.f == output[0].data.f);
assert(input[1].discriminant == output[1].discriminant);
assert(input[1].data.pair.x == output[1].data.pair.x && input[1].data.pair.y == output[1].data.pair.y);
All the asserts passed successfully. Our packing/unpacking method works well. Now it’s time to do the same in Rust.
You can find the whole code on compiler explorer
Version 2: in Rust, using tagged union
This is one of the first time I’m writing code in this language but I’ve been
interested in it since a long time. I read a lot of things, but I never really
used it myself. Having my first project with that kind of low-level
manipulation may not be the better thing to do, especially since I will need to
use unsafe
blocks to do the low-level memory manipulation, but let’s try it
together!
Like in C, the first thing we are going to do is to create a tagged union. I
will intentionally use union
and not enum
for this first version to be
closer to the C code. Then I will re-write it a second time to be more
idiomatic.
#[repr(C)]
pub union Union {
pub array: [u8; 7],
pub f: f32,
pub pair: (u8, u8),
}
#[repr(C)]
pub struct TaggedUnion {
pub discriminant: u8,
pub data: Union,
}
I used the #[repr(C)]
to tell the compiler to use the C ABI, which gives
guaranties about the padding and the size of each variant. Like in C, the
Union
is 8 bytes (including the byte of padding), and the TaggedUnion
is 12
bytes long.
Given that Rust gives an easy access to dynamically growing arrays through the
Vec<T>
class, we will use it instead of a static array like in C.
let input = vec![
TaggedUnion {
discriminant: 1,
data: Union {
f: 3.0e-18,
}
},
TaggedUnion {
discriminant: 2,
data: Union {
pair: (
b'x',
b'y',
)
},
},
];
In rust the std::char
type is a Unicode codepoint. This is really useful
because this means that it can store any characters from a std::String
(which
encoded in utf-8, like any sane language should be). However, this also means
that the size of one std::char
is going to be 4 bytes long to represent any
codepoints. In the C version, the pair
was a tuple of two ASCI characters (1
byte each). To have the same representation in Rust, I had to use a single
unsigned byte, and fill it with byte literal (this is what the b
in the b'x'
and b'y'
means).
Now, let’s implement the serialization.
fn serialize(buffer: &mut Vec<u8>, tagged_union: &TaggedUnion) {
buffer.push(tagged_union.discriminant);
buffer.extend_from_slice(
match tagged_union.discriminant {
0 => { let data: &[u8; size_of::<[u8; 7]>()] = unsafe{transmute(&tagged_union.data)}; data },
1 => { let data: &[u8; size_of::<f32>()] = unsafe{transmute(&tagged_union.data)}; data },
2 => { let data: &[u8; size_of::<(u8, u8)>()] = unsafe{transmute(&tagged_union.data)}; data },
_ => unsafe {::core::hint::unreachable_unchecked()},
}
);
}
Taking a slice of memory and re-interpreting to as a slice of another type of
data is inherently insecure, and error prone. The alignment, the size of the
type, the endianness,… and many other things that I probably don’t know may
change the data layout. In C we used memcpy
to copy the bytes. In Rust, we can
use transmute
. In C++ it would have been std::static_cast
. Since this
operation is unsafe, and if not done correctly can lead to undefined behavior,
we had to use unsafe
in Rust. unsafe
doesn’t mean that the code
contains undefined behavior, but that it is the responsibility of the programmer
to prove it. In safe Rust (any code that isn’t unsafe
), the compiler will do
this for you and guaranty that no undefined behaviors can happen.
If you want to know exactly what kind of superpower (as well as the
responsibilities associated with it) unsafe
offers you, you can take a look at
the Rust book.
The deserialization is a bit more verbose because of the absence of either const generics and generic associated type (GAT). Both of those features are being worked on, and will eventually land, but for the moment, we will have to wait a bit. This just mean that we will have to repeat ourselves in the extraction logic.
First, let’s take a look at the general structure of the deserialize
function.
// Copy an union from the buffer into an output tagged union.
// SAFETY: The buffer must have been filled by the `serialize` function.
// The `index` must point to the beggining of a serialized `TaggedUnion`.
unsafe fn deserialize(buffer: &Vec<u8>, offset: &mut usize) -> TaggedUnion
{
let discriminant = buffer[*offset];
*offset += 1;
TaggedUnion {
discriminant: discriminant,
data: match discriminant {
0 => /* extract the array */,
1 => /* extract the float */,
1 => /* extract the pair of characters */,
_ => unsafe {::core::hint::unreachable_unchecked()},
}
}
}
If you are not familiar with Rust, you need to know that nearly everything is an
expression, and the last expression is returned from a function. This means that
the TaggedUnion
is created from the match
statement, and then returned from
the function.
You may notice that I added unsafe
around
::core::hint::unreachable_unchecked()
, even if we are already in an unsafe
function. I agree with the currently discussed proposition of stopping to
interpret unsafe fn
as unsafe
context, and added
#![allow(unused_unsafe)]
at the beginning of my file.
The logic is exactly the same for all 3 types of insertion, I will just past the one for the array of 7 bytes here. The extraction of the float, and the pair of character is an exercise for the reader. If you are really lazy, you can just look at the code at the end of this post!
type T = [u8; 7];
const SIZE: usize = size_of::<T>();
let pointer: *const [u8; SIZE] = unsafe{ transmute(&buffer[*offset]) };
let value: T = unsafe{ transmute(*pointer) };
*offset += SIZE;
/* return */ Union { array: value }
It is just a bit of pointer casting. First I transmute the pointer to the data inside the buffer into a pointer to a known number of bytes. Then I transmute the pointed data into the real type.
This structure is a bit verbose, but you only need to modify the first line (by
specifying another T
), and the last (by specifying another variant type) when
implementing the logic for the float and the pair of characters.
The whole implementation can be found in the playground.
This code is working, but it isn’t idiomatic Rust. First of all, I should have
used an enum
as explained in the introduction. It’s a bit like tagged union, but
safer, and the steroids are included! Secondly, my deserialize
function is
unsafe, because we give it raw data. If instead I was using an iterator to an
opaque type created from serialize
, it wouldn’t possible neither to read data
that wasn’t created by the serialization
function, nor at an invalid index.
And finally, I think I should probably not re-invent the wheel, and implement
the Serialize
and Deserialize
trait from
Serde to make it compatible with the
rest of the Rust ecosystem.
Version 3: in Rust, using an enum
It’s sad to say, but being able to serialize an enum
is anything but trivial.
The alignment, and the size of the discriminant are not easily accessible.
Luckily, the well known serialization crate Serde exists. One of his backend, bincode give it the ability to directly generate a binary buffer.
For that you would need to add in your Cargo.toml
bincode = "1.2"
serde = { version = "1.0", features = ["derive"] }
And then all the boilerplate is going to be generated for you.
use serde::{Serialize, Deserialize};
#[derive(PartialEq, Debug)]
#[derive(Serialize, Deserialize)]
pub enum Enum {
Array([u8; 7]),
F(f32),
Pair(u8, u8),
}
fn main() {
let input = vec![
Enum::F(3.0e-18),
Enum::Pair(b'x', b'y'),
];
let buffer = dbg!(bincode::serialize(&input).unwrap());
let output: Vec<Enum> = dbg!(bincode::deserialize::<Vec<Enum>>(&buffer[..]).unwrap());
}
Unfortunately, the encoded discriminant is always stored using 4 bytes even if
you used a different #[repr(...)]
for the enum
. All hopes aren’t
lost, because varint is currently proposed to be added in
bincode. If this PR is accepted,
this means that we could directly use an enum, and get optimal serialization for
free!
When encoding a discriminant as a varint, we first encore the number of bytes
needed (if discriminant = 5
, you need 1 byte, while if discriminant = 1000
,
it will be 2 bytes), then the value of the discriminant. And there is a clever
trick if the value is less than 250 (which can fit in a single byte), the size
doesn’t even need to be encoded!
EDIT: If you are in a no_std
environment, you may be interested by the
postcard crate. It already uses
leb128 for enums and lengths, witch make it a great candidate.
Another possible space optimization to store number that have a high chance to be close to zero (like a discriminant) would have been to use LEB128. The idea is relatively simple. The number is written in base 128 (which means that you write 7 bit at a time), and you set the highest bit to 1 if there is additional bits to read on the next bytes.
thirty = 00000000 00011110 -> 00011110 (fit in one byte)
-------- ---abcde 0--abcde
thousand = 00000011 11101000 -> 11101000 00000111 (fit in two bytes)
------HI Jabcdefg 1abcdefg 0----HIJ
As you can see, since thirty can fit in seven bits. The output will be a
single byte.
However thousand is bigger than 127 (the biggest 7-bits number), so the output
will be two bytes. The three highest bit of thousand have moved to the second
byte (H
, I
and J
). The highest bit of the first output byte is a 1
,
because we need to read the next output byte to get the full number, while the
highest bit of the second output byte is a 0 since we don’t need to read more
bytes when doing the deserialization.
And finally, if memory is definitively an issue, it is possible to take that output buffer and compress it. lz4 is probably a good bet, and even has a rust binding.
Discuss-it on reddit.