diff --git a/Applications/Classification/otbSOMClassification.cxx b/Applications/Classification/otbSOMClassification.cxx index aa2307378f0974edfdc8b1e0855c7973a0a5474b..d9ccca181d3e42f579c71cf5485391fc8bfa30d4 100644 --- a/Applications/Classification/otbSOMClassification.cxx +++ b/Applications/Classification/otbSOMClassification.cxx @@ -100,43 +100,55 @@ private: AddParameter(ParameterType_InputImage, "vm", "ValidityMask"); SetParameterDescription("vm", "Validity mask"); + MandatoryOff("vm"); AddParameter(ParameterType_Float, "tp", "TrainingProbability"); SetParameterDescription("tp", "Probability for a sample to be selected in the training set"); + MandatoryOff("tp"); - AddParameter(ParameterType_Int, "ts", "TrainingSet"); + AddParameter(ParameterType_Int, "ts", "TrainingSetSize"); SetParameterDescription("ts", "Maximum training set size"); + MandatoryOff("ts"); AddParameter(ParameterType_Int, "sl", "StreamingLines"); SetParameterDescription("sl", "Number of lines in each streaming block (used during data sampling)"); + MandatoryOff("sl"); AddParameter(ParameterType_OutputImage, "som", "SOM Map"); SetParameterDescription("som","Self-Organizing Map map"); + MandatoryOff("som"); AddParameter(ParameterType_Int, "sx", "SizeX"); SetParameterDescription("sx", "X size of the SOM map"); + MandatoryOff("sx"); AddParameter(ParameterType_Int, "sy", "SizeY"); SetParameterDescription("sy", "Y size of the SOM map"); + MandatoryOff("sy"); AddParameter(ParameterType_Int, "nx", "NeighborhoodX"); SetParameterDescription("nx", "X initial neighborhood of the SOM map"); + MandatoryOff("nx"); AddParameter(ParameterType_Int, "ny", "NeighborhoodY"); SetParameterDescription("ny", "Y initial neighborhood of the SOM map"); + MandatoryOff("nx"); AddParameter(ParameterType_Int, "ni", "NumberIteration"); SetParameterDescription("ni", "Number of iterations of the SOM learning"); + MandatoryOff("ni"); AddParameter(ParameterType_Float, "bi", "BetaInit"); SetParameterDescription("bi", "Initial beta value"); + MandatoryOff("bi"); AddParameter(ParameterType_Float, "bf", "BetaFinal"); SetParameterDescription("bf", "Final beta value"); + MandatoryOff("bf"); AddParameter(ParameterType_Float, "iv", "InitialValue"); SetParameterDescription("iv", "Initial value (max weight)"); - + MandatoryOff("iv"); AddParameter(ParameterType_RAM, "ram", "Available RAM"); SetDefaultParameterInt("ram", 256); @@ -144,6 +156,7 @@ private: // TODO : replace StreamingLines by RAM param ? // Default parameters + SetDefaultParameterFloat("tp",1.0); SetDefaultParameterInt("sx", 32); SetDefaultParameterInt("sy", 32); SetDefaultParameterInt("nx", 10); @@ -177,41 +190,66 @@ private: } void DoExecute() - { + { // initiating random number generation itk::Statistics::MersenneTwisterRandomVariateGenerator::Pointer randomGen = itk::Statistics::MersenneTwisterRandomVariateGenerator::New(); FloatVectorImageType::Pointer input = GetParameterImage("in"); - LabeledImageType::Pointer mask = GetParameterImage<LabeledImageType>("vm"); + LabeledImageType::Pointer mask; + m_UseMask = false; + if (HasValue("vm")) + { + mask = GetParameterImage<LabeledImageType>("vm"); + if (input->GetLargestPossibleRegion() + != mask->GetLargestPossibleRegion()) + { + otbAppLogFATAL("Mask image and input image have different sizes."); + } + m_UseMask = true; + } /*******************************************/ /* Sampling data */ /*******************************************/ otbAppLogINFO("-- SAMPLING DATA --"); - - if (input->GetLargestPossibleRegion() - != mask->GetLargestPossibleRegion()) - { - otbAppLogFATAL("Mask image and input image have different sizes."); - } RegionType largestRegion = input->GetLargestPossibleRegion(); // Setting up local streaming capabilities SplitterType::Pointer splitter = SplitterType::New(); - unsigned int numberOfStreamDivisions = StreamingTraitsType::CalculateNumberOfStreamDivisions(input, + unsigned int numberOfStreamDivisions; + if (HasValue("sl")) + { + numberOfStreamDivisions = StreamingTraitsType::CalculateNumberOfStreamDivisions(input, largestRegion, splitter, otb::SET_BUFFER_NUMBER_OF_LINES, 0, 0, GetParameterInt("sl")); + } + else + { + numberOfStreamDivisions = StreamingTraitsType::CalculateNumberOfStreamDivisions(input, + largestRegion, + splitter, + otb::SET_BUFFER_MEMORY_SIZE, + 0, 1048576*GetParameterInt("ram"),0); + } otbAppLogINFO("The images will be streamed into "<<numberOfStreamDivisions<<" parts."); // Training sample lists ListSampleType::Pointer sampleList = ListSampleType::New(); const double trainingProb = static_cast<double>(GetParameterFloat("tp")); - const unsigned int nbsamples = GetParameterInt("ts"); + unsigned int nbsamples; + if (HasValue("ts")) + { + nbsamples = GetParameterInt("ts"); + } + else + { + nbsamples = largestRegion.GetNumberOfPixels(); + } // Sample dimension and max dimension unsigned int sampleSize = input->GetNumberOfComponentsPerPixel(); @@ -229,35 +267,64 @@ private: // TODO : maybe change the approach: at the moment, the sampling process is able to pick a sample twice or more while (totalSamples < nbsamples) { + unsigned int localNbSamples=0; + piece = randPerm[index]; - streamingRegion = splitter->GetSplit(piece, numberOfStreamDivisions, largestRegion); - //otbAppLogINFO("Processing region: "<<streamingRegion); - + input->SetRequestedRegion(streamingRegion); input->PropagateRequestedRegion(); input->UpdateOutputData(); - - mask->SetRequestedRegion(streamingRegion); - mask->PropagateRequestedRegion(); - mask->UpdateOutputData(); IteratorType it(input, streamingRegion); - LabeledIteratorType maskIt(mask, streamingRegion); - it.GoToBegin(); - maskIt.GoToBegin(); - - unsigned int localNbSamples=0; - - // Loop on the image - while ( !it.IsAtEnd() - && !maskIt.IsAtEnd() - && (totalSamples<nbsamples)) + + if (m_UseMask) { - // If the current pixel is labeled - if (maskIt.Get()>0) + mask->SetRequestedRegion(streamingRegion); + mask->PropagateRequestedRegion(); + mask->UpdateOutputData(); + + LabeledIteratorType maskIt(mask, streamingRegion); + maskIt.GoToBegin(); + + // Loop on the image and the mask + while ( !it.IsAtEnd() + && !maskIt.IsAtEnd() + && (totalSamples<nbsamples)) + { + // If the current pixel is labeled + if (maskIt.Get()>0) + { + if (randomGen->GetVariateWithClosedRange() < trainingProb) + { + SampleType newSample; + newSample.SetSize(sampleSize); + // build the sample + newSample.Fill(0); + for (unsigned int i = 0; i<sampleSize; ++i) + { + newSample[i]=it.Get()[i]; + } + // Update the sample lists + sampleList->PushBack(newSample); + ++totalSamples; + ++localNbSamples; + } + } + ++it; + ++maskIt; + } + index++; + // we could break out of the while loop here, once the entire image has been streamed once + if (index == numberOfStreamDivisions) index = 0; + } + else + { + // Loop on the image + while ( !it.IsAtEnd() + && (totalSamples<nbsamples)) { if (randomGen->GetVariateWithClosedRange() < trainingProb) { @@ -274,16 +341,12 @@ private: ++totalSamples; ++localNbSamples; } + ++it; } - ++it; - ++maskIt; + index++; + // we could break out of the while loop here, once the entire image has been streamed once + if (index == numberOfStreamDivisions) index = 0; } - - index++; - // we could break out of the while loop here, once the entire image has been streamed once - if (index == numberOfStreamDivisions) index = 0; - - //otbAppLogINFO(""<<localNbSamples<<" samples added to the training set."); } otbAppLogINFO("The final training set contains "<<totalSamples<<" samples."); @@ -313,8 +376,11 @@ private: estimator->Update(); m_SOMMap = estimator->GetOutput(); - SetParameterOutputImage<DoubleVectorImageType>("som",m_SOMMap); - + if (HasValue("som")) + { + SetParameterOutputImage<DoubleVectorImageType>("som",m_SOMMap); + } + /*******************************************/ /* Classification */ /*******************************************/ @@ -322,14 +388,15 @@ private: m_Classifier = ClassificationFilterType::New(); m_Classifier->SetInput(input); - m_Classifier->SetInputMask(mask); m_Classifier->SetMap(m_SOMMap); + if (m_UseMask) m_Classifier->SetInputMask(mask); AddProcess(m_Classifier,"Classification"); SetParameterOutputImage<LabeledImageType>("out",m_Classifier->GetOutput()); - } + } + bool m_UseMask; SOMMapType::Pointer m_SOMMap; ClassificationFilterType::Pointer m_Classifier; }; @@ -338,8 +405,3 @@ private: } OTB_APPLICATION_EXPORT(otb::Wrapper::SOMClassification) - -int main(int argc, char * argv[]) -{ - -}