00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031
00032
00033
00034
00035
00036 #ifndef __ALG_UMINSTEP_H_
00037 #define __ALG_UMINSTEP_H_
00038
00039 #include "lnsrch.hh"
00040 #include "alg.hh"
00041 #include "functional.hh"
00042
00043 namespace RVLUmin {
00044
00045 using namespace RVLAlg;
00046 using RVL::Functional;
00047 using RVL::Vector;
00048 using RVL::Vector;
00049 using RVL::FunctionalEvaluation;
00050
00056 template<typename Scalar>
00057 class UMinDir: public Terminator {
00058 public:
00059 UMinDir() {}
00060 UMinDir(UMinDir<Scalar> const &) {}
00061 virtual ~UMinDir() {}
00062
00063
00064
00066 virtual void calcDir(Vector<Scalar> & dir,
00067 FunctionalEvaluation<Scalar> & fx) = 0;
00068
00070 virtual void updateDir(LineSearchAlg<Scalar> const & ls) = 0;
00071
00073 virtual void resetDir() = 0;
00074
00076 virtual ostream & write(ostream & str) const = 0;
00077 };
00078
00098 template<class Scalar>
00099 class UMinStepLS: public Algorithm, public Terminator {
00100 private:
00101
00102 FunctionalEvaluation<Scalar> & fx;
00103 UMinDir<Scalar> & dc;
00104 LineSearchAlg<Scalar> & ls;
00105 bool ans;
00106 ostream & str;
00107
00117 virtual void calcStep(Vector<Scalar> & dir) {
00118 try {
00119
00120 bool tried_steepest_descent = false;
00121
00122 ls.initialize(fx,dir);
00123 ls.run();
00124 bool res = ls.query();
00125 if (!res) {
00126 dc.updateDir(ls);
00127 ans = ans || dc.query();
00128 }
00129
00130 else {
00131 if (!tried_steepest_descent) {
00132 str<<"UMinStep: attempting steepest descent restart\n";
00133
00134 Vector<Scalar> & x = fx.getPoint();
00135 x.copy(ls.getBasePoint());
00136 dir.copy(ls.getBaseGradient());
00137 dir.negate();
00138 dc.resetDir();
00139
00140 ls.initialize(fx,dir);
00141 ls.run();
00142 if (!ls.query()) {
00143 dc.updateDir(ls);
00144 ans= ans || dc.query();
00145 }
00146 else {
00147 ans = true;
00148 }
00149 tried_steepest_descent=true;
00150 }
00151 else {
00152 ans = true;
00153 }
00154 }
00155
00156 } catch(RVLException & e) {
00157 e << "\ncalled from UMinStepLS::calcStep()\n";
00158 throw e;
00159 }
00160 }
00161
00162 public:
00163
00167 UMinStepLS(FunctionalEvaluation<Scalar> & _fx,
00168 UMinDir<Scalar> & _dc,
00169 LineSearchAlg<Scalar> & _ls,
00170 ostream & _str = cout)
00171 : fx(_fx), dc(_dc), ls(_ls), ans(false), str(_str) {}
00172
00173 UMinStepLS(const UMinStepLS<Scalar> & cos)
00174 : fx(cos.fx), dc(cos.dc), ls(cos.ls), ans(cos.ans), str(cos.str) {}
00175
00176 virtual ~UMinStepLS() {}
00177
00179 Vector<Scalar> const & getBasePoint() { return ls.getBasePoint(); }
00180
00182 Vector<Scalar> const & getBaseGradient() { return ls.getBaseGradient(); }
00183
00185 FunctionalEvaluation<Scalar> & getFunctionalEvaluation() { return fx; }
00186
00187 void run() {
00188 try {
00189 Vector<Scalar> dir(fx.getDomain(), true);
00190 dc.calcDir(dir,fx);
00191 calcStep(dir);
00192 }
00193 catch(RVLException & e) {
00194 e << "called from UMinStepLS::run()\n";
00195 throw e;
00196 }
00197 catch( std::exception & e) {
00198 RVLException es;
00199 es << "Exception caught in UMinStepLS::run() with error message";
00200 es << e.what();
00201 throw e;
00202 }
00203 }
00204
00205 bool query() { return ans; }
00206
00207 };
00208
00209 }
00210
00211
00212 #endif