Skip to content

Commit f5e9a13

Browse files
committed
etter doc for classes
1 parent a5f2569 commit f5e9a13

File tree

1 file changed

+31
-30
lines changed

1 file changed

+31
-30
lines changed

ot/da.py

+31-30
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,7 @@ def joint_OT_mapping_linear(xs,xt,mu=1,eta=0.001,bias=False,verbose=False,verbos
151151
152152
The algorithm used for solving the problem is the block coordinate
153153
descent that alternates between updates of G (using conditionnal gradient)
154-
abd the update of L using a classical least square solver.
154+
and the update of L using a classical least square solver.
155155
156156
157157
Parameters
@@ -320,7 +320,7 @@ def joint_OT_mapping_kernel(xs,xt,mu=1,eta=0.001,kerneltype='gaussian',sigma=1,b
320320
321321
The algorithm used for solving the problem is the block coordinate
322322
descent that alternates between updates of G (using conditionnal gradient)
323-
abd the update of L using a classical kernel least square solver.
323+
and the update of L using a classical kernel least square solver.
324324
325325
326326
Parameters
@@ -492,7 +492,15 @@ def df(G):
492492

493493

494494
class OTDA(object):
495-
"""Class for domain adaptation with optimal transport"""
495+
"""Class for domain adaptation with optimal transport as proposed in [5]
496+
497+
498+
References
499+
----------
500+
501+
.. [5] N. Courty; R. Flamary; D. Tuia; A. Rakotomamonjy, "Optimal Transport for Domain Adaptation," in IEEE Transactions on Pattern Analysis and Machine Intelligence , vol.PP, no.99, pp.1-1
502+
503+
"""
496504

497505
def __init__(self,metric='sqeuclidean'):
498506
""" Class initialization"""
@@ -504,8 +512,7 @@ def __init__(self,metric='sqeuclidean'):
504512

505513

506514
def fit(self,xs,xt,ws=None,wt=None):
507-
""" Fit domain adaptation between samples is xs and xt (with optional
508-
weights)"""
515+
""" Fit domain adaptation between samples is xs and xt (with optional weights)"""
509516
self.xs=xs
510517
self.xt=xt
511518

@@ -522,7 +529,7 @@ def fit(self,xs,xt,ws=None,wt=None):
522529
self.computed=True
523530

524531
def interp(self,direction=1):
525-
"""Barycentric interpolation for the source (1) or target (-1)
532+
"""Barycentric interpolation for the source (1) or target (-1) samples
526533
527534
This Barycentric interpolation solves for each source (resp target)
528535
sample xs (resp xt) the following optimization problem:
@@ -558,10 +565,16 @@ def interp(self,direction=1):
558565

559566

560567
def predict(self,x,direction=1):
561-
""" Out of sample mapping using the formulation from Ferradans
568+
""" Out of sample mapping using the formulation from [6]
569+
570+
For each sample x to map, it finds the nearest source sample xs and
571+
map the samle x to the position xst+(x-xs) wher xst is the barycentric
572+
interpolation of source sample xs.
573+
574+
References
575+
----------
562576
563-
It basically find the source sample the nearset to the nex sample and
564-
apply the difference to the displaced source sample.
577+
.. [6] Ferradans, S., Papadakis, N., Peyré, G., & Aujol, J. F. (2014). Regularized discrete optimal transport. SIAM Journal on Imaging Sciences, 7(3), 1853-1882.
565578
566579
"""
567580
if direction>0: # >0 then source to target
@@ -582,8 +595,7 @@ class OTDA_sinkhorn(OTDA):
582595
"""Class for domain adaptation with optimal transport with entropic regularization"""
583596

584597
def fit(self,xs,xt,reg=1,ws=None,wt=None,**kwargs):
585-
""" Fit domain adaptation between samples is xs and xt (with optional
586-
weights)"""
598+
""" Fit regularized domain adaptation between samples is xs and xt (with optional weights)"""
587599
self.xs=xs
588600
self.xt=xt
589601

@@ -601,12 +613,12 @@ def fit(self,xs,xt,reg=1,ws=None,wt=None,**kwargs):
601613

602614

603615
class OTDA_lpl1(OTDA):
604-
"""Class for domain adaptation with optimal transport with entropic an group regularization"""
616+
"""Class for domain adaptation with optimal transport with entropic and group regularization"""
605617

606618

607619
def fit(self,xs,ys,xt,reg=1,eta=1,ws=None,wt=None,**kwargs):
608-
""" Fit domain adaptation between samples is xs and xt (with optional
609-
weights)"""
620+
""" Fit regularized domain adaptation between samples is xs and xt (with optional weights),
621+
See ot.da.sinkhorn_lpl1_mm for fit parameters""""
610622
self.xs=xs
611623
self.xt=xt
612624

@@ -623,7 +635,7 @@ def fit(self,xs,ys,xt,reg=1,eta=1,ws=None,wt=None,**kwargs):
623635
self.computed=True
624636

625637
class OTDA_mapping_linear(OTDA):
626-
"""Class for optimal transport with joint linear mapping estimation"""
638+
"""Class for optimal transport with joint linear mapping estimation as in [8]"""
627639

628640

629641
def __init__(self):
@@ -657,12 +669,7 @@ def mapping(self):
657669

658670

659671
def predict(self,x):
660-
""" Out of sample mapping using the formulation from Ferradans
661-
662-
It basically find the source sample the nearset to the nex sample and
663-
apply the difference to the displaced source sample.
664-
665-
"""
672+
""" Out of sample mapping estimated during the call to fit"""
666673
if self.computed:
667674
if self.bias:
668675
x=np.hstack((x,np.ones((x.shape[0],1))))
@@ -672,13 +679,12 @@ def predict(self,x):
672679
return None
673680

674681
class OTDA_mapping_kernel(OTDA_mapping_linear):
675-
"""Class for optimal transport with joint linear mapping estimation"""
682+
"""Class for optimal transport with joint nonlinear mapping estimation as in [8]"""
676683

677684

678685

679686
def fit(self,xs,xt,mu=1,eta=1,bias=False,kerneltype='gaussian',sigma=1,**kwargs):
680-
""" Fit domain adaptation between samples is xs and xt (with optional
681-
weights)"""
687+
""" Fit domain adaptation between samples is xs and xt """
682688
self.xs=xs
683689
self.xt=xt
684690
self.bias=bias
@@ -695,12 +701,7 @@ def fit(self,xs,xt,mu=1,eta=1,bias=False,kerneltype='gaussian',sigma=1,**kwargs)
695701

696702

697703
def predict(self,x):
698-
""" Out of sample mapping using the formulation from Ferradans
699-
700-
It basically find the source sample the nearset to the nex sample and
701-
apply the difference to the displaced source sample.
702-
703-
"""
704+
""" Out of sample mapping estimated during the call to fit"""
704705

705706
if self.computed:
706707
K=kernel(x,self.xs,method=self.kernel,sigma=self.sigma,**self.kwargs)

0 commit comments

Comments
 (0)