Skip to content

Commit 9c3279b

Browse files
committed
Introduce Structural Hashing
1 parent b8c78bc commit 9c3279b

File tree

17 files changed

+822
-388
lines changed

17 files changed

+822
-388
lines changed

CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ cmake_minimum_required(VERSION 3.15)
22

33
project(
44
mlc
5-
VERSION 0.0.10
5+
VERSION 0.0.11
66
DESCRIPTION "MLC-Python"
77
LANGUAGES C CXX
88
)

README.md

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ class Let(Expr):
9292
body: Expr
9393
```
9494

95-
**Structural equality**. Method eq_s is ready to use to compare the structural equality (alpha equivalence) of two IRs.
95+
**Structural equality**. Member method `eq_s` compares the structural equality (alpha equivalence) of two IRs represented by MLC's structured dataclass.
9696

9797
```python
9898
"""
@@ -110,7 +110,13 @@ True
110110
ValueError: Structural equality check failed at {root}.rhs.b: Inconsistent binding. RHS has been bound to a different node while LHS is not bound
111111
```
112112

113-
**Structural hashing**. TBD
113+
**Structural hashing**. The structure of MLC dataclasses can be hashed via `hash_s`, which guarantees if two dataclasses are alpha-equivalent, they will share the same structural hash:
114+
115+
```python
116+
>>> L1_hash, L2_hash, L3_hash = L1.hash_s(), L2.hash_s(), L3.hash_s()
117+
>>> assert L1_hash == L2_hash
118+
>>> assert L1_hash != L3_hash
119+
```
114120

115121
### :snake: Text Formats in Python AST
116122

cpp/c_api.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,10 @@ MLC_REGISTER_FUNC("mlc.core.JSONDeserialize").set_body([](AnyView json_str) {
3737
}
3838
});
3939
MLC_REGISTER_FUNC("mlc.core.StructuralEqual").set_body(::mlc::core::StructuralEqual);
40+
MLC_REGISTER_FUNC("mlc.core.StructuralHash").set_body([](::mlc::Object *obj) -> int64_t {
41+
uint64_t ret = ::mlc::core::StructuralHash(obj);
42+
return static_cast<int64_t>(ret);
43+
});
4044
} // namespace
4145

4246
MLC_API MLCAny MLCGetLastError() {

cpp/registry.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,7 @@ struct TypeTable {
205205
MLCTypeInfo *info = &wrapper->info;
206206
info->type_index = type_index;
207207
info->type_key = this->NewArray(type_key);
208+
info->type_key_hash = ::mlc::base::StrHash(type_key, std::strlen(type_key));
208209
info->type_depth = (parent == nullptr) ? 0 : (parent->type_depth + 1);
209210
info->type_ancestors = this->NewArray<int32_t>(info->type_depth);
210211
if (parent) {

include/mlc/base/utils.h

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#endif
1717
#include "./base_traits.h"
1818
#include <cstdlib>
19+
#include <cstring>
1920
#include <memory>
2021
#include <sstream>
2122
#include <type_traits>
@@ -297,6 +298,54 @@ inline int64_t StrToInt(const std::string &str, size_t start_pos = 0) {
297298
}
298299
return result;
299300
}
301+
302+
MLC_INLINE uint64_t HashCombine(uint64_t seed, uint64_t value) {
303+
return seed ^ (value + 0x9e3779b9 + (seed << 6) + (seed >> 2));
304+
}
305+
306+
MLC_INLINE int32_t StrCompare(const char *a, const char *b, int64_t a_len, int64_t b_len) {
307+
if (a_len != b_len) {
308+
return static_cast<int32_t>(a_len - b_len);
309+
}
310+
return std::strncmp(a, b, a_len);
311+
}
312+
313+
inline uint64_t StrHash(const char *str, int64_t length) {
314+
const char *it = str;
315+
const char *end = str + length;
316+
uint64_t result = 0;
317+
for (; it + 8 <= end; it += 8) {
318+
uint64_t b = (static_cast<uint64_t>(it[0]) << 56) | (static_cast<uint64_t>(it[1]) << 48) |
319+
(static_cast<uint64_t>(it[2]) << 40) | (static_cast<uint64_t>(it[3]) << 32) |
320+
(static_cast<uint64_t>(it[4]) << 24) | (static_cast<uint64_t>(it[5]) << 16) |
321+
(static_cast<uint64_t>(it[6]) << 8) | static_cast<uint64_t>(it[7]);
322+
result = HashCombine(result, b);
323+
}
324+
if (it < end) {
325+
uint64_t b = 0;
326+
if (it + 4 <= end) {
327+
b = (static_cast<uint64_t>(it[0]) << 24) | (static_cast<uint64_t>(it[1]) << 16) |
328+
(static_cast<uint64_t>(it[2]) << 8) | static_cast<uint64_t>(it[3]);
329+
it += 4;
330+
}
331+
if (it + 2 <= end) {
332+
b = (b << 16) | (static_cast<uint64_t>(it[0]) << 8) | static_cast<uint64_t>(it[1]);
333+
it += 2;
334+
}
335+
if (it + 1 <= end) {
336+
b = (b << 8) | static_cast<uint64_t>(it[0]);
337+
it += 1;
338+
}
339+
result = HashCombine(result, b);
340+
}
341+
return result;
342+
}
343+
344+
inline uint64_t StrHash(const char *str) {
345+
int64_t length = static_cast<int64_t>(std::strlen(str));
346+
return StrHash(str, length);
347+
}
348+
300349
} // namespace base
301350
} // namespace mlc
302351

include/mlc/c_api.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,7 @@ typedef struct {
200200
typedef struct MLCTypeInfo {
201201
int32_t type_index;
202202
const char *type_key;
203+
uint64_t type_key_hash;
203204
int32_t type_depth;
204205
int32_t *type_ancestors; // Range: [0, type_depth)
205206
MLCTypeField *fields; // Ends with a field with name == nullptr

include/mlc/core/str.h

Lines changed: 19 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,6 @@ struct StrObj : public MLCStr {
4343
MLC_INLINE const char *data() const { return this->MLCStr::data; }
4444
MLC_INLINE int64_t length() const { return this->MLCStr::length; }
4545
MLC_INLINE int64_t size() const { return this->MLCStr::length; }
46-
MLC_INLINE uint64_t Hash() const;
4746
MLC_INLINE bool StartsWith(const std::string &prefix) {
4847
int64_t N = static_cast<int64_t>(prefix.length());
4948
return N <= MLCStr::length && strncmp(MLCStr::data, prefix.data(), prefix.length()) == 0;
@@ -53,9 +52,21 @@ struct StrObj : public MLCStr {
5352
return N <= MLCStr::length && strncmp(MLCStr::data + MLCStr::length - N, suffix.data(), N) == 0;
5453
}
5554
MLC_INLINE void PrintEscape(std::ostream &os) const;
56-
MLC_INLINE int Compare(const StrObj *other) const { return std::strncmp(c_str(), other->c_str(), this->size() + 1); }
57-
MLC_INLINE int Compare(const std::string &other) const { return std::strncmp(c_str(), other.c_str(), size() + 1); }
58-
MLC_INLINE int Compare(const char *other) const { return std::strncmp(this->c_str(), other, this->size() + 1); }
55+
MLC_INLINE int32_t Compare(const char *rhs_str, int64_t rhs_len) const {
56+
return ::mlc::base::StrCompare(this->MLCStr::data, rhs_str, this->MLCStr::length, rhs_len);
57+
}
58+
MLC_INLINE int32_t Compare(const StrObj *other) const {
59+
return this->Compare(other->c_str(), other->MLCStr::length); //
60+
}
61+
MLC_INLINE int32_t Compare(const std::string &other) const {
62+
return this->Compare(other.data(), static_cast<int64_t>(other.length()));
63+
}
64+
MLC_INLINE int32_t Compare(const char *other) const {
65+
return this->Compare(other, static_cast<int64_t>(std::strlen(other)));
66+
}
67+
MLC_INLINE uint64_t Hash() const {
68+
return ::mlc::base::StrHash(this->MLCStr::data, this->MLCStr::length); //
69+
}
5970
MLC_DEF_STATIC_TYPE(StrObj, Object, MLCTypeIndex::kMLCStr, "object.Str")
6071
.FieldReadOnly("length", &MLCStr::length)
6172
.FieldReadOnly("data", &MLCStr::data)
@@ -247,7 +258,7 @@ inline std::ostream &operator<<(std::ostream &os, const Object &src) {
247258
return os;
248259
}
249260

250-
void StrObj::PrintEscape(std::ostream &oss) const {
261+
inline void StrObj::PrintEscape(std::ostream &oss) const {
251262
const char *data = this->MLCStr::data;
252263
int64_t length = this->MLCStr::length;
253264
oss << '"';
@@ -322,53 +333,7 @@ void StrObj::PrintEscape(std::ostream &oss) const {
322333
oss << '"';
323334
}
324335

325-
} // namespace mlc
326-
327-
namespace mlc {
328-
namespace core {
329-
330-
MLC_INLINE int32_t StrCompare(const MLCStr *a, const MLCStr *b) {
331-
if (a->length != b->length) {
332-
return static_cast<int32_t>(a->length - b->length);
333-
}
334-
return std::strncmp(a->data, b->data, a->length);
335-
}
336-
337-
MLC_INLINE uint64_t StrHash(const MLCStr *str) {
338-
const constexpr uint64_t kMultiplier = 1099511628211ULL;
339-
const constexpr uint64_t kMod = 2147483647ULL;
340-
const char *it = str->data;
341-
const char *end = it + str->length;
342-
uint64_t result = 0;
343-
for (; it + 8 <= end; it += 8) {
344-
uint64_t b = (static_cast<uint64_t>(it[0]) << 56) | (static_cast<uint64_t>(it[1]) << 48) |
345-
(static_cast<uint64_t>(it[2]) << 40) | (static_cast<uint64_t>(it[3]) << 32) |
346-
(static_cast<uint64_t>(it[4]) << 24) | (static_cast<uint64_t>(it[5]) << 16) |
347-
(static_cast<uint64_t>(it[6]) << 8) | static_cast<uint64_t>(it[7]);
348-
result = (result * kMultiplier + b) % kMod;
349-
}
350-
if (it < end) {
351-
uint64_t b = 0;
352-
if (it + 4 <= end) {
353-
b = (static_cast<uint64_t>(it[0]) << 24) | (static_cast<uint64_t>(it[1]) << 16) |
354-
(static_cast<uint64_t>(it[2]) << 8) | static_cast<uint64_t>(it[3]);
355-
it += 4;
356-
}
357-
if (it + 2 <= end) {
358-
b = (b << 16) | (static_cast<uint64_t>(it[0]) << 8) | static_cast<uint64_t>(it[1]);
359-
it += 2;
360-
}
361-
if (it + 1 <= end) {
362-
b = (b << 8) | static_cast<uint64_t>(it[0]);
363-
it += 1;
364-
}
365-
result = (result * kMultiplier + b) % kMod;
366-
}
367-
return result;
368-
}
369-
} // namespace core
370-
371-
MLC_INLINE Str Str::FromEscaped(int64_t N, const char *str) {
336+
inline Str Str::FromEscaped(int64_t N, const char *str) {
372337
std::ostringstream oss;
373338
if (N < 2 || str[0] != '\"' || str[N - 1] != '\"') {
374339
MLC_THROW(ValueError) << "Invalid escaped string: " << str;
@@ -443,13 +408,14 @@ MLC_INLINE Str Str::FromEscaped(int64_t N, const char *str) {
443408
}
444409
return Str(oss.str());
445410
}
411+
} // namespace mlc
446412

413+
namespace mlc {
447414
namespace base {
448415
MLC_INLINE StrObj *StrCopyFromCharArray(const char *source, size_t length) {
449416
return StrObj::Allocator::New(source, length + 1);
450417
}
451418
} // namespace base
452-
MLC_INLINE uint64_t StrObj::Hash() const { return ::mlc::core::StrHash(reinterpret_cast<const MLCStr *>(this)); }
453419
} // namespace mlc
454420

455421
#endif // MLC_CORE_STR_H_

0 commit comments

Comments
 (0)