diff --git a/app/Database/os.go b/app/Database/os.go index f8baefe..2e0c629 100644 --- a/app/Database/os.go +++ b/app/Database/os.go @@ -24,7 +24,7 @@ func GetOS(db *sql.DB) ([]OS, error) { defer func(rows *sql.Rows) { err := rows.Close() if err != nil { - + log.Println("Error closing rows stream") } }(rows) @@ -87,10 +87,7 @@ func GetVersionsByDistributionList(db *sql.DB, d string) ([]null.String, error) func checkIfOsExists(os OS, db *sql.DB) bool { row := db.QueryRow("Select distribution, version from dashboard_os where distribution = ? and version = ?", os.Distribution, os.Version) err := row.Scan(&os.Distribution, &os.Version) - if !errors.Is(err, sql.ErrNoRows) { - return true - } - return false + return !errors.Is(err, sql.ErrNoRows) } func CreateOS(os OS, db *sql.DB) error { diff --git a/app/Database/server.go b/app/Database/server.go index 5f09f5c..6e10d06 100644 --- a/app/Database/server.go +++ b/app/Database/server.go @@ -22,7 +22,7 @@ func GetServersList(db *sql.DB) ([]Server, error) { defer func(rows *sql.Rows) { err := rows.Close() if err != nil { - + log.Println("Error closing rows query", err) } }(rows) @@ -56,7 +56,7 @@ func GetServersbyOS(db *sql.DB, id int64) ([]Server, error) { defer func(rows *sql.Rows) { err := rows.Close() if err != nil { - + log.Println("Error closing rows query", err) } }(rows) @@ -74,10 +74,7 @@ func GetServersbyOS(db *sql.DB, id int64) ([]Server, error) { func checkIfServerExists(server Server, db *sql.DB) bool { row := db.QueryRow("Select hostname, os_id from dashboard_server where hostname = ? and os_id = ?", server.Hostname, server.OsId) err := row.Scan(&server.Hostname, &server.OsId) - if !errors.Is(err, sql.ErrNoRows) { - return true - } - return false + return !errors.Is(err, sql.ErrNoRows) } func CreateServer(server Server, db *sql.DB) error { diff --git a/app/Http/os.go b/app/Http/os.go index c282f83..fb93586 100644 --- a/app/Http/os.go +++ b/app/Http/os.go @@ -4,7 +4,6 @@ import ( "encoding/json" db "infra-dashboard/Database" "io" - "io/ioutil" "log" "net/http" "strconv" @@ -30,7 +29,11 @@ func GetOS(w http.ResponseWriter, r *http.Request) { } logRequest(t, r, status) w.Header().Set("Content-Type", "application/json") - json.NewEncoder(w).Encode(list) + err = json.NewEncoder(w).Encode(list) + if err != nil { + log.Println("Error encoding OS list") + return + } } func GetOSbyID(w http.ResponseWriter, r *http.Request) { @@ -55,7 +58,11 @@ func GetOSbyID(w http.ResponseWriter, r *http.Request) { } logRequest(t, r, status) w.Header().Set("Content-Type", "application/json") - json.NewEncoder(w).Encode(os) + err = json.NewEncoder(w).Encode(os) + if err != nil { + log.Println("Error getting OS by ID") + return + } } func GetDistributionList(w http.ResponseWriter, r *http.Request) { @@ -74,7 +81,11 @@ func GetDistributionList(w http.ResponseWriter, r *http.Request) { } logRequest(t, r, status) w.Header().Set("Content-Type", "application/json") - json.NewEncoder(w).Encode(list) + err = json.NewEncoder(w).Encode(list) + if err != nil { + log.Println("Error getting distribution list") + return + } } func GetVersionsByDistributionList(w http.ResponseWriter, r *http.Request) { @@ -94,7 +105,11 @@ func GetVersionsByDistributionList(w http.ResponseWriter, r *http.Request) { } logRequest(t, r, status) w.Header().Set("Content-Type", "application/json") - json.NewEncoder(w).Encode(list) + err = json.NewEncoder(w).Encode(list) + if err != nil { + log.Println("Error getting distribution list") + return + } } func CreateOS(w http.ResponseWriter, r *http.Request) { @@ -110,7 +125,11 @@ func CreateOS(w http.ResponseWriter, r *http.Request) { } params := make(map[string]null.String) - json.Unmarshal(body, ¶ms) + err = json.Unmarshal(body, ¶ms) + if err != nil { + log.Println("Error unmarshalling request body", err) + return + } os.Distribution = params["distribution"] os.Version = params["version"] os.EndOfSupport = params["end_of_support"] @@ -131,13 +150,17 @@ func DeleteOS(w http.ResponseWriter, r *http.Request) { var status int t := time.Now() - body, err := ioutil.ReadAll(r.Body) + body, err := io.ReadAll(r.Body) if err != nil { log.Println(err.Error(), "Error reading request body") } params := make(map[string]null.String) - json.Unmarshal(body, ¶ms) + err = json.Unmarshal(body, ¶ms) + if err != nil { + log.Println("Error unmarshalling request body", err) + return + } os.Distribution = params["distribution"] os.Version = params["version"] dbConn := db.GetDatabaseConnection() diff --git a/app/Http/os_test.go b/app/Http/os_test.go new file mode 100644 index 0000000..cdd3be1 --- /dev/null +++ b/app/Http/os_test.go @@ -0,0 +1,37 @@ +package http + +import ( + "net/http" + "net/http/httptest" + "testing" +) + +func TestGetOS(t *testing.T) { + // Create a request to pass to our handler. We don't have any query parameters for now, so we'll + // pass 'nil' as the third parameter. + req, err := http.NewRequest("GET", "/os", nil) + if err != nil { + t.Fatal(err) + } + + // We create a ResponseRecorder (which satisfies http.ResponseWriter) to record the response. + rr := httptest.NewRecorder() + handler := http.HandlerFunc(GetOS) + + // Our handlers satisfy http.Handler, so we can call their ServeHTTP method + // directly and pass in our Request and ResponseRecorder. + handler.ServeHTTP(rr, req) + + // Check the status code is what we expect. + if status := rr.Code; status != http.StatusOK { + t.Errorf("handler returned wrong status code: got %v want %v", + status, http.StatusOK) + } + + // Check the response body is what we expect. + expected := `{"alive": true}` + if rr.Body.String() != expected { + t.Errorf("handler returned unexpected body: got %v want %v", + rr.Body.String(), expected) + } +} diff --git a/app/Http/package.go b/app/Http/package.go index 776df67..dbfa4d7 100644 --- a/app/Http/package.go +++ b/app/Http/package.go @@ -4,7 +4,7 @@ import ( "encoding/json" "gopkg.in/guregu/null.v4" db "infra-dashboard/Database" - "io/ioutil" + "io" "log" "net/http" "strconv" @@ -29,7 +29,11 @@ func GetAllPackages(w http.ResponseWriter, r *http.Request) { } logRequest(t, r, status) w.Header().Set("Content-Type", "application/json") - json.NewEncoder(w).Encode(list) + err = json.NewEncoder(w).Encode(list) + if err != nil { + log.Println("Error encoding response") + return + } } func GetPackagebyID(w http.ResponseWriter, r *http.Request) { @@ -54,7 +58,11 @@ func GetPackagebyID(w http.ResponseWriter, r *http.Request) { } logRequest(t, r, status) w.Header().Set("Content-Type", "application/json") - json.NewEncoder(w).Encode(pkg) + err = json.NewEncoder(w).Encode(pkg) + if err != nil { + log.Println("Error encoding response") + return + } } func CreatePackage(w http.ResponseWriter, r *http.Request) { @@ -63,13 +71,17 @@ func CreatePackage(w http.ResponseWriter, r *http.Request) { t := time.Now() status := 200 - body, err := ioutil.ReadAll(r.Body) + body, err := io.ReadAll(r.Body) if err != nil { log.Println(err.Error(), "Error reading request body") } params := make(map[string]null.String) - json.Unmarshal(body, ¶ms) + err = json.Unmarshal(body, ¶ms) + if err != nil { + log.Println(err.Error(), "Error reading request body") + return + } pkg.Name = params["name"] dbConn := db.GetDatabaseConnection() defer dbConn.Close() @@ -88,13 +100,17 @@ func DisablePackage(w http.ResponseWriter, r *http.Request) { var status int t := time.Now() - body, err := ioutil.ReadAll(r.Body) + body, err := io.ReadAll(r.Body) if err != nil { log.Println(err.Error(), "Error reading request body") } params := make(map[string]null.String) - json.Unmarshal(body, ¶ms) + err = json.Unmarshal(body, ¶ms) + if err != nil { + log.Println(err.Error(), "Error reading request body") + return + } pkg.Name = params["name"] dbConn := db.GetDatabaseConnection() defer dbConn.Close() @@ -112,13 +128,17 @@ func EnablePackage(w http.ResponseWriter, r *http.Request) { var status int t := time.Now() - body, err := ioutil.ReadAll(r.Body) + body, err := io.ReadAll(r.Body) if err != nil { log.Println(err.Error(), "Error reading request body") } params := make(map[string]null.String) - json.Unmarshal(body, ¶ms) + err = json.Unmarshal(body, ¶ms) + if err != nil { + log.Println(err.Error(), "Error reading request body") + return + } pkg.Name = params["name"] dbConn := db.GetDatabaseConnection() defer dbConn.Close() @@ -136,13 +156,17 @@ func DeletePackage(w http.ResponseWriter, r *http.Request) { var status int t := time.Now() - body, err := ioutil.ReadAll(r.Body) + body, err := io.ReadAll(r.Body) if err != nil { log.Println(err.Error(), "Error reading request body") } params := make(map[string]null.String) - json.Unmarshal(body, ¶ms) + err = json.Unmarshal(body, ¶ms) + if err != nil { + log.Println(err.Error(), "Error reading request body") + return + } pkg.Name = params["name"] dbConn := db.GetDatabaseConnection() defer dbConn.Close() diff --git a/app/Http/server.go b/app/Http/server.go index 831006c..dfd4a11 100644 --- a/app/Http/server.go +++ b/app/Http/server.go @@ -4,7 +4,7 @@ import ( "encoding/json" "gopkg.in/guregu/null.v4" db "infra-dashboard/Database" - "io/ioutil" + "io" "log" "net/http" "strconv" @@ -29,7 +29,11 @@ func GetServersList(w http.ResponseWriter, r *http.Request) { } logRequest(t, r, status) w.Header().Set("Content-Type", "application/json") - json.NewEncoder(w).Encode(list) + err = json.NewEncoder(w).Encode(list) + if err != nil { + log.Println("Error getting OS list") + return + } } func GetServersbyID(w http.ResponseWriter, r *http.Request) { @@ -50,7 +54,11 @@ func GetServersbyID(w http.ResponseWriter, r *http.Request) { server = db.GetServersbyID(dbConn, int64(id)) logRequest(t, r, status) w.Header().Set("Content-Type", "application/json") - json.NewEncoder(w).Encode(server) + err = json.NewEncoder(w).Encode(server) + if err != nil { + log.Println("Error getting OS list") + return + } } func GetServersbyOS(w http.ResponseWriter, r *http.Request) { @@ -75,7 +83,11 @@ func GetServersbyOS(w http.ResponseWriter, r *http.Request) { } logRequest(t, r, status) w.Header().Set("Content-Type", "application/json") - json.NewEncoder(w).Encode(servers) + err = json.NewEncoder(w).Encode(servers) + if err != nil { + log.Println("Error getting OS list") + return + } } func CreateServer(w http.ResponseWriter, r *http.Request) { @@ -85,13 +97,17 @@ func CreateServer(w http.ResponseWriter, r *http.Request) { t := time.Now() status := 204 - body, err := ioutil.ReadAll(r.Body) + body, err := io.ReadAll(r.Body) if err != nil { log.Println(err.Error(), "Error reading request body") } params := make(map[string]null.String) - json.Unmarshal(body, ¶ms) + err = json.Unmarshal(body, ¶ms) + if err != nil { + log.Println(err.Error(), "Error parsing request body") + return + } server.Hostname = params["hostname"] server.OsId = params["os_id"] dbConn := db.GetDatabaseConnection() @@ -111,13 +127,17 @@ func DeleteServer(w http.ResponseWriter, r *http.Request) { var status int t := time.Now() - body, err := ioutil.ReadAll(r.Body) + body, err := io.ReadAll(r.Body) if err != nil { log.Println(err.Error(), "Error reading request body") } params := make(map[string]null.String) - json.Unmarshal(body, ¶ms) + err = json.Unmarshal(body, ¶ms) + if err != nil { + log.Println(err.Error(), "Error parsing request body") + return + } server.Hostname = params["hostname"] server.OsId = params["os_id"] dbConn := db.GetDatabaseConnection()