diff --git a/db.go b/db.go index ccab03e..abe3ad5 100644 --- a/db.go +++ b/db.go @@ -32,6 +32,8 @@ func init() { log.Printf("failed to connect database, got error %v\n", err) } + sqlDB.SetMaxOpenConns(1) + RunMigrations() if DB.Dialector.Name() == "sqlite" { DB.Exec("PRAGMA foreign_keys = ON") @@ -69,7 +71,9 @@ func OpenTestConnection() (db *gorm.DB, err error) { db, err = gorm.Open(sqlserver.Open(dbDSN), &gorm.Config{}) default: log.Println("testing sqlite3...") - db, err = gorm.Open(sqlite.Open(filepath.Join(os.TempDir(), "gorm.db")), &gorm.Config{}) + db, err = gorm.Open(sqlite.Open(filepath.Join(os.TempDir(), "gorm.db")), &gorm.Config{ + PrepareStmt: true, // set this to false and the test passes + }) } if debug := os.Getenv("DEBUG"); debug == "true" { diff --git a/go.mod b/go.mod index f194634..a0ae2bf 100644 --- a/go.mod +++ b/go.mod @@ -30,12 +30,9 @@ require ( golang.org/x/crypto v0.29.0 // indirect golang.org/x/mod v0.22.0 // indirect golang.org/x/sync v0.9.0 // indirect - golang.org/x/sys v0.27.0 // indirect golang.org/x/text v0.20.0 // indirect golang.org/x/tools v0.27.0 // indirect gorm.io/datatypes v1.2.4 // indirect gorm.io/hints v1.1.2 // indirect gorm.io/plugin/dbresolver v1.5.3 // indirect ) - -replace gorm.io/gorm => ./gorm diff --git a/main_test.go b/main_test.go index 60a388f..0697f03 100644 --- a/main_test.go +++ b/main_test.go @@ -1,9 +1,14 @@ package main import ( + "context" + "sync" "testing" + "time" ) +const concurrentReads = 40 + // GORM_REPO: https://github.com/go-gorm/gorm.git // GORM_BRANCH: master // TEST_DRIVERS: sqlite, mysql, postgres, sqlserver @@ -13,8 +18,60 @@ func TestGORM(t *testing.T) { DB.Create(&user) - var result User - if err := DB.First(&result, user.ID).Error; err != nil { - t.Errorf("Failed, got error: %v", err) + testRunSuccessful := false + wgSuccess := sync.WaitGroup{} + wgSuccess.Add(concurrentReads) + + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + + start := make(chan struct{}) + + for i := 0; i < concurrentReads/2; i++ { + go func() { + t.Logf("Entered routine 1-%d", i) + var result User + + <-start + transaction := DB.Begin() + if err := transaction.First(&result, "id = ? ", 1).Error; err != nil { + transaction.Rollback() + return + } + transaction.Commit() + t.Log("Got User from routine 1") + wgSuccess.Done() + }() } + + for i := 0; i < concurrentReads/2; i++ { + go func() { + t.Logf("Entered routine 2-%d", i) + var result User + + <-start + if err := DB.First(&result, "id = ? ", 1).Error; err != nil { + t.Errorf("Failed, got error: %v", err) + return + } + t.Log("Got User from routine 2") + wgSuccess.Done() + }() + } + + time.Sleep(200 * time.Millisecond) + close(start) + t.Log("Started routines") + + go func() { + wgSuccess.Wait() + testRunSuccessful = true + }() + + <-ctx.Done() + if !testRunSuccessful { + t.Fatalf("Test failed") + } + + t.Logf("Test completed") }