@@ -156,6 +156,113 @@ static size_t sizeof_dtype(const DTypeId dt) {
156156 };
157157};
158158
159+ static bool is_float (DTypeId t) {
160+ switch (t) {
161+ case DDPT::DTypeId::FLOAT64:
162+ case DDPT::DTypeId::FLOAT32:
163+ return true ;
164+ default :
165+ return false ;
166+ }
167+ }
168+
169+ static bool is_int (DTypeId t) {
170+ switch (t) {
171+ case DDPT::DTypeId::INT64:
172+ case DDPT::DTypeId::INT32:
173+ case DDPT::DTypeId::INT16:
174+ case DDPT::DTypeId::INT8:
175+ return true ;
176+ default :
177+ return false ;
178+ }
179+ }
180+
181+ static bool is_uint (DTypeId t) {
182+ switch (t) {
183+ case DDPT::DTypeId::UINT64:
184+ case DDPT::DTypeId::UINT32:
185+ case DDPT::DTypeId::UINT16:
186+ case DDPT::DTypeId::UINT8:
187+ case DDPT::DTypeId::BOOL:
188+ return true ;
189+ default :
190+ return false ;
191+ }
192+ }
193+
194+ static size_t dtype_bitwidth (DTypeId t) {
195+ switch (t) {
196+ case DDPT::DTypeId::FLOAT64:
197+ case DDPT::DTypeId::INT64:
198+ case DDPT::DTypeId::UINT64:
199+ return 64 ;
200+ case DDPT::DTypeId::FLOAT32:
201+ case DDPT::DTypeId::INT32:
202+ case DDPT::DTypeId::UINT32:
203+ return 32 ;
204+ case DDPT::DTypeId::INT16:
205+ case DDPT::DTypeId::UINT16:
206+ return 16 ;
207+ case DDPT::DTypeId::INT8:
208+ case DDPT::DTypeId::UINT8:
209+ return 8 ;
210+ case DDPT::DTypeId::BOOL:
211+ return 1 ;
212+ default :
213+ assert (!" Unknown DTypeId" );
214+ }
215+ }
216+
217+ static DTypeId get_float_dtype (size_t bitwidth) {
218+ switch (bitwidth) {
219+ case 64 :
220+ return DDPT::DTypeId::FLOAT64;
221+ case 32 :
222+ return DDPT::DTypeId::FLOAT32;
223+ default :
224+ assert (!" Unknown bitwidth" );
225+ }
226+ }
227+
228+ static DTypeId get_int_dtype (size_t bitwidth) {
229+ switch (bitwidth) {
230+ case 64 :
231+ return DDPT::DTypeId::INT64;
232+ case 32 :
233+ return DDPT::DTypeId::INT32;
234+ case 16 :
235+ return DDPT::DTypeId::INT16;
236+ case 8 :
237+ return DDPT::DTypeId::INT8;
238+ default :
239+ assert (!" Unknown bitwidth" );
240+ }
241+ }
242+
243+ static DTypeId promoted_dtype (DTypeId a, DTypeId b) {
244+ if ((is_float (a) && is_float (b)) || (is_int (a) && is_int (b)) ||
245+ (is_uint (a) && is_uint (b))) {
246+ return dtype_bitwidth (a) > dtype_bitwidth (b) ? a : b;
247+ }
248+ if (is_float (a) || is_float (b)) {
249+ return get_float_dtype (std::max (dtype_bitwidth (a), dtype_bitwidth (b)));
250+ }
251+ // mixed signed/unsigned int case
252+ size_t si_width, ui_width, max_width = 64 ;
253+ if (is_uint (a)) {
254+ ui_width = dtype_bitwidth (a);
255+ si_width = dtype_bitwidth (b);
256+ } else {
257+ ui_width = dtype_bitwidth (b);
258+ si_width = dtype_bitwidth (a);
259+ }
260+ if (ui_width < si_width) {
261+ return get_int_dtype (si_width);
262+ }
263+ return get_int_dtype (std::min (2 * ui_width, max_width));
264+ }
265+
159266using RedOpType = ReduceOpId;
160267
161268inline RedOpType red_op (const char *op) {
0 commit comments