DimensionalityReductionModelFactory.txx 7.41 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
/*=========================================================================

  Program:   ORFEO Toolbox
  Language:  C++
  Date:      $Date$
  Version:   $Revision$


  Copyright (c) Centre National d'Etudes Spatiales. All rights reserved.
  See OTBCopyright.txt for details.


     This software is distributed WITHOUT ANY WARRANTY; without even
     the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
     PURPOSE.  See the above copyright notices for more information.

=========================================================================*/
#ifndef DimensionalityReductionModelFactory_txx
#define DimensionalityReductionFactory_txx

#include "DimensionalityReductionModelFactory.h"
#include "otbConfigure.h"

24
#include "SOMModelFactory.h"
25

26
27
#ifdef OTB_USE_SHARK
#include "AutoencoderModelFactory.h"
Cédric Traizet's avatar
Cédric Traizet committed
28
#include "PCAModelFactory.h"
29
#endif
30

31
32
33
34
35
#include "itkMutexLockHolder.h"


namespace otb
{
36

37
/*
38
template <class TInputValue, class TTargetValue>
39
40
// using AutoencoderModelFactory = AutoencoderModelFactoryBase<TInputValue, TTargetValue, shark::Autoencoder<shark::TanhNeuron, shark::LinearNeuron>>  ;
using AutoencoderModelFactory = AutoencoderModelFactoryBase<TInputValue, TTargetValue, shark::Autoencoder<shark::TanhNeuron, shark::TanhNeuron>>  ;
41
42
43


template <class TInputValue, class TTargetValue>
44
45
// using TiedAutoencoderModelFactory = AutoencoderModelFactoryBase<TInputValue, TTargetValue, shark::TiedAutoencoder< shark::TanhNeuron, shark::LinearNeuron>>  ;
using TiedAutoencoderModelFactory = AutoencoderModelFactoryBase<TInputValue, TTargetValue, shark::TiedAutoencoder< shark::TanhNeuron, shark::TanhNeuron>>  ;
46
47
48
*/

template <class TInputValue, class TTargetValue>
49
using AutoencoderModelFactory = AutoencoderModelFactoryBase<TInputValue, TTargetValue, shark::LogisticNeuron>  ;
50

51

Cédric Traizet's avatar
Cédric Traizet committed
52
template <class TInputValue, class TTargetValue>
53
54
55
56
using SOM2DModelFactory = SOMModelFactory<TInputValue, TTargetValue, 2>  ;

template <class TInputValue, class TTargetValue>
using SOM3DModelFactory = SOMModelFactory<TInputValue, TTargetValue, 3>  ;
Cédric Traizet's avatar
Cédric Traizet committed
57

58
59
60
61
62
63
template <class TInputValue, class TTargetValue>
using SOM4DModelFactory = SOMModelFactory<TInputValue, TTargetValue, 4>  ;

template <class TInputValue, class TTargetValue>
using SOM5DModelFactory = SOMModelFactory<TInputValue, TTargetValue, 5>  ;

Cédric Traizet's avatar
Cédric Traizet committed
64

65
template <class TInputValue, class TOutputValue>
66
typename MachineLearningModel<itk::VariableLengthVector< TInputValue> , itk::VariableLengthVector< TOutputValue>>::Pointer
67
68
69
70
71
72
73
74
DimensionalityReductionModelFactory<TInputValue,TOutputValue>
::CreateDimensionalityReductionModel(const std::string& path, FileModeType mode)
{
  RegisterBuiltInFactories();

  std::list<DimensionalityReductionModelTypePointer> possibleDimensionalityReductionModel;
  std::list<LightObject::Pointer> allobjects =
    itk::ObjectFactoryBase::CreateAllInstance("DimensionalityReductionModel");
75
 
76
77
78
  for(std::list<LightObject::Pointer>::iterator i = allobjects.begin();
      i != allobjects.end(); ++i)
    {
79
    MachineLearningModel<itk::VariableLengthVector< TInputValue> , itk::VariableLengthVector< TOutputValue>> * io = dynamic_cast<MachineLearningModel<itk::VariableLengthVector< TInputValue> , itk::VariableLengthVector< TOutputValue>>*>(i->GetPointer());
80
81
82
83
84
85
    if(io)
      {
      possibleDimensionalityReductionModel.push_back(io);
      }
    else
      {
86
	
87
88
89
90
91
      std::cerr << "Error DimensionalityReductionModel Factory did not return an DimensionalityReductionModel: "
                << (*i)->GetNameOfClass()
                << std::endl;
      }
    }
92
  
93
94
95
96
97
for(typename std::list<DimensionalityReductionModelTypePointer>::iterator k = possibleDimensionalityReductionModel.begin();
      k != possibleDimensionalityReductionModel.end(); ++k)
    {
      if( mode == ReadMode )
      {
98
		
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
      if((*k)->CanReadFile(path))
        {
        return *k;
        }
      }
    else if( mode == WriteMode )
      {
      if((*k)->CanWriteFile(path))
        {
        return *k;
        }

      }
    }
  return ITK_NULLPTR;
}

template <class TInputValue, class TOutputValue>
void
DimensionalityReductionModelFactory<TInputValue,TOutputValue>
::RegisterBuiltInFactories()
{
  itk::MutexLockHolder<itk::SimpleMutexLock> lockHolder(mutex);
  
123
  
124

125
  RegisterFactory(SOM2DModelFactory<TInputValue,TOutputValue>::New());
126
127
128
  RegisterFactory(SOM3DModelFactory<TInputValue,TOutputValue>::New());
  RegisterFactory(SOM4DModelFactory<TInputValue,TOutputValue>::New());
  RegisterFactory(SOM5DModelFactory<TInputValue,TOutputValue>::New());
129
  
130
#ifdef OTB_USE_SHARK
Cédric Traizet's avatar
Cédric Traizet committed
131
  RegisterFactory(PCAModelFactory<TInputValue,TOutputValue>::New());
132
  RegisterFactory(AutoencoderModelFactory<TInputValue,TOutputValue>::New());
133
 // RegisterFactory(TiedAutoencoderModelFactory<TInputValue,TOutputValue>::New());
134
#endif
135
  
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
}

template <class TInputValue, class TOutputValue>
void
DimensionalityReductionModelFactory<TInputValue,TOutputValue>
::RegisterFactory(itk::ObjectFactoryBase * factory)
{
  // Unregister any previously registered factory of the same class
  // Might be more intensive but static bool is not an option due to
  // ld error.
  itk::ObjectFactoryBase::UnRegisterFactory(factory);
  itk::ObjectFactoryBase::RegisterFactory(factory);
}

template <class TInputValue, class TOutputValue>
void
DimensionalityReductionModelFactory<TInputValue,TOutputValue>
::CleanFactories()
{
  itk::MutexLockHolder<itk::SimpleMutexLock> lockHolder(mutex);

  std::list<itk::ObjectFactoryBase*> factories = itk::ObjectFactoryBase::GetRegisteredFactories();
  std::list<itk::ObjectFactoryBase*>::iterator itFac;

  for (itFac = factories.begin(); itFac != factories.end() ; ++itFac)
    {

163
	// SOM
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
	
	SOM5DModelFactory<TInputValue,TOutputValue> *som5dFactory =
      dynamic_cast<SOM5DModelFactory<TInputValue,TOutputValue> *>(*itFac);
    if (som5dFactory)
      {
      itk::ObjectFactoryBase::UnRegisterFactory(som5dFactory);
      continue;
      }
    
    SOM4DModelFactory<TInputValue,TOutputValue> *som4dFactory =
      dynamic_cast<SOM4DModelFactory<TInputValue,TOutputValue> *>(*itFac);
    if (som4dFactory)
      {
      itk::ObjectFactoryBase::UnRegisterFactory(som4dFactory);
      continue;
      }
      
181
    SOM3DModelFactory<TInputValue,TOutputValue> *som3dFactory =
Cédric Traizet's avatar
Cédric Traizet committed
182
      dynamic_cast<SOM3DModelFactory<TInputValue,TOutputValue> *>(*itFac);
183
184
185
186
187
188
189
190
191
    if (som3dFactory)
      {
      itk::ObjectFactoryBase::UnRegisterFactory(som3dFactory);
      continue;
      }
      
    SOM2DModelFactory<TInputValue,TOutputValue> *som2dFactory =
      dynamic_cast<SOM2DModelFactory<TInputValue,TOutputValue> *>(*itFac);
    if (som2dFactory)
192
      {
193
      itk::ObjectFactoryBase::UnRegisterFactory(som2dFactory);
194
195
      continue;
      }
196
      
197
198
199
200
201
202
203
204
205
206
207
#ifdef OTB_USE_SHARK
	
	// Autoencoder
	AutoencoderModelFactory<TInputValue,TOutputValue> *aeFactory =
      dynamic_cast<AutoencoderModelFactory<TInputValue,TOutputValue> *>(*itFac);
    if (aeFactory)
      {
      itk::ObjectFactoryBase::UnRegisterFactory(aeFactory);
      continue;
      }
    
208
    /*
209
210
211
212
213
214
215
    TiedAutoencoderModelFactory<TInputValue,TOutputValue> *taeFactory =
      dynamic_cast<TiedAutoencoderModelFactory<TInputValue,TOutputValue> *>(*itFac);
    if (taeFactory)
      {
      itk::ObjectFactoryBase::UnRegisterFactory(taeFactory);
      continue;
      }
216
    */
217
    // PCA  
Cédric Traizet's avatar
Cédric Traizet committed
218
219
220
221
222
223
224
    PCAModelFactory<TInputValue,TOutputValue> *pcaFactory =
      dynamic_cast<PCAModelFactory<TInputValue,TOutputValue> *>(*itFac);
    if (pcaFactory)
      {
      itk::ObjectFactoryBase::UnRegisterFactory(pcaFactory);
      continue;
      }
225
#endif
226

227
228
229
230
231
232
233
    }

}

} // end namespace otb

#endif