@@ -240,7 +240,7 @@ namespace fir
240
240
}
241
241
242
242
243
- bool ClassType::isInParentHierarchy (Type* base)
243
+ bool ClassType::hasParent (Type* base)
244
244
{
245
245
auto target = dcast (ClassType, base);
246
246
if (!target) return false ;
@@ -294,41 +294,92 @@ namespace fir
294
294
this ->reverseVirtualMethodMap = this ->baseClass ->reverseVirtualMethodMap ;
295
295
}
296
296
297
- void ClassType::addVirtualMethod (Function* method)
297
+
298
+ // expects the self param to be removed already!!!
299
+ // note: this one doesn't check if the return types are compatible; we expect typechecking to have already
300
+ // verified that, and we don't store the return type in the class virtual method map anyway.
301
+ static bool _areTypeListsVirtuallyCompatible (const std::vector<Type*>& base, const std::vector<Type*>& fn)
298
302
{
299
- // * what this does is compare the arguments without the first parameter,
300
- // * since that's going to be the self parameter, and that's going to be different
301
- auto withoutself = [](std::vector<Type*> p) -> std::vector<Type*> {
302
- p.erase (p.begin ());
303
+ // parameters must be contravariant, ie. fn must take more general types than base
304
+ // return type must be covariant, ie. fn must return a more specific type than base.
303
305
304
- return p;
305
- };
306
+ // duh
307
+ if (base.size () != fn.size ())
308
+ return false ;
306
309
307
- auto matching = [&withoutself](const std::vector<Type*>& a, FunctionType* ft) -> bool {
308
- auto bp = withoutself (ft->getArgumentTypes ());
310
+ // drop the first argument.
311
+ for (auto [ base, derv ] : util::zip (base, fn))
312
+ {
313
+ if (base == derv)
314
+ continue ;
309
315
310
- // * note: we don't call withoutself on 'a' because we expect that to already have been done
311
- // * before it was added.
312
- return Type::areTypeListsEqual (a, bp);
313
- };
316
+ if (!derv->isPointerType () || !derv->getPointerElementType ()->isClassType ()
317
+ || !base->isPointerType () || !base->getPointerElementType ()->isClassType ())
318
+ {
319
+ return false ;
320
+ }
321
+
322
+ auto bc = base->getPointerElementType ()->toClassType ();
323
+ auto dc = derv->getPointerElementType ()->toClassType ();
324
+
325
+ if (!bc->hasParent (dc))
326
+ {
327
+ debuglogln (" %s is not a parent of %s" , dc->str (), bc->str ());
328
+ return false ;
329
+ }
330
+ }
331
+
332
+ return true ;
333
+ }
314
334
335
+ bool ClassType::areMethodsVirtuallyCompatible (FunctionType* base, FunctionType* fn)
336
+ {
337
+ bool ret = _areTypeListsVirtuallyCompatible (util::drop (base->getArgumentTypes (), 1 ), util::drop (fn->getArgumentTypes (), 1 ));
338
+
339
+ if (!ret)
340
+ return false ;
341
+
342
+ auto baseRet = base->getReturnType ();
343
+ auto fnRet = fn->getReturnType ();
344
+
345
+ // ok now check the return type.
346
+ if (baseRet == fnRet)
347
+ return true ;
348
+
349
+ if (baseRet->isPointerType () && baseRet->getPointerElementType ()->isClassType ()
350
+ && fnRet->isPointerType () && fnRet->getPointerElementType ()->isClassType ())
351
+ {
352
+ auto br = baseRet->getPointerElementType ()->toClassType ();
353
+ auto dr = fnRet->getPointerElementType ()->toClassType ();
354
+
355
+ return dr->hasParent (br);
356
+ }
357
+ else
358
+ {
359
+ return false ;
360
+ }
361
+ }
362
+
363
+ void ClassType::addVirtualMethod (Function* method)
364
+ {
315
365
// * note: the 'reverse' virtual method map is to allow us, at translation time, to easily create the vtable without
316
366
// * unnecessary searching. When we set a base class, we copy its 'reverse' map; thus, if we don't override anything,
317
367
// * our vtable will just refer to the methods in the base class.
318
368
319
369
// * but if we do override something, we just set the method in our 'reverse' map, which is what we'll use to build
320
370
// * the vtable. simple?
321
371
322
- auto list = method->getType ()->toFunctionType ()->getArgumentTypes ();
372
+ auto list = util::drop ( method->getType ()->toFunctionType ()->getArgumentTypes (), 1 );
323
373
324
374
// check every member of the current mapping -- not the fastest method i admit.
325
375
bool found = false ;
326
376
for (auto vm : this ->virtualMethodMap )
327
377
{
328
- if (vm.first .first == method->getName ().name && matching (vm.first .second , method->getType ()->toFunctionType ()))
378
+ if (vm.first .first == method->getName ().name
379
+ && _areTypeListsVirtuallyCompatible (vm.first .second , list))
329
380
{
330
381
found = true ;
331
- this ->virtualMethodMap [{ method->getName ().name , withoutself ( list) }] = vm.second ;
382
+ this ->virtualMethodMap [{ method->getName ().name , list }] = vm.second ;
332
383
this ->reverseVirtualMethodMap [vm.second ] = method;
333
384
break ;
334
385
}
@@ -337,7 +388,7 @@ namespace fir
337
388
if (!found)
338
389
{
339
390
// just make a new one.
340
- this ->virtualMethodMap [{ method->getName ().name , withoutself ( list) }] = this ->virtualMethodCount ;
391
+ this ->virtualMethodMap [{ method->getName ().name , list }] = this ->virtualMethodCount ;
341
392
this ->reverseVirtualMethodMap [this ->virtualMethodCount ] = method;
342
393
this ->virtualMethodCount ++;
343
394
}
0 commit comments