Skip to content

Commit f3c52b2

Browse files
authored
Merge pull request #75 from GenericMappingTools/classifications
Add functions to do assisted classification.
2 parents 52dea69 + 7add8e5 commit f3c52b2

File tree

4 files changed

+85
-2
lines changed

4 files changed

+85
-2
lines changed

Project.toml

+2
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
99
Dates = "ade2ca70-3891-5945-98fb-dc099432e06a"
1010
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
1111
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
12+
DecisionTree = "7806a523-6efd-50cb-b5f6-3fa6f1930dbb"
1213

1314
[weakdeps]
1415
SatelliteToolboxTle = "7ff27aeb-5fff-4337-a9ee-a9fe6b7ed35e"
@@ -34,3 +35,4 @@ PrecompileTools = "1.0"
3435
SatelliteToolboxTle = "1"
3536
SatelliteToolboxPropagators = "0.3"
3637
SatelliteToolboxTransformations = "0.1"
38+
DecisionTree = "0.12"

src/RemoteS.jl

+4-2
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
module RemoteS
22

3-
using GMT, Printf, Statistics, Dates#, Requires
3+
using GMT, Printf, Statistics, Dates
44
using PrecompileTools
5+
using DecisionTree
56

67
const SCENE_HALFW = Dict("AQUA" => 1163479, "TERRA" => 1163479, "LANDSAT8" => 92500) # half widths
78

@@ -11,8 +12,9 @@ end
1112

1213
export
1314
cutcube, subcube, dn2temperature, dn2radiance, dn2reflectance, reflectance_surf, grid_at_sensor, truecolor,
14-
clg, clre, evi, evi2, gndvi, mndwi, mtci, mcari, msavi, mbri, ndvi, ndwi, ndwi2, ndrei1,
15+
clg, clre, evi, evi2, gndvi, mndwi, mtci, mcari, msavi, nbri, ndvi, ndwi, ndwi2, ndrei1,
1516
ndrei2, satvi, savi, slavi,
17+
classify, train_raster,
1618
clip_orbits, findscenes, sat_scenes, sat_tracks, reportbands
1719

1820
include("grid_at_sensor.jl")

src/spectral_indices.jl

+3
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,9 @@ Normalised Burn Ratio Index. Garcia 1991
208208
NBRI = (nir - swir2) / (nir + swir2)
209209
"""
210210
nbri(nir, swir2; kw...) = sp_indices(swir2, nir; index="NBRI", kw...)
211+
nbri(cube::String;
212+
bands::Vector{Int}=Int[], layers::Vector{Int}=Int[], bandnames::Vector{String}=String[], kw...) =
213+
helper_si_method(cube, "NBRI"; bands=bands, layers=layers, bandnames=bandnames, defbandnames=["nir", "swir2"], kw...)
211214
nbri(cube::Union{GMT.GMTimage{UInt16, 3}, AbstractArray{<:AbstractFloat, 3}};
212215
bands::Vector{Int}=Int[], layers::Vector{Int}=Int[], bandnames::Vector{String}=String[], kw...) =
213216
helper_si_method(cube, "NBRI"; bands=bands, layers=layers, bandnames=bandnames, defbandnames=["nir", "swir2"], kw...)

src/utils.jl

+76
Original file line numberDiff line numberDiff line change
@@ -752,3 +752,79 @@ function reflectance_surf(fname::String; band::Int=0, mtl::String="", save::Stri
752752
(save != "") && gdaltranslate(G, dest=save)
753753
return (save != "") ? nothing : G
754754
end
755+
756+
# ----------------------------------------------------------------------------------------------------------
757+
"""
758+
I = classify(cube::GItype, train::Union{Vector{<:GMTdataset}, String}) -> GMTimage
759+
760+
- `cube`: The cube wtih band data to classify.
761+
- `train`: A vector of GMTdatasets or a file name of one containing the polygons used to train the model.
762+
NOTE: The individual datasets MUST have associated an attribute called "class" containing the class name as a string.
763+
This can be achieved for text data in the form of a GMT multi-segment file (one where segments are separated by the '>'
764+
symbol) if the multi-segment separator line contains the text ``Attrib(class=name)``
765+
766+
Returns an image with the classification results where each class name was assigned a different integer number.
767+
That colorized image can plotted with ``viz(I, colorbar=true)``.
768+
"""
769+
function classify(cube::GItype, train::Union{Vector{<:GMTdataset}, String})
770+
model, classes = train_raster(cube, train)
771+
I = classify(cube, model)
772+
cpt = makecpt(cmap=:categorical, range=classes);
773+
image_cpt!(I, cpt)
774+
return I
775+
end
776+
777+
# ----------------------------------------------------------------------------------------------------------
778+
"""
779+
I = classify(cube::GItype, model; class_names::Union{String, Vector{String}}="") -> GMTimage
780+
781+
- `cube`: The cube wtih band data to classify.
782+
- `model`: The trained model obtained from the `train_raster` function.
783+
- `class_names`: A vector of strings with the class names to be used in the categorical colorbar or a
784+
comma separated single with those class names. The number of class names must match the number used
785+
when training the model with `train_raster`.
786+
"""
787+
function classify(cube::GItype, model; class_names::Union{String, Vector{String}}="")
788+
nr, nc = size(cube)[1:2];
789+
mat = Vector{UInt8}(undef, nr*nc);
790+
t = permutedims(cube.z, (1,3,2));
791+
i1 = 1; i2 = nr;
792+
for k = 1:nc # Loop over columns
793+
mat[i1:i2] = DecisionTree.predict(model, Float64.(t[:,:,k]))
794+
i1 = i2 + 1
795+
i2 = i1 + nr - 1
796+
end
797+
I = mat2img(reshape(mat, nc,nr), cube);
798+
(class_names == "") && return I # No class names, no CPT
799+
classes = isa(class_names, Vector) ? join(class_names, ",") : class_names
800+
cpt = makecpt(cmap=:categorical, range=classes);
801+
image_cpt!(I, cpt)
802+
return I
803+
end
804+
805+
# ----------------------------------------------------------------------------------------------------------
806+
"""
807+
model, classes = train_raster(cube::GItype, train::Union{Vector{<:GMTdataset}, String}; np::Int=0, density=0.1)
808+
809+
- `cube`: The cube wtih band data to classify.
810+
- `train`: A vector of GMTdatasets or a file name of one containing the polygons used to train the model.
811+
NOTE: The individual datasets MUST have associated an attribute called "class" containing the class name as a string.
812+
This can be achieved for text data in the form of a GMT multi-segment file (one where segments are separated by the '>'
813+
symbol) if the multi-segment separator line contains the text ``Attrib(class=name)``
814+
- `np`: Number of points per polygon to be determined by ``randinpolygon``
815+
- `density`: Alternative to `np`. See also the help of the ``randinpolygon`` function.
816+
817+
Returns the trained model and the class names.
818+
"""
819+
function train_raster(cube::GItype, train::Union{Vector{<:GMTdataset}, String}; np::Int=0, density=0.1)
820+
samples = isa(train, String) ? gmtread(train) : train
821+
pts = randinpolygon(samples, np=np, density=density);
822+
LCsamp = grdinterpolate(cube, S=pts, nocoords=true);
823+
features = GMT.ds2ds(LCsamp);
824+
labels = parse.(UInt8, vcat([fill(LCsamp[k].attrib["id"], size(LCsamp[k],1)) for k=1:length(LCsamp)]...));
825+
826+
model = DecisionTree.DecisionTreeClassifier(max_depth=3);
827+
DecisionTree.fit!(model, features, labels)
828+
classes = join(unique(GMT.make_attrtbl(samples, false)[1][:,1]), ",")
829+
return model, classes
830+
end

0 commit comments

Comments
 (0)