Skip to content

Commit

Permalink
Some review changes
Browse files Browse the repository at this point in the history
  • Loading branch information
srinidhis94 committed Apr 17, 2021
1 parent 9d9a498 commit 3743db2
Show file tree
Hide file tree
Showing 9 changed files with 110 additions and 81 deletions.
2 changes: 1 addition & 1 deletion drivers/entry.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,10 +83,10 @@ type Driver interface {
Describe(table string, db *sql.DB) ([]FieldDescriptor, error)
MultiDescribe(tables []string, db *sql.DB) (map[string][]FieldDescriptor, []string, error)
GetLatestColumnValue(table, column string, db *sql.DB) (interface{}, error)
GetTestCase(name string) (TestCase, error)
}

type Testable interface {
GetTestCase(name string) (TestCase, error)
TestTable(conn *sql.DB, testCase, table string) error
}

Expand Down
2 changes: 1 addition & 1 deletion drivers/postgres.go
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,7 @@ func (p Postgres) Describe(table string, db *sql.DB) ([]FieldDescriptor, error)
}

func (p Postgres) GetLatestColumnValue(table, column string, db *sql.DB) (interface{}, error) {
query := fmt.Sprintf("select %v from %v order by %v desc limit 1", column, table, column)
query := fmt.Sprintf("select %s from %s order by %s desc limit 1", column, table, column)
rows, err := db.Query(query)
if err != nil {
return nil, err
Expand Down
2 changes: 1 addition & 1 deletion drivers/sqlcommons.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ func getInsertionOrder(tablesToFieldsMap map[string][]FieldDescriptor) ([]string
return tablesVisitOrder, nil
}

func testTable(db *sql.DB, testCase, table string, d Driver) error {
func testTable(db *sql.DB, testCase, table string, d Testable) error {
test, err := d.GetTestCase(testCase)
if err != nil {
return err
Expand Down
2 changes: 1 addition & 1 deletion main.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ func main() {
f := flags.Get()
gofakeit.Seed(0)
driver := drivers.New(f.Driver)
db := connector.Connection(driver)
db := connector.Connection(driver, f)
defer db.Close()

var tables []string
Expand Down
12 changes: 6 additions & 6 deletions main_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ func TestFuzz(t *testing.T) {
gofakeit.Seed(0)
driver := drivers.New(f.Driver)
testable := drivers.NewTestable(f.Driver)
db := connector.Connection(driver)
db := connector.Connection(driver, f)
defer db.Close()
if _, err := db.Exec(fmt.Sprintf("DROP TABLE IF EXISTS %s", f.Table)); err != nil {
t.Fatal(err)
Expand Down Expand Up @@ -86,7 +86,7 @@ func TestFuzzPostgres(t *testing.T) {
gofakeit.Seed(0)
driver := drivers.New(f.Driver)
testable := drivers.NewTestable(f.Driver)
db := connector.Connection(driver)
db := connector.Connection(driver, f)
defer db.Close()
if _, err := db.Exec(fmt.Sprintf("DROP TABLE IF EXISTS %s", f.Table)); err != nil {
t.Fatal(err)
Expand Down Expand Up @@ -153,11 +153,11 @@ func TestMysqlMultiInsert(t *testing.T) {
gofakeit.Seed(0)
driver := drivers.New(f.Driver)
testable := drivers.NewTestable(f.Driver)
test, err := driver.GetTestCase("multi")
test, err := testable.GetTestCase("multi")
if err != nil {
t.Error(fmt.Sprintf("postgres : error fetching test case for multi. %v", err.Error()))
}
db := connector.Connection(driver)
db := connector.Connection(driver, f)
defer db.Close()
if _, err := db.Exec(fmt.Sprintf("DROP TABLE IF EXISTS %s", f.Table)); err != nil {
t.Fatal(err)
Expand Down Expand Up @@ -194,11 +194,11 @@ func TestPostgresMultiInsert(t *testing.T) {
gofakeit.Seed(0)
driver := drivers.New(f.Driver)
testable := drivers.NewTestable(f.Driver)
test, err := driver.GetTestCase("multi")
test, err := testable.GetTestCase("multi")
if err != nil {
t.Error(fmt.Sprintf("postgres : error fetching test case for multi. %v", err.Error()))
}
db := connector.Connection(driver)
db := connector.Connection(driver, f)
defer db.Close()
if _, err := db.Exec(fmt.Sprintf("DROP TABLE IF EXISTS %s", f.Table)); err != nil {
t.Fatal(err)
Expand Down
79 changes: 53 additions & 26 deletions pkg/action/action.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package action
import (
"database/sql"
"encoding/base64"
"errors"
"fmt"
"log"
"math"
Expand All @@ -17,15 +18,42 @@ import (
"github.com/rs/xid"
)

func InsertMulti(args ...interface{}) error {
db := args[0].(*sql.DB)
driver := args[1].(drivers.Driver)
tableToFieldsMap := args[2].(map[string][]drivers.FieldDescriptor)
insertionOrder := args[3].([]string)
//func InsertMulti(db *sql.DB, driver drivers.Driver, tableToFieldsMap map[string][]drivers.FieldDescriptor, insertionOrder []string) error {
type SingleInsertParams struct {
DB *sql.DB
Driver drivers.Driver
Table string
Fields []drivers.FieldDescriptor
}

type MultiInsertParams struct {
DB *sql.DB
Driver drivers.Driver
InsertionOrder []string
TableToFieldsMap map[string][]drivers.FieldDescriptor
}

type SQLInsertInput struct {
SingleInsertParams *SingleInsertParams
MultiInsertParams *MultiInsertParams
}

func (sqlInsertInput SQLInsertInput) Insert() error {
if sqlInsertInput.SingleInsertParams != nil {
return sqlInsertInput.singleInsert()
} else if sqlInsertInput.MultiInsertParams != nil {
return sqlInsertInput.multiInsert()
}
return errors.New("action: error in sql insert input. Both single and multi insert arguments are not initialised")
}

func (sqlInsertInput SQLInsertInput) multiInsert() error {
multiInsertParams := sqlInsertInput.MultiInsertParams
if multiInsertParams == nil {
return errors.New("action : error during multi insert. Could not find necessary arguments")
}
tableFieldValuesMap := make(map[string]map[string]interface{})
for _, table := range insertionOrder {
if fields, ok := tableToFieldsMap[table]; ok {
for _, table := range multiInsertParams.InsertionOrder {
if fields, ok := multiInsertParams.TableToFieldsMap[table]; ok {
var f = make([]string, 0, len(fields))
var values []interface{}
for _, field := range fields {
Expand All @@ -41,19 +69,19 @@ func InsertMulti(args ...interface{}) error {
continue
}
}
val, err := driver.GetLatestColumnValue(field.ForeignKeyDescriptor.ForeignTableName, field.ForeignKeyDescriptor.ForeignColumnName, db)
val, err := multiInsertParams.Driver.GetLatestColumnValue(field.ForeignKeyDescriptor.ForeignTableName, field.ForeignKeyDescriptor.ForeignColumnName, multiInsertParams.DB)
if err != nil {
return err
}
data = val
// Get from table. If no value present in table as well, throw error.
} else {
data = generateData(driver, field)
data = generateData(multiInsertParams.Driver, field)
}
values = append(values, data)
}
query := driver.Insert(f, table)
_, err := db.Exec(query, values...)
query := multiInsertParams.Driver.Insert(f, table)
_, err := multiInsertParams.DB.Exec(query, values...)
if err != nil {
return err
}
Expand All @@ -62,26 +90,25 @@ func InsertMulti(args ...interface{}) error {
return nil
}

// Insert is inserting a random generated data into the chosen table
func Insert(args ...interface{}) error {
//func Insert(db *sql.DB, fields []drivers.FieldDescriptor, driver drivers.Driver, table string) error {
db := args[0].(*sql.DB)
fields := args[1].([]drivers.FieldDescriptor)
driver := args[2].(drivers.Driver)
table := args[3].(string)
var f = make([]string, 0, len(fields))
var values = make([]interface{}, 0, len(fields))
for _, field := range fields {
// singleInsert is inserting a random generated data into the chosen table
func (sqlInsertInput SQLInsertInput) singleInsert() error {
insertParams := sqlInsertInput.SingleInsertParams
if insertParams == nil {
return errors.New("action : error during insert. Could not find necessary arguments")
}
var f = make([]string, 0, len(insertParams.Fields))
var values = make([]interface{}, 0, len(insertParams.Fields))
for _, field := range insertParams.Fields {
// Has default value. No need to insert this field manually.
if field.HasDefaultValue {
continue
}
f = append(f, field.Field)
values = append(values, generateData(driver, field))
values = append(values, generateData(insertParams.Driver, field))
}
query := driver.Insert(f, table)
query := insertParams.Driver.Insert(f, insertParams.Table)

_, err := db.Exec(query, values...)
_, err := insertParams.DB.Exec(query, values...)
return err
}

Expand Down Expand Up @@ -128,7 +155,7 @@ func generateData(driver drivers.Driver, fieldDescriptor drivers.FieldDescriptor
)
case drivers.Time:
return time.Date(
gofakeit.Number(1970, 2038),
gofakeit.Number(1980, 2028),
time.Month(gofakeit.Number(0, 12)),
gofakeit.Day(),
gofakeit.Hour(),
Expand Down
13 changes: 6 additions & 7 deletions pkg/connector/connector.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,23 +2,22 @@ package connector

import (
"database/sql"
"log"
"time"

"github.com/PumpkinSeed/sqlfuzz/drivers"
"github.com/PumpkinSeed/sqlfuzz/pkg/flags"
_ "github.com/lib/pq"
"log"
)

// Connection building a singleton connection to the database for give driver
func Connection(d drivers.Driver) *sql.DB {
func Connection(d drivers.Driver, f flags.Flags) *sql.DB {
db, err := connect(d)
if err != nil {
log.Fatal(err)
return nil
}
db.SetConnMaxLifetime(100 * time.Second)
db.SetMaxIdleConns(1000)
db.SetMaxOpenConns(200)
db.SetConnMaxLifetime(f.ConnMaxLifetimeInSec)
db.SetMaxIdleConns(f.MaxIdleConns)
db.SetMaxOpenConns(f.MaxOpenConns)
return db
}

Expand Down
12 changes: 10 additions & 2 deletions pkg/flags/flags.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@ package flags

import (
"flag"

"github.com/PumpkinSeed/sqlfuzz/drivers"
"time"
)

var f Flags
Expand All @@ -15,7 +15,12 @@ type Flags struct {
Num int
Workers int
Table string
Parsed bool

ConnMaxLifetimeInSec time.Duration
MaxIdleConns int
MaxOpenConns int

Parsed bool
}

// Get the parsed flags and parsing them if needed
Expand All @@ -39,6 +44,9 @@ func parse() {
flag.StringVar(&f.Table, "t", "", "Table for fuzzing")
flag.IntVar(&f.Num, "n", 1000, "Number of rows")
flag.IntVar(&f.Workers, "w", 20, "Number of workers")
flag.IntVar(&f.MaxIdleConns, "i", 200, "Number of max sql db idle connections")
flag.IntVar(&f.MaxOpenConns, "o", 1000, "Number of max sql db open connections")
flag.DurationVar(&f.ConnMaxLifetimeInSec, "l", 100*time.Second, "Maximum lifetime of each open connection")
flag.Parse()
}

Expand Down
67 changes: 31 additions & 36 deletions pkg/fuzzer/runner.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,20 +14,20 @@ import (

func getDriverAndDB(f flags.Flags) (drivers.Driver, *sql.DB) {
driver := drivers.New(f.Driver)
db := connector.Connection(driver)
db := connector.Connection(driver, f)
db.SetMaxOpenConns(f.Workers)
db.SetMaxIdleConns(f.Workers)
return driver, db
}

func runHelper(f flags.Flags, exec func(...interface{}) error, args []interface{}) error {
func runHelper(f flags.Flags, input action.SQLInsertInput) error {
numJobs := f.Num
workers := f.Workers
jobs := make(chan struct{}, numJobs)
wg := &sync.WaitGroup{}
wg.Add(workers)
for w := 0; w < workers; w++ {
go newWorker(jobs, wg, f, exec, args)
go worker(jobs, wg, f, input)
}

for j := 0; j < numJobs; j++ {
Expand All @@ -39,58 +39,53 @@ func runHelper(f flags.Flags, exec func(...interface{}) error, args []interface{
return nil
}

// Run the commands in a worker pool
func Run(fields []drivers.FieldDescriptor, f flags.Flags) error {
driver, db := getDriverAndDB(f)
func worker(jobs <-chan struct{}, wg *sync.WaitGroup, f flags.Flags, input action.SQLInsertInput) {
defer wg.Done()
driver := drivers.New(f.Driver)
db := connector.Connection(driver, f)
defer func() {
if err := db.Close(); err != nil {
log.Print(err)
}
}()
return runHelper(f, action.Insert, []interface{}{db, fields, driver, f.Table})
}

func RunMulti(tableToFieldsMap map[string][]drivers.FieldDescriptor, insertionOrder []string, f flags.Flags) error {
driver, db := getDriverAndDB(f)
defer func() {
if err := db.Close(); err != nil {
log.Print(err)
for range jobs {
if err := input.Insert(); err != nil {
log.Println(err)
}
}()
return runHelper(f, action.InsertMulti, []interface{}{db, driver, tableToFieldsMap, insertionOrder})
}
}

func newWorker(jobs <-chan struct{}, wg *sync.WaitGroup, f flags.Flags, exec func(...interface{}) error, args []interface{}) {
defer wg.Done()
driver := drivers.New(f.Driver)
db := connector.Connection(driver)
// Run the commands in a worker pool
func Run(fields []drivers.FieldDescriptor, f flags.Flags) error {
driver, db := getDriverAndDB(f)
defer func() {
if err := db.Close(); err != nil {
log.Print(err)
}
}()
for range jobs {
if err := exec(args...); err != nil {
log.Println(err)
}
sqlInsertInput := action.SQLInsertInput{
SingleInsertParams: &action.SingleInsertParams{
DB: db,
Driver: driver,
Table: f.Table,
Fields: fields,
},
}
return runHelper(f, sqlInsertInput)
}

// worker of the worker pool, executing the command, logging if fails
func worker(jobs <-chan struct{}, fields []drivers.FieldDescriptor, wg *sync.WaitGroup, f flags.Flags) {
driver := drivers.New(f.Driver)
db := connector.Connection(driver)
func RunMulti(tableToFieldsMap map[string][]drivers.FieldDescriptor, insertionOrder []string, f flags.Flags) error {
driver, db := getDriverAndDB(f)
defer func() {
if err := db.Close(); err != nil {
log.Print(err)
}
}()

for range jobs {
if err := action.Insert(db, fields, driver, f.Table); err != nil {
log.Println(err)
}
}

wg.Done()
sqlInsertInput := action.SQLInsertInput{MultiInsertParams: &action.MultiInsertParams{
DB: db,
Driver: driver,
InsertionOrder: insertionOrder,
TableToFieldsMap: tableToFieldsMap,
}}
return runHelper(f, sqlInsertInput)
}

0 comments on commit 3743db2

Please sign in to comment.