From 3743db25297bffd22aa7f658b0572ef2d7f59bb9 Mon Sep 17 00:00:00 2001 From: srinidhi Date: Sat, 17 Apr 2021 15:17:30 +0530 Subject: [PATCH] Some review changes --- drivers/entry.go | 2 +- drivers/postgres.go | 2 +- drivers/sqlcommons.go | 2 +- main.go | 2 +- main_test.go | 12 +++--- pkg/action/action.go | 79 +++++++++++++++++++++++++------------- pkg/connector/connector.go | 13 +++---- pkg/flags/flags.go | 12 +++++- pkg/fuzzer/runner.go | 67 +++++++++++++++----------------- 9 files changed, 110 insertions(+), 81 deletions(-) diff --git a/drivers/entry.go b/drivers/entry.go index 3abff26..660eaaf 100644 --- a/drivers/entry.go +++ b/drivers/entry.go @@ -83,10 +83,10 @@ type Driver interface { Describe(table string, db *sql.DB) ([]FieldDescriptor, error) MultiDescribe(tables []string, db *sql.DB) (map[string][]FieldDescriptor, []string, error) GetLatestColumnValue(table, column string, db *sql.DB) (interface{}, error) - GetTestCase(name string) (TestCase, error) } type Testable interface { + GetTestCase(name string) (TestCase, error) TestTable(conn *sql.DB, testCase, table string) error } diff --git a/drivers/postgres.go b/drivers/postgres.go index 95c9a84..bd83b14 100644 --- a/drivers/postgres.go +++ b/drivers/postgres.go @@ -243,7 +243,7 @@ func (p Postgres) Describe(table string, db *sql.DB) ([]FieldDescriptor, error) } func (p Postgres) GetLatestColumnValue(table, column string, db *sql.DB) (interface{}, error) { - query := fmt.Sprintf("select %v from %v order by %v desc limit 1", column, table, column) + query := fmt.Sprintf("select %s from %s order by %s desc limit 1", column, table, column) rows, err := db.Query(query) if err != nil { return nil, err diff --git a/drivers/sqlcommons.go b/drivers/sqlcommons.go index 57d0a99..e1b0f58 100644 --- a/drivers/sqlcommons.go +++ b/drivers/sqlcommons.go @@ -74,7 +74,7 @@ func getInsertionOrder(tablesToFieldsMap map[string][]FieldDescriptor) ([]string return tablesVisitOrder, nil } -func testTable(db *sql.DB, testCase, table string, d Driver) error { +func testTable(db *sql.DB, testCase, table string, d Testable) error { test, err := d.GetTestCase(testCase) if err != nil { return err diff --git a/main.go b/main.go index 7695acd..1a4bf9a 100644 --- a/main.go +++ b/main.go @@ -16,7 +16,7 @@ func main() { f := flags.Get() gofakeit.Seed(0) driver := drivers.New(f.Driver) - db := connector.Connection(driver) + db := connector.Connection(driver, f) defer db.Close() var tables []string diff --git a/main_test.go b/main_test.go index c94ce3c..8062dc9 100644 --- a/main_test.go +++ b/main_test.go @@ -29,7 +29,7 @@ func TestFuzz(t *testing.T) { gofakeit.Seed(0) driver := drivers.New(f.Driver) testable := drivers.NewTestable(f.Driver) - db := connector.Connection(driver) + db := connector.Connection(driver, f) defer db.Close() if _, err := db.Exec(fmt.Sprintf("DROP TABLE IF EXISTS %s", f.Table)); err != nil { t.Fatal(err) @@ -86,7 +86,7 @@ func TestFuzzPostgres(t *testing.T) { gofakeit.Seed(0) driver := drivers.New(f.Driver) testable := drivers.NewTestable(f.Driver) - db := connector.Connection(driver) + db := connector.Connection(driver, f) defer db.Close() if _, err := db.Exec(fmt.Sprintf("DROP TABLE IF EXISTS %s", f.Table)); err != nil { t.Fatal(err) @@ -153,11 +153,11 @@ func TestMysqlMultiInsert(t *testing.T) { gofakeit.Seed(0) driver := drivers.New(f.Driver) testable := drivers.NewTestable(f.Driver) - test, err := driver.GetTestCase("multi") + test, err := testable.GetTestCase("multi") if err != nil { t.Error(fmt.Sprintf("postgres : error fetching test case for multi. %v", err.Error())) } - db := connector.Connection(driver) + db := connector.Connection(driver, f) defer db.Close() if _, err := db.Exec(fmt.Sprintf("DROP TABLE IF EXISTS %s", f.Table)); err != nil { t.Fatal(err) @@ -194,11 +194,11 @@ func TestPostgresMultiInsert(t *testing.T) { gofakeit.Seed(0) driver := drivers.New(f.Driver) testable := drivers.NewTestable(f.Driver) - test, err := driver.GetTestCase("multi") + test, err := testable.GetTestCase("multi") if err != nil { t.Error(fmt.Sprintf("postgres : error fetching test case for multi. %v", err.Error())) } - db := connector.Connection(driver) + db := connector.Connection(driver, f) defer db.Close() if _, err := db.Exec(fmt.Sprintf("DROP TABLE IF EXISTS %s", f.Table)); err != nil { t.Fatal(err) diff --git a/pkg/action/action.go b/pkg/action/action.go index 0623078..e02ff34 100644 --- a/pkg/action/action.go +++ b/pkg/action/action.go @@ -3,6 +3,7 @@ package action import ( "database/sql" "encoding/base64" + "errors" "fmt" "log" "math" @@ -17,15 +18,42 @@ import ( "github.com/rs/xid" ) -func InsertMulti(args ...interface{}) error { - db := args[0].(*sql.DB) - driver := args[1].(drivers.Driver) - tableToFieldsMap := args[2].(map[string][]drivers.FieldDescriptor) - insertionOrder := args[3].([]string) - //func InsertMulti(db *sql.DB, driver drivers.Driver, tableToFieldsMap map[string][]drivers.FieldDescriptor, insertionOrder []string) error { +type SingleInsertParams struct { + DB *sql.DB + Driver drivers.Driver + Table string + Fields []drivers.FieldDescriptor +} + +type MultiInsertParams struct { + DB *sql.DB + Driver drivers.Driver + InsertionOrder []string + TableToFieldsMap map[string][]drivers.FieldDescriptor +} + +type SQLInsertInput struct { + SingleInsertParams *SingleInsertParams + MultiInsertParams *MultiInsertParams +} + +func (sqlInsertInput SQLInsertInput) Insert() error { + if sqlInsertInput.SingleInsertParams != nil { + return sqlInsertInput.singleInsert() + } else if sqlInsertInput.MultiInsertParams != nil { + return sqlInsertInput.multiInsert() + } + return errors.New("action: error in sql insert input. Both single and multi insert arguments are not initialised") +} + +func (sqlInsertInput SQLInsertInput) multiInsert() error { + multiInsertParams := sqlInsertInput.MultiInsertParams + if multiInsertParams == nil { + return errors.New("action : error during multi insert. Could not find necessary arguments") + } tableFieldValuesMap := make(map[string]map[string]interface{}) - for _, table := range insertionOrder { - if fields, ok := tableToFieldsMap[table]; ok { + for _, table := range multiInsertParams.InsertionOrder { + if fields, ok := multiInsertParams.TableToFieldsMap[table]; ok { var f = make([]string, 0, len(fields)) var values []interface{} for _, field := range fields { @@ -41,19 +69,19 @@ func InsertMulti(args ...interface{}) error { continue } } - val, err := driver.GetLatestColumnValue(field.ForeignKeyDescriptor.ForeignTableName, field.ForeignKeyDescriptor.ForeignColumnName, db) + val, err := multiInsertParams.Driver.GetLatestColumnValue(field.ForeignKeyDescriptor.ForeignTableName, field.ForeignKeyDescriptor.ForeignColumnName, multiInsertParams.DB) if err != nil { return err } data = val // Get from table. If no value present in table as well, throw error. } else { - data = generateData(driver, field) + data = generateData(multiInsertParams.Driver, field) } values = append(values, data) } - query := driver.Insert(f, table) - _, err := db.Exec(query, values...) + query := multiInsertParams.Driver.Insert(f, table) + _, err := multiInsertParams.DB.Exec(query, values...) if err != nil { return err } @@ -62,26 +90,25 @@ func InsertMulti(args ...interface{}) error { return nil } -// Insert is inserting a random generated data into the chosen table -func Insert(args ...interface{}) error { - //func Insert(db *sql.DB, fields []drivers.FieldDescriptor, driver drivers.Driver, table string) error { - db := args[0].(*sql.DB) - fields := args[1].([]drivers.FieldDescriptor) - driver := args[2].(drivers.Driver) - table := args[3].(string) - var f = make([]string, 0, len(fields)) - var values = make([]interface{}, 0, len(fields)) - for _, field := range fields { +// singleInsert is inserting a random generated data into the chosen table +func (sqlInsertInput SQLInsertInput) singleInsert() error { + insertParams := sqlInsertInput.SingleInsertParams + if insertParams == nil { + return errors.New("action : error during insert. Could not find necessary arguments") + } + var f = make([]string, 0, len(insertParams.Fields)) + var values = make([]interface{}, 0, len(insertParams.Fields)) + for _, field := range insertParams.Fields { // Has default value. No need to insert this field manually. if field.HasDefaultValue { continue } f = append(f, field.Field) - values = append(values, generateData(driver, field)) + values = append(values, generateData(insertParams.Driver, field)) } - query := driver.Insert(f, table) + query := insertParams.Driver.Insert(f, insertParams.Table) - _, err := db.Exec(query, values...) + _, err := insertParams.DB.Exec(query, values...) return err } @@ -128,7 +155,7 @@ func generateData(driver drivers.Driver, fieldDescriptor drivers.FieldDescriptor ) case drivers.Time: return time.Date( - gofakeit.Number(1970, 2038), + gofakeit.Number(1980, 2028), time.Month(gofakeit.Number(0, 12)), gofakeit.Day(), gofakeit.Hour(), diff --git a/pkg/connector/connector.go b/pkg/connector/connector.go index d3dbfe2..73f7b7f 100644 --- a/pkg/connector/connector.go +++ b/pkg/connector/connector.go @@ -2,23 +2,22 @@ package connector import ( "database/sql" - "log" - "time" - "github.com/PumpkinSeed/sqlfuzz/drivers" + "github.com/PumpkinSeed/sqlfuzz/pkg/flags" _ "github.com/lib/pq" + "log" ) // Connection building a singleton connection to the database for give driver -func Connection(d drivers.Driver) *sql.DB { +func Connection(d drivers.Driver, f flags.Flags) *sql.DB { db, err := connect(d) if err != nil { log.Fatal(err) return nil } - db.SetConnMaxLifetime(100 * time.Second) - db.SetMaxIdleConns(1000) - db.SetMaxOpenConns(200) + db.SetConnMaxLifetime(f.ConnMaxLifetimeInSec) + db.SetMaxIdleConns(f.MaxIdleConns) + db.SetMaxOpenConns(f.MaxOpenConns) return db } diff --git a/pkg/flags/flags.go b/pkg/flags/flags.go index 807df44..2aecfcd 100644 --- a/pkg/flags/flags.go +++ b/pkg/flags/flags.go @@ -2,8 +2,8 @@ package flags import ( "flag" - "github.com/PumpkinSeed/sqlfuzz/drivers" + "time" ) var f Flags @@ -15,7 +15,12 @@ type Flags struct { Num int Workers int Table string - Parsed bool + + ConnMaxLifetimeInSec time.Duration + MaxIdleConns int + MaxOpenConns int + + Parsed bool } // Get the parsed flags and parsing them if needed @@ -39,6 +44,9 @@ func parse() { flag.StringVar(&f.Table, "t", "", "Table for fuzzing") flag.IntVar(&f.Num, "n", 1000, "Number of rows") flag.IntVar(&f.Workers, "w", 20, "Number of workers") + flag.IntVar(&f.MaxIdleConns, "i", 200, "Number of max sql db idle connections") + flag.IntVar(&f.MaxOpenConns, "o", 1000, "Number of max sql db open connections") + flag.DurationVar(&f.ConnMaxLifetimeInSec, "l", 100*time.Second, "Maximum lifetime of each open connection") flag.Parse() } diff --git a/pkg/fuzzer/runner.go b/pkg/fuzzer/runner.go index df86e25..be652af 100644 --- a/pkg/fuzzer/runner.go +++ b/pkg/fuzzer/runner.go @@ -14,20 +14,20 @@ import ( func getDriverAndDB(f flags.Flags) (drivers.Driver, *sql.DB) { driver := drivers.New(f.Driver) - db := connector.Connection(driver) + db := connector.Connection(driver, f) db.SetMaxOpenConns(f.Workers) db.SetMaxIdleConns(f.Workers) return driver, db } -func runHelper(f flags.Flags, exec func(...interface{}) error, args []interface{}) error { +func runHelper(f flags.Flags, input action.SQLInsertInput) error { numJobs := f.Num workers := f.Workers jobs := make(chan struct{}, numJobs) wg := &sync.WaitGroup{} wg.Add(workers) for w := 0; w < workers; w++ { - go newWorker(jobs, wg, f, exec, args) + go worker(jobs, wg, f, input) } for j := 0; j < numJobs; j++ { @@ -39,58 +39,53 @@ func runHelper(f flags.Flags, exec func(...interface{}) error, args []interface{ return nil } -// Run the commands in a worker pool -func Run(fields []drivers.FieldDescriptor, f flags.Flags) error { - driver, db := getDriverAndDB(f) +func worker(jobs <-chan struct{}, wg *sync.WaitGroup, f flags.Flags, input action.SQLInsertInput) { + defer wg.Done() + driver := drivers.New(f.Driver) + db := connector.Connection(driver, f) defer func() { if err := db.Close(); err != nil { log.Print(err) } }() - return runHelper(f, action.Insert, []interface{}{db, fields, driver, f.Table}) -} - -func RunMulti(tableToFieldsMap map[string][]drivers.FieldDescriptor, insertionOrder []string, f flags.Flags) error { - driver, db := getDriverAndDB(f) - defer func() { - if err := db.Close(); err != nil { - log.Print(err) + for range jobs { + if err := input.Insert(); err != nil { + log.Println(err) } - }() - return runHelper(f, action.InsertMulti, []interface{}{db, driver, tableToFieldsMap, insertionOrder}) + } } -func newWorker(jobs <-chan struct{}, wg *sync.WaitGroup, f flags.Flags, exec func(...interface{}) error, args []interface{}) { - defer wg.Done() - driver := drivers.New(f.Driver) - db := connector.Connection(driver) +// Run the commands in a worker pool +func Run(fields []drivers.FieldDescriptor, f flags.Flags) error { + driver, db := getDriverAndDB(f) defer func() { if err := db.Close(); err != nil { log.Print(err) } }() - for range jobs { - if err := exec(args...); err != nil { - log.Println(err) - } + sqlInsertInput := action.SQLInsertInput{ + SingleInsertParams: &action.SingleInsertParams{ + DB: db, + Driver: driver, + Table: f.Table, + Fields: fields, + }, } + return runHelper(f, sqlInsertInput) } -// worker of the worker pool, executing the command, logging if fails -func worker(jobs <-chan struct{}, fields []drivers.FieldDescriptor, wg *sync.WaitGroup, f flags.Flags) { - driver := drivers.New(f.Driver) - db := connector.Connection(driver) +func RunMulti(tableToFieldsMap map[string][]drivers.FieldDescriptor, insertionOrder []string, f flags.Flags) error { + driver, db := getDriverAndDB(f) defer func() { if err := db.Close(); err != nil { log.Print(err) } }() - - for range jobs { - if err := action.Insert(db, fields, driver, f.Table); err != nil { - log.Println(err) - } - } - - wg.Done() + sqlInsertInput := action.SQLInsertInput{MultiInsertParams: &action.MultiInsertParams{ + DB: db, + Driver: driver, + InsertionOrder: insertionOrder, + TableToFieldsMap: tableToFieldsMap, + }} + return runHelper(f, sqlInsertInput) }