00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031
00032
00033
00034
00035
00036
00037 #ifndef __RVLALG_UMIN_LSQR__
00038 #define __RVLALG_UMIN_LSQR__
00039
00040 #include "alg.hh"
00041 #include "terminator.hh"
00042 #include "linop.hh"
00043 #include "table.hh"
00044
00045 using namespace RVLAlg;
00046
00047 namespace RVLUmin {
00048
00049 using namespace RVL;
00050 using namespace RVLAlg;
00051
00076 template<typename Scalar>
00077 class LSQRStep : public Algorithm {
00078
00079 typedef typename ScalarFieldTraits<Scalar>::AbsType atype;
00080
00081 public:
00082
00083 LSQRStep(LinearOp<Scalar> const & _A,
00084 Vector<Scalar> & _x,
00085 Vector<Scalar> const & _b,
00086 atype & _rnorm,
00087 atype & _nrnorm)
00088 : A(_A), x(_x), b(_b), rnorm(_rnorm), nrnorm(_nrnorm), v(A.getDomain()), alphav(A.getDomain()),
00089 u(A.getRange()), betau(A.getRange()), w(A.getDomain()) {
00090
00091
00092 beta=b.norm();
00093 rnorm=beta;
00094 atype tmp;
00095 if (ProtectedDivision<atype>(ScalarFieldTraits<atype>::One(),beta,tmp)) {
00096 RVLException e;
00097 e<<"Error: LSQRStep constructor\n";
00098 e<<" RHS has vanishing norm\n";
00099 throw e;
00100 }
00101 u.scale(tmp,b);
00102 A.applyAdjOp(u,v);
00103 alpha = v.norm();
00104 nrnorm = alpha*rnorm;
00105 if (ProtectedDivision<atype>(ScalarFieldTraits<atype>::One(),alpha,tmp)) {
00106 RVLException e;
00107 e<<"Error: LSQRStep constructor\n";
00108 e<<" Normal residual has vanishing norm\n";
00109 throw e;
00110 }
00111 v.scale(tmp);
00112 w.copy(v);
00113 phibar = beta;
00114 rhobar = alpha;
00115 }
00116
00120 void run() {
00121 try {
00122
00123 A.applyOp(v,betau);
00124
00125 Scalar stmp = -alpha;
00126 betau.linComb(stmp,u);
00127
00128 beta=betau.norm();
00129 atype tmp;
00130 if (ProtectedDivision<atype>(ScalarFieldTraits<atype>::One(),beta,tmp)) {
00131 RVLException e;
00132 e<<"Error: LSQRStep::run\n";
00133 e<<" beta vanishes\n";
00134 throw e;
00135 }
00136
00137 stmp = tmp;
00138 u.scale(stmp,betau);
00139
00140
00141 A.applyAdjOp(u,alphav);
00142
00143 stmp = -beta;
00144 alphav.linComb(stmp,v);
00145
00146 alpha = alphav.norm();
00147 if (ProtectedDivision<atype>(ScalarFieldTraits<atype>::One(),alpha,tmp)) {
00148 RVLException e;
00149 e<<"Error: LSQRStep::run\n";
00150 e<<" beta vanishes\n";
00151 throw e;
00152 }
00153
00154 stmp=tmp;
00155 v.scale(tmp,alphav);
00156
00157
00158 atype rho = sqrt(rhobar*rhobar + beta*beta);
00159
00160 atype c = rhobar/rho;
00161
00162 atype s = beta/rho;
00163
00164 atype theta = s*alpha;
00165
00166 rhobar = - c*alpha;
00167
00168 atype phi = c*phibar;
00169
00170 phibar = s*phibar;
00171
00172
00173 x.linComb(phi/rho,w);
00174
00175 w.scale(-theta/rho);
00176 w.linComb(ScalarFieldTraits<Scalar>::One(),v);
00177
00178
00179 rnorm = phibar;
00180 nrnorm = phibar*alpha*abs(c);
00181
00182 }
00183 catch (RVLException & e) {
00184 e<<"\ncalled from CGNEStep::run()\n";
00185 throw e;
00186 }
00187
00188 }
00189
00190 ~LSQRStep() {}
00191
00192 private:
00193
00194
00195 LinearOp<Scalar> const & A;
00196 Vector<Scalar> & x;
00197 Vector<Scalar> const & b;
00198 atype & rnorm;
00199 atype & nrnorm;
00200
00201
00202 Vector<Scalar> u;
00203 Vector<Scalar> v;
00204 Vector<Scalar> betau;
00205 Vector<Scalar> alphav;
00206 Vector<Scalar> w;
00207 atype alpha;
00208 atype beta;
00209 atype rhobar;
00210 atype phibar;
00211 };
00212
00271 template<typename Scalar>
00272 class LSQRAlg: public Algorithm, public Terminator {
00273
00274 typedef typename ScalarFieldTraits<Scalar>::AbsType atype;
00275
00276 public:
00277
00317 LSQRAlg(RVL::Vector<Scalar> & _x,
00318 LinearOp<Scalar> const & _inA,
00319 Vector<Scalar> const & _rhs,
00320 atype & _rnorm,
00321 atype & _nrnorm,
00322 atype _rtol = 100.0*numeric_limits<atype>::epsilon(),
00323 atype _nrtol = 100.0*numeric_limits<atype>::epsilon(),
00324 int _maxcount = 10,
00325 atype _maxstep = numeric_limits<atype>::max(),
00326 ostream & _str = cout)
00327 : inA(_inA),
00328 x(_x),
00329 rhs(_rhs),
00330 rnorm(_rnorm),
00331 nrnorm(_nrnorm),
00332 rtol(_rtol),
00333 nrtol(_nrtol),
00334 maxstep(_maxstep),
00335 maxcount(_maxcount),
00336 count(0),
00337 proj(false),
00338 str(_str),
00339 step(inA,x,rhs,rnorm,nrnorm)
00340 { x.zero(); }
00341
00342 ~LSQRAlg() {}
00343
00344 bool query() { return proj; }
00345
00346 void run() {
00347
00348 vector<string> names(2);
00349 vector<atype *> nums(2);
00350 vector<atype> tols(2);
00351 names[0]="Residual Norm"; nums[0]=&rnorm; tols[0]=rtol;
00352 names[1]="Gradient Norm"; nums[1]=&nrnorm; tols[1]=nrtol;
00353 str<<"========================== BEGIN LSQR =========================\n";
00354 VectorCountingThresholdIterationTable<atype> stop1(maxcount,names,nums,tols,str);
00355 stop1.init();
00356
00357
00358 BallProjTerminator<Scalar> stop2(x,maxstep,str);
00359
00360 OrTerminator stop(stop1,stop2);
00361
00362 LoopAlg doit(step,stop);
00363 doit.run();
00364
00365 proj = stop2.query();
00366 if (proj) {
00367 Vector<Scalar> temp(inA.getRange());
00368 inA.applyOp(x,temp);
00369 temp.linComb(-1.0,rhs);
00370 rnorm=temp.norm();
00371 Vector<Scalar> temp1(inA.getDomain());
00372 inA.applyAdjOp(temp,temp1);
00373 nrnorm=temp1.norm();
00374 }
00375 count = stop1.getCount();
00376 str<<"=========================== END LSQR ==========================\n";
00377 }
00378
00379 int getCount() const { return count; }
00380
00381 private:
00382
00383 LinearOp<Scalar> const & inA;
00384 Vector<Scalar> & x;
00385 Vector<Scalar> const & rhs;
00386 atype & rnorm;
00387 atype & nrnorm;
00388 atype rtol;
00389 atype nrtol;
00390 atype maxstep;
00391 int maxcount;
00392 int count;
00393 mutable bool proj;
00394 ostream & str;
00395 LSQRStep<Scalar> step;
00396
00397
00398 LSQRAlg();
00399 LSQRAlg(LSQRAlg<Scalar> const &);
00400
00401 };
00402
00405 template<typename Scalar>
00406 class LSQRPolicyData {
00407
00408 typedef typename ScalarFieldTraits<Scalar>::AbsType atype;
00409
00410 public:
00411
00412 atype rtol;
00413 atype nrtol;
00414 atype Delta;
00415 int maxcount;
00416 bool verbose;
00417
00418 LSQRPolicyData(atype _rtol = numeric_limits<atype>::max(),
00419 atype _nrtol = numeric_limits<atype>::max(),
00420 atype _Delta = numeric_limits<atype>::max(),
00421 int _maxcount = 0,
00422 bool _verbose = false)
00423 : rtol(_rtol), nrtol(_nrtol), Delta(_Delta), maxcount(_maxcount), verbose(_verbose) {}
00424
00425 LSQRPolicyData(LSQRPolicyData<Scalar> const & a)
00426 : rtol(a.rtol), nrtol(a.nrtol), Delta(a.Delta), maxcount(a.maxcount), verbose(a.verbose) {}
00427
00428 ostream & write(ostream & str) const {
00429 str<<"\n";
00430 str<<"==============================================\n";
00431 str<<"LSQRPolicyData: \n";
00432 str<<"rtol = "<<rtol<<"\n";
00433 str<<"nrtol = "<<nrtol<<"\n";
00434 str<<"Delta = "<<Delta<<"\n";
00435 str<<"maxcount = "<<maxcount<<"\n";
00436 str<<"verbose = "<<verbose<<"\n";
00437 str<<"==============================================\n";
00438 return str;
00439 }
00440 };
00441
00464 template<typename Scalar>
00465 class LSQRPolicy {
00466
00467 typedef typename ScalarFieldTraits<Scalar>::AbsType atype;
00468
00469 public:
00489 LSQRAlg<Scalar> * build(Vector<Scalar> & x,
00490 LinearOp<Scalar> const & A,
00491 Vector<Scalar> const & d,
00492 atype & rnorm,
00493 atype & nrnorm,
00494 ostream & str) const {
00495 if (verbose)
00496 return new LSQRAlg<Scalar>(x,A,d,rnorm,nrnorm,rtol,nrtol,maxcount,Delta,str);
00497 else
00498 return new LSQRAlg<Scalar>(x,A,d,rnorm,nrnorm,rtol,nrtol,maxcount,Delta,nullstr);
00499 }
00500
00506 void assign(atype _rtol, atype _nrtol, atype _Delta, int _maxcount, bool _verbose) {
00507 rtol=_rtol; nrtol=_nrtol; Delta=_Delta; maxcount=_maxcount; verbose=_verbose;
00508 }
00509
00511 void assign(Table const & t) {
00512 rtol=getValueFromTable<atype>(t,"LSQR_ResTol");
00513 nrtol=getValueFromTable<atype>(t,"LSQR_GradTol");
00514 Delta=getValueFromTable<atype>(t,"TR_Delta");
00515 maxcount=getValueFromTable<int>(t,"LSQR_MaxItn");
00516 verbose=getValueFromTable<bool>(t,"LSQR_Verbose");
00517 }
00518
00520 void assign(LSQRPolicyData<Scalar> const & s) {
00521 rtol=s.rtol;
00522 nrtol=s.nrtol;
00523 Delta=s.Delta;
00524 maxcount=s.maxcount;
00525 verbose=s.verbose;
00526 }
00527
00532 mutable atype Delta;
00533
00542 LSQRPolicy(atype _rtol = numeric_limits<atype>::max(),
00543 atype _nrtol = numeric_limits<atype>::max(),
00544 atype _Delta = numeric_limits<atype>::max(),
00545 int _maxcount = 0,
00546 bool _verbose = true)
00547 : Delta(_Delta), rtol(_rtol), nrtol(_nrtol), maxcount(_maxcount), verbose(_verbose), nullstr(0) {}
00548
00549 LSQRPolicy(LSQRPolicy<Scalar> const & p)
00550 : Delta(p.Delta),
00551 rtol(p.rtol),
00552 nrtol(p.nrtol),
00553 maxcount(p.maxcount),
00554 verbose(p.verbose),
00555 nullstr(0) {}
00556
00557 private:
00558 mutable atype rtol;
00559 mutable atype nrtol;
00560 mutable int maxcount;
00561 mutable bool verbose;
00562 mutable std::ostream nullstr;
00563 };
00564 }
00565
00566 #endif