00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031
00032
00033 #ifndef __RVL_RNMAT
00034 #define __RVL_RNMAT
00035
00036 #include "op.hh"
00037 #include "local.hh"
00038 #include "rnspace.hh"
00039
00040 namespace RVL {
00041
00045 template<typename T>
00046 class matvec: public BinaryLocalFunctionObject<T> {
00047
00048 private:
00049
00050 int rows;
00051 int cols;
00052 mutable bool adj;
00053 RnArray<T> mat;
00054 matvec();
00055
00056 public:
00057
00058 matvec(int _rows, int _cols): rows(_rows), cols(_cols), adj(false), mat(rows*cols) {
00059 for (int i=0;i<rows*cols; i++) mat.getData()[i]=ScalarFieldTraits<T>::Zero();
00060 }
00061 matvec(matvec<T> const * m): rows(m->rows), cols(m->cols), adj(m->adj), mat(m->mat) {}
00062 ~matvec() {}
00063
00065 T * getData() { return mat.getData(); }
00066 T const * getData() const { return mat.getData(); }
00067
00069 void eval(UnaryLocalFunctionObject<T> & f) { f(mat); }
00070
00072 T & getElement(int i, int j) {
00073 if (i<0 || i>rows-1) {
00074 RVLException e;
00075 e<<"Error: matvec::getElement\n";
00076 e<<"row index "<<i<<" out of range [0,"<<rows-1<<"]\n";
00077 throw e;
00078 }
00079 if (j<0 || j>cols-1) {
00080 RVLException e;
00081 e<<"Error: matvec::getElement\n";
00082 e<<"col index "<<j<<" out of range [0,"<<cols-1<<"]\n";
00083 throw e;
00084 }
00085 return mat.getData()[i+rows*j];
00086 }
00087
00089 T const & getElement(int i, int j) const {
00090 if (i<0 || i>rows-1) {
00091 RVLException e;
00092 e<<"Error: matvec::getElement\n";
00093 e<<"row index "<<i<<" out of range [0,"<<rows-1<<"]\n";
00094 throw e;
00095 }
00096 if (j<0 || j>cols-1) {
00097 RVLException e;
00098 e<<"Error: matvec::getElement\n";
00099 e<<"col index "<<j<<" out of range [0,"<<cols-1<<"]\n";
00100 throw e;
00101 }
00102 return mat.getData()[i+rows*j];
00103 }
00104
00106 void setAdj(bool flag) const { adj=flag; }
00108 bool getAdj() const { return adj; }
00109
00115 using RVL::BinaryLocalEvaluation<T>::operator();
00116 void operator()(LocalDataContainer<T> & y, LocalDataContainer<T> const & x) {
00117
00118 if (adj) {
00119
00120 if (y.getSize() < (size_t(cols)) || x.getSize() < size_t(rows)) {
00121 RVLException e;
00122 e<<"Error: matvec::operator(), adjoint\n";
00123 e<<"either input or output too short\n";
00124 throw e;
00125 }
00126 for (int j=0;j<cols;j++) {
00127 y.getData()[j]=ScalarFieldTraits<T>::Zero();
00128 for (int i=0;i<rows;i++)
00129 y.getData()[j]+= RVL::conj(mat.getData()[i+j*rows])*x.getData()[i];
00130 }
00131
00132 }
00133 else {
00134
00135 if (y.getSize() < size_t(rows) || x.getSize() < size_t(cols)) {
00136 RVLException e;
00137 e<<"Error: matvec::operator()\n";
00138 e<<"either input or output too short\n";
00139 e<<"input size = "<<x.getSize()<<" should be at least "<<cols<<"\n";
00140 e<<"output size = "<<y.getSize()<<" should be at least "<<rows<<"\n";
00141 throw e;
00142 }
00143 for (int i=0;i<rows;i++) {
00144 y.getData()[i]=ScalarFieldTraits<T>::Zero();
00145 for (int j=0;j<cols;j++)
00146 y.getData()[i]+=mat.getData()[i+j*rows]*x.getData()[j];
00147 }
00148 }
00149 }
00150
00151 string getName() const { string tmp = "matvec"; return tmp; }
00152
00153 };
00154
00156 template<typename T>
00157 class fmatvec: public BinaryLocalFunctionObject<T> {
00158
00159 private:
00160 matvec<T> & m;
00161 fmatvec();
00162
00163 public:
00164 fmatvec( matvec<T> & _m ) : m(_m) {}
00165 fmatvec( fmatvec<T> const & f) : m(f.m) {}
00166 ~fmatvec() {}
00167
00168 using RVL::BinaryLocalEvaluation<T>::operator();
00169 void operator()(LocalDataContainer<T> & y,
00170 LocalDataContainer<T> const & x) {
00171 m.setAdj(false);
00172 m(y,x);
00173 }
00174 string getName() const { string tmp = "fmatvec"; return tmp; }
00175 };
00176
00178 template<typename T>
00179 class amatvec: public BinaryLocalFunctionObject<T> {
00180
00181 private:
00182 matvec<T> & m;
00183 amatvec();
00184
00185 public:
00186 amatvec( matvec<T> & _m ) : m(_m) {}
00187 amatvec( amatvec<T> const & f) : m(f.m) {}
00188 ~amatvec() {}
00189
00190 using RVL::BinaryLocalEvaluation<T>::operator();
00191 void operator()(LocalDataContainer<T> & y,
00192 LocalDataContainer<T> const & x) {
00193 m.setAdj(true);
00194 m(y,x);
00195 }
00196 string getName() const { string tmp = "amatvec"; return tmp; }
00197 };
00198
00205 template<typename T>
00206 class GenMat: public LinearOp<T> {
00207
00208 private:
00209
00210 RnSpace<T> dom;
00211 RnSpace<T> rng;
00212 mutable matvec<T> a;
00213 GenMat();
00214
00215 protected:
00216
00217 virtual LinearOp<T> * clone() const { return new GenMat(*this); }
00218
00219 void apply(Vector<T> const & x,
00220 Vector<T> & y) const {
00221 try {
00222 a.setAdj(false);
00223 y.eval(a,x);
00224 }
00225 catch (RVLException & e) {
00226 e<<"\ncalled from GenMat::apply\n";
00227 throw e;
00228 }
00229 }
00230
00231 void applyAdj(Vector<T> const & x,
00232 Vector<T> & y) const {
00233 try {
00234 a.setAdj(true);
00235 y.eval(a,x);
00236 }
00237 catch (RVLException & e) {
00238 e<<"\ncalled from GenMat::applyAdj\n";
00239 throw e;
00240 }
00241 }
00242
00243 public:
00244
00246 GenMat(RnSpace<T> const & _dom, RnSpace<T> const & _rng)
00247 : dom(_dom), rng(_rng), a(rng.getSize(),dom.getSize()) {}
00248
00250 GenMat(GenMat<T> const & m)
00251 : dom(m.dom), rng(m.rng), a(m.a) {}
00252
00253 ~GenMat() {}
00254
00255 Space<T> const & getDomain() const { return dom; }
00256 Space<T> const & getRange() const { return rng; }
00257
00259 int getNRows() const { return int(rng.getSize()); }
00260 int getNCols() const { return int(dom.getSize()); }
00261
00263 T * getData() { return a.getData(); }
00264 T const * getData() const { return a.getData(); }
00265
00267
00268 T const & getElement(int i, int j) const {
00269 try { return a.getElement(i,j); }
00270 catch (RVLException e) {
00271 e<<"\ncalled from GenMat::getElement\n"; throw e;
00272 }
00273 }
00274
00275
00276 T & getElement(int i, int j) {
00277 try { return a.getElement(i,j); }
00278 catch (RVLException e) {e<<"\ncalled from GenMat::getElement\n"; throw e; }
00279 }
00280
00281
00282 virtual void setElement(int i, int j, T e) {
00283 try { a.getElement(i,j)=e; }
00284 catch (RVLException e) {e<<"\ncalled from GenMat::setElement\n"; throw e; }
00285 }
00286
00287
00288 void eval(UnaryLocalFunctionObject<T> & f) { a.eval(f); }
00289
00290 matvec<T> const & getMatVec() const { return a; }
00291 matvec<T> & getMatVec() { return a; }
00292
00293 virtual ostream & write(ostream & str) const {
00294 str<<"GenMat: simple general matrix class\n";
00295 str<<"based on matvec FunctionObject\n";
00296 str<<"rows = "<<rng.getSize()<<" cols = "<<dom.getSize()<<endl;
00297 return str;
00298 }
00299
00300 };
00301
00304 template<typename T>
00305 class SymMat: public GenMat<T> {
00306
00307 protected:
00308
00309 LinearOp<T> * clone() const { return new SymMat<T>(*this); }
00310
00311 public:
00312
00313 SymMat(RnSpace<T> const & dom)
00314 : GenMat<T>(dom,dom) {}
00315
00316 SymMat(SymMat<T> const & m)
00317 : GenMat<T>(m) {}
00318
00319 ~SymMat() {}
00320
00321 void setElement(int i, int j, T e) {
00322 try { GenMat<T>::setElement(i,j,e); GenMat<T>::getElement(j,i,e); }
00323 catch (RVLException e) {e<<"\ncalled from GenMat::setElement\n"; throw e; }
00324 }
00325
00326 ostream & write(ostream & str) const {
00327 str<<"SymMat: simple symmetric matrix class\n";
00328 str<<"based on matvec FunctionObject\n";
00329 RnSpace<T> const & dom = dynamic_cast<RnSpace<T > const &>(this->getDomain());
00330 str<<"rows = cols = "<<dom.getSize()<<endl;
00331 return str;
00332 }
00333 };
00334
00335 }
00336
00337 #endif