diff --git a/Makefile b/Makefile index 65d1f1d..e97af14 100644 --- a/Makefile +++ b/Makefile @@ -1,7 +1,9 @@ GOFILES := $(shell find . -name "*.go" -type f ! -path "./vendor/*") GOFMT ?= gofmt -s -.PHONY: all +PACKAGES = $(shell go list ./... | grep -v /vendor/) + +.PHONY: all test all: slqfuzz_darwin_amd64 slqfuzz_windows_amd64 slqfuzz_linux_amd64 slqfuzz_linux_arm64 slqfuzz_darwin_amd64: @@ -24,4 +26,7 @@ clean: rm -rf sqlfuzz* fmt: - @$(GOFMT) -w ${GOFILES} \ No newline at end of file + @$(GOFMT) -w ${GOFILES} + +test: + @go test -v -coverprofile cover.out ${PACKAGES} \ No newline at end of file diff --git a/drivers/entry.go b/drivers/entry.go index 897d449..660eaaf 100644 --- a/drivers/entry.go +++ b/drivers/entry.go @@ -44,18 +44,33 @@ type Field struct { Enum []string } +type FKDescriptor struct { + ConstraintName string + TableName string + ColumnName string + ForeignTableName string + ForeignColumnName string +} + //FieldDescriptor represents a field described by the table in the SQL database type FieldDescriptor struct { - Field string - Type string - Null string - Key string - Length null.Int - Default null.String - Extra string - Precision null.Int - Scale null.Int - HasDefaultValue bool + Field string + Type string + Null string + Key string + Length null.Int + Default null.String + Extra string + Precision null.Int + Scale null.Int + HasDefaultValue bool + ForeignKeyDescriptor *FKDescriptor +} + +// TestCase has a map of table to its create table query and table creation order +type TestCase struct { + TableToCreateQueryMap map[string]string + TableCreationOrder []string } // Driver is the interface should satisfied by a certain driver @@ -65,11 +80,14 @@ type Driver interface { Driver() string Insert(fields []string, table string) string MapField(descriptor FieldDescriptor) Field - DescribeFields(table string, db *sql.DB) ([]FieldDescriptor, error) + Describe(table string, db *sql.DB) ([]FieldDescriptor, error) + MultiDescribe(tables []string, db *sql.DB) (map[string][]FieldDescriptor, []string, error) + GetLatestColumnValue(table, column string, db *sql.DB) (interface{}, error) } type Testable interface { - TestTable(conn *sql.DB, table string) error + GetTestCase(name string) (TestCase, error) + TestTable(conn *sql.DB, testCase, table string) error } // New creates a new driver instance based on the flags diff --git a/drivers/mysql.go b/drivers/mysql.go index 1aed3df..9d8825b 100644 --- a/drivers/mysql.go +++ b/drivers/mysql.go @@ -1,14 +1,40 @@ package drivers import ( - "context" "database/sql" + "errors" "fmt" "strings" ) const ( MySQLDescribeTableQuery = "SHOW TABLES;" + mysqlFKQuery = "SELECT CONSTRAINT_NAME,TABLE_NAME,COLUMN_NAME,REFERENCED_TABLE_NAME,REFERENCED_COLUMN_NAME from INFORMATION_SCHEMA.KEY_COLUMN_USAGE where REFERENCED_TABLE_NAME <> 'NULL' and REFERENCED_COLUMN_NAME <> 'NULL' and TABLE_NAME = '%s'" +) + +var ( + mySQLNameToTestCase = map[string]TestCase{ + "single": { + TableToCreateQueryMap: map[string]string{DefaultTableCreateQueryKey: `CREATE TABLE %s ( + id INT(6) UNSIGNED, + firstname VARCHAR(30), + lastname VARCHAR(30), + email VARCHAR(50), + reg_date TIMESTAMP + )`}, + TableCreationOrder: nil, + }, + "multi": { + TableToCreateQueryMap: map[string]string{ + "t_currency": "CREATE TABLE IF NOT EXISTS t_currency ( id int not null,shortcut char (3) not null,PRIMARY KEY (id));", + "t_location": "CREATE TABLE IF NOT EXISTS t_location ( id int not null,location_name text not null,PRIMARY KEY (id));", + "t_product": "CREATE TABLE IF NOT EXISTS t_product( id int not null,name text not null,currency_id int ,PRIMARY KEY (id), FOREIGN KEY (currency_id) REFERENCES t_currency(id));", + "t_product_desc": "CREATE TABLE IF NOT EXISTS t_product_desc (id int not null,product_id int , description text not null, PRIMARY KEY (id), FOREIGN KEY (product_id) REFERENCES t_currency(id) );", + "t_product_stock": "CREATE TABLE IF NOT EXISTS t_product_stock(product_id int , location_id int ,amount numeric not null, FOREIGN KEY (product_id) REFERENCES t_currency(id),FOREIGN KEY(location_id) REFERENCES t_location(id));", + }, + TableCreationOrder: []string{"t_currency", "t_location", "t_product", "t_product_desc", "t_product_stock"}, + }, + } ) // MySQL implementation of the Driver @@ -163,46 +189,87 @@ func (m MySQL) MapField(descriptor FieldDescriptor) Field { return Field{Type: Unknown, Length: -1} } -func (MySQL) DescribeFields(table string, db *sql.DB) ([]FieldDescriptor, error) { +func (MySQL) Describe(table string, db *sql.DB) ([]FieldDescriptor, error) { describeQuery := fmt.Sprintf("DESCRIBE %s;", table) results, err := db.Query(describeQuery) if err != nil { return nil, err } - return parseMySQLFields(results) + fkRows, err := db.Query(fmt.Sprintf(mysqlFKQuery, strings.ToLower(table))) + if err != nil { + return nil, err + } + return parseMySQLFields(results, fkRows) } -// TestTable only for test purposes -func (m MySQL) TestTable(db *sql.DB, table string) error { - query := `CREATE TABLE %s ( - id INT(6) UNSIGNED, - firstname VARCHAR(30), - lastname VARCHAR(30), - email VARCHAR(50), - reg_date TIMESTAMP - )` - - res, err := db.ExecContext(context.Background(), fmt.Sprintf(query, table)) +func (m MySQL) MultiDescribe(tables []string, db *sql.DB) (map[string][]FieldDescriptor, []string, error) { + processedTables := make(map[string]struct{}) + tableToDescriptorMap := make(map[string][]FieldDescriptor) + for { + newTableToDescriptorMap, newlyReferencedTables, err := multiDescribeHelper(tables, processedTables, db, m) + if err != nil { + return nil, nil, err + } + for key, val := range newTableToDescriptorMap { + tableToDescriptorMap[key] = val + } + if len(newlyReferencedTables) == 0 { + break + } + tables = newlyReferencedTables + } + insertionOrder, err := getInsertionOrder(tableToDescriptorMap) if err != nil { - return err + return nil, nil, err } + return tableToDescriptorMap, insertionOrder, nil +} - _, err = res.RowsAffected() +func (MySQL) GetLatestColumnValue(table, column string, db *sql.DB) (interface{}, error) { + query := fmt.Sprintf("select %v from %v order by %v desc limit 1", column, table, column) + rows, err := db.Query(query) if err != nil { - return err + return nil, err + } + var val interface{} + for rows.Next() { + rows.Scan(&val) } - return nil + return val, nil +} + +// TestTable only for test purposes +func (m MySQL) TestTable(db *sql.DB, testCase, table string) error { + return testTable(db, testCase, table, m) } -func parseMySQLFields(results *sql.Rows) ([]FieldDescriptor, error) { +func (MySQL) GetTestCase(name string) (TestCase, error) { + if val, ok := mySQLNameToTestCase[name]; ok { + return val, nil + } + return TestCase{}, errors.New(fmt.Sprintf("postgres: Error getting testcase with name %v", name)) +} + +func parseMySQLFields(results, fkRows *sql.Rows) ([]FieldDescriptor, error) { var fields []FieldDescriptor + columnToFKMap := make(map[string]FKDescriptor) + for fkRows.Next() { + var fk FKDescriptor + err := fkRows.Scan(&fk.ConstraintName, &fk.TableName, &fk.ColumnName, &fk.ForeignTableName, &fk.ForeignColumnName) + if err != nil { + return nil, err + } + columnToFKMap[fk.ColumnName] = fk + } for results.Next() { var d FieldDescriptor err := results.Scan(&d.Field, &d.Type, &d.Null, &d.Key, &d.Default, &d.Extra) if err != nil { return nil, err } - + if val, ok := columnToFKMap[d.Field]; ok { + d.ForeignKeyDescriptor = &val + } fields = append(fields, d) } return fields, nil diff --git a/drivers/postgres.go b/drivers/postgres.go index 0482442..bd83b14 100644 --- a/drivers/postgres.go +++ b/drivers/postgres.go @@ -1,8 +1,8 @@ package drivers import ( - "context" "database/sql" + "errors" "fmt" "log" "strings" @@ -69,6 +69,73 @@ const ( PSQLConnectionTemplate = "host=%s port=%s user=%s password=%s dbname=%s sslmode=disable" PSQLInsertTemplate = `INSERT INTO %s("%s") VALUES(%s)` PSQLShowTablesQuery = "SELECT tablename FROM pg_catalog.pg_tables WHERE schemaname != 'pg_catalog' AND schemaname != 'information_schema';" + psqlForeignKeysQuery = ` + SELECT + tc.constraint_name, + tc.table_name, + kcu.column_name, + ccu.table_name AS foreign_table_name, + ccu.column_name AS foreign_column_name +FROM + information_schema.table_constraints AS tc + JOIN information_schema.key_column_usage AS kcu + ON tc.constraint_name = kcu.constraint_name + AND tc.table_schema = kcu.table_schema + JOIN information_schema.constraint_column_usage AS ccu + ON ccu.constraint_name = tc.constraint_name + AND ccu.table_schema = tc.table_schema +WHERE tc.constraint_type = 'FOREIGN KEY' AND tc.table_name='%s' + ` +) + +var ( + pgNameToTestCase = map[string]TestCase{ + "single": { + TableToCreateQueryMap: map[string]string{DefaultTableCreateQueryKey: CreateTable}, + TableCreationOrder: nil, + }, + "multi": { + TableToCreateQueryMap: map[string]string{"t_currency": `CREATE TABLE IF NOT EXISTS "t_currency" + ( + id int not null, + shortcut char (3) not null, + PRIMARY KEY (id) + ); + `, + "t_location": `CREATE TABLE IF NOT EXISTS "t_location" + ( + id int not null, + location_name text not null, + PRIMARY KEY (id) + ); + `, + "t_product": `CREATE TABLE IF NOT EXISTS "t_product" + ( + id int not null, + name text not null, + currency_id int REFERENCES t_currency (id) not null, + PRIMARY KEY (id) + ); + `, + "t_product_desc": `CREATE TABLE IF NOT EXISTS "t_product_desc" + ( + id int not null, + product_id int REFERENCES t_product (id) not null, + description text not null, + PRIMARY KEY (id) + ); + `, + "t_product_stock": `CREATE TABLE IF NOT EXISTS "t_product_stock" + ( + product_id int REFERENCES t_product (id) not null, + location_id int REFERENCES t_location (id) not null, + amount numeric not null + ); + `, + }, + TableCreationOrder: []string{"t_currency", "t_location", "t_product", "t_product_desc", "t_product_stock"}, + }, + } ) type Postgres struct { @@ -140,30 +207,77 @@ func (p Postgres) MapField(descriptor FieldDescriptor) Field { return field } -func (p Postgres) DescribeFields(table string, db *sql.DB) ([]FieldDescriptor, error) { +func (p Postgres) MultiDescribe(tables []string, db *sql.DB) (map[string][]FieldDescriptor, []string, error) { + processedTables := make(map[string]struct{}) + tableToDescriptorMap := make(map[string][]FieldDescriptor) + for { + newTableToDescriptorMap, newlyReferencedTables, err := multiDescribeHelper(tables, processedTables, db, p) + if err != nil { + return nil, nil, err + } + for key, val := range newTableToDescriptorMap { + tableToDescriptorMap[key] = val + } + if len(newlyReferencedTables) == 0 { + break + } + tables = newlyReferencedTables + } + insertionOrder, err := getInsertionOrder(tableToDescriptorMap) + if err != nil { + return nil, nil, err + } + return tableToDescriptorMap, insertionOrder, nil +} + +func (p Postgres) Describe(table string, db *sql.DB) ([]FieldDescriptor, error) { results, err := db.Query(fmt.Sprintf(PSQLDescribeTemplate, strings.ToLower(table))) if err != nil { return nil, err } - return parsePostgresFields(results) + fkResults, err := db.Query(fmt.Sprintf(psqlForeignKeysQuery, strings.ToLower(table))) + if err != nil { + return nil, err + } + return parsePostgresFields(results, fkResults) } -// TestTable only for test purposes -func (p Postgres) TestTable(db *sql.DB, table string) error { - res, err := db.ExecContext(context.Background(), fmt.Sprintf(CreateTable, table)) +func (p Postgres) GetLatestColumnValue(table, column string, db *sql.DB) (interface{}, error) { + query := fmt.Sprintf("select %s from %s order by %s desc limit 1", column, table, column) + rows, err := db.Query(query) if err != nil { - return err + return nil, err + } + var val interface{} + for rows.Next() { + rows.Scan(&val) } + return val, nil +} - _, err = res.RowsAffected() - if err != nil { - return err +// TestTable only for test purposes +func (p Postgres) TestTable(db *sql.DB, testCase, table string) error { + return testTable(db, testCase, table, p) +} + +func (Postgres) GetTestCase(name string) (TestCase, error) { + if val, ok := pgNameToTestCase[name]; ok { + return val, nil } - return nil + return TestCase{}, errors.New(fmt.Sprintf("postgres: Error getting testcase with name %v", name)) } -func parsePostgresFields(rows *sql.Rows) ([]FieldDescriptor, error) { +func parsePostgresFields(rows, fkRows *sql.Rows) ([]FieldDescriptor, error) { var tableFields []FieldDescriptor + columnToFKMap := make(map[string]FKDescriptor) + for fkRows.Next() { + var fk FKDescriptor + err := fkRows.Scan(&fk.ConstraintName, &fk.TableName, &fk.ColumnName, &fk.ForeignTableName, &fk.ForeignColumnName) + if err != nil { + return nil, err + } + columnToFKMap[fk.ColumnName] = fk + } for rows.Next() { var field FieldDescriptor err := rows.Scan(&field.Field, &field.Type, &field.Length, &field.Default, &field.Null, &field.Precision, &field.Scale) @@ -171,6 +285,9 @@ func parsePostgresFields(rows *sql.Rows) ([]FieldDescriptor, error) { if err != nil { return nil, err } + if val, ok := columnToFKMap[field.Field]; ok { + field.ForeignKeyDescriptor = &val + } tableFields = append(tableFields, field) } return tableFields, nil diff --git a/drivers/postgres_example_test.go b/drivers/postgres_example_test.go index 5729df2..9a44792 100644 --- a/drivers/postgres_example_test.go +++ b/drivers/postgres_example_test.go @@ -6,8 +6,8 @@ import ( "log" ) -func getConnection() (*sql.DB, error) { - connectionString := fmt.Sprintf("host=%s port=%s user=%s password=%s dbname=%s sslmode=disable ", "127.0.0.1", "5432", "postgres", "password", "fuzzpostgres") +func getPostgresConnection() (*sql.DB, error) { + connectionString := fmt.Sprintf("host=%s port=%s user=%s password=%s dbname=%s sslmode=disable ", "127.0.0.1", "5432", "test", "test", "test") db, err := sql.Open("postgres", connectionString) if err != nil { return nil, err @@ -16,7 +16,7 @@ func getConnection() (*sql.DB, error) { } func createTable() { - db, err := getConnection() + db, err := getPostgresConnection() if err != nil { return } @@ -25,7 +25,7 @@ func createTable() { func ExamplePostgres_ShowTables() { createTable() - db, err := getConnection() + db, err := getPostgresConnection() if err != nil { return } @@ -42,12 +42,12 @@ func ExamplePostgres_ShowTables() { func ExamplePostgres_DescribeFields() { createTable() - db, err := getConnection() + db, err := getPostgresConnection() if err != nil { return } driver := Postgres{} - fields, err := driver.DescribeFields("pg_data_types", db) + fields, err := driver.Describe("pg_data_types", db) if err != nil { log.Printf("Error describing table : %s", err.Error()) return diff --git a/drivers/postgres_test.go b/drivers/postgres_test.go index 8faac86..273d979 100644 --- a/drivers/postgres_test.go +++ b/drivers/postgres_test.go @@ -1,6 +1,8 @@ package drivers import ( + "encoding/json" + "fmt" "reflect" "testing" @@ -40,3 +42,33 @@ func TestPostgres_MapField(t *testing.T) { } } } + +func TestPostgres_MultiDescribe(t *testing.T) { + db, err := getPostgresConnection() + pgDriver := Postgres{} + + if err != nil { + t.Errorf("error getting postgres connection : %v", err.Error()) + } + testCase, err := pgDriver.GetTestCase("multi") + if err != nil { + t.Errorf("error getting multi test case : %v", err.Error()) + } + err = pgDriver.TestTable(db, "multi", "") + if err != nil { + t.Errorf("error initialising multi test case : %v", err.Error()) + } + tables := testCase.TableCreationOrder + tableFieldsMap, insertionOrder, err := Postgres{}.MultiDescribe(tables, db) + if err != nil { + t.Errorf("error descriving tables %v. Error : %v", tables, err) + } + if len(tableFieldsMap) == 0 || len(insertionOrder) != len(tableFieldsMap) || len(insertionOrder) != len(tables) { + t.Errorf("error receiving required fields count. input len %v described fields len %v insertion order length %v", len(tables), len(tableFieldsMap), len(insertionOrder)) + } + tableFieldMapStr, err := json.Marshal(tableFieldsMap) + if err != nil { + t.Error(err) + } + fmt.Println(string(tableFieldMapStr)) +} diff --git a/drivers/sqlcommons.go b/drivers/sqlcommons.go new file mode 100644 index 0000000..e1b0f58 --- /dev/null +++ b/drivers/sqlcommons.go @@ -0,0 +1,100 @@ +package drivers + +import ( + "context" + "database/sql" + "errors" + "fmt" + "strings" +) + +const ( + DefaultTableCreateQueryKey = "" +) + +func multiDescribeHelper(tables []string, processedTables map[string]struct{}, db *sql.DB, d Driver) (map[string][]FieldDescriptor, []string, error) { + knownTables := make(map[string]bool) + tableDescriptorMap := make(map[string][]FieldDescriptor) + var newlyReferencedTables []string + for _, table := range tables { + knownTables[table] = true + } + for _, table := range tables { + fields, err := d.Describe(table, db) + if err != nil { + return nil, nil, err + } + for _, field := range fields { + if field.ForeignKeyDescriptor == nil { + continue + } + foreignTableName := field.ForeignKeyDescriptor.ForeignTableName + if _, ok := processedTables[foreignTableName]; ok && !knownTables[foreignTableName] { + newlyReferencedTables = append(newlyReferencedTables, foreignTableName) + knownTables[foreignTableName] = true + } + } + tableDescriptorMap[table] = fields + processedTables[table] = struct{}{} + } + return tableDescriptorMap, newlyReferencedTables, nil +} + +func getInsertionOrder(tablesToFieldsMap map[string][]FieldDescriptor) ([]string, error) { + var tablesVisitOrder []string + tablesVisited := make(map[string]struct{}) + for len(tablesVisitOrder) < len(tablesToFieldsMap) { + newInsertCount := 0 + for table, fields := range tablesToFieldsMap { + if _, ok := tablesVisited[table]; ok { + continue + } + canInsert := true + for _, field := range fields { + if field.ForeignKeyDescriptor == nil { + continue + } + if _, ok := tablesVisited[field.ForeignKeyDescriptor.ForeignTableName]; ok { + continue + } + // Necessary table is not yet visited. + canInsert = false + break + } + if canInsert { + newInsertCount++ + tablesVisited[table] = struct{}{} + tablesVisitOrder = append(tablesVisitOrder, table) + } + } + if newInsertCount == 0 { + return nil, errors.New("error generating insertion order. Maybe necessary dependencies are not met") + } + } + return tablesVisitOrder, nil +} + +func testTable(db *sql.DB, testCase, table string, d Testable) error { + test, err := d.GetTestCase(testCase) + if err != nil { + return err + } + if test.TableCreationOrder == nil { + if query, ok := test.TableToCreateQueryMap[DefaultTableCreateQueryKey]; ok { + if res, err := db.ExecContext(context.Background(), fmt.Sprintf(query, table)); err != nil { + return err + } else if _, err := res.RowsAffected(); err != nil { + return err + } + } + } else { + for _, table := range test.TableCreationOrder { + createCommand := test.TableToCreateQueryMap[table] + _, err := db.Query(strings.TrimSpace(createCommand)) + if err != nil { + return err + } + } + } + return nil +} diff --git a/main.go b/main.go index 3b17c1e..1a4bf9a 100644 --- a/main.go +++ b/main.go @@ -16,7 +16,7 @@ func main() { f := flags.Get() gofakeit.Seed(0) driver := drivers.New(f.Driver) - db := connector.Connection(driver) + db := connector.Connection(driver, f) defer db.Close() var tables []string @@ -31,7 +31,7 @@ func main() { } for _, table := range tables { f.Table = table - fields, err := driver.DescribeFields(f.Table, db) + fields, err := driver.Describe(f.Table, db) if err != nil { log.Fatal(err.Error()) } diff --git a/main_test.go b/main_test.go index dcdcb49..8062dc9 100644 --- a/main_test.go +++ b/main_test.go @@ -29,15 +29,15 @@ func TestFuzz(t *testing.T) { gofakeit.Seed(0) driver := drivers.New(f.Driver) testable := drivers.NewTestable(f.Driver) - db := connector.Connection(driver) + db := connector.Connection(driver, f) defer db.Close() if _, err := db.Exec(fmt.Sprintf("DROP TABLE IF EXISTS %s", f.Table)); err != nil { t.Fatal(err) } - if err := testable.TestTable(db, f.Table); err != nil { + if err := testable.TestTable(db, "single", f.Table); err != nil { t.Fatal(err) } - fields, err := driver.DescribeFields(f.Table, db) + fields, err := driver.Describe(f.Table, db) if err != nil { t.Fatal(err.Error()) } @@ -86,15 +86,15 @@ func TestFuzzPostgres(t *testing.T) { gofakeit.Seed(0) driver := drivers.New(f.Driver) testable := drivers.NewTestable(f.Driver) - db := connector.Connection(driver) + db := connector.Connection(driver, f) defer db.Close() if _, err := db.Exec(fmt.Sprintf("DROP TABLE IF EXISTS %s", f.Table)); err != nil { t.Fatal(err) } - if err := testable.TestTable(db, f.Table); err != nil { + if err := testable.TestTable(db, "single", f.Table); err != nil { t.Fatal(err) } - fields, err := driver.DescribeFields(f.Table, db) + fields, err := driver.Describe(f.Table, db) if err != nil { t.Fatal(err.Error()) } @@ -135,6 +135,88 @@ func TestFuzzPostgres(t *testing.T) { } } +func TestMysqlMultiInsert(t *testing.T) { + f := flags.Flags{} + f.Driver = drivers.Flags{ + Username: "test", + Password: "test", + Database: "test", + Host: "localhost", + Port: "3306", + Driver: "mysql", + } + f.Table = "Persons" + f.Parsed = true + f.Num = 10 + f.Workers = 2 + + gofakeit.Seed(0) + driver := drivers.New(f.Driver) + testable := drivers.NewTestable(f.Driver) + test, err := testable.GetTestCase("multi") + if err != nil { + t.Error(fmt.Sprintf("postgres : error fetching test case for multi. %v", err.Error())) + } + db := connector.Connection(driver, f) + defer db.Close() + if _, err := db.Exec(fmt.Sprintf("DROP TABLE IF EXISTS %s", f.Table)); err != nil { + t.Fatal(err) + } + if err := testable.TestTable(db, "multi", f.Table); err != nil { + t.Fatal(err) + } + tables := test.TableCreationOrder + tableFieldMap, insertionOrder, err := driver.MultiDescribe(tables, db) + if err != nil { + t.Errorf("Error describing tables %v. Error %v", tables, err) + } + err = fuzzer.RunMulti(tableFieldMap, insertionOrder, f) + if err != nil { + t.Errorf("error during multi insert %v", err.Error()) + } +} + +func TestPostgresMultiInsert(t *testing.T) { + f := flags.Flags{} + f.Driver = drivers.Flags{ + Username: "test", + Password: "test", + Database: "test", + Host: "localhost", + Port: "5432", + Driver: "postgres", + } + f.Table = "Persons" + f.Parsed = true + f.Num = 10 + f.Workers = 2 + + gofakeit.Seed(0) + driver := drivers.New(f.Driver) + testable := drivers.NewTestable(f.Driver) + test, err := testable.GetTestCase("multi") + if err != nil { + t.Error(fmt.Sprintf("postgres : error fetching test case for multi. %v", err.Error())) + } + db := connector.Connection(driver, f) + defer db.Close() + if _, err := db.Exec(fmt.Sprintf("DROP TABLE IF EXISTS %s", f.Table)); err != nil { + t.Fatal(err) + } + if err := testable.TestTable(db, "multi", f.Table); err != nil { + t.Fatal(err) + } + tables := test.TableCreationOrder + tableFieldMap, insertionOrder, err := driver.MultiDescribe(tables, db) + if err != nil { + t.Errorf("Error describing tables %v. Error %v", tables, err) + } + err = fuzzer.RunMulti(tableFieldMap, insertionOrder, f) + if err != nil { + t.Errorf("error during multi insert %v", err.Error()) + } +} + type testTable struct { id int firstname string diff --git a/pkg/action/action.go b/pkg/action/action.go index 09298c6..e02ff34 100644 --- a/pkg/action/action.go +++ b/pkg/action/action.go @@ -3,6 +3,7 @@ package action import ( "database/sql" "encoding/base64" + "errors" "fmt" "log" "math" @@ -17,21 +18,97 @@ import ( "github.com/rs/xid" ) -// Insert is inserting a random generated data into the chosen table -func Insert(db *sql.DB, fields []drivers.FieldDescriptor, driver drivers.Driver, table string) error { - var f = make([]string, 0, len(fields)) - var values = make([]interface{}, 0, len(fields)) - for _, field := range fields { +type SingleInsertParams struct { + DB *sql.DB + Driver drivers.Driver + Table string + Fields []drivers.FieldDescriptor +} + +type MultiInsertParams struct { + DB *sql.DB + Driver drivers.Driver + InsertionOrder []string + TableToFieldsMap map[string][]drivers.FieldDescriptor +} + +type SQLInsertInput struct { + SingleInsertParams *SingleInsertParams + MultiInsertParams *MultiInsertParams +} + +func (sqlInsertInput SQLInsertInput) Insert() error { + if sqlInsertInput.SingleInsertParams != nil { + return sqlInsertInput.singleInsert() + } else if sqlInsertInput.MultiInsertParams != nil { + return sqlInsertInput.multiInsert() + } + return errors.New("action: error in sql insert input. Both single and multi insert arguments are not initialised") +} + +func (sqlInsertInput SQLInsertInput) multiInsert() error { + multiInsertParams := sqlInsertInput.MultiInsertParams + if multiInsertParams == nil { + return errors.New("action : error during multi insert. Could not find necessary arguments") + } + tableFieldValuesMap := make(map[string]map[string]interface{}) + for _, table := range multiInsertParams.InsertionOrder { + if fields, ok := multiInsertParams.TableToFieldsMap[table]; ok { + var f = make([]string, 0, len(fields)) + var values []interface{} + for _, field := range fields { + f = append(f, field.Field) + if field.HasDefaultValue { + continue + } + var data interface{} + if field.ForeignKeyDescriptor != nil { + if foreignTableFields, ok := tableFieldValuesMap[field.ForeignKeyDescriptor.ForeignTableName]; ok { + if val, ok := foreignTableFields[field.ForeignKeyDescriptor.ForeignColumnName]; ok { + data = val + continue + } + } + val, err := multiInsertParams.Driver.GetLatestColumnValue(field.ForeignKeyDescriptor.ForeignTableName, field.ForeignKeyDescriptor.ForeignColumnName, multiInsertParams.DB) + if err != nil { + return err + } + data = val + // Get from table. If no value present in table as well, throw error. + } else { + data = generateData(multiInsertParams.Driver, field) + } + values = append(values, data) + } + query := multiInsertParams.Driver.Insert(f, table) + _, err := multiInsertParams.DB.Exec(query, values...) + if err != nil { + return err + } + } + } + return nil +} + +// singleInsert is inserting a random generated data into the chosen table +func (sqlInsertInput SQLInsertInput) singleInsert() error { + insertParams := sqlInsertInput.SingleInsertParams + if insertParams == nil { + return errors.New("action : error during insert. Could not find necessary arguments") + } + var f = make([]string, 0, len(insertParams.Fields)) + var values = make([]interface{}, 0, len(insertParams.Fields)) + for _, field := range insertParams.Fields { // Has default value. No need to insert this field manually. if field.HasDefaultValue { continue } f = append(f, field.Field) - values = append(values, generateData(driver, field)) + values = append(values, generateData(insertParams.Driver, field)) } - query := driver.Insert(f, table) + query := insertParams.Driver.Insert(f, insertParams.Table) - _, err := db.Exec(query, values...) + _, err := insertParams.DB.Exec(query, values...) return err } @@ -78,7 +155,7 @@ func generateData(driver drivers.Driver, fieldDescriptor drivers.FieldDescriptor ) case drivers.Time: return time.Date( - gofakeit.Number(1970, 2038), + gofakeit.Number(1980, 2028), time.Month(gofakeit.Number(0, 12)), gofakeit.Day(), gofakeit.Hour(), diff --git a/pkg/connector/connector.go b/pkg/connector/connector.go index a2bd95c..73f7b7f 100644 --- a/pkg/connector/connector.go +++ b/pkg/connector/connector.go @@ -2,20 +2,22 @@ package connector import ( "database/sql" - "log" - "github.com/PumpkinSeed/sqlfuzz/drivers" + "github.com/PumpkinSeed/sqlfuzz/pkg/flags" _ "github.com/lib/pq" + "log" ) // Connection building a singleton connection to the database for give driver -func Connection(d drivers.Driver) *sql.DB { +func Connection(d drivers.Driver, f flags.Flags) *sql.DB { db, err := connect(d) if err != nil { log.Fatal(err) return nil } - + db.SetConnMaxLifetime(f.ConnMaxLifetimeInSec) + db.SetMaxIdleConns(f.MaxIdleConns) + db.SetMaxOpenConns(f.MaxOpenConns) return db } diff --git a/pkg/flags/flags.go b/pkg/flags/flags.go index 807df44..2aecfcd 100644 --- a/pkg/flags/flags.go +++ b/pkg/flags/flags.go @@ -2,8 +2,8 @@ package flags import ( "flag" - "github.com/PumpkinSeed/sqlfuzz/drivers" + "time" ) var f Flags @@ -15,7 +15,12 @@ type Flags struct { Num int Workers int Table string - Parsed bool + + ConnMaxLifetimeInSec time.Duration + MaxIdleConns int + MaxOpenConns int + + Parsed bool } // Get the parsed flags and parsing them if needed @@ -39,6 +44,9 @@ func parse() { flag.StringVar(&f.Table, "t", "", "Table for fuzzing") flag.IntVar(&f.Num, "n", 1000, "Number of rows") flag.IntVar(&f.Workers, "w", 20, "Number of workers") + flag.IntVar(&f.MaxIdleConns, "i", 200, "Number of max sql db idle connections") + flag.IntVar(&f.MaxOpenConns, "o", 1000, "Number of max sql db open connections") + flag.DurationVar(&f.ConnMaxLifetimeInSec, "l", 100*time.Second, "Maximum lifetime of each open connection") flag.Parse() } diff --git a/pkg/fuzzer/runner.go b/pkg/fuzzer/runner.go index ac11872..fa30081 100644 --- a/pkg/fuzzer/runner.go +++ b/pkg/fuzzer/runner.go @@ -1,6 +1,7 @@ package fuzzer import ( + "database/sql" "log" "sync" @@ -11,15 +12,20 @@ import ( _ "github.com/lib/pq" ) -// Run the commands in a worker pool -func Run(fields []drivers.FieldDescriptor, f flags.Flags) error { +func getDriverAndDB(f flags.Flags) (drivers.Driver, *sql.DB) { + driver := drivers.New(f.Driver) + db := connector.Connection(driver, f) + return driver, db +} + +func runHelper(f flags.Flags, input action.SQLInsertInput) error { numJobs := f.Num workers := f.Workers jobs := make(chan struct{}, numJobs) wg := &sync.WaitGroup{} wg.Add(workers) for w := 0; w < workers; w++ { - go worker(jobs, fields, wg, f) + go worker(jobs, wg, f, input) } for j := 0; j < numJobs; j++ { @@ -31,21 +37,53 @@ func Run(fields []drivers.FieldDescriptor, f flags.Flags) error { return nil } -// worker of the worker pool, executing the command, logging if fails -func worker(jobs <-chan struct{}, fields []drivers.FieldDescriptor, wg *sync.WaitGroup, f flags.Flags) { +func worker(jobs <-chan struct{}, wg *sync.WaitGroup, f flags.Flags, input action.SQLInsertInput) { + defer wg.Done() driver := drivers.New(f.Driver) - db := connector.Connection(driver) + db := connector.Connection(driver, f) defer func() { if err := db.Close(); err != nil { log.Print(err) } }() - for range jobs { - if err := action.Insert(db, fields, driver, f.Table); err != nil { + if err := input.Insert(); err != nil { log.Println(err) } } +} + +// Run the commands in a worker pool +func Run(fields []drivers.FieldDescriptor, f flags.Flags) error { + driver, db := getDriverAndDB(f) + defer func() { + if err := db.Close(); err != nil { + log.Print(err) + } + }() + sqlInsertInput := action.SQLInsertInput{ + SingleInsertParams: &action.SingleInsertParams{ + DB: db, + Driver: driver, + Table: f.Table, + Fields: fields, + }, + } + return runHelper(f, sqlInsertInput) +} - wg.Done() +func RunMulti(tableToFieldsMap map[string][]drivers.FieldDescriptor, insertionOrder []string, f flags.Flags) error { + driver, db := getDriverAndDB(f) + defer func() { + if err := db.Close(); err != nil { + log.Print(err) + } + }() + sqlInsertInput := action.SQLInsertInput{MultiInsertParams: &action.MultiInsertParams{ + DB: db, + Driver: driver, + InsertionOrder: insertionOrder, + TableToFieldsMap: tableToFieldsMap, + }} + return runHelper(f, sqlInsertInput) }