diff --git a/sql-delta-import/src/main/scala/com/razorpay/spark/jdbc/JDBCImport.scala b/sql-delta-import/src/main/scala/com/razorpay/spark/jdbc/JDBCImport.scala index ac40c37e3..b200d8e17 100644 --- a/sql-delta-import/src/main/scala/com/razorpay/spark/jdbc/JDBCImport.scala +++ b/sql-delta-import/src/main/scala/com/razorpay/spark/jdbc/JDBCImport.scala @@ -20,9 +20,12 @@ import com.razorpay.spark.jdbc.common.Constants import org.apache.spark.sql.functions.{col, from_unixtime, lit, substring} import org.apache.spark.sql.types.{IntegerType, LongType} import org.apache.spark.sql.{DataFrame, SaveMode, SparkSession} +import org.slf4j.{Logger, LoggerFactory} + import scala.sys.process._ import java.util.Properties + /** * Class that contains JDBC source, read parallelism params and target table name * @@ -54,10 +57,15 @@ case class ImportConfig( val escapeCharacter = if (dbType == Constants.MYSQL) { "`" } else if (dbType == Constants.POSTGRESQL) { - """"""" + "" + } + + var inputTableEscaped: String = escapeCharacter + inputTable + escapeCharacter + + if (dbType == Constants.POSTGRESQL){ + inputTableEscaped = schema.get + "."+ inputTableEscaped } - val inputTableEscaped: String = escapeCharacter + inputTable + escapeCharacter val boundsSql: String = boundaryQuery.getOrElse( s"(select min($splitColumn) as min, max($splitColumn) as max from $inputTableEscaped) as bounds" @@ -81,6 +89,8 @@ class JDBCImport( )(implicit val spark: SparkSession) { import spark.implicits._ + val logger: Logger = LoggerFactory.getLogger(this.getClass) + def createDbIfNotExists(outputDbName: String): Unit = { val s3Bucket = Credentials.getSecretValue("SQOOP_S3_BUCKET") @@ -130,11 +140,7 @@ class JDBCImport( val database = importConfig.database val schema = importConfig.schema - var connectionUrl = s"jdbc:$dbType://$host:$port/$database" - - if (dbType == Constants.POSTGRESQL && schema.isDefined) { - connectionUrl = s"jdbc:$dbType://$host:$port/$database?currentSchema=${schema.get}" - } + val connectionUrl = s"jdbc:$dbType://$host:$port/$database" connectionUrl } @@ -150,6 +156,14 @@ class JDBCImport( if (importConfig.splitBy.nonEmpty) { val defaultString = "0" + val dbType = Credentials.getSecretValue(s"${databricksScope}_DB_TYPE") + val schema = importConfig.schema + + var dbTable = importConfig.jdbcQuery + + logger.error(s"JDBC 1: jdbcUrl $buildJdbcUrl and dbTable $dbTable") + + val (lower, upper) = spark.read .jdbc(buildJdbcUrl, importConfig.boundsSql, jdbcParams) .selectExpr("cast(min as string) min", "cast(max as string) max") @@ -160,11 +174,13 @@ class JDBCImport( val jdbcUsername = Credentials.getSecretValue(s"${databricksScope}_DB_USERNAME") val jdbcPassword = Credentials.getSecretValue(s"${databricksScope}_DB_PASSWORD") + val driverType = DriverType.getJdbcDriver(dbType) spark.read .format("jdbc") + .option("driver",driverType) .option("url", buildJdbcUrl) - .option("dbtable", importConfig.jdbcQuery) + .option("dbtable", dbTable) .option("user", jdbcUsername) .option("password", jdbcPassword) .option("partitionColumn", importConfig.splitColumn) @@ -341,4 +357,14 @@ object Credentials { } } +object DriverType{ + def getJdbcDriver(dbtype: String): String = { + dbtype.toLowerCase match { + case Constants.MYSQL => "com.mysql.cj.jdbc.Driver" + case Constants.POSTGRESQL => "org.postgresql.Driver" + case _ => throw new IllegalArgumentException(s"Unsupported dbtype: $dbtype") + } + } + +}