Skip to content
Snippets Groups Projects
Commit f624445f authored by Cédric Traizet's avatar Cédric Traizet
Browse files

pca factory done

parent ae3191a2
No related branches found
No related tags found
1 merge request!4Dimensionality reduction algorithms
......@@ -28,6 +28,8 @@
#include "otbImageToVectorImageCastFilter.h"
#include "DimensionalityReductionModelFactory.h"
#include "PCAModel.h"
namespace otb
{
namespace Functor
......
......@@ -24,6 +24,7 @@
#ifdef OTB_USE_SHARK
#include "AutoencoderModelFactory.h"
#include "PCAModelFactory.h"
#endif
#include "itkMutexLockHolder.h"
......@@ -87,6 +88,7 @@ DimensionalityReductionModelFactory<TInputValue,TOutputValue>
#ifdef OTB_USE_SHARK
RegisterFactory(PCAModelFactory<TInputValue,TOutputValue>::New());
RegisterFactory(AutoencoderModelFactory<TInputValue,TOutputValue>::New());
RegisterFactory(TiedAutoencoderModelFactory<TInputValue,TOutputValue>::New());
#endif
......@@ -137,6 +139,14 @@ DimensionalityReductionModelFactory<TInputValue,TOutputValue>
itk::ObjectFactoryBase::UnRegisterFactory(taeFactory);
continue;
}
PCAModelFactory<TInputValue,TOutputValue> *pcaFactory =
dynamic_cast<PCAModelFactory<TInputValue,TOutputValue> *>(*itFac);
if (pcaFactory)
{
itk::ObjectFactoryBase::UnRegisterFactory(pcaFactory);
continue;
}
#endif
}
......
......@@ -31,7 +31,7 @@ public:
itkTypeMacro(AutoencoderModel, DimensionalityReductionModel);
unsigned int GetDimension() {return m_Dimension;};
itkGetMacro(Dimension,unsigned int);
//itkGetMacro(Dimension,unsigned int);
bool CanReadFile(const std::string & filename);
bool CanWriteFile(const std::string & filename);
......@@ -51,9 +51,9 @@ protected:
virtual void DoPredictBatch(const InputListSampleType *, const unsigned int & startIndex, const unsigned int & size, TargetListSampleType *, ConfidenceListSampleType * = ITK_NULLPTR) const ITK_OVERRIDE;
private:
LinearModel<> m_encoder
LinearModel<> m_decoder
PCA m_pca;
shark::LinearModel<> m_encoder;
shark::LinearModel<> m_decoder;
shark::PCA m_pca;
unsigned int m_Dimension;
};
......
......@@ -52,7 +52,7 @@ bool PCAModel<TInputValue>::CanReadFile(const std::string & filename)
try
{
this->Load(filename);
m_net.name();
m_encoder.name();
}
catch(...)
{
......@@ -72,9 +72,9 @@ template <class TInputValue>
void PCAModel<TInputValue>::Save(const std::string & filename, const std::string & name)
{
std::ofstream ofs(filename);
ofs << m_net.name() << std::endl; //first line
ofs << m_encoder.name() << std::endl; //first line
boost::archive::polymorphic_text_oarchive oa(ofs);
m_net.write(oa);
m_encoder.write(oa);
ofs.close();
}
......@@ -82,17 +82,17 @@ template <class TInputValue>
void PCAModel<TInputValue>::Load(const std::string & filename, const std::string & name)
{
std::ifstream ifs(filename);
char autoencoder[256];
ifs.getline(autoencoder,256);
std::string autoencoderstr(autoencoder);
char encoder[256];
ifs.getline(encoder,256);
std::string encoderstr(encoder);
if (autoencoderstr != m_net.name()){
if (autoencoderstr != m_encoder.name()){
itkExceptionMacro(<< "Error opening " << filename.c_str() );
}
boost::archive::polymorphic_text_iarchive ia(ifs);
m_net.read(ia);
ifs.close();
m_NumberOfHiddenNeurons = m_net.numberOfHiddenNeurons();
m_Dimension = m_encoder.outputSize();
//this->m_Size = m_NumberOfHiddenNeurons;
}
......@@ -108,14 +108,18 @@ PCAModel<TInputValue>::DoPredict(const InputSampleType & value, ConfidenceValueT
}
shark::Data<shark::RealVector> data;
data.element(0)=samples;
data = m_net.encode(data);
data = m_encoder(data);
TargetSampleType target;
//target.SetSize(m_NumberOfHiddenNeurons);
for(unsigned int a = 0; a < m_NumberOfHiddenNeurons; ++a){
//target[a]=data.element(0)[a];
target=data.element(0)[a];
target[a]=p[a];
//target.SetElement(a,p[a]);
}
return target;
}
......@@ -130,7 +134,7 @@ void PCAModel<TInputValue>
Shark::ListSampleRangeToSharkVector(input, features,startIndex,size);
shark::Data<shark::RealVector> data = shark::createDataFromRange(features);
TargetSampleType target;
data = m_net.encode(data);
data = m_encoder(data);
unsigned int id = startIndex;
target.SetSize(m_NumberOfHiddenNeurons);
for(const auto& p : data.elements()){
......
#ifndef PCAModelFactory_h
#define PCAModelFactory_h
#include "itkObjectFactoryBase.h"
#include "itkImageIOBase.h"
namespace otb
{
template <class TInputValue, class TTargetValue>
class ITK_EXPORT PCAModelFactory : public itk::ObjectFactoryBase
{
public:
/** Standard class typedefs. */
typedef PCAModelFactory Self;
typedef itk::ObjectFactoryBase Superclass;
typedef itk::SmartPointer<Self> Pointer;
typedef itk::SmartPointer<const Self> ConstPointer;
/** Class methods used to interface with the registered factories. */
const char* GetITKSourceVersion(void) const ITK_OVERRIDE;
const char* GetDescription(void) const ITK_OVERRIDE;
/** Method for class instantiation. */
itkFactorylessNewMacro(Self);
/** Run-time type information (and related methods). */
itkTypeMacro(PCAModelFactory, itk::ObjectFactoryBase);
/** Register one factory of this type */
static void RegisterOneFactory(void)
{
Pointer PCAFactory = PCAModelFactory::New();
itk::ObjectFactoryBase::RegisterFactory(PCAFactory);
}
protected:
PCAModelFactory();
~PCAModelFactory() ITK_OVERRIDE;
private:
PCAModelFactory(const Self &); //purposely not implemented
void operator =(const Self&); //purposely not implemented
};
} //namespace otb
#ifndef OTB_MANUAL_INSTANTIATION
#include "PCAModelFactory.txx"
#endif
#endif
/*=========================================================================
Program: ORFEO Toolbox
Language: C++
Date: $Date$
Version: $Revision$
Copyright (c) Centre National d'Etudes Spatiales. All rights reserved.
See OTBCopyright.txt for details.
This software is distributed WITHOUT ANY WARRANTY; without even
the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
PURPOSE. See the above copyright notices for more information.
=========================================================================*/
#ifndef AutoencoderModelFactory_txx
#define AutoencoderModelFactory_txx
#include "AutoencoderModelFactory.h"
#include "itkCreateObjectFunction.h"
#include "AutoencoderModel.h"
#include <shark/Algorithms/Trainers/PCA.h>
#include "itkVersion.h"
namespace otb
{
template <class TInputValue, class TOutputValue>
PCAModelFactory<TInputValue,TOutputValue>::PCAModelFactory()
{
std::string classOverride = std::string("DimensionalityReductionModel");
std::string subclass = std::string("PCAModel");
this->RegisterOverride(classOverride.c_str(),
subclass.c_str(),
"Shark PCA ML Model",
1,
// itk::CreateObjectFunction<AutoencoderModel<TInputValue,TOutputValue> >::New());
itk::CreateObjectFunction<PCAModel<TInputValue,shark::LinearModel<> > >::New());
}
template <class TInputValue, class TOutputValue>
PCAModelFactory<TInputValue,TOutputValue>::~PCAModelFactory()
{
}
template <class TInputValue, class TOutputValue>
const char* PCAModelFactory<TInputValue,TOutputValue>::GetITKSourceVersion(void) const
{
return ITK_SOURCE_VERSION;
}
template <class TInputValue, class TOutputValue>
const char* PCAModelFactory<TInputValue,TOutputValue>::GetDescription() const
{
return "PCA model factory";
}
} // end namespace otb
#endif
......@@ -16,6 +16,7 @@
#ifdef OTB_USE_SHARK
#include "AutoencoderModel.h"
#include "PCAModel.h"
#endif
namespace otb
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment