15 Star 88 Fork 24

konyshe / gogo

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
克隆/下载
UtilsSQL.go 15.02 KB
一键复制 编辑 原始数据 按行查看 历史
konyshe 提交于 2021-05-11 16:35 . 1. 增加支持postgres数据库
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559
package gogo
import (
"database/sql"
"encoding/json"
"errors"
"fmt"
"strconv"
"strings"
_ "github.com/bmizerany/pq"
_ "github.com/go-sql-driver/mysql"
)
var (
// mDB 全局的数据库操作句柄
mDB *sql.DB
// mDriverName 数据库类型,例如mysql、sqlite、postgres等,参考github.com/go-sql-driver/mysql官方介绍
mDriverName string
)
// SQLInit 初始化数据库操作句柄,这里要提供:
// driverName string: 数据库类型,例如mysql、sqlite、postgres等,参考github.com/go-sql-driver/mysql官方介绍
// dataSourceName string: 数据库地址,参考github.com/go-sql-driver/mysql官方介绍
// MaxOpenConns int: 最大缓存连接数,这个数值包含了MaxIdleConns
// MaxIdleConns int:预备的最大空闲连接数
// 返回 error: 如果初始化失败,会将原始error返回。如果之前初始化过,则不会重复初始化,error返回空
func SQLInit(driverName, dataSourceName string, maxOpenConns, maxIdleConns int) error {
if mDB == nil {
var err error
if mDB, err = sql.Open(driverName, dataSourceName); err != nil {
return err
}
mDB.SetMaxOpenConns(maxOpenConns)
mDB.SetMaxIdleConns(maxIdleConns)
mDriverName = driverName
}
return nil
}
// SQLInsert 增加一条数据
// tableName string: 操作的表名
// data []byte: 需要更新的内容,用string转换后是json格式
// 返回 int64: 最后Insert成功的ID
// 返回 error: 如果insert失败,会将原始error返回
func SQLInsert(tableName string, data []byte) (int64, error) {
if err := sqlCheckParam(tableName); err != nil {
return 0, err
}
var f []map[string]interface{}
err := json.Unmarshal(data, &f)
if err != nil {
return 0, err
}
var res sql.Result
for _, data := range f {
var sqlset string
for k, v := range data {
if sqlset != "" {
sqlset += ","
}
switch vv := v.(type) {
case string:
sqlset += k + "='" + vv + "'"
case int:
sqlset += k + "=" + strconv.Itoa(vv)
case float64:
sqlset += k + "=" + strconv.FormatFloat(vv, 'f', -1, 64)
default:
fmt.Println(k, "is of a type I don't know how to handle")
}
}
stmt, err := mDB.Prepare("INSERT " + tableName + " set " + sqlset)
if err != nil {
return 0, err
}
res, err = stmt.Exec()
if err != nil {
return 0, err
}
}
if mDriverName == "postgres" {
return 0, nil
}
return res.LastInsertId()
}
// SQLUpdate 更新一条数据
// tableName string: 操作的表名
// where string: 过滤条件,就是where后面跟着的部分
// data []byte: 需要更新的内容,用string转换后是json格式
func SQLUpdate(tableName, where string, data []byte) (int64, error) {
if err := sqlCheckParam(tableName + where); err != nil {
return 0, err
}
var f map[string]interface{}
err := json.Unmarshal(data, &f)
var sqlset string
for k, v := range f {
if sqlset != "" {
sqlset += ","
}
switch vv := v.(type) {
case string:
sqlset += k + "='" + vv + "'"
case int:
sqlset += k + "=" + strconv.Itoa(vv)
case float64:
sqlset += k + "=" + strconv.FormatFloat(vv, 'f', -1, 64)
default:
fmt.Println(k, "is of a type I don't know how to handle")
}
}
stmt, err := mDB.Prepare("UPDATE " + tableName + " set " + sqlset + " where " + where)
if err != nil {
return 0, err
}
res, err := stmt.Exec()
if err != nil {
return 0, err
}
return res.RowsAffected()
}
// SQLDelete 根据where条件删除数据
// tableName string: 操作的表名
// where string: 过滤条件,就是where后面跟着的部分
func SQLDelete(tableName, where string) (int64, error) {
if err := sqlCheckParam(tableName + where); err != nil {
return 0, err
}
if mDB == nil {
return 0, errors.New("gogo sql not init")
}
//删除数据
stmt, err := mDB.Prepare("DELETE from " + tableName + " where " + where)
if err != nil {
return 0, err
}
res, err := stmt.Exec()
if err != nil {
return 0, err
}
return res.RowsAffected()
}
// SQLQuery 完全自定义查询语句,相当于直接暴露原始的sql查询接口
// sqlstr string: 自定义查询语句,例如:select count(*) from table_name
// where string: 过滤条件,就是where后面跟着的部分
// 返回 (*sql.Rows, error): 就是原始sql接口返回的类型
func SQLQuery(sqlstr string) (*sql.Rows, error) {
if mDB == nil {
return nil, errors.New("gogo sql not init !")
}
LogDebug(sqlstr)
return mDB.Query(sqlstr)
}
// SQLQueryRows 半自定义查询语句,提供查询关键参数即可
// feilds string: 查询需要获取哪些字段的值,就是select后面跟着的部分,一般用"*"
// tableName string: 查询的表名
// where string: 过滤条件,就是where后面跟着的部分
// order string: 排序条件,就是order by后面跟着的部分。默认是ASC排序,除非"-"开头则DESC排序
// offset string: limit后面逗号相隔的两个数值,前者就是offset,后者就是count
// count string: limit后面逗号相隔的两个数值,前者就是offset,后者就是count
// 返回 (*sql.Rows, error): 就是原始sql接口返回的类型
func SQLQueryRows(feilds, tableName, where, order string, offset, count int) (*sql.Rows, error) {
if mDB == nil {
return nil, errors.New("gogo sql not init")
}
if feilds == "" {
feilds = "*"
}
sqlstr := "select " + feilds + " from " + tableName
if where != "" {
sqlstr += " where " + where
}
if order != "" {
sqlstr += " order by "
if strings.HasPrefix(order, "-") {
sqlstr += string([]byte(order)[1:]) + " desc"
} else if strings.HasPrefix(order, "+") {
sqlstr += string([]byte(order)[1:]) + " asc"
} else {
sqlstr += order + " asc"
}
}
if offset >= 0 && count > 0 {
sqlstr += " limit " + strconv.Itoa(offset) + "," + strconv.Itoa(count)
}
LogDebug(sqlstr)
return mDB.Query(sqlstr)
}
// SQLQueryByMap 将查询到的数据,按照指定字段的值做为索引构建map并返回
// columnName string: 作为索引的字段名称
// feilds string: 查询需要获取哪些字段的值,就是select后面跟着的部分,一般用"*"
// tableName string: 查询的表名
// where string: 过滤条件,就是where后面跟着的部分
// order string: 排序条件,就是order by后面跟着的部分。默认是ASC排序,除非"-"开头则DESC排序
// offset string: limit后面逗号相隔的两个数值,前者就是offset,后者就是count
// count string: limit后面逗号相隔的两个数值,前者就是offset,后者就是count
func SQLQueryByMap(columnName, feilds, tableName, where, order string, offset, count int) (interface{}, error) {
if err := sqlCheckParam(columnName + feilds + tableName + where + order); err != nil {
return 0, err
}
columnsType, columnsLen, queryData, queryCount, err := sqlQueryTable(feilds, tableName, where, order, offset, count)
if err != nil {
return nil, err
}
if queryCount == 0 {
return "", errors.New("0")
}
if columnName == "" {
return sqlQuery(columnsType, columnsLen, queryData, queryCount)
}
switch sqlGetColumnType(columnsType, columnsLen, columnName) {
case "TINYINT":
return sqlQueryByTinyIntMap(columnName, columnsType, columnsLen, queryData, queryCount)
case "SMALLINT":
return sqlQueryBySmallIntMap(columnName, columnsType, columnsLen, queryData, queryCount)
case "MEDIUMINT":
return sqlQueryByIntMap(columnName, columnsType, columnsLen, queryData, queryCount)
case "INT":
return sqlQueryByIntMap(columnName, columnsType, columnsLen, queryData, queryCount)
case "INTEGER":
return sqlQueryByIntMap(columnName, columnsType, columnsLen, queryData, queryCount)
case "BIGINT":
return sqlQueryByBigIntMap(columnName, columnsType, columnsLen, queryData, queryCount)
case "FLOAT":
return sqlQueryByFloatIntMap(columnName, columnsType, columnsLen, queryData, queryCount)
case "DOUBLE":
return sqlQueryByDoubleMap(columnName, columnsType, columnsLen, queryData, queryCount)
}
return sqlQueryByStringMap(columnName, columnsType, columnsLen, queryData, queryCount)
}
func sqlCheckParam(param string) error {
/*if strings.Contains(param, "where") {
return errors.New("can not have where")
}
if strings.Contains(param, "and") {
return errors.New("can not have and")
}
if strings.Contains(param, "or") {
return errors.New("can not have or")
}
if strings.Contains(param, "=") {
return errors.New("can not have =")
}*/
if strings.Contains(param, ";") {
return errors.New("can not have ;")
}
return nil
}
// sqlQueryTable 从数据库中查询到的数据,这里是以数组方式存储的,需要做二次转换
func sqlQueryTable(feilds, tableName, where, order string, offset, count int) ([]*sql.ColumnType, int, [][]interface{}, int, error) {
rows, err := SQLQueryRows(feilds, tableName, where, order, offset, count)
if err != nil {
return nil, 0, nil, 0, err
}
columnsType, _ := rows.ColumnTypes()
columnsLen := len(columnsType)
queryData := make([][]interface{}, count)
queryCount := 0
for rows.Next() {
queryData[queryCount] = make([]interface{}, columnsLen)
for a := 0; a < columnsLen; a++ {
switch columnsType[a].DatabaseTypeName() {
case "TINYINT":
{
queryData[queryCount][a] = new(int8)
}
case "SMALLINT":
{
queryData[queryCount][a] = new(int16)
}
case "MEDIUMINT":
{
queryData[queryCount][a] = new(int32)
}
case "INT":
{
queryData[queryCount][a] = new(int32)
}
case "INTEGER":
{
queryData[queryCount][a] = new(int32)
}
case "BIGINT":
{
queryData[queryCount][a] = new(int64)
}
case "FLOAT":
{
queryData[queryCount][a] = new(float32)
}
case "DOUBLE":
{
queryData[queryCount][a] = new(float64)
}
default:
{
queryData[queryCount][a] = new(string)
}
}
}
if err = rows.Scan(queryData[queryCount]...); err != nil {
//return nil, 0, nil, 0, err //如果遇到内容为NULL的字段,会在此处返回err
}
queryCount = queryCount + 1
}
return columnsType, columnsLen, queryData, queryCount, nil
}
// sqlGetValues 根据结构体中指向实际数据的指针获取出数据,并存储到另一张表中返回
func sqlGetValues(pvs []interface{}, columnsType []*sql.ColumnType, columnsLen int) map[string]interface{} {
result := make(map[string]interface{}, columnsLen)
for a := 0; a < columnsLen; a++ {
switch s := pvs[a].(type) {
case *int8:
result[columnsType[a].Name()] = *s
case *int16:
result[columnsType[a].Name()] = *s
case *int32:
result[columnsType[a].Name()] = *s
case *int64:
result[columnsType[a].Name()] = *s
case *float32:
result[columnsType[a].Name()] = *s
case *float64:
result[columnsType[a].Name()] = *s
case *string:
result[columnsType[a].Name()] = *s
}
}
return result
}
// 这里返回的是原始数组的基础上加上了字段名标识
func sqlQuery(columnsType []*sql.ColumnType, columnsLen int, queryData [][]interface{}, queryCount int) ([]map[string]interface{}, error) {
jsondata := make([]map[string]interface{}, queryCount)
for k1, v1 := range queryData {
if k1 >= queryCount {
break
}
jsondata[k1] = sqlGetValues(v1, columnsType, columnsLen)
}
return jsondata, nil
}
func sqlQueryByTinyIntMap(columnName string, columnsType []*sql.ColumnType, columnsLen int, queryData [][]interface{}, queryCount int) (map[int8]map[string]interface{}, error) {
jsondata := make(map[int8]map[string]interface{}, queryCount)
for k1, v1 := range queryData {
if k1 >= queryCount {
break
}
for a := 0; a < columnsLen; a++ {
if columnsType[a].Name() == columnName {
if value, ok := v1[a].(*int8); ok {
jsondata[*value] = sqlGetValues(v1, columnsType, columnsLen)
}
break
}
}
}
return jsondata, nil
}
func sqlQueryBySmallIntMap(columnName string, columnsType []*sql.ColumnType, columnsLen int, queryData [][]interface{}, queryCount int) (map[int16]map[string]interface{}, error) {
jsondata := make(map[int16]map[string]interface{}, queryCount)
for k1, v1 := range queryData {
if k1 >= queryCount {
break
}
for a := 0; a < columnsLen; a++ {
if columnsType[a].Name() == columnName {
if value, ok := v1[a].(*int16); ok {
jsondata[*value] = sqlGetValues(v1, columnsType, columnsLen)
}
break
}
}
}
return jsondata, nil
}
func sqlQueryByIntMap(columnName string, columnsType []*sql.ColumnType, columnsLen int, queryData [][]interface{}, queryCount int) (map[int32]map[string]interface{}, error) {
jsondata := make(map[int32]map[string]interface{}, queryCount)
for k1, v1 := range queryData {
if k1 >= queryCount {
break
}
for a := 0; a < columnsLen; a++ {
if columnsType[a].Name() == columnName {
if value, ok := v1[a].(*int32); ok {
jsondata[*value] = sqlGetValues(v1, columnsType, columnsLen)
}
break
}
}
}
return jsondata, nil
}
func sqlQueryByBigIntMap(columnName string, columnsType []*sql.ColumnType, columnsLen int, queryData [][]interface{}, queryCount int) (map[int64]map[string]interface{}, error) {
jsondata := make(map[int64]map[string]interface{}, queryCount)
for k1, v1 := range queryData {
if k1 >= queryCount {
break
}
for a := 0; a < columnsLen; a++ {
if columnsType[a].Name() == columnName {
if value, ok := v1[a].(*int64); ok {
jsondata[*value] = sqlGetValues(v1, columnsType, columnsLen)
}
break
}
}
}
return jsondata, nil
}
func sqlQueryByFloatIntMap(columnName string, columnsType []*sql.ColumnType, columnsLen int, queryData [][]interface{}, queryCount int) (map[float32]map[string]interface{}, error) {
jsondata := make(map[float32]map[string]interface{}, queryCount)
for k1, v1 := range queryData {
if k1 >= queryCount {
break
}
for a := 0; a < columnsLen; a++ {
if columnsType[a].Name() == columnName {
if value, ok := v1[a].(*float32); ok {
jsondata[*value] = sqlGetValues(v1, columnsType, columnsLen)
}
break
}
}
}
return jsondata, nil
}
func sqlQueryByDoubleMap(columnName string, columnsType []*sql.ColumnType, columnsLen int, queryData [][]interface{}, queryCount int) (map[float64]map[string]interface{}, error) {
jsondata := make(map[float64]map[string]interface{}, queryCount)
for k1, v1 := range queryData {
if k1 >= queryCount {
break
}
for a := 0; a < columnsLen; a++ {
if columnsType[a].Name() == columnName {
if value, ok := v1[a].(*float64); ok {
jsondata[*value] = sqlGetValues(v1, columnsType, columnsLen)
}
break
}
}
}
return jsondata, nil
}
func sqlQueryByStringMap(columnName string, columnsType []*sql.ColumnType, columnsLen int, queryData [][]interface{}, queryCount int) (map[string]map[string]interface{}, error) {
jsondata := make(map[string]map[string]interface{}, queryCount)
for k1, v1 := range queryData {
if k1 >= queryCount {
break
}
for a := 0; a < columnsLen; a++ {
if columnsType[a].Name() == columnName {
if value, ok := v1[a].(*string); ok {
jsondata[*value] = sqlGetValues(v1, columnsType, columnsLen)
}
break
}
}
}
return jsondata, nil
}
func sqlGetColumnType(columnsType []*sql.ColumnType, columnsLen int, valueName string) string {
for a := 0; a < columnsLen; a++ {
if columnsType[a].Name() == valueName {
return columnsType[a].DatabaseTypeName()
}
}
return ""
}
Go
1
https://gitee.com/konyshe/gogo.git
git@gitee.com:konyshe/gogo.git
konyshe
gogo
gogo
v2

搜索帮助