Skip to content
Snippets Groups Projects
Commit 794a2377 authored by Cédric Traizet's avatar Cédric Traizet
Browse files

application working (monoband...), commit before aefactory template creation

parent 64af665b
No related branches found
No related tags found
No related merge requests found
#include "otbWrapperApplication.h"
#include "otbWrapperApplicationFactory.h"
#include <iostream> /*=========================================================================
Program: ORFEO Toolbox
Language: C++
Date: $Date$
Version: $Revision$
Copyright (c) Centre National d'Etudes Spatiales. All rights reserved.
See OTBCopyright.txt for details.
#include "otbImage.h" This software is distributed WITHOUT ANY WARRANTY; without even
#include "otbVectorImage.h" the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
PURPOSE. See the above copyright notices for more information.
#include <shark/Models/Autoencoder.h>//normal autoencoder model =========================================================================*/
#include <shark/Models/TiedAutoencoder.h>//autoencoder with tied weights #include "otbWrapperApplication.h"
#include <shark/Models/Normalizer.h> #include "otbWrapperApplicationFactory.h"
#include "encode_filter.h"
#include "otbMultiChannelExtractROI.h" #include "itkUnaryFunctorImageFilter.h"
#include "otbChangeLabelImageFilter.h"
#include "otbStandardWriterWatcher.h"
#include "otbStatisticsXMLFileReader.h"
#include "otbShiftScaleVectorImageFilter.h"
#include "otbImageClassificationFilter.h"
#include "otbMultiToMonoChannelExtractROI.h"
#include "otbImageToVectorImageCastFilter.h"
#include "otbMachineLearningModelFactory.h"
namespace otb namespace otb
{ {
namespace Functor
{
/**
* simple affine function : y = ax+b
*/
template<class TInput, class TOutput>
class AffineFunctor
{
public:
typedef double InternalType;
// constructor
AffineFunctor() : m_A(1.0),m_B(0.0) {}
// destructor
virtual ~AffineFunctor() {}
void SetA(InternalType a)
{
m_A = a;
}
void SetB(InternalType b)
{
m_B = b;
}
inline TOutput operator()(const TInput & x) const
{
return static_cast<TOutput>( static_cast<InternalType>(x)*m_A + m_B);
}
private:
InternalType m_A;
InternalType m_B;
};
}
namespace Wrapper namespace Wrapper
{ {
class CbDimensionalityReduction : public otb::Wrapper::Application class CbDimensionalityReduction : public Application
{ {
public: public:
/** Standard class typedefs. */ /** Standard class typedefs. */
typedef CbDimensionalityReduction Self; typedef CbDimensionalityReduction Self;
typedef Application Superclass; typedef Application Superclass;
typedef itk::SmartPointer<Self> Pointer; typedef itk::SmartPointer<Self> Pointer;
typedef itk::SmartPointer<const Self> ConstPointer; typedef itk::SmartPointer<const Self> ConstPointer;
/** Standard macro */
using image_type = FloatVectorImageType; itkNewMacro(Self);
typedef shark::Autoencoder< shark::TanhNeuron, shark::LinearNeuron> AutoencoderType; itkTypeMacro(CbDimensionalityReduction, otb::Application);
using FilterType = EncodeFilter<image_type, AutoencoderType, shark::Normalizer<shark::RealVector>> ;
typedef otb::MultiChannelExtractROI<FloatVectorImageType::InternalPixelType, FloatVectorImageType::InternalPixelType> ExtractROIFilterType; /** Filters typedef */
/** Standard macro */ typedef UInt8ImageType MaskImageType;
itkNewMacro(Self); typedef itk::VariableLengthVector<FloatVectorImageType::InternalPixelType> MeasurementType;
itkTypeMacro(CbDimensionalityReduction, otb::Application); typedef otb::StatisticsXMLFileReader<MeasurementType> StatisticsReader;
typedef otb::ShiftScaleVectorImageFilter<FloatVectorImageType, FloatVectorImageType> RescalerType;
typedef itk::UnaryFunctorImageFilter<
FloatImageType,
FloatImageType,
otb::Functor::AffineFunctor<float,float> > OutputRescalerType;
typedef otb::ImageClassificationFilter<FloatVectorImageType, FloatImageType, MaskImageType> ClassificationFilterType;
typedef ClassificationFilterType::Pointer ClassificationFilterPointerType;
typedef ClassificationFilterType::ModelType ModelType;
typedef ModelType::Pointer ModelPointerType;
typedef ClassificationFilterType::ValueType ValueType;
typedef ClassificationFilterType::LabelType LabelType;
typedef otb::MachineLearningModelFactory<ValueType, LabelType> MachineLearningModelFactoryType;
protected:
~CbDimensionalityReduction() ITK_OVERRIDE
{
MachineLearningModelFactoryType::CleanFactories();
}
private: private:
void DoInit() ITK_OVERRIDE
{
SetName("PredictRegression");
SetDescription("Performs a prediction of the input image according to a regression model file.");
// Documentation
SetDocName("Predict Regression");
SetDocLongDescription("This application predict output values from an input"
" image, based on a regression model file produced by"
" the TrainRegression application. Pixels of the "
"output image will contain the predicted values from"
"the regression model (single band). The input pixels"
" can be optionally centered and reduced according "
"to the statistics file produced by the "
"ComputeImagesStatistics application. An optional "
"input mask can be provided, in which case only "
"input image pixels whose corresponding mask value "
"is greater than 0 will be processed. The remaining"
" of pixels will be given the value 0 in the output"
" image.");
SetDocLimitations("The input image must contain the feature bands used for"
" the model training (without the predicted value). "
"If a statistics file was used during training by the "
"TrainRegression, it is mandatory to use the same "
"statistics file for prediction. If an input mask is "
"used, its size must match the input image size.");
SetDocAuthors("OTB-Team");
SetDocSeeAlso("TrainRegression, ComputeImagesStatistics");
AddDocTag(Tags::Learning);
AddParameter(ParameterType_InputImage, "in", "Input Image");
SetParameterDescription( "in", "The input image to predict.");
// TODO : use CSV input/output ?
AddParameter(ParameterType_InputImage, "mask", "Input Mask");
SetParameterDescription( "mask", "The mask allow restricting "
"classification of the input image to the area where mask pixel values "
"are greater than 0.");
MandatoryOff("mask");
AddParameter(ParameterType_InputFilename, "model", "Model file");
SetParameterDescription("model", "A regression model file (produced by "
"TrainRegression application).");
AddParameter(ParameterType_InputFilename, "imstat", "Statistics file");
SetParameterDescription("imstat", "A XML file containing mean and standard"
" deviation to center and reduce samples before prediction "
"(produced by ComputeImagesStatistics application). If this file contains"
"one more band than the sample size, the last stat of last band will be"
"applied to expand the output predicted value");
MandatoryOff("imstat");
AddParameter(ParameterType_OutputImage, "out", "Output Image");
SetParameterDescription( "out", "Output image containing predicted values");
AddRAMParameter();
// Doc example parameter settings
SetDocExampleParameterValue("in", "QB_1_ortho.tif");
SetDocExampleParameterValue("imstat", "EstimateImageStatisticsQB1.xml");
SetDocExampleParameterValue("model", "clsvmModelQB1.svm");
SetDocExampleParameterValue("out", "clLabeledImageQB1.tif");
}
void DoUpdateParameters() ITK_OVERRIDE
{
// Nothing to do here : all parameters are independent
}
void DoExecute() ITK_OVERRIDE
{
// Load input image
FloatVectorImageType::Pointer inImage = GetParameterImage("in");
inImage->UpdateOutputInformation();
unsigned int nbFeatures = inImage->GetNumberOfComponentsPerPixel();
// Load svm model
otbAppLogINFO("Loading model");
m_Model = MachineLearningModelFactoryType::CreateMachineLearningModel(GetParameterString("model"),
MachineLearningModelFactoryType::ReadMode);
otbAppLogINFO("yo");
if (m_Model.IsNull())
{
otbAppLogFATAL(<< "Error when loading model " << GetParameterString("model") << " : unsupported model type");
}
m_Model->Load(GetParameterString("model"));
m_Model->SetRegressionMode(true);
otbAppLogINFO("Model loaded");
// Classify
m_ClassificationFilter = ClassificationFilterType::New();
m_ClassificationFilter->SetModel(m_Model);
FloatImageType::Pointer outputImage = m_ClassificationFilter->GetOutput();
// Normalize input image if asked
if(IsParameterEnabled("imstat") )
{
otbAppLogINFO("Input image normalization activated.");
// Normalize input image (optional)
StatisticsReader::Pointer statisticsReader = StatisticsReader::New();
MeasurementType meanMeasurementVector;
MeasurementType stddevMeasurementVector;
m_Rescaler = RescalerType::New();
// Load input image statistics
statisticsReader->SetFileName(GetParameterString("imstat"));
meanMeasurementVector = statisticsReader->GetStatisticVectorByName("mean");
stddevMeasurementVector = statisticsReader->GetStatisticVectorByName("stddev");
otbAppLogINFO( "mean used: " << meanMeasurementVector );
otbAppLogINFO( "standard deviation used: " << stddevMeasurementVector );
if (meanMeasurementVector.Size() == nbFeatures + 1)
{
double outMean = meanMeasurementVector[nbFeatures];
double outStdDev = stddevMeasurementVector[nbFeatures];
meanMeasurementVector.SetSize(nbFeatures,false);
stddevMeasurementVector.SetSize(nbFeatures,false);
m_OutRescaler = OutputRescalerType::New();
m_OutRescaler->SetInput(m_ClassificationFilter->GetOutput());
m_OutRescaler->GetFunctor().SetA(outStdDev);
m_OutRescaler->GetFunctor().SetB(outMean);
outputImage = m_OutRescaler->GetOutput();
}
else if (meanMeasurementVector.Size() != nbFeatures)
{
otbAppLogFATAL("Wrong number of components in statistics file : "<<meanMeasurementVector.Size());
}
// Rescale vector image
m_Rescaler->SetScale(stddevMeasurementVector);
m_Rescaler->SetShift(meanMeasurementVector);
m_Rescaler->SetInput(inImage);
m_ClassificationFilter->SetInput(m_Rescaler->GetOutput());
}
else
{
otbAppLogINFO("Input image normalization deactivated.");
m_ClassificationFilter->SetInput(inImage);
}
if(IsParameterEnabled("mask"))
{
otbAppLogINFO("Using input mask");
// Load mask image and cast into LabeledImageType
MaskImageType::Pointer inMask = GetParameterUInt8Image("mask");
m_ClassificationFilter->SetInputMask(inMask);
}
SetParameterOutputImage<FloatImageType>("out", outputImage);
}
ClassificationFilterType::Pointer m_ClassificationFilter;
ModelPointerType m_Model;
RescalerType::Pointer m_Rescaler;
OutputRescalerType::Pointer m_OutRescaler;
void DoInit()
{
SetName("CbDimensionalityReduction");
SetDescription("Perform dimensionality reduction on the input image");
AddParameter(ParameterType_InputImage, "in", "Input Image");
SetParameterDescription( "in", "The input image to perform dimensionality reduction on.");
AddParameter(ParameterType_InputFilename, "model", "Model file");
SetParameterDescription("model", "A model file (produced by the cbDimensionalityReductionTrainer application).");
AddParameter(ParameterType_InputFilename, "normalizer", "Normalizer model file");
SetParameterDescription("normalizer", "A normalizer model file (produced by the cbDimensionalityReductionTrainer application).");
AddParameter(ParameterType_OutputImage, "out", "Output Image");
SetParameterDescription("out", "Output image");
AddRAMParameter();
}
void DoUpdateParameters()
{
}
void DoExecute()
{
std::cout << "Appli" << std::endl;
image_type::Pointer inImage = GetParameterImage("in");
std::string encoderPath = GetParameterString("model");
std::string normalizerPath = GetParameterString("normalizer");
filter_dim_reduc = FilterType::New();
filter_dim_reduc->SetAutoencoderModel(encoderPath);
filter_dim_reduc->SetNormalizerModel(normalizerPath);
filter_dim_reduc->SetInput(inImage);
SetParameterOutputImage("out", filter_dim_reduc->GetOutput());
/*
m_ExtractROIFilter = ExtractROIFilterType::New();
m_ExtractROIFilter->SetInput(filter_dim_reduc->GetOutput());
for (unsigned int idx = 1; idx <= filter_dim_reduc->GetDimension(); ++idx)
{
m_ExtractROIFilter->SetChannel(idx );
}
SetParameterOutputImage("out", m_ExtractROIFilter->GetOutput());
*/
//SetParameterOutputImage("out", inImage); // copy input image
}
FilterType::Pointer filter_dim_reduc;
ExtractROIFilterType::Pointer m_ExtractROIFilter;
//d
}; };
} }
} }
OTB_APPLICATION_EXPORT(otb::Wrapper::CbDimensionalityReduction) OTB_APPLICATION_EXPORT(otb::Wrapper::CbDimensionalityReduction)
...@@ -79,10 +79,6 @@ public: ...@@ -79,10 +79,6 @@ public:
typedef otb::MachineLearningModelFactory<ValueType, ValueType> ModelFactoryType; typedef otb::MachineLearningModelFactory<ValueType, ValueType> ModelFactoryType;
typedef shark::Autoencoder< shark::TanhNeuron, shark::LinearNeuron> AutoencoderType;
typedef AutoencoderModel<ValueType,AutoencoderType> AutoencoderModelType;
typedef RandomForestsMachineLearningModel<ValueType,int> rfModelType;
private: private:
void DoInit() void DoInit()
{ {
...@@ -182,7 +178,6 @@ private: ...@@ -182,7 +178,6 @@ private:
this->Train(trainingListSample,GetParameterString("io.out")); this->Train(trainingListSample,GetParameterString("io.out"));
// d
} }
......
...@@ -52,7 +52,7 @@ public: ...@@ -52,7 +52,7 @@ public:
protected: protected:
AutoencoderModel(); AutoencoderModel();
//~AutoencoderModel() ITK_OVERRIDE; ~AutoencoderModel() ITK_OVERRIDE;
virtual TargetSampleType DoPredict(const InputSampleType& input, ConfidenceValueType *quality=ITK_NULLPTR) const ITK_OVERRIDE; virtual TargetSampleType DoPredict(const InputSampleType& input, ConfidenceValueType *quality=ITK_NULLPTR) const ITK_OVERRIDE;
virtual void DoPredictBatch(const InputListSampleType *, const unsigned int & startIndex, const unsigned int & size, TargetListSampleType *, ConfidenceListSampleType * = ITK_NULLPTR) const ITK_OVERRIDE; virtual void DoPredictBatch(const InputListSampleType *, const unsigned int & startIndex, const unsigned int & size, TargetListSampleType *, ConfidenceListSampleType * = ITK_NULLPTR) const ITK_OVERRIDE;
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
#include <fstream> #include <fstream>
#include <shark/Data/Dataset.h> #include <shark/Data/Dataset.h>
#include "itkMacro.h"
#include "otbSharkUtils.h" #include "otbSharkUtils.h"
//include train function //include train function
#include <shark/ObjectiveFunctions/ErrorFunction.h> #include <shark/ObjectiveFunctions/ErrorFunction.h>
...@@ -18,10 +19,15 @@ namespace otb ...@@ -18,10 +19,15 @@ namespace otb
template <class TInputValue, class AutoencoderType> template <class TInputValue, class AutoencoderType>
AutoencoderModel<TInputValue,AutoencoderType>::AutoencoderModel() AutoencoderModel<TInputValue,AutoencoderType>::AutoencoderModel()
{ {
//this->m_IsRegressionSupported = true; this->m_IsRegressionSupported = true;
} }
template <class TInputValue, class AutoencoderType>
AutoencoderModel<TInputValue,AutoencoderType>::~AutoencoderModel()
{
}
template <class TInputValue, class AutoencoderType> template <class TInputValue, class AutoencoderType>
void AutoencoderModel<TInputValue,AutoencoderType>::Train() void AutoencoderModel<TInputValue,AutoencoderType>::Train()
...@@ -53,6 +59,8 @@ void AutoencoderModel<TInputValue,AutoencoderType>::Train() ...@@ -53,6 +59,8 @@ void AutoencoderModel<TInputValue,AutoencoderType>::Train()
} }
//std::cout<<optimizer.solution().value<<std::endl; //std::cout<<optimizer.solution().value<<std::endl;
m_net.setParameterVector(optimizer.solution().point); m_net.setParameterVector(optimizer.solution().point);
} }
...@@ -82,6 +90,7 @@ template <class TInputValue, class AutoencoderType> ...@@ -82,6 +90,7 @@ template <class TInputValue, class AutoencoderType>
void AutoencoderModel<TInputValue,AutoencoderType>::Save(const std::string & filename, const std::string & name) void AutoencoderModel<TInputValue,AutoencoderType>::Save(const std::string & filename, const std::string & name)
{ {
std::ofstream ofs(filename); std::ofstream ofs(filename);
ofs << m_net.name() << std::endl; //first line
boost::archive::polymorphic_text_oarchive oa(ofs); boost::archive::polymorphic_text_oarchive oa(ofs);
m_net.write(oa); m_net.write(oa);
ofs.close(); ofs.close();
...@@ -91,6 +100,13 @@ template <class TInputValue, class AutoencoderType> ...@@ -91,6 +100,13 @@ template <class TInputValue, class AutoencoderType>
void AutoencoderModel<TInputValue,AutoencoderType>::Load(const std::string & filename, const std::string & name) void AutoencoderModel<TInputValue,AutoencoderType>::Load(const std::string & filename, const std::string & name)
{ {
std::ifstream ifs(filename); std::ifstream ifs(filename);
char autoencoder[256];
ifs.getline(autoencoder,256);
std::string autoencoderstr(autoencoder);
if (autoencoderstr != m_net.name()){
itkExceptionMacro(<< "Error opening " << filename.c_str() );
}
boost::archive::polymorphic_text_iarchive ia(ifs); boost::archive::polymorphic_text_iarchive ia(ifs);
m_net.read(ia); m_net.read(ia);
ifs.close(); ifs.close();
...@@ -101,10 +117,49 @@ void AutoencoderModel<TInputValue,AutoencoderType>::Load(const std::string & fil ...@@ -101,10 +117,49 @@ void AutoencoderModel<TInputValue,AutoencoderType>::Load(const std::string & fil
template <class TInputValue, class AutoencoderType> template <class TInputValue, class AutoencoderType>
typename AutoencoderModel<TInputValue,AutoencoderType>::TargetSampleType typename AutoencoderModel<TInputValue,AutoencoderType>::TargetSampleType
AutoencoderModel<TInputValue,AutoencoderType>::DoPredict(const InputSampleType & value, ConfidenceValueType *quality) const AutoencoderModel<TInputValue,AutoencoderType>::DoPredict(const InputSampleType & value, ConfidenceValueType *quality) const
{
shark::RealVector samples(value.Size());
for(size_t i = 0; i < value.Size();i++)
{
samples.push_back(value[i]);
}
shark::Data<shark::RealVector> data;
data.element(0)=samples;
data = m_net.encode(data);
TargetSampleType target;
//target.SetSize(m_NumberOfHiddenNeurons);
for(unsigned int a = 0; a < m_NumberOfHiddenNeurons; ++a){
//target[a]=data.element(0)[a];
target=data.element(0)[a];
}
return target;
}
template <class TInputValue, class AutoencoderType>
void AutoencoderModel<TInputValue,AutoencoderType>
::DoPredictBatch(const InputListSampleType *input, const unsigned int & startIndex, const unsigned int & size, TargetListSampleType * targets, ConfidenceListSampleType * quality) const
{ {
std::vector<shark::RealVector> features;
Shark::ListSampleRangeToSharkVector(input, features,startIndex,size);
shark::Data<shark::RealVector> data = shark::createDataFromRange(features);
TargetSampleType target; TargetSampleType target;
return target; data = m_net.encode(data);
unsigned int id = startIndex;
for(const auto& p : data.elements()){
for(unsigned int a = 0; a < m_NumberOfHiddenNeurons; ++a){
//target[a]=p[a];
target=p[a];
}
//std::cout << p << std::endl;
targets->SetMeasurementVector(id,target);
++id;
}
} }
} // namespace otb } // namespace otb
#endif #endif
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
#include "itkObjectFactoryBase.h" #include "itkObjectFactoryBase.h"
#include "itkImageIOBase.h"
namespace otb namespace otb
{ {
...@@ -30,8 +31,8 @@ public: ...@@ -30,8 +31,8 @@ public:
/** Register one factory of this type */ /** Register one factory of this type */
static void RegisterOneFactory(void) static void RegisterOneFactory(void)
{ {
Pointer RFFactory = AutoencoderModelFactory::New(); Pointer AEFactory = AutoencoderModelFactory::New();
itk::ObjectFactoryBase::RegisterFactory(RFFactory); itk::ObjectFactoryBase::RegisterFactory(AEFactory);
} }
protected: protected:
...@@ -45,6 +46,10 @@ private: ...@@ -45,6 +46,10 @@ private:
}; };
} //namespace otb } //namespace otb
#ifndef OTB_MANUAL_INSTANTIATION
#include "AutoencoderModelFactory.txx"
#endif
#endif #endif
...@@ -25,6 +25,7 @@ ...@@ -25,6 +25,7 @@
#include "AutoencoderModel.h" #include "AutoencoderModel.h"
#include "itkVersion.h" #include "itkVersion.h"
#include <shark/Models/Autoencoder.h>//normal autoencoder model
namespace otb namespace otb
{ {
template <class TInputValue, class TOutputValue> template <class TInputValue, class TOutputValue>
......
...@@ -117,16 +117,7 @@ private: ...@@ -117,16 +117,7 @@ private:
#ifdef OTB_USE_SHARK #ifdef OTB_USE_SHARK
void InitAutoencoderParams(); void InitAutoencoderParams();
template <class autoencoderchoice> template <class autoencoderchoice>
void TrainAutoencoder(typename ListSampleType::Pointer trainingListSample, std::string modelPath);/*{ void TrainAutoencoder(typename ListSampleType::Pointer trainingListSample, std::string modelPath);
// typename AutoencoderModelType::Pointer dimredTrainer = AutoencoderModelType::New();
typename autoencoderchoice::Pointer dimredTrainer = autoencoderchoice::New();
dimredTrainer->SetNumberOfHiddenNeurons(GetParameterInt("model.autoencoder.nbneuron"));
dimredTrainer->SetNumberOfIterations(GetParameterInt("model.autoencoder.nbiter"));
dimredTrainer->SetRegularization(GetParameterFloat("model.autoencoder.normalizer"));
dimredTrainer->SetInputListSample(trainingListSample);
dimredTrainer->Train();
dimredTrainer->Save(modelPath);
}; // !!!!!!!!!!!!!!!!! How to declare this method body in the .txx ? (double template...) */
#endif #endif
//@} //@}
}; };
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment