00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031
00032
00033
00034
00035
00036
00037 #ifndef __RVL_CGALG
00038 #define __RVL_CGALG
00039
00040 #include "alg.hh"
00041 #include "terminator.hh"
00042 #include "linop.hh"
00043
00044 namespace RVLUmin {
00045
00046 using namespace RVL;
00047 using namespace RVLAlg;
00048
00051 template<typename Scalar>
00052 bool realgt(Scalar left, Scalar right) {
00053 if (left > right) return true;
00054 return false;
00055 }
00056
00057 template<typename Scalar>
00058 bool realgt(complex<Scalar> left, complex<Scalar> right) {
00059 if (real(left) > real(right)) return true;
00060 return false;
00061 }
00062
00064 class CGException: public RVLException {
00065 public:
00066 CGException(): RVLException() {
00067 (*this)<<"Error: RVLUmin::CGAlg::run\n";
00068 }
00069 CGException(CGException const & s): RVLException(s) {}
00070 ~CGException() throw() {}
00071 };
00072
00081 template<typename Scalar>
00082 class CGStep: public Algorithm {
00083
00084 typedef typename ScalarFieldTraits<Scalar>::AbsType atype;
00085
00086 public:
00087
00092 void run() {
00093
00094 Scalar alpha, beta;
00095
00096 A.applyOp(p,w);
00097 curv = p.inner(w);
00098
00099 if (!realgt(curv,ScalarFieldTraits<Scalar>::Zero()) ) {
00100 CGException e;
00101 e<<"RVLUmin::CGStep::run: termination\n";
00102 e<<" negative curvature (p^TAp) in search direction p = "<<curv<<"\n";
00103 throw e;
00104 }
00105
00106 if (ProtectedDivision<Scalar>(rnormsq,curv,alpha)) {
00107 CGException e;
00108 e<<"RVLUmin::CGStep::run: termination\n";
00109 e<<" curvature much smaller than mean square residual, yields zerodivide\n";
00110 e<<" curvature = "<<curv<<"\n";
00111 e<<" mean square residual = "<<rnormsq<<"\n";
00112 throw e;
00113 }
00114
00115 x.linComb(alpha, p);
00116 r.linComb(-alpha, w);
00117 beta = rnormsq;
00118 rnormsq = r.normsq();
00119
00120 if (ProtectedDivision<Scalar>(rnormsq,beta,beta)) {
00121 CGException e;
00122 e<<"RVLUmin::CGStep::run: termination\n";
00123 e<<" previous square residual much smaller than current, yields zerodivide\n";
00124 e<<" previous square residual = "<<beta<<"\n";
00125 e<<" current square residual = "<<rnormsq<<"\n";
00126 throw e;
00127 }
00128
00129 p.linComb(ScalarFieldTraits<Scalar>::One(), r, beta);
00130 }
00131
00147 CGStep( Vector<Scalar> & x0, LinearOp<Scalar> const & inA,
00148 Vector<Scalar> const & rhs, atype & _rnormsq)
00149 : x(x0), A(inA), b(rhs), r(A.getRange()),
00150 curv(numeric_limits<Scalar>::max()),
00151 rnormsq(_rnormsq),
00152 p(A.getRange()),
00153 w(A.getRange()) {
00154 CGStep<Scalar>::restart();
00155 }
00156
00157 protected:
00158
00163 void restart() {
00164 A.applyOp(x, w);
00165 r.copy(b);
00166 r.linComb(-1.0, w);
00167 rnormsq = r.normsq();
00168 p.copy(r);
00169 }
00170
00171 Vector<Scalar> & x;
00172 const LinearOp<Scalar> & A;
00173 Vector<Scalar> const & b;
00174 Vector<Scalar> r;
00175 Scalar curv;
00176 atype & rnormsq;
00177 Vector<Scalar> p;
00178 Vector<Scalar> w;
00179 };
00180
00206 template<typename Scalar>
00207 class CGAlg: public Algorithm {
00208
00209 typedef typename ScalarFieldTraits<Scalar>::AbsType atype;
00210
00211 public:
00212
00241 CGAlg(RVL::Vector<Scalar> & _x,
00242 LinearOp<Scalar> const & _inA,
00243 Vector<Scalar> const & _rhs,
00244 atype & _rnormsq,
00245 atype _tol = 100.0*numeric_limits<atype>::epsilon(),
00246 int _maxcount = 10,
00247 atype _maxstep = numeric_limits<atype>::max(),
00248 ostream & _str = cout)
00249 : x(_x),
00250 resname("MS residual"),
00251 inA(_inA),
00252 rhs(_rhs),
00253 tol(_tol),
00254 maxcount(_maxcount),
00255 maxstep(_maxstep),
00256 str(_str),
00257 rnormsq(_rnormsq),
00258 step(x,inA,rhs,rnormsq),
00259 it(0) {}
00260
00261 int getCount() { return it; }
00262 void run() {
00263 try {
00264
00265 CountingThresholdIterationTable<Scalar> stop1(maxcount,rnormsq,tol,resname,str);
00266
00267 BallProjTerminator<Scalar> stop2(x,maxstep,str);
00268
00269 OrTerminator stop(stop1,stop2);
00270
00271 LoopAlg doit(step,stop);
00272 doit.run();
00273
00274
00275 if (stop2.query()) {
00276 str<<"CGAlg::run: scale step to trust region boundary\n";
00277 Vector<Scalar> temp(inA.getDomain());
00278 inA.applyOp(x,temp);
00279 temp.linComb(-1.0,rhs);
00280 rnormsq=temp.normsq();
00281 }
00282
00283 it=stop1.getCount();
00284 }
00285 catch (CGException & e) {
00286 throw e;
00287 }
00288 catch (RVLException & e) {
00289 e<<"Error: CGAlg::run\n";
00290 throw e;
00291 }
00292 }
00293
00294 private:
00295
00296 LinearOp<Scalar> const & inA;
00297 Vector<Scalar> & x;
00298 Vector<Scalar> const & rhs;
00299 atype tol;
00300 atype maxstep;
00301 int maxcount;
00302 string resname;
00303 ostream & str;
00304 int it;
00305 atype & rnormsq;
00306 CGStep<Scalar> step;
00307
00308 };
00309
00310 }
00311
00312 #endif