From 570491705eec847ceca8588aafe967bb00566bdd Mon Sep 17 00:00:00 2001
From: aacebo <aacebowork@gmail.com>
Date: Fri, 22 Nov 2024 16:06:41 -0500
Subject: [PATCH] add join

---
 sqlx/join.go                                  | 67 ++++++++++++++++++
 sqlx/select.go                                | 16 +++++
 sqlx/select_test.go                           | 68 +++++++++++++++++++
 sqlx/testcases/select/left_join.sql           |  1 +
 sqlx/testcases/select/left_join_pretty.sql    |  6 ++
 sqlx/testcases/select/left_outer_join.sql     |  1 +
 .../select/left_outer_join_pretty.sql         |  6 ++
 7 files changed, 165 insertions(+)
 create mode 100644 sqlx/join.go
 create mode 100644 sqlx/testcases/select/left_join.sql
 create mode 100644 sqlx/testcases/select/left_join_pretty.sql
 create mode 100644 sqlx/testcases/select/left_outer_join.sql
 create mode 100644 sqlx/testcases/select/left_outer_join_pretty.sql

diff --git a/sqlx/join.go b/sqlx/join.go
new file mode 100644
index 0000000..6d8f577
--- /dev/null
+++ b/sqlx/join.go
@@ -0,0 +1,67 @@
+package sqlx
+
+import (
+	"fmt"
+	"strings"
+)
+
+type JoinClause struct {
+	method *string
+	table  string
+	where  *WhereClause
+}
+
+func LeftJoin(table string, predicate any) *JoinClause {
+	method := "LEFT"
+	return &JoinClause{&method, table, Where(predicate)}
+}
+
+func LeftOuterJoin(table string, predicate any) *JoinClause {
+	method := "LEFT OUTER"
+	return &JoinClause{&method, table, Where(predicate)}
+}
+
+func (self *JoinClause) And(predicate any) *JoinClause {
+	self.where.And(predicate)
+	return self
+}
+
+func (self *JoinClause) Or(predicate any) *JoinClause {
+	self.where.Or(predicate)
+	return self
+}
+
+func (self JoinClause) Sql() string {
+	parts := []string{}
+
+	if self.method != nil {
+		parts = append(parts, *self.method)
+	}
+
+	parts = append(parts, "JOIN", self.table, "ON")
+	parts = append(parts, self.where.Sql())
+	return strings.Join(parts, " ")
+}
+
+func (self JoinClause) SqlPretty(indent string) string {
+	parts := []string{}
+
+	if self.method != nil {
+		parts = append(parts, fmt.Sprintf("%s JOIN %s", *self.method, self.table))
+	} else {
+		parts = append(parts, fmt.Sprintf("JOIN %s", self.table))
+	}
+
+	lines := strings.Split(self.where.SqlPretty(indent), "\n")
+	parts = append(parts, indent+"ON "+lines[0])
+
+	for _, line := range lines[1:] {
+		parts = append(parts, indent+line)
+	}
+
+	return strings.Join(parts, "\n")
+}
+
+func (self *JoinClause) setDepth(_ uint) {
+
+}
diff --git a/sqlx/select.go b/sqlx/select.go
index ae0cf55..68a8ed1 100644
--- a/sqlx/select.go
+++ b/sqlx/select.go
@@ -10,6 +10,7 @@ type SelectStatement struct {
 	columns Columns
 	from    Sqlizer
 	where   *WhereClause
+	joins   []Sqlizer
 	groupBy Sqlizer
 	orderBy Sqlizer
 	limit   Sqlizer
@@ -26,6 +27,7 @@ func Select(columns ...any) *SelectStatement {
 	return &SelectStatement{
 		depth:   0,
 		columns: cols,
+		joins:   []Sqlizer{},
 	}
 }
 
@@ -39,6 +41,11 @@ func (self *SelectStatement) From(from any) *SelectStatement {
 	return self
 }
 
+func (self *SelectStatement) Join(join Sqlizer) *SelectStatement {
+	self.joins = append(self.joins, join)
+	return self
+}
+
 func (self *SelectStatement) Where(predicate any) *SelectStatement {
 	self.where = Where(predicate)
 	return self
@@ -93,6 +100,10 @@ func (self SelectStatement) Sql() string {
 		parts = append(parts, "FROM", self.from.Sql())
 	}
 
+	for _, join := range self.joins {
+		parts = append(parts, join.Sql())
+	}
+
 	if self.where != nil {
 		parts = append(parts, "WHERE", self.where.Sql())
 	}
@@ -144,6 +155,11 @@ func (self SelectStatement) SqlPretty(indent string) string {
 		parts = append(parts, lines[1:]...)
 	}
 
+	for _, join := range self.joins {
+		lines := strings.Split(join.SqlPretty(indent), "\n")
+		parts = append(parts, lines...)
+	}
+
 	if self.where != nil {
 		lines := strings.Split(self.where.SqlPretty(indent), "\n")
 		parts = append(parts, "WHERE "+lines[0])
diff --git a/sqlx/select_test.go b/sqlx/select_test.go
index 1ab4b66..12b13a0 100644
--- a/sqlx/select_test.go
+++ b/sqlx/select_test.go
@@ -166,6 +166,40 @@ func TestSelect(t *testing.T) {
 		})
 	})
 
+	t.Run("join", func(t *testing.T) {
+		t.Run("left", func(t *testing.T) {
+			expected, err := os.ReadFile("./testcases/select/left_join.sql")
+
+			if err != nil {
+				t.Fatal(err)
+			}
+
+			sql := sqlx.Select("*").From("a").Join(
+				sqlx.LeftJoin("b", "a.id = b.id").And("b.deleted_at IS NULL"),
+			).Sql()
+
+			if sql != strings.TrimSuffix(string(expected), "\n") {
+				t.Fatalf(sql)
+			}
+		})
+
+		t.Run("left outer", func(t *testing.T) {
+			expected, err := os.ReadFile("./testcases/select/left_outer_join.sql")
+
+			if err != nil {
+				t.Fatal(err)
+			}
+
+			sql := sqlx.Select("*").From("a").Join(
+				sqlx.LeftOuterJoin("b", "a.id = b.id").And("b.deleted_at IS NULL"),
+			).Sql()
+
+			if sql != strings.TrimSuffix(string(expected), "\n") {
+				t.Fatalf(sql)
+			}
+		})
+	})
+
 	t.Run("group by", func(t *testing.T) {
 		expected, err := os.ReadFile("./testcases/select/group_by.sql")
 
@@ -399,6 +433,40 @@ func TestSelect(t *testing.T) {
 			})
 		})
 
+		t.Run("join", func(t *testing.T) {
+			t.Run("left", func(t *testing.T) {
+				expected, err := os.ReadFile("./testcases/select/left_join_pretty.sql")
+
+				if err != nil {
+					t.Fatal(err)
+				}
+
+				sql := sqlx.Select("*").From("a").Join(
+					sqlx.LeftJoin("b", "a.id = b.id").And("b.deleted_at IS NULL"),
+				).SqlPretty("    ")
+
+				if sql != strings.TrimSuffix(string(expected), "\n") {
+					t.Fatalf(sql)
+				}
+			})
+
+			t.Run("left outer", func(t *testing.T) {
+				expected, err := os.ReadFile("./testcases/select/left_outer_join_pretty.sql")
+
+				if err != nil {
+					t.Fatal(err)
+				}
+
+				sql := sqlx.Select("*").From("a").Join(
+					sqlx.LeftOuterJoin("b", "a.id = b.id").And("b.deleted_at IS NULL"),
+				).SqlPretty("    ")
+
+				if sql != strings.TrimSuffix(string(expected), "\n") {
+					t.Fatalf(sql)
+				}
+			})
+		})
+
 		t.Run("group by", func(t *testing.T) {
 			expected, err := os.ReadFile("./testcases/select/group_by_pretty.sql")
 
diff --git a/sqlx/testcases/select/left_join.sql b/sqlx/testcases/select/left_join.sql
new file mode 100644
index 0000000..e0ea5c1
--- /dev/null
+++ b/sqlx/testcases/select/left_join.sql
@@ -0,0 +1 @@
+SELECT * FROM a LEFT JOIN b ON a.id = b.id AND b.deleted_at IS NULL;
diff --git a/sqlx/testcases/select/left_join_pretty.sql b/sqlx/testcases/select/left_join_pretty.sql
new file mode 100644
index 0000000..6600579
--- /dev/null
+++ b/sqlx/testcases/select/left_join_pretty.sql
@@ -0,0 +1,6 @@
+SELECT
+    *
+FROM a
+LEFT JOIN b
+    ON a.id = b.id
+    AND b.deleted_at IS NULL;
diff --git a/sqlx/testcases/select/left_outer_join.sql b/sqlx/testcases/select/left_outer_join.sql
new file mode 100644
index 0000000..39ba3b4
--- /dev/null
+++ b/sqlx/testcases/select/left_outer_join.sql
@@ -0,0 +1 @@
+SELECT * FROM a LEFT OUTER JOIN b ON a.id = b.id AND b.deleted_at IS NULL;
diff --git a/sqlx/testcases/select/left_outer_join_pretty.sql b/sqlx/testcases/select/left_outer_join_pretty.sql
new file mode 100644
index 0000000..e8b5fee
--- /dev/null
+++ b/sqlx/testcases/select/left_outer_join_pretty.sql
@@ -0,0 +1,6 @@
+SELECT
+    *
+FROM a
+LEFT OUTER JOIN b
+    ON a.id = b.id
+    AND b.deleted_at IS NULL;