Skip to content

Fix potential unicode conversion issues for *nix #7506

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 24 additions & 8 deletions include/dxc/WinAdapter.h
Original file line number Diff line number Diff line change
Expand Up @@ -916,19 +916,35 @@ unsigned int SysStringLen(const BSTR bstrString);
// RAII style mechanism for setting/unsetting a locale for the specified Windows
// codepage
class ScopedLocale {
const char *m_prevLocale;
locale_t Utf8Locale = nullptr;
locale_t PrevLocale = nullptr;

public:
explicit ScopedLocale(uint32_t codePage)
: m_prevLocale(setlocale(LC_ALL, nullptr)) {
assert((codePage == CP_UTF8) &&
explicit ScopedLocale(uint32_t CodePage) {
assert((CodePage == CP_UTF8) &&
"Support for Linux only handles UTF8 code pages");
setlocale(LC_ALL, "en_US.UTF-8");
Utf8Locale = newlocale(LC_CTYPE_MASK, "C.UTF-8", NULL);
if (!Utf8Locale)
Utf8Locale = newlocale(LC_CTYPE_MASK, "C.utf8", NULL);
if (!Utf8Locale)
Utf8Locale = newlocale(LC_CTYPE_MASK, "en_US.UTF-8", NULL);
assert(Utf8Locale && "Failed to create UTF-8 locale");
if (!Utf8Locale)
return;
PrevLocale = uselocale(Utf8Locale);
assert(PrevLocale && "Failed to set locale to UTF-8");
if (!PrevLocale) {
freelocale(Utf8Locale);
Utf8Locale = nullptr;
}
}
~ScopedLocale() {
if (m_prevLocale != nullptr) {
setlocale(LC_ALL, m_prevLocale);
}
if (PrevLocale != nullptr)
uselocale(PrevLocale);
if (Utf8Locale)
freelocale(Utf8Locale);
PrevLocale = nullptr;
Utf8Locale = nullptr;
}
};

Expand Down
102 changes: 67 additions & 35 deletions lib/DxcSupport/Unicode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,18 +42,17 @@ int MultiByteToWideChar(uint32_t /*CodePage*/, uint32_t /*dwFlags*/,
++cbMultiByte;
}
// If zero is given as the destination size, this function should
// return the required size (including the null-terminating character).
// return the required size (including or excluding the null-terminating
// character depending on whether the input included the null-terminator).
// This is the behavior of mbstowcs when the target is null.
if (cchWideChar == 0) {
lpWideCharStr = nullptr;
} else if (cchWideChar < cbMultiByte) {
SetLastError(ERROR_INSUFFICIENT_BUFFER);
return 0;
}

ScopedLocale utf8_locale_scope(CP_UTF8);

bool isNullTerminated = false;
size_t rv;
const char *prevLocale = setlocale(LC_ALL, nullptr);
setlocale(LC_ALL, "en_US.UTF-8");
if (lpMultiByteStr[cbMultiByte - 1] != '\0') {
char *srcStr = (char *)malloc((cbMultiByte + 1) * sizeof(char));
strncpy(srcStr, lpMultiByteStr, cbMultiByte);
Expand All @@ -62,14 +61,22 @@ int MultiByteToWideChar(uint32_t /*CodePage*/, uint32_t /*dwFlags*/,
free(srcStr);
} else {
rv = mbstowcs(lpWideCharStr, lpMultiByteStr, cchWideChar);
isNullTerminated = true;
}

if (prevLocale)
setlocale(LC_ALL, prevLocale);
if (rv == (size_t)-1) {
// mbstowcs returns -1 on error.
SetLastError(ERROR_INVALID_PARAMETER);
return 0;
}

if (rv == (size_t)cbMultiByte)
return rv;
return rv + 1; // mbstowcs excludes the terminating character
// Return value of mbstowcs (rv) excludes the terminating character.
// Matching MultiByteToWideChar requires returning the size written including
// the null terminator if the input was null-terminated, otherwise it
// returns the size written excluding the null terminator.
if (isNullTerminated)
return rv + 1;
return rv;
}

// WideCharToMultiByte is a Windows-specific method.
Expand Down Expand Up @@ -98,18 +105,17 @@ int WideCharToMultiByte(uint32_t /*CodePage*/, uint32_t /*dwFlags*/,
++cchWideChar;
}
// If zero is given as the destination size, this function should
// return the required size (including the null-terminating character).
// return the required size (including or excluding the null-terminating
// character depending on whether the input included the null-terminator).
// This is the behavior of wcstombs when the target is null.
if (cbMultiByte == 0) {
lpMultiByteStr = nullptr;
} else if (cbMultiByte < cchWideChar) {
SetLastError(ERROR_INSUFFICIENT_BUFFER);
return 0;
}

ScopedLocale utf8_locale_scope(CP_UTF8);

bool isNullTerminated = false;
size_t rv;
const char *prevLocale = setlocale(LC_ALL, nullptr);
setlocale(LC_ALL, "en_US.UTF-8");
if (lpWideCharStr[cchWideChar - 1] != L'\0') {
wchar_t *srcStr = (wchar_t *)malloc((cchWideChar + 1) * sizeof(wchar_t));
wcsncpy(srcStr, lpWideCharStr, cchWideChar);
Expand All @@ -118,21 +124,30 @@ int WideCharToMultiByte(uint32_t /*CodePage*/, uint32_t /*dwFlags*/,
free(srcStr);
} else {
rv = wcstombs(lpMultiByteStr, lpWideCharStr, cbMultiByte);
isNullTerminated = true;
}

if (prevLocale)
setlocale(LC_ALL, prevLocale);
if (rv == (size_t)-1) {
// wcstombs returns -1 on error.
SetLastError(ERROR_INVALID_PARAMETER);
return 0;
}

if (rv == (size_t)cchWideChar)
return rv;
return rv + 1; // mbstowcs excludes the terminating character
// Return value of wcstombs (rv) excludes the terminating character.
// Matching MultiByteToWideChar requires returning the size written including
// the null terminator if the input was null-terminated, otherwise it
// returns the size written excluding the null terminator.
if (isNullTerminated)
return rv + 1;
return rv;
}
#endif // _WIN32

namespace Unicode {

bool WideToEncodedString(const wchar_t *text, size_t cWide, DWORD cp,
DWORD flags, std::string *pValue, bool *lossy) {
DXASSERT_NOMSG(cWide == (size_t)-1 || cWide < INT32_MAX);
BOOL usedDefaultChar;
LPBOOL pUsedDefaultChar = (lossy == nullptr) ? nullptr : &usedDefaultChar;
if (lossy != nullptr)
Expand All @@ -147,16 +162,24 @@ bool WideToEncodedString(const wchar_t *text, size_t cWide, DWORD cp,
return true;
}

int cbUTF8 = ::WideCharToMultiByte(cp, flags, text, cWide, nullptr, 0,
nullptr, pUsedDefaultChar);
int cbUTF8 = ::WideCharToMultiByte(cp, flags, text, static_cast<int>(cWide),
nullptr, 0, nullptr, pUsedDefaultChar);
if (cbUTF8 == 0)
return false;

pValue->resize(cbUTF8);

cbUTF8 = ::WideCharToMultiByte(cp, flags, text, cWide, &(*pValue)[0],
pValue->size(), nullptr, pUsedDefaultChar);
cbUTF8 = ::WideCharToMultiByte(cp, flags, text, static_cast<int>(cWide),
&(*pValue)[0], pValue->size(), nullptr,
pUsedDefaultChar);
DXASSERT(cbUTF8 > 0, "otherwise contents have changed");
if ((cWide == (size_t)-1 || text[cWide - 1] == L'\0') &&
(*pValue)[pValue->size() - 1] == '\0') {
// When the input is null-terminated, the output includes the null
// terminator. Reduce the size by 1 to remove the embedded null terminator
// inside the string.
pValue->resize(cbUTF8 - 1);
}
DXASSERT((*pValue)[pValue->size()] == '\0',
"otherwise string didn't null-terminate after resize() call");

Expand All @@ -166,12 +189,12 @@ bool WideToEncodedString(const wchar_t *text, size_t cWide, DWORD cp,
}

bool UTF8ToWideString(const char *pUTF8, std::wstring *pWide) {
size_t cbUTF8 = (pUTF8 == nullptr) ? 0 : strlen(pUTF8);
return UTF8ToWideString(pUTF8, cbUTF8, pWide);
return UTF8ToWideString(pUTF8, -1, pWide);
}

bool UTF8ToWideString(const char *pUTF8, size_t cbUTF8, std::wstring *pWide) {
DXASSERT_NOMSG(pWide != nullptr);
DXASSERT_NOMSG(cbUTF8 == (size_t)-1 || cbUTF8 < INT32_MAX);

// Handle zero-length as a special case; it's a special value to indicate
// errors in MultiByteToWideChar.
Expand All @@ -181,15 +204,23 @@ bool UTF8ToWideString(const char *pUTF8, size_t cbUTF8, std::wstring *pWide) {
}

int cWide = ::MultiByteToWideChar(CP_UTF8, MB_ERR_INVALID_CHARS, pUTF8,
cbUTF8, nullptr, 0);
static_cast<int>(cbUTF8), nullptr, 0);
if (cWide == 0)
return false;

pWide->resize(cWide);

cWide = ::MultiByteToWideChar(CP_UTF8, MB_ERR_INVALID_CHARS, pUTF8, cbUTF8,
&(*pWide)[0], pWide->size());
cWide = ::MultiByteToWideChar(CP_UTF8, MB_ERR_INVALID_CHARS, pUTF8,
static_cast<int>(cbUTF8), &(*pWide)[0],
pWide->size());
DXASSERT(cWide > 0, "otherwise contents changed");
if ((cbUTF8 == (size_t)-1 || pUTF8[cbUTF8 - 1] == '\0') &&
(*pWide)[pWide->size() - 1] == '\0') {
// When the input is null-terminated, the output includes the null
// terminator. Reduce the size by 1 to remove the embedded null terminator
// inside the string.
pWide->resize(cWide - 1);
}
DXASSERT((*pWide)[pWide->size()] == L'\0',
"otherwise wstring didn't null-terminate after resize() call");
return true;
Expand All @@ -213,11 +244,12 @@ bool UTF8ToConsoleString(const char *text, size_t textLen, std::string *pValue,
if (!UTF8ToWideString(text, textLen, &text16)) {
return false;
}
return WideToConsoleString(text16.c_str(), text16.length(), pValue, lossy);
return WideToConsoleString(text16.c_str(), text16.length() + 1, pValue,
lossy);
}

bool UTF8ToConsoleString(const char *text, std::string *pValue, bool *lossy) {
return UTF8ToConsoleString(text, strlen(text), pValue, lossy);
return UTF8ToConsoleString(text, (size_t)-1, pValue, lossy);
}

bool WideToConsoleString(const wchar_t *text, size_t textLen,
Expand All @@ -230,7 +262,7 @@ bool WideToConsoleString(const wchar_t *text, size_t textLen,

bool WideToConsoleString(const wchar_t *text, std::string *pValue,
bool *lossy) {
return WideToConsoleString(text, wcslen(text), pValue, lossy);
return WideToConsoleString(text, (size_t)-1, pValue, lossy);
}

bool WideToUTF8String(const wchar_t *pWide, size_t cWide, std::string *pUTF8) {
Expand All @@ -242,7 +274,7 @@ bool WideToUTF8String(const wchar_t *pWide, size_t cWide, std::string *pUTF8) {
bool WideToUTF8String(const wchar_t *pWide, std::string *pUTF8) {
DXASSERT_NOMSG(pWide != nullptr);
DXASSERT_NOMSG(pUTF8 != nullptr);
return WideToEncodedString(pWide, wcslen(pWide), CP_UTF8, 0, pUTF8, nullptr);
return WideToEncodedString(pWide, (size_t)-1, CP_UTF8, 0, pUTF8, nullptr);
}

std::string WideToUTF8StringOrThrow(const wchar_t *pWide) {
Expand Down
66 changes: 33 additions & 33 deletions tools/clang/unittests/HLSL/CompilerTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,13 @@ class CompilerTest : public ::testing::Test {
void TestEncodingImpl(const void *sourceData, size_t sourceSize,
UINT32 codePage, const void *includedData,
size_t includedSize, const WCHAR *encoding = nullptr);
template <typename T1, typename T2>
void TestEncodingImpl(std::basic_string<T1> source, UINT32 codePage,
std::basic_string<T2> included,
const WCHAR *encoding = nullptr) {
TestEncodingImpl(source.data(), source.size() * sizeof(T1), codePage,
included.data(), included.size() * sizeof(T2), encoding);
}
TEST_METHOD(CompileWithEncodeFlagTestSource)

#if _ITERATOR_DEBUG_LEVEL == 0
Expand Down Expand Up @@ -3636,54 +3643,47 @@ void CompilerTest::TestEncodingImpl(const void *sourceData, size_t sourceSize,

TEST_F(CompilerTest, CompileWithEncodeFlagTestSource) {

std::string sourceUtf8 = "#include \"include.hlsl\"\r\n"
"float4 main() : SV_Target { return 0; }";
std::string includeUtf8 = "// Comment\n";
std::string SourceUtf8 = "#include \"include.hlsl\"\n"
"float4 main() : SV_Target { return Buf[0]; }";
std::string IncludeUtf8 = "Buffer<float4> Buf;\n";
std::string utf8BOM = "\xEF"
"\xBB"
"\xBF"; // UTF-8 BOM
std::string includeUtf8BOM = utf8BOM + includeUtf8;
std::string IncludeUtf8BOM = utf8BOM + IncludeUtf8;

std::wstring sourceWide = L"#include \"include.hlsl\"\r\n"
L"float4 main() : SV_Target { return 0; }";
std::wstring includeWide = L"// Comments\n";
std::wstring utf16BOM = L"\xFEFF"; // UTF-16 LE BOM
std::wstring includeUtf16BOM = utf16BOM + includeWide;
std::wstring SourceWide = L"#include \"include.hlsl\"\n"
L"float4 main() : SV_Target { return Buf[0]; }";
std::wstring IncludeWide = L"Buffer<float4> Buf;\n";

// Included files interpreted with encoding option if no BOM
TestEncodingImpl(sourceUtf8.data(), sourceUtf8.size(), DXC_CP_UTF8,
includeUtf8.data(), includeUtf8.size(), L"utf8");
// Windows: UTF-16 BOM is '\xFEFF'
// *nix: UTF-32 BOM is L'\x0000FEFF'
// Thus, BOM wide character value is identical for UTF-16 and UTF-32.
// Endianess will be native, since we are using wide strings directly.
std::wstring WideBOM = L"\xFEFF";

std::wstring IncludeWideBOM = WideBOM + IncludeWide;

TestEncodingImpl(sourceWide.data(), sourceWide.size() * sizeof(L'A'),
DXC_CP_WIDE, includeWide.data(),
includeWide.size() * sizeof(L'A'), L"wide");
// Included files interpreted with encoding option if no BOM
TestEncodingImpl(SourceUtf8, DXC_CP_UTF8, IncludeUtf8, L"utf8");
TestEncodingImpl(SourceWide, DXC_CP_WIDE, IncludeWide, L"wide");

// Encoding option ignored if BOM present
TestEncodingImpl(sourceUtf8.data(), sourceUtf8.size(), DXC_CP_UTF8,
includeUtf8BOM.data(), includeUtf8BOM.size(), L"wide");
TestEncodingImpl(SourceUtf8, DXC_CP_UTF8, IncludeUtf8BOM, L"wide");
TestEncodingImpl(SourceWide, DXC_CP_WIDE, IncludeWideBOM, L"utf8");

TestEncodingImpl(sourceWide.data(), sourceWide.size() * sizeof(L'A'),
DXC_CP_WIDE, includeUtf16BOM.data(),
includeUtf16BOM.size() * sizeof(L'A'), L"utf8");
// Encoding option ignored if BOM present - different encoding for source
TestEncodingImpl(SourceWide, DXC_CP_WIDE, IncludeUtf8BOM, L"wide");
TestEncodingImpl(SourceUtf8, DXC_CP_UTF8, IncludeWideBOM, L"utf8");

// Source file interpreted according to DxcBuffer encoding if not CP_ACP
// Included files interpreted with encoding option if no BOM
TestEncodingImpl(sourceUtf8.data(), sourceUtf8.size(), DXC_CP_UTF8,
includeWide.data(), includeWide.size() * sizeof(L'A'),
L"wide");

TestEncodingImpl(sourceWide.data(), sourceWide.size() * sizeof(L'A'),
DXC_CP_WIDE, includeUtf8.data(), includeUtf8.size(),
L"utf8");
TestEncodingImpl(SourceUtf8, DXC_CP_UTF8, IncludeWide, L"wide");
TestEncodingImpl(SourceWide, DXC_CP_WIDE, IncludeUtf8, L"utf8");

// Source file interpreted by encoding option if source DxcBuffer encoding =
// CP_ACP (default)
TestEncodingImpl(sourceUtf8.data(), sourceUtf8.size(), DXC_CP_ACP,
includeUtf8.data(), includeUtf8.size(), L"utf8");

TestEncodingImpl(sourceWide.data(), sourceWide.size() * sizeof(L'A'),
DXC_CP_ACP, includeWide.data(),
includeWide.size() * sizeof(L'A'), L"wide");
TestEncodingImpl(SourceUtf8, DXC_CP_ACP, IncludeUtf8, L"utf8");
TestEncodingImpl(SourceWide, DXC_CP_ACP, IncludeWide, L"wide");
}

TEST_F(CompilerTest, CompileWhenODumpThenOptimizerMatch) {
Expand Down