LCOV - code coverage report
Current view: top level - metatensor - metatensor.cpp (source / functions) Hit Total Coverage
Test: plumed test coverage Lines: 16 24 66.7 %
Date: 2025-04-08 21:11:17 Functions: 1 5 20.0 %

          Line data    Source code
       1             : /* +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
       2             : Copyright (c) 2024 The METATENSOR code team
       3             : (see the PEOPLE-METATENSOR file at the root of this folder for a list of names)
       4             : 
       5             : See https://docs.metatensor.org/latest/ for more information about the
       6             : metatensor package that this module allows you to call from PLUMED.
       7             : 
       8             : This file is part of METATENSOR-PLUMED module.
       9             : 
      10             : The METATENSOR-PLUMED module is free software: you can redistribute it and/or modify
      11             : it under the terms of the GNU Lesser General Public License as published by
      12             : the Free Software Foundation, either version 3 of the License, or
      13             : (at your option) any later version.
      14             : 
      15             : The METATENSOR-PLUMED module is distributed in the hope that it will be useful,
      16             : but WITHOUT ANY WARRANTY; without even the implied warranty of
      17             : MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
      18             : GNU Lesser General Public License for more details.
      19             : 
      20             : You should have received a copy of the GNU Lesser General Public License
      21             : along with the METATENSOR-PLUMED module. If not, see <http://www.gnu.org/licenses/>.
      22             : +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ */
      23             : 
      24             : #include "core/ActionAtomistic.h"
      25             : #include "core/ActionWithValue.h"
      26             : #include "core/ActionRegister.h"
      27             : #include "core/PlumedMain.h"
      28             : 
      29             : //+PLUMEDOC METATENSORMOD_COLVAR METATENSOR
      30             : /*
      31             : Use arbitrary machine learning models as collective variables.
      32             : 
      33             : Note that this action requires the metatensor-torch library. Check the
      34             : instructions in the \ref METATENSORMOD page to enable this module.
      35             : 
      36             : This action enables the use of fully custom machine learning models — based on
      37             : the [metatensor atomistic models][mts_models] interface — as collective
      38             : variables in PLUMED. Such machine learning model are typically written and
      39             : customized using Python code, and then exported to run within PLUMED as
      40             : [TorchScript], which is a subset of Python that can be executed by the C++ torch
      41             : library.
      42             : 
      43             : Metatensor offers a way to define such models and pass data from PLUMED (or any
      44             : other simulation engine) to the model and back. For more information on how to
      45             : define such model, have a look at the [corresponding tutorials][mts_tutorials],
      46             : or at the code in `regtest/metatensor/`. Each of the Python scripts in this
      47             : directory defines a custom machine learning CV that can be used with PLUMED.
      48             : 
      49             : \par Examples
      50             : 
      51             : The following input shows how you can call metatensor and evaluate the model
      52             : that is described in the file `custom_cv.pt` from PLUMED.
      53             : 
      54             : \plumedfile metatensor_cv: METATENSOR ... MODEL=custom_cv.pt
      55             : 
      56             :     SPECIES1=1-26
      57             :     SPECIES2=27-62
      58             :     SPECIES3=63-76
      59             :     SPECIES_TO_TYPES=6,1,8
      60             : ...
      61             : \endplumedfile
      62             : 
      63             : The numbered `SPECIES` labels are used to indicate the list of atoms that belong
      64             : to each atomic species in the system. The `SPECIES_TO_TYPE` keyword then
      65             : provides information on the atom type for each species. The first number here is
      66             : the atomic type of the atoms that have been specified using the `SPECIES1` flag,
      67             : the second number is the atomic number of the atoms that have been specified
      68             : using the `SPECIES2` flag and so on.
      69             : 
      70             : `METATENSOR` action also accepts the following options:
      71             : 
      72             : - `EXTENSIONS_DIRECTORY` should be the path to a directory containing
      73             :   TorchScript extensions (as shared libraries) that are required to load and
      74             :   execute the model. This matches the `collect_extensions` argument to
      75             :   `MetatensorAtomisticModel.export` in Python.
      76             : - `CHECK_CONSISTENCY` can be used to enable internal consistency checks;
      77             : - `SELECTED_ATOMS` can be used to signal the metatensor models that it should
      78             :   only run its calculation for the selected subset of atoms. The model still
      79             :   need to know about all the atoms in the system (through the `SPECIES`
      80             :   keyword); but this can be used to reduce the calculation cost. Note that the
      81             :   indices of the selected atoms should start at 1 in the PLUMED input file, but
      82             :   they will be translated to start at 0 when given to the model (i.e. in
      83             :   Python/TorchScript, the `forward` method will receive a `selected_atoms` which
      84             :   starts at 0)
      85             : 
      86             : Here is another example with all the possible keywords:
      87             : 
      88             : \plumedfile soap: METATENSOR ... MODEL=soap.pt EXTENSION_DIRECTORY=extensions
      89             : CHECK_CONSISTENCY
      90             : 
      91             :     SPECIES1=1-10
      92             :     SPECIES2=11-20
      93             :     SPECIES_TO_TYPES=8,13
      94             : 
      95             :     # only run the calculation for the Aluminium (type 13) atoms, but
      96             :     # include the Oxygen (type 8) as potential neighbors.
      97             :     SELECTED_ATOMS=11-20
      98             : ...
      99             : \endplumedfile
     100             : 
     101             : \par Collective variables and metatensor models
     102             : 
     103             : PLUMED can use the [`"features"` output][features_output] of metatensor
     104             : atomistic models as a collective variables. Alternatively, the code also accepts
     105             : an output named `"plumed::cv"`, with the same metadata structure as the
     106             : `"features"` output.
     107             : 
     108             : */ /*
     109             : 
     110             : [TorchScript]: https://pytorch.org/docs/stable/jit.html
     111             : [mts_models]: https://docs.metatensor.org/latest/atomistic/index.html
     112             : [mts_tutorials]: https://docs.metatensor.org/latest/examples/atomistic/index.html
     113             : [mts_block]: https://docs.metatensor.org/latest/torch/reference/block.html
     114             : [features_output]: https://docs.metatensor.org/latest/examples/atomistic/outputs/features.html
     115             : */
     116             : //+ENDPLUMEDOC
     117             : 
     118             : /*INDENT-OFF*/
     119             : #if !defined(__PLUMED_HAS_METATENSOR) || !defined(__PLUMED_HAS_LIBTORCH)
     120             : 
     121             : namespace PLMD { namespace metatensor {
     122             : class MetatensorPlumedAction: public ActionAtomistic, public ActionWithValue {
     123             : public:
     124             :     static void registerKeywords(Keywords& keys);
     125           0 :     explicit MetatensorPlumedAction(const ActionOptions& options):
     126             :         Action(options),
     127             :         ActionAtomistic(options),
     128           0 :         ActionWithValue(options)
     129             :     {
     130           0 :         throw std::runtime_error(
     131             :             "Can not use metatensor action without the corresponding libraries. \n"
     132             :             "Make sure to configure with `--enable-metatensor --enable-libtorch` "
     133             :             "and that the corresponding libraries are found"
     134           0 :         );
     135           0 :     }
     136             : 
     137           0 :     void calculate() override {}
     138           0 :     void apply() override {}
     139           0 :     unsigned getNumberOfDerivatives() override {return 0;}
     140             : };
     141             : 
     142             : }} // namespace PLMD::metatensor
     143             : 
     144             : #else
     145             : 
     146             : #include <type_traits>
     147             : 
     148             : #pragma GCC diagnostic push
     149             : #pragma GCC diagnostic ignored "-Wpedantic"
     150             : #pragma GCC diagnostic ignored "-Wunused-parameter"
     151             : #pragma GCC diagnostic ignored "-Wfloat-equal"
     152             : #pragma GCC diagnostic ignored "-Wfloat-conversion"
     153             : #pragma GCC diagnostic ignored "-Wimplicit-float-conversion"
     154             : #pragma GCC diagnostic ignored "-Wimplicit-int-conversion"
     155             : #pragma GCC diagnostic ignored "-Wshorten-64-to-32"
     156             : #pragma GCC diagnostic ignored "-Wsign-conversion"
     157             : #pragma GCC diagnostic ignored "-Wold-style-cast"
     158             : 
     159             : #include <torch/script.h>
     160             : #include <torch/version.h>
     161             : #include <torch/cuda.h>
     162             : #if TORCH_VERSION_MAJOR >= 2
     163             : #include <torch/mps.h>
     164             : #endif
     165             : 
     166             : #pragma GCC diagnostic pop
     167             : 
     168             : #include <metatensor/torch.hpp>
     169             : #include <metatensor/torch/atomistic.hpp>
     170             : 
     171             : #include "vesin.h"
     172             : 
     173             : 
     174             : namespace PLMD {
     175             : namespace metatensor {
     176             : 
     177             : // We will cast Vector/Tensor to pointers to arrays and doubles, so let's make
     178             : // sure this is legal to do
     179             : static_assert(std::is_standard_layout<PLMD::Vector>::value);
     180             : static_assert(sizeof(PLMD::Vector) == sizeof(std::array<double, 3>));
     181             : static_assert(alignof(PLMD::Vector) == alignof(std::array<double, 3>));
     182             : 
     183             : static_assert(std::is_standard_layout<PLMD::Tensor>::value);
     184             : static_assert(sizeof(PLMD::Tensor) == sizeof(std::array<std::array<double, 3>, 3>));
     185             : static_assert(alignof(PLMD::Tensor) == alignof(std::array<std::array<double, 3>, 3>));
     186             : 
     187             : class MetatensorPlumedAction: public ActionAtomistic, public ActionWithValue {
     188             : public:
     189             :     static void registerKeywords(Keywords& keys);
     190             :     explicit MetatensorPlumedAction(const ActionOptions&);
     191             : 
     192             :     void calculate() override;
     193             :     void apply() override;
     194             :     unsigned getNumberOfDerivatives() override;
     195             : 
     196             : private:
     197             :     // fill this->system_ according to the current PLUMED data
     198             :     void createSystem();
     199             :     // compute a neighbor list following metatensor format, using data from PLUMED
     200             :     metatensor_torch::TensorBlock computeNeighbors(
     201             :         metatensor_torch::NeighborListOptions request,
     202             :         const std::vector<PLMD::Vector>& positions,
     203             :         const PLMD::Tensor& cell,
     204             :         bool periodic
     205             :     );
     206             : 
     207             :     // execute the model for the given system
     208             :     metatensor_torch::TensorBlock executeModel(metatensor_torch::System system);
     209             : 
     210             :     torch::jit::Module model_;
     211             : 
     212             :     metatensor_torch::ModelCapabilities capabilities_;
     213             :     std::string model_output_;
     214             : 
     215             :     // neighbor lists requests made by the model
     216             :     std::vector<metatensor_torch::NeighborListOptions> nl_requests_;
     217             : 
     218             :     // dtype/device to use to execute the model
     219             :     torch::ScalarType dtype_;
     220             :     torch::Device device_;
     221             : 
     222             :     torch::Tensor atomic_types_;
     223             :     // store the strain to be able to compute the virial with autograd
     224             :     torch::Tensor strain_;
     225             : 
     226             :     metatensor_torch::System system_;
     227             :     metatensor_torch::ModelEvaluationOptions evaluations_options_;
     228             :     bool check_consistency_;
     229             : 
     230             :     metatensor_torch::TensorMap output_;
     231             :     // shape of the output of this model
     232             :     unsigned n_samples_;
     233             :     unsigned n_properties_;
     234             : };
     235             : 
     236             : 
     237             : MetatensorPlumedAction::MetatensorPlumedAction(const ActionOptions& options):
     238             :     Action(options),
     239             :     ActionAtomistic(options),
     240             :     ActionWithValue(options),
     241             :     device_(torch::kCPU)
     242             : {
     243             :     if (metatensor_torch::version().find("0.7.") != 0) {
     244             :         this->error(
     245             :             "this code requires version 0.7.x of metatensor-torch, got version " +
     246             :             metatensor_torch::version()
     247             :         );
     248             :     }
     249             : 
     250             :     // first, load the model
     251             :     std::string extensions_directory_str;
     252             :     this->parse("EXTENSIONS_DIRECTORY", extensions_directory_str);
     253             : 
     254             :     torch::optional<std::string> extensions_directory = torch::nullopt;
     255             :     if (!extensions_directory_str.empty()) {
     256             :         extensions_directory = std::move(extensions_directory_str);
     257             :     }
     258             : 
     259             :     std::string model_path;
     260             :     this->parse("MODEL", model_path);
     261             : 
     262             :     try {
     263             :         this->model_ = metatensor_torch::load_atomistic_model(model_path, extensions_directory);
     264             :     } catch (const std::exception& e) {
     265             :         this->error("failed to load model at '" + model_path + "': " + e.what());
     266             :     }
     267             : 
     268             :     // extract information from the model
     269             :     auto metadata = this->model_.run_method("metadata").toCustomClass<metatensor_torch::ModelMetadataHolder>();
     270             :     this->capabilities_ = this->model_.run_method("capabilities").toCustomClass<metatensor_torch::ModelCapabilitiesHolder>();
     271             :     auto requests_ivalue = this->model_.run_method("requested_neighbor_lists");
     272             :     for (auto request_ivalue: requests_ivalue.toList()) {
     273             :         auto request = request_ivalue.get().toCustomClass<metatensor_torch::NeighborListOptionsHolder>();
     274             :         this->nl_requests_.push_back(request);
     275             :     }
     276             : 
     277             :     log.printf("\n%s\n", metadata->print().c_str());
     278             :     // add the model references to PLUMED citation handling mechanism
     279             :     for (const auto& it: metadata->references) {
     280             :         for (const auto& ref: it.value()) {
     281             :             this->cite(ref);
     282             :         }
     283             :     }
     284             : 
     285             :     // parse the atomic types from the input file
     286             :     std::vector<int32_t> atomic_types;
     287             :     std::vector<int32_t> species_to_types;
     288             :     this->parseVector("SPECIES_TO_TYPES", species_to_types);
     289             :     bool has_custom_types = !species_to_types.empty();
     290             : 
     291             :     std::vector<AtomNumber> all_atoms;
     292             :     this->parseAtomList("SPECIES", all_atoms);
     293             : 
     294             :     size_t n_species = 0;
     295             :     if (all_atoms.empty()) {
     296             :         std::vector<AtomNumber> t;
     297             :         int i = 0;
     298             :         while (true) {
     299             :             i += 1;
     300             :             this->parseAtomList("SPECIES", i, t);
     301             :             if (t.empty()) {
     302             :                 break;
     303             :             }
     304             : 
     305             :             int32_t type = i;
     306             :             if (has_custom_types) {
     307             :                 if (species_to_types.size() < static_cast<size_t>(i)) {
     308             :                     this->error(
     309             :                         "SPECIES_TO_TYPES is too small, it should have one entry "
     310             :                         "for each species (we have at least " + std::to_string(i) +
     311             :                         " species and " + std::to_string(species_to_types.size()) +
     312             :                         "entries in SPECIES_TO_TYPES)"
     313             :                     );
     314             :                 }
     315             : 
     316             :                 type = species_to_types[static_cast<size_t>(i - 1)];
     317             :             }
     318             : 
     319             :             log.printf("  atoms with type %d are: ", type);
     320             :             for(unsigned j=0; j<t.size(); j++) {
     321             :                 log.printf("%d ", t[j]);
     322             :                 all_atoms.push_back(t[j]);
     323             :                 atomic_types.push_back(type);
     324             :             }
     325             :             log.printf("\n"); t.resize(0);
     326             : 
     327             :             n_species += 1;
     328             :         }
     329             :     } else {
     330             :         n_species = 1;
     331             : 
     332             :         int32_t type = 1;
     333             :         if (has_custom_types) {
     334             :             type = species_to_types[0];
     335             :         }
     336             :         atomic_types.resize(all_atoms.size(), type);
     337             :     }
     338             : 
     339             :     if (has_custom_types && species_to_types.size() != n_species) {
     340             :         this->warning(
     341             :             "SPECIES_TO_TYPES contains more entries (" +
     342             :             std::to_string(species_to_types.size()) +
     343             :             ") than there where species (" + std::to_string(n_species) + ")"
     344             :         );
     345             :     }
     346             : 
     347             :     this->atomic_types_ = torch::tensor(std::move(atomic_types));
     348             :     this->requestAtoms(all_atoms);
     349             : 
     350             :     this->check_consistency_ = false;
     351             :     this->parseFlag("CHECK_CONSISTENCY", this->check_consistency_);
     352             :     if (this->check_consistency_) {
     353             :         log.printf("  checking for internal consistency of the model\n");
     354             :     }
     355             : 
     356             :     // create evaluation options for the model. These won't change during the
     357             :     // simulation, so we initialize them once here.
     358             :     evaluations_options_ = torch::make_intrusive<metatensor_torch::ModelEvaluationOptionsHolder>();
     359             :     evaluations_options_->set_length_unit(getUnits().getLengthString());
     360             : 
     361             :     auto outputs = this->capabilities_->outputs();
     362             :     if (outputs.contains("features")) {
     363             :         this->model_output_ = "features";
     364             :     }
     365             : 
     366             :     if (outputs.contains("plumed::cv")) {
     367             :         if (outputs.contains("features")) {
     368             :             this->warning(
     369             :                 "this model exposes both 'features' and 'plumed::cv' outputs, "
     370             :                 "we will use 'features'. 'plumed::cv' is deprecated, please "
     371             :                 "remove it from your models"
     372             :             );
     373             :         } else {
     374             :             this->warning(
     375             :                 "this model is using 'plumed::cv' output, which is deprecated. "
     376             :                 "Please replace it with a 'features' output"
     377             :             );
     378             :             this->model_output_ = "plumed::cv";
     379             :         }
     380             :     }
     381             : 
     382             : 
     383             :     if (this->model_output_.empty()) {
     384             :         auto existing_outputs = std::vector<std::string>();
     385             :         for (const auto& it: this->capabilities_->outputs()) {
     386             :             existing_outputs.push_back(it.key());
     387             :         }
     388             : 
     389             :         this->error(
     390             :             "expected 'features' or 'plumed::cv' in the capabilities of the model, "
     391             :             "could not find it. the following outputs exist: " + torch::str(existing_outputs)
     392             :         );
     393             :     }
     394             : 
     395             :     auto output = torch::make_intrusive<metatensor_torch::ModelOutputHolder>();
     396             :     // this output has no quantity or unit to set
     397             : 
     398             :     output->per_atom = this->capabilities_->outputs().at(this->model_output_)->per_atom;
     399             :     // we are using torch autograd system to compute gradients,
     400             :     // so we don't need any explicit gradients.
     401             :     output->explicit_gradients = {};
     402             :     evaluations_options_->outputs.insert(this->model_output_, output);
     403             : 
     404             :     // Determine which device we should use based on user input, what the model
     405             :     // supports and what's available
     406             :     auto available_devices = std::vector<torch::Device>();
     407             :     for (const auto& device: this->capabilities_->supported_devices) {
     408             :         if (device == "cpu") {
     409             :             available_devices.push_back(torch::kCPU);
     410             :         } else if (device == "cuda") {
     411             :             if (torch::cuda::is_available()) {
     412             :                 available_devices.push_back(torch::Device("cuda"));
     413             :             }
     414             :         } else if (device == "mps") {
     415             :             #if TORCH_VERSION_MAJOR >= 2
     416             :             if (torch::mps::is_available()) {
     417             :                 available_devices.push_back(torch::Device("mps"));
     418             :             }
     419             :             #endif
     420             :         } else {
     421             :             this->warning(
     422             :                 "the model declared support for unknown device '" + device +
     423             :                 "', it will be ignored"
     424             :             );
     425             :         }
     426             :     }
     427             : 
     428             :     if (available_devices.empty()) {
     429             :         this->error(
     430             :             "failed to find a valid device for the model at '" + model_path + "': "
     431             :             "the model supports " + torch::str(this->capabilities_->supported_devices) +
     432             :             ", none of these where available"
     433             :         );
     434             :     }
     435             : 
     436             :     std::string requested_device;
     437             :     this->parse("DEVICE", requested_device);
     438             :     if (requested_device.empty()) {
     439             :         // no user request, pick the device the model prefers
     440             :         this->device_ = available_devices[0];
     441             :     } else {
     442             :         bool found_requested_device = false;
     443             :         for (const auto& device: available_devices) {
     444             :             if (device.is_cpu() && requested_device == "cpu") {
     445             :                 this->device_ = device;
     446             :                 found_requested_device = true;
     447             :                 break;
     448             :             } else if (device.is_cuda() && requested_device == "cuda") {
     449             :                 this->device_ = device;
     450             :                 found_requested_device = true;
     451             :                 break;
     452             :             } else if (device.is_mps() && requested_device == "mps") {
     453             :                 this->device_ = device;
     454             :                 found_requested_device = true;
     455             :                 break;
     456             :             }
     457             :         }
     458             : 
     459             :         if (!found_requested_device) {
     460             :             this->error(
     461             :                 "failed to find requested device (" + requested_device + "): it is either "
     462             :                 "not supported by this model or not available on this machine"
     463             :             );
     464             :         }
     465             :     }
     466             : 
     467             :     this->model_.to(this->device_);
     468             :     this->atomic_types_ = this->atomic_types_.to(this->device_);
     469             : 
     470             :     log.printf(
     471             :         "  running model on %s device with %s data\n",
     472             :         this->device_.str().c_str(),
     473             :         this->capabilities_->dtype().c_str()
     474             :     );
     475             : 
     476             :     if (this->capabilities_->dtype() == "float64") {
     477             :         this->dtype_ = torch::kFloat64;
     478             :     } else if (this->capabilities_->dtype() == "float32") {
     479             :         this->dtype_ = torch::kFloat32;
     480             :     } else {
     481             :         this->error(
     482             :             "the model requested an unsupported dtype '" + this->capabilities_->dtype() + "'"
     483             :         );
     484             :     }
     485             : 
     486             :     auto tensor_options = torch::TensorOptions().dtype(this->dtype_).device(this->device_);
     487             :     this->strain_ = torch::eye(3, tensor_options.requires_grad(true));
     488             : 
     489             :     // determine how many properties there will be in the output by running the
     490             :     // model once on a dummy system
     491             :     auto dummy_system = torch::make_intrusive<metatensor_torch::SystemHolder>(
     492             :         /*types = */ torch::zeros({0}, tensor_options.dtype(torch::kInt32)),
     493             :         /*positions = */ torch::zeros({0, 3}, tensor_options),
     494             :         /*cell = */ torch::zeros({3, 3}, tensor_options),
     495             :         /*pbc = */ torch::zeros({3}, tensor_options.dtype(torch::kBool))
     496             :     );
     497             : 
     498             :     log.printf("  the following neighbor lists have been requested:\n");
     499             :     auto length_unit = this->getUnits().getLengthString();
     500             :     auto model_length_unit = this->capabilities_->length_unit();
     501             :     for (auto request: this->nl_requests_) {
     502             :         log.printf("    - %s list, %g %s cutoff (requested %g %s)\n",
     503             :             request->full_list() ? "full" : "half",
     504             :             request->engine_cutoff(length_unit),
     505             :             length_unit.c_str(),
     506             :             request->cutoff(),
     507             :             model_length_unit.c_str()
     508             :         );
     509             : 
     510             :         auto neighbors = this->computeNeighbors(
     511             :             request,
     512             :             {PLMD::Vector(0, 0, 0)},
     513             :             PLMD::Tensor(0, 0, 0, 0, 0, 0, 0, 0, 0),
     514             :             false
     515             :         );
     516             :         metatensor_torch::register_autograd_neighbors(dummy_system, neighbors, this->check_consistency_);
     517             :         dummy_system->add_neighbor_list(request, neighbors);
     518             :     }
     519             : 
     520             :     this->n_properties_ = static_cast<unsigned>(
     521             :         this->executeModel(dummy_system)->properties()->count()
     522             :     );
     523             : 
     524             :     // parse and handle atom sub-selection. This is done AFTER determining the
     525             :     // output size, since the selection might not be valid for the dummy system
     526             :     std::vector<int32_t> selected_atoms;
     527             :     this->parseVector("SELECTED_ATOMS", selected_atoms);
     528             :     if (!selected_atoms.empty()) {
     529             :         auto selection_value = torch::zeros(
     530             :             {static_cast<int64_t>(selected_atoms.size()), 2},
     531             :             torch::TensorOptions().dtype(torch::kInt32).device(this->device_)
     532             :         );
     533             : 
     534             :         for (unsigned i=0; i<selected_atoms.size(); i++) {
     535             :             auto n_atoms = static_cast<int32_t>(this->atomic_types_.size(0));
     536             :             if (selected_atoms[i] <= 0 || selected_atoms[i] > n_atoms) {
     537             :                 this->error(
     538             :                     "Values in metatensor's SELECTED_ATOMS should be between 1 "
     539             :                     "and the number of atoms (" + std::to_string(n_atoms) + "), "
     540             :                     "got " + std::to_string(selected_atoms[i]));
     541             :             }
     542             :             // PLUMED input uses 1-based indexes, but metatensor wants 0-based
     543             :             selection_value[i][1] = selected_atoms[i] - 1;
     544             :         }
     545             : 
     546             :         evaluations_options_->set_selected_atoms(
     547             :             torch::make_intrusive<metatensor_torch::LabelsHolder>(
     548             :                 std::vector<std::string>{"system", "atom"}, selection_value
     549             :             )
     550             :         );
     551             :     }
     552             : 
     553             :     // Now that we now both n_samples and n_properties, we can setup the
     554             :     // PLUMED-side storage for the computed CV
     555             :     if (output->per_atom) {
     556             :         if (selected_atoms.empty()) {
     557             :             this->n_samples_ = static_cast<unsigned>(this->atomic_types_.size(0));
     558             :         } else {
     559             :             this->n_samples_ = static_cast<unsigned>(selected_atoms.size());
     560             :         }
     561             :     } else {
     562             :         this->n_samples_ = 1;
     563             :     }
     564             : 
     565             :     if (n_samples_ == 1 && n_properties_ == 1) {
     566             :         log.printf("  the output of this model is a scalar\n");
     567             : 
     568             :         this->addValue();
     569             :     } else if (n_samples_ == 1) {
     570             :         log.printf("  the output of this model is 1x%d vector\n", n_properties_);
     571             : 
     572             :         this->addValue({this->n_properties_});
     573             :         this->getPntrToComponent(0)->buildDataStore();
     574             :     } else if (n_properties_ == 1) {
     575             :         log.printf("  the output of this model is %dx1 vector\n", n_samples_);
     576             : 
     577             :         this->addValue({this->n_samples_});
     578             :         this->getPntrToComponent(0)->buildDataStore();
     579             :     } else {
     580             :         log.printf("  the output of this model is a %dx%d matrix\n", n_samples_, n_properties_);
     581             : 
     582             :         this->addValue({this->n_samples_, this->n_properties_});
     583             :         this->getPntrToComponent(0)->buildDataStore();
     584             :         this->getPntrToComponent(0)->reshapeMatrixStore(n_properties_);
     585             :     }
     586             : 
     587             :     this->setNotPeriodic();
     588             : }
     589             : 
     590             : unsigned MetatensorPlumedAction::getNumberOfDerivatives() {
     591             :     // gradients w.r.t. positions (3 x N values) + gradients w.r.t. strain (9 values)
     592             :     return 3 * this->getNumberOfAtoms() + 9;
     593             : }
     594             : 
     595             : 
     596             : void MetatensorPlumedAction::createSystem() {
     597             :     if (this->getTotAtoms() != static_cast<unsigned>(this->atomic_types_.size(0))) {
     598             :         std::ostringstream oss;
     599             :         oss << "METATENSOR action needs to know about all atoms in the system. ";
     600             :         oss << "There are " << this->getTotAtoms() << " atoms overall, ";
     601             :         oss << "but we only have atomic types for " << this->atomic_types_.size(0) << " of them.";
     602             :         plumed_merror(oss.str());
     603             :     }
     604             : 
     605             :     // this->getTotAtoms()
     606             : 
     607             :     const auto& cell = this->getPbc().getBox();
     608             : 
     609             :     auto cpu_f64_tensor = torch::TensorOptions().dtype(torch::kFloat64).device(torch::kCPU);
     610             :     auto torch_cell = torch::zeros({3, 3}, cpu_f64_tensor);
     611             : 
     612             :     torch_cell[0][0] = cell(0, 0);
     613             :     torch_cell[0][1] = cell(0, 1);
     614             :     torch_cell[0][2] = cell(0, 2);
     615             : 
     616             :     torch_cell[1][0] = cell(1, 0);
     617             :     torch_cell[1][1] = cell(1, 1);
     618             :     torch_cell[1][2] = cell(1, 2);
     619             : 
     620             :     torch_cell[2][0] = cell(2, 0);
     621             :     torch_cell[2][1] = cell(2, 1);
     622             :     torch_cell[2][2] = cell(2, 2);
     623             : 
     624             :     using torch::indexing::Slice;
     625             : 
     626             :     auto pbc_a = torch_cell.index({0, Slice()}).norm().abs().item<double>() > 1e-9;
     627             :     auto pbc_b = torch_cell.index({1, Slice()}).norm().abs().item<double>() > 1e-9;
     628             :     auto pbc_c = torch_cell.index({2, Slice()}).norm().abs().item<double>() > 1e-9;
     629             :     auto torch_pbc = torch::tensor({pbc_a, pbc_b, pbc_c});
     630             : 
     631             :     const auto& positions = this->getPositions();
     632             : 
     633             :     auto torch_positions = torch::from_blob(
     634             :         const_cast<PLMD::Vector*>(positions.data()),
     635             :         {static_cast<int64_t>(positions.size()), 3},
     636             :         cpu_f64_tensor
     637             :     );
     638             : 
     639             :     torch_positions = torch_positions.to(this->dtype_).to(this->device_);
     640             :     torch_cell = torch_cell.to(this->dtype_).to(this->device_);
     641             :     torch_pbc = torch_pbc.to(this->device_);
     642             : 
     643             :     // setup torch's automatic gradient tracking
     644             :     if (!this->doNotCalculateDerivatives()) {
     645             :         torch_positions.requires_grad_(true);
     646             : 
     647             :         // pretend to scale positions/cell by the strain so that it enters the
     648             :         // computational graph.
     649             :         torch_positions = torch_positions.matmul(this->strain_);
     650             :         torch_positions.retain_grad();
     651             : 
     652             :         torch_cell = torch_cell.matmul(this->strain_);
     653             :     }
     654             : 
     655             :     this->system_ = torch::make_intrusive<metatensor_torch::SystemHolder>(
     656             :         this->atomic_types_,
     657             :         torch_positions,
     658             :         torch_cell,
     659             :         torch_pbc
     660             :     );
     661             : 
     662             :     auto periodic = torch::all(torch_pbc).item<bool>();
     663             :     if (!periodic && torch::any(torch_pbc).item<bool>()) {
     664             :         std::string periodic_directions;
     665             :         std::string non_periodic_directions;
     666             :         if (pbc_a) {
     667             :             periodic_directions += "A";
     668             :         } else {
     669             :             non_periodic_directions += "A";
     670             :         }
     671             : 
     672             :         if (pbc_b) {
     673             :             periodic_directions += "B";
     674             :         } else {
     675             :             non_periodic_directions += "B";
     676             :         }
     677             : 
     678             :         if (pbc_c) {
     679             :             periodic_directions += "C";
     680             :         } else {
     681             :             non_periodic_directions += "C";
     682             :         }
     683             : 
     684             :         plumed_merror(
     685             :             "mixed periodic boundary conditions are not supported, this system "
     686             :             "is periodic along the " + periodic_directions + " cell vector(s), "
     687             :             "but not along the " + non_periodic_directions + " cell vector(s)."
     688             :         );
     689             :     }
     690             : 
     691             :     // compute the neighbors list requested by the model, and register them with
     692             :     // the system
     693             :     for (auto request: this->nl_requests_) {
     694             :         auto neighbors = this->computeNeighbors(request, positions, cell, periodic);
     695             :         metatensor_torch::register_autograd_neighbors(this->system_, neighbors, this->check_consistency_);
     696             :         this->system_->add_neighbor_list(request, neighbors);
     697             :     }
     698             : }
     699             : 
     700             : 
     701             : metatensor_torch::TensorBlock MetatensorPlumedAction::computeNeighbors(
     702             :     metatensor_torch::NeighborListOptions request,
     703             :     const std::vector<PLMD::Vector>& positions,
     704             :     const PLMD::Tensor& cell,
     705             :     bool periodic
     706             : ) {
     707             :     auto labels_options = torch::TensorOptions().dtype(torch::kInt32).device(this->device_);
     708             :     auto neighbor_component = torch::make_intrusive<metatensor_torch::LabelsHolder>(
     709             :         "xyz",
     710             :         torch::tensor({0, 1, 2}, labels_options).reshape({3, 1})
     711             :     );
     712             :     auto neighbor_properties = torch::make_intrusive<metatensor_torch::LabelsHolder>(
     713             :         "distance", torch::zeros({1, 1}, labels_options)
     714             :     );
     715             : 
     716             :     auto cutoff = request->engine_cutoff(this->getUnits().getLengthString());
     717             : 
     718             :     // use https://github.com/Luthaf/vesin to compute the requested neighbor
     719             :     // lists since we can not get these from PLUMED
     720             :     vesin::VesinOptions options;
     721             :     options.cutoff = cutoff;
     722             :     options.full = request->full_list();
     723             :     options.return_shifts = true;
     724             :     options.return_distances = false;
     725             :     options.return_vectors = true;
     726             : 
     727             :     vesin::VesinNeighborList* vesin_neighbor_list = new vesin::VesinNeighborList();
     728             :     memset(vesin_neighbor_list, 0, sizeof(vesin::VesinNeighborList));
     729             : 
     730             :     const char* error_message = NULL;
     731             :     int status = vesin_neighbors(
     732             :         reinterpret_cast<const double (*)[3]>(positions.data()),
     733             :         positions.size(),
     734             :         reinterpret_cast<const double (*)[3]>(&cell(0, 0)),
     735             :         periodic,
     736             :         vesin::VesinCPU,
     737             :         options,
     738             :         vesin_neighbor_list,
     739             :         &error_message
     740             :     );
     741             : 
     742             :     if (status != EXIT_SUCCESS) {
     743             :         plumed_merror(
     744             :             "failed to compute neighbor list (cutoff=" + std::to_string(cutoff) +
     745             :             ", full=" + (request->full_list() ? "true" : "false") + "): " + error_message
     746             :         );
     747             :     }
     748             : 
     749             :     // transform from vesin to metatensor format
     750             :     auto n_pairs = static_cast<int64_t>(vesin_neighbor_list->length);
     751             : 
     752             :     auto pair_vectors = torch::from_blob(
     753             :         vesin_neighbor_list->vectors,
     754             :         {n_pairs, 3, 1},
     755             :         /*deleter*/ [=](void*) {
     756             :             vesin_free(vesin_neighbor_list);
     757             :             delete vesin_neighbor_list;
     758             :         },
     759             :         torch::TensorOptions().dtype(torch::kFloat64).device(torch::kCPU)
     760             :     );
     761             : 
     762             :     auto pair_samples_values = torch::zeros({n_pairs, 5}, labels_options.device(torch::kCPU));
     763             :     for (unsigned i=0; i<n_pairs; i++) {
     764             :         pair_samples_values[i][0] = static_cast<int32_t>(vesin_neighbor_list->pairs[i][0]);
     765             :         pair_samples_values[i][1] = static_cast<int32_t>(vesin_neighbor_list->pairs[i][1]);
     766             :         pair_samples_values[i][2] = vesin_neighbor_list->shifts[i][0];
     767             :         pair_samples_values[i][3] = vesin_neighbor_list->shifts[i][1];
     768             :         pair_samples_values[i][4] = vesin_neighbor_list->shifts[i][2];
     769             :     }
     770             : 
     771             :     auto neighbor_samples = torch::make_intrusive<metatensor_torch::LabelsHolder>(
     772             :         std::vector<std::string>{"first_atom", "second_atom", "cell_shift_a", "cell_shift_b", "cell_shift_c"},
     773             :         pair_samples_values.to(this->device_)
     774             :     );
     775             : 
     776             :     auto neighbors = torch::make_intrusive<metatensor_torch::TensorBlockHolder>(
     777             :         pair_vectors.to(this->dtype_).to(this->device_),
     778             :         neighbor_samples,
     779             :         std::vector<metatensor_torch::Labels>{neighbor_component},
     780             :         neighbor_properties
     781             :     );
     782             : 
     783             :     return neighbors;
     784             : }
     785             : 
     786             : metatensor_torch::TensorBlock MetatensorPlumedAction::executeModel(metatensor_torch::System system) {
     787             :     try {
     788             :         auto ivalue_output = this->model_.forward({
     789             :             std::vector<metatensor_torch::System>{system},
     790             :             evaluations_options_,
     791             :             this->check_consistency_,
     792             :         });
     793             : 
     794             :         auto dict_output = ivalue_output.toGenericDict();
     795             :         auto cv = dict_output.at(this->model_output_);
     796             :         this->output_ = cv.toCustomClass<metatensor_torch::TensorMapHolder>();
     797             :     } catch (const std::exception& e) {
     798             :         plumed_merror("failed to evaluate the model: " + std::string(e.what()));
     799             :     }
     800             : 
     801             :     plumed_massert(this->output_->keys()->count() == 1, "output should have a single block");
     802             :     auto block = metatensor_torch::TensorMapHolder::block_by_id(this->output_, 0);
     803             :     plumed_massert(block->components().empty(), "components are not yet supported in the output");
     804             : 
     805             :     return block;
     806             : }
     807             : 
     808             : 
     809             : void MetatensorPlumedAction::calculate() {
     810             :     this->createSystem();
     811             : 
     812             :     auto block = this->executeModel(this->system_);
     813             :     auto torch_values = block->values().to(torch::kCPU).to(torch::kFloat64);
     814             : 
     815             :     if (static_cast<unsigned>(torch_values.size(0)) != this->n_samples_) {
     816             :         plumed_merror(
     817             :             "expected the model to return a TensorBlock with " +
     818             :             std::to_string(this->n_samples_) + " samples, got " +
     819             :             std::to_string(torch_values.size(0)) + " instead"
     820             :         );
     821             :     } else if (static_cast<unsigned>(torch_values.size(1)) != this->n_properties_) {
     822             :         plumed_merror(
     823             :             "expected the model to return a TensorBlock with " +
     824             :             std::to_string(this->n_properties_) + " properties, got " +
     825             :             std::to_string(torch_values.size(1)) + " instead"
     826             :         );
     827             :     }
     828             : 
     829             :     Value* value = this->getPntrToComponent(0);
     830             :     // reshape the plumed `Value` to hold the data returned by the model
     831             :     if (n_samples_ == 1) {
     832             :         if (n_properties_ == 1) {
     833             :             value->set(torch_values.item<double>());
     834             :         } else {
     835             :             // we have multiple CV describing a single thing (atom or full system)
     836             :             for (unsigned i=0; i<n_properties_; i++) {
     837             :                 value->set(i, torch_values[0][i].item<double>());
     838             :             }
     839             :         }
     840             :     } else {
     841             :         auto samples = block->samples();
     842             :         plumed_assert((samples->names() == std::vector<std::string>{"system", "atom"}));
     843             : 
     844             :         auto samples_values = samples->values().to(torch::kCPU);
     845             :         auto selected_atoms = this->evaluations_options_->get_selected_atoms();
     846             : 
     847             :         // handle the possibility that samples are returned in
     848             :         // a non-sorted order.
     849             :         auto get_output_location = [&](unsigned i) {
     850             :             if (selected_atoms.has_value()) {
     851             :                 // If the users picked some selected atoms, then we store the
     852             :                 // output in the same order as the selection was given
     853             :                 auto sample = samples_values.index({static_cast<int64_t>(i), torch::indexing::Slice()});
     854             :                 auto position = selected_atoms.value()->position(sample);
     855             :                 plumed_assert(position.has_value());
     856             :                 return static_cast<unsigned>(position.value());
     857             :             } else {
     858             :                 return static_cast<unsigned>(samples_values[i][1].item<int32_t>());
     859             :             }
     860             :         };
     861             : 
     862             :         if (n_properties_ == 1) {
     863             :             // we have a single CV describing multiple things (i.e. atoms)
     864             :             for (unsigned i=0; i<n_samples_; i++) {
     865             :                 auto output_i = get_output_location(i);
     866             :                 value->set(output_i, torch_values[i][0].item<double>());
     867             :             }
     868             :         } else {
     869             :             // the CV is a matrix
     870             :             for (unsigned i=0; i<n_samples_; i++) {
     871             :                 auto output_i = get_output_location(i);
     872             :                 for (unsigned j=0; j<n_properties_; j++) {
     873             :                     value->set(output_i * n_properties_ + j, torch_values[i][j].item<double>());
     874             :                 }
     875             :             }
     876             :         }
     877             :     }
     878             : }
     879             : 
     880             : 
     881             : void MetatensorPlumedAction::apply() {
     882             :     const auto* value = this->getPntrToComponent(0);
     883             :     if (!value->forcesWereAdded()) {
     884             :         return;
     885             :     }
     886             : 
     887             :     auto block = metatensor_torch::TensorMapHolder::block_by_id(this->output_, 0);
     888             :     auto torch_values = block->values().to(torch::kCPU).to(torch::kFloat64);
     889             : 
     890             :     if (!torch_values.requires_grad()) {
     891             :         this->warning(
     892             :             "the output of the model does not requires gradients, this might "
     893             :             "indicate a problem"
     894             :         );
     895             :         return;
     896             :     }
     897             : 
     898             :     auto output_grad = torch::zeros_like(torch_values);
     899             :     if (n_samples_ == 1) {
     900             :         if (n_properties_ == 1) {
     901             :             output_grad[0][0] = value->getForce();
     902             :         } else {
     903             :             for (unsigned i=0; i<n_properties_; i++) {
     904             :                 output_grad[0][i] = value->getForce(i);
     905             :             }
     906             :         }
     907             :     } else {
     908             :         auto samples = block->samples();
     909             :         plumed_assert((samples->names() == std::vector<std::string>{"system", "atom"}));
     910             : 
     911             :         auto samples_values = samples->values().to(torch::kCPU);
     912             :         auto selected_atoms = this->evaluations_options_->get_selected_atoms();
     913             : 
     914             :         // see above for an explanation of why we use this function
     915             :         auto get_output_location = [&](unsigned i) {
     916             :             if (selected_atoms.has_value()) {
     917             :                 auto sample = samples_values.index({static_cast<int64_t>(i), torch::indexing::Slice()});
     918             :                 auto position = selected_atoms.value()->position(sample);
     919             :                 plumed_assert(position.has_value());
     920             :                 return static_cast<unsigned>(position.value());
     921             :             } else {
     922             :                 return static_cast<unsigned>(samples_values[i][1].item<int32_t>());
     923             :             }
     924             :         };
     925             : 
     926             :         if (n_properties_ == 1) {
     927             :             for (unsigned i=0; i<n_samples_; i++) {
     928             :                 auto output_i = get_output_location(i);
     929             :                 output_grad[i][0] = value->getForce(output_i);
     930             :             }
     931             :         } else {
     932             :             for (unsigned i=0; i<n_samples_; i++) {
     933             :                 auto output_i = get_output_location(i);
     934             :                 for (unsigned j=0; j<n_properties_; j++) {
     935             :                     output_grad[i][j] = value->getForce(output_i * n_properties_ + j);
     936             :                 }
     937             :             }
     938             :         }
     939             :     }
     940             : 
     941             :     this->system_->positions().mutable_grad() = torch::Tensor();
     942             :     this->strain_.mutable_grad() = torch::Tensor();
     943             : 
     944             :     torch_values.backward(output_grad);
     945             :     auto positions_grad = this->system_->positions().grad();
     946             :     auto strain_grad = this->strain_.grad();
     947             : 
     948             :     positions_grad = positions_grad.to(torch::kCPU).to(torch::kFloat64);
     949             :     strain_grad = strain_grad.to(torch::kCPU).to(torch::kFloat64);
     950             : 
     951             :     plumed_assert(positions_grad.sizes().size() == 2);
     952             :     plumed_assert(positions_grad.is_contiguous());
     953             : 
     954             :     plumed_assert(strain_grad.sizes().size() == 2);
     955             :     plumed_assert(strain_grad.is_contiguous());
     956             : 
     957             :     auto derivatives = std::vector<double>(
     958             :         positions_grad.data_ptr<double>(),
     959             :         positions_grad.data_ptr<double>() + 3 * this->system_->size()
     960             :     );
     961             : 
     962             :     // add virials to the derivatives
     963             :     derivatives.push_back(-strain_grad[0][0].item<double>());
     964             :     derivatives.push_back(-strain_grad[0][1].item<double>());
     965             :     derivatives.push_back(-strain_grad[0][2].item<double>());
     966             : 
     967             :     derivatives.push_back(-strain_grad[1][0].item<double>());
     968             :     derivatives.push_back(-strain_grad[1][1].item<double>());
     969             :     derivatives.push_back(-strain_grad[1][2].item<double>());
     970             : 
     971             :     derivatives.push_back(-strain_grad[2][0].item<double>());
     972             :     derivatives.push_back(-strain_grad[2][1].item<double>());
     973             :     derivatives.push_back(-strain_grad[2][2].item<double>());
     974             : 
     975             :     unsigned index = 0;
     976             :     this->setForcesOnAtoms(derivatives, index);
     977             : }
     978             : 
     979             : } // namespace metatensor
     980             : } // namespace PLMD
     981             : 
     982             : 
     983             : #endif
     984             : 
     985             : 
     986             : namespace PLMD {
     987             : namespace metatensor {
     988             : 
     989             : // use the same implementation for both the actual action and the dummy one
     990             : // (when libtorch and libmetatensor could not be found).
     991           2 : void MetatensorPlumedAction::registerKeywords(Keywords& keys) {
     992           2 :     Action::registerKeywords(keys);
     993           2 :     ActionAtomistic::registerKeywords(keys);
     994           2 :     ActionWithValue::registerKeywords(keys);
     995             : 
     996           2 :     keys.add("compulsory", "MODEL", "path to the exported metatensor model");
     997           2 :     keys.add("optional", "EXTENSIONS_DIRECTORY", "path to the directory containing TorchScript extensions to load");
     998           2 :     keys.add("optional", "DEVICE", "Torch device to use for the calculation");
     999             : 
    1000           2 :     keys.addFlag("CHECK_CONSISTENCY", false, "Should we enable internal consistency of the model");
    1001             : 
    1002           2 :     keys.add("numbered", "SPECIES", "the atoms in each PLUMED species");
    1003           4 :     keys.reset_style("SPECIES", "atoms");
    1004             : 
    1005           2 :     keys.add("optional", "SELECTED_ATOMS", "subset of atoms that should be used for the calculation");
    1006           4 :     keys.reset_style("SELECTED_ATOMS", "atoms");
    1007             : 
    1008           2 :     keys.add("optional", "SPECIES_TO_TYPES", "mapping from PLUMED SPECIES to metatensor's atomic types");
    1009             : 
    1010           4 :     keys.addOutputComponent("outputs", "default", "scalar", "collective variable created by the metatensor model");
    1011           4 :     keys.setValueDescription("scalar/vector/matrix","collective variable created by the metatensor model");
    1012           2 : }
    1013             : 
    1014             : PLUMED_REGISTER_ACTION(MetatensorPlumedAction, "METATENSOR")
    1015             : 
    1016             : } // namespace metatensor
    1017             : } // namespace PLMD

Generated by: LCOV version 1.16