diff --git a/src/analysis/lattices/bounded-conjunction.h b/src/analysis/lattices/bounded-conjunction.h new file mode 100644 index 00000000000..202614ccd21 --- /dev/null +++ b/src/analysis/lattices/bounded-conjunction.h @@ -0,0 +1,270 @@ +/* + * Copyright 2026 WebAssembly Community Group participants + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef wasm_analysis_lattices_bounded_conjunction_h +#define wasm_analysis_lattices_bounded_conjunction_h + +#include +#include +#include +#include +#include + +#include "analysis/lattice.h" +#include "support/inplace_vector.h" + +namespace wasm::analysis { + +// BoundedConjunction represents a conjunction of up to N +// pairwise unrelated elements of the underlying lattice L. +// +// This is a semilattice (implements Lattice but not FullLattice). +// +// The elements are represented as a std::variant of wasm::inplace_vector and +// Bot. The top element (no constraints) is represented as an empty vector. +// +// Subclasses must implement: +// std::strong_ordering orderElements(const typename L::Element&, const +// typename L::Element&) const +// +// To maximize monotonicity of boundedMeet, the total order defined by +// orderElements must be a linear extension of the lattice partial order (i.e. x +// <_L y => x < y in total order). +// +// To enforce this and simplify subclass implementation, boundedMeet will ONLY +// call subclass.orderElements for elements that are unrelated in L. For related +// elements, the lattice order is used. +template +struct BoundedConjunction { + struct Bot { + bool operator==(const Bot&) const { return true; } + bool operator!=(const Bot&) const { return false; } + }; + + using Element = std::variant, Bot>; + + L lattice; + + BoundedConjunction(L&& lattice) : lattice(std::move(lattice)) { + static_assert( + requires(const Subclass& sub, + const typename L::Element& a, + const typename L::Element& b) { + { sub.orderElements(a, b) } -> std::same_as; + }, + "Subclass must implement orderElements(const L::Element&, const " + "L::Element&) -> std::strong_ordering"); + } + + Element getBottom() const noexcept { + return Element{std::in_place_type}; + } + + LatticeComparison compare(const Element& a, const Element& b) const noexcept { + if (std::holds_alternative(a) && std::holds_alternative(b)) + return EQUAL; + if (std::holds_alternative(a)) + return LESS; + if (std::holds_alternative(b)) + return GREATER; + + const auto& va = std::get>(a); + const auto& vb = std::get>(b); + + if (va.empty() && vb.empty()) + return EQUAL; + if (va.empty()) + return GREATER; + if (vb.empty()) + return LESS; + + // a <= b iff for all eb in b, there exists ea in a s.t. ea <= eb + bool a_le_b = std::all_of(vb.begin(), vb.end(), [&](const auto& eb) { + return std::any_of(va.begin(), va.end(), [&](const auto& ea) { + auto comp = lattice.compare(ea, eb); + return comp == LESS || comp == EQUAL; + }); + }); + + // b <= a iff for all ea in a, there exists eb in b s.t. eb <= ea + bool b_le_a = std::all_of(va.begin(), va.end(), [&](const auto& ea) { + return std::any_of(vb.begin(), vb.end(), [&](const auto& eb) { + auto comp = lattice.compare(eb, ea); + return comp == LESS || comp == EQUAL; + }); + }); + + if (a_le_b && b_le_a) + return EQUAL; + if (a_le_b) + return LESS; + if (b_le_a) + return GREATER; + return NO_RELATION; + } + + bool join(Element& joinee, const Element& joiner) const noexcept { + if (std::holds_alternative(joiner)) + return false; + if (std::holds_alternative(joinee)) { + joinee = joiner; + return true; + } + + auto& v_joinee = std::get>(joinee); + const auto& v_joiner = + std::get>(joiner); + + if (v_joinee.empty()) + return false; + if (v_joiner.empty()) { + v_joinee.clear(); + return true; + } + + std::vector temp_result; + for (const auto& ea : v_joinee) { + for (const auto& eb : v_joiner) { + auto joined = ea; + lattice.join(joined, eb); + addConstraint(temp_result, joined); + } + } + + if (temp_result.size() > N) { + const auto& subclass = *static_cast(this); + std::sort(temp_result.begin(), + temp_result.end(), + [&](const auto& a, const auto& b) { + auto comp = lattice.compare(a, b); + if (comp == LESS) + return true; + if (comp == GREATER) + return false; + if (comp == EQUAL) + return false; + return subclass.orderElements(a, b) == + std::strong_ordering::less; + }); + temp_result.erase(temp_result.begin() + N, temp_result.end()); + } + + inplace_vector result; + for (auto& e : temp_result) { + result.push_back(std::move(e)); + } + + if (v_joinee == result) + return false; + v_joinee = std::move(result); + return true; + } + + bool boundedMeet(Element& meetee, const Element& meeter) const noexcept { + if (std::holds_alternative(meetee)) + return false; + if (std::holds_alternative(meeter)) { + meetee = getBottom(); + return true; + } + + auto& v_meetee = std::get>(meetee); + const auto& v_meeter = + std::get>(meeter); + + if (v_meeter.empty()) + return false; + if (v_meetee.empty()) { + v_meetee = v_meeter; + return true; + } + + std::vector temp_result(v_meetee.begin(), + v_meetee.end()); + for (const auto& em : v_meeter) { + if (!addConstraint(temp_result, em)) { + meetee = getBottom(); + return true; + } + } + + if (temp_result.size() > N) { + const auto& subclass = *static_cast(this); + std::sort(temp_result.begin(), + temp_result.end(), + [&](const auto& a, const auto& b) { + auto comp = lattice.compare(a, b); + if (comp == LESS) + return true; + if (comp == GREATER) + return false; + if (comp == EQUAL) + return false; + return subclass.orderElements(a, b) == + std::strong_ordering::less; + }); + temp_result.erase(temp_result.begin() + N, temp_result.end()); + } + + inplace_vector result; + for (auto& e : temp_result) { + result.push_back(std::move(e)); + } + + if (v_meetee == result) + return false; + v_meetee = std::move(result); + return true; + } + +private: + // Helper to add constraint and simplify. + // Returns false if meet results in Bottom. + bool addConstraint(std::vector& vec, + const typename L::Element& e) const noexcept { + if (lattice.compare(e, lattice.getBottom()) == EQUAL) { + return false; + } + if constexpr (requires { + { lattice.getTop() } -> std::same_as; + }) { + if (lattice.compare(e, lattice.getTop()) == EQUAL) { + return true; + } + } + + bool add = true; + for (auto it = vec.begin(); it != vec.end();) { + auto comp = lattice.compare(e, *it); + if (comp == EQUAL || comp == LESS) { + it = vec.erase(it); + } else if (comp == GREATER) { + add = false; + ++it; + } else { + ++it; + } + } + if (add) { + vec.push_back(e); + } + return true; + } +}; + +} // namespace wasm::analysis + +#endif // wasm_analysis_lattices_bounded_conjunction_h diff --git a/src/tools/wasm-fuzz-lattices.cpp b/src/tools/wasm-fuzz-lattices.cpp index ecc23b151bb..20bb1091f6a 100644 --- a/src/tools/wasm-fuzz-lattices.cpp +++ b/src/tools/wasm-fuzz-lattices.cpp @@ -16,6 +16,7 @@ #include #include +#include #include #include #include @@ -23,6 +24,7 @@ #include "analysis/lattice.h" #include "analysis/lattices/array.h" #include "analysis/lattices/bool.h" +#include "analysis/lattices/bounded-conjunction.h" #include "analysis/lattices/flat.h" #include "analysis/lattices/int.h" #include "analysis/lattices/inverted.h" @@ -56,25 +58,39 @@ uint64_t getSeed() { return std::uniform_int_distribution{}(rand); } +// Functor deleter to make RandomElement default-constructible, which is +// required by std::array (used in inplace_vector). +struct RandomElementDeleter { + void (*deleter)(void*) = nullptr; + void operator()(void* p) const { + if (deleter) { + deleter(p); + } + } +}; + // Actually a pointer to `L::ElementImpl`, but we erase the type to avoid // getting into a situation where `L` satisfying `Lattice` or `FullLattice` // circularly requires that `L` satisfies `Lattice` or `FullLattice`. C++ does // not allow concepts to depend on themselves. Also make the pointer copyable to // satisfy that constraint on lattice elements. template -struct RandomElement : std::unique_ptr { +struct RandomElement : std::unique_ptr { RandomElement() = default; RandomElement(typename L::ElementImpl&& other) - : std::unique_ptr( + : std::unique_ptr( new typename L::ElementImpl(std::move(other)), - [](void* e) { delete static_cast(e); }) {} + RandomElementDeleter{ + [](void* e) { delete static_cast(e); }}) {} RandomElement(const RandomElement& other) - : RandomElement([&]() { - auto copy = *other; - return copy; - }()) {} + : std::unique_ptr( + other.get() ? new + typename L::ElementImpl( + *static_cast(other.get())) + : nullptr, + other.get_deleter()) {} RandomElement(RandomElement&& other) = default; RandomElement& operator=(const RandomElement& other) { @@ -87,16 +103,28 @@ struct RandomElement : std::unique_ptr { RandomElement& operator=(RandomElement&& other) = default; typename L::ElementImpl& operator*() { - return *static_cast(get()); + return *static_cast(this->get()); } const typename L::ElementImpl& operator*() const { - return *static_cast(get()); + return *static_cast(this->get()); } typename L::ElementImpl* operator->() { return &*(*this); } const typename L::ElementImpl* operator->() const { return &*(*this); } + + bool operator==(const RandomElement& other) const { + if (this->get() == other.get()) + return true; + if (!this->get() || !other.get()) + return false; + return **this == *other; + } + + bool operator!=(const RandomElement& other) const { + return !(*this == other); + } }; struct RandomFullLattice { @@ -186,6 +214,45 @@ using FullLatticeElementVariant = struct RandomFullLattice::ElementImpl : FullLatticeElementVariant {}; +struct FuzzBoundedConjunction; + +void printElement(std::ostream& os, + const typename RandomLattice::Element& elem, + int depth = 0); + +struct FuzzBoundedConjunction + : analysis::BoundedConjunction { + using Base = + analysis::BoundedConjunction; + FuzzBoundedConjunction(RandomLattice&& lattice) : Base(std::move(lattice)) {} + + std::strong_ordering orderElements(const RandomLattice::Element& a, + const RandomLattice::Element& b) const { + std::stringstream ss1, ss2; + printElement(ss1, a, 0); + printElement(ss2, b, 0); + return ss1.str() <=> ss2.str(); + } +}; + +FuzzBoundedConjunction::Element makeFuzzBoundedConjunctionElement( + const FuzzBoundedConjunction& lattice, + const std::vector& elements) { + FuzzBoundedConjunction::Element result{ + inplace_vector{}}; + + for (const auto& e : elements) { + if (lattice.lattice.compare(e, lattice.lattice.getBottom()) == EQUAL) { + return lattice.getBottom(); + } + inplace_vector vec; + vec.push_back(e); + FuzzBoundedConjunction::Element meeter{vec}; + lattice.boundedMeet(result, meeter); + } + return result; +} + using LatticeVariant = std::variant, Lift, @@ -193,7 +260,8 @@ using LatticeVariant = std::variant, TupleLattice, SharedPath, - OneOfLattice>; + OneOfLattice, + FuzzBoundedConjunction>; struct RandomLattice::LatticeImpl : LatticeVariant {}; @@ -205,7 +273,8 @@ using LatticeElementVariant = typename Vector::Element, typename TupleLattice::Element, typename SharedPath::Element, - typename OneOfLattice::Element>; + typename OneOfLattice::Element, + typename FuzzBoundedConjunction::Element>; struct RandomLattice::ElementImpl : LatticeElementVariant {}; @@ -216,7 +285,8 @@ RandomFullLattice::RandomFullLattice(Random& rand, std::optional maybePick) : rand(rand) { // TODO: Limit the depth once we get lattices with more fan-out. - uint32_t pick = maybePick ? *maybePick : rand.upTo(FullLatticePicks); + uint32_t maxPick = depth > 3 ? 3 : FullLatticePicks; + uint32_t pick = maybePick ? *maybePick : rand.upTo(maxPick); switch (pick) { case 0: lattice = std::make_unique(LatticeImpl{Bool{}}); @@ -255,7 +325,8 @@ RandomFullLattice::RandomFullLattice(Random& rand, RandomLattice::RandomLattice(Random& rand, size_t depth) : rand(rand) { // TODO: Limit the depth once we get lattices with more fan-out. - uint32_t pick = rand.upTo(FullLatticePicks + 7); + uint32_t maxPick = depth > 3 ? FullLatticePicks + 1 : FullLatticePicks + 8; + uint32_t pick = rand.upTo(maxPick); if (pick < FullLatticePicks) { lattice = std::make_unique( @@ -291,6 +362,10 @@ RandomLattice::RandomLattice(Random& rand, size_t depth) : rand(rand) { lattice = std::make_unique(LatticeImpl{OneOfLattice{ RandomLattice{rand, depth + 1}, RandomLattice{rand, depth + 1}}}); return; + case FullLatticePicks + 7: + lattice = std::make_unique( + LatticeImpl{FuzzBoundedConjunction{RandomLattice{rand, depth + 1}}}); + return; } WASM_UNREACHABLE("unexpected pick"); } @@ -424,6 +499,24 @@ RandomLattice::Element RandomLattice::makeElement() const noexcept { return ElementImpl{l->get<1>(std::get<1>(l->lattices).makeElement())}; } } + if (const auto* l = std::get_if(lattice.get())) { + auto pick = rand.upTo(4); + switch (pick) { + case 0: + return ElementImpl{l->getBottom()}; + case 1: + return ElementImpl{typename FuzzBoundedConjunction::Element{ + inplace_vector{}}}; + case 2: { + return ElementImpl{ + makeFuzzBoundedConjunctionElement(*l, {l->lattice.makeElement()})}; + } + case 3: { + return ElementImpl{makeFuzzBoundedConjunctionElement( + *l, {l->lattice.makeElement(), l->lattice.makeElement()})}; + } + } + } WASM_UNREACHABLE("unexpected lattice"); } @@ -467,6 +560,13 @@ void printFullElement(std::ostream& os, } indent(os, depth); os << "]\n"; + } else if (const auto* e = + std::get_if(&*elem)) { + os << "Tuple(\n"; + printFullElement(os, std::get<0>(*e), depth + 1); + printFullElement(os, std::get<1>(*e), depth + 1); + indent(os, depth); + os << ")\n"; } else if (const auto* e = std::get_if(&*elem)) { if (e->isBottom()) { @@ -493,7 +593,7 @@ void printFullElement(std::ostream& os, void printElement(std::ostream& os, const typename RandomLattice::Element& elem, - int depth = 0) { + int depth) { if (const auto* e = std::get_if(&*elem)) { printFullElement(os, *e, depth); @@ -569,6 +669,24 @@ void printElement(std::ostream& os, } else { WASM_UNREACHABLE("unexpected one-of element"); } + } else if (const auto* e = + std::get_if(&*elem)) { + if (std::holds_alternative(*e)) { + os << "bounded-conjunction bot\n"; + } else { + const auto& vec = + std::get>(*e); + if (vec.empty()) { + os << "bounded-conjunction top\n"; + } else { + os << "BoundedConjunction[\n"; + for (const auto& el : vec) { + printElement(os, el, depth + 1); + } + indent(os, depth); + os << "]\n"; + } + } } else { WASM_UNREACHABLE("unexpected element"); } diff --git a/test/gtest/lattices.cpp b/test/gtest/lattices.cpp index a6098f8de12..67f35779266 100644 --- a/test/gtest/lattices.cpp +++ b/test/gtest/lattices.cpp @@ -18,6 +18,7 @@ #include "analysis/lattices/abstraction.h" #include "analysis/lattices/array.h" #include "analysis/lattices/bool.h" +#include "analysis/lattices/bounded-conjunction.h" #include "analysis/lattices/conetype.h" #include "analysis/lattices/flat.h" #include "analysis/lattices/int.h" @@ -1512,3 +1513,159 @@ TEST(OneOfLattice, Meet) { test(b_false, i_0, bot); test(i_10, b_true, bot); } + +struct TestBoundedConjunction + : analysis:: + BoundedConjunction, 2> { + using Base = analysis:: + BoundedConjunction, 2>; + TestBoundedConjunction() : Base(analysis::Flat{}) {} + + std::strong_ordering + orderElements(const analysis::Flat::Element& a, + const analysis::Flat::Element& b) const { + assert(a.getVal() && b.getVal()); + return *a.getVal() <=> *b.getVal(); + } +}; + +TEST(BoundedConjunctionLattice, Compare) { + TestBoundedConjunction lattice; + auto flat = lattice.lattice; + + auto bot = lattice.getBottom(); + + auto make_elem = [&](std::initializer_list vals) { + wasm::inplace_vector::Element, 2> vec; + for (int val : vals) { + vec.push_back(flat.get(val)); + } + return TestBoundedConjunction::Element{vec}; + }; + + auto e_empty = make_elem({}); // Top + auto e_1 = make_elem({1}); + auto e_2 = make_elem({2}); + auto e_1_2 = make_elem({1, 2}); + auto e_2_3 = make_elem({2, 3}); + + // Bot comparison + EXPECT_EQ(lattice.compare(bot, bot), analysis::EQUAL); + EXPECT_EQ(lattice.compare(bot, e_empty), analysis::LESS); + EXPECT_EQ(lattice.compare(bot, e_1), analysis::LESS); + EXPECT_EQ(lattice.compare(e_1, bot), analysis::GREATER); + + // Top (empty) comparison + EXPECT_EQ(lattice.compare(e_empty, e_empty), analysis::EQUAL); + EXPECT_EQ(lattice.compare(e_empty, e_1), analysis::GREATER); + EXPECT_EQ(lattice.compare(e_1, e_empty), analysis::LESS); + + // Subset comparison (more constraints = smaller) + EXPECT_EQ(lattice.compare(e_1_2, e_1), analysis::LESS); + EXPECT_EQ(lattice.compare(e_1, e_1_2), analysis::GREATER); + + // Equal + EXPECT_EQ(lattice.compare(e_1, e_1), analysis::EQUAL); + EXPECT_EQ(lattice.compare(e_1_2, e_1_2), analysis::EQUAL); + + // Unrelated + EXPECT_EQ(lattice.compare(e_1, e_2), analysis::NO_RELATION); + EXPECT_EQ(lattice.compare(e_1_2, e_2_3), analysis::NO_RELATION); +} + +TEST(BoundedConjunctionLattice, Join) { + TestBoundedConjunction lattice; + auto flat = lattice.lattice; + + auto bot = lattice.getBottom(); + + auto make_elem = [&](std::initializer_list vals) { + wasm::inplace_vector::Element, 2> vec; + for (int val : vals) { + vec.push_back(flat.get(val)); + } + return TestBoundedConjunction::Element{vec}; + }; + + auto e_top = make_elem({}); + auto e_1 = make_elem({1}); + auto e_2 = make_elem({2}); + auto e_1_2 = make_elem({1, 2}); + auto e_2_3 = make_elem({2, 3}); + + auto test = + [&](const auto& joinee, const auto& joiner, const auto& expected) { + auto copy = joinee; + EXPECT_EQ(lattice.join(copy, joiner), joinee != expected); + EXPECT_EQ(copy, expected); + }; + + // Bot/Top joins + test(bot, bot, bot); + test(bot, e_1, e_1); + test(e_1, bot, e_1); + test(e_top, e_1, e_top); + test(e_1, e_top, e_top); + + // Same joins + test(e_1, e_1, e_1); + test(e_1_2, e_1_2, e_1_2); + + // Unrelated joins + test(e_1, e_2, e_top); + + // {1, 2} join {2, 3} -> {2} + test(e_1_2, e_2_3, e_2); +} + +TEST(BoundedConjunctionLattice, BoundedMeet) { + TestBoundedConjunction lattice; + auto flat = lattice.lattice; + + auto bot = lattice.getBottom(); + + auto make_elem = [&](std::initializer_list vals) { + wasm::inplace_vector::Element, 2> vec; + for (int val : vals) { + vec.push_back(flat.get(val)); + } + return TestBoundedConjunction::Element{vec}; + }; + + auto e_top = make_elem({}); + auto e_1 = make_elem({1}); + auto e_2 = make_elem({2}); + auto e_3 = make_elem({3}); + auto e_1_2 = make_elem({1, 2}); + auto e_1_3 = make_elem({1, 3}); + + auto test = + [&](const auto& meetee, const auto& meeter, const auto& expected) { + auto copy = meetee; + EXPECT_EQ(lattice.boundedMeet(copy, meeter), meetee != expected); + EXPECT_EQ(copy, expected); + }; + + // Bot/Top meets + test(bot, bot, bot); + test(bot, e_1, bot); + test(e_1, bot, bot); + test(e_top, e_1, e_1); + test(e_1, e_top, e_1); + + // Same meets + test(e_1, e_1, e_1); + + // Unrelated meets without overflow + test(e_1, e_2, e_1_2); + + // Unrelated meets with overflow (N=2) + // {1, 2} meet {3} = {1, 2, 3} -> keep 1, 2 + test(e_1_2, e_3, e_1_2); // returns false (no change) + + // {1, 3} meet {2} = {1, 2, 3} -> keep 1, 2 (changes from {1, 3}) + test(e_1_3, e_2, e_1_2); // returns true + + // Meet with redundancy + test(e_1, e_1_2, e_1_2); +}