diff --git a/include/bitsery/ext/utils/pointer_utils.h b/include/bitsery/ext/utils/pointer_utils.h index 5057bce..6f548e6 100644 --- a/include/bitsery/ext/utils/pointer_utils.h +++ b/include/bitsery/ext/utils/pointer_utils.h @@ -117,18 +117,23 @@ struct PLCInfoSerializer : PLCInfo size_t id; }; +struct ObserverRef +{ + std::reference_wrapper obj; + size_t baseTypeId; +}; + struct PLCInfoDeserializer : PLCInfo { PLCInfoDeserializer(void* ptr, - size_t sharedTypeId_, + size_t ownerTypeId_, PointerOwnershipType ownershipType_, MemResourceBase* memResource_) : PLCInfo(ownershipType_) , ownerPtr{ ptr } - , sharedTypeId{ sharedTypeId_ } + , ownerTypeId{ ownerTypeId_ } , memResource{ memResource_ } - , observersList{ StdPolyAlloc>{ - memResource_ } } {}; + , observersList{ StdPolyAlloc{ memResource_ } } {}; // need to override these explicitly because we have pointer member PLCInfoDeserializer(const PLCInfoDeserializer&) = delete; @@ -139,33 +144,12 @@ struct PLCInfoDeserializer : PLCInfo PLCInfoDeserializer& operator=(PLCInfoDeserializer&&) = default; - void processOwner(void* ptr) - { - ownerPtr = ptr; - assert(ownershipType != PointerOwnershipType::Observer); - for (auto& o : observersList) - o.get() = ptr; - observersList.clear(); - observersList.shrink_to_fit(); - } - - void processObserver(void*(&ptr)) - { - if (ownerPtr) { - ptr = ownerPtr; - } else { - observersList.emplace_back(ptr); - } - } - void* ownerPtr; // used for polymorphic types in order to identify // if shared objects can be assigned - size_t sharedTypeId; + size_t ownerTypeId; MemResourceBase* memResource; - std::vector, - StdPolyAlloc>> - observersList; + std::vector> observersList; std::unique_ptr sharedState{}; }; @@ -483,7 +467,13 @@ private: memResource](const std::shared_ptr& handler) { TPtrManager::destroyPolymorphic(obj, memResource, handler); }); - ptrInfo.processOwner(TPtrManager::getPtr(obj)); + auto ptr = TPtrManager::getPtr(obj); + // might be null in case data pointer is not valid + if (ptr) { + ptrInfo.ownerPtr = ptr; + ptrInfo.ownerTypeId = ctx.getPolymorphicHandler(ptr)->getDerivedTypeId(); + processObserverListPolymorphic(des, ptrInfo, ctx); + } } template @@ -496,17 +486,18 @@ private: OwnershipType) const { auto ptr = TPtrManager::getPtr(obj); - if (ptr) { - fnc(des, *ptr); - } else { + if (!ptr) { TPtrManager::create( obj, memResource, RTTI::template get::TElement>()); ptr = TPtrManager::getPtr(obj); - fnc(des, *ptr); } - ptrInfo.processOwner(ptr); + fnc(des, *ptr); + ptrInfo.ownerPtr = ptr; + ptrInfo.ownerTypeId = + RTTI::template get::TElement>(); + processObserverList(des, ptrInfo); } template @@ -540,44 +531,26 @@ private: TPtrManager::saveToSharedStatePolymorphic( createAndGetSharedStateObj(ptrInfo), obj); } - ptrInfo.sharedTypeId = + ptrInfo.ownerPtr = TPtrManager::getPtr(obj); + ptrInfo.ownerTypeId = ctx.getPolymorphicHandler(TPtrManager::getPtr(obj)) ->getDerivedTypeId(); // since we just deserialized an object, we can skip checking hierarchy // chain by assigning baseType id instead of derived type id deserializedTypeId = baseTypeId; } else { - deserializedTypeId = ptrInfo.sharedTypeId; + deserializedTypeId = ptrInfo.ownerTypeId; } - if (canAssignToShared(baseTypeId, deserializedTypeId, ctx)) { + if (canAssignToBase(baseTypeId, deserializedTypeId, ctx)) { TPtrManager::loadFromSharedStatePolymorphic( getSharedStateObj(ptrInfo), obj); - ptrInfo.processOwner(TPtrManager::getPtr(obj)); + processObserverListPolymorphic(des, ptrInfo, ctx); } else { des.adapter().error(ReaderError::InvalidPointer); } } - // check if actual deserialized type can be assigned to the base type - // (statically typed) - bool canAssignToShared(size_t baseTypeId, - size_t deserializedTypeId, - const TPolymorphicContext& ctx) const - { - if (baseTypeId == deserializedTypeId) - return true; - auto bases = ctx.getDirectBases(deserializedTypeId); - if (bases) { - for (auto typeId : *bases) { - if (canAssignToShared(baseTypeId, typeId, ctx)) { - return true; - } - } - } - return false; - } - template void deserializeImpl(MemResourceBase* memResource, PLCInfoDeserializer& ptrInfo, @@ -603,12 +576,13 @@ private: ptr = TPtrManager::getPtr(obj); } fnc(des, *ptr); - ptrInfo.sharedTypeId = + ptrInfo.ownerTypeId = RTTI::template get::TElement>(); + ptrInfo.ownerPtr = TPtrManager::getPtr(obj); } - if (baseTypeId == ptrInfo.sharedTypeId) { + if (baseTypeId == ptrInfo.ownerTypeId) { TPtrManager::loadFromSharedState(getSharedStateObj(ptrInfo), obj); - ptrInfo.processOwner(TPtrManager::getPtr(obj)); + processObserverList(des, ptrInfo); } else { des.adapter().error(ReaderError::InvalidPointer); } @@ -633,17 +607,104 @@ private: OwnershipType{}); } - template + template void deserializeImpl(MemResourceBase*, PLCInfoDeserializer& ptrInfo, - Des&, + Des& des, T& obj, Fnc&&, - isPolymorphic, + std::false_type, OwnershipType) const { - ptrInfo.processObserver( - reinterpret_cast(TPtrManager::getPtrRef(obj))); + auto baseTypeId = RTTI::template get::TElement>(); + void*(&ptr) = reinterpret_cast(TPtrManager::getPtrRef(obj)); + if (ptrInfo.ownerPtr) { + if (ptrInfo.ownerTypeId == baseTypeId) { + ptr = ptrInfo.ownerPtr; + } else { + des.adapter().error(ReaderError::InvalidPointer); + } + } else { + ptrInfo.observersList.emplace_back(ObserverRef{ ptr, baseTypeId }); + } + } + + template + void deserializeImpl(MemResourceBase*, + PLCInfoDeserializer& ptrInfo, + Des& des, + T& obj, + Fnc&&, + std::true_type, + OwnershipType) const + { + const auto& ctx = des.template context>(); + const size_t baseTypeId = + RTTI::template get::TElement>(); + void*(&ptr) = reinterpret_cast(TPtrManager::getPtrRef(obj)); + if (ptrInfo.ownerPtr) { + if (canAssignToBase(baseTypeId, ptrInfo.ownerTypeId, ctx)) { + // TODO cast from one ptr to another + ptr = ptrInfo.ownerPtr; + } else { + des.adapter().error(ReaderError::InvalidPointer); + } + } else { + ptrInfo.observersList.emplace_back(ObserverRef{ ptr, baseTypeId }); + } + } + + // check if actual deserialized type can be assigned to the base type + // (statically typed) + bool canAssignToBase(size_t baseTypeId, + size_t deserializedTypeId, + const TPolymorphicContext& ctx) const + { + if (baseTypeId == deserializedTypeId) + return true; + auto bases = ctx.getDirectBases(deserializedTypeId); + if (bases) { + for (auto typeId : *bases) { + if (canAssignToBase(baseTypeId, typeId, ctx)) { + return true; + } + } + } + return false; + } + + template + void processObserverList(Des& des, PLCInfoDeserializer& ptrInfo) const + { + assert(ptrInfo.ownershipType != PointerOwnershipType::Observer); + for (auto& o : ptrInfo.observersList) { + if (ptrInfo.ownerTypeId == o.baseTypeId) { + o.obj.get() = ptrInfo.ownerPtr; + } else { + des.adapter().error(ReaderError::InvalidPointer); + } + } + ptrInfo.observersList.clear(); + ptrInfo.observersList.shrink_to_fit(); + } + + template + void processObserverListPolymorphic( + Des& des, + PLCInfoDeserializer& ptrInfo, + const TPolymorphicContext& ctx) const + { + assert(ptrInfo.ownershipType != PointerOwnershipType::Observer); + for (auto& o : ptrInfo.observersList) { + if (canAssignToBase(o.baseTypeId, ptrInfo.ownerTypeId, ctx)) { + // TODO cast from one ptr to another + o.obj.get() = ptrInfo.ownerPtr; + } else { + des.adapter().error(ReaderError::InvalidPointer); + } + } + ptrInfo.observersList.clear(); + ptrInfo.observersList.shrink_to_fit(); } template diff --git a/include/bitsery/ext/utils/polymorphism_utils.h b/include/bitsery/ext/utils/polymorphism_utils.h index 7b9476b..23fb79d 100644 --- a/include/bitsery/ext/utils/polymorphism_utils.h +++ b/include/bitsery/ext/utils/polymorphism_utils.h @@ -135,20 +135,20 @@ template class AbstractPolymorphicHandler : public PolymorphicHandlerBase { public: - void* create(const pointer_utils::PolyAllocWithTypeId& alloc) const + void* create(const pointer_utils::PolyAllocWithTypeId&) const { assert(false); return nullptr; } - void destroy(const pointer_utils::PolyAllocWithTypeId& alloc, void* ptr) const + void destroy(const pointer_utils::PolyAllocWithTypeId&, void*) const { assert(false); }; - void process(void* ser, void* obj) const { assert(false); } + void process(void*, void*) const { assert(false); } - void* getRootPtr(const void* obj) const + void* getRootPtr(const void*) const { assert(false); return nullptr; diff --git a/tests/serialization_ext_pointer.cpp b/tests/serialization_ext_pointer.cpp index 72dace1..250f06c 100644 --- a/tests/serialization_ext_pointer.cpp +++ b/tests/serialization_ext_pointer.cpp @@ -458,6 +458,23 @@ TEST_F(SerializeExtensionPointerDeserialization, PointerObserver) EXPECT_THAT(pr3, Eq(&r3)); } +TEST_F(SerializeExtensionPointerDeserialization, + PointerObserverAndOwnerTypeMustBeTheSame) +{ + // serialize as if we have two same objects + auto& ser = createSerializer(); + ser.ext2b(d1, ReferencedByPointer{}); + ser.ext2b(pd1, PointerObserver{}); + auto& des = createDeserializer(); + // but actual implementation expects distinct objects + des.ext2b(r1, ReferencedByPointer{}); + des.ext4b(pr2, PointerObserver{}); + + EXPECT_THAT(isPointerContextValid(), Eq(true)); + EXPECT_THAT(sctx1.des->adapter().error(), + Eq(bitsery::ReaderError::InvalidPointer)); +} + struct Test1Data { std::vector vdata; diff --git a/tests/serialization_ext_pointer_polymorphic_types.cpp b/tests/serialization_ext_pointer_polymorphic_types.cpp index 200ed3c..3052133 100644 --- a/tests/serialization_ext_pointer_polymorphic_types.cpp +++ b/tests/serialization_ext_pointer_polymorphic_types.cpp @@ -371,3 +371,63 @@ TEST_F(SerializeExtensionPointerPolymorphicTypes, EXPECT_THAT(sctx.des->adapter().error(), Eq(bitsery::ReaderError::InvalidPointer)); } + +TEST_F(SerializeExtensionPointerPolymorphicTypes, + SameObjectIsCorrectlyIdentifiedEvenIfObserverHasDifferentBase) +{ + + MultipleVirtualInheritance md; + Derived2* derivedData = &md; + EXPECT_THAT(static_cast(&md), + ::testing::Ne(static_cast(derivedData))); + + auto& ser = createSerializer(); + ser.ext(md, ReferencedByPointer{}); + ser.ext(derivedData, PointerObserver{}); + EXPECT_THAT(isPointerContextValid(), Eq(true)); +} + +TEST_F(SerializeExtensionPointerPolymorphicTypes, + CheckIfOwnerTypeIsAssignableToObserverType) +{ + + MultipleVirtualInheritance md; + Derived2* derivedData = &md; + + auto& ser = createSerializer(); + ser.ext(&md, PointerOwner{}); + ser.ext(derivedData, PointerObserver{}); + + MultipleVirtualInheritance* res1 = nullptr; + NoRelationshipSpecifiedDerived* res2 = nullptr; + auto& des = createDeserializer(); + des.ext(res1, PointerOwner{}); + des.ext(res2, PointerObserver{}); + + EXPECT_THAT(res1, ::testing::NotNull()); + EXPECT_THAT(res2, ::testing::IsNull()); + EXPECT_THAT(sctx.des->adapter().error(), + Eq(bitsery::ReaderError::InvalidPointer)); +} + +TEST_F(SerializeExtensionPointerPolymorphicTypes, + OwnerIsStaticallyCastToObserverType) +{ + + MultipleVirtualInheritance md{ 1, 2, 3, 4 }; + Derived2* derivedData = &md; + + auto& ser = createSerializer(); + ser.ext(&md, PointerOwner{}); + ser.ext(derivedData, PointerObserver{}); + + MultipleVirtualInheritance* res1 = nullptr; + Base* res2 = nullptr; + auto& des = createDeserializer(); + des.ext(res1, PointerOwner{}); + des.ext(res2, PointerObserver{}); + + EXPECT_THAT(res1, ::testing::NotNull()); + EXPECT_THAT(res2, ::testing::NotNull()); + EXPECT_THAT(res2->x, Eq(1)); +}