如何使用gorm进行单元测试

huangapple go评论79阅读模式
英文:

How to do Unit Testing with gorm

问题

你好!以下是你的代码的中文翻译:

我是`Go``单元测试`的新手在我的项目中我使用`Go``gorm`连接`mysql`数据库

我的问题是如何对我的代码进行单元测试

**我的代码如下main.go):**

package main

import (
	"encoding/json"
	"fmt"
	"net/http"
	"strconv"
	"time"

	"github.com/gorilla/mux"
	"github.com/jinzhu/gorm"
	_ "github.com/jinzhu/gorm/dialects/mysql"
)

type Jobs struct {
	JobID                  uint   `json:"jobId" gorm:"primary_key;auto_increment"`
	SourcePath             string `json:"sourcePath"`
	Priority               int64  `json:"priority"`
	InternalPriority       string `json:"internalPriority"`
	ExecutionEnvironmentID string `json:"executionEnvironmentID"`
}

type ExecutionEnvironment struct {
	ID                     uint      `json:"id" gorm:"primary_key;auto_increment"`
	ExecutionEnvironmentId string    `json:"executionEnvironmentID"`
	CloudProviderType      string    `json:"cloudProviderType"`
	InfrastructureType     string    `json:"infrastructureType"`
	CloudRegion            string    `json:"cloudRegion"`
	CreatedAt              time.Time `json:"createdAt"`
}

var db *gorm.DB

func initDB() {
	var err error
	dataSourceName := "root:@tcp(localhost:3306)/?parseTime=True"
	db, err = gorm.Open("mysql", dataSourceName)

	if err != nil {
		fmt.Println(err)
		panic("failed to connect database")
	}
	//db.Exec("CREATE DATABASE test")
	db.LogMode(true)
	db.Exec("USE test")
	db.AutoMigrate(&Jobs{}, &ExecutionEnvironment{})
}

func GetAllJobs(w http.ResponseWriter, r *http.Request) {
	w.Header().Set("Content-Type", "application/json")
	fmt.Println("Executing Get All Jobs function")

	var jobs []Jobs
	if err := db.Select("jobs.*, execution_environments.*").Joins("JOIN execution_environments on execution_environments.execution_environment_id = jobs.execution_environment_id").Find(&jobs).Error; err != nil {
		fmt.Println(err)
	}
	fmt.Println()
	if len(jobs) == 0 {
		json.NewEncoder(w).Encode("No data found")
	} else {
		json.NewEncoder(w).Encode(jobs)
	}
}

// create job
func createJob(w http.ResponseWriter, r *http.Request) {
	w.Header().Set("Content-Type", "application/json")
	fmt.Println("Executing Create Jobs function")
	var jobs Jobs
	json.NewDecoder(r.Body).Decode(&jobs)
	db.Create(&jobs)
	json.NewEncoder(w).Encode(jobs)
}

// get job by id
func GetJobById(w http.ResponseWriter, r *http.Request) {
	w.Header().Set("Content-Type", "application/json")
	params := mux.Vars(r)
	jobId := params["jobId"]

	//var job []Jobs
	//db.Preload("Items").First(&job, jobId)
	var jobs []Jobs
	var executionEnvironments []ExecutionEnvironment
	if err := db.Table("jobs").Select("jobs.*, execution_environments.*").Joins("JOIN execution_environments on execution_environments.execution_environment_id = jobs.execution_environment_id").Where("job_id =?", jobId).Find(&jobs).Scan(&executionEnvironments).Error; err != nil {
		fmt.Println(err)
	}

	if len(jobs) == 0 {
		json.NewEncoder(w).Encode("No data found")
	} else {
		json.NewEncoder(w).Encode(jobs)
	}
}

// Delete Job By Id
func DeleteJobById(w http.ResponseWriter, r *http.Request) {
	params := mux.Vars(r)
	jobId := params["jobId"]

	// check data
	var job []Jobs
	db.Table("jobs").Select("jobs.*").Where("job_id=?", jobId).Find(&job)
	if len(job) == 0 {
		json.NewEncoder(w).Encode("Invalid JobId")
	} else {

		id64, _ := strconv.ParseUint(jobId, 10, 64)
		idToDelete := uint(id64)

		db.Where("job_id = ?", idToDelete).Delete(&Jobs{})
		//db.Where("jobId = ?", idToDelete).Delete(&ExecutionEnvironment{})

		json.NewEncoder(w).Encode("Job deleted successfully")
		w.WriteHeader(http.StatusNoContent)
	}

}

// create Execution Environments
func createEnvironments(w http.ResponseWriter, r *http.Request) {
	w.Header().Set("Content-Type", "application/json")
	fmt.Println("Executing Create Execution Environments function")
	var executionEnvironments ExecutionEnvironment
	json.NewDecoder(r.Body).Decode(&executionEnvironments)
	db.Create(&executionEnvironments)
	json.NewEncoder(w).Encode(executionEnvironments)
}

// Get Job Cloud Region
func GetJobCloudRegion(w http.ResponseWriter, r *http.Request) {
	w.Header().Set("Content-Type", "application/json")
	fmt.Println("Executing Get Job Cloud Region function")

	params := mux.Vars(r)
	jobId := params["jobId"]

	//var jobs []Jobs
	var executionEnvironment []ExecutionEnvironment

	db.Table("jobs").Select("execution_environments.*").Joins("JOIN execution_environments on execution_environments.execution_environment_id = jobs.execution_environment_id").Where("jobs.job_id =?", jobId).Find(&executionEnvironment)

	var pUuid []string
	for _, uuid := range executionEnvironment {
		pUuid = append(pUuid, uuid.CloudRegion)
	}
	json.NewEncoder(w).Encode(pUuid)

}

func main() {
	// router
	router := mux.NewRouter()
	// Access URL
	router.HandleFunc("/GetAllJobs", GetAllJobs).Methods("GET")
	router.HandleFunc("/createJob", createJob).Methods("POST")
	router.HandleFunc("/GetJobById/{jobId}", GetJobById).Methods("GET")
	router.HandleFunc("/DeleteJobById/{jobId}", DeleteJobById).Methods("DELETE")

	router.HandleFunc("/createEnvironments", createEnvironments).Methods("POST")
	router.HandleFunc("/GetJobCloudRegion/{jobId}", GetJobCloudRegion).Methods("GET")

	// Initialize db connection
	initDB()

	// config port
	fmt.Printf("Starting server at 8000 \n")
	http.ListenAndServe(":8000", router)
}

你尝试创建的单元测试文件如下,但它没有运行,显示如下所示:

main_test.go:

package main

import (
	"log"
	"os"
	"testing"

	"github.com/jinzhu/gorm"
	_ "github.com/jinzhu/gorm/dialects/mysql"
)

func TestinitDB(m *testing.M) {
	dataSourceName := "root:@tcp(localhost:3306)/?parseTime=True"
	db, err := gorm.Open("mysql", dataSourceName)

	if err != nil {
		log.Fatal("failed to connect database")
	}
	//db.Exec("CREATE DATABASE test")
	db.LogMode(true)
	db.Exec("USE test111")
	os.Exit(m.Run())
}

请帮我编写单元测试文件。

英文:

I'm new in Go and unit test. In my project am using Go with gorm and connecting mysql database.

my queries is how to unit test my code:

My code is below(main.go):

package main
import (
"encoding/json"
"fmt"
"net/http"
"strconv"
"time"
"github.com/gorilla/mux"
"github.com/jinzhu/gorm"
_ "github.com/jinzhu/gorm/dialects/mysql"
)
type Jobs struct {
JobID                  uint   `json: "jobId" gorm:"primary_key;auto_increment"`
SourcePath             string `json: "sourcePath"`
Priority               int64  `json: "priority"`
InternalPriority       string `json: "internalPriority"`
ExecutionEnvironmentID string `json: "executionEnvironmentID"`
}
type ExecutionEnvironment struct {
ID                     uint      `json: "id" gorm:"primary_key;auto_increment"`
ExecutionEnvironmentId string    `json: "executionEnvironmentID"`
CloudProviderType      string    `json: "cloudProviderType"`
InfrastructureType     string    `json: "infrastructureType"`
CloudRegion            string    `json: "cloudRegion"`
CreatedAt              time.Time `json: "createdAt"`
}
var db *gorm.DB
func initDB() {
var err error
dataSourceName := "root:@tcp(localhost:3306)/?parseTime=True"
db, err = gorm.Open("mysql", dataSourceName)
if err != nil {
fmt.Println(err)
panic("failed to connect database")
}
//db.Exec("CREATE DATABASE test")
db.LogMode(true)
db.Exec("USE test")
db.AutoMigrate(&Jobs{}, &ExecutionEnvironment{})
}
func GetAllJobs(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
fmt.Println("Executing Get All Jobs function")
var jobs []Jobs
if err := db.Select("jobs.*, execution_environments.*").Joins("JOIN execution_environments on execution_environments.execution_environment_id = jobs.execution_environment_id").Find(&jobs).Error; err != nil {
fmt.Println(err)
}
fmt.Println()
if len(jobs) == 0 {
json.NewEncoder(w).Encode("No data found")
} else {
json.NewEncoder(w).Encode(jobs)
}
}
// create job
func createJob(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
fmt.Println("Executing Create Jobs function")
var jobs Jobs
json.NewDecoder(r.Body).Decode(&jobs)
db.Create(&jobs)
json.NewEncoder(w).Encode(jobs)
}
// get job by id
func GetJobById(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
params := mux.Vars(r)
jobId := params["jobId"]
//var job []Jobs
//db.Preload("Items").First(&job, jobId)
var jobs []Jobs
var executionEnvironments []ExecutionEnvironment
if err := db.Table("jobs").Select("jobs.*, execution_environments.*").Joins("JOIN execution_environments on execution_environments.execution_environment_id = jobs.execution_environment_id").Where("job_id =?", jobId).Find(&jobs).Scan(&executionEnvironments).Error; err != nil {
fmt.Println(err)
}
if len(jobs) == 0 {
json.NewEncoder(w).Encode("No data found")
} else {
json.NewEncoder(w).Encode(jobs)
}
}
// Delete Job By Id
func DeleteJobById(w http.ResponseWriter, r *http.Request) {
params := mux.Vars(r)
jobId := params["jobId"]
// check data
var job []Jobs
db.Table("jobs").Select("jobs.*").Where("job_id=?", jobId).Find(&job)
if len(job) == 0 {
json.NewEncoder(w).Encode("Invalid JobId")
} else {
id64, _ := strconv.ParseUint(jobId, 10, 64)
idToDelete := uint(id64)
db.Where("job_id = ?", idToDelete).Delete(&Jobs{})
//db.Where("jobId = ?", idToDelete).Delete(&ExecutionEnvironment{})
json.NewEncoder(w).Encode("Job deleted successfully")
w.WriteHeader(http.StatusNoContent)
}
}
// create Execution Environments
func createEnvironments(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
fmt.Println("Executing Create Execution Environments function")
var executionEnvironments ExecutionEnvironment
json.NewDecoder(r.Body).Decode(&executionEnvironments)
db.Create(&executionEnvironments)
json.NewEncoder(w).Encode(executionEnvironments)
}
// Get Job Cloud Region
func GetJobCloudRegion(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
fmt.Println("Executing Get Job Cloud Region function")
params := mux.Vars(r)
jobId := params["jobId"]
//var jobs []Jobs
var executionEnvironment []ExecutionEnvironment
db.Table("jobs").Select("execution_environments.*").Joins("JOIN execution_environments on execution_environments.execution_environment_id = jobs.execution_environment_id").Where("jobs.job_id =?", jobId).Find(&executionEnvironment)
var pUuid []string
for _, uuid := range executionEnvironment {
pUuid = append(pUuid, uuid.CloudRegion)
}
json.NewEncoder(w).Encode(pUuid)
}
func main() {
// router
router := mux.NewRouter()
// Access URL
router.HandleFunc("/GetAllJobs", GetAllJobs).Methods("GET")
router.HandleFunc("/createJob", createJob).Methods("POST")
router.HandleFunc("/GetJobById/{jobId}", GetJobById).Methods("GET")
router.HandleFunc("/DeleteJobById/{jobId}", DeleteJobById).Methods("DELETE")
router.HandleFunc("/createEnvironments", createEnvironments).Methods("POST")
router.HandleFunc("/GetJobCloudRegion/{jobId}", GetJobCloudRegion).Methods("GET")
// Initialize db connection
initDB()
// config port
fmt.Printf("Starting server at 8000 \n")
http.ListenAndServe(":8000", router)
}

I try to create unit test file below, but it is not running it shows like this
如何使用gorm进行单元测试

main_test.go:

package main
import (
"log"
"os"
"testing"
"github.com/jinzhu/gorm"
_ "github.com/jinzhu/gorm/dialects/mysql"
)
func TestinitDB(m *testing.M) {
dataSourceName := "root:@tcp(localhost:3306)/?parseTime=True"
db, err := gorm.Open("mysql", dataSourceName)
if err != nil {
log.Fatal("failed to connect database")
}
//db.Exec("CREATE DATABASE test")
db.LogMode(true)
db.Exec("USE test111")
os.Exit(m.Run())
}

Please help me to write unit test file

答案1

得分: 6

"如何进行单元测试"是一个非常广泛的问题,因为它取决于你想要测试什么。在你的例子中,你正在处理与数据库的远程连接,这通常是在单元测试中模拟的内容。不清楚这是否是你要寻找的,也没有必要这样做。根据你使用不同的数据库,我预计你的意图不是模拟。

首先看一下这个已经回答了你关于TestMaintesting.M如何工作的问题的这个帖子。

你当前的代码(如果你的测试名称正确命名为TestMain)是在你的其他测试周围添加一个方法来进行设置和拆卸,然而你没有其他测试来使用这个设置和拆卸,因此你会得到结果no tests to run

这不是你问题的一部分,但我建议在你对测试Go代码感到自信之前尽量避免使用testing.M。使用testing.T和测试单独的单元可能更容易理解。你可以通过在你的测试中调用initDB()并使初始化器接受一个参数来实现几乎相同的效果。

func initDB(dbToUse string) {
    // ...
    db.Exec("USE "+dbToUse)
}

然后你可以从你的主文件中调用initDB("test"),从你的测试中调用initDB("test111")
你可以在这里阅读关于Go的测试包的更多信息,你也会找到testing.Ttesting.M之间的区别。

下面是一个更简短的示例,其中包含一些基本的测试,不需要任何设置或拆卸,并且使用testing.T而不是testing.M

main.go

package main

import "fmt"

func main() {
	fmt.Println(add(1, 2))
}

func add(a, b int) int {
	return a + b
}

main_test.go

package main

import "testing"

func TestAdd(t *testing.T) {
	t.Run("add 2 + 2", func(t *testing.T) {
		want := 4

		// 调用你想要测试的函数。
		got := add(2, 2)

		// 断言你得到了预期的响应
		if got != want {
			t.Fail()
		}
	})
}

这个测试将测试你的add方法,并确保当你传入2, 2作为参数时,它返回正确的值。t.Run的使用是可选的,但它为你创建了一个子测试,使得阅读输出更容易。

由于你在包级别进行测试,如果你不使用递归包含每个包的三个点格式,你需要指定要测试的包。

要运行上面的示例测试,请指定你的包和-v以获取详细输出。

$ go test ./ -v
=== RUN   TestAdd
=== RUN   TestAdd/add_2_+_2
--- PASS: TestAdd (0.00s)
    --- PASS: TestAdd/add_2_+_2 (0.00s)
PASS
ok      x       (cached)

还有很多关于这个主题的内容需要学习,比如测试框架和测试模式。例如,测试框架testify可以帮助你进行断言,并在测试失败时打印出漂亮的输出,而表驱动测试是Go中一个相当常见的模式。

你还在编写一个HTTP服务器,通常需要额外的测试设置才能进行适当的测试。幸运的是,标准库中的http包带有一个名为httptest的子包,它可以帮助你记录外部请求或启动用于外部请求的本地服务器。你还可以通过使用手动构造的请求直接调用你的处理程序来测试你的处理程序。

它看起来像这样。

func TestSomeHandler(t *testing.T) {
    // 创建一个请求传递给我们的处理程序。现在我们没有任何查询参数,所以我们将
    // 作为第三个参数传递'nil'。
    req, err := http.NewRequest("GET", "/some-endpoint", nil)
    if err != nil {
        t.Fatal(err)
    }

    // 我们创建一个ResponseRecorder(满足http.ResponseWriter)来记录响应。
    rr := httptest.NewRecorder()
    handler := http.HandlerFunc(SomeHandler)

    // 我们的处理程序满足http.Handler,所以我们可以直接调用它们的ServeHTTP方法
    // 并传入我们的请求和ResponseRecorder。
    handler.ServeHTTP(rr, req)

    // 检查状态码是否符合预期。
    if status := rr.Code; status != http.StatusOK {
        t.Errorf("handler returned wrong status code: got %v want %v",
            status, http.StatusOK)
    }
}

现在,为了测试你的一些代码,我们可以运行初始化方法,并使用响应记录器调用任何一个你的服务。

package main

import (
	"encoding/json"
	"net/http"
	"net/http/httptest"
	"testing"
)

func TestGetAllJobs(t *testing.T) {
	// 初始化数据库
	initDB("test111")

	req, err := http.NewRequest("GET", "/GetAllJobs", nil)
	if err != nil {
		t.Fatal(err)
	}

	rr := httptest.NewRecorder()
	handler := http.HandlerFunc(GetAllJobs)

	handler.ServeHTTP(rr, req)

	// 检查状态码是否符合预期。
	if status := rr.Code; status != http.StatusOK {
		t.Errorf("handler returned wrong status code: got %v want %v",
			status, http.StatusOK)
	}

	var response []Jobs
	if err := json.Unmarshal(rr.Body.Bytes(), &response); err != nil {
		t.Errorf("got invalid response, expected list of jobs, got: %v", rr.Body.String())
	}

	if len(response) < 1 {
		t.Errorf("expected at least 1 job, got %v", len(response))
	}

	for _, job := range response {
		if job.SourcePath == "" {
			t.Errorf("expected job id %d to  have a source path, was empty", job.JobID)
		}
	}
}
英文:

"How to unit test" is a very broad question since it depends on what you want to test. In your example you're working with remote connections to a database which is usually something that is mocked in unit testing. It's not clear if that's what you're looking for and it's not a requirement to do so either. By seeing you use different databases I would expect the intention is not to mock.

Start by looking at this post that has already answered your question around how TestMain and testing.M is intended to work.

What your code currently does (if your test name would be named TestMain properly) is add a method around your other tests to do setup and teardown, however you don't have any other tests to make use of this setup and teardown thus you'll get the result no tests to run.

It's not a part of your question but I would suggest try to avoid testing.M until you feel confident in testing Go code. Using testing.T and testing separate units might be easier to understand. You could achieve pretty much the same thing by just calling initDB() in your test and making the initializer take an argument.

func initDB(dbToUse string) {
    // ...
    db.Exec(&quot;USE &quot;+dbToUse)
}

You would then call initDB(&quot;test&quot;) from your main file and initDB(&quot;test111&quot;) from your test.
You can read about the testing package for Go at pkg.go.dev/testing where you'll also find the differences between testing.T and testing.M.

Here's a shorter example with some basic testing that does not require any setup or teardown and that uses testing.T instead of testing.M.

main.go

package main

import &quot;fmt&quot;

func main() {
	fmt.Println(add(1, 2))
}

func add(a, b int) int {
	return a + b
}

main_test.go

package main

import &quot;testing&quot;

func TestAdd(t *testing.T) {
	t.Run(&quot;add 2 + 2&quot;, func(t *testing.T) {
		want := 4

		// Call the function you want to test.
		got := add(2, 2)

		// Assert that you got your expected response
		if got != want {
			t.Fail()
		}
	})
}

This test will test your method add and ensure it returns the right value when you pass 2, 2 as argument. The use of t.Run is optional but it creates a sub test for you which makes reading the output a bit easier.

Since you test on package level you'll need to specify what package to test if you're not using the triple dot format including every package recursively.

To run the test in the example above, specify your package and -v for verbose output.

$ go test ./ -v
=== RUN   TestAdd
=== RUN   TestAdd/add_2_+_2
--- PASS: TestAdd (0.00s)
    --- PASS: TestAdd/add_2_+_2 (0.00s)
PASS
ok      x       (cached)

There is a lot more to learn around this topic as well such as testing frameworks and testing patterns. As an example the testing framework testify helps you do assertions and prints nice output when tests fail and table driven tests is a pretty common pattern in Go.

You're also writing a HTTP server which usually requires additional testing setup to test properly. Luckily the http package in standard library comes with a sub package named httptest which can help you record external requests or start local servers for external requests. You can also test your handlers by directly calling your handlers with a manually constructed request.

It would look something like this.

func TestSomeHandler(t *testing.T) {
    // Create a request to pass to our handler. We don&#39;t have any query parameters for now, so we&#39;ll
    // pass &#39;nil&#39; as the third parameter.
    req, err := http.NewRequest(&quot;GET&quot;, &quot;/some-endpoint&quot;, nil)
    if err != nil {
        t.Fatal(err)
    }

    // We create a ResponseRecorder (which satisfies http.ResponseWriter) to record the response.
    rr := httptest.NewRecorder()
    handler := http.HandlerFunc(SomeHandler)

    // Our handlers satisfy http.Handler, so we can call their ServeHTTP method 
    // directly and pass in our Request and ResponseRecorder.
    handler.ServeHTTP(rr, req)

    // Check the status code is what we expect.
    if status := rr.Code; status != http.StatusOK {
        t.Errorf(&quot;handler returned wrong status code: got %v want %v&quot;,
            status, http.StatusOK)
    }

Now, to test some of your code. We can run the init method and call any of your services with a response recorder.

package main

import (
	&quot;encoding/json&quot;
	&quot;net/http&quot;
	&quot;net/http/httptest&quot;
	&quot;testing&quot;
)

func TestGetAllJobs(t *testing.T) {
	// Initialize the DB
	initDB(&quot;test111&quot;)

	req, err := http.NewRequest(&quot;GET&quot;, &quot;/GetAllJobs&quot;, nil)
	if err != nil {
		t.Fatal(err)
	}

	rr := httptest.NewRecorder()
	handler := http.HandlerFunc(GetAllJobs)

	handler.ServeHTTP(rr, req)

	// Check the status code is what we expect.
	if status := rr.Code; status != http.StatusOK {
		t.Errorf(&quot;handler returned wrong status code: got %v want %v&quot;,
			status, http.StatusOK)
	}

	var response []Jobs
	if err := json.Unmarshal(rr.Body.Bytes(), &amp;response); err != nil {
		t.Errorf(&quot;got invalid response, expected list of jobs, got: %v&quot;, rr.Body.String())
	}

	if len(response) &lt; 1 {
		t.Errorf(&quot;expected at least 1 job, got %v&quot;, len(response))
	}

	for _, job := range response {
		if job.SourcePath == &quot;&quot; {
			t.Errorf(&quot;expected job id %d to  have a source path, was empty&quot;, job.JobID)
		}
	}
}

答案2

得分: 1

你可以使用go-sqlmock库来进行测试:

package main

import (
	"database/sql"
	"regexp"
	"testing"

	"gopkg.in/DATA-DOG/go-sqlmock.v1"
	"gorm.io/driver/postgres"
	"gorm.io/gorm"
)

type Student struct {
	//*gorm.Model
	Name string
	ID   string
}

type v2Suite struct {
	db      *gorm.DB
	mock    sqlmock.Sqlmock
	student Student
}

func TestGORMV2(t *testing.T) {
	s := &v2Suite{}
	var (
		db  *sql.DB
		err error
	)

	db, s.mock, err = sqlmock.New()
	if err != nil {
		t.Errorf("Failed to open mock sql db, got error: %v", err)
	}

	if db == nil {
		t.Error("mock db is null")
	}

	if s.mock == nil {
		t.Error("sqlmock is null")
	}

	dialector := postgres.New(postgres.Config{
		DSN:                  "sqlmock_db_0",
		DriverName:           "postgres",
		Conn:                 db,
		PreferSimpleProtocol: true,
	})
	s.db, err = gorm.Open(dialector, &gorm.Config{})
	if err != nil {
		t.Errorf("Failed to open gorm v2 db, got error: %v", err)
	}

	if s.db == nil {
		t.Error("gorm db is null")
	}

	s.student = Student{
		ID:   "123456",
		Name: "Test 1",
	}

	defer db.Close()

	s.mock.MatchExpectationsInOrder(false)
	s.mock.ExpectBegin()

	s.mock.ExpectQuery(regexp.QuoteMeta(
		`INSERT INTO "students" ("id","name")
					VALUES ($1,$2) RETURNING "students"."id"`)).
		WithArgs(s.student.ID, s.student.Name).
		WillReturnRows(sqlmock.NewRows([]string{"id"}).
			AddRow(s.student.ID))

	s.mock.ExpectCommit()

	if err = s.db.Create(&s.student).Error; err != nil {
		t.Errorf("Failed to insert to gorm db, got error: %v", err)
	}

	err = s.mock.ExpectationsWereMet()
	if err != nil {
		t.Errorf("Failed to meet expectations, got error: %v", err)
	}
}

这是一个使用go-sqlmock库进行测试的示例代码。你可以根据自己的需求进行修改和使用。

英文:

you can use go-sqlmock:

    package main
import (
&quot;database/sql&quot;
&quot;regexp&quot;
&quot;testing&quot;
&quot;gopkg.in/DATA-DOG/go-sqlmock.v1&quot;
&quot;gorm.io/driver/postgres&quot;
&quot;gorm.io/gorm&quot;
)
type Student struct {
//*gorm.Model
Name string
ID string
}
type v2Suite struct {
db      *gorm.DB
mock    sqlmock.Sqlmock
student Student
}
func TestGORMV2(t *testing.T) {
s := &amp;v2Suite{}
var (
db  *sql.DB
err error
)
db, s.mock, err = sqlmock.New()
if err != nil {
t.Errorf(&quot;Failed to open mock sql db, got error: %v&quot;, err)
}
if db == nil {
t.Error(&quot;mock db is null&quot;)
}
if s.mock == nil {
t.Error(&quot;sqlmock is null&quot;)
}
dialector := postgres.New(postgres.Config{
DSN:                  &quot;sqlmock_db_0&quot;,
DriverName:           &quot;postgres&quot;,
Conn:                 db,
PreferSimpleProtocol: true,
})
s.db, err = gorm.Open(dialector, &amp;gorm.Config{})
if err != nil {
t.Errorf(&quot;Failed to open gorm v2 db, got error: %v&quot;, err)
}
if s.db == nil {
t.Error(&quot;gorm db is null&quot;)
}
s.student = Student{
ID:   &quot;123456&quot;,
Name: &quot;Test 1&quot;,
}
defer db.Close()
s.mock.MatchExpectationsInOrder(false)
s.mock.ExpectBegin()
s.mock.ExpectQuery(regexp.QuoteMeta(
`INSERT INTO &quot;students&quot; (&quot;id&quot;,&quot;name&quot;)
VALUES ($1,$2) RETURNING &quot;students&quot;.&quot;id&quot;`)).
WithArgs(s.student.ID, s.student.Name).
WillReturnRows(sqlmock.NewRows([]string{&quot;id&quot;}).
AddRow(s.student.ID))
s.mock.ExpectCommit()
if err = s.db.Create(&amp;s.student).Error; err != nil {
t.Errorf(&quot;Failed to insert to gorm db, got error: %v&quot;, err)
}
err = s.mock.ExpectationsWereMet()
if err != nil {
t.Errorf(&quot;Failed to meet expectations, got error: %v&quot;, err)
}
}

huangapple
  • 本文由 发表于 2021年9月23日 20:20:19
  • 转载请务必保留本文链接:https://go.coder-hub.com/69299894.html
匿名

发表评论

匿名网友

:?: :razz: :sad: :evil: :!: :smile: :oops: :grin: :eek: :shock: :???: :cool: :lol: :mad: :twisted: :roll: :wink: :idea: :arrow: :neutral: :cry: :mrgreen:

确定