Skip to content

Avoid double allocation when passing strings via IntoParam #1713

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Apr 28, 2022
44 changes: 33 additions & 11 deletions crates/libs/windows/src/core/heap.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,20 +32,42 @@ pub unsafe fn heap_free(ptr: RawPtr) {
}
}

/// Copy a slice of `T` into a freshly allocated buffer with an additional default `T` at the end.
/// Copy an iterator of `T` into a freshly allocated buffer with an additional default `T` at the end.
///
/// Returns a pointer to the beginning of the buffer
/// Returns a pointer to the beginning of the buffer. This pointer must be freed when done using `heap_free`.
///
/// # Panics
///
/// This function panics if the heap allocation fails or if the pointer returned from
/// the heap allocation is not properly aligned to `T`.
pub fn heap_string<T: Copy + Default + Sized>(slice: &[T]) -> *const T {
unsafe {
let buffer = heap_alloc((slice.len() + 1) * std::mem::size_of::<T>()).expect("could not allocate string") as *mut T;
assert!(buffer.align_offset(std::mem::align_of::<T>()) == 0, "heap allocated buffer is not properly aligned");
buffer.copy_from_nonoverlapping(slice.as_ptr(), slice.len());
buffer.add(slice.len()).write(T::default());
buffer
/// This function panics if the heap allocation fails, the alignment requirements of 'T' surpass
/// 8 (HeapAlloc's alignment) or if len is less than the number of items in the iterator.
pub fn string_from_iter<I, T>(iter: I, len: usize) -> *const T
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: it always bothered me a bit that we're using the term string here when this function is more general than that. It might be nice to name it in a way that describes more closely what's actually happening.

In fact, it might make sense for this to only copy an iterator and the caller is responsible for adding the trailing null byte. This function would then lose the Default bound and the caller would call it like so:

copy_from_iterator(self.as_bytes().iter().copied().chain(core::iter::once(0)), self.len() + 1);

The caller is a bit more verbose, but it's way clearer what's actually happening.

where
I: Iterator<Item = T>,
T: Copy + Default,
{
// alignment of memory returned by HeapAlloc is at least 8
// Source: https://docs.microsoft.com/en-us/windows/win32/api/heapapi/nf-heapapi-heapalloc
// Ensure that T has sufficient alignment requirements
assert!(std::mem::align_of::<T>() <= 8, "T alignment surpasses HeapAlloc alignment");

let len = len + 1;
let ptr = heap_alloc(len * std::mem::size_of::<T>()).expect("could not allocate string") as *mut T;
let mut encoder = iter.chain(core::iter::once(T::default()));

for i in 0..len {
// SAFETY: ptr points to an allocation object of size `len`, indices accessed are always lower than `len`
unsafe {
core::ptr::write(
ptr.add(i),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: can you use ptr.add(i).write(...) instead?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Of course! My bad.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can store the zipped iterator into a local variable and still do the check afterwards if you want.

Pardon, but how do I do that? The zip iterator consumes both inputs and only yields elements if both iterators have an element, what can I get from the zipped iterator besides None afterwards?

match encoder.next() {
Some(encoded) => encoded,
None => break,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It might be useful to do the following here instead of the assert in code a few lines later:

debug_assert!(i == len -1);

Essentially, while this code is always safe (i.e., we'll never try to write to unallocated memory), if the iterator's length and len don't match we either end up allocating too little memory or too much. It seems reasonable to help user's of this function not make that mistake.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you mean debug_assert!(i < len) and not ==?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I meant ==, because I believe we expect that length will always be equal to the number of elements in the iterator.

},
);
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Make this a bit clearer perhaps:

for (offset, c) in (0..len).zip(encoder) {
  // SAFETY: ptr points to an allocation object of size `len`, indices accessed are always lower than `len`
  unsafe { core::ptr::write(ptr.add(offset), c); }
}

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks much better! Unfortunately putting the encoder into .zip(self) would consume it, which wouldn't allow our assertion afterwards:

assert!(encoder.next().is_none(), "encoder returned more characters than expected");

Your version does look much better and I'd take it in a heartbeat, but we'd need a non-consuming version of zip unfortunately.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be unusual, but is it invalid to request encoding fewer characters than there are?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Technically in my opinion it is impossible, but since there was uncertainty in the reviews about silent truncation I kept it in. It could be possible that somehow an invalid length gets passed in the future, so it'd be helpful to be aware via a panic rather than silent truncation. I'm not hugely opinionated on this issue, we can go for the more elegant .zip at the expense of potential bugs in the encoder or what gets passed to the function.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wasn't worried about truncation, I was worried about writing past the allocated memory with the unsafe ptr.write(). Your outer loop and/or the .zip() makes that a non-issue. No strong opinion on this either, I can see it both ways. We're not going to overwrite anything no matter what.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can store the zipped iterator into a local variable and still do the check afterwards if you want.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The Zip iterator instance will return None because the range iterator is finished. You can never get to the original encoder iterator again because it's inside the Zip iterator instance.

}

assert!(encoder.next().is_none(), "encoder returned more characters than expected");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Won't this be fused because the chaining of once() is guaranteed to be fused? iter might not be fused, but the implementation will never call next() on it again because the chain moved on to the next one.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice catch, I didn't consider that. Once is indeed guaranteed to be fused, so that issue is solved!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

asserting here seems wrong... I'm not sure we want to panic if the encoder contains more than len. If anything, a debug_assert! might be appropriate (though I think silently ignoring is likely ok).

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd also like to avoid panics in general. I have customers for whom this is inappropriate. If it means we have to embed this inside PCWSTR and PCSTR directly to avoid the issues with generality, then so be it.

Copy link
Contributor

@ryancerium ryancerium Apr 27, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this panic: string_from_iter("hello world".encode_utf16(), 5)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That should never happen. My concern is that we're trying to harden a function that is only ever used internally so if there's a general concern about the safety of this function then we can either mark it unsafe or just get rid of the function entirely.


ptr
}
2 changes: 1 addition & 1 deletion crates/libs/windows/src/core/pcstr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ unsafe impl Abi for PCSTR {
#[cfg(feature = "alloc")]
impl<'a> IntoParam<'a, PCSTR> for &str {
fn into_param(self) -> Param<'a, PCSTR> {
Param::Boxed(PCSTR(heap_string(self.as_bytes())))
Param::Boxed(PCSTR(string_from_iter(self.as_bytes().iter().copied(), self.len())))
}
}
#[cfg(feature = "alloc")]
Expand Down
4 changes: 2 additions & 2 deletions crates/libs/windows/src/core/pcwstr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ unsafe impl Abi for PCWSTR {
#[cfg(feature = "alloc")]
impl<'a> IntoParam<'a, PCWSTR> for &str {
fn into_param(self) -> Param<'a, PCWSTR> {
Param::Boxed(PCWSTR(heap_string(&self.encode_utf16().collect::<alloc::vec::Vec<u16>>())))
Param::Boxed(PCWSTR(string_from_iter(self.encode_utf16(), self.len())))
}
}
#[cfg(feature = "alloc")]
Expand All @@ -58,7 +58,7 @@ impl<'a> IntoParam<'a, PCWSTR> for alloc::string::String {
impl<'a> IntoParam<'a, PCWSTR> for &::std::ffi::OsStr {
fn into_param(self) -> Param<'a, PCWSTR> {
use ::std::os::windows::ffi::OsStrExt;
Param::Boxed(PCWSTR(heap_string(&self.encode_wide().collect::<alloc::vec::Vec<u16>>())))
Param::Boxed(PCWSTR(string_from_iter(self.encode_wide(), self.len())))
}
}
#[cfg(feature = "alloc")]
Expand Down