Line data Source code
1 : /* +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
2 : Copyright (c) 2024 The METATENSOR code team
3 : (see the PEOPLE-METATENSOR file at the root of this folder for a list of names)
4 :
5 : See https://docs.metatensor.org/latest/ for more information about the
6 : metatensor package that this module allows you to call from PLUMED.
7 :
8 : This file is part of METATENSOR-PLUMED module.
9 :
10 : The METATENSOR-PLUMED module is free software: you can redistribute it and/or modify
11 : it under the terms of the GNU Lesser General Public License as published by
12 : the Free Software Foundation, either version 3 of the License, or
13 : (at your option) any later version.
14 :
15 : The METATENSOR-PLUMED module is distributed in the hope that it will be useful,
16 : but WITHOUT ANY WARRANTY; without even the implied warranty of
17 : MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
18 : GNU Lesser General Public License for more details.
19 :
20 : You should have received a copy of the GNU Lesser General Public License
21 : along with the METATENSOR-PLUMED module. If not, see <http://www.gnu.org/licenses/>.
22 : +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ */
23 :
24 : #include "core/ActionAtomistic.h"
25 : #include "core/ActionWithValue.h"
26 : #include "core/ActionRegister.h"
27 : #include "core/PlumedMain.h"
28 :
29 : //+PLUMEDOC METATENSORMOD_COLVAR METATENSOR
30 : /*
31 : Use arbitrary machine learning models as collective variables.
32 :
33 : Note that this action requires the metatensor-torch library. Check the
34 : instructions in the \ref METATENSORMOD page to enable this module.
35 :
36 : This action enables the use of fully custom machine learning models — based on
37 : the [metatensor atomistic models][mts_models] interface — as collective
38 : variables in PLUMED. Such machine learning model are typically written and
39 : customized using Python code, and then exported to run within PLUMED as
40 : [TorchScript], which is a subset of Python that can be executed by the C++ torch
41 : library.
42 :
43 : Metatensor offers a way to define such models and pass data from PLUMED (or any
44 : other simulation engine) to the model and back. For more information on how to
45 : define such model, have a look at the [corresponding tutorials][mts_tutorials],
46 : or at the code in `regtest/metatensor/`. Each of the Python scripts in this
47 : directory defines a custom machine learning CV that can be used with PLUMED.
48 :
49 : \par Examples
50 :
51 : The following input shows how you can call metatensor and evaluate the model
52 : that is described in the file `custom_cv.pt` from PLUMED.
53 :
54 : \plumedfile metatensor_cv: METATENSOR ... MODEL=custom_cv.pt
55 :
56 : SPECIES1=1-26
57 : SPECIES2=27-62
58 : SPECIES3=63-76
59 : SPECIES_TO_TYPES=6,1,8
60 : ...
61 : \endplumedfile
62 :
63 : The numbered `SPECIES` labels are used to indicate the list of atoms that belong
64 : to each atomic species in the system. The `SPECIES_TO_TYPE` keyword then
65 : provides information on the atom type for each species. The first number here is
66 : the atomic type of the atoms that have been specified using the `SPECIES1` flag,
67 : the second number is the atomic number of the atoms that have been specified
68 : using the `SPECIES2` flag and so on.
69 :
70 : `METATENSOR` action also accepts the following options:
71 :
72 : - `EXTENSIONS_DIRECTORY` should be the path to a directory containing
73 : TorchScript extensions (as shared libraries) that are required to load and
74 : execute the model. This matches the `collect_extensions` argument to
75 : `MetatensorAtomisticModel.export` in Python.
76 : - `CHECK_CONSISTENCY` can be used to enable internal consistency checks;
77 : - `SELECTED_ATOMS` can be used to signal the metatensor models that it should
78 : only run its calculation for the selected subset of atoms. The model still
79 : need to know about all the atoms in the system (through the `SPECIES`
80 : keyword); but this can be used to reduce the calculation cost. Note that the
81 : indices of the selected atoms should start at 1 in the PLUMED input file, but
82 : they will be translated to start at 0 when given to the model (i.e. in
83 : Python/TorchScript, the `forward` method will receive a `selected_atoms` which
84 : starts at 0)
85 :
86 : Here is another example with all the possible keywords:
87 :
88 : \plumedfile soap: METATENSOR ... MODEL=soap.pt EXTENSION_DIRECTORY=extensions
89 : CHECK_CONSISTENCY
90 :
91 : SPECIES1=1-10
92 : SPECIES2=11-20
93 : SPECIES_TO_TYPES=8,13
94 :
95 : # only run the calculation for the Aluminium (type 13) atoms, but
96 : # include the Oxygen (type 8) as potential neighbors.
97 : SELECTED_ATOMS=11-20
98 : ...
99 : \endplumedfile
100 :
101 : \par Collective variables and metatensor models
102 :
103 : PLUMED can use the [`"features"` output][features_output] of metatensor
104 : atomistic models as a collective variables. Alternatively, the code also accepts
105 : an output named `"plumed::cv"`, with the same metadata structure as the
106 : `"features"` output.
107 :
108 : */ /*
109 :
110 : [TorchScript]: https://pytorch.org/docs/stable/jit.html
111 : [mts_models]: https://docs.metatensor.org/latest/atomistic/index.html
112 : [mts_tutorials]: https://docs.metatensor.org/latest/examples/atomistic/index.html
113 : [mts_block]: https://docs.metatensor.org/latest/torch/reference/block.html
114 : [features_output]: https://docs.metatensor.org/latest/examples/atomistic/outputs/features.html
115 : */
116 : //+ENDPLUMEDOC
117 :
118 : /*INDENT-OFF*/
119 : #if !defined(__PLUMED_HAS_METATENSOR) || !defined(__PLUMED_HAS_LIBTORCH)
120 :
121 : namespace PLMD { namespace metatensor {
122 : class MetatensorPlumedAction: public ActionAtomistic, public ActionWithValue {
123 : public:
124 : static void registerKeywords(Keywords& keys);
125 0 : explicit MetatensorPlumedAction(const ActionOptions& options):
126 : Action(options),
127 : ActionAtomistic(options),
128 0 : ActionWithValue(options)
129 : {
130 0 : throw std::runtime_error(
131 : "Can not use metatensor action without the corresponding libraries. \n"
132 : "Make sure to configure with `--enable-metatensor --enable-libtorch` "
133 : "and that the corresponding libraries are found"
134 0 : );
135 0 : }
136 :
137 0 : void calculate() override {}
138 0 : void apply() override {}
139 0 : unsigned getNumberOfDerivatives() override {return 0;}
140 : };
141 :
142 : }} // namespace PLMD::metatensor
143 :
144 : #else
145 :
146 : #include <type_traits>
147 :
148 : #pragma GCC diagnostic push
149 : #pragma GCC diagnostic ignored "-Wpedantic"
150 : #pragma GCC diagnostic ignored "-Wunused-parameter"
151 : #pragma GCC diagnostic ignored "-Wfloat-equal"
152 : #pragma GCC diagnostic ignored "-Wfloat-conversion"
153 : #pragma GCC diagnostic ignored "-Wimplicit-float-conversion"
154 : #pragma GCC diagnostic ignored "-Wimplicit-int-conversion"
155 : #pragma GCC diagnostic ignored "-Wshorten-64-to-32"
156 : #pragma GCC diagnostic ignored "-Wsign-conversion"
157 : #pragma GCC diagnostic ignored "-Wold-style-cast"
158 :
159 : #include <torch/script.h>
160 : #include <torch/version.h>
161 : #include <torch/cuda.h>
162 : #if TORCH_VERSION_MAJOR >= 2
163 : #include <torch/mps.h>
164 : #endif
165 :
166 : #pragma GCC diagnostic pop
167 :
168 : #include <metatensor/torch.hpp>
169 : #include <metatensor/torch/atomistic.hpp>
170 :
171 : #include "vesin.h"
172 :
173 :
174 : namespace PLMD {
175 : namespace metatensor {
176 :
177 : // We will cast Vector/Tensor to pointers to arrays and doubles, so let's make
178 : // sure this is legal to do
179 : static_assert(std::is_standard_layout<PLMD::Vector>::value);
180 : static_assert(sizeof(PLMD::Vector) == sizeof(std::array<double, 3>));
181 : static_assert(alignof(PLMD::Vector) == alignof(std::array<double, 3>));
182 :
183 : static_assert(std::is_standard_layout<PLMD::Tensor>::value);
184 : static_assert(sizeof(PLMD::Tensor) == sizeof(std::array<std::array<double, 3>, 3>));
185 : static_assert(alignof(PLMD::Tensor) == alignof(std::array<std::array<double, 3>, 3>));
186 :
187 : class MetatensorPlumedAction: public ActionAtomistic, public ActionWithValue {
188 : public:
189 : static void registerKeywords(Keywords& keys);
190 : explicit MetatensorPlumedAction(const ActionOptions&);
191 :
192 : void calculate() override;
193 : void apply() override;
194 : unsigned getNumberOfDerivatives() override;
195 :
196 : private:
197 : // fill this->system_ according to the current PLUMED data
198 : void createSystem();
199 : // compute a neighbor list following metatensor format, using data from PLUMED
200 : metatensor_torch::TensorBlock computeNeighbors(
201 : metatensor_torch::NeighborListOptions request,
202 : const std::vector<PLMD::Vector>& positions,
203 : const PLMD::Tensor& cell,
204 : bool periodic
205 : );
206 :
207 : // execute the model for the given system
208 : metatensor_torch::TensorBlock executeModel(metatensor_torch::System system);
209 :
210 : torch::jit::Module model_;
211 :
212 : metatensor_torch::ModelCapabilities capabilities_;
213 : std::string model_output_;
214 :
215 : // neighbor lists requests made by the model
216 : std::vector<metatensor_torch::NeighborListOptions> nl_requests_;
217 :
218 : // dtype/device to use to execute the model
219 : torch::ScalarType dtype_;
220 : torch::Device device_;
221 :
222 : torch::Tensor atomic_types_;
223 : // store the strain to be able to compute the virial with autograd
224 : torch::Tensor strain_;
225 :
226 : metatensor_torch::System system_;
227 : metatensor_torch::ModelEvaluationOptions evaluations_options_;
228 : bool check_consistency_;
229 :
230 : metatensor_torch::TensorMap output_;
231 : // shape of the output of this model
232 : unsigned n_samples_;
233 : unsigned n_properties_;
234 : };
235 :
236 :
237 : MetatensorPlumedAction::MetatensorPlumedAction(const ActionOptions& options):
238 : Action(options),
239 : ActionAtomistic(options),
240 : ActionWithValue(options),
241 : device_(torch::kCPU)
242 : {
243 : if (metatensor_torch::version().find("0.7.") != 0) {
244 : this->error(
245 : "this code requires version 0.7.x of metatensor-torch, got version " +
246 : metatensor_torch::version()
247 : );
248 : }
249 :
250 : // first, load the model
251 : std::string extensions_directory_str;
252 : this->parse("EXTENSIONS_DIRECTORY", extensions_directory_str);
253 :
254 : torch::optional<std::string> extensions_directory = torch::nullopt;
255 : if (!extensions_directory_str.empty()) {
256 : extensions_directory = std::move(extensions_directory_str);
257 : }
258 :
259 : std::string model_path;
260 : this->parse("MODEL", model_path);
261 :
262 : try {
263 : this->model_ = metatensor_torch::load_atomistic_model(model_path, extensions_directory);
264 : } catch (const std::exception& e) {
265 : this->error("failed to load model at '" + model_path + "': " + e.what());
266 : }
267 :
268 : // extract information from the model
269 : auto metadata = this->model_.run_method("metadata").toCustomClass<metatensor_torch::ModelMetadataHolder>();
270 : this->capabilities_ = this->model_.run_method("capabilities").toCustomClass<metatensor_torch::ModelCapabilitiesHolder>();
271 : auto requests_ivalue = this->model_.run_method("requested_neighbor_lists");
272 : for (auto request_ivalue: requests_ivalue.toList()) {
273 : auto request = request_ivalue.get().toCustomClass<metatensor_torch::NeighborListOptionsHolder>();
274 : this->nl_requests_.push_back(request);
275 : }
276 :
277 : log.printf("\n%s\n", metadata->print().c_str());
278 : // add the model references to PLUMED citation handling mechanism
279 : for (const auto& it: metadata->references) {
280 : for (const auto& ref: it.value()) {
281 : this->cite(ref);
282 : }
283 : }
284 :
285 : // parse the atomic types from the input file
286 : std::vector<int32_t> atomic_types;
287 : std::vector<int32_t> species_to_types;
288 : this->parseVector("SPECIES_TO_TYPES", species_to_types);
289 : bool has_custom_types = !species_to_types.empty();
290 :
291 : std::vector<AtomNumber> all_atoms;
292 : this->parseAtomList("SPECIES", all_atoms);
293 :
294 : size_t n_species = 0;
295 : if (all_atoms.empty()) {
296 : std::vector<AtomNumber> t;
297 : int i = 0;
298 : while (true) {
299 : i += 1;
300 : this->parseAtomList("SPECIES", i, t);
301 : if (t.empty()) {
302 : break;
303 : }
304 :
305 : int32_t type = i;
306 : if (has_custom_types) {
307 : if (species_to_types.size() < static_cast<size_t>(i)) {
308 : this->error(
309 : "SPECIES_TO_TYPES is too small, it should have one entry "
310 : "for each species (we have at least " + std::to_string(i) +
311 : " species and " + std::to_string(species_to_types.size()) +
312 : "entries in SPECIES_TO_TYPES)"
313 : );
314 : }
315 :
316 : type = species_to_types[static_cast<size_t>(i - 1)];
317 : }
318 :
319 : log.printf(" atoms with type %d are: ", type);
320 : for(unsigned j=0; j<t.size(); j++) {
321 : log.printf("%d ", t[j]);
322 : all_atoms.push_back(t[j]);
323 : atomic_types.push_back(type);
324 : }
325 : log.printf("\n"); t.resize(0);
326 :
327 : n_species += 1;
328 : }
329 : } else {
330 : n_species = 1;
331 :
332 : int32_t type = 1;
333 : if (has_custom_types) {
334 : type = species_to_types[0];
335 : }
336 : atomic_types.resize(all_atoms.size(), type);
337 : }
338 :
339 : if (has_custom_types && species_to_types.size() != n_species) {
340 : this->warning(
341 : "SPECIES_TO_TYPES contains more entries (" +
342 : std::to_string(species_to_types.size()) +
343 : ") than there where species (" + std::to_string(n_species) + ")"
344 : );
345 : }
346 :
347 : this->atomic_types_ = torch::tensor(std::move(atomic_types));
348 : this->requestAtoms(all_atoms);
349 :
350 : this->check_consistency_ = false;
351 : this->parseFlag("CHECK_CONSISTENCY", this->check_consistency_);
352 : if (this->check_consistency_) {
353 : log.printf(" checking for internal consistency of the model\n");
354 : }
355 :
356 : // create evaluation options for the model. These won't change during the
357 : // simulation, so we initialize them once here.
358 : evaluations_options_ = torch::make_intrusive<metatensor_torch::ModelEvaluationOptionsHolder>();
359 : evaluations_options_->set_length_unit(getUnits().getLengthString());
360 :
361 : auto outputs = this->capabilities_->outputs();
362 : if (outputs.contains("features")) {
363 : this->model_output_ = "features";
364 : }
365 :
366 : if (outputs.contains("plumed::cv")) {
367 : if (outputs.contains("features")) {
368 : this->warning(
369 : "this model exposes both 'features' and 'plumed::cv' outputs, "
370 : "we will use 'features'. 'plumed::cv' is deprecated, please "
371 : "remove it from your models"
372 : );
373 : } else {
374 : this->warning(
375 : "this model is using 'plumed::cv' output, which is deprecated. "
376 : "Please replace it with a 'features' output"
377 : );
378 : this->model_output_ = "plumed::cv";
379 : }
380 : }
381 :
382 :
383 : if (this->model_output_.empty()) {
384 : auto existing_outputs = std::vector<std::string>();
385 : for (const auto& it: this->capabilities_->outputs()) {
386 : existing_outputs.push_back(it.key());
387 : }
388 :
389 : this->error(
390 : "expected 'features' or 'plumed::cv' in the capabilities of the model, "
391 : "could not find it. the following outputs exist: " + torch::str(existing_outputs)
392 : );
393 : }
394 :
395 : auto output = torch::make_intrusive<metatensor_torch::ModelOutputHolder>();
396 : // this output has no quantity or unit to set
397 :
398 : output->per_atom = this->capabilities_->outputs().at(this->model_output_)->per_atom;
399 : // we are using torch autograd system to compute gradients,
400 : // so we don't need any explicit gradients.
401 : output->explicit_gradients = {};
402 : evaluations_options_->outputs.insert(this->model_output_, output);
403 :
404 : // Determine which device we should use based on user input, what the model
405 : // supports and what's available
406 : auto available_devices = std::vector<torch::Device>();
407 : for (const auto& device: this->capabilities_->supported_devices) {
408 : if (device == "cpu") {
409 : available_devices.push_back(torch::kCPU);
410 : } else if (device == "cuda") {
411 : if (torch::cuda::is_available()) {
412 : available_devices.push_back(torch::Device("cuda"));
413 : }
414 : } else if (device == "mps") {
415 : #if TORCH_VERSION_MAJOR >= 2
416 : if (torch::mps::is_available()) {
417 : available_devices.push_back(torch::Device("mps"));
418 : }
419 : #endif
420 : } else {
421 : this->warning(
422 : "the model declared support for unknown device '" + device +
423 : "', it will be ignored"
424 : );
425 : }
426 : }
427 :
428 : if (available_devices.empty()) {
429 : this->error(
430 : "failed to find a valid device for the model at '" + model_path + "': "
431 : "the model supports " + torch::str(this->capabilities_->supported_devices) +
432 : ", none of these where available"
433 : );
434 : }
435 :
436 : std::string requested_device;
437 : this->parse("DEVICE", requested_device);
438 : if (requested_device.empty()) {
439 : // no user request, pick the device the model prefers
440 : this->device_ = available_devices[0];
441 : } else {
442 : bool found_requested_device = false;
443 : for (const auto& device: available_devices) {
444 : if (device.is_cpu() && requested_device == "cpu") {
445 : this->device_ = device;
446 : found_requested_device = true;
447 : break;
448 : } else if (device.is_cuda() && requested_device == "cuda") {
449 : this->device_ = device;
450 : found_requested_device = true;
451 : break;
452 : } else if (device.is_mps() && requested_device == "mps") {
453 : this->device_ = device;
454 : found_requested_device = true;
455 : break;
456 : }
457 : }
458 :
459 : if (!found_requested_device) {
460 : this->error(
461 : "failed to find requested device (" + requested_device + "): it is either "
462 : "not supported by this model or not available on this machine"
463 : );
464 : }
465 : }
466 :
467 : this->model_.to(this->device_);
468 : this->atomic_types_ = this->atomic_types_.to(this->device_);
469 :
470 : log.printf(
471 : " running model on %s device with %s data\n",
472 : this->device_.str().c_str(),
473 : this->capabilities_->dtype().c_str()
474 : );
475 :
476 : if (this->capabilities_->dtype() == "float64") {
477 : this->dtype_ = torch::kFloat64;
478 : } else if (this->capabilities_->dtype() == "float32") {
479 : this->dtype_ = torch::kFloat32;
480 : } else {
481 : this->error(
482 : "the model requested an unsupported dtype '" + this->capabilities_->dtype() + "'"
483 : );
484 : }
485 :
486 : auto tensor_options = torch::TensorOptions().dtype(this->dtype_).device(this->device_);
487 : this->strain_ = torch::eye(3, tensor_options.requires_grad(true));
488 :
489 : // determine how many properties there will be in the output by running the
490 : // model once on a dummy system
491 : auto dummy_system = torch::make_intrusive<metatensor_torch::SystemHolder>(
492 : /*types = */ torch::zeros({0}, tensor_options.dtype(torch::kInt32)),
493 : /*positions = */ torch::zeros({0, 3}, tensor_options),
494 : /*cell = */ torch::zeros({3, 3}, tensor_options),
495 : /*pbc = */ torch::zeros({3}, tensor_options.dtype(torch::kBool))
496 : );
497 :
498 : log.printf(" the following neighbor lists have been requested:\n");
499 : auto length_unit = this->getUnits().getLengthString();
500 : auto model_length_unit = this->capabilities_->length_unit();
501 : for (auto request: this->nl_requests_) {
502 : log.printf(" - %s list, %g %s cutoff (requested %g %s)\n",
503 : request->full_list() ? "full" : "half",
504 : request->engine_cutoff(length_unit),
505 : length_unit.c_str(),
506 : request->cutoff(),
507 : model_length_unit.c_str()
508 : );
509 :
510 : auto neighbors = this->computeNeighbors(
511 : request,
512 : {PLMD::Vector(0, 0, 0)},
513 : PLMD::Tensor(0, 0, 0, 0, 0, 0, 0, 0, 0),
514 : false
515 : );
516 : metatensor_torch::register_autograd_neighbors(dummy_system, neighbors, this->check_consistency_);
517 : dummy_system->add_neighbor_list(request, neighbors);
518 : }
519 :
520 : this->n_properties_ = static_cast<unsigned>(
521 : this->executeModel(dummy_system)->properties()->count()
522 : );
523 :
524 : // parse and handle atom sub-selection. This is done AFTER determining the
525 : // output size, since the selection might not be valid for the dummy system
526 : std::vector<int32_t> selected_atoms;
527 : this->parseVector("SELECTED_ATOMS", selected_atoms);
528 : if (!selected_atoms.empty()) {
529 : auto selection_value = torch::zeros(
530 : {static_cast<int64_t>(selected_atoms.size()), 2},
531 : torch::TensorOptions().dtype(torch::kInt32).device(this->device_)
532 : );
533 :
534 : for (unsigned i=0; i<selected_atoms.size(); i++) {
535 : auto n_atoms = static_cast<int32_t>(this->atomic_types_.size(0));
536 : if (selected_atoms[i] <= 0 || selected_atoms[i] > n_atoms) {
537 : this->error(
538 : "Values in metatensor's SELECTED_ATOMS should be between 1 "
539 : "and the number of atoms (" + std::to_string(n_atoms) + "), "
540 : "got " + std::to_string(selected_atoms[i]));
541 : }
542 : // PLUMED input uses 1-based indexes, but metatensor wants 0-based
543 : selection_value[i][1] = selected_atoms[i] - 1;
544 : }
545 :
546 : evaluations_options_->set_selected_atoms(
547 : torch::make_intrusive<metatensor_torch::LabelsHolder>(
548 : std::vector<std::string>{"system", "atom"}, selection_value
549 : )
550 : );
551 : }
552 :
553 : // Now that we now both n_samples and n_properties, we can setup the
554 : // PLUMED-side storage for the computed CV
555 : if (output->per_atom) {
556 : if (selected_atoms.empty()) {
557 : this->n_samples_ = static_cast<unsigned>(this->atomic_types_.size(0));
558 : } else {
559 : this->n_samples_ = static_cast<unsigned>(selected_atoms.size());
560 : }
561 : } else {
562 : this->n_samples_ = 1;
563 : }
564 :
565 : if (n_samples_ == 1 && n_properties_ == 1) {
566 : log.printf(" the output of this model is a scalar\n");
567 :
568 : this->addValue();
569 : } else if (n_samples_ == 1) {
570 : log.printf(" the output of this model is 1x%d vector\n", n_properties_);
571 :
572 : this->addValue({this->n_properties_});
573 : this->getPntrToComponent(0)->buildDataStore();
574 : } else if (n_properties_ == 1) {
575 : log.printf(" the output of this model is %dx1 vector\n", n_samples_);
576 :
577 : this->addValue({this->n_samples_});
578 : this->getPntrToComponent(0)->buildDataStore();
579 : } else {
580 : log.printf(" the output of this model is a %dx%d matrix\n", n_samples_, n_properties_);
581 :
582 : this->addValue({this->n_samples_, this->n_properties_});
583 : this->getPntrToComponent(0)->buildDataStore();
584 : this->getPntrToComponent(0)->reshapeMatrixStore(n_properties_);
585 : }
586 :
587 : this->setNotPeriodic();
588 : }
589 :
590 : unsigned MetatensorPlumedAction::getNumberOfDerivatives() {
591 : // gradients w.r.t. positions (3 x N values) + gradients w.r.t. strain (9 values)
592 : return 3 * this->getNumberOfAtoms() + 9;
593 : }
594 :
595 :
596 : void MetatensorPlumedAction::createSystem() {
597 : if (this->getTotAtoms() != static_cast<unsigned>(this->atomic_types_.size(0))) {
598 : std::ostringstream oss;
599 : oss << "METATENSOR action needs to know about all atoms in the system. ";
600 : oss << "There are " << this->getTotAtoms() << " atoms overall, ";
601 : oss << "but we only have atomic types for " << this->atomic_types_.size(0) << " of them.";
602 : plumed_merror(oss.str());
603 : }
604 :
605 : // this->getTotAtoms()
606 :
607 : const auto& cell = this->getPbc().getBox();
608 :
609 : auto cpu_f64_tensor = torch::TensorOptions().dtype(torch::kFloat64).device(torch::kCPU);
610 : auto torch_cell = torch::zeros({3, 3}, cpu_f64_tensor);
611 :
612 : torch_cell[0][0] = cell(0, 0);
613 : torch_cell[0][1] = cell(0, 1);
614 : torch_cell[0][2] = cell(0, 2);
615 :
616 : torch_cell[1][0] = cell(1, 0);
617 : torch_cell[1][1] = cell(1, 1);
618 : torch_cell[1][2] = cell(1, 2);
619 :
620 : torch_cell[2][0] = cell(2, 0);
621 : torch_cell[2][1] = cell(2, 1);
622 : torch_cell[2][2] = cell(2, 2);
623 :
624 : using torch::indexing::Slice;
625 :
626 : auto pbc_a = torch_cell.index({0, Slice()}).norm().abs().item<double>() > 1e-9;
627 : auto pbc_b = torch_cell.index({1, Slice()}).norm().abs().item<double>() > 1e-9;
628 : auto pbc_c = torch_cell.index({2, Slice()}).norm().abs().item<double>() > 1e-9;
629 : auto torch_pbc = torch::tensor({pbc_a, pbc_b, pbc_c});
630 :
631 : const auto& positions = this->getPositions();
632 :
633 : auto torch_positions = torch::from_blob(
634 : const_cast<PLMD::Vector*>(positions.data()),
635 : {static_cast<int64_t>(positions.size()), 3},
636 : cpu_f64_tensor
637 : );
638 :
639 : torch_positions = torch_positions.to(this->dtype_).to(this->device_);
640 : torch_cell = torch_cell.to(this->dtype_).to(this->device_);
641 : torch_pbc = torch_pbc.to(this->device_);
642 :
643 : // setup torch's automatic gradient tracking
644 : if (!this->doNotCalculateDerivatives()) {
645 : torch_positions.requires_grad_(true);
646 :
647 : // pretend to scale positions/cell by the strain so that it enters the
648 : // computational graph.
649 : torch_positions = torch_positions.matmul(this->strain_);
650 : torch_positions.retain_grad();
651 :
652 : torch_cell = torch_cell.matmul(this->strain_);
653 : }
654 :
655 : this->system_ = torch::make_intrusive<metatensor_torch::SystemHolder>(
656 : this->atomic_types_,
657 : torch_positions,
658 : torch_cell,
659 : torch_pbc
660 : );
661 :
662 : auto periodic = torch::all(torch_pbc).item<bool>();
663 : if (!periodic && torch::any(torch_pbc).item<bool>()) {
664 : std::string periodic_directions;
665 : std::string non_periodic_directions;
666 : if (pbc_a) {
667 : periodic_directions += "A";
668 : } else {
669 : non_periodic_directions += "A";
670 : }
671 :
672 : if (pbc_b) {
673 : periodic_directions += "B";
674 : } else {
675 : non_periodic_directions += "B";
676 : }
677 :
678 : if (pbc_c) {
679 : periodic_directions += "C";
680 : } else {
681 : non_periodic_directions += "C";
682 : }
683 :
684 : plumed_merror(
685 : "mixed periodic boundary conditions are not supported, this system "
686 : "is periodic along the " + periodic_directions + " cell vector(s), "
687 : "but not along the " + non_periodic_directions + " cell vector(s)."
688 : );
689 : }
690 :
691 : // compute the neighbors list requested by the model, and register them with
692 : // the system
693 : for (auto request: this->nl_requests_) {
694 : auto neighbors = this->computeNeighbors(request, positions, cell, periodic);
695 : metatensor_torch::register_autograd_neighbors(this->system_, neighbors, this->check_consistency_);
696 : this->system_->add_neighbor_list(request, neighbors);
697 : }
698 : }
699 :
700 :
701 : metatensor_torch::TensorBlock MetatensorPlumedAction::computeNeighbors(
702 : metatensor_torch::NeighborListOptions request,
703 : const std::vector<PLMD::Vector>& positions,
704 : const PLMD::Tensor& cell,
705 : bool periodic
706 : ) {
707 : auto labels_options = torch::TensorOptions().dtype(torch::kInt32).device(this->device_);
708 : auto neighbor_component = torch::make_intrusive<metatensor_torch::LabelsHolder>(
709 : "xyz",
710 : torch::tensor({0, 1, 2}, labels_options).reshape({3, 1})
711 : );
712 : auto neighbor_properties = torch::make_intrusive<metatensor_torch::LabelsHolder>(
713 : "distance", torch::zeros({1, 1}, labels_options)
714 : );
715 :
716 : auto cutoff = request->engine_cutoff(this->getUnits().getLengthString());
717 :
718 : // use https://github.com/Luthaf/vesin to compute the requested neighbor
719 : // lists since we can not get these from PLUMED
720 : vesin::VesinOptions options;
721 : options.cutoff = cutoff;
722 : options.full = request->full_list();
723 : options.return_shifts = true;
724 : options.return_distances = false;
725 : options.return_vectors = true;
726 :
727 : vesin::VesinNeighborList* vesin_neighbor_list = new vesin::VesinNeighborList();
728 : memset(vesin_neighbor_list, 0, sizeof(vesin::VesinNeighborList));
729 :
730 : const char* error_message = NULL;
731 : int status = vesin_neighbors(
732 : reinterpret_cast<const double (*)[3]>(positions.data()),
733 : positions.size(),
734 : reinterpret_cast<const double (*)[3]>(&cell(0, 0)),
735 : periodic,
736 : vesin::VesinCPU,
737 : options,
738 : vesin_neighbor_list,
739 : &error_message
740 : );
741 :
742 : if (status != EXIT_SUCCESS) {
743 : plumed_merror(
744 : "failed to compute neighbor list (cutoff=" + std::to_string(cutoff) +
745 : ", full=" + (request->full_list() ? "true" : "false") + "): " + error_message
746 : );
747 : }
748 :
749 : // transform from vesin to metatensor format
750 : auto n_pairs = static_cast<int64_t>(vesin_neighbor_list->length);
751 :
752 : auto pair_vectors = torch::from_blob(
753 : vesin_neighbor_list->vectors,
754 : {n_pairs, 3, 1},
755 : /*deleter*/ [=](void*) {
756 : vesin_free(vesin_neighbor_list);
757 : delete vesin_neighbor_list;
758 : },
759 : torch::TensorOptions().dtype(torch::kFloat64).device(torch::kCPU)
760 : );
761 :
762 : auto pair_samples_values = torch::zeros({n_pairs, 5}, labels_options.device(torch::kCPU));
763 : for (unsigned i=0; i<n_pairs; i++) {
764 : pair_samples_values[i][0] = static_cast<int32_t>(vesin_neighbor_list->pairs[i][0]);
765 : pair_samples_values[i][1] = static_cast<int32_t>(vesin_neighbor_list->pairs[i][1]);
766 : pair_samples_values[i][2] = vesin_neighbor_list->shifts[i][0];
767 : pair_samples_values[i][3] = vesin_neighbor_list->shifts[i][1];
768 : pair_samples_values[i][4] = vesin_neighbor_list->shifts[i][2];
769 : }
770 :
771 : auto neighbor_samples = torch::make_intrusive<metatensor_torch::LabelsHolder>(
772 : std::vector<std::string>{"first_atom", "second_atom", "cell_shift_a", "cell_shift_b", "cell_shift_c"},
773 : pair_samples_values.to(this->device_)
774 : );
775 :
776 : auto neighbors = torch::make_intrusive<metatensor_torch::TensorBlockHolder>(
777 : pair_vectors.to(this->dtype_).to(this->device_),
778 : neighbor_samples,
779 : std::vector<metatensor_torch::Labels>{neighbor_component},
780 : neighbor_properties
781 : );
782 :
783 : return neighbors;
784 : }
785 :
786 : metatensor_torch::TensorBlock MetatensorPlumedAction::executeModel(metatensor_torch::System system) {
787 : try {
788 : auto ivalue_output = this->model_.forward({
789 : std::vector<metatensor_torch::System>{system},
790 : evaluations_options_,
791 : this->check_consistency_,
792 : });
793 :
794 : auto dict_output = ivalue_output.toGenericDict();
795 : auto cv = dict_output.at(this->model_output_);
796 : this->output_ = cv.toCustomClass<metatensor_torch::TensorMapHolder>();
797 : } catch (const std::exception& e) {
798 : plumed_merror("failed to evaluate the model: " + std::string(e.what()));
799 : }
800 :
801 : plumed_massert(this->output_->keys()->count() == 1, "output should have a single block");
802 : auto block = metatensor_torch::TensorMapHolder::block_by_id(this->output_, 0);
803 : plumed_massert(block->components().empty(), "components are not yet supported in the output");
804 :
805 : return block;
806 : }
807 :
808 :
809 : void MetatensorPlumedAction::calculate() {
810 : this->createSystem();
811 :
812 : auto block = this->executeModel(this->system_);
813 : auto torch_values = block->values().to(torch::kCPU).to(torch::kFloat64);
814 :
815 : if (static_cast<unsigned>(torch_values.size(0)) != this->n_samples_) {
816 : plumed_merror(
817 : "expected the model to return a TensorBlock with " +
818 : std::to_string(this->n_samples_) + " samples, got " +
819 : std::to_string(torch_values.size(0)) + " instead"
820 : );
821 : } else if (static_cast<unsigned>(torch_values.size(1)) != this->n_properties_) {
822 : plumed_merror(
823 : "expected the model to return a TensorBlock with " +
824 : std::to_string(this->n_properties_) + " properties, got " +
825 : std::to_string(torch_values.size(1)) + " instead"
826 : );
827 : }
828 :
829 : Value* value = this->getPntrToComponent(0);
830 : // reshape the plumed `Value` to hold the data returned by the model
831 : if (n_samples_ == 1) {
832 : if (n_properties_ == 1) {
833 : value->set(torch_values.item<double>());
834 : } else {
835 : // we have multiple CV describing a single thing (atom or full system)
836 : for (unsigned i=0; i<n_properties_; i++) {
837 : value->set(i, torch_values[0][i].item<double>());
838 : }
839 : }
840 : } else {
841 : auto samples = block->samples();
842 : plumed_assert((samples->names() == std::vector<std::string>{"system", "atom"}));
843 :
844 : auto samples_values = samples->values().to(torch::kCPU);
845 : auto selected_atoms = this->evaluations_options_->get_selected_atoms();
846 :
847 : // handle the possibility that samples are returned in
848 : // a non-sorted order.
849 : auto get_output_location = [&](unsigned i) {
850 : if (selected_atoms.has_value()) {
851 : // If the users picked some selected atoms, then we store the
852 : // output in the same order as the selection was given
853 : auto sample = samples_values.index({static_cast<int64_t>(i), torch::indexing::Slice()});
854 : auto position = selected_atoms.value()->position(sample);
855 : plumed_assert(position.has_value());
856 : return static_cast<unsigned>(position.value());
857 : } else {
858 : return static_cast<unsigned>(samples_values[i][1].item<int32_t>());
859 : }
860 : };
861 :
862 : if (n_properties_ == 1) {
863 : // we have a single CV describing multiple things (i.e. atoms)
864 : for (unsigned i=0; i<n_samples_; i++) {
865 : auto output_i = get_output_location(i);
866 : value->set(output_i, torch_values[i][0].item<double>());
867 : }
868 : } else {
869 : // the CV is a matrix
870 : for (unsigned i=0; i<n_samples_; i++) {
871 : auto output_i = get_output_location(i);
872 : for (unsigned j=0; j<n_properties_; j++) {
873 : value->set(output_i * n_properties_ + j, torch_values[i][j].item<double>());
874 : }
875 : }
876 : }
877 : }
878 : }
879 :
880 :
881 : void MetatensorPlumedAction::apply() {
882 : const auto* value = this->getPntrToComponent(0);
883 : if (!value->forcesWereAdded()) {
884 : return;
885 : }
886 :
887 : auto block = metatensor_torch::TensorMapHolder::block_by_id(this->output_, 0);
888 : auto torch_values = block->values().to(torch::kCPU).to(torch::kFloat64);
889 :
890 : if (!torch_values.requires_grad()) {
891 : this->warning(
892 : "the output of the model does not requires gradients, this might "
893 : "indicate a problem"
894 : );
895 : return;
896 : }
897 :
898 : auto output_grad = torch::zeros_like(torch_values);
899 : if (n_samples_ == 1) {
900 : if (n_properties_ == 1) {
901 : output_grad[0][0] = value->getForce();
902 : } else {
903 : for (unsigned i=0; i<n_properties_; i++) {
904 : output_grad[0][i] = value->getForce(i);
905 : }
906 : }
907 : } else {
908 : auto samples = block->samples();
909 : plumed_assert((samples->names() == std::vector<std::string>{"system", "atom"}));
910 :
911 : auto samples_values = samples->values().to(torch::kCPU);
912 : auto selected_atoms = this->evaluations_options_->get_selected_atoms();
913 :
914 : // see above for an explanation of why we use this function
915 : auto get_output_location = [&](unsigned i) {
916 : if (selected_atoms.has_value()) {
917 : auto sample = samples_values.index({static_cast<int64_t>(i), torch::indexing::Slice()});
918 : auto position = selected_atoms.value()->position(sample);
919 : plumed_assert(position.has_value());
920 : return static_cast<unsigned>(position.value());
921 : } else {
922 : return static_cast<unsigned>(samples_values[i][1].item<int32_t>());
923 : }
924 : };
925 :
926 : if (n_properties_ == 1) {
927 : for (unsigned i=0; i<n_samples_; i++) {
928 : auto output_i = get_output_location(i);
929 : output_grad[i][0] = value->getForce(output_i);
930 : }
931 : } else {
932 : for (unsigned i=0; i<n_samples_; i++) {
933 : auto output_i = get_output_location(i);
934 : for (unsigned j=0; j<n_properties_; j++) {
935 : output_grad[i][j] = value->getForce(output_i * n_properties_ + j);
936 : }
937 : }
938 : }
939 : }
940 :
941 : this->system_->positions().mutable_grad() = torch::Tensor();
942 : this->strain_.mutable_grad() = torch::Tensor();
943 :
944 : torch_values.backward(output_grad);
945 : auto positions_grad = this->system_->positions().grad();
946 : auto strain_grad = this->strain_.grad();
947 :
948 : positions_grad = positions_grad.to(torch::kCPU).to(torch::kFloat64);
949 : strain_grad = strain_grad.to(torch::kCPU).to(torch::kFloat64);
950 :
951 : plumed_assert(positions_grad.sizes().size() == 2);
952 : plumed_assert(positions_grad.is_contiguous());
953 :
954 : plumed_assert(strain_grad.sizes().size() == 2);
955 : plumed_assert(strain_grad.is_contiguous());
956 :
957 : auto derivatives = std::vector<double>(
958 : positions_grad.data_ptr<double>(),
959 : positions_grad.data_ptr<double>() + 3 * this->system_->size()
960 : );
961 :
962 : // add virials to the derivatives
963 : derivatives.push_back(-strain_grad[0][0].item<double>());
964 : derivatives.push_back(-strain_grad[0][1].item<double>());
965 : derivatives.push_back(-strain_grad[0][2].item<double>());
966 :
967 : derivatives.push_back(-strain_grad[1][0].item<double>());
968 : derivatives.push_back(-strain_grad[1][1].item<double>());
969 : derivatives.push_back(-strain_grad[1][2].item<double>());
970 :
971 : derivatives.push_back(-strain_grad[2][0].item<double>());
972 : derivatives.push_back(-strain_grad[2][1].item<double>());
973 : derivatives.push_back(-strain_grad[2][2].item<double>());
974 :
975 : unsigned index = 0;
976 : this->setForcesOnAtoms(derivatives, index);
977 : }
978 :
979 : } // namespace metatensor
980 : } // namespace PLMD
981 :
982 :
983 : #endif
984 :
985 :
986 : namespace PLMD {
987 : namespace metatensor {
988 :
989 : // use the same implementation for both the actual action and the dummy one
990 : // (when libtorch and libmetatensor could not be found).
991 2 : void MetatensorPlumedAction::registerKeywords(Keywords& keys) {
992 2 : Action::registerKeywords(keys);
993 2 : ActionAtomistic::registerKeywords(keys);
994 2 : ActionWithValue::registerKeywords(keys);
995 :
996 2 : keys.add("compulsory", "MODEL", "path to the exported metatensor model");
997 2 : keys.add("optional", "EXTENSIONS_DIRECTORY", "path to the directory containing TorchScript extensions to load");
998 2 : keys.add("optional", "DEVICE", "Torch device to use for the calculation");
999 :
1000 2 : keys.addFlag("CHECK_CONSISTENCY", false, "Should we enable internal consistency of the model");
1001 :
1002 2 : keys.add("numbered", "SPECIES", "the atoms in each PLUMED species");
1003 4 : keys.reset_style("SPECIES", "atoms");
1004 :
1005 2 : keys.add("optional", "SELECTED_ATOMS", "subset of atoms that should be used for the calculation");
1006 4 : keys.reset_style("SELECTED_ATOMS", "atoms");
1007 :
1008 2 : keys.add("optional", "SPECIES_TO_TYPES", "mapping from PLUMED SPECIES to metatensor's atomic types");
1009 :
1010 4 : keys.addOutputComponent("outputs", "default", "scalar", "collective variable created by the metatensor model");
1011 4 : keys.setValueDescription("scalar/vector/matrix","collective variable created by the metatensor model");
1012 2 : }
1013 :
1014 : PLUMED_REGISTER_ACTION(MetatensorPlumedAction, "METATENSOR")
1015 :
1016 : } // namespace metatensor
1017 : } // namespace PLMD
|