|
15 | 15 | from .bregman import sinkhorn
|
16 | 16 | from .lp import emd
|
17 | 17 | from .utils import unif, dist, kernel, cost_normalization
|
18 |
| -from .utils import check_params, deprecated, BaseEstimator |
| 18 | +from .utils import check_params, BaseEstimator |
19 | 19 | from .optim import cg
|
20 | 20 | from .optim import gcg
|
21 | 21 |
|
@@ -740,288 +740,6 @@ def OT_mapping_linear(xs, xt, reg=1e-6, ws=None,
|
740 | 740 | return A, b
|
741 | 741 |
|
742 | 742 |
|
743 |
| -@deprecated("The class OTDA is deprecated in 0.3.1 and will be " |
744 |
| - "removed in 0.5" |
745 |
| - "\n\tfor standard transport use class EMDTransport instead.") |
746 |
| -class OTDA(object): |
747 |
| - |
748 |
| - """Class for domain adaptation with optimal transport as proposed in [5] |
749 |
| -
|
750 |
| -
|
751 |
| - References |
752 |
| - ---------- |
753 |
| -
|
754 |
| - .. [5] N. Courty; R. Flamary; D. Tuia; A. Rakotomamonjy, |
755 |
| - "Optimal Transport for Domain Adaptation," in IEEE Transactions on |
756 |
| - Pattern Analysis and Machine Intelligence , vol.PP, no.99, pp.1-1 |
757 |
| -
|
758 |
| - """ |
759 |
| - |
760 |
| - def __init__(self, metric='sqeuclidean', norm=None): |
761 |
| - """ Class initialization""" |
762 |
| - self.xs = 0 |
763 |
| - self.xt = 0 |
764 |
| - self.G = 0 |
765 |
| - self.metric = metric |
766 |
| - self.norm = norm |
767 |
| - self.computed = False |
768 |
| - |
769 |
| - def fit(self, xs, xt, ws=None, wt=None, max_iter=100000): |
770 |
| - """Fit domain adaptation between samples is xs and xt |
771 |
| - (with optional weights)""" |
772 |
| - self.xs = xs |
773 |
| - self.xt = xt |
774 |
| - |
775 |
| - if wt is None: |
776 |
| - wt = unif(xt.shape[0]) |
777 |
| - if ws is None: |
778 |
| - ws = unif(xs.shape[0]) |
779 |
| - |
780 |
| - self.ws = ws |
781 |
| - self.wt = wt |
782 |
| - |
783 |
| - self.M = dist(xs, xt, metric=self.metric) |
784 |
| - self.M = cost_normalization(self.M, self.norm) |
785 |
| - self.G = emd(ws, wt, self.M, max_iter) |
786 |
| - self.computed = True |
787 |
| - |
788 |
| - def interp(self, direction=1): |
789 |
| - """Barycentric interpolation for the source (1) or target (-1) samples |
790 |
| -
|
791 |
| - This Barycentric interpolation solves for each source (resp target) |
792 |
| - sample xs (resp xt) the following optimization problem: |
793 |
| -
|
794 |
| - .. math:: |
795 |
| - arg\min_x \sum_i \gamma_{k,i} c(x,x_i^t) |
796 |
| -
|
797 |
| - where k is the index of the sample in xs |
798 |
| -
|
799 |
| - For the moment only squared euclidean distance is provided but more |
800 |
| - metric could be used in the future. |
801 |
| -
|
802 |
| - """ |
803 |
| - if direction > 0: # >0 then source to target |
804 |
| - G = self.G |
805 |
| - w = self.ws.reshape((self.xs.shape[0], 1)) |
806 |
| - x = self.xt |
807 |
| - else: |
808 |
| - G = self.G.T |
809 |
| - w = self.wt.reshape((self.xt.shape[0], 1)) |
810 |
| - x = self.xs |
811 |
| - |
812 |
| - if self.computed: |
813 |
| - if self.metric == 'sqeuclidean': |
814 |
| - return np.dot(G / w, x) # weighted mean |
815 |
| - else: |
816 |
| - print( |
817 |
| - "Warning, metric not handled yet, using weighted average") |
818 |
| - return np.dot(G / w, x) # weighted mean |
819 |
| - return None |
820 |
| - else: |
821 |
| - print("Warning, model not fitted yet, returning None") |
822 |
| - return None |
823 |
| - |
824 |
| - def predict(self, x, direction=1): |
825 |
| - """ Out of sample mapping using the formulation from [6] |
826 |
| -
|
827 |
| - For each sample x to map, it finds the nearest source sample xs and |
828 |
| - map the samle x to the position xst+(x-xs) wher xst is the barycentric |
829 |
| - interpolation of source sample xs. |
830 |
| -
|
831 |
| - References |
832 |
| - ---------- |
833 |
| -
|
834 |
| - .. [6] Ferradans, S., Papadakis, N., Peyré, G., & Aujol, J. F. (2014). |
835 |
| - Regularized discrete optimal transport. SIAM Journal on Imaging |
836 |
| - Sciences, 7(3), 1853-1882. |
837 |
| -
|
838 |
| - """ |
839 |
| - if direction > 0: # >0 then source to target |
840 |
| - xf = self.xt |
841 |
| - x0 = self.xs |
842 |
| - else: |
843 |
| - xf = self.xs |
844 |
| - x0 = self.xt |
845 |
| - |
846 |
| - D0 = dist(x, x0) # dist netween new samples an source |
847 |
| - idx = np.argmin(D0, 1) # closest one |
848 |
| - xf = self.interp(direction) # interp the source samples |
849 |
| - # aply the delta to the interpolation |
850 |
| - return xf[idx, :] + x - x0[idx, :] |
851 |
| - |
852 |
| - |
853 |
| -@deprecated("The class OTDA_sinkhorn is deprecated in 0.3.1 and will be" |
854 |
| - " removed in 0.5 \nUse class SinkhornTransport instead.") |
855 |
| -class OTDA_sinkhorn(OTDA): |
856 |
| - |
857 |
| - """Class for domain adaptation with optimal transport with entropic |
858 |
| - regularization |
859 |
| -
|
860 |
| -
|
861 |
| - """ |
862 |
| - |
863 |
| - def fit(self, xs, xt, reg=1, ws=None, wt=None, **kwargs): |
864 |
| - """Fit regularized domain adaptation between samples is xs and xt |
865 |
| - (with optional weights)""" |
866 |
| - self.xs = xs |
867 |
| - self.xt = xt |
868 |
| - |
869 |
| - if wt is None: |
870 |
| - wt = unif(xt.shape[0]) |
871 |
| - if ws is None: |
872 |
| - ws = unif(xs.shape[0]) |
873 |
| - |
874 |
| - self.ws = ws |
875 |
| - self.wt = wt |
876 |
| - |
877 |
| - self.M = dist(xs, xt, metric=self.metric) |
878 |
| - self.M = cost_normalization(self.M, self.norm) |
879 |
| - self.G = sinkhorn(ws, wt, self.M, reg, **kwargs) |
880 |
| - self.computed = True |
881 |
| - |
882 |
| - |
883 |
| -@deprecated("The class OTDA_lpl1 is deprecated in 0.3.1 and will be" |
884 |
| - " removed in 0.5 \nUse class SinkhornLpl1Transport instead.") |
885 |
| -class OTDA_lpl1(OTDA): |
886 |
| - |
887 |
| - """Class for domain adaptation with optimal transport with entropic and |
888 |
| - group regularization""" |
889 |
| - |
890 |
| - def fit(self, xs, ys, xt, reg=1, eta=1, ws=None, wt=None, **kwargs): |
891 |
| - """Fit regularized domain adaptation between samples is xs and xt |
892 |
| - (with optional weights), See ot.da.sinkhorn_lpl1_mm for fit |
893 |
| - parameters""" |
894 |
| - self.xs = xs |
895 |
| - self.xt = xt |
896 |
| - |
897 |
| - if wt is None: |
898 |
| - wt = unif(xt.shape[0]) |
899 |
| - if ws is None: |
900 |
| - ws = unif(xs.shape[0]) |
901 |
| - |
902 |
| - self.ws = ws |
903 |
| - self.wt = wt |
904 |
| - |
905 |
| - self.M = dist(xs, xt, metric=self.metric) |
906 |
| - self.M = cost_normalization(self.M, self.norm) |
907 |
| - self.G = sinkhorn_lpl1_mm(ws, ys, wt, self.M, reg, eta, **kwargs) |
908 |
| - self.computed = True |
909 |
| - |
910 |
| - |
911 |
| -@deprecated("The class OTDA_l1L2 is deprecated in 0.3.1 and will be" |
912 |
| - " removed in 0.5 \nUse class SinkhornL1l2Transport instead.") |
913 |
| -class OTDA_l1l2(OTDA): |
914 |
| - |
915 |
| - """Class for domain adaptation with optimal transport with entropic |
916 |
| - and group lasso regularization""" |
917 |
| - |
918 |
| - def fit(self, xs, ys, xt, reg=1, eta=1, ws=None, wt=None, **kwargs): |
919 |
| - """Fit regularized domain adaptation between samples is xs and xt |
920 |
| - (with optional weights), See ot.da.sinkhorn_lpl1_gl for fit |
921 |
| - parameters""" |
922 |
| - self.xs = xs |
923 |
| - self.xt = xt |
924 |
| - |
925 |
| - if wt is None: |
926 |
| - wt = unif(xt.shape[0]) |
927 |
| - if ws is None: |
928 |
| - ws = unif(xs.shape[0]) |
929 |
| - |
930 |
| - self.ws = ws |
931 |
| - self.wt = wt |
932 |
| - |
933 |
| - self.M = dist(xs, xt, metric=self.metric) |
934 |
| - self.M = cost_normalization(self.M, self.norm) |
935 |
| - self.G = sinkhorn_l1l2_gl(ws, ys, wt, self.M, reg, eta, **kwargs) |
936 |
| - self.computed = True |
937 |
| - |
938 |
| - |
939 |
| -@deprecated("The class OTDA_mapping_linear is deprecated in 0.3.1 and will be" |
940 |
| - " removed in 0.5 \nUse class MappingTransport instead.") |
941 |
| -class OTDA_mapping_linear(OTDA): |
942 |
| - |
943 |
| - """Class for optimal transport with joint linear mapping estimation as in |
944 |
| - [8] |
945 |
| - """ |
946 |
| - |
947 |
| - def __init__(self): |
948 |
| - """ Class initialization""" |
949 |
| - |
950 |
| - self.xs = 0 |
951 |
| - self.xt = 0 |
952 |
| - self.G = 0 |
953 |
| - self.L = 0 |
954 |
| - self.bias = False |
955 |
| - self.computed = False |
956 |
| - self.metric = 'sqeuclidean' |
957 |
| - |
958 |
| - def fit(self, xs, xt, mu=1, eta=1, bias=False, **kwargs): |
959 |
| - """ Fit domain adaptation between samples is xs and xt (with optional |
960 |
| - weights)""" |
961 |
| - self.xs = xs |
962 |
| - self.xt = xt |
963 |
| - self.bias = bias |
964 |
| - |
965 |
| - self.ws = unif(xs.shape[0]) |
966 |
| - self.wt = unif(xt.shape[0]) |
967 |
| - |
968 |
| - self.G, self.L = joint_OT_mapping_linear( |
969 |
| - xs, xt, mu=mu, eta=eta, bias=bias, **kwargs) |
970 |
| - self.computed = True |
971 |
| - |
972 |
| - def mapping(self): |
973 |
| - return lambda x: self.predict(x) |
974 |
| - |
975 |
| - def predict(self, x): |
976 |
| - """ Out of sample mapping estimated during the call to fit""" |
977 |
| - if self.computed: |
978 |
| - if self.bias: |
979 |
| - x = np.hstack((x, np.ones((x.shape[0], 1)))) |
980 |
| - return x.dot(self.L) # aply the delta to the interpolation |
981 |
| - else: |
982 |
| - print("Warning, model not fitted yet, returning None") |
983 |
| - return None |
984 |
| - |
985 |
| - |
986 |
| -@deprecated("The class OTDA_mapping_kernel is deprecated in 0.3.1 and will be" |
987 |
| - " removed in 0.5 \nUse class MappingTransport instead.") |
988 |
| -class OTDA_mapping_kernel(OTDA_mapping_linear): |
989 |
| - |
990 |
| - """Class for optimal transport with joint nonlinear mapping |
991 |
| - estimation as in [8]""" |
992 |
| - |
993 |
| - def fit(self, xs, xt, mu=1, eta=1, bias=False, kerneltype='gaussian', |
994 |
| - sigma=1, **kwargs): |
995 |
| - """ Fit domain adaptation between samples is xs and xt """ |
996 |
| - self.xs = xs |
997 |
| - self.xt = xt |
998 |
| - self.bias = bias |
999 |
| - |
1000 |
| - self.ws = unif(xs.shape[0]) |
1001 |
| - self.wt = unif(xt.shape[0]) |
1002 |
| - self.kernel = kerneltype |
1003 |
| - self.sigma = sigma |
1004 |
| - self.kwargs = kwargs |
1005 |
| - |
1006 |
| - self.G, self.L = joint_OT_mapping_kernel( |
1007 |
| - xs, xt, mu=mu, eta=eta, bias=bias, **kwargs) |
1008 |
| - self.computed = True |
1009 |
| - |
1010 |
| - def predict(self, x): |
1011 |
| - """ Out of sample mapping estimated during the call to fit""" |
1012 |
| - |
1013 |
| - if self.computed: |
1014 |
| - K = kernel( |
1015 |
| - x, self.xs, method=self.kernel, sigma=self.sigma, |
1016 |
| - **self.kwargs) |
1017 |
| - if self.bias: |
1018 |
| - K = np.hstack((K, np.ones((x.shape[0], 1)))) |
1019 |
| - return K.dot(self.L) |
1020 |
| - else: |
1021 |
| - print("Warning, model not fitted yet, returning None") |
1022 |
| - return None |
1023 |
| - |
1024 |
| - |
1025 | 743 | def distribution_estimation_uniform(X):
|
1026 | 744 | """estimates a uniform distribution from an array of samples X
|
1027 | 745 |
|
|
0 commit comments