@@ -785,23 +785,46 @@ end
785
785
when training the model with `train_raster`.
786
786
"""
787
787
function classify (cube:: GItype , model; class_names:: Union{String, Vector{String}} = " " )
788
- nr, nc = size (cube)[1 : 2 ];
789
- mat = Vector {UInt8} (undef, nr* nc);
790
- t = permutedims (cube. z, (1 ,3 ,2 ));
791
- i1 = 1 ; i2 = nr;
788
+ nr, nc = size (cube)[1 : 2 ]
789
+ mat = Vector {UInt8} (undef, nr* nc)
790
+ t = permutedims (cube. z, (1 ,3 ,2 ))
791
+ i1 = 1 ; i2 = nr
792
792
for k = 1 : nc # Loop over columns
793
793
mat[i1: i2] = DecisionTree. predict (model, Float64 .(t[:,:,k]))
794
794
i1 = i2 + 1
795
795
i2 = i1 + nr - 1
796
796
end
797
- I = mat2img (reshape (mat, nc,nr ), cube);
797
+ I = mat2img (reshape (mat, nr,nc ), cube)
798
798
(class_names == " " ) && return I # No class names, no CPT
799
799
classes = isa (class_names, Vector) ? join (class_names, " ," ) : class_names
800
- cpt = makecpt (cmap= :categorical , range= classes);
800
+ cpt = makecpt (cmap= :categorical , range= classes)
801
801
image_cpt! (I, cpt)
802
802
return I
803
803
end
804
804
805
+ # ----------------------------------------------------------------------------------------------------------
806
+ """
807
+ I = classification_proba(cube::GItype, model; class_number=1) -> GMTimage
808
+
809
+ Returns an image with the assigned probabilities when classifying the class number `class_number`
810
+
811
+ - `cube`: The cube wtih band data to classify
812
+ - `model`: is the model obtained from the `train_raster` function
813
+ - `class_number`: is the class number to be classified
814
+ """
815
+ function classification_proba (cube:: GItype , model; class_number:: Int = 1 )
816
+ nr, nc = size (cube)[1 : 2 ]
817
+ mat = Vector {UInt8} (undef, nr* nc)
818
+ t = permutedims (cube. z, (1 ,3 ,2 ))
819
+ i1 = 1 ; i2 = nr
820
+ for k = 1 : nc # Loop over columns
821
+ mat[i1: i2] = round .(UInt8, predict_proba (model, Float64 .(t[:,:,k]))[:,class_number] * 255 )
822
+ i1 = i2 + 1
823
+ i2 = i1 + nr - 1
824
+ end
825
+ mat2img (reshape (mat, nc,nr), cube)
826
+ end
827
+
805
828
# ----------------------------------------------------------------------------------------------------------
806
829
"""
807
830
model, classes = train_raster(cube::GItype, train::Union{Vector{<:GMTdataset}, String}; np::Int=0, density=0.1)
@@ -816,15 +839,47 @@ end
816
839
817
840
Returns the trained model and the class names.
818
841
"""
819
- function train_raster (cube:: GItype , train:: Union{Vector{<:GMTdataset}, String} ; np:: Int = 0 , density= 0.1 )
842
+ function train_raster (cube:: GItype , train:: Union{Vector{<:GMTdataset}, String} ; np:: Int = 0 , density= 0.1 , max_depth = 3 )
820
843
samples = isa (train, String) ? gmtread (train) : train
821
- pts = randinpolygon (samples, np= np, density= density);
822
- LCsamp = grdinterpolate (cube, S= pts, nocoords= true );
823
- features = GMT. ds2ds (LCsamp);
844
+ pts = randinpolygon (samples, np= np, density= density)
845
+ LCsamp = grdinterpolate (cube, S= pts, nocoords= true )
846
+ features = GMT. ds2ds (LCsamp)
824
847
labels = parse .(UInt8, vcat ([fill (LCsamp[k]. attrib[" id" ], size (LCsamp[k],1 )) for k= 1 : length (LCsamp)]. .. ));
825
848
826
- model = DecisionTree. DecisionTreeClassifier (max_depth= 3 );
849
+ model = DecisionTree. DecisionTreeClassifier (max_depth= max_depth)
827
850
DecisionTree. fit! (model, features, labels)
828
851
classes = join (unique (GMT. make_attrtbl (samples, false )[1 ][:,1 ]), " ," )
829
852
return model, classes
830
853
end
854
+
855
+ # ----------------------------------------------------------------------------------------------------------
856
+ #=
857
+ From https://rspatial.org/rs/5-supclassification.html
858
+
859
+ TODO . Simplify this with a call to classify() and make it an example.
860
+
861
+ samp = gmtread("samples.shp.zip");
862
+ pts = randinpolygon(samp, density=0.02);
863
+ C = gdalread("LC08_L1TP_20210525_02_cube.tiff");
864
+ LCsamp = grdinterpolate(C, S=pts, nocoords=true);
865
+ features = GMT.ds2ds(LCsamp);
866
+ #labels = vcat([fill(LCsamp[k].attrib["class"], size(LCsamp[k],1)) for k=1:length(LCsamp)]...);
867
+ labels = parse.(UInt8, vcat([fill(LCsamp[k].attrib["id"], size(LCsamp[k],1)) for k=1:length(LCsamp)]...));
868
+
869
+ using DecisionTree
870
+ model = DecisionTreeClassifier(max_depth=3);
871
+ fit!(model, features, labels)
872
+ nr, nc = size(C)[1:2];
873
+ mat = Array{UInt8}(undef, nr*nc);
874
+ t = permutedims(C.z, (1,3,2));
875
+ i1 = 1; i2 = nr;
876
+ for k = 1:nc # Loop over columns
877
+ mat[i1:i2] = predict(model, t[:,:,k])
878
+ i1 = i2 + 1
879
+ i2 = i1 + nr - 1
880
+ end
881
+ I = mat2img(reshape(mat, nc,nr), C);
882
+ cpt = makecpt(cmap=:categorical, range="cropland,water,fallow,built,open");
883
+ image_cpt!(I, cpt)
884
+ viz(I, colorbar=true)
885
+ =#
0 commit comments