diff --git a/Changelog.md b/Changelog.md index b4c031cc..f6e56541 100644 --- a/Changelog.md +++ b/Changelog.md @@ -6,6 +6,7 @@ _yyyy.mm.dd_ ### New features * Add `Data::PointerWrapper` class to simplify interracing of GEGELATI with primitive variables (non-array) data from a learning environment. * Add `TPG::ExecutionsStats` class to analyze and export execution statistics gathered using an instrumented TPGGraph. Statistics include averages on numbers of evaluated teams, programs, program lines and instructions, execution traces and various distributions based on execution traces. The class also provides a method to export these statistics to a JSon file, which can be used by other programs and scripts. +* Add a `File::TPGDotExporter::printSubgraph()` method to print only a subgraph from a TPG, starting from a specified `TPGVertex`. This method can notably be used to export the champion TPG throughout the training process, without having to remove other roots from the TPG. * Add a new `Learn::LearningAgent::evaluateOneRoot()` method to ease the evaluation of individual policies in a trained TPG. * Add a new `Learn::LearningAgent::getEnvironment()` method for convenience. diff --git a/gegelatilib/include/file/tpgGraphDotExporter.h b/gegelatilib/include/file/tpgGraphDotExporter.h index 0adbbe12..efbed769 100644 --- a/gegelatilib/include/file/tpgGraphDotExporter.h +++ b/gegelatilib/include/file/tpgGraphDotExporter.h @@ -215,6 +215,21 @@ namespace File { * TPGGraphDotExporter into a dot file. */ void print(); + + /** + * \brief Print a sub-tree of the TPGGraph given when constructing the + * TPGGraphDotExporter into a dot file. + * + * Contrary to the print() method, which prints the whole TPG, this + * method only prints the TPG stemming from the TPG::TPGVertex passed as + * a parameter. Hence, only vertices and programs connected to this + * TPGVertex will be printed in the file, and all others will be + * ignored. + * + * \param[in] root The vertex used as a starting point to print a + * connected TPG. + */ + void printSubGraph(const TPG::TPGVertex* root); }; }; // namespace File diff --git a/gegelatilib/include/file/tpgGraphDotImporter.h b/gegelatilib/include/file/tpgGraphDotImporter.h index aa3b0703..ef4ce8f3 100644 --- a/gegelatilib/include/file/tpgGraphDotImporter.h +++ b/gegelatilib/include/file/tpgGraphDotImporter.h @@ -381,8 +381,10 @@ namespace File { * \param[in] filePath initial path to the file where the dot content * will be written. * \param[in] environment the environment in which the tpg Graph should - * be built \param[in] tpgref a Reference to the TPGGraph to buiuld from - * the .dot file \throws std::runtime_error in case no file could be + * be built + * \param[in] tpgref a Reference to the TPGGraph to build from + * the .dot file + * \throws std::runtime_error in case no file could be * opened at the given filePath. */ TPGGraphDotImporter(const char* filePath, Environment environment, diff --git a/gegelatilib/src/file/tpgGraphDotExporter.cpp b/gegelatilib/src/file/tpgGraphDotExporter.cpp index 2530e37e..5018c95b 100644 --- a/gegelatilib/src/file/tpgGraphDotExporter.cpp +++ b/gegelatilib/src/file/tpgGraphDotExporter.cpp @@ -199,6 +199,10 @@ void File::TPGGraphDotExporter::print() } // Reset program ids + // This is done to ensure that a program without an ID is properly printed + // when first encountered. However, this ruins the original purpose of the + // ID, which should remain constant through multiple exports and + // generations. this->programID.erase(this->programID.begin(), this->programID.end()); // Print all edges @@ -213,3 +217,66 @@ void File::TPGGraphDotExporter::print() // flush file fflush(pFile); } + +void File::TPGGraphDotExporter::printSubGraph(const TPG::TPGVertex* root) +{ + // Print the graph header + this->printTPGGraphHeader(); + + // Reset program ids + // This is done to ensure that a program without an ID is properly printed + // when first encountered. However, this ruins the original purpose of the + // ID, which should remain constant through multiple exports and + // generations. + this->programID.erase(this->programID.begin(), this->programID.end()); + + // Print edges stemming from the given root + // Init a Breadth First scan + std::deque verticesToVisit; + verticesToVisit.push_back(root); + std::vector visitedVertices; + std::vector edgesToPrint; + + while (!verticesToVisit.empty()) { + // Get first vertex + const TPG::TPGVertex* vertex = verticesToVisit.front(); + verticesToVisit.pop_front(); + visitedVertices.push_back(vertex); + + // Print it if it is a team (actions are printed with edges) + const TPG::TPGTeam* team = nullptr; + if ((team = dynamic_cast(vertex)) != nullptr) { + this->printTPGTeam(*(const TPG::TPGTeam*)vertex); + + // Put its outgoing edge in the list for later print. + // Edges must be printed after their destination team has been + // written. + for (auto edge : team->getOutgoingEdges()) { + edgesToPrint.push_back(edge); + + // If the edge destination is a Team, put it in the list of + // vertex to be visited. + const TPG::TPGVertex* dest = edge->getDestination(); + if (dynamic_cast(dest) != nullptr && + std::find(visitedVertices.begin(), visitedVertices.end(), + dest) == visitedVertices.end() && + std::find(verticesToVisit.begin(), verticesToVisit.end(), + dest) == verticesToVisit.end()) { + verticesToVisit.push_back(dest); + } + } + } + } + + // Print edges + for (const TPG::TPGEdge* edge : edgesToPrint) { + this->printTPGEdge(*edge); + } + + // Print specific footer (no need for rank, since there is a single root) + this->offset = ""; + fprintf(pFile, "%s}\n", this->offset.c_str()); + + // flush file + fflush(pFile); +} diff --git a/test/dat/exported_subtpg_ref.dot b/test/dat/exported_subtpg_ref.dot new file mode 100644 index 00000000..154cf779 --- /dev/null +++ b/test/dat/exported_subtpg_ref.dot @@ -0,0 +1,43 @@ +~// File exported with GEGELATI vX.Y.Z +~// On the YYYY-MM-DD HH:MM:SS +~// With the +digraph{ + graph[pad = "0.212, 0.055" bgcolor = lightgray] + node[shape=circle style = filled label = ""] + T0 [fillcolor="#1199bb"] + T1 [fillcolor="#66ddff"] + T2 [fillcolor="#66ddff"] + P0 [fillcolor="#cccccc" shape=point] //-2|-1|0|1|2| + I0 [shape=box style=invis label="0|1&0|1#0|0\n0|1&0|1#0|0\n0|1&0|1#0|0\n"] + P0 -> I0[style=invis] + A0 [fillcolor="#ff3366" shape=box margin=0.03 width=0 height=0 label="0"] + T0 -> P0 -> A0 + P1 [fillcolor="#cccccc" shape=point] //-2|-1|0|1|2| + I1 [shape=box style=invis label=""] + P1 -> I1[style=invis] + T0 -> P1 -> T1 + P2 [fillcolor="#cccccc" shape=point] //-2|-1|0|1|2| + I2 [shape=box style=invis label=""] + P2 -> I2[style=invis] + A1 [fillcolor="#ff3366" shape=box margin=0.03 width=0 height=0 label="1"] + T1 -> P2 -> A1 + P3 [fillcolor="#cccccc" shape=point] //-2|-1|0|1|2| + I3 [shape=box style=invis label=""] + P3 -> I3[style=invis] + T1 -> P3 -> T2 + T1 -> P0 + P4 [fillcolor="#cccccc" shape=point] //-2|-1|0|1|2| + I4 [shape=box style=invis label=""] + P4 -> I4[style=invis] + A2 [fillcolor="#ff3366" shape=box margin=0.03 width=0 height=0 label="2"] + T1 -> P4 -> A2 + P5 [fillcolor="#cccccc" shape=point] //-2|-1|0|1|2| + I5 [shape=box style=invis label=""] + P5 -> I5[style=invis] + A3 [fillcolor="#ff3366" shape=box margin=0.03 width=0 height=0 label="2"] + T2 -> P5 -> A3 + P6 [fillcolor="#cccccc" shape=point] //-2|-1|0|1|2| + I6 [shape=box style=invis label=""] + P6 -> I6[style=invis] + T2 -> P6 -> T1 +} diff --git a/test/exporterTest.cpp b/test/exporterTest.cpp index b78c74db..35cad219 100644 --- a/test/exporterTest.cpp +++ b/test/exporterTest.cpp @@ -199,6 +199,20 @@ TEST_F(ExporterTest, print) << "File export was executed without error."; } +TEST_F(ExporterTest, printSubGraph) +{ + File::TPGGraphDotExporter dotExporter("exported_subtpg.dot", *tpg); + + ASSERT_NO_THROW(dotExporter.printSubGraph(tpg->getVertices().at(0))) + << "File export was executed without error."; + + // Compare the file with a golden ref + ASSERT_TRUE(compare_files("exported_subtpg.dot", + TESTS_DAT_PATH "exported_subtpg_ref.dot")) + << "Differences between reference file and exported " + "file were detected."; +} + TEST_F(ExporterTest, FileContentVerification) { // This Test checks the content of the exported file against a golden diff --git a/test/pointerWrapperTest.cpp b/test/pointerWrapperTest.cpp index 44c74a61..5e602fcb 100644 --- a/test/pointerWrapperTest.cpp +++ b/test/pointerWrapperTest.cpp @@ -101,7 +101,7 @@ TEST(PointerWrapperTest, GetDataAtNativeType) #else ASSERT_THROW( d->getDataAt(typeid(double), 0).getSharedPointer(), - std::out_of_range) + std::runtime_error) << "In NDEBUG mode, a pointer with invalid type will be returned when " "requesting a non-handled type, even at a valid location."; #endif