LCOV - code coverage report
Current view: top level - cltools - ShowGraph.cpp (source / functions) Hit Total Coverage
Test: plumed test coverage Lines: 177 177 100.0 %
Date: 2024-10-18 14:00:25 Functions: 13 13 100.0 %

          Line data    Source code
       1             : /* +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
       2             :    Copyright (c) 2012-2023 The plumed team
       3             :    (see the PEOPLE file at the root of the distribution for a list of names)
       4             : 
       5             :    See http://www.plumed.org for more information.
       6             : 
       7             :    This file is part of plumed, version 2.
       8             : 
       9             :    plumed is free software: you can redistribute it and/or modify
      10             :    it under the terms of the GNU Lesser General Public License as published by
      11             :    the Free Software Foundation, either version 3 of the License, or
      12             :    (at your option) any later version.
      13             : 
      14             :    plumed is distributed in the hope that it will be useful,
      15             :    but WITHOUT ANY WARRANTY; without even the implied warranty of
      16             :    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
      17             :    GNU Lesser General Public License for more details.
      18             : 
      19             :    You should have received a copy of the GNU Lesser General Public License
      20             :    along with plumed.  If not, see <http://www.gnu.org/licenses/>.
      21             : +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ */
      22             : #include "CLTool.h"
      23             : #include "core/CLToolRegister.h"
      24             : #include "tools/Tools.h"
      25             : #include "config/Config.h"
      26             : #include "core/PlumedMain.h"
      27             : #include "core/ActionSet.h"
      28             : #include "core/ActionRegister.h"
      29             : #include "core/ActionShortcut.h"
      30             : #include "core/ActionToPutData.h"
      31             : #include "core/ActionWithVirtualAtom.h"
      32             : #include "core/ActionWithVector.h"
      33             : #include <cstdio>
      34             : #include <string>
      35             : #include <iostream>
      36             : 
      37             : namespace PLMD {
      38             : namespace cltools {
      39             : 
      40             : //+PLUMEDOC TOOLS show_graph
      41             : /*
      42             : show_graph is a tool that takes a plumed input and generates a graph showing how
      43             : data flows through the action set involved.
      44             : 
      45             : If this tool is invoked without the --force keyword then the way data is passed through the code during the forward pass
      46             : through the action is shown.
      47             : 
      48             : When the --force keyword is used then the way forces are passed from biases through actions is shown.
      49             : 
      50             : \par Examples
      51             : 
      52             : The following generates the mermaid file for the input in plumed.dat
      53             : \verbatim
      54             : plumed show_graph --plumed plumed.dat
      55             : \endverbatim
      56             : 
      57             : */
      58             : //+ENDPLUMEDOC
      59             : 
      60             : class ShowGraph :
      61             :   public CLTool
      62             : {
      63             : public:
      64             :   static void registerKeywords( Keywords& keys );
      65             :   explicit ShowGraph(const CLToolOptions& co );
      66             :   int main(FILE* in, FILE*out,Communicator& pc);
      67           4 :   std::string description()const {
      68           4 :     return "generate a graph showing how data flows through a PLUMED action set";
      69             :   }
      70             :   std::string getLabel(const Action* a, const bool& amp=false);
      71             :   std::string getLabel(const std::string& s, const bool& amp=false );
      72             :   void printStyle( const unsigned& linkcount, const Value* v, OFile& ofile );
      73             :   void printArgumentConnections( const ActionWithArguments* a, unsigned& linkcount, const bool& force, OFile& ofile );
      74             :   void printAtomConnections( const ActionAtomistic* a, unsigned& linkcount, const bool& force, OFile& ofile );
      75             :   void drawActionWithVectorNode( OFile& ofile, PlumedMain& p, Action* ag, const std::vector<std::string>& mychain, std::vector<bool>& printed );
      76             : };
      77             : 
      78       15960 : PLUMED_REGISTER_CLTOOL(ShowGraph,"show_graph")
      79             : 
      80        5316 : void ShowGraph::registerKeywords( Keywords& keys ) {
      81        5316 :   CLTool::registerKeywords( keys );
      82       10632 :   keys.add("compulsory","--plumed","plumed.dat","the plumed input that we are generating the graph for");
      83       10632 :   keys.add("compulsory","--out","graph.md","the dot file containing the graph that has been generated");
      84       10632 :   keys.addFlag("--force",false,"print a graph that shows how forces are passed through the actions");
      85        5316 : }
      86             : 
      87          12 : ShowGraph::ShowGraph(const CLToolOptions& co ):
      88          12 :   CLTool(co)
      89             : {
      90          12 :   inputdata=commandline;
      91          12 : }
      92             : 
      93         377 : std::string ShowGraph::getLabel(const Action* a, const bool& amp) {
      94         377 :   return getLabel( a->getLabel(), amp );
      95             : }
      96             : 
      97         453 : std::string ShowGraph::getLabel( const std::string& s, const bool& amp ) {
      98         453 :   if( s.find("@")==std::string::npos ) return s;
      99          48 :   std::size_t p=s.find_first_of("@");
     100          63 :   if( amp ) return "#64;" + s.substr(p+1);
     101          33 :   return s.substr(p+1);
     102             : }
     103             : 
     104          85 : void ShowGraph::printStyle( const unsigned& linkcount, const Value* v, OFile& ofile ) {
     105          85 :   if( v->getRank()>0 && v->hasDerivatives() ) ofile.printf("linkStyle %d stroke:green,color:green;\n", linkcount);
     106          85 :   else if( v->getRank()==1 ) ofile.printf("linkStyle %d stroke:blue,color:blue;\n", linkcount);
     107          52 :   else if ( v->getRank()==2 ) ofile.printf("linkStyle %d stroke:red,color:red;\n", linkcount);
     108          85 : }
     109             : 
     110          63 : void ShowGraph::printArgumentConnections( const ActionWithArguments* a, unsigned& linkcount, const bool& force, OFile& ofile ) {
     111          63 :   if( !a ) return;
     112         101 :   for(const auto & v : a->getArguments() ) {
     113          55 :     if( force && v->forcesWereAdded() ) {
     114          28 :       ofile.printf("%s -- %s --> %s\n", getLabel(a).c_str(), v->getName().c_str(), getLabel(v->getPntrToAction()).c_str() );
     115          14 :       printStyle( linkcount, v, ofile ); linkcount++;
     116          41 :     } else if( !force ) {
     117          66 :       ofile.printf("%s -- %s --> %s\n", getLabel(v->getPntrToAction()).c_str(),v->getName().c_str(),getLabel(a).c_str() );
     118          33 :       printStyle( linkcount, v, ofile ); linkcount++;
     119             :     }
     120             :   }
     121             : }
     122             : 
     123          55 : void ShowGraph::printAtomConnections( const ActionAtomistic* a, unsigned& linkcount, const bool& force, OFile& ofile ) {
     124          55 :   if( !a ) return;
     125         179 :   for(const auto & d : a->getDependencies() ) {
     126         138 :     ActionToPutData* dp=dynamic_cast<ActionToPutData*>(d);
     127         138 :     if( dp && dp->getLabel()=="posx" ) {
     128          18 :       if( force && (dp->copyOutput(0))->forcesWereAdded() ) {
     129           8 :         ofile.printf("%s --> MD\n", getLabel(a).c_str() );
     130           8 :         ofile.printf("linkStyle %d stroke:violet,color:violet;\n", linkcount); linkcount++;
     131             :       } else {
     132          10 :         ofile.printf("MD --> %s\n", getLabel(a).c_str() );
     133          10 :         ofile.printf("linkStyle %d stroke:violet,color:violet;\n", linkcount); linkcount++;
     134             :       }
     135         120 :     } else if( dp && dp->getLabel()!="posy" && dp->getLabel()!="posz" && dp->getLabel()!="Masses" && dp->getLabel()!="Charges" ) {
     136          21 :       if( force && (dp->copyOutput(0))->forcesWereAdded() ) {
     137          18 :         ofile.printf("%s -- %s --> %s\n",getLabel(a).c_str(), getLabel(d).c_str(), getLabel(d).c_str() );
     138           9 :         printStyle( linkcount, dp->copyOutput(0), ofile ); linkcount++;
     139             :       } else {
     140          24 :         ofile.printf("%s -- %s --> %s\n", getLabel(d).c_str(),getLabel(d).c_str(),getLabel(a).c_str() );
     141          12 :         printStyle( linkcount, dp->copyOutput(0), ofile ); linkcount++;
     142             :       }
     143          21 :       continue;
     144             :     }
     145         117 :     ActionWithVirtualAtom* dv=dynamic_cast<ActionWithVirtualAtom*>(d);
     146         117 :     if( dv ) {
     147           4 :       if( force && (dv->copyOutput(0))->forcesWereAdded() ) {
     148           2 :         ofile.printf("%s -- %s --> %s\n", getLabel(a).c_str(),getLabel(d).c_str(),getLabel(d).c_str() );
     149           1 :         ofile.printf("linkStyle %d stroke:violet,color:violet;\n", linkcount); linkcount++;
     150             :       } else {
     151           6 :         ofile.printf("%s -- %s --> %s\n", getLabel(d).c_str(),getLabel(d).c_str(),getLabel(a).c_str() );
     152           3 :         ofile.printf("linkStyle %d stroke:violet,color:violet;\n", linkcount); linkcount++;
     153             :       }
     154             :     }
     155             :   }
     156             : }
     157             : 
     158          30 : void ShowGraph::drawActionWithVectorNode( OFile& ofile, PlumedMain& p, Action* ag, const std::vector<std::string>& mychain, std::vector<bool>& printed ) {
     159          30 :   ActionWithVector* agg=dynamic_cast<ActionWithVector*>(ag);
     160          30 :   std::vector<std::string> matchain; agg->getAllActionLabelsInMatrixChain( matchain );
     161          30 :   if( matchain.size()>0 ) {
     162          16 :     ofile.printf("subgraph sub%s_mat [%s]\n",getLabel(agg).c_str(), getLabel(agg).c_str());
     163          24 :     for(unsigned j=0; j<matchain.size(); ++j ) {
     164          16 :       Action* agm=p.getActionSet().selectWithLabel<Action*>(matchain[j]);
     165          60 :       for(unsigned k=0; k<mychain.size(); ++k ) {
     166          76 :         if( mychain[k]==matchain[j] ) { printed[k]=true; break; }
     167             :       }
     168          32 :       ofile.printf("%s([\"label=%s \n %s \n\"])\n", getLabel(matchain[j]).c_str(), getLabel(matchain[j],true).c_str(), agm->writeInGraph().c_str() );
     169             :     }
     170           8 :     ofile.printf("end\n");
     171          16 :     ofile.printf("style sub%s_mat fill:lightblue\n",getLabel(ag).c_str());
     172          44 :   } else ofile.printf("%s([\"label=%s \n %s \n\"])\n", getLabel(ag->getLabel()).c_str(), getLabel(ag->getLabel(),true).c_str(), ag->writeInGraph().c_str() );
     173          30 : }
     174             : 
     175           8 : int ShowGraph::main(FILE* in, FILE*out,Communicator& pc) {
     176             : 
     177          16 :   std::string inpt; parse("--plumed",inpt);
     178           8 :   std::string outp; parse("--out",outp);
     179           8 :   bool forces; parseFlag("--force",forces);
     180             : 
     181             :   // Create a plumed main object and initilize
     182           8 :   PlumedMain p; int rr=sizeof(double);
     183           8 :   p.cmd("setRealPrecision",&rr);
     184           8 :   double lunit=1.0; p.cmd("setMDLengthUnits",&lunit);
     185           8 :   double cunit=1.0; p.cmd("setMDChargeUnits",&cunit);
     186           8 :   double munit=1.0; p.cmd("setMDMassUnits",&munit);
     187           8 :   p.cmd("setPlumedDat",inpt.c_str());
     188           8 :   p.cmd("setLog",out);
     189          24 :   int natoms=1000000; p.cmd("setNatoms",&natoms);
     190           8 :   p.cmd("init");
     191             : 
     192           8 :   unsigned linkcount=0; OFile ofile; ofile.open(outp);
     193           8 :   if( forces ) {
     194           4 :     unsigned step=1; p.cmd("setStep",step);
     195           4 :     p.cmd("prepareCalc");
     196           4 :     ofile.printf("flowchart BT \n"); std::vector<std::string> drawn_nodes; std::set<std::string> atom_force_set;
     197         103 :     for(auto pp=p.getActionSet().rbegin(); pp!=p.getActionSet().rend(); ++pp) {
     198             :       const auto & a(pp->get());
     199         534 :       if( a->getName()=="DOMAIN_DECOMPOSITION" || a->getLabel()=="posx" || a->getLabel()=="posy" || a->getLabel()=="posz" || a->getLabel()=="Masses" || a->getLabel()=="Charges" ) continue;
     200             : 
     201          75 :       if(a->isActive()) {
     202          44 :         ActionToPutData* ap=dynamic_cast<ActionToPutData*>(a);
     203          44 :         if( ap ) {
     204           8 :           ofile.printf("%s(\"label=%s \n %s \n\")\n", getLabel(a).c_str(), getLabel(a,true).c_str(), a->writeInGraph().c_str() );
     205           4 :           continue;
     206             :         }
     207          40 :         ActionWithValue* av=dynamic_cast<ActionWithValue*>(a);
     208          40 :         if( !av ) continue ;
     209             :         // Now apply the force if there is one
     210          35 :         a->apply();
     211             :         bool hasforce=false;
     212          67 :         for(int i=0; i<av->getNumberOfComponents(); ++i) {
     213          42 :           if( (av->copyOutput(i))->forcesWereAdded() ) { hasforce=true; break; }
     214             :         }
     215             :         //Check if there are forces here
     216          35 :         ActionWithArguments* aaa=dynamic_cast<ActionWithArguments*>(a);
     217          35 :         if( aaa ) {
     218          46 :           for(const auto & v : aaa->getArguments() ) {
     219          30 :             if( v->forcesWereAdded() ) { hasforce=true; break; }
     220             :           }
     221             :         }
     222          35 :         if( !hasforce ) continue;
     223          21 :         ActionWithVector* avec=dynamic_cast<ActionWithVector*>(a);
     224          21 :         if( avec ) {
     225           8 :           ActionWithVector* head=avec->getFirstActionInChain();
     226           8 :           std::vector<std::string> mychain; head->getAllActionLabelsInChain( mychain ); std::vector<bool> printed(mychain.size(),false);
     227          16 :           ofile.printf("subgraph sub%s [%s]\n",getLabel(head).c_str(),getLabel(head).c_str());
     228          70 :           for(unsigned i=0; i<mychain.size(); ++i) {
     229             :             bool drawn=false;
     230         314 :             for(unsigned j=0; j<drawn_nodes.size(); ++j ) {
     231         294 :               if( drawn_nodes[j]==mychain[i] ) { drawn=true; break; }
     232             :             }
     233          62 :             if( drawn ) continue;
     234          20 :             ActionWithVector* ag=p.getActionSet().selectWithLabel<ActionWithVector*>(mychain[i]); plumed_assert( ag ); drawn_nodes.push_back( mychain[i] );
     235          20 :             if( !printed[i] ) { drawActionWithVectorNode( ofile, p, ag, mychain, printed ); printed[i]=true; }
     236          41 :             for(const auto & v : ag->getArguments() ) {
     237             :               bool chain_conn=false;
     238         109 :               for(unsigned j=0; j<mychain.size(); ++j) {
     239         105 :                 if( (v->getPntrToAction())->getLabel()==mychain[j] ) { chain_conn=true; break; }
     240             :               }
     241          21 :               if( !chain_conn ) continue;
     242          34 :               ofile.printf("%s -. %s .-> %s\n", getLabel(v->getPntrToAction()).c_str(),v->getName().c_str(),getLabel(ag).c_str() );
     243          17 :               printStyle( linkcount, v, ofile ); linkcount++;
     244             :             }
     245             :           }
     246           8 :           ofile.printf("end\n");
     247           8 :           if( avec!=head ) {
     248          70 :             for(unsigned i=0; i<mychain.size(); ++i) {
     249          62 :               ActionWithVector* c = p.getActionSet().selectWithLabel<ActionWithVector*>( mychain[i] ); plumed_assert(c);
     250          62 :               if( c->getNumberOfAtoms()>0 || c->hasStoredArguments() ) {
     251          60 :                 for(unsigned j=0; j<avec->getNumberOfComponents(); ++j ) {
     252          30 :                   if( avec->copyOutput(j)->getRank()>0 ) continue;
     253          20 :                   ofile.printf("%s == %s ==> %s\n", getLabel(avec).c_str(), avec->copyOutput(j)->getName().c_str(), getLabel(c).c_str() );
     254          10 :                   linkcount++;
     255             :                 }
     256          30 :                 if( c->getNumberOfAtoms()>0 ) atom_force_set.insert( c->getLabel() );
     257             :               }
     258             :             }
     259             :           }
     260           8 :         } else {
     261             :           // Print out the node if we have force on it
     262          26 :           ofile.printf("%s([\"label=%s \n %s \n\"])\n", getLabel(a).c_str(), getLabel(a,true).c_str(), a->writeInGraph().c_str() );
     263             :         }
     264             :         // Check where this force is being added
     265          21 :         printArgumentConnections( aaa, linkcount, true, ofile );
     266             :       }
     267             :     }
     268             :     // Now draw connections from action atomistic to relevant actions
     269           4 :     std::vector<ActionAtomistic*> all_atoms = p.getActionSet().select<ActionAtomistic*>();
     270          33 :     for(const auto & at : all_atoms ) {
     271          29 :       ActionWithValue* av=dynamic_cast<ActionWithValue*>(at); bool hasforce=false;
     272          29 :       if( av ) {
     273          44 :         for(unsigned i=0; i<av->getNumberOfComponents(); ++i ) {
     274          26 :           if( av->copyOutput(i)->forcesWereAdded() ) {
     275           8 :             printAtomConnections( at, linkcount, true, ofile );
     276           8 :             atom_force_set.erase( av->getLabel() ); break;
     277             :           }
     278             :         }
     279             :       }
     280             :     }
     281           9 :     for(const auto & l : atom_force_set ) {
     282           5 :       ActionAtomistic* at = p.getActionSet().selectWithLabel<ActionAtomistic*>(l);
     283           5 :       plumed_assert(at); printAtomConnections( at, linkcount, true, ofile );
     284             :     }
     285           4 :     ofile.printf("MD(positions from MD)\n");
     286             :     return 0;
     287           4 :   }
     288             : 
     289           4 :   ofile.printf("flowchart TB \n"); ofile.printf("MD(positions from MD)\n");
     290          98 :   for(const auto & aa : p.getActionSet() ) {
     291             :     Action* a(aa.get());
     292         504 :     if( a->getName()=="DOMAIN_DECOMPOSITION" || a->getLabel()=="posx" || a->getLabel()=="posy" || a->getLabel()=="posz" || a->getLabel()=="Masses" || a->getLabel()=="Charges" ) continue;
     293          70 :     ActionToPutData* ap=dynamic_cast<ActionToPutData*>(a);
     294          70 :     if( ap ) {
     295           8 :       ofile.printf("%s(\"label=%s \n %s \n\")\n", getLabel(a).c_str(), getLabel(a,true).c_str(), a->writeInGraph().c_str() );
     296           4 :       continue;
     297             :     }
     298          66 :     ActionShortcut* as=dynamic_cast<ActionShortcut*>(a); if( as ) continue ;
     299          42 :     ActionWithValue* av=dynamic_cast<ActionWithValue*>(a);
     300          42 :     ActionWithArguments* aaa=dynamic_cast<ActionWithArguments*>(a);
     301          42 :     ActionAtomistic* at=dynamic_cast<ActionAtomistic*>(a);
     302          42 :     ActionWithVector* avec=dynamic_cast<ActionWithVector*>(a);
     303             :     // Print out the connections between nodes
     304          42 :     printAtomConnections( at, linkcount, false, ofile );
     305          42 :     printArgumentConnections( aaa, linkcount, false, ofile );
     306             :     // Print out the nodes
     307          42 :     if( avec && !avec->actionInChain() ) {
     308           6 :       ofile.printf("subgraph sub%s [%s]\n",getLabel(a).c_str(),getLabel(a).c_str());
     309           3 :       std::vector<std::string> mychain; avec->getAllActionLabelsInChain( mychain ); std::vector<bool> printed(mychain.size(),false);
     310          21 :       for(unsigned i=0; i<mychain.size(); ++i) {
     311          18 :         Action* ag=p.getActionSet().selectWithLabel<Action*>(mychain[i]);
     312          18 :         if( !printed[i] ) { drawActionWithVectorNode( ofile, p, ag, mychain, printed ); printed[i]=true; }
     313             :       }
     314           3 :       ofile.printf("end\n");
     315          42 :     } else if( !av ) {
     316          22 :       ofile.printf("%s(\"label=%s \n %s \n\")\n", getLabel(a).c_str(), getLabel(a,true).c_str(), a->writeInGraph().c_str() );
     317          28 :     } else if( !avec ) {
     318          26 :       ofile.printf("%s([\"label=%s \n %s \n\"])\n", getLabel(a).c_str(), getLabel(a,true).c_str(), a->writeInGraph().c_str() );
     319             :     }
     320             :   }
     321           4 :   ofile.close();
     322             : 
     323             :   return 0;
     324           8 : }
     325             : 
     326             : } // End of namespace
     327             : }

Generated by: LCOV version 1.16