package utils import ( "encoding/json" "fmt" "reflect" "regexp" "strconv" "strings" "time" "runtime" "github.com/google/uuid" ) // SQL注入检测关键词 var sqlInjectionKeywords = []string{ "select ", "insert ", "update ", "delete ", "drop ", "truncate ", "union ", "create ", "alter ", "exec ", "execute ", "script", "'", "\"", "--", "#", "/*", "*/", "sleep(", "waitfor", } // 后缀常量 const ( suffixBegin = "_begin" suffixEnd = "_end" suffixLike = "_like" suffixNot = "_not" suffixIn = "_in" suffixNotIn = "_not_in" suffixNull = "_null" suffixNotNull = "_not_null" suffixApplySQL = "apply_sql" ) // 标签常量 const ( tagTableKey = "table_key" tagIdType = "id_type" tagDb = "db" ) // IdType ID类型常量 type IdType string const ( IdTypeAuto IdType = "AUTO" // 自增 IdTypeUUID IdType = "UUID" // UUID ) // SQLBuilder SQL构建器结构体 type SQLBuilder struct { conditions []string orderBy string groupBy string limit string } // NewSQLBuilder 创建新的SQL构建器 func NewSQLBuilder() *SQLBuilder { return &SQLBuilder{ conditions: make([]string, 0), } } // BuildWhereCondition 构建WHERE条件 func BuildWhereCondition(params map[string]interface{}) (string, error) { builder := NewSQLBuilder() // 首先处理分页、排序、分组参数 builder.extractPageParams(params) // 遍历参数构建条件 for key, value := range params { if err := builder.processCondition(key, value); err != nil { return "", err } } // 构建完整的SQL sql := builder.buildSQL() return sql, nil } // FormatSelectByIdSql 根据对象主键生成SELECT SQL语句 func FormatSelectByIdSql(obj interface{}) (string, error) { // 获取对象类型和值 objType := reflect.TypeOf(obj) objValue := reflect.ValueOf(obj) // 如果是指针,获取指向的值 if objType.Kind() == reflect.Ptr { objType = objType.Elem() objValue = objValue.Elem() } // 检查是否是结构体 if objType.Kind() != reflect.Struct { return "", fmt.Errorf("参数必须是结构体或结构体指针") } // 获取表名(结构体名转换为下划线) tableName := objType.Name() tableName = humpToUnderline(tableName) // 查找主键字段 primaryKeyField := "" primaryKeyValue := interface{}(nil) columnName := "" // 遍历结构体字段 for i := 0; i < objType.NumField(); i++ { field := objType.Field(i) // 检查 TableKey 标签 if tableKey := field.Tag.Get(tagTableKey); tableKey == "true" { // 获取字段值 fieldValue := getRealValue(objValue.Field(i).Interface()) primaryKeyField = field.Name primaryKeyValue = fieldValue if tableKey := field.Tag.Get(tagDb); tableKey != "-" { columnName = tableKey } break } } // 如果没有找到带标签的主键,尝试查找名为"id"的字段 if primaryKeyField == "" { // 先尝试查找导出字段"Id" if field, ok := objType.FieldByName("Id"); ok { fieldValue := getRealValue(objValue.FieldByName("Id").Interface()) primaryKeyField = field.Name primaryKeyValue = fieldValue if tableKey := field.Tag.Get(tagDb); tableKey != "-" { columnName = tableKey } } else if field, ok := objType.FieldByName("ID"); ok { // 尝试查找"ID" fieldValue := getRealValue(objValue.FieldByName("ID").Interface()) primaryKeyField = field.Name primaryKeyValue = fieldValue if tableKey := field.Tag.Get(tagDb); tableKey != "-" { columnName = tableKey } } else { // 尝试查找小写"id"(非导出字段,通过CanInterface检查) for i := 0; i < objType.NumField(); i++ { field := objType.Field(i) if strings.ToLower(field.Name) == "id" { fieldValue := objValue.Field(i) if fieldValue.CanInterface() { primaryKeyField = field.Name primaryKeyValue = getRealValue(fieldValue.Interface()) if tableKey := field.Tag.Get(tagDb); tableKey != "-" { columnName = tableKey } break } } } } } // 如果仍然没有找到主键字段,返回错误 if primaryKeyField == "" { return "", fmt.Errorf("未找到主键字段") } // 如果主键值为空,返回错误 if primaryKeyValue == nil { return "", fmt.Errorf("主键值为空") } if columnName == "" { columnName = humpToUnderline(primaryKeyField) } // 根据值类型决定是否添加引号 sql := fmt.Sprintf("select * from %s where %s = ", tableName, columnName) // 根据值类型添加引号 sql += formatWhereCondition(primaryKeyValue) return sql + " limit 1", nil } // FormatSelectCountSql 根据对象非空字段生成COUNT SQL语句 func FormatSelectCountSql(obj interface{}) (string, error) { // 获取对象类型和值 objType := reflect.TypeOf(obj) objValue := reflect.ValueOf(obj) // 如果是指针,获取指向的值 if objType.Kind() == reflect.Ptr { objType = objType.Elem() objValue = objValue.Elem() } // 检查是否是结构体 if objType.Kind() != reflect.Struct { return "", fmt.Errorf("参数必须是结构体或结构体指针") } // 获取表名(结构体名转换为下划线) tableName := objType.Name() tableName = humpToUnderline(tableName) // 构建WHERE条件 var whereConditions []string // 遍历结构体字段 for i := 0; i < objType.NumField(); i++ { field := objType.Field(i) fieldValue := objValue.Field(i) if db := field.Tag.Get(tagDb); db != "" && db == "-" { // 跳过非表字段 continue } // 检查字段值是否为空 if isEmptyValue(fieldValue) { continue } // 获取字段值(处理指针) actualValue := getRealValue(fieldValue.Interface()) // 获取字段名并转换为下划线 column := field.Tag.Get(tagDb) // 构建条件 condition := fmt.Sprintf("`%s` = ", column) // 根据字段类型决定是否添加引号 condition += formatWhereCondition(actualValue) whereConditions = append(whereConditions, condition) } // 构建完整的SQL sql := fmt.Sprintf("select count(1) from %s", tableName) if len(whereConditions) > 0 { sql += " where 1=1" for _, condition := range whereConditions { sql += " and " + condition } } return sql, nil } // FormatSelectCountSqlByMap 根据类和参数map生成COUNT SQL语句 func FormatSelectCountSqlByMap(clazz interface{}, params map[string]interface{}) (string, error) { // 获取类型 var objType reflect.Type switch v := clazz.(type) { case reflect.Type: objType = v default: objType = reflect.TypeOf(clazz) } // 如果是指针,获取指向的类型 if objType.Kind() == reflect.Ptr { objType = objType.Elem() } // 获取表名(类型名转换为下划线) tableName := objType.Name() tableName = humpToUnderline(tableName) // 构建WHERE条件 whereSQL, err := BuildWhereCondition(params) if err != nil { return "", err } // 构建完整的SQL sql := fmt.Sprintf("select count(1) from %s%s", tableName, whereSQL) return sql, nil } // FormatSelectSql 根据对象非空字段生成SELECT SQL语句(限制1条) func FormatSelectSql(obj interface{}) (string, error) { // 获取对象类型和值 objType := reflect.TypeOf(obj) objValue := reflect.ValueOf(obj) // 如果是指针,获取指向的值 if objType.Kind() == reflect.Ptr { objType = objType.Elem() objValue = objValue.Elem() } // 检查是否是结构体 if objType.Kind() != reflect.Struct { return "", fmt.Errorf("参数必须是结构体或结构体指针") } // 获取表名(结构体名转换为下划线) tableName := objType.Name() tableName = humpToUnderline(tableName) // 构建WHERE条件 var whereConditions []string // 遍历结构体字段 for i := 0; i < objType.NumField(); i++ { field := objType.Field(i) fieldValue := objValue.Field(i) if db := field.Tag.Get(tagDb); db != "" && db == "-" { continue } // 检查字段值是否为空 if isEmptyValue(fieldValue) { continue } // 获取字段值(处理指针) actualValue := getRealValue(fieldValue.Interface()) // 获取字段名并转换为下划线 column := field.Tag.Get(tagDb) // 构建条件 condition := fmt.Sprintf("`%s` = ", column) // 根据字段类型决定是否添加引号 condition += formatWhereCondition(actualValue) whereConditions = append(whereConditions, condition) } // 检查是否有条件 if len(whereConditions) == 0 { return "", fmt.Errorf("关键参数缺失") } // 构建完整的SQL sql := fmt.Sprintf("select * from %s where 1=1", tableName) for _, condition := range whereConditions { sql += " and " + condition } return sql + " limit 1", nil } // FormatSelectSqlByMap 根据参数map和类生成SELECT SQL语句(限制1条) func FormatSelectSqlByMap(params map[string]interface{}, clazz interface{}) (string, error) { // 获取类型 var objType reflect.Type switch v := clazz.(type) { case reflect.Type: objType = v default: objType = reflect.TypeOf(clazz) } // 如果是指针,获取指向的类型 if objType.Kind() == reflect.Ptr { objType = objType.Elem() } // 获取表名(类型名转换为下划线) tableName := objType.Name() tableName = humpToUnderline(tableName) // 构建WHERE条件 whereSQL, err := BuildWhereCondition(params) if err != nil { return "", err } // 检查是否有条件(排除默认的1=1) if whereSQL == " WHERE 1=1" { return "", fmt.Errorf("关键参数缺失") } // 构建完整的SQL sql := fmt.Sprintf("select * from %s%s limit 1", tableName, whereSQL) return sql, nil } // FormatSelectSqlByMapAndSql 根据参数map和基础SQL生成SELECT SQL语句 func FormatSelectSqlByMapAndSql(params map[string]interface{}, baseSql string) (string, error) { // 构建WHERE条件 whereSQL, err := BuildWhereCondition(params) if err != nil { return "", err } // 组合SQL sql := baseSql + whereSQL return sql, nil } // FormatInsertSql 根据对象生成INSERT SQL语句(修复指针和自增ID问题) func FormatInsertSql(obj interface{}, uuid string) (string, error) { // 获取对象类型和值 objType := reflect.TypeOf(obj) objValue := reflect.ValueOf(obj) // 如果是指针,获取指向的值 if objType.Kind() == reflect.Ptr { objType = objType.Elem() objValue = objValue.Elem() } // 检查是否是结构体 if objType.Kind() != reflect.Struct { return "", fmt.Errorf("参数必须是结构体或结构体指针") } // 获取表名(结构体名转换为下划线) tableName := objType.Name() tableName = humpToUnderline(tableName) // 构建列名和值的列表 var columns []string var values []string // 遍历结构体字段 for i := 0; i < objType.NumField(); i++ { field := objType.Field(i) fieldValue := objValue.Field(i) if db := field.Tag.Get(tagDb); db != "" && db == "-" { continue } // 处理特殊字段 fieldName := field.Name // 处理id字段 - 特别处理自增ID if strings.ToLower(fieldName) == "id" { // 获取实际值 actualValue := getRealValue(fieldValue.Interface()) // 检查是否为自增ID idType := field.Tag.Get(tagIdType) if idType == string(IdTypeAuto) { // 自增ID,如果值为0或nil则跳过 if actualValue == nil { continue } if intVal, ok := actualValue.(int64); ok && intVal == 0 { continue } if intVal, ok := actualValue.(int32); ok && intVal == 0 { continue } if intVal, ok := actualValue.(int); ok && intVal == 0 { continue } } else if idType == string(IdTypeUUID) || idType == "" { // UUID或未指定,检查值是否为空 if actualValue == nil { // 生成UUID uuidStr := uuid columns = append(columns, fmt.Sprintf("`%s`", humpToUnderline(fieldName))) values = append(values, fmt.Sprintf("'%s'", escapeSQLString(uuidStr))) continue } } } // 处理时间字段 if strings.ToLower(fieldName) == "createtime" || strings.ToLower(fieldName) == "updatetime" { // 获取实际值 actualValue := getRealValue(fieldValue.Interface()) if actualValue == nil { // 设置为当前时间 now := time.Now() create_time := now.Format("2006-01-02 15:04:05") columns = append(columns, fmt.Sprintf("`%s`", humpToUnderline(fieldName))) if field.Type.String() == "time.Time" { values = append(values, fmt.Sprintf("'%s'", create_time)) } else { values = append(values, fmt.Sprintf("'%s'", create_time)) } continue } } if strings.HasSuffix(fieldName, "JsonList") && fieldValue.Type() == reflect.TypeOf((*[]string)(nil)) { actualValue := FormatToString(fieldValue.Interface()) if len(actualValue) > 0 { jsonBytes, err := json.Marshal(actualValue) if err != nil { build_time_str := time.Now().Format(time.DateTime) _, file, line, _ := runtime.Caller(0) fmt.Printf("%s JSON编码失败: %s %d======> %v\n", build_time_str, file, line, err) continue } jsonStr := string(jsonBytes) column := field.Tag.Get(tagDb) columns = append(columns, fmt.Sprintf("`%s`", column)) values = append(values, fmt.Sprintf("'%s'", jsonStr)) continue } } // 获取字段值(处理指针) actualValue := getRealValue(fieldValue.Interface()) // if actualValue == nil { // 指针为nil,跳过或插入NULL,根据需求决定 // 这里选择跳过,如果你需要插入NULL,可以取消下面的注释 // columns = append(columns, fmt.Sprintf("`%s`", humpToUnderline(fieldName))) // values = append(values, "NULL") // continue // } // SQL注入检测 // if containsSQLInjection(actualValue) { // return "", fmt.Errorf("非法参数: 字段 %s 包含SQL注入关键词", fieldName) // } // 获取列名(驼峰转下划线) column := field.Tag.Get(tagDb) columns = append(columns, fmt.Sprintf("`%s`", column)) // 根据字段类型构建值 valueStr := formatValue(actualValue) values = append(values, valueStr) } // 检查是否有列 if len(columns) == 0 { return "", fmt.Errorf("没有有效的字段可以插入") } // 构建SQL sql := fmt.Sprintf("insert into %s (%s) values (%s)", tableName, strings.Join(columns, ", "), strings.Join(values, ", ")) return sql, nil } // FormatUpdateByIdSql 根据对象主键生成UPDATE SQL语句 func FormatUpdateByIdSql(obj interface{}) (string, error) { // 获取对象类型和值 objType := reflect.TypeOf(obj) objValue := reflect.ValueOf(obj) // 如果是指针,获取指向的值 if objType.Kind() == reflect.Ptr { objType = objType.Elem() objValue = objValue.Elem() } // 检查是否是结构体 if objType.Kind() != reflect.Struct { return "", fmt.Errorf("参数必须是结构体或结构体指针") } // 查找主键字段和值 primaryKeyField := "" primaryKeyValue := interface{}(nil) columnName := "" // 遍历结构体字段查找主键 for i := 0; i < objType.NumField(); i++ { field := objType.Field(i) // 检查 TableKey 标签 if tableKey := field.Tag.Get(tagTableKey); tableKey == "true" { fieldValue := getRealValue(objValue.Field(i).Interface()) primaryKeyField = field.Name primaryKeyValue = fieldValue if tableKey := field.Tag.Get(tagDb); tableKey != "-" { columnName = tableKey } break } } // 如果没有找到带标签的主键,尝试查找名为"id"的字段 if primaryKeyField == "" { // 先尝试查找导出字段"Id" if field, ok := objType.FieldByName("Id"); ok { fieldValue := getRealValue(objValue.FieldByName("Id").Interface()) primaryKeyField = field.Name primaryKeyValue = fieldValue if tableKey := field.Tag.Get(tagDb); tableKey != "-" { columnName = tableKey } } else if field, ok := objType.FieldByName("ID"); ok { // 尝试查找"ID" fieldValue := getRealValue(objValue.FieldByName("ID").Interface()) primaryKeyField = field.Name primaryKeyValue = fieldValue if tableKey := field.Tag.Get(tagDb); tableKey != "-" { columnName = tableKey } } else { // 尝试查找小写"id"(非导出字段,通过CanInterface检查) for i := 0; i < objType.NumField(); i++ { field := objType.Field(i) if strings.ToLower(field.Name) == "id" { fieldValue := objValue.Field(i) if fieldValue.CanInterface() { primaryKeyField = field.Name primaryKeyValue = getRealValue(fieldValue.Interface()) if tableKey := field.Tag.Get(tagDb); tableKey != "-" { columnName = tableKey } break } } } } } // 如果仍然没有找到主键字段,返回错误 if primaryKeyField == "" { return "", fmt.Errorf("未找到主键字段") } // 如果主键值为空,返回错误 if primaryKeyValue == nil { return "", fmt.Errorf("主键值为空") } if columnName == "" { columnName = humpToUnderline(primaryKeyField) } // 创建WHERE条件map whereMap := map[string]interface{}{ columnName: primaryKeyValue, } // 调用FormatUpdateSql return FormatUpdateSql(obj, whereMap) } // FormatUpdateSql 根据对象和WHERE条件map生成UPDATE SQL语句 func FormatUpdateSql(obj interface{}, whereMap map[string]interface{}) (string, error) { // 获取对象类型和值 objType := reflect.TypeOf(obj) objValue := reflect.ValueOf(obj) // 如果是指针,获取指向的值 if objType.Kind() == reflect.Ptr { objType = objType.Elem() objValue = objValue.Elem() } // 检查是否是结构体 if objType.Kind() != reflect.Struct { return "", fmt.Errorf("参数必须是结构体或结构体指针") } // 获取表名(结构体名转换为下划线) tableName := objType.Name() tableName = humpToUnderline(tableName) // 构建SET子句 var setClauses []string // 遍历结构体字段 for i := 0; i < objType.NumField(); i++ { field := objType.Field(i) fieldValue := objValue.Field(i) fieldName := field.Name if db := field.Tag.Get(tagDb); db == "-" { // 检查是否为emptyField字段(类型为[]string或*[]string) if strings.ToLower(fieldName) == "emptyfield" { // 获取字段值 actualValue := getRealValue(fieldValue.Interface()) if actualValue != nil { // 尝试将actualValue转换为[]string var strSlice []string // 判断actualValue的类型 switch v := actualValue.(type) { case []string: strSlice = v case *[]string: if v != nil { strSlice = *v } else { continue } default: // 如果不是这两种类型,跳过 continue } if len(strSlice) == 0 { continue } else { // 遍历切片中的每个字段名,生成条件并添加到setClauses for _, fieldNameToSet := range strSlice { // 用反引号括起字段名,避免SQL关键字冲突 columnName := humpToUnderline(fieldNameToSet) condition := fmt.Sprintf("`%s` = NULL", columnName) setClauses = append(setClauses, condition) } } // 处理完emptyField后继续下一个字段 continue } } continue } // 跳过主键字段 if strings.ToLower(fieldName) == "id" || field.Tag.Get(tagTableKey) == "true" { continue } // 处理updateTime字段 if strings.ToLower(fieldName) == "updatetime" { // 获取实际值 actualValue := getRealValue(fieldValue.Interface()) if actualValue == nil { // 设置为当前时间 now := time.Now() // 获取列名(驼峰转下划线) column := field.Tag.Get(tagDb) // 构建SET子句 timeStr := now.Format("2006-01-02 15:04:05") setClause := fmt.Sprintf("`%s` = '%s'", column, timeStr) setClauses = append(setClauses, setClause) continue } } if strings.HasSuffix(fieldName, "JsonList") && fieldValue.Type() == reflect.TypeOf((*[]string)(nil)) { actualValue := FormatToString(fieldValue.Interface()) if len(actualValue) > 0 { jsonBytes, err := json.Marshal(actualValue) if err != nil { build_time_str := time.Now().Format(time.DateTime) _, file, line, _ := runtime.Caller(0) fmt.Printf("%s JSON编码失败: %s %d======> %v\n", build_time_str, file, line, err) continue } jsonStr := string(jsonBytes) column := field.Tag.Get(tagDb) setClause := fmt.Sprintf("`%s` = '%s'", column, jsonStr) setClauses = append(setClauses, setClause) continue } } // 检查字段值是否为空 if isEmptyValue(fieldValue) { continue } // 获取字段值(处理指针) actualValue := getRealValue(fieldValue.Interface()) // SQL注入检测 // if containsSQLInjection(actualValue) { // return "", fmt.Errorf("非法参数: 字段 %s 包含SQL注入关键词", fieldName) // } // 获取列名(驼峰转下划线) column := field.Tag.Get(tagDb) // 构建SET子句 valueStr := formatValue(actualValue) setClause := fmt.Sprintf("`%s` = %s", column, valueStr) setClauses = append(setClauses, setClause) } // 检查是否有SET子句 if len(setClauses) == 0 { return "", fmt.Errorf("没有有效的字段可以更新") } // 检查WHERE条件 if len(whereMap) == 0 { return "", fmt.Errorf("关键参数缺失: WHERE条件不能为空") } // 构建WHERE条件 whereSQL, err := BuildWhereCondition(whereMap) if err != nil { return "", err } // 检查WHERE条件是否有效(排除默认的1=1) if whereSQL == " WHERE 1=1" { return "", fmt.Errorf("关键参数缺失: WHERE条件无效") } // 构建SQL sql := fmt.Sprintf("update %s set %s%s", tableName, strings.Join(setClauses, ", "), whereSQL) return sql, nil } // FormatRemoveByIdSql 根据对象主键生成DELETE SQL语句 func FormatRemoveByIdSql(obj interface{}) (string, error) { // 获取对象类型和值 objType := reflect.TypeOf(obj) objValue := reflect.ValueOf(obj) // 如果是指针,获取指向的值 if objType.Kind() == reflect.Ptr { objType = objType.Elem() objValue = objValue.Elem() } // 检查是否是结构体 if objType.Kind() != reflect.Struct { return "", fmt.Errorf("参数必须是结构体或结构体指针") } // 获取表名(结构体名转换为下划线) tableName := objType.Name() tableName = humpToUnderline(tableName) // 查找主键字段 primaryKeyField := "" primaryKeyValue := interface{}(nil) columnName := "" // 遍历结构体字段 for i := 0; i < objType.NumField(); i++ { field := objType.Field(i) // 检查 TableKey 标签 if tableKey := field.Tag.Get(tagTableKey); tableKey == "true" { // 获取字段值 fieldValue := getRealValue(objValue.Field(i).Interface()) primaryKeyField = field.Name primaryKeyValue = fieldValue if tableKey := field.Tag.Get(tagDb); tableKey != "-" { columnName = tableKey } break } } // 如果没有找到带标签的主键,尝试查找名为"id"的字段 if primaryKeyField == "" { // 先尝试查找导出字段"Id" if field, ok := objType.FieldByName("Id"); ok { fieldValue := getRealValue(objValue.FieldByName("Id").Interface()) primaryKeyField = field.Name primaryKeyValue = fieldValue if tableKey := field.Tag.Get(tagDb); tableKey != "-" { columnName = tableKey } } else if field, ok := objType.FieldByName("ID"); ok { // 尝试查找"ID" fieldValue := getRealValue(objValue.FieldByName("ID").Interface()) primaryKeyField = field.Name primaryKeyValue = fieldValue if tableKey := field.Tag.Get(tagDb); tableKey != "-" { columnName = tableKey } } else { // 尝试查找小写"id"(非导出字段,通过CanInterface检查) for i := 0; i < objType.NumField(); i++ { field := objType.Field(i) if strings.ToLower(field.Name) == "id" { fieldValue := objValue.Field(i) if fieldValue.CanInterface() { primaryKeyField = field.Name primaryKeyValue = getRealValue(fieldValue.Interface()) if tableKey := field.Tag.Get(tagDb); tableKey != "-" { columnName = tableKey } break } } } } } // 如果仍然没有找到主键字段,返回错误 if primaryKeyField == "" { return "", fmt.Errorf("未找到主键字段") } // 如果主键值为空,返回错误 if primaryKeyValue == nil { return "", fmt.Errorf("主键值为空") } if columnName == "" { columnName = humpToUnderline(primaryKeyField) } // 根据值类型决定是否添加引号 sql := fmt.Sprintf("delete from %s where %s = ", tableName, columnName) // 根据值类型添加引号 sql += formatWhereCondition(primaryKeyValue) return sql, nil } // FormatRemoveSql 根据对象非空字段生成DELETE SQL语句 func FormatRemoveSql(obj interface{}) (string, error) { // 获取对象类型和值 objType := reflect.TypeOf(obj) objValue := reflect.ValueOf(obj) // 如果是指针,获取指向的值 if objType.Kind() == reflect.Ptr { objType = objType.Elem() objValue = objValue.Elem() } // 检查是否是结构体 if objType.Kind() != reflect.Struct { return "", fmt.Errorf("参数必须是结构体或结构体指针") } // 获取表名(结构体名转换为下划线) tableName := objType.Name() tableName = humpToUnderline(tableName) // 将对象转换为map paramMap := objectToMap(obj) // 构建WHERE条件 whereSQL, err := BuildWhereCondition(paramMap) if err != nil { return "", err } // 检查是否有条件(排除默认的1=1) if whereSQL == " WHERE 1=1" { return "", fmt.Errorf("关键参数缺失") } // 构建完整的SQL sql := fmt.Sprintf("delete from %s%s", tableName, whereSQL) return sql, nil } // getRealValue 获取参数的实际值(处理指针) func getRealValue(value interface{}) interface{} { if value == nil { return nil } v := reflect.ValueOf(value) // 如果是指针类型,获取指针指向的值 if v.Kind() == reflect.Ptr { if v.IsNil() { return nil } // 递归调用,获取指针指向的实际值 return getRealValue(v.Elem().Interface()) } return value } // processCondition 处理单个条件 func (b *SQLBuilder) processCondition(key string, value interface{}) error { if key == "" || value == nil { return nil } // 获取实际值(处理指针) realValue := getRealValue(value) // SQL注入检测 if containsSQLInjection(realValue) && key != suffixApplySQL { println("非法参数", key, toString(realValue)) return fmt.Errorf("非法参数") } column := humpToUnderline(key) // 根据后缀处理不同的条件类型 switch { case strings.HasSuffix(column, suffixBegin): b.addRangeCondition(column, realValue, ">=", suffixBegin) case strings.HasSuffix(column, suffixEnd): b.addRangeCondition(column, realValue, "<=", suffixEnd) case strings.HasSuffix(column, suffixLike): b.addLikeCondition(column, realValue) case strings.HasSuffix(column, suffixNot): b.addNotEqualCondition(column, realValue) case strings.HasSuffix(column, suffixIn): b.addInCondition(column, realValue, false) case strings.HasSuffix(column, suffixNotIn): b.addInCondition(column, realValue, true) case strings.HasSuffix(column, suffixNull): b.addNullCondition(column, false) case strings.HasSuffix(column, suffixNotNull): b.addNullCondition(column, true) case column == suffixApplySQL: b.addCustomSQLCondition(realValue) default: b.addEqualCondition(column, realValue) } return nil } // addRangeCondition 添加范围条件 func (b *SQLBuilder) addRangeCondition(column string, value interface{}, operator, suffix string) { col := strings.TrimSuffix(column, suffix) condition := fmt.Sprintf("%s %s %s", quoteColumn(col), operator, formatValue(value)) b.conditions = append(b.conditions, condition) } // addLikeCondition 添加LIKE条件 func (b *SQLBuilder) addLikeCondition(column string, value interface{}) { col := strings.TrimSuffix(column, suffixLike) condition := fmt.Sprintf("%s LIKE '%%%s%%'", quoteColumn(col), escapeSQLString(toString(value))) b.conditions = append(b.conditions, condition) } // addNotEqualCondition 添加不等于条件 func (b *SQLBuilder) addNotEqualCondition(column string, value interface{}) { col := strings.TrimSuffix(column, suffixNot) condition := fmt.Sprintf("%s != %s", quoteColumn(col), formatValue(value)) b.conditions = append(b.conditions, condition) } // addInCondition 添加IN/NOT IN条件 func (b *SQLBuilder) addInCondition(column string, value interface{}, isNotIn bool) { // 获取实际值(处理指针) realValue := getRealValue(value) // 尝试将值转换为切片 var list []interface{} switch v := realValue.(type) { case []interface{}: list = v case []string: list = make([]interface{}, len(v)) for i, item := range v { list[i] = item } case []int: list = make([]interface{}, len(v)) for i, item := range v { list[i] = item } case []int64: list = make([]interface{}, len(v)) for i, item := range v { list[i] = item } case []float64: list = make([]interface{}, len(v)) for i, item := range v { list[i] = item } default: // 如果不是切片类型,直接返回 return } // 处理切片中的指针元素 for i, item := range list { list[i] = getRealValue(item) } if len(list) == 0 { list = []interface{}{"null"} } col := strings.TrimSuffix(column, func() string { if isNotIn { return suffixNotIn } return suffixIn }()) // 格式化每个值 formattedValues := make([]string, len(list)) for i, item := range list { formattedValues[i] = formatValue(item) } operator := "IN" if isNotIn { operator = "NOT IN" } condition := fmt.Sprintf("%s %s (%s)", quoteColumn(col), operator, strings.Join(formattedValues, ",")) b.conditions = append(b.conditions, condition) } // addNullCondition 添加NULL条件 func (b *SQLBuilder) addNullCondition(column string, isNotNull bool) { col := strings.TrimSuffix(column, func() string { if isNotNull { return suffixNotNull } return suffixNull }()) operator := "IS NULL" if isNotNull { operator = "IS NOT NULL" } condition := fmt.Sprintf("%s %s", quoteColumn(col), operator) b.conditions = append(b.conditions, condition) } // addCustomSQLCondition 添加自定义SQL条件 func (b *SQLBuilder) addCustomSQLCondition(value interface{}) { // 注意:自定义SQL条件不使用参数化,需要调用者确保安全 condition := fmt.Sprintf("(%s)", toString(value)) b.conditions = append(b.conditions, condition) } // addEqualCondition 添加等于条件 func (b *SQLBuilder) addEqualCondition(column string, value interface{}) { condition := fmt.Sprintf("%s = %s", quoteColumn(column), formatValue(value)) b.conditions = append(b.conditions, condition) } // extractPageParams 提取分页、排序、分组参数 func (b *SQLBuilder) extractPageParams(params map[string]interface{}) { // 处理分页 if current, ok := params["current"]; ok { currentInt := 0 switch v := current.(type) { case int: currentInt = v case float64: currentInt = int(v) case string: if val, err := strconv.Atoi(v); err == nil { currentInt = val } } if size, ok := params["size"]; ok && currentInt > 0 { sizeInt := 0 switch v := size.(type) { case int: sizeInt = v case float64: sizeInt = int(v) case string: if val, err := strconv.Atoi(v); err == nil { sizeInt = val } } if sizeInt > 0 { offset := (currentInt - 1) * sizeInt b.limit = fmt.Sprintf(" LIMIT %d, %d", offset, sizeInt) delete(params, "current") delete(params, "size") } } } // 处理排序 if orderBy, ok := params["orderBy"]; ok { orderByStr := toString(orderBy) orders := strings.Split(orderByStr, ",") orderClauses := make([]string, 0, len(orders)) for _, order := range orders { parts := strings.Split(order, "_") if len(parts) == 2 { col := humpToUnderline(parts[0]) orderClauses = append(orderClauses, fmt.Sprintf("%s %s", quoteColumn(col), parts[1])) } } if len(orderClauses) > 0 { b.orderBy = " ORDER BY " + strings.Join(orderClauses, ", ") } delete(params, "orderBy") } // 处理分组 if groupBy, ok := params["groupBy"]; ok { groupByStr := toString(groupBy) groups := strings.Split(groupByStr, "-") groupClauses := make([]string, 0, len(groups)) for _, group := range groups { col := humpToUnderline(group) groupClauses = append(groupClauses, quoteColumn(col)) } if len(groupClauses) > 0 { b.groupBy = " GROUP BY " + strings.Join(groupClauses, ", ") } delete(params, "groupBy") } } // buildSQL 构建完整的SQL语句 func (b *SQLBuilder) buildSQL() string { var sql strings.Builder if len(b.conditions) > 0 { sql.WriteString(" WHERE 1=1") for _, condition := range b.conditions { sql.WriteString(" AND ") sql.WriteString(condition) } } sql.WriteString(b.orderBy) sql.WriteString(b.groupBy) sql.WriteString(b.limit) return sql.String() } // objectToMap 将对象转换为map[string]interface{} func objectToMap(obj interface{}) map[string]interface{} { result := make(map[string]interface{}) // 获取对象类型和值 objType := reflect.TypeOf(obj) objValue := reflect.ValueOf(obj) // 如果是指针,获取指向的值 if objType.Kind() == reflect.Ptr { objType = objType.Elem() objValue = objValue.Elem() } // 遍历结构体字段 for i := 0; i < objType.NumField(); i++ { field := objType.Field(i) fieldValue := objValue.Field(i) if db := field.Tag.Get(tagDb); db != "" && db == "-" { continue } // 检查字段值是否为空 if isEmptyValue(fieldValue) { continue } // 添加字段到map(处理指针) result[field.Name] = getRealValue(fieldValue.Interface()) } return result } func formatWhereCondition(v interface{}) string { sql := "" actualValue := getRealValue(v) // 根据值类型添加引号 switch val := actualValue.(type) { case string: sql = fmt.Sprintf("'%s'", escapeSQLString(val)) case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64, float32, float64: sql = fmt.Sprintf("%v", val) case bool: if val { sql = "true" } else { sql = "false" } default: // 对于其他类型,使用字符串表示 sql = fmt.Sprintf("'%v'", escapeSQLString(fmt.Sprintf("%v", val))) } return sql } // isEmptyValue 检查值是否为空(零值) // 核心调整:数值类型(int/uint/float)的 0 不再视为空值,会保留原值 func isEmptyValue(v reflect.Value) bool { switch v.Kind() { case reflect.String: // 空字符串视为空 return v.String() == "" case reflect.Bool: // 布尔类型无"空"概念,始终返回 false return false case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: // 整数 0 不视为空,返回 false return false case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: // 无符号整数 0 不视为空,返回 false return false case reflect.Float32, reflect.Float64: // 浮点数 0 不视为空,返回 false return false case reflect.Ptr, reflect.Interface: // nil 指针/接口视为空 return v.IsNil() case reflect.Slice, reflect.Map, reflect.Array: // 长度为 0 的切片/Map/数组视为空 return v.Len() == 0 default: // 其他类型(如结构体):Go 1.13+ 用 IsZero 判断,但排除数值 0 场景 if v.CanInterface() { // 先判断是否是数值类型(防止结构体字段中的 0 被误判) switch v.Type().Kind() { case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Float32, reflect.Float64: return false // 数值类型无论值是多少,都不视为空 default: // 非数值类型,用 Zero 判断是否为零值 return reflect.DeepEqual(v.Interface(), reflect.Zero(v.Type()).Interface()) } } return false } } // formatValue 格式化值 func formatValue(v interface{}) string { if v == nil { return "NULL" } // 获取实际值(处理指针) actualValue := getRealValue(v) switch val := actualValue.(type) { case string: return fmt.Sprintf("'%s'", escapeSQLString(val)) case *string: if val == nil { return "NULL" } return fmt.Sprintf("'%s'", escapeSQLString(*val)) case bool: if val { return "1" } return "0" case *bool: if val == nil { return "NULL" } if *val { return "1" } return "0" case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64: return fmt.Sprintf("%v", val) case *int, *int8, *int16, *int32, *int64, *uint, *uint8, *uint16, *uint32, *uint64: // 解引用指针 rv := reflect.ValueOf(actualValue) if rv.IsNil() { return "NULL" } return fmt.Sprintf("%v", rv.Elem().Interface()) case float32, float64: return fmt.Sprintf("%v", val) case *float32, *float64: rv := reflect.ValueOf(actualValue) if rv.IsNil() { return "NULL" } return fmt.Sprintf("%v", rv.Elem().Interface()) case time.Time: return fmt.Sprintf("'%s'", val.Format("2006-01-02 15:04:05")) case *time.Time: if val == nil { return "NULL" } return fmt.Sprintf("'%s'", val.Format("2006-01-02 15:04:05")) default: // 对于其他类型,尝试转换为字符串 str := fmt.Sprintf("%v", actualValue) return fmt.Sprintf("'%s'", escapeSQLString(str)) } } // escapeSQLString 转义SQL字符串中的特殊字符 func escapeSQLString(str string) string { // 转义单引号 str = strings.ReplaceAll(str, "'", "''") // 转义反斜杠 str = strings.ReplaceAll(str, "\\", "\\\\") return str } // generateUUID 生成UUID func GenerateUUID() string { uuid := uuid.New() return uuid.String() } // containsSQLInjection 检测SQL注入 func containsSQLInjection(value interface{}) bool { // 获取实际值(处理指针) realValue := getRealValue(value) if realValue == nil { return false } str := strings.ToLower(toString(realValue)) for _, keyword := range sqlInjectionKeywords { if strings.Contains(str, keyword) { return true } } // 检测特殊字符 if strings.Contains(str, ";") || strings.Contains(str, "\\") { return true } return false } // humpToUnderline 驼峰转下划线 func humpToUnderline(str string) string { // 如果已经是下划线格式,直接返回 if strings.Contains(str, "_") { return strings.ToLower(str) } // 处理连续大写字母的情况 re := regexp.MustCompile(`([A-Z]+)([A-Z][a-z])`) str = re.ReplaceAllString(str, "${1}_${2}") // 处理单个大写字母的情况 re = regexp.MustCompile(`([a-z])([A-Z])`) return strings.ToLower(re.ReplaceAllString(str, "${1}_${2}")) } // quoteColumn 给列名添加引号 func quoteColumn(column string) string { // 如果已经是引号包裹的,直接返回 if strings.HasPrefix(column, "`") && strings.HasSuffix(column, "`") { return column } if strings.HasPrefix(column, "\"") && strings.HasSuffix(column, "\"") { return column } if strings.HasPrefix(column, "[") && strings.HasSuffix(column, "]") { return column } return fmt.Sprintf("`%s`", column) } // toString 将任意类型转为字符串(处理指针) func toString(value interface{}) string { // 获取实际值(处理指针) realValue := getRealValue(value) switch v := realValue.(type) { case string: return v case int: return strconv.Itoa(v) case int8: return strconv.FormatInt(int64(v), 10) case int16: return strconv.FormatInt(int64(v), 10) case int32: return strconv.FormatInt(int64(v), 10) case int64: return strconv.FormatInt(v, 10) case uint: return strconv.FormatUint(uint64(v), 10) case uint8: return strconv.FormatUint(uint64(v), 10) case uint16: return strconv.FormatUint(uint64(v), 10) case uint32: return strconv.FormatUint(uint64(v), 10) case uint64: return strconv.FormatUint(v, 10) case float32: return strconv.FormatFloat(float64(v), 'f', -1, 32) case float64: return strconv.FormatFloat(v, 'f', -1, 64) case bool: if v { return "true" } return "false" default: return fmt.Sprintf("%v", v) } } func FormatToString(value interface{}) string { return toString(value) } // GetTableName 获取表名 func GetTableName(obj interface{}) string { objType := reflect.TypeOf(obj) // 如果是指针,获取指向的类型 if objType.Kind() == reflect.Ptr { objType = objType.Elem() } // 获取结构体名称并转换为下划线格式 tableName := objType.Name() return humpToUnderline(tableName) } func GetIdValue(obj interface{}) (string, error) { // 获取对象类型和值 objType := reflect.TypeOf(obj) objValue := reflect.ValueOf(obj) // 如果是指针,获取指向的值 if objType.Kind() == reflect.Ptr { objType = objType.Elem() objValue = objValue.Elem() } // 检查是否是结构体 if objType.Kind() != reflect.Struct { return "", fmt.Errorf("参数必须是结构体或结构体指针") } // 查找主键字段 primaryKeyField := "" primaryKeyValue := interface{}(nil) // 遍历结构体字段 for i := 0; i < objType.NumField(); i++ { field := objType.Field(i) // 检查 TableKey 标签 if tableKey := field.Tag.Get(tagTableKey); tableKey != "" && tableKey != "false" { // 获取字段值 fieldValue := getRealValue(objValue.Field(i).Interface()) primaryKeyField = field.Name primaryKeyValue = fieldValue break } } // 如果没有找到带标签的主键,尝试查找名为"id"的字段 if primaryKeyField == "" { // 先尝试查找导出字段"Id" if field, ok := objType.FieldByName("Id"); ok { fieldValue := getRealValue(objValue.FieldByName("Id").Interface()) primaryKeyField = field.Name primaryKeyValue = fieldValue } else if field, ok := objType.FieldByName("ID"); ok { // 尝试查找"ID" fieldValue := getRealValue(objValue.FieldByName("ID").Interface()) primaryKeyField = field.Name primaryKeyValue = fieldValue } else { // 尝试查找小写"id"(非导出字段,通过CanInterface检查) for i := 0; i < objType.NumField(); i++ { field := objType.Field(i) if strings.ToLower(field.Name) == "id" { fieldValue := objValue.Field(i) if fieldValue.CanInterface() { primaryKeyField = field.Name primaryKeyValue = getRealValue(fieldValue.Interface()) break } } } } } // 如果仍然没有找到主键字段,返回错误 if primaryKeyField == "" { return "", fmt.Errorf("未找到主键字段") } // 如果主键值为空,返回错误 if primaryKeyValue == nil { return "", fmt.Errorf("主键值为空") } // 根据值类型添加引号 var id_value string switch v := primaryKeyValue.(type) { case string: id_value = escapeSQLString(v) case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64: id_value = fmt.Sprintf("%v", v) case float32, float64: id_value = fmt.Sprintf("%v", v) case bool: if v { id_value = "true" } else { id_value = "false" } default: // 对于其他类型,使用字符串表示 id_value = fmt.Sprintf("'%v'", escapeSQLString(fmt.Sprintf("%v", v))) } return id_value, nil }