00001 #ifndef __RVL_MPIFO
00002 #define __RVL_MPIFO
00003
00004 #ifdef IWAVE_USE_MPI
00005 #include "mpidatatransfer.hh"
00006 #endif
00007 #include "local.hh"
00008
00009 namespace RVL {
00010
00014 class MPISynchRoot {
00015 public:
00017 virtual void set() = 0;
00018
00022 virtual void synch() = 0;
00023 virtual ~MPISynchRoot() {};
00024 };
00025
00033 template<class DataType>
00034 class MPISerialFunctionObject: public LocalFunctionObject<DataType> {
00035
00036 private:
00037
00038 LocalEvaluation<DataType> * fptr;
00039 string fname;
00040 MPISerialFunctionObject();
00041
00042 public:
00043
00044 MPISerialFunctionObject(FunctionObject & f): fptr(NULL) {
00045 if (!(fptr=dynamic_cast<LocalEvaluation<DataType> *>(&f))) {
00046 RVLException e;
00047 e<<"Error: MPISerialFunctionObject constructor\n";
00048 e<<"input function, named "<<f.getName()<<" not LocalEvaluation\n";
00049 throw e;
00050 }
00051 fname=f.getName();
00052 }
00053 MPISerialFunctionObject(MPISerialFunctionObject<DataType> const & f)
00054 : fptr(f.fptr), fname(f.fname) {}
00055 ~MPISerialFunctionObject() {}
00056
00057 void operator()(LocalDataContainer<DataType> & target,
00058 vector<LocalDataContainer<DataType> const *> & sources) {
00059 try {
00060
00061 #ifdef IWAVE_USE_MPI
00062 int rk=0;
00063 MPI_Comm_rank(MPI_COMM_WORLD,&rk);
00064 if (rk==0)
00065 #endif
00066 fptr->operator()(target,sources);
00067 }
00068 catch (RVLException & e) {
00069 e<<"\ncalled from MPISerialFunctionObject::operator()\n";
00070 throw e;
00071 }
00072 }
00073
00074 string getName() const {
00075 string str="MPISerialFO wrapper around ";
00076 str+=fname;
00077 return str;
00078 }
00079 };
00080
00092 template<typename DataType, typename ValType>
00093 class MPISerialFunctionObjectRedn:
00094 public FunctionObjectScalarRedn<ValType>,
00095 public LocalConstEval<DataType>,
00096 public MPISynchRoot {
00097
00098 private:
00099
00100 LocalConstEval<DataType> * lrptr;
00101 FunctionObjectScalarRedn<ValType> & f;
00102 string fname;
00103 int root;
00104 int rk;
00105
00106 #ifdef IWAVE_USE_MPI
00107 MPI_Comm comm;
00108 #endif
00109
00110 MPISerialFunctionObjectRedn();
00111
00112 public:
00113
00114 MPISerialFunctionObjectRedn(FunctionObjectScalarRedn<ValType> & _f,
00115 int _root=0
00116 #ifdef IWAVE_USE_MPI
00117 , MPI_Comm _comm=MPI_COMM_WORLD
00118 #endif
00119 )
00120 : FunctionObjectScalarRedn<ValType>(_f.getValue()),
00121 lrptr(NULL), f(_f), root(_root), rk(_root)
00122 #ifdef IWAVE_USE_MPI
00123 , comm(_comm)
00124 #endif
00125 {
00126 #ifdef IWAVE_USE_MPI
00127 MPI_Comm_rank(comm,&rk);
00128 #endif
00129 if (!(lrptr=dynamic_cast<LocalConstEval<DataType> *>(&f))) {
00130 RVLException e;
00131 e<<"Error: MPISerialFunctionObjectRedn constructor\n";
00132 e<<"input function, named "<<f.getName()<<" not LocalReduction\n";
00133 throw e;
00134 }
00135 fname=f.getName();
00136 }
00137
00138 MPISerialFunctionObjectRedn(MPISerialFunctionObjectRedn<DataType,ValType> const & f)
00139 : lrptr(f.lrptr), f(f.f), root(f.root), rk(f.root),
00140 #ifdef IWAVE_USE_MPI
00141 comm(f.comm),
00142 #endif
00143 fname(f.fname) {
00144 #ifdef IWAVE_USE_MPI
00145 MPI_Comm_rank(comm,&rk);
00146 #endif
00147 }
00148
00149 ~MPISerialFunctionObjectRedn() {}
00150
00151 void operator()(vector<LocalDataContainer<DataType> const *> & sources) {
00152 try {
00153
00154 if (rk==root) {
00155
00156 lrptr->operator()(sources);
00157
00158 }
00159
00160
00161
00162
00163 ScalarRedn<ValType>::setValue(f.getValue());
00164 }
00165 catch (RVLException & e) {
00166 e<<"\ncalled from MPISerialFunctionObjectRedn::operator()\n";
00167 throw e;
00168 }
00169 }
00170
00171 void synch() {
00172 try {
00173
00174
00175 ValType a = ScalarRedn<ValType>::getValue();
00176 #ifdef IWAVE_USE_MPI
00177
00178 MPI_Broadcaster<ValType> bc(root,comm);
00179 bc(a);
00180 #endif
00181
00182 ScalarRedn<ValType>::setValue(a);
00183
00184 f.setValue(a);
00185
00186 }
00187 catch (RVLException & e) {
00188 e<<"\ncalled from MPISerialFunctionObjectRedn::synch\n";
00189 throw e;
00190 }
00191 }
00192
00195 void setValue() {
00196 f.setValue();
00197 ScalarRedn<ValType>::setValue(f.getValue());
00198 }
00199
00201 void set() {
00202 this->setValue();
00203 }
00204
00205 string getName() const {
00206 string str="MPISerialFOR wrapper around ";
00207 str+=fname;
00208 return str;
00209 }
00210 };
00211
00212 }
00213 #endif