@@ -752,3 +752,79 @@ function reflectance_surf(fname::String; band::Int=0, mtl::String="", save::Stri
752
752
(save != " " ) && gdaltranslate (G, dest= save)
753
753
return (save != " " ) ? nothing : G
754
754
end
755
+
756
+ # ----------------------------------------------------------------------------------------------------------
757
+ """
758
+ I = classify(cube::GItype, train::Union{Vector{<:GMTdataset}, String}) -> GMTimage
759
+
760
+ - `cube`: The cube wtih band data to classify.
761
+ - `train`: A vector of GMTdatasets or a file name of one containing the polygons used to train the model.
762
+ NOTE: The individual datasets MUST have associated an attribute called "class" containing the class name as a string.
763
+ This can be achieved for text data in the form of a GMT multi-segment file (one where segments are separated by the '>'
764
+ symbol) if the multi-segment separator line contains the text ``Attrib(class=name)``
765
+
766
+ Returns an image with the classification results where each class name was assigned a different integer number.
767
+ That colorized image can plotted with ``viz(I, colorbar=true)``.
768
+ """
769
+ function classify (cube:: GItype , train:: Union{Vector{<:GMTdataset}, String} )
770
+ model, classes = train_raster (cube, train)
771
+ I = classify (cube, model)
772
+ cpt = makecpt (cmap= :categorical , range= classes);
773
+ image_cpt! (I, cpt)
774
+ return I
775
+ end
776
+
777
+ # ----------------------------------------------------------------------------------------------------------
778
+ """
779
+ I = classify(cube::GItype, model; class_names::Union{String, Vector{String}}="") -> GMTimage
780
+
781
+ - `cube`: The cube wtih band data to classify.
782
+ - `model`: The trained model obtained from the `train_raster` function.
783
+ - `class_names`: A vector of strings with the class names to be used in the categorical colorbar or a
784
+ comma separated single with those class names. The number of class names must match the number used
785
+ when training the model with `train_raster`.
786
+ """
787
+ function classify (cube:: GItype , model; class_names:: Union{String, Vector{String}} = " " )
788
+ nr, nc = size (cube)[1 : 2 ];
789
+ mat = Vector {UInt8} (undef, nr* nc);
790
+ t = permutedims (cube. z, (1 ,3 ,2 ));
791
+ i1 = 1 ; i2 = nr;
792
+ for k = 1 : nc # Loop over columns
793
+ mat[i1: i2] = DecisionTree. predict (model, Float64 .(t[:,:,k]))
794
+ i1 = i2 + 1
795
+ i2 = i1 + nr - 1
796
+ end
797
+ I = mat2img (reshape (mat, nc,nr), cube);
798
+ (class_names == " " ) && return I # No class names, no CPT
799
+ classes = isa (class_names, Vector) ? join (class_names, " ," ) : class_names
800
+ cpt = makecpt (cmap= :categorical , range= classes);
801
+ image_cpt! (I, cpt)
802
+ return I
803
+ end
804
+
805
+ # ----------------------------------------------------------------------------------------------------------
806
+ """
807
+ model, classes = train_raster(cube::GItype, train::Union{Vector{<:GMTdataset}, String}; np::Int=0, density=0.1)
808
+
809
+ - `cube`: The cube wtih band data to classify.
810
+ - `train`: A vector of GMTdatasets or a file name of one containing the polygons used to train the model.
811
+ NOTE: The individual datasets MUST have associated an attribute called "class" containing the class name as a string.
812
+ This can be achieved for text data in the form of a GMT multi-segment file (one where segments are separated by the '>'
813
+ symbol) if the multi-segment separator line contains the text ``Attrib(class=name)``
814
+ - `np`: Number of points per polygon to be determined by ``randinpolygon``
815
+ - `density`: Alternative to `np`. See also the help of the ``randinpolygon`` function.
816
+
817
+ Returns the trained model and the class names.
818
+ """
819
+ function train_raster (cube:: GItype , train:: Union{Vector{<:GMTdataset}, String} ; np:: Int = 0 , density= 0.1 )
820
+ samples = isa (train, String) ? gmtread (train) : train
821
+ pts = randinpolygon (samples, np= np, density= density);
822
+ LCsamp = grdinterpolate (cube, S= pts, nocoords= true );
823
+ features = GMT. ds2ds (LCsamp);
824
+ labels = parse .(UInt8, vcat ([fill (LCsamp[k]. attrib[" id" ], size (LCsamp[k],1 )) for k= 1 : length (LCsamp)]. .. ));
825
+
826
+ model = DecisionTree. DecisionTreeClassifier (max_depth= 3 );
827
+ DecisionTree. fit! (model, features, labels)
828
+ classes = join (unique (GMT. make_attrtbl (samples, false )[1 ][:,1 ]), " ," )
829
+ return model, classes
830
+ end
0 commit comments