Skip to content

Commit fde07f7

Browse files
authored
Merge pull request #77 from GenericMappingTools/classific-proba
Add a classification_proba() function.
2 parents 63b1cef + c45cf58 commit fde07f7

File tree

1 file changed

+66
-11
lines changed

1 file changed

+66
-11
lines changed

src/utils.jl

+66-11
Original file line numberDiff line numberDiff line change
@@ -785,23 +785,46 @@ end
785785
when training the model with `train_raster`.
786786
"""
787787
function classify(cube::GItype, model; class_names::Union{String, Vector{String}}="")
788-
nr, nc = size(cube)[1:2];
789-
mat = Vector{UInt8}(undef, nr*nc);
790-
t = permutedims(cube.z, (1,3,2));
791-
i1 = 1; i2 = nr;
788+
nr, nc = size(cube)[1:2]
789+
mat = Vector{UInt8}(undef, nr*nc)
790+
t = permutedims(cube.z, (1,3,2))
791+
i1 = 1; i2 = nr
792792
for k = 1:nc # Loop over columns
793793
mat[i1:i2] = DecisionTree.predict(model, Float64.(t[:,:,k]))
794794
i1 = i2 + 1
795795
i2 = i1 + nr - 1
796796
end
797-
I = mat2img(reshape(mat, nc,nr), cube);
797+
I = mat2img(reshape(mat, nr,nc), cube)
798798
(class_names == "") && return I # No class names, no CPT
799799
classes = isa(class_names, Vector) ? join(class_names, ",") : class_names
800-
cpt = makecpt(cmap=:categorical, range=classes);
800+
cpt = makecpt(cmap=:categorical, range=classes)
801801
image_cpt!(I, cpt)
802802
return I
803803
end
804804

805+
# ----------------------------------------------------------------------------------------------------------
806+
"""
807+
I = classification_proba(cube::GItype, model; class_number=1) -> GMTimage
808+
809+
Returns an image with the assigned probabilities when classifying the class number `class_number`
810+
811+
- `cube`: The cube wtih band data to classify
812+
- `model`: is the model obtained from the `train_raster` function
813+
- `class_number`: is the class number to be classified
814+
"""
815+
function classification_proba(cube::GItype, model; class_number::Int=1)
816+
nr, nc = size(cube)[1:2]
817+
mat = Vector{UInt8}(undef, nr*nc)
818+
t = permutedims(cube.z, (1,3,2))
819+
i1 = 1; i2 = nr
820+
for k = 1:nc # Loop over columns
821+
mat[i1:i2] = round.(UInt8, predict_proba(model, Float64.(t[:,:,k]))[:,class_number] * 255)
822+
i1 = i2 + 1
823+
i2 = i1 + nr - 1
824+
end
825+
mat2img(reshape(mat, nc,nr), cube)
826+
end
827+
805828
# ----------------------------------------------------------------------------------------------------------
806829
"""
807830
model, classes = train_raster(cube::GItype, train::Union{Vector{<:GMTdataset}, String}; np::Int=0, density=0.1)
@@ -816,15 +839,47 @@ end
816839
817840
Returns the trained model and the class names.
818841
"""
819-
function train_raster(cube::GItype, train::Union{Vector{<:GMTdataset}, String}; np::Int=0, density=0.1)
842+
function train_raster(cube::GItype, train::Union{Vector{<:GMTdataset}, String}; np::Int=0, density=0.1, max_depth=3)
820843
samples = isa(train, String) ? gmtread(train) : train
821-
pts = randinpolygon(samples, np=np, density=density);
822-
LCsamp = grdinterpolate(cube, S=pts, nocoords=true);
823-
features = GMT.ds2ds(LCsamp);
844+
pts = randinpolygon(samples, np=np, density=density)
845+
LCsamp = grdinterpolate(cube, S=pts, nocoords=true)
846+
features = GMT.ds2ds(LCsamp)
824847
labels = parse.(UInt8, vcat([fill(LCsamp[k].attrib["id"], size(LCsamp[k],1)) for k=1:length(LCsamp)]...));
825848

826-
model = DecisionTree.DecisionTreeClassifier(max_depth=3);
849+
model = DecisionTree.DecisionTreeClassifier(max_depth=max_depth)
827850
DecisionTree.fit!(model, features, labels)
828851
classes = join(unique(GMT.make_attrtbl(samples, false)[1][:,1]), ",")
829852
return model, classes
830853
end
854+
855+
# ----------------------------------------------------------------------------------------------------------
856+
#=
857+
From https://rspatial.org/rs/5-supclassification.html
858+
859+
TODO. Simplify this with a call to classify() and make it an example.
860+
861+
samp = gmtread("samples.shp.zip");
862+
pts = randinpolygon(samp, density=0.02);
863+
C = gdalread("LC08_L1TP_20210525_02_cube.tiff");
864+
LCsamp = grdinterpolate(C, S=pts, nocoords=true);
865+
features = GMT.ds2ds(LCsamp);
866+
#labels = vcat([fill(LCsamp[k].attrib["class"], size(LCsamp[k],1)) for k=1:length(LCsamp)]...);
867+
labels = parse.(UInt8, vcat([fill(LCsamp[k].attrib["id"], size(LCsamp[k],1)) for k=1:length(LCsamp)]...));
868+
869+
using DecisionTree
870+
model = DecisionTreeClassifier(max_depth=3);
871+
fit!(model, features, labels)
872+
nr, nc = size(C)[1:2];
873+
mat = Array{UInt8}(undef, nr*nc);
874+
t = permutedims(C.z, (1,3,2));
875+
i1 = 1; i2 = nr;
876+
for k = 1:nc # Loop over columns
877+
mat[i1:i2] = predict(model, t[:,:,k])
878+
i1 = i2 + 1
879+
i2 = i1 + nr - 1
880+
end
881+
I = mat2img(reshape(mat, nc,nr), C);
882+
cpt = makecpt(cmap=:categorical, range="cropland,water,fallow,built,open");
883+
image_cpt!(I, cpt)
884+
viz(I, colorbar=true)
885+
=#

0 commit comments

Comments
 (0)