diff --git a/sec_check.go b/sec_check.go index dea36386..10fe693c 100644 --- a/sec_check.go +++ b/sec_check.go @@ -20,7 +20,7 @@ func checkConds(conds []clause.Expression) error { } var banClauses = map[string]bool{ - "INSERT": true, + // "INSERT": true, "VALUES": true, // "ON CONFLICT": true, "SELECT": true, @@ -44,6 +44,8 @@ func CheckClause(cond clause.Expression) error { return checkOnConflict(cond) case clause.Locking: return checkLocking(cond) + case clause.Insert: + return checkInsert(cond) case clause.Interface: if banClauses[cond.Name()] { return fmt.Errorf("clause %s is banned", cond.Name()) @@ -75,3 +77,38 @@ func checkLocking(c clause.Locking) error { } return nil } + +// checkInsert check if clause.Insert is safe +// https://dev.mysql.com/doc/refman/8.0/en/sql-statements.html#insert +func checkInsert(c clause.Insert) error { + if c.Table.Raw == true { + return errors.New("Table Raw cannot be true") + } + + if c.Modifier == "" { + return nil + } + + var priority, ignore string + if modifiers := strings.SplitN(strings.ToUpper(strings.TrimSpace(c.Modifier)), " ", 2); len(modifiers) == 2 { + priority, ignore = strings.TrimSpace(modifiers[0]), strings.TrimSpace(modifiers[1]) + } else { + ignore = strings.TrimSpace(modifiers[0]) + } + if priority != "" && !in(priority, "LOW_PRIORITY", "DELAYED", "HIGH_PRIORITY") { + return errors.New("invalid priority value") + } + if ignore != "" && ignore != "IGNORE" { + return errors.New("invalid modifiers value, should be IGNORE") + } + return nil +} + +func in(s string, v ...string) bool { + for _, vv := range v { + if vv == s { + return true + } + } + return false +}