Commit da152f58 authored by Guillaume Pasero's avatar Guillaume Pasero

REFAC: ObjectDetection now requires OTBSupervised and OTBLibSVM for tests

parent 77506785
......@@ -32,7 +32,7 @@
#include "itkFunctionBase.h"
#include "otbVectorData.h"
#include "otbSVMModel.h"
#include "otbMachineLearningModel.h"
#include "otbPersistentImageFilter.h"
#include "otbPersistentFilterStreamingDecorator.h"
......@@ -54,7 +54,7 @@ public:
* plus the ThreadedGenerateData function implementing the image function evaluation
*
*
* \ingroup OTBSVMLearning
* \ingroup OTBObjectDetection
*/
template <class TInputImage, class TOutputVectorData, class TLabel, class TFunctionType>
class ITK_EXPORT PersistentObjectDetectionClassifier :
......@@ -117,9 +117,9 @@ public:
/** TLabel output */
typedef TLabel LabelType;
typedef SVMModel<DescriptorPrecision, LabelType> SVMModelType;
typedef typename SVMModelType::Pointer SVMModelPointerType;
typedef typename SVMModelType::MeasurementType SVMModelMeasurementType;
typedef MachineLearningModel<DescriptorPrecision, LabelType> ModelType;
typedef typename ModelType::Pointer ModelPointerType;
typedef typename ModelType::InputSampleType ModelMeasurementType;
typedef itk::Statistics::ListSample<DescriptorType> ListSampleType;
......@@ -128,8 +128,10 @@ public:
this->Superclass::AddInput(dataObject);
}
/** SVM model used for classification */
void SetSVMModel(SVMModelType * model);
/** learning model used for classification */
void SetModel(ModelType * model);
const ModelType* GetModel(void) const;
VectorDataType* GetOutputVectorData(void);
......@@ -239,9 +241,9 @@ private:
/** \class ObjectDetectionClassifier
* \brief This class detects object in an image, given a SVM model and a local descriptors function
* \brief This class detects object in an image, given a ML model and a local descriptors function
*
* Given an image (by SetInputImage()), a SVM model (by SetSVMModel) and an local descriptors ImageFunction
* Given an image (by SetInputImage()), a ML model (by SetModel) and an local descriptors ImageFunction
* (set by SetDescriptorsFunction()), this class computes the local descriptors on a regular grid
* over the image, and evaluates the class label of the corresponding sample.
* It outputs a vector data with the points for which the descriptors are not classified as "negative",
......@@ -249,7 +251,7 @@ private:
*
* This class is streaming capable and multithreaded
*
* \ingroup OTBSVMLearning
* \ingroup OTBObjectDetection
*/
template <class TInputImage, class TOutputVectorData, class TLabel, class TFunctionPrecision = double, class TCoordRep = double>
class ITK_EXPORT ObjectDetectionClassifier :
......@@ -299,8 +301,8 @@ public:
typedef typename Superclass::FilterType PersistentFilterType;
typedef typename PersistentFilterType::SVMModelType SVMModelType;
typedef typename PersistentFilterType::SVMModelPointerType SVMModelPointerType;
typedef typename PersistentFilterType::ModelType ModelType;
typedef typename PersistentFilterType::ModelPointerType ModelPointerType;
/** Input image to extract feature */
void SetInputImage(InputImageType* input)
......@@ -332,15 +334,15 @@ public:
}
/** The function to evaluate */
void SetSVMModel(SVMModelType* model)
void SetModel(ModelType* model)
{
this->GetFilter()->SetSVMModel(model);
this->GetFilter()->SetModel(model);
}
/** The function to evaluate */
SVMModelType* GetSVMModel()
/** The classification model */
const ModelType* GetModel()
{
return this->GetFilter()->GetSVMModel();
return this->GetFilter()->GetModel();
}
otbSetObjectMemberMacro(Filter, NeighborhoodRadius, unsigned int);
......
......@@ -84,11 +84,23 @@ PersistentObjectDetectionClassifier<TInputImage, TOutputVectorData, TLabel, TFun
template <class TInputImage, class TOutputVectorData, class TLabel, class TFunctionType>
void
PersistentObjectDetectionClassifier<TInputImage, TOutputVectorData, TLabel, TFunctionType>
::SetSVMModel(SVMModelType* model)
::SetModel(ModelType* model)
{
this->SetNthInput(1, model);
}
template <class TInputImage, class TOutputVectorData, class TLabel, class TFunctionType>
const typename PersistentObjectDetectionClassifier<TInputImage, TOutputVectorData, TLabel, TFunctionType>::ModelType*
PersistentObjectDetectionClassifier<TInputImage, TOutputVectorData, TLabel, TFunctionType>
::GetModel(void) const
{
if(this->GetNumberOfInputs()<2)
{
return ITK_NULLPTR;
}
return static_cast<const ModelType*>(this->itk::ProcessObject::GetInput(1));
}
template <class TInputImage, class TOutputVectorData, class TLabel, class TFunctionType>
typename PersistentObjectDetectionClassifier<TInputImage, TOutputVectorData, TLabel, TFunctionType>::VectorDataType*
PersistentObjectDetectionClassifier<TInputImage, TOutputVectorData, TLabel, TFunctionType>
......@@ -245,7 +257,7 @@ PersistentObjectDetectionClassifier<TInputImage, TOutputVectorData, TLabel, TFun
itk::ThreadIdType threadId)
{
InputImageType* input = static_cast<InputImageType*>(this->itk::ProcessObject::GetInput(0));
SVMModelType* model = static_cast<SVMModelType*>(this->itk::ProcessObject::GetInput(1));
const ModelType* model = this->GetModel();
typedef typename RegionType::IndexType IndexType;
IndexType begin = outputRegionForThread.GetIndex();
......@@ -266,12 +278,12 @@ PersistentObjectDetectionClassifier<TInputImage, TOutputVectorData, TLabel, TFun
input->TransformIndexToPhysicalPoint(current, point);
DescriptorType descriptor = m_DescriptorsFunction->Evaluate(point);
SVMModelMeasurementType modelMeasurement(descriptor.GetSize());
ModelMeasurementType modelMeasurement(descriptor.GetSize());
for (unsigned int i = 0; i < descriptor.GetSize(); ++i)
{
modelMeasurement[i] = (descriptor[i] - m_Shifts[i]) * m_InvertedScales[i];
}
LabelType label = model->EvaluateLabel(modelMeasurement);
LabelType label = (model->Predict(modelMeasurement))[0];
if (label != m_NoClassLabel)
{
......
......@@ -30,13 +30,14 @@ otb_module(OTBObjectDetection
OTBObjectList
OTBStatistics
OTBStreaming
OTBSupervised
OTBTextures
OTBVectorDataBase
TEST_DEPENDS
OTBIOXML
OTBImageIO
OTBSVMLearning
OTBLibSVM
OTBTestKernel
OTBVectorDataIO
......
......@@ -17,9 +17,6 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <iterator>
#include "otbImage.h"
......@@ -31,7 +28,7 @@
#include "otbStatisticsXMLFileReader.h"
#include "otbShiftScaleSampleListFilter.h"
#include "otbSVMSampleListModelEstimator.h"
#include "otbLibSVMMachineLearningModel.h"
const unsigned int Dimension = 2;
typedef int LabelType;
......@@ -60,17 +57,10 @@ typedef otb::DescriptorsListSampleGenerator
typedef otb::ImageFileReader<ImageType> ImageReaderType;
typedef otb::VectorDataFileReader<VectorDataType> VectorDataReaderType;
typedef otb::Functor::VariableLengthVectorToMeasurementVectorFunctor<SampleType>
MeasurementVectorFunctorType;
typedef otb::StatisticsXMLFileReader<SampleType> StatisticsReader;
typedef otb::Statistics::ShiftScaleSampleListFilter<ListSampleType> ShiftScaleListSampleFilterType;
typedef otb::SVMSampleListModelEstimator<
ListSampleType,
LabelListSampleType,
MeasurementVectorFunctorType> SVMEstimatorType;
typedef otb::LibSVMMachineLearningModel<FunctionPrecisionType, LabelType> SVMType;
typedef FunctionType::PointType PointType;
typedef DescriptorsListSampleGeneratorType::SamplesPositionType SamplesPositionType;
......@@ -243,11 +233,11 @@ int otbDescriptorsSVMModelCreation(int argc, char* argv[])
shiftscaleFilter->SetScales(varianceMeasurentVector);
shiftscaleFilter->Update();
SVMEstimatorType::Pointer svmEstimator = SVMEstimatorType::New();
svmEstimator->SetInputSampleList(shiftscaleFilter->GetOutput());
svmEstimator->SetTrainingSampleList(descriptorsGenerator->GetLabelListSample());
svmEstimator->Update();
svmEstimator->GetModel()->SaveModel(outputFileName);
SVMType::Pointer svmEstimator = SVMType::New();
svmEstimator->SetInputListSample(shiftscaleFilter->GetOutput());
svmEstimator->SetTargetListSample(descriptorsGenerator->GetLabelListSample());
svmEstimator->Train();
svmEstimator->Save(outputFileName);
return EXIT_SUCCESS;
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment