427c5a82fc9acd9f65319432467d85089ed4edec
[ctsim.git] / tools / pjrec.cpp
1 /*****************************************************************************
2 ** FILE IDENTIFICATION
3 **
4 **   Name:          pjrec.cpp
5 **   Purpose:       Reconstruct an image from projections
6 **   Programmer:    Kevin Rosenberg
7 **   Date Started:  Aug 1984
8 **
9 **  This is part of the CTSim program
10 **  Copyright (C) 1983-2000 Kevin Rosenberg
11 **
12 **  $Id$
13 **
14 **  This program is free software; you can redistribute it and/or modify
15 **  it under the terms of the GNU General Public License (version 2) as
16 **  published by the Free Software Foundation.
17 **
18 **  This program is distributed in the hope that it will be useful,
19 **  but WITHOUT ANY WARRANTY; without even the implied warranty of
20 **  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
21 **  GNU General Public License for more details.
22 **
23 **  You should have received a copy of the GNU General Public License
24 **  along with this program; if not, write to the Free Software
25 **  Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
26 ******************************************************************************/
27
28 #include "ct.h"
29 #include "timer.h"
30
31 enum {O_INTERP, O_FILTER, O_FILTER_METHOD, O_ZEROPAD, O_FILTER_PARAM, O_FILTER_GENERATION, O_BACKPROJ, O_PREINTERPOLATION_FACTOR, O_VERBOSE, O_TRACE, O_HELP, O_DEBUG, O_VERSION};
32
33 static struct option my_options[] =
34 {
35   {"interp", 1, 0, O_INTERP},
36   {"preinterpolation-factor", 1, 0, O_PREINTERPOLATION_FACTOR},
37   {"filter", 1, 0, O_FILTER},
38   {"filter-method", 1, 0, O_FILTER_METHOD},
39   {"zeropad", 1, 0, O_ZEROPAD},
40   {"filter-generation", 1, 0, O_FILTER_GENERATION},
41   {"filter-param", 1, 0, O_FILTER_PARAM},
42   {"backproj", 1, 0, O_BACKPROJ},
43   {"trace", 1, 0, O_TRACE},
44   {"debug", 0, 0, O_DEBUG},
45   {"verbose", 0, 0, O_VERBOSE},
46   {"help", 0, 0, O_HELP},
47   {"version", 0, 0, O_VERSION},
48   {0, 0, 0, 0}
49 };
50
51 static const char* g_szIdStr = "$Id$";
52
53 void
54 pjrec_usage (const char *program)
55 {
56   std::cout << "usage: " << fileBasename(program) << " raysum-file image-file nx-image ny-image [OPTIONS]" << std::endl;
57   std::cout << "Image reconstruction from raysum projections" << std::endl;
58   std::cout << std::endl;
59   std::cout << "  raysum-file     Input raysum file" << std::endl;
60   std::cout << "  image-file      Output image file in SDF2D format" << std::endl;
61   std::cout << "  nx-image        Number of columns in output image" << std::endl;
62   std::cout << "  ny-image        Number of rows in output image" << std::endl;
63   std::cout << "  --interp        Interpolation method during backprojection" << std::endl;
64   std::cout << "    nearest         Nearest neighbor interpolation" << std::endl;
65   std::cout << "    linear          Linear interpolation (default)" << std::endl;
66   std::cout << "    cubic           Cubic interpolation\n";
67 #if HAVE_BSPLINE_INTERP
68   std::cout << "    bspline         B-spline interpolation" << std::endl;
69 #endif
70   std::cout << "  --preinterpolate  Preinterpolation factor (default = 1)\n";
71   std::cout << "                    Used only with frequency-based filtering\n";
72   std::cout << "  --filter       Filter name" << std::endl;
73   std::cout << "    abs_bandlimit  Abs * Bandlimiting (default)" << std::endl;
74   std::cout << "    abs_sinc       Abs * Sinc" << std::endl;
75   std::cout << "    abs_cosine     Abs * Cosine" << std::endl;
76   std::cout << "    abs_hamming    Abs * Hamming" << std::endl;
77   std::cout << "    abs_hanning    Abs * Hanning" << std::endl;
78   std::cout << "    shepp          Shepp-Logan" << std::endl;
79   std::cout << "    bandlimit      Bandlimiting" << std::endl;
80   std::cout << "    sinc           Sinc" << std::endl;
81   std::cout << "    cosine         Cosine" << std::endl;
82   std::cout << "    triangle       Triangle" << std::endl;
83   std::cout << "    hamming        Hamming" << std::endl;
84   std::cout << "    hanning        Hanning" << std::endl;
85   std::cout << "  --filter-method  Filter method before backprojections\n";;
86   std::cout << "    convolution      Spatial filtering (default)\n";
87   std::cout << "    fourier          Frequency filtering with discete fourier\n";
88   std::cout << "    fourier_table    Frequency filtering with table lookup fourier\n";
89   std::cout << "    fft              Fast Fourier Transform\n";
90 #if HAVE_FFTW
91   std::cout << "    fftw             Fast Fourier Transform West library\n";
92   std::cout << "    rfftw            Fast Fourier Transform West (real-mode) library\n";
93 #endif
94   std::cout << "  --zeropad n   Set zeropad level (default = 0)\n";
95   std::cout << "                set n to number of powers to two to pad\n";
96   std::cout << "  --filter-generation  Filter Generation mode\n";
97   std::cout << "    direct       Use direct filter in spatial or frequency domain (default)\n";
98   std::cout << "    inverse_fourier  Use inverse fourier transform of inverse filter\n";
99   std::cout << "  --backproj    Backprojection Method" << std::endl;
100   std::cout << "    trig        Trigometric functions at every point" << std::endl;
101   std::cout << "    table       Trigometric functions with precalculated table" << std::endl;
102   std::cout << "    diff        Difference method" << std::endl;
103   std::cout << "    idiff       Difference method with integer math [default]" << std::endl;
104   std::cout << "  --filter-param Alpha level for Hamming filter" << std::endl;
105   std::cout << "  --trace        Set tracing to level" << std::endl;
106   std::cout << "     none        No tracing (default)" << std::endl;
107   std::cout << "     console     Text level tracing" << std::endl;
108   std::cout << "  --verbose      Turn on verbose mode" << std::endl;
109   std::cout << "  --debug        Turn on debug mode" << std::endl;
110   std::cout << "  --version      Print version" << std::endl;
111   std::cout << "  --help         Print this help message" << std::endl;
112 }
113
114
115 #ifdef HAVE_MPI
116 static void ScatterProjectionsMPI (MPIWorld& mpiWorld, Projections& projGlobal, Projections& projLocal, const bool bDebug);
117 static void ReduceImageMPI (MPIWorld& mpiWorld, ImageFile* imLocal, ImageFile* imGlobal);
118 #endif
119
120
121 int
122 pjrec_main (int argc, char * const argv[])
123 {
124   Projections projGlobal;
125   ImageFile* imGlobal = NULL;
126   char* pszFilenameProj = NULL;
127   char* pszFilenameImage = NULL;
128   std::string sRemark;
129   bool bOptVerbose = false;
130   bool bOptDebug = 1;
131   int iOptZeropad = 1;
132   int optTrace = Trace::TRACE_NONE;
133   double dOptFilterParam = -1;
134   std::string sOptFilterName (SignalFilter::convertFilterIDToName (SignalFilter::FILTER_ABS_BANDLIMIT));
135   std::string sOptFilterMethodName (ProcessSignal::convertFilterMethodIDToName (ProcessSignal::FILTER_METHOD_CONVOLUTION));
136   std::string sOptFilterGenerationName (ProcessSignal::convertFilterGenerationIDToName (ProcessSignal::FILTER_GENERATION_DIRECT));
137   std::string sOptInterpName (Backprojector::convertInterpIDToName (Backprojector::INTERP_LINEAR));
138   std::string sOptBackprojectName (Backprojector::convertBackprojectIDToName (Backprojector::BPROJ_IDIFF));
139   int iOptPreinterpolationFactor = 1;
140   int nx, ny;
141   char *endptr;
142 #ifdef HAVE_MPI
143   ImageFile* imLocal;
144   int mpi_nview, mpi_ndet;
145   double mpi_detinc, mpi_rotinc, mpi_phmlen;
146   MPIWorld mpiWorld (argc, argv);
147 #endif
148
149   Timer timerProgram;
150
151 #ifdef HAVE_MPI
152   if (mpiWorld.getRank() == 0) {
153 #endif
154     while (1) {
155       int c = getopt_long(argc, argv, "", my_options, NULL);
156       char *endptr = NULL;
157
158       if (c == -1)
159         break;
160
161       switch (c)
162         {
163         case O_INTERP:
164           sOptInterpName = optarg;
165           break;
166         case O_PREINTERPOLATION_FACTOR:
167           iOptPreinterpolationFactor = strtol(optarg, &endptr, 10);
168           if (endptr != optarg + strlen(optarg)) {
169             pjrec_usage(argv[0]);
170             return(1);
171           }
172           break;
173         case O_FILTER:
174           sOptFilterName = optarg;
175           break;
176         case O_FILTER_METHOD:
177           sOptFilterMethodName = optarg;
178           break;
179         case O_FILTER_GENERATION:
180           sOptFilterGenerationName = optarg;
181           break;
182         case O_FILTER_PARAM:
183           dOptFilterParam = strtod(optarg, &endptr);
184           if (endptr != optarg + strlen(optarg)) {
185             pjrec_usage(argv[0]);
186             return(1);
187           }
188           break;
189         case O_ZEROPAD:
190           iOptZeropad = strtol(optarg, &endptr, 10);
191           if (endptr != optarg + strlen(optarg)) {
192             pjrec_usage(argv[0]);
193             return(1);
194           }
195           break;
196         case O_BACKPROJ:
197           sOptBackprojectName = optarg;
198           break;
199         case O_VERBOSE:
200           bOptVerbose = true;
201           break;
202         case O_DEBUG:
203           bOptDebug = true;
204           break;
205         case O_TRACE:
206           if ((optTrace = Trace::convertTraceNameToID(optarg)) == Trace::TRACE_INVALID) {
207             pjrec_usage(argv[0]);
208             return (1);
209           }
210           break;
211         case O_VERSION:
212 #ifdef VERSION
213           std::cout <<  "Version " <<  VERSION << std::endl << g_szIdStr << std::endl;
214 #else
215           std::cout << "Unknown version number" << std::endl;
216 #endif
217           return (0);
218         case O_HELP:
219         case '?':
220           pjrec_usage(argv[0]);
221           return (0);
222         default:
223           pjrec_usage(argv[0]);
224           return (1);
225         }
226     }
227
228     if (optind + 4 != argc) {
229       pjrec_usage(argv[0]);
230       return (1);
231     }
232
233     pszFilenameProj = argv[optind];
234
235     pszFilenameImage = argv[optind + 1];
236
237     nx = strtol(argv[optind + 2], &endptr, 10);
238     ny = strtol(argv[optind + 3], &endptr, 10);
239
240     std::ostringstream filterDesc;
241     if (dOptFilterParam >= 0)
242       filterDesc << sOptFilterName << ": alpha=" << dOptFilterParam;
243     else
244       filterDesc << sOptFilterName;
245
246     std::ostringstream label;
247     label << "pjrec: " << nx << "x" << ny << ", " << filterDesc.str() << ", " << sOptInterpName << ", preinterpolationFactor=" << iOptPreinterpolationFactor << ", " << sOptBackprojectName;
248     sRemark = label.str();
249
250     if (bOptVerbose)
251       std::cout << "SRemark: " << sRemark << std::endl;
252 #ifdef HAVE_MPI
253   }
254 #endif
255
256 #ifdef HAVE_MPI
257   if (mpiWorld.getRank() == 0) {
258     projGlobal.read (pszFilenameProj);
259     if (bOptVerbose) {
260       ostringstream os;
261       projGlobal.printScanInfo (os);
262       std::cout << os.str();
263     }
264
265     mpi_ndet = projGlobal.nDet();
266     mpi_nview = projGlobal.nView();
267     mpi_detinc = projGlobal.detInc();
268     mpi_phmlen = projGlobal.phmLen();
269     mpi_rotinc = projGlobal.rotInc();
270   }
271
272   TimerCollectiveMPI timerBcast (mpiWorld.getComm());
273   mpiWorld.BcastString (sOptBackprojectName);
274   mpiWorld.BcastString (sOptFilterName);
275   mpiWorld.BcastString (sOptFilterMethodName);
276   mpiWorld.BcastString (sOptInterpName);
277   mpiWorld.getComm().Bcast (&bOptVerbose, 1, MPI::INT, 0);
278   mpiWorld.getComm().Bcast (&bOptDebug, 1, MPI::INT, 0);
279   mpiWorld.getComm().Bcast (&optTrace, 1, MPI::INT, 0);
280   mpiWorld.getComm().Bcast (&dOptFilterParam, 1, MPI::DOUBLE, 0);
281   mpiWorld.getComm().Bcast (&iOptZeropad, 1, MPI::INT, 0);
282   mpiWorld.getComm().Bcast (&iOptPreinterpolationFactor, 1, MPI::INT, 0);
283   mpiWorld.getComm().Bcast (&mpi_ndet, 1, MPI::INT, 0);
284   mpiWorld.getComm().Bcast (&mpi_nview, 1, MPI::INT, 0);
285   mpiWorld.getComm().Bcast (&mpi_detinc, 1, MPI::DOUBLE, 0);
286   mpiWorld.getComm().Bcast (&mpi_phmlen, 1, MPI::DOUBLE, 0);
287   mpiWorld.getComm().Bcast (&mpi_rotinc, 1, MPI::DOUBLE, 0);
288   mpiWorld.getComm().Bcast (&nx, 1, MPI::INT, 0);
289   mpiWorld.getComm().Bcast (&ny, 1, MPI::INT, 0);
290   if (bOptVerbose)
291       timerBcast.timerEndAndReport ("Time to broadcast variables");
292
293   mpiWorld.setTotalWorkUnits (mpi_nview);
294
295   Projections projLocal (mpiWorld.getMyLocalWorkUnits(), mpi_ndet);
296   projLocal.setDetInc (mpi_detinc);
297   projLocal.setPhmLen (mpi_phmlen);
298   projLocal.setRotInc (mpi_rotinc);
299
300   TimerCollectiveMPI timerScatter (mpiWorld.getComm());
301   ScatterProjectionsMPI (mpiWorld, projGlobal, projLocal, bOptDebug);
302   if (bOptVerbose)
303       timerScatter.timerEndAndReport ("Time to scatter projections");
304
305   if (mpiWorld.getRank() == 0) {
306     imGlobal = new ImageFile (nx, ny);
307   }
308
309   imLocal = new ImageFile (nx, ny);
310 #else
311
312   if (! projGlobal.read (pszFilenameProj)) {
313     fprintf(stderr, "Unable to read projectfile file %s\n", pszFilenameProj);
314     exit(1);
315   }
316
317   if (bOptVerbose) {
318     std::ostringstream os;
319     projGlobal.printScanInfo(os);
320     std::cout << os.str();
321   }
322
323   imGlobal = new ImageFile (nx, ny);
324 #endif
325
326 #ifdef HAVE_MPI
327   TimerCollectiveMPI timerReconstruct (mpiWorld.getComm());
328
329   Reconstructor reconstruct (projLocal, *imLocal, sOptFilterName.c_str(), dOptFilterParam, sOptFilterMethodName.c_str(), iOptZeropad, sOptFilterGenerationName.c_str(), sOptInterpName.c_str(), iOptPreinterpolationFactor, sOptBackprojectName.c_str(), optTrace);
330   if (reconstruct.fail()) {
331     std::cout << reconstruct.failMessage();
332     return (1);
333   }
334   reconstruct.reconstructAllViews();
335
336   if (bOptVerbose)
337       timerReconstruct.timerEndAndReport ("Time to reconstruct");
338
339   TimerCollectiveMPI timerReduce (mpiWorld.getComm());
340   ReduceImageMPI (mpiWorld, imLocal, imGlobal);
341   if (bOptVerbose)
342       timerReduce.timerEndAndReport ("Time to reduce image");
343 #else
344   Reconstructor reconstruct (projGlobal, *imGlobal, sOptFilterName.c_str(), dOptFilterParam, sOptFilterMethodName.c_str(), iOptZeropad, sOptFilterGenerationName.c_str(), sOptInterpName.c_str(), iOptPreinterpolationFactor, sOptBackprojectName.c_str(), optTrace);
345   if (reconstruct.fail()) {
346     std::cout << reconstruct.failMessage();
347     return (1);
348   }
349   reconstruct.reconstructAllViews();
350 #endif
351
352 #ifdef HAVE_MPI
353   if (mpiWorld.getRank() == 0)
354 #endif
355     {
356       double dCalcTime = timerProgram.timerEnd();
357       imGlobal->labelAdd (projGlobal.getLabel());
358       imGlobal->labelAdd (Array2dFileLabel::L_HISTORY, sRemark.c_str(), dCalcTime);
359       imGlobal->fileWrite (pszFilenameImage);
360       if (bOptVerbose)
361         std::cout << "Run time: " << dCalcTime << " seconds" << std::endl;
362     }
363 #ifdef HAVE_MPI
364   MPI::Finalize();
365 #endif
366
367   return (0);
368 }
369
370
371 //////////////////////////////////////////////////////////////////////////////////////
372 // MPI Support Routines
373 //
374 //////////////////////////////////////////////////////////////////////////////////////
375
376 #ifdef HAVE_MPI
377 static void ScatterProjectionsMPI (MPIWorld& mpiWorld, Projections& projGlobal, Projections& projLocal, const bool bOptDebug)
378 {
379   if (mpiWorld.getRank() == 0) {
380     for (int iProc = 0; iProc < mpiWorld.getNumProcessors(); iProc++) {
381       for (int iw = mpiWorld.getStartWorkUnit(iProc); iw <= mpiWorld.getEndWorkUnit(iProc); iw++) {
382         DetectorArray& detarray = projGlobal.getDetectorArray( iw );
383         int nDet = detarray.nDet();
384         DetectorValue* detval = detarray.detValues();
385
386         double viewAngle = detarray.viewAngle();
387         mpiWorld.getComm().Send(&nDet, 1, MPI::INT, iProc, 0);
388         mpiWorld.getComm().Send(&viewAngle, 1, MPI::DOUBLE, iProc, 0);
389         mpiWorld.getComm().Send(detval, nDet, MPI::FLOAT, iProc, 0);
390       }
391     }
392   }
393
394   for (int iw = 0; iw < mpiWorld.getMyLocalWorkUnits(); iw++) {
395     MPI::Status status;
396     int nDet;
397     double viewAngle;
398     DetectorValue* detval = projLocal.getDetectorArray(iw).detValues();
399
400     mpiWorld.getComm().Recv(&nDet, 1, MPI::INT, 0, 0, status);
401     mpiWorld.getComm().Recv(&viewAngle, 1, MPI::DOUBLE, 0, 0, status);
402     mpiWorld.getComm().Recv(detval, nDet, MPI::FLOAT, 0, 0, status);
403     projLocal.getDetectorArray(iw).setViewAngle( viewAngle );
404   }
405 }
406
407 static void
408 ReduceImageMPI (MPIWorld& mpiWorld, ImageFile* imLocal, ImageFile* imGlobal)
409 {
410   ImageFileArray vLocal = imLocal->getArray();
411
412   for (unsigned int ix = 0; ix < imLocal->nx(); ix++) {
413     void *recvbuf = NULL;
414     if (mpiWorld.getRank() == 0) {
415       ImageFileArray vGlobal = imGlobal->getArray();
416       recvbuf = vGlobal[ix];
417     }
418     mpiWorld.getComm().Reduce (vLocal[ix], recvbuf, imLocal->ny(), imLocal->getMPIDataType(), MPI::SUM, 0);
419   }
420 }
421
422 #endif
423
424
425 #ifndef NO_MAIN
426 int
427 main (int argc, char* argv[])
428 {
429   int retval = 1;
430
431   try {
432     retval = pjrec_main(argc, argv);
433   } catch (exception e) {
434           std::cerr << "Exception: " << e.what() << std::endl;
435   } catch (...) {
436           std::cerr << "Unknown exception" << std::endl;
437   }
438
439   return (retval);
440 }
441 #endif
442