@@ -2895,7 +2895,24 @@ at::Tensor viewStorage(const at::Tensor input, const c10::IntArrayRef sizes, con
2895
2895
if (st != -1 ) st *= sizes[i - 1 ];
2896
2896
}
2897
2897
}
2898
- return fromPreAllocated (input.data_ptr () + storageOffset * input.itemsize (), sizes, stridesVec, input.options ());
2898
+
2899
+ // when shape[0]=-1, fill data
2900
+ std::vector<int64_t > sizeVec (sizes.size (), 1 );
2901
+ std::copy (sizes.begin (), sizes.end (), sizeVec.begin ());
2902
+ if (!sizes.empty () && sizes[0 ] == -1 ) {
2903
+ bool flag = true ;
2904
+ for (auto i : sizes) {
2905
+ if (!flag && i < 0 ) {
2906
+ TORCH_CHECK (false , " more than one -1, sizes=" , sizes);
2907
+ }
2908
+ if (i < 0 ) {
2909
+ flag = false ;
2910
+ }
2911
+ }
2912
+ int count = std::accumulate (sizeVec.begin () + 1 , sizeVec.end (), 1 , std::multiplies<int >());
2913
+ sizeVec[0 ] = input.numel () / count;
2914
+ }
2915
+ return fromPreAllocated (input.data_ptr () + storageOffset * input.itemsize (), sizeVec, stridesVec, input.options ());
2899
2916
}
2900
2917
2901
2918
c10::List<c10::optional<at::Tensor>> castIntIndicesToLongIndices (const c10::List<c10::optional<at::Tensor>>& indices) {
@@ -3057,7 +3074,11 @@ at::Tensor wrapper__transpose(const at::Tensor& self, int64_t dim0, int64_t dim1
3057
3074
}
3058
3075
3059
3076
at::Scalar wrapper___local_scalar_dense (const at::Tensor& self) { return at_npu::native::NPUNativeFunctions::_local_scalar_dense (self); }
3077
+ at::Tensor& wrapper_out_mm_out (const at::Tensor& self, const at::Tensor& mat2, at::Tensor& out) { return acl_op::mm_out (self, mat2, out); }
3060
3078
3079
+ at::Tensor& wrapper_source_Tensor_set_ (at::Tensor& self, const at::Tensor& source) { return at_npu::native::NPUNativeFunctions::set_ (self, source); }
3080
+ at::Tensor& wrapper_out_bmm_out (const at::Tensor& self, const at::Tensor& mat2, at::Tensor& out) { return acl_op::bmm_out (self, mat2, out); }
3081
+ at::Tensor wrapper__dot (const at::Tensor& self, const at::Tensor& tensor) { return acl_op::dot (self, tensor); }
3061
3082
} // namespace
3062
3083
3063
3084
namespace at {
@@ -3092,6 +3113,11 @@ TORCH_LIBRARY_IMPL(aten, XLA, m) {
3092
3113
m.impl (" repeat" , TORCH_FN (wrapper__repeat));
3093
3114
m.impl (" transpose.int" , TORCH_FN (wrapper__transpose));
3094
3115
m.impl (" _local_scalar_dense" , TORCH_FN (wrapper___local_scalar_dense));
3116
+ m.impl (" cat" , TORCH_FN (wrapper__cat));
3117
+ m.impl (" mm.out" , TORCH_FN (wrapper_out_mm_out));
3118
+ m.impl (" set_.source_Tensor" , TORCH_FN (wrapper_source_Tensor_set_));
3119
+ m.impl (" dot" , TORCH_FN (wrapper__dot));
3120
+ m.impl (" bmm.out" , TORCH_FN (wrapper_out_bmm_out));
3095
3121
};
3096
3122
3097
3123
TORCH_LIBRARY_IMPL (_, XLA, m) { m.fallback (torch::CppFunction::makeFromBoxedFunction<&ascend_diopi_fallback>()); }
0 commit comments