@@ -546,6 +546,10 @@ void Decl::transpile(std::ostream &out, int tab) const
546
546
<< " = "
547
547
<< " _g." ;
548
548
549
+ if (this ->DataType == TypeSpecifier::INT){
550
+ out << " _scalar" ;
551
+ }
552
+
549
553
switch (this ->GradType )
550
554
{
551
555
case GradSpecifier::CNS:
@@ -571,8 +575,9 @@ void InitDeclarator::transpile(std::ostream &out, int tab) const
571
575
}
572
576
573
577
if (this ->initializer != nullptr )
574
- {
575
- out << " , " ;
578
+ {
579
+ if (!this ->declarator ->Dimensions .empty ())
580
+ out << " , " ;
576
581
this ->initializer ->transpile (out, tab);
577
582
}
578
583
}
@@ -651,16 +656,22 @@ void Expr::transpile(std::ostream &out, int tab) const
651
656
void GradStmt::transpile (std::ostream &out, int tab) const
652
657
{
653
658
if (this ->grad_type == GradType::GRAD)
654
- {
655
- out << std::string (" \t " , tab) << this ->name << " ->gradient.print();" << std::endl;
659
+ {
660
+ SymTabItem *item = search (root->symbolTable , this ->name );
661
+ if (item->type != " Tensor" )
662
+ out << std::string (" \t " , tab) << " std::cout << " << this ->name << " ->scalar_gradient" << " << std::endl;" << std::endl;
663
+ else
664
+ out << std::string (" \t " , tab) << this ->name << " ->gradient.print();" << std::endl;
665
+ }
666
+ else if (this ->grad_type == GradType::PRINT){
667
+ SymTabItem *item = search (root->symbolTable , this ->name );
668
+ if (item->type != " Tensor" )
669
+ out << std::string (" \t " , tab) << " std::cout << " << this ->name <<" ->ddata " <<" << std::endl;" << std::endl;
670
+ else
671
+ out << std::string (" \t " , tab) << this ->name << " ->" <<" data.print();" << std::endl;
656
672
}
657
673
else
658
674
{
659
675
out << std::string (" \t " , tab) << " _g." << GradTypeMapCpp[this ->grad_type ] << " (" << this ->name << " );" << std::endl;
660
676
}
661
- }
662
-
663
- // int main()
664
- // {
665
- // return 0;
666
- // }
677
+ }
0 commit comments