From dbb582b930c824cad4c772cf2fd16b31e0602486 Mon Sep 17 00:00:00 2001 From: Jordi Inglada <jordi.inglada@orfeo-toolbox.org> Date: Tue, 6 Jun 2006 12:12:10 +0000 Subject: [PATCH] Corrections pour le SVM multiclass --- Code/Learning/otbSVMClassifier.txx | 22 +++++++++---- Code/Learning/otbSVMModel.h | 1 + Examples/Data/ROI_mask_multi.png | Bin 851 -> 845 bytes .../Learning/GenerateTrainingImageExample.cxx | 6 ++-- ...ageEstimatorClassificationMultiExample.cxx | 29 +++++++++--------- 5 files changed, 36 insertions(+), 22 deletions(-) diff --git a/Code/Learning/otbSVMClassifier.txx b/Code/Learning/otbSVMClassifier.txx index 2e534e74ed..fe8e2bf773 100644 --- a/Code/Learning/otbSVMClassifier.txx +++ b/Code/Learning/otbSVMClassifier.txx @@ -180,6 +180,7 @@ SVMClassifier< TSample, TLabel > typename OutputType::ConstIterator endO = m_Output->End() ; typename TSample::MeasurementVectorType measurements ; + int numberOfComponentsPerSample = iter.GetMeasurementVector().Size() ;//this->GetSample().GetMeasurementVectorSize();// int max_line_len = 1024; @@ -187,13 +188,16 @@ SVMClassifier< TSample, TLabel > int max_nr_attr = 64; bool predict_probability = 1; + const struct svm_model* model = m_Model->GetModel(); // char* line = (char *) malloc(max_line_len*sizeof(char)); // x = (struct svm_node *) malloc(max_nr_attr*sizeof(struct // svm_node)); - m_Model->AllocateProblem(1, numberOfComponentsPerSample); - x = m_Model->GetXSpace(); + +/* m_Model->AllocateProblem(1, numberOfComponentsPerSample);*/ + + x = new svm_node[numberOfComponentsPerSample+1];//m_Model->GetXSpace(); //std::cout << "XSpace Allocated" << std::endl; if(svm_check_probability_model(model)==0) @@ -208,11 +212,14 @@ SVMClassifier< TSample, TLabel > int total = 0; double error = 0; double sumv = 0, sumy = 0, sumvv = 0, sumyy = 0, sumvy = 0; - + + int svm_type=svm_get_svm_type(model); //std::cout << "SVM Type = " << svm_type << std::endl; + int nr_class=svm_get_nr_class(model); //std::cout << "SVM nr_class = " << nr_class << std::endl; + int *labels=(int *) malloc(nr_class*sizeof(int)); double *prob_estimates=NULL; int j; @@ -223,17 +230,19 @@ SVMClassifier< TSample, TLabel > printf("Prob. model for test data: target value = predicted value + z,\nz: Laplace distribution e^(-|z|/sigma)/(2sigma),sigma=%g\n",svm_get_svr_probability(model)); else { + svm_get_labels(model,labels); + prob_estimates = (double *) malloc(nr_class*sizeof(double)); + } /*fprintf(output,"labels"); for(j=0;j<nr_class;j++) fprintf(output," %d",labels[j]); fprintf(output,"\n");*/ - } } // while(1) - //std::cout << "Starting iterations " << std::endl; +// std::cout << "Starting iterations " << std::endl; while (iter != end && iterO != endO) { @@ -329,7 +338,8 @@ if(predict_probability) //std::cout << "End of iterations and free" << std::endl; // free(x); - + +delete [] x; } } // end of namespace itk diff --git a/Code/Learning/otbSVMModel.h b/Code/Learning/otbSVMModel.h index ea37fc27ae..d4fc25084b 100644 --- a/Code/Learning/otbSVMModel.h +++ b/Code/Learning/otbSVMModel.h @@ -145,6 +145,7 @@ public: /** Allocates the problem */ void AllocateProblem(int l, long int elements); + /** Sets the model */ void SetModel(struct svm_model* aModel); diff --git a/Examples/Data/ROI_mask_multi.png b/Examples/Data/ROI_mask_multi.png index 7322cfc2294637951d68ae78986acd066f814175..3a3de297566e508c8eb9162bd71cc70724c242e1 100644 GIT binary patch delta 211 zcmcc2c9v~InA#&x7srr_IdAV+^Bpo^V09GQ@xSw@Xa^t5#3u0*J}-Avv~E9}$3C&q zLg+*FvML51Q~lj_|IT#HZjElZI*CzvGB2YT7_Vg#7qH7+#jaCa@$I2S5$l>m^HnD= zX9}2{z_>;Dhgq;~gZ`&arZH}7ZuISxW0<{k;y+ceK8Swq$$^ZHstx~*uQE$y{QvR8 zq>Qz^*p|V}yEY(DKI96!=vPNqD05<?qWJqAOxq5#tStDrl$U{lfx*+&&t;ucLK6TH CwNu>y delta 221 zcmX@hcA0HLnAQtV7srr_IdAV+_8oFyaCLkmeeld*{{uy;OeQxsc<zvYuxi8Z)W1)b z-piRB$Y?2ScRqx_VPV;MTYJ6DkCNg(JwCuYiBWcP1C##5|J;+;Gx5ncEcN}`@cd8Q z?i>9vjQS7H3k3wqhg@M7{p#pi8xSZzIgs)F#0l)eduFecJ0N}YUUh-g#w6QkRSavU zO)g}zoA_UMass2V%D(d<{0R%@|L^Bd4*!t8j{QK^N`9?Z&0W71EV5g%0>Yfws3`tG bF#X(NmX&@?f?F9F7#KWV{an^LB{Ts5zVc)O diff --git a/Examples/Learning/GenerateTrainingImageExample.cxx b/Examples/Learning/GenerateTrainingImageExample.cxx index 525c497cf9..8117571855 100644 --- a/Examples/Learning/GenerateTrainingImageExample.cxx +++ b/Examples/Learning/GenerateTrainingImageExample.cxx @@ -116,14 +116,15 @@ int main( int argc, char ** argv ) --nbRois; OutputPixelType label = 0; - unsigned long xUL, yUL, xBR, yBR = 0; + unsigned long xUL, yUL, xBR, yBR, tmp_label = 0; - roisFile >> label; + roisFile >> tmp_label; roisFile >> xUL; roisFile >> yUL; roisFile >> xBR; roisFile >> yBR; + label = static_cast<OutputPixelType>(tmp_label); std::cout << "Label : " << int(label) << std::endl; std::cout << "( " << xUL << " , " << yUL << " )" << std::endl; @@ -151,6 +152,7 @@ int main( int argc, char ** argv ) while(!it.IsAtEnd()) { + it.Set(static_cast<OutputPixelType>(label)); //std::cout << static_cast<OutputPixelType>(label) << " -- "; diff --git a/Examples/Learning/SVMImageEstimatorClassificationMultiExample.cxx b/Examples/Learning/SVMImageEstimatorClassificationMultiExample.cxx index 9384a99701..268dce5913 100644 --- a/Examples/Learning/SVMImageEstimatorClassificationMultiExample.cxx +++ b/Examples/Learning/SVMImageEstimatorClassificationMultiExample.cxx @@ -152,6 +152,7 @@ int main( int argc, char* argv[] ) // Software Guide : BeginCodeSnippet svmEstimator->SaveModel(outputModelFileName); + // Software Guide : EndCodeSnippet @@ -176,11 +177,9 @@ int main( int argc, char* argv[] ) ClassifyReaderType::Pointer cReader = ClassifyReaderType::New(); - cReader->SetFileName( inputImageFileName ); cReader->Update(); - // Software Guide : BeginLatex // @@ -197,7 +196,6 @@ int main( int argc, char* argv[] ) typedef itk::Statistics::ImageToListAdaptor< ClassifyImageType > SampleType; SampleType::Pointer sample = SampleType::New(); - // Software Guide : EndCodeSnippet // Software Guide : BeginLatex @@ -211,7 +209,6 @@ int main( int argc, char* argv[] ) sample->SetImage(cReader->GetOutput()); - // Software Guide : EndCodeSnippet // Software Guide : BeginLatex @@ -230,7 +227,7 @@ int main( int argc, char* argv[] ) typedef otb::SVMModel< InputPixelType, LabelPixelType > ModelType; ModelType::Pointer model = svmEstimator->GetModel(); - + //model->LoadModel(outputModelFileName); // Software Guide : EndCodeSnippet @@ -260,13 +257,17 @@ int main( int argc, char* argv[] ) // Software Guide : EndLatex // Software Guide : BeginCodeSnippet - - int numberOfClasses = model->GetNumberOfClasses(); + std::cout << "GNC" << std::endl; + int numberOfClasses = model->GetNumberOfClasses(); + std::cout << "SNC = "<< numberOfClasses << std::endl; classifier->SetNumberOfClasses(numberOfClasses) ; + std::cout << "SM" << std::endl; classifier->SetModel( model ); + std::cout << "SS" << std::endl; classifier->SetSample(sample.GetPointer()) ; + std::cout << "Up" << std::endl; classifier->Update() ; - + std::cout << "---" << std::endl; // Software Guide : EndCodeSnippet // Software Guide : BeginLatex @@ -288,7 +289,7 @@ int main( int argc, char* argv[] ) typedef otb::Image< OutputPixelType, Dimension > OutputImageType; OutputImageType::Pointer outputImage = OutputImageType::New(); - + std::cout << "---" << std::endl; // Software Guide : EndCodeSnippet // Software Guide : BeginLatex @@ -320,7 +321,7 @@ int main( int argc, char* argv[] ) outputImage->SetRegions( region ); outputImage->Allocate(); - + std::cout << "---" << std::endl; // Software Guide : EndCodeSnippet // Software Guide : BeginLatex @@ -360,14 +361,14 @@ int main( int argc, char* argv[] ) // Software Guide : BeginCodeSnippet - + std::cout << "---" << std::endl; while (m_iter != m_last && !outIt.IsAtEnd()) { outIt.Set(m_iter.GetClassLabel()); ++m_iter ; ++outIt; } - + std::cout << "---" << std::endl; // Software Guide : EndCodeSnippet // Software Guide : BeginLatex @@ -388,12 +389,12 @@ int main( int argc, char* argv[] ) FileImageType > RescalerType; RescalerType::Pointer rescaler = RescalerType::New(); - + std::cout << "---" << std::endl; rescaler->SetOutputMinimum( itk::NumericTraits< unsigned char >::min()); rescaler->SetOutputMaximum( itk::NumericTraits< unsigned char >::max()); rescaler->SetInput( outputImage ); - + std::cout << "---" << std::endl; // Software Guide : EndCodeSnippet // Software Guide : BeginLatex -- GitLab