Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 84 additions & 23 deletions inst/Classification/ClassificationGAM.m
Original file line number Diff line number Diff line change
Expand Up @@ -149,9 +149,9 @@
##
## Prior probability for each class
##
## A 2-element numeric vector specifying the prior probabilities for each
## class. The order of the elements in @qcode{Prior} corresponds to the
## order of the classes in @qcode{ClassNames}. This property is read-only.
## A numeric vector specifying the prior probabilities for each class. The
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need to change the property docstring. It is always a 2-element numeric vector.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I investigated matlab's working with Prior, so modified for MATLAB compatibility.

## order of the elements in @qcode{Prior} corresponds to the order of the
## classes in @qcode{ClassNames}. This property is read-only.
##
## @end deftp
Prior = [];
Expand Down Expand Up @@ -432,7 +432,7 @@ function disp (this)
endif
switch (s.subs)
case 'Cost'
this.Cost = setCost (this, val);
this.Cost = val;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Keep my previous change utilizing a priveate method so it can also be called from subsasgn

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried using that, but when I tested, it was causing errors, I tried 2-3 ways but it didnt work, so came up with this.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, it doesn't work because my previous code was wrong. Instead of

this.Cost = setCost (this, val);

it should have been

this = setCost (this, val);

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So should I keep the current changes, since it is working well right?

case 'ScoreTransform'
name = "ClassificationGAM";
this.ScoreTransform = parseScoreTransform (val, name);
Expand Down Expand Up @@ -567,6 +567,7 @@ function disp (this)
Formula = [];
Interactions = [];
ClassNames = [];
Prior = "empirical";
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Prior must be a 2-element numeric vector. Don't initialize it like this

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same mentioned reason of MATLAB compatibility.

DoF = ones (1, ndims_X) * 8;
Order = ones (1, ndims_X) * 3;
Knots = ones (1, ndims_X) * 5;
Expand Down Expand Up @@ -625,6 +626,20 @@ function disp (this)
endif
endif

case "prior"
Prior = varargin{2};
if (! (isnumeric (Prior) || ischar (Prior)))
error (strcat ("ClassificationGAM: 'Prior' must be", ...
" a numeric vector or a string."));
endif
if (ischar (Prior) && ! any (strcmpi (Prior, {"empirical", "uniform"})))
error (strcat ("ClassificationGAM: 'Prior' must be", ...
" 'empirical', 'uniform', or a numeric vector."));
endif
if (isnumeric (Prior) && numel (Prior) != 2)
error ("ClassificationGAM: 'Prior' must be a 2-element vector.");
endif

case "cost"
Cost = varargin{2};
if (! (isnumeric (Cost) && issquare (Cost)))
Expand Down Expand Up @@ -773,6 +788,19 @@ function disp (this)
error ("ClassificationGAM: can only be used for binary classification.");
endif

## Calculate prior probabilities
if (ischar (Prior))
if (strcmpi (Prior, "uniform"))
this.Prior = [0.5, 0.5];
elseif (strcmpi (Prior, "empirical"))
counts = histc (gY, 1:2);
this.Prior = counts / sum (counts);
endif
else
## Numeric prior - normalize to sum to 1
this.Prior = Prior / sum (Prior);
endif

## Force Y into numeric
if (! isnumeric (Y))
Y = gY - 1;
Expand All @@ -784,9 +812,16 @@ function disp (this)
## Assign the number of original predictors to the ClassificationGAM object
this.NumPredictors = ndims_X;

## Assign Cost and compute Prior (FIXME: not used)
this = setCost (this, Cost, gnY);
this.Prior = [sum(gY == 1), sum(gY == 2)];
if (isempty (Cost))
this.Cost = cast (! eye (numel (gnY)), "double");
else
if (numel (gnY) != sqrt (numel (Cost)))
error (strcat ("ClassificationGAM: the number of rows", ...
" and columns in 'Cost' must correspond", ...
" to selected classes in Y."));
endif
this.Cost = Cost;
endif

## Assign remaining optional parameters
this.Formula = Formula;
Expand Down Expand Up @@ -1338,22 +1373,6 @@ function savemodel (this, fname)
RSS = sum (res .^ 2);
endfunction

function this = setCost (this, Cost, gnY = [])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to keep this private method. See my comment above

Copy link
Contributor Author

@Sonu0305 Sonu0305 Oct 31, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same above mentioned reason for removing it.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

see my comment above

if (isempty (gnY))
[~, gnY, gY] = unique (this.Y(this.RowsUsed));
endif
if (isempty (Cost))
this.Cost = cast (! eye (numel (gnY)), "double");
else
if (numel (gnY) != sqrt (numel (Cost)))
error (strcat ("ClassificationGAM: the number", ...
" of rows and columns in 'Cost' must", ...
" correspond to selected classes in Y."));
endif
this.Cost = Cost;
endif
endfunction

endmethods

endclassdef
Expand Down Expand Up @@ -1435,6 +1454,48 @@ function savemodel (this, fname)
%! assert (a.DoF, [7, 7, 7])
%! assert (a.BaseModel.Intercept, 0.4055, 1e-1)

## Test Prior calculation
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You also need to document this name-value argument in the constructor's help docstring

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

    ## @item @qcode{'Prior'} @tab @tab A numeric vector specifying the prior
    ## probabilities for each class.  The order of the elements in @qcode{Prior}
    ## corresponds to the order of the classes in @qcode{ClassNames}.
    ## Alternatively, you can specify @qcode{"empirical"} to use the empirical
    ## class probabilities or @qcode{"uniform"} to assume equal class
    ## probabilities.

Is this what you are talking about? If yes, then it is already added in the file.

%!test
%! ## Test uniform prior
%! x = [1, 2; 3, 4; 5, 6; 7, 8];
%! y = [0; 0; 1; 1];
%! a = ClassificationGAM (x, y, 'Prior', 'uniform');
%! assert (a.Prior, [0.5, 0.5], 1e-6);
%!test
%! ## Test empirical prior
%! x = [1, 2; 3, 4; 5, 6; 7, 8; 9, 10];
%! y = [0; 0; 0; 1; 1];
%! a = ClassificationGAM (x, y, 'Prior', 'empirical');
%! assert (a.Prior, [0.6; 0.4], 1e-6);
%!test
%! ## Test numeric prior
%! x = [1, 2; 3, 4; 5, 6; 7, 8];
%! y = [0; 0; 1; 1];
%! a = ClassificationGAM (x, y, 'Prior', [0.7, 0.3]);
%! assert (a.Prior, [0.7, 0.3], 1e-6);
%!test
%! ## Test default prior (empirical)
%! x = [1, 2; 3, 4; 5, 6; 7, 8; 9, 10; 11, 12];
%! y = [0; 0; 0; 1; 1; 1];
%! a = ClassificationGAM (x, y);
%! assert (a.Prior, [0.5; 0.5], 1e-6);
%!test
%! ## Test prior normalization
%! x = [1, 2; 3, 4; 5, 6; 7, 8];
%! y = [0; 0; 1; 1];
%! a = ClassificationGAM (x, y, 'Prior', [2, 1]);
%! assert (a.Prior, [2/3, 1/3], 1e-6);

## Test input validation for Prior
%!error<ClassificationGAM: 'Prior' must be a 2-element vector.> ...
%! ClassificationGAM (ones(4,2), ones(4,1), "Prior", [1])
%!error<ClassificationGAM: 'Prior' must be a 2-element vector.> ...
%! ClassificationGAM (ones(4,2), ones(4,1), "Prior", [1, 2, 3])
%!error<ClassificationGAM: 'Prior' must be a numeric vector or a string.> ...
%! ClassificationGAM (ones(4,2), ones(4,1), "Prior", {1, 2})
%!error<ClassificationGAM: 'Prior' must be> ...
%! ClassificationGAM (ones(4,2), ones(4,1), "Prior", "invalid")

## Test input validation for constructor
%!error<ClassificationGAM: too few input arguments.> ClassificationGAM ()
%!error<ClassificationGAM: too few input arguments.> ...
Expand Down