diff --git a/include/bitsery/ext/std_smart_ptr.h b/include/bitsery/ext/std_smart_ptr.h index 0cc1fd3..52a938f 100644 --- a/include/bitsery/ext/std_smart_ptr.h +++ b/include/bitsery/ext/std_smart_ptr.h @@ -147,6 +147,7 @@ struct SmartPtrOwnerManager [alloc, typeId](TElement* data) { alloc.deleteObject(data, typeId); }, pointer_utils::StdPolyAlloc(memResource)); state.obj = obj; + state.typeId = typeId; } static void createSharedPolymorphic( @@ -162,6 +163,7 @@ struct SmartPtrOwnerManager [alloc, handler](TElement* data) { handler->destroy(alloc, data); }, pointer_utils::StdPolyAlloc(memResource)); state.obj = obj; + state.typeId = handler->getDerivedTypeId(); } static void createShared(TSharedState& state, @@ -176,6 +178,7 @@ struct SmartPtrOwnerManager pointer_utils::StdPolyAlloc(memResource)); obj = res; state.obj = res; + state.typeId = typeId; } static void createSharedPolymorphic( @@ -191,6 +194,7 @@ struct SmartPtrOwnerManager pointer_utils::StdPolyAlloc(memResource)); obj = res; state.obj = res; + state.typeId = handler->getDerivedTypeId(); } static void saveToSharedState(TSharedState& state, T& obj) @@ -205,20 +209,17 @@ struct SmartPtrOwnerManager static void loadFromSharedState(TSharedState& state, T& obj) { - // reinterpret_pointer_cast is only since c++17 auto v = state.obj.get(); - auto p = reinterpret_cast(v); + auto p = static_cast(v); obj = std::shared_ptr(state.obj, p); } - static void loadFromSharedStatePolymorphic(TSharedState& state, T& obj) + static void loadFromSharedStatePolymorphic(TSharedState& state, + T& obj, + const PolymorphicHandlerBase&) { - // TODO Fix pointer addresses in case objects are deserialized using - // different bases - - // reinterpret_pointer_cast is only since c++17 auto v = state.obj.get(); - auto p = reinterpret_cast(v); + auto p = static_cast(v); obj = std::shared_ptr(state.obj, p); } }; diff --git a/include/bitsery/ext/utils/pointer_utils.h b/include/bitsery/ext/utils/pointer_utils.h index 6f548e6..0fa8465 100644 --- a/include/bitsery/ext/utils/pointer_utils.h +++ b/include/bitsery/ext/utils/pointer_utils.h @@ -56,6 +56,7 @@ namespace pointer_utils { // this class is used to store context for shared ptr owners struct PointerSharedStateBase { + size_t typeId{}; virtual ~PointerSharedStateBase() = default; }; @@ -541,10 +542,10 @@ private: } else { deserializedTypeId = ptrInfo.ownerTypeId; } - - if (canAssignToBase(baseTypeId, deserializedTypeId, ctx)) { + if (auto hndl = + ctx.getPolymorphicHandler(baseTypeId, ptrInfo.sharedState->typeId)) { TPtrManager::loadFromSharedStatePolymorphic( - getSharedStateObj(ptrInfo), obj); + getSharedStateObj(ptrInfo), obj, **hndl); processObserverListPolymorphic(des, ptrInfo, ctx); } else { des.adapter().error(ReaderError::InvalidPointer); @@ -643,9 +644,9 @@ private: RTTI::template get::TElement>(); void*(&ptr) = reinterpret_cast(TPtrManager::getPtrRef(obj)); if (ptrInfo.ownerPtr) { - if (canAssignToBase(baseTypeId, ptrInfo.ownerTypeId, ctx)) { - // TODO cast from one ptr to another - ptr = ptrInfo.ownerPtr; + if (auto hndl = + ctx.getPolymorphicHandler(baseTypeId, ptrInfo.ownerTypeId)) { + ptr = hndl->get()->fromDerivedToBasePtr(ptrInfo.ownerPtr); } else { des.adapter().error(ReaderError::InvalidPointer); } @@ -654,25 +655,6 @@ private: } } - // check if actual deserialized type can be assigned to the base type - // (statically typed) - bool canAssignToBase(size_t baseTypeId, - size_t deserializedTypeId, - const TPolymorphicContext& ctx) const - { - if (baseTypeId == deserializedTypeId) - return true; - auto bases = ctx.getDirectBases(deserializedTypeId); - if (bases) { - for (auto typeId : *bases) { - if (canAssignToBase(baseTypeId, typeId, ctx)) { - return true; - } - } - } - return false; - } - template void processObserverList(Des& des, PLCInfoDeserializer& ptrInfo) const { @@ -696,9 +678,9 @@ private: { assert(ptrInfo.ownershipType != PointerOwnershipType::Observer); for (auto& o : ptrInfo.observersList) { - if (canAssignToBase(o.baseTypeId, ptrInfo.ownerTypeId, ctx)) { - // TODO cast from one ptr to another - o.obj.get() = ptrInfo.ownerPtr; + if (auto hndl = + ctx.getPolymorphicHandler(o.baseTypeId, ptrInfo.ownerTypeId)) { + o.obj.get() = hndl->get()->fromDerivedToBasePtr(ptrInfo.ownerPtr); } else { des.adapter().error(ReaderError::InvalidPointer); } diff --git a/include/bitsery/ext/utils/polymorphism_utils.h b/include/bitsery/ext/utils/polymorphism_utils.h index 23fb79d..ef50d75 100644 --- a/include/bitsery/ext/utils/polymorphism_utils.h +++ b/include/bitsery/ext/utils/polymorphism_utils.h @@ -76,6 +76,8 @@ public: virtual void* getRootPtr(const void* obj) const = 0; + virtual void* fromDerivedToBasePtr(void* obj) const = 0; + virtual size_t getDerivedTypeId() const = 0; virtual ~PolymorphicHandlerBase() = default; @@ -111,6 +113,8 @@ public: static_cast(const_cast(obj))); } + void* fromDerivedToBasePtr(void* obj) const final { return toBase(obj); } + size_t getDerivedTypeId() const final { return RTTI::template get(); @@ -154,6 +158,12 @@ public: return nullptr; } + void* fromDerivedToBasePtr(void*) const final + { + assert(false); + return nullptr; + } + size_t getDerivedTypeId() const { return RTTI::template get(); }; }; @@ -188,12 +198,11 @@ private: typename TRoot, typename TBase, typename TDerived> - void add(size_t depth) + void add() { - addToMap(depth == 1, - std::is_abstract{}); + addToMap(std::is_abstract{}); addChilds( - depth + 1, typename THierarchy::Childs{}); + typename THierarchy::Childs{}); } template - void addChilds(size_t depth, PolymorphicClassesList) + void addChilds(PolymorphicClassesList) { static_assert(std::is_base_of::value, "PolymorphicBaseClass must derive a list of derived " "classes from TBase."); - add(depth); + add(); addChilds( - depth, PolymorphicClassesList{}); - add(0); + PolymorphicClassesList{}); + add(); } template - void addChilds(size_t, PolymorphicClassesList<>) + void addChilds(PolymorphicClassesList<>) { } @@ -229,7 +238,7 @@ private: typename TRoot, typename TBase, typename TDerived> - void addToMap(bool directBase, std::false_type) + void addToMap(std::false_type) { using THandler = PolymorphicHandler; @@ -256,25 +265,13 @@ private: } it->second.push_back(key.derivedHash); } - if (directBase) { - auto it = _derivedToBaseArray.find(key.derivedHash); - if (it == _derivedToBaseArray.end()) { - it = _derivedToBaseArray - .emplace(std::piecewise_construct, - std::forward_as_tuple(key.derivedHash), - std::forward_as_tuple( - pointer_utils::StdPolyAlloc{ _memResource })) - .first; - } - it->second.push_back(key.baseHash); - } } template - void addToMap(bool directBase, std::true_type) + void addToMap(std::true_type) { using THandler = AbstractPolymorphicHandler; BaseToDerivedKey key{ RTTI::template get(), @@ -288,19 +285,7 @@ private: alloc.deallocate(data, 1); }, alloc); - _baseToDerivedMap.emplace(key, std::move(handler)).second; - if (directBase) { - auto it = _derivedToBaseArray.find(key.derivedHash); - if (it == _derivedToBaseArray.end()) { - it = _derivedToBaseArray - .emplace(std::piecewise_construct, - std::forward_as_tuple(key.derivedHash), - std::forward_as_tuple( - pointer_utils::StdPolyAlloc{ _memResource })) - .first; - } - it->second.push_back(key.baseHash); - } + _baseToDerivedMap.emplace(key, std::move(handler)); } MemResourceBase* _memResource; @@ -328,17 +313,6 @@ private: std::vector>>>> _baseToDerivedArray; - // Used to iterate through hierarchy chain from most derived to the base(s) - std::unordered_map< - size_t, - std::vector>, - std::hash, - std::equal_to, - pointer_utils::StdPolyAlloc< - std::pair>>>> - _derivedToBaseArray; - public: explicit PolymorphicContext(MemResourceBase* memResource = nullptr) : _memResource{ memResource } @@ -349,11 +323,6 @@ public: std::pair>>>{ memResource } } - , _derivedToBaseArray{ pointer_utils::StdPolyAlloc< - std::pair>>>{ - memResource } } - { } @@ -379,7 +348,7 @@ public: typename... Tn> void registerBasesList(PolymorphicClassesList) { - add(0); + add(); registerBasesList(PolymorphicClassesList{}); } @@ -464,15 +433,16 @@ public: return it->second; } - const std::vector>* - getDirectBases(size_t derivedTypeId) const + const std::shared_ptr* getPolymorphicHandler( + size_t baseTypeId, + size_t derivedTypeId) const { - auto it = _derivedToBaseArray.find(derivedTypeId); - if (it != _derivedToBaseArray.end()) { - return &it->second; - } else { + auto it = + _baseToDerivedMap.find(BaseToDerivedKey{ baseTypeId, derivedTypeId }); + if (it == _baseToDerivedMap.end()) { return nullptr; } + return &it->second; } }; diff --git a/tests/serialization_ext_pointer_polymorphic_types.cpp b/tests/serialization_ext_pointer_polymorphic_types.cpp index 3052133..d60e755 100644 --- a/tests/serialization_ext_pointer_polymorphic_types.cpp +++ b/tests/serialization_ext_pointer_polymorphic_types.cpp @@ -149,12 +149,10 @@ struct PolymorphicBaseClass : PolymorphicDerivedClasses {}; -// this is commented on purpose, to test scenario when base class is registered -// (Base) but using instance of Derived1 which is not registered as base -// template<> -// struct PolymorphicBaseClass : -// PolymorphicDerivedClasses { -// }; +template<> +struct PolymorphicBaseClass + : PolymorphicDerivedClasses +{}; template<> struct PolymorphicBaseClass @@ -410,8 +408,7 @@ TEST_F(SerializeExtensionPointerPolymorphicTypes, Eq(bitsery::ReaderError::InvalidPointer)); } -TEST_F(SerializeExtensionPointerPolymorphicTypes, - OwnerIsStaticallyCastToObserverType) +TEST_F(SerializeExtensionPointerPolymorphicTypes, OwnerIsCastObserverType) { MultipleVirtualInheritance md{ 1, 2, 3, 4 };