Skip to content

Commit

Permalink
Merge pull request #23 from razorpay/emr_sqoop_fix
Browse files Browse the repository at this point in the history
added Driver support for EMR DE-3061
  • Loading branch information
manishsingh-rzp authored Jul 23, 2024
2 parents 1280cca + fbe0d92 commit a771ba4
Showing 1 changed file with 34 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,12 @@ import com.razorpay.spark.jdbc.common.Constants
import org.apache.spark.sql.functions.{col, from_unixtime, lit, substring}
import org.apache.spark.sql.types.{IntegerType, LongType}
import org.apache.spark.sql.{DataFrame, SaveMode, SparkSession}
import org.slf4j.{Logger, LoggerFactory}

import scala.sys.process._
import java.util.Properties


/**
* Class that contains JDBC source, read parallelism params and target table name
*
Expand Down Expand Up @@ -54,10 +57,15 @@ case class ImportConfig(
val escapeCharacter = if (dbType == Constants.MYSQL) {
"`"
} else if (dbType == Constants.POSTGRESQL) {
"""""""
""
}

var inputTableEscaped: String = escapeCharacter + inputTable + escapeCharacter

if (dbType == Constants.POSTGRESQL){
inputTableEscaped = schema.get + "."+ inputTableEscaped
}

val inputTableEscaped: String = escapeCharacter + inputTable + escapeCharacter

val boundsSql: String = boundaryQuery.getOrElse(
s"(select min($splitColumn) as min, max($splitColumn) as max from $inputTableEscaped) as bounds"
Expand All @@ -81,6 +89,8 @@ class JDBCImport(
)(implicit val spark: SparkSession) {

import spark.implicits._
val logger: Logger = LoggerFactory.getLogger(this.getClass)


def createDbIfNotExists(outputDbName: String): Unit = {
val s3Bucket = Credentials.getSecretValue("SQOOP_S3_BUCKET")
Expand Down Expand Up @@ -130,11 +140,7 @@ class JDBCImport(
val database = importConfig.database
val schema = importConfig.schema

var connectionUrl = s"jdbc:$dbType://$host:$port/$database"

if (dbType == Constants.POSTGRESQL && schema.isDefined) {
connectionUrl = s"jdbc:$dbType://$host:$port/$database?currentSchema=${schema.get}"
}
val connectionUrl = s"jdbc:$dbType://$host:$port/$database"

connectionUrl
}
Expand All @@ -150,6 +156,14 @@ class JDBCImport(

if (importConfig.splitBy.nonEmpty) {
val defaultString = "0"
val dbType = Credentials.getSecretValue(s"${databricksScope}_DB_TYPE")
val schema = importConfig.schema

var dbTable = importConfig.jdbcQuery

logger.error(s"JDBC 1: jdbcUrl $buildJdbcUrl and dbTable $dbTable")


val (lower, upper) = spark.read
.jdbc(buildJdbcUrl, importConfig.boundsSql, jdbcParams)
.selectExpr("cast(min as string) min", "cast(max as string) max")
Expand All @@ -160,11 +174,13 @@ class JDBCImport(

val jdbcUsername = Credentials.getSecretValue(s"${databricksScope}_DB_USERNAME")
val jdbcPassword = Credentials.getSecretValue(s"${databricksScope}_DB_PASSWORD")
val driverType = DriverType.getJdbcDriver(dbType)

spark.read
.format("jdbc")
.option("driver",driverType)
.option("url", buildJdbcUrl)
.option("dbtable", importConfig.jdbcQuery)
.option("dbtable", dbTable)
.option("user", jdbcUsername)
.option("password", jdbcPassword)
.option("partitionColumn", importConfig.splitColumn)
Expand Down Expand Up @@ -341,4 +357,14 @@ object Credentials {
}
}

object DriverType{
def getJdbcDriver(dbtype: String): String = {
dbtype.toLowerCase match {
case Constants.MYSQL => "com.mysql.cj.jdbc.Driver"
case Constants.POSTGRESQL => "org.postgresql.Driver"
case _ => throw new IllegalArgumentException(s"Unsupported dbtype: $dbtype")
}
}

}

0 comments on commit a771ba4

Please sign in to comment.