Line data Source code
1 : /* +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
2 : Copyright (c) 2011-2023 The plumed team
3 : (see the PEOPLE file at the root of the distribution for a list of names)
4 :
5 : See http://www.plumed.org for more information.
6 :
7 : This file is part of plumed, version 2.
8 :
9 : plumed is free software: you can redistribute it and/or modify
10 : it under the terms of the GNU Lesser General Public License as published by
11 : the Free Software Foundation, either version 3 of the License, or
12 : (at your option) any later version.
13 :
14 : plumed is distributed in the hope that it will be useful,
15 : but WITHOUT ANY WARRANTY; without even the implied warranty of
16 : MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17 : GNU Lesser General Public License for more details.
18 :
19 : You should have received a copy of the GNU Lesser General Public License
20 : along with plumed. If not, see <http://www.gnu.org/licenses/>.
21 : +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ */
22 : #include "core/ActionWithMatrix.h"
23 : #include "core/ActionRegister.h"
24 : #include "tools/Torsion.h"
25 :
26 :
27 : #include <iostream>
28 :
29 : namespace PLMD {
30 : namespace crystdistrib {
31 :
32 : //+PLUMEDOC MCOLVAR QUATERNION_BOND_PRODUCT_MATRIX
33 : /*
34 : Calculate the product between a matrix of quaternions and the bonds
35 :
36 : \par Examples
37 :
38 : */
39 : //+ENDPLUMEDOC
40 :
41 : class QuaternionBondProductMatrix : public ActionWithMatrix {
42 : private:
43 : unsigned nderivatives;
44 : std::vector<bool> stored;
45 : // const Vector4d& rightMultiply(Tensor4d&, Vector4d&);
46 : public:
47 : static void registerKeywords( Keywords& keys );
48 : explicit QuaternionBondProductMatrix(const ActionOptions&);
49 : unsigned getNumberOfDerivatives();
50 : unsigned getNumberOfColumns() const override ;
51 : void setupForTask( const unsigned& task_index, std::vector<unsigned>& indices, MultiValue& myvals ) const ;
52 : void performTask( const std::string& controller, const unsigned& index1, const unsigned& index2, MultiValue& myvals ) const override;
53 : void runEndOfRowJobs( const unsigned& ival, const std::vector<unsigned> & indices, MultiValue& myvals ) const override ;
54 : };
55 :
56 : PLUMED_REGISTER_ACTION(QuaternionBondProductMatrix,"QUATERNION_BOND_PRODUCT_MATRIX")
57 :
58 :
59 : //const Vector4d& QuaternionBondMatrix::rightMultiply(Tensor4d& pref, Vector4d& quat) {
60 : // Vector4d temp;
61 : // int sumTemp;
62 : // for (int i=0; i<4; i++){ //rows
63 : // sumTemp=0;
64 : // for (int j=0; j<4; j++){ //cols
65 : // sumTemp+=pref(i,j)*quat[j];
66 : // }
67 : // temp[i]=sumTemp;
68 : // }
69 : // return temp;
70 : //}
71 :
72 :
73 :
74 :
75 7 : void QuaternionBondProductMatrix::registerKeywords( Keywords& keys ) {
76 7 : ActionWithMatrix::registerKeywords(keys);
77 14 : keys.addInputKeyword("compulsory","ARG","vector/matrix","this action takes 8 arguments. The first four should be the w,i,j and k components of a quaternion vector. The second four should be contact matrix and the matrices should be the x, y and z components of the bond vectors");
78 14 : keys.addOutputComponent("w","default","matrix","the real component of quaternion");
79 14 : keys.addOutputComponent("i","default","matrix","the i component of the quaternion");
80 14 : keys.addOutputComponent("j","default","matrix","the j component of the quaternion");
81 14 : keys.addOutputComponent("k","default","matrix","the k component of the quaternion");
82 7 : }
83 :
84 4 : QuaternionBondProductMatrix::QuaternionBondProductMatrix(const ActionOptions&ao):
85 : Action(ao),
86 4 : ActionWithMatrix(ao)
87 : {
88 4 : if( getNumberOfArguments()!=8 ) error("should be eight arguments to this action, 4 quaternion components and 4 matrices");
89 4 : unsigned nquat = getPntrToArgument(0)->getNumberOfValues();
90 20 : for(unsigned i=0; i<4; ++i) {
91 16 : Value* myarg=getPntrToArgument(i); myarg->buildDataStore();
92 16 : if( myarg->getRank()!=1 ) error("first four arguments to this action should be vectors");
93 16 : if( (myarg->getPntrToAction())->getName()!="QUATERNION_VECTOR" ) error("first four arguments to this action should be quaternions");
94 16 : std::string mylab=getPntrToArgument(i)->getName(); std::size_t dot=mylab.find_first_of(".");
95 24 : if( i==0 && mylab.substr(dot+1)!="w" ) error("quaternion arguments are in wrong order");
96 24 : if( i==1 && mylab.substr(dot+1)!="i" ) error("quaternion arguments are in wrong order");
97 24 : if( i==2 && mylab.substr(dot+1)!="j" ) error("quaternion arguments are in wrong order");
98 24 : if( i==3 && mylab.substr(dot+1)!="k" ) error("quaternion arguments are in wrong order");
99 : }
100 4 : std::vector<unsigned> shape( getPntrToArgument(4)->getShape() );
101 20 : for(unsigned i=4; i<8; ++i) {
102 : Value* myarg=getPntrToArgument(i);
103 16 : if( myarg->getRank()!=2 ) error("second four arguments to this action should be matrices");
104 16 : if( myarg->getShape()[0]!=shape[0] || myarg->getShape()[1]!=shape[1] ) error("matrices should all have the same shape");
105 16 : if( myarg->getShape()[0]!=nquat ) error("number of rows in matrix should equal number of input quaternions");
106 16 : std::string mylab=getPntrToArgument(i)->getName(); std::size_t dot=mylab.find_first_of(".");
107 24 : if( i==5 && mylab.substr(dot+1)!="x" ) error("quaternion arguments are in wrong order");
108 24 : if( i==6 && mylab.substr(dot+1)!="y" ) error("quaternion arguments are in wrong order");
109 24 : if( i==7 && mylab.substr(dot+1)!="z" ) error("quaternion arguments are in wrong order");
110 : }
111 8 : addComponent( "w", shape ); componentIsNotPeriodic("w");
112 8 : addComponent( "i", shape ); componentIsNotPeriodic("i");
113 8 : addComponent( "j", shape ); componentIsNotPeriodic("j");
114 8 : addComponent( "k", shape ); componentIsNotPeriodic("k");
115 4 : done_in_chain=true; nderivatives = buildArgumentStore(0);
116 :
117 4 : std::string headstr=getFirstActionInChain()->getLabel(); stored.resize( getNumberOfArguments() );
118 36 : for(unsigned i=0; i<getNumberOfArguments(); ++i) stored[i] = getPntrToArgument(i)->ignoreStoredValue( headstr );
119 4 : }
120 :
121 32 : unsigned QuaternionBondProductMatrix::getNumberOfDerivatives() {
122 32 : return nderivatives;
123 : }
124 :
125 96 : unsigned QuaternionBondProductMatrix::getNumberOfColumns() const {
126 96 : const ActionWithMatrix* am=dynamic_cast<const ActionWithMatrix*>( getPntrToArgument(4)->getPntrToAction() );
127 96 : plumed_assert( am ); return am->getNumberOfColumns();
128 : }
129 :
130 0 : void QuaternionBondProductMatrix::setupForTask( const unsigned& task_index, std::vector<unsigned>& indices, MultiValue& myvals ) const {
131 0 : unsigned start_n = getPntrToArgument(4)->getShape()[0], size_v = getPntrToArgument(4)->getShape()[1];
132 0 : if( indices.size()!=size_v+1 ) indices.resize( size_v+1 );
133 0 : for(unsigned i=0; i<size_v; ++i) indices[i+1] = start_n + i;
134 : myvals.setSplitIndex( size_v + 1 );
135 0 : }
136 :
137 381366 : void QuaternionBondProductMatrix::performTask( const std::string& controller, const unsigned& index1, const unsigned& index2, MultiValue& myvals ) const {
138 381366 : unsigned ind2=index2;
139 381366 : if( index2>=getPntrToArgument(0)->getShape()[0] ) ind2 = index2 - getPntrToArgument(0)->getShape()[0];
140 :
141 381366 : std::vector<double> quat(4), bond(4), quatTemp(4);
142 381366 : std::vector<Tensor4d> dqt(2); //dqt[0] -> derivs w.r.t quat [dwt/dw1 dwt/di1 dwt/dj1 dwt/dk1]
143 : //[dit/dw1 dit/di1 dit/dj1 dit/dk1] etc, and dqt[1] is w.r.t the vector-turned-quaternion called bond
144 :
145 : // Retrieve the quaternion
146 1906830 : for(unsigned i=0; i<4; ++i) quat[i] = getArgumentElement( i, index1, myvals );
147 :
148 : // Retrieve the components of the matrix
149 381366 : double weight = getElementOfMatrixArgument( 4, index1, ind2, myvals );
150 1525464 : for(unsigned i=1; i<4; ++i) bond[i] = getElementOfMatrixArgument( 4+i, index1, ind2, myvals );
151 :
152 : // calculate normalization factor
153 381366 : bond[0]=0.0;
154 381366 : double normFac = 1/sqrt(bond[1]*bond[1] + bond[2]*bond[2] + bond[3]*bond[3]);
155 381366 : if (bond[1] == 0.0 && bond[2]==0.0 && bond[3]==0) normFac=1; //just for the case where im comparing a quat to itself, itll be 0 at the end anyway
156 : double normFac3 = normFac*normFac*normFac;
157 : //I hold off on normalizing because this can be done at the very end, and it makes the derivatives with respect to 'bond' more simple
158 :
159 :
160 :
161 381366 : std::vector<double> quat_conj(4);
162 381366 : quat_conj[0] = quat[0]; quat_conj[1] = -1*quat[1]; quat_conj[2] = -1*quat[2]; quat_conj[3] = -1*quat[3];
163 : //make a conjugate of q1 my own sanity
164 :
165 :
166 :
167 :
168 : //q1_conj * r first, while keep track of derivs
169 : double pref=1;
170 : double conj=1;
171 : double pref2=1;
172 : //real part of q1*q2
173 :
174 1906830 : for(unsigned i=0; i<4; ++i) {
175 1525464 : if( i>0 ) {pref=-1; conj=-1; pref2=-1;}
176 1525464 : quatTemp[0]+=pref*quat_conj[i]*bond[i];
177 1525464 : dqt[0](0,i) = conj*pref*bond[i];
178 1525464 : dqt[1](0,i) = pref2*quat_conj[i];
179 : //addDerivativeOnVectorArgument( false, 0, i, index1, conj*pref*bond[i], myvals );
180 : //addDerivativeOnVectorArgument( false, 0, 4+i, ind2, conj*pref*quat[i], myvals );
181 : }
182 : //i component
183 : pref=1;
184 : conj=1;
185 : pref2=1;
186 :
187 1906830 : for (unsigned i=0; i<4; i++) {
188 1525464 : if(i==3) pref=-1;
189 : else pref=1;
190 1525464 : if(i==2) pref2=-1;
191 : else pref2=1;
192 1525464 : if (i>0) conj=-1;
193 :
194 1525464 : quatTemp[1]+=pref*quat_conj[i]*bond[(5-i)%4];
195 1525464 : dqt[0](1,i) =conj*pref*bond[(5-i)%4];
196 1525464 : dqt[1](1,i) = pref2*quat_conj[(5-i)%4];
197 : //addDerivativeOnVectorArgument( false, 1, i, index1, conj*pref*bond[(5-i)%4], myvals );
198 : //addDerivativeOnVectorArgument( false, 1, 4+i, ind2, conj*pref*quat[i], myvals );
199 : }
200 :
201 : //j component
202 : pref=1;
203 : pref2=1;
204 : conj=1;
205 :
206 1906830 : for (unsigned i=0; i<4; i++) {
207 1525464 : if(i==1) pref=-1;
208 : else pref=1;
209 1525464 : if (i==3) pref2=-1;
210 : else pref2=1;
211 1525464 : if (i>0) conj=-1;
212 :
213 1525464 : quatTemp[2]+=pref*quat_conj[i]*bond[(i+2)%4];
214 1525464 : dqt[0](2,i)=conj*pref*bond[(i+2)%4];
215 1525464 : dqt[1](2,i)=pref2*quat_conj[(i+2)%4];
216 : //addDerivativeOnVectorArgument( false, 2, i, index1, conj*pref*bond[(i+2)%4], myvals );
217 : //addDerivativeOnVectorArgument( false, 2, 4+i, ind2, conj*pref*quat[i], myvals );
218 : }
219 :
220 : //k component
221 : pref=1;
222 : pref2=1;
223 : conj=1;
224 :
225 1906830 : for (unsigned i=0; i<4; i++) {
226 1525464 : if(i==2) pref=-1;
227 : else pref=1;
228 1525464 : if(i==1) pref2=-1;
229 : else pref2=1;
230 1525464 : if(i>0) conj=-1;
231 1525464 : quatTemp[3]+=pref*quat_conj[i]*bond[(3-i)];
232 1525464 : dqt[0](3,i)=conj*pref*bond[3-i];
233 1525464 : dqt[1](3,i)= pref2*quat_conj[3-i];
234 : //addDerivativeOnVectorArgument( false, 3, i, index1, conj*pref*bond[3-i], myvals );
235 : //addDerivativeOnVectorArgument( false, 3, 4+i, ind2, conj*pref*quat[i], myvals );
236 :
237 : }
238 :
239 :
240 : //now previous ^ product times quat again, not conjugated
241 : //real part of q1*q2
242 381366 : double tempDot=0,wf=0,xf=0,yf=0,zf=0;
243 : pref=1;
244 : pref2=1;
245 1906830 : for(unsigned i=0; i<4; ++i) {
246 1525464 : if( i>0 ) {pref=-1; pref2=-1;}
247 1525464 : myvals.addValue( getConstPntrToComponent(0)->getPositionInStream(), normFac*pref*quatTemp[i]*quat[i] );
248 : wf+=normFac*pref*quatTemp[i]*quat[i];
249 1525464 : if( doNotCalculateDerivatives() ) continue ;
250 1525464 : tempDot=(dotProduct(Vector4d(quat[0],-quat[1],-quat[2],-quat[3]), dqt[0].getCol(i)) + pref2*quatTemp[i])*normFac;
251 1525464 : addDerivativeOnVectorArgument( stored[i], 0, i, index1, tempDot, myvals);
252 : }
253 : //had to split because bond's derivatives depend on the value of the overall quaternion component
254 : //addDerivativeOnMatrixArgument( false, 0, 4, index1, ind2, 0.0, myvals );
255 1906830 : for(unsigned i=0; i<4; ++i) {
256 1525464 : tempDot=dotProduct(Vector4d(quat[0],-quat[1],-quat[2],-quat[3]), dqt[1].getCol(i))*normFac;
257 1525464 : if (i!=0 )addDerivativeOnMatrixArgument( stored[4+i], 0, 4+i, index1, ind2, tempDot, myvals );
258 381366 : else addDerivativeOnMatrixArgument( stored[4+i], 0, 4+i, index1, ind2, 0.0, myvals );
259 : }
260 : // for (unsigned i=0; i<4; ++i) {
261 : //myvals.addValue( getConstPntrToComponent(0)->getPositionInStream(), 0.0 );
262 : //if( doNotCalculateDerivatives() ) continue ;
263 : //addDerivativeOnVectorArgument( false, 0, i, index1, 0.0, myvals);
264 : //addDerivativeOnVectorArgument( false, 0, 4+i, ind2, 0.0 , myvals);
265 : // }
266 : //the w component should always be zero, barring some catastrophe, but we calculate it out anyway
267 :
268 : //i component
269 : pref=1;
270 : pref2=1;
271 1906830 : for (unsigned i=0; i<4; i++) {
272 1525464 : if(i==3) pref=-1;
273 : else pref=1;
274 1525464 : myvals.addValue( getConstPntrToComponent(1)->getPositionInStream(), normFac*pref*quatTemp[i]*quat[(5-i)%4]);
275 1525464 : xf+=normFac*pref*quatTemp[i]*quat[(5-i)%4];
276 1525464 : if(i==2) pref2=-1;
277 : else pref2=1;
278 1525464 : if( doNotCalculateDerivatives() ) continue ;
279 1525464 : tempDot=(dotProduct(Vector4d(quat[1],quat[0],quat[3],-quat[2]), dqt[0].getCol(i)) + pref2*quatTemp[(5-i)%4])*normFac;
280 1525464 : addDerivativeOnVectorArgument( stored[i], 1, i, index1, tempDot, myvals);
281 : }
282 : //addDerivativeOnMatrixArgument( false, 1, 4, index1, ind2, 0.0, myvals );
283 :
284 1906830 : for(unsigned i=0; i<4; ++i) {
285 1525464 : tempDot=dotProduct(Vector4d(quat[1],quat[0],quat[3],-quat[2]), dqt[1].getCol(i))*normFac;
286 1525464 : if (i!=0) addDerivativeOnMatrixArgument( stored[4+i], 1, 4+i, index1, ind2, tempDot+(-bond[i]*normFac*normFac*xf), myvals );
287 381366 : else addDerivativeOnMatrixArgument( stored[4+i], 1, 4+i, index1, ind2, 0.0, myvals );
288 :
289 : }
290 :
291 :
292 : //j component
293 : pref=1;
294 : pref2=1;
295 1906830 : for (unsigned i=0; i<4; i++) {
296 1525464 : if(i==1) pref=-1;
297 : else pref=1;
298 1525464 : if (i==3) pref2=-1;
299 : else pref2=1;
300 :
301 1525464 : myvals.addValue( getConstPntrToComponent(2)->getPositionInStream(), normFac*pref*quatTemp[i]*quat[(i+2)%4]);
302 1525464 : yf+=normFac*pref*quatTemp[i]*quat[(i+2)%4];
303 1525464 : if( doNotCalculateDerivatives() ) continue ;
304 1525464 : tempDot=(dotProduct(Vector4d(quat[2],-quat[3],quat[0],quat[1]), dqt[0].getCol(i)) + pref2*quatTemp[(i+2)%4])*normFac;
305 1525464 : addDerivativeOnVectorArgument( stored[i], 2, i, index1, tempDot, myvals);
306 : }
307 : // addDerivativeOnMatrixArgument( false, 2, 4, index1, ind2,0.0 , myvals );
308 :
309 1906830 : for(unsigned i=0; i<4; ++i) {
310 1525464 : tempDot=dotProduct(Vector4d(quat[2],-quat[3],quat[0],quat[1]), dqt[1].getCol(i))*normFac;
311 1525464 : if (i!=0) addDerivativeOnMatrixArgument( stored[4+i], 2, 4+i, index1, ind2, tempDot+(-bond[i]*normFac*normFac*yf), myvals );
312 381366 : else addDerivativeOnMatrixArgument( stored[4+i], 2, 4+i, index1, ind2, 0.0, myvals );
313 :
314 :
315 : }
316 :
317 : //k component
318 : pref=1;
319 : pref2=1;
320 1906830 : for (unsigned i=0; i<4; i++) {
321 1525464 : if(i==2) pref=-1;
322 : else pref=1;
323 1525464 : if(i==1) pref2=-1;
324 : else pref2=1;
325 :
326 1525464 : myvals.addValue( getConstPntrToComponent(3)->getPositionInStream(), normFac*pref*quatTemp[i]*quat[(3-i)]);
327 1525464 : zf+=normFac*pref*quatTemp[i]*quat[(3-i)];
328 1525464 : if( doNotCalculateDerivatives() ) continue ;
329 1525464 : tempDot=(dotProduct(Vector4d(quat[3],quat[2],-quat[1],quat[0]), dqt[0].getCol(i)) + pref2*quatTemp[(3-i)])*normFac;
330 1525464 : addDerivativeOnVectorArgument( stored[i], 3, i, index1, tempDot, myvals);
331 : }
332 : //addDerivativeOnMatrixArgument( false, 3, 4, index1, ind2, 0.0 , myvals );
333 :
334 1906830 : for(unsigned i=0; i<4; ++i) {
335 1525464 : tempDot=dotProduct(Vector4d(quat[3],quat[2],-quat[1],quat[0]), dqt[1].getCol(i))*normFac;
336 1525464 : if (i!=0) addDerivativeOnMatrixArgument( stored[4+i], 3, 4+i, index1, ind2, tempDot+(-bond[i]*normFac*normFac*zf), myvals );
337 381366 : else addDerivativeOnMatrixArgument( stored[4+i], 3, 4+i, index1, ind2, 0.0, myvals );
338 :
339 :
340 : }
341 381366 : if( doNotCalculateDerivatives() ) return ;
342 :
343 1906830 : for(unsigned outcomp=0; outcomp<4; ++outcomp) {
344 1525464 : unsigned ostrn = getConstPntrToComponent(outcomp)->getPositionInStream();
345 7627320 : for(unsigned i=4; i<8; ++i) {
346 : bool found=false;
347 6101856 : for(unsigned j=4; j<i; ++j) {
348 4576392 : if( arg_deriv_starts[i]==arg_deriv_starts[j] ) { found=true; break; }
349 : }
350 6101856 : if( found || !stored[i] ) continue;
351 :
352 : unsigned istrn = getPntrToArgument(i)->getPositionInStream();
353 24407424 : for(unsigned k=0; k<myvals.getNumberActive(istrn); ++k) {
354 22881960 : unsigned kind=myvals.getActiveIndex(istrn,k); myvals.updateIndex( ostrn, kind );
355 : }
356 : }
357 : }
358 : }
359 :
360 1614 : void QuaternionBondProductMatrix::runEndOfRowJobs( const unsigned& ival, const std::vector<unsigned> & indices, MultiValue& myvals ) const {
361 1614 : if( doNotCalculateDerivatives() || !matrixChainContinues() ) return ;
362 :
363 8070 : for(unsigned j=0; j<getNumberOfComponents(); ++j) {
364 6456 : unsigned nmat = getConstPntrToComponent(j)->getPositionInMatrixStash(), nmat_ind = myvals.getNumberOfMatrixRowDerivatives( nmat );
365 : std::vector<unsigned>& matrix_indices( myvals.getMatrixRowDerivativeIndices( nmat ) ); unsigned ntwo_atoms = myvals.getSplitIndex();
366 : // Quaternion
367 32280 : for(unsigned k=0; k<4; ++k) { matrix_indices[nmat_ind] = arg_deriv_starts[k] + ival; nmat_ind++; }
368 : // Loop over row of matrix
369 32280 : for(unsigned n=4; n<8; ++n) {
370 : bool found=false;
371 25824 : for(unsigned k=4; k<n; ++k) {
372 19368 : if( arg_deriv_starts[k]==arg_deriv_starts[n] ) { found=true; break; }
373 : }
374 25824 : if( found ) continue;
375 : unsigned istrn = getPntrToArgument(n)->getPositionInMatrixStash();
376 : std::vector<unsigned>& imat_indices( myvals.getMatrixRowDerivativeIndices( istrn ) );
377 4660320 : for(unsigned k=0; k<myvals.getNumberOfMatrixRowDerivatives( istrn ); ++k) matrix_indices[nmat_ind + k] = arg_deriv_starts[n] + imat_indices[k];
378 6456 : nmat_ind += myvals.getNumberOfMatrixRowDerivatives( getPntrToArgument(4)->getPositionInMatrixStash() );
379 : }
380 : myvals.setNumberOfMatrixRowDerivatives( nmat, nmat_ind );
381 : }
382 : }
383 :
384 : }
385 : }
|