https://www.huangdf.xyz/categories/study-notes
南风
南风
发布于 2025-01-23 / 9 阅读
1
0

Java拦截器修改sql

前言

项目中有这样一个需求,需要根据用户的选择动态的执行sql,用户能够更加自由的选择参数的组合,并非像过去一样sql是固定的且条件之间的关系也是固定的,用户只有选择是否需要这个条件的权力。现在用户可以自定义这个sql,可以自己定义各个条件之间的关系,以及各个条件和值的关系。

原理

自定义sql的目标是将定义sql的权限由后端交到前端页面,原理是将前端页面输入的信息通过特定的规则进行解析,并将其编译成数据库能够识别的语言,拼接到预先写好的sql中,最终实现由用户来定义查询的条件。

适用范围及局限

适用范围:表结构比较复杂,或者多表联查等有多个查询条件,比如,工厂生产的商品,查询特定的商品时,可以查询的条件有:原材料、销量、成本、利润、利率、sku、spu等等,如果统一写在页面上,一是会导致查询栏内容过多,其次是条件都是写好的,例如原材料=a、销量>100不能反过来,所以通过自定义查询,能够更加灵活的查询内容。

局限性:这个自定义只是定义查询条件,查询的结果是不会变的,因为这个涉及到字段名称映射,没办法做到自定义映射,除非把相关的所有字段都罗列出来,否则参数对应的时候会抛出异常。但如果把参数都罗列出来,那就等同于select *了,效率低下。

实践

controller

@PostMapping("/selfQuery")
@ApiOperation("自定义查询测试方法")
public void selfQuery(@RequestBody SelfQueryViewDTO dto){
    selfQueryService.selfQuery(dto);
}

entidy

SelfQueryViewDTO

@AllArgsConstructor
@NoArgsConstructor
@Data
@ApiModel("自定义查询页面入参")
public class SelfQueryViewDTO implements Serializable {
    private  String sqlSegment;
    @ApiModelProperty("查询列及运算条件")
    private List<QueryColumnDTO> queryColumns;
    @ApiModelProperty("表对应id")
    private Integer tableId;
    public List<QueryColumnDTO> getQueryColumns() {
        return null == queryColumns ? Collections.emptyList() : queryColumns;
    }
}

QueryColumnDTO

@AllArgsConstructor
@NoArgsConstructor
@Data
@Builder
@ApiModel
public class QueryColumnDTO implements Serializable {
    /**
     * 当前SQL中表别名
     */
    private  String tableAlias;
    @ApiModelProperty("数据库表名")
    private String dbTableName;
    @ApiModelProperty("数据库字段名")
    private String dbColumnName;
    @ApiModelProperty("操作符:例如大于、小于、之间……")
    private Integer operator;
    @ApiModelProperty("值:若运算关系要求有多个值,以英文逗号分割;时间字符串格式如'1900-01-01'")
    private String value;
    @ApiModelProperty("条件逻辑关系:0-->and;1-->or")
    private Integer relationType;
    @ApiModelProperty("列标题")
    private String title;
}

配置项

注解

@Target({ElementType.METHOD,ElementType.TYPE})
@Retention(RetentionPolicy.RUNTIME)
public @interface SelfQuery {
}

@Target({ElementType.METHOD,ElementType.TYPE})
@Retention(RetentionPolicy.RUNTIME)
public @interface SqlSegment {
}

切面

@Aspect
@Component
@RequiredArgsConstructor
public class SelfQueryAOP {
    private final SelfConcatService selfConcatService;

    @Before("@annotation(com.xxh.customsql.aop.self.SelfQuery)")
    public void vw(JoinPoint joinPoint) {
        Object[] args = joinPoint.getArgs();
        for (Object arg : args) {
            if (!(arg instanceof SelfQueryViewDTO)) continue;
            SelfQueryViewDTO vwDto = (SelfQueryViewDTO) arg;
            // 来自参数中的过滤条件:即页面上视图之外选中的条件
            List<QueryColumnDTO> inParameterQcs = vwDto.getQueryColumns();
            // 取注解与方法信息
            MethodSignature ms = (MethodSignature) joinPoint.getSignature();
            //得到注解方法的全限定名
            String mapperId = ms.getDeclaringTypeName() + StringPool.DOT + ms.getName();
            // 3、填充别名
            selfConcatService.fillTableAlias(inParameterQcs, mapperId);
            // 4、拼接WHERE片段
            String sqlSegment = selfConcatService.concatSqlByMh(inParameterQcs);
            // 填充参数
            vwDto.setSqlSegment(sqlSegment);
        }
    }
}

常量接口

public interface StrPool {
    String SPACE=" ";
    String PK="PRIMARY KEY";
    String AUTO_INCR="AUTO_INCREMENT";
    String COMMENT="COMMENT";
    String NULL="NULL";
    String NOT_NULL ="NOT NULL";
    String DEFAULT="DEFAULT";
    String USE_HASH="USING HASH";
    String USE_BTREE="USING BTREE";
    String UNIQUE="UNIQUE INDEX";
    String SPATIAL="SPATIAL INDEX";
    String FULLTEXT="FULLTEXT INDEX";
    String NORMAL="INDEX";
    String ITALIC_DOT="`";
    String ITALIC_DOUBLE="\"";
    String ITALIC_SINGLE="'";
    String LEFT_BRACKET="(";
    String RIGHT_BRACKET=")";
    String SEMICOLON=";";
    String COMMA=",";
    String CREATE_TABLE="CREATE TABLE";
}

枚举类

@Getter
@AllArgsConstructor
public enum OperatorEnum {

    LIKE(1, "包含", "like"),

    NOT_LIKE(2, "不包含", "notLike"),

    IS_NULL(3, "为空", "isNull"),

    IS_NOT_NULL(4, "不为空", "isNotNull"),

    EQ(5, "等于", "eq"),

    GT(6, "大于", "gt"),

    LT(7, "小于", "lt"),

    GE(8, "大于等于", "ge"),

    LE(9, "小于等于", "le"),

    NE(10, "不等于", "ne"),

    BETWEEN(11, "之间", "between"),

    IN(12, "属于", "in"),

    NOT_IN(13, "不属于", "notIn"),

//    D_DATE(14, "动态时间", "dDateBetween"),
//
//    C_DATE(15, "日历时间", "cDateBetween"),

    DATE(16, "普通时间", "dateBetween"),

    ;


    private Integer state;


    private String stateName;

    /**
     * 对应到QueryWrapper中的方法名,若在此枚举类中添加枚举,请务必正确写对方法名
     */
    private String wrapperMethodName;

    /**
     * 根据state数字获取对应的枚举
     *
     * @param state 数字
     * @return 枚举
     */
    public static OperatorEnum getOperator(Integer state) {
        for (OperatorEnum operator : OperatorEnum.values()) {
            if (operator.state.equals(state)) return operator;
        }
        return null;
    }

    /**
     * 根据state获取对应的QueryWrapper方法名
     *
     * @param state 数字
     * @return 枚举
     */
    public static String getWrapperMethodName(Integer state) {
        for (OperatorEnum operator : OperatorEnum.values()) {
            if (operator.state.equals(state)) return operator.wrapperMethodName;
        }
        return null;
    }

}

工具类

public class SqlParserTableHelper {

    public static Table getTableByName(PlainSelect plainSelect, String dbTableName) {
        Table table = new Table();
        List<Table> tables=new ArrayList<>();
        Table fromItem = (Table)plainSelect.getFromItem();
        tables.add(fromItem);
        List<Join> joins = plainSelect.getJoins();
        if(joins!=null){
            for (Join join : joins) {
                tables.add((Table) join.getFromItem());
            }
        }
        for (Table table1 : tables) {
            if (table1.getName().equals(dbTableName)) {
                table=table1;
                break;
            }
        }
        return table;
    }
}

mapper

@SelfQuery
Integer selfQuery(@Param("info") SelfQueryViewDTO dto);

<select id="selfQuery" resultType="java.lang.Integer">
    select count(*)
    from user user
    <where>
        1=1
        <if test="info !=null and info !=''">
            ${info.sqlSegment}
        </if>
    </where>
</select>

service

@Service
@RequiredArgsConstructor
public class SelfQueryServiceImpl implements SelfQueryService {

    @Autowired
    private  SelfQueryMapper selfQueryMapper;

    @Override
    public void selfQuery(SelfQueryViewDTO dto) {
        System.out.println(selfQueryMapper.selfQuery(dto));
    }
}

public interface SelfQueryService {
    void selfQuery(SelfQueryViewDTO dto);
}

@Service
@RequiredArgsConstructor
public class SelfConcatServiceImpl extends JsqlParserSupport implements SelfConcatService {
    private final HashMap<Integer, MethodHandle> opMh = new HashMap<>(OperatorEnum.values().length << 2);

    private final SqlSessionFactory sqlSessionFactory;

    {
        // 以方法参数和返回值定义方法类型
        MethodType mt = MethodType.methodType(void.class, String.class, String.class);
        MethodHandles.Lookup lookup = ConcatWrapper.lookup();
        for (OperatorEnum value : OperatorEnum.values()) {
            String wrapperMethodName = value.getWrapperMethodName();
            // 以方法名与方法类型查找并持有方法句柄
            try {
                MethodHandle mh = lookup.findVirtual(ConcatWrapper.class, wrapperMethodName, mt);
                opMh.put(value.getState(), mh);
            } catch (NoSuchMethodException | IllegalAccessException e) {
                e.printStackTrace();
                throw new RuntimeException("有操作符没有对应的SQL拼接方法,请检查该操作符!" + value.getStateName());
            }
        }
    }

    @Override
    protected void processSelect(Select select, int index, String sql, Object obj) {
         List<QueryColumnDTO> queryColumns = (List<QueryColumnDTO>) obj;
        PlainSelect plainSelect = (PlainSelect) select.getSelectBody();
        for (QueryColumnDTO queryColumn : queryColumns) {
            Table table = SqlParserTableHelper.getTableByName(plainSelect, queryColumn.getDbTableName());
            if (ObjectUtil.isNotNull(table) && ObjectUtil.isNotNull(table.getAlias())) {
                queryColumn.setTableAlias(table.getAlias().getName());
            } else {
                queryColumn.setTableAlias(queryColumn.getDbTableName());
            }
        }
    }

    /**
     * 在指定Mapper文件中指定SQL中查找数据库表名在当前SQL中对应的表别名,并填充进集合中
     */
    @Override
    public void fillTableAlias(List<QueryColumnDTO> inParameterQcs, String mapperId) {
        if (CollectionUtil.isEmpty(inParameterQcs)) return;
        MappedStatement statement = sqlSessionFactory.getConfiguration().getMappedStatement(mapperId);
        BoundSql boundSql = statement.getBoundSql(null);
        super.parserSingle(boundSql.getSql(), inParameterQcs);
    }

    @Override
    public String concatSqlByMh(List<QueryColumnDTO> inParameterQcs) {
        if (CollectionUtil.isEmpty(inParameterQcs)) return "";
        ConcatWrapper wrapper = new ConcatWrapper();
        for (QueryColumnDTO queryColumn : inParameterQcs) {
            String column;
            column = StrUtil.isBlank(queryColumn.getTableAlias()) ? queryColumn.getDbTableName() : queryColumn.getTableAlias()
                    + StringPool.DOT + queryColumn.getDbColumnName();
            try {
                // 如果是OR关系,拼接OR
                if (NumsConstant.INTEGER_ONE.equals(queryColumn.getRelationType())) {
                    wrapper.or();
                }
                // 取方法句柄并调用,拼接SQL
                MethodHandle mh = opMh.get(queryColumn.getOperator());
                mh.invoke(wrapper, column, queryColumn.getValue());
            } catch (Throwable e) {
                e.printStackTrace();
                throw new RuntimeException("字段【" + queryColumn.getDbTableName() + StringPool.DOT + queryColumn.getDbColumnName() + "】参数错误!");
            }
        }
        return wrapper.getSqlSeg();
    }
}

public interface SelfConcatService {
    void fillTableAlias(List<QueryColumnDTO> inParameterQcs, String mapperId);

    String concatSqlByMh(List<QueryColumnDTO> inParameterQcs);

}

wapper


public class ConcatWrapper implements Serializable {

    QueryWrapper<Object> wrapper;

    public ConcatWrapper() {
        this.wrapper = new QueryWrapper<>();
    }

    public void or() {
        wrapper.or();
    }

    public void like(String column, String val) {
        wrapper.like(column, val);
    }

    public void notLike(String column, String val) {
        wrapper.notLike(column, val);
    }

    public void isNull(String column, String val) {
        wrapper.isNull(column);
    }

    public void isNotNull(String column, String val) {
        wrapper.isNotNull(column);
    }

    public void eq(String column, String val) {
        wrapper.eq(column, val);
    }

    public void gt(String column, String val) {
        wrapper.gt(column, val);
    }

    public void lt(String column, String val) {
        wrapper.lt(column, val);
    }

    public void ge(String column, String val) {
        wrapper.ge(column, val);
    }

    public void le(String column, String val) {
        wrapper.le(column, val);
    }

    public void ne(String column, String val) {
        wrapper.ne(column, val);
    }

    public void in(String column, String val) {
        wrapper.in(column, this.splitByComma(val));
    }

    public void notIn(String column, String val) {
        wrapper.notIn(column, this.splitByComma(val));
    }

    public void between(String column, String val) {
        List<String> values = this.safeBetWeen(val);
        wrapper.between(column, values.get(0), values.get(1));
    }

    public void dateBetween(String column, String val) {
        // 前一个时间取当天开始零点,后一个时间取当天结束零点
        Calendar calendar=Calendar.getInstance();
        try {
            List<String> dateStr = this.safeBetWeen(val);
            Date startTime = DateUtil.parseDate(dateStr.get(0));
            calendar.setTime(DateUtil.parseDate(dateStr.get(1)));
            calendar.add(Calendar.DAY_OF_MONTH,1);
            Date endTime = calendar.getTime();
            wrapper.between(column, Timestamp.from(startTime.toInstant()), Timestamp.from(endTime.toInstant()));
        } catch (Throwable e) {
            e.printStackTrace();
            throw new RuntimeException("请指定正确的时间格式!");
        }
    }

    private List<String> safeBetWeen(String val) {
        List<String> values = this.splitByComma(val);
        if (values.size() != 2)
            throw new RuntimeException("检查值【" + val + "】,条件列运算符为之间时,值只能传入两个参数!");
        return values;
    }

    private List<String> splitByComma(String val) {
        if (null == val) throw new RuntimeException("查询条件不要传入空值!");
        String[] tarArr = val.trim().split(StrPool.COMMA);
        return Arrays.asList(tarArr);
    }

    public static MethodHandles.Lookup lookup() {
        return MethodHandles.lookup();
    }

    /**
     * 替换拼接后wrapper预编译形式SQL中的’?‘为真正的参数
     *
     * @return 完整的SQL片段
     */
    public String getSqlSeg() {
        // 预编译SQL
        String targetSql = wrapper.getTargetSql();
        StringBuilder sqlSeg = new StringBuilder(StringPool.AND.length() + targetSql.length() + 16);
        sqlSeg.append(StringPool.AND);
        // 参数转换,转换后键为参数位置,值为参数值
        Map<String, Object> sqlParamMap = wrapper.getParamNameValuePairs();
        Object[] paramArr = sqlParamMap.entrySet().stream()
                .collect(Collectors.groupingBy(entry -> entry.getKey().length()))
                .entrySet().stream()
                .collect(Collectors.toMap(
                        Map.Entry::getKey,
                        entry -> entry.getValue().stream()
                                .sorted(Map.Entry.comparingByKey())
                                .map(Map.Entry::getValue)
                                .collect(Collectors.toList())))
                .entrySet().stream()
                .sorted(Map.Entry.comparingByKey())
                .flatMap(entry -> entry.getValue().stream())
                .toArray();
        char zw = '?';
        int paramIdx = 0;
        for (int i = 0; i < targetSql.length(); i++) {
            char c = targetSql.charAt(i);
            // 预编译形式的SQL中不会有问号出现,若遇到问号,则替换为加引号的参数
            if (c == zw) {
                sqlSeg.append(StringPool.SINGLE_QUOTE);
                sqlSeg.append(paramArr[paramIdx]);
                paramIdx++;
                sqlSeg.append(StringPool.SINGLE_QUOTE);
            } else {
                sqlSeg.append(c);
            }
        }
        return sqlSeg.toString();
    }

}

过程粗解

在方法走到mapper层时,会触发切面的内容,对入参进行处理,最终处理成编译好的sql语言。

1、通过 ms.getDeclaringTypeName() + StringPool.DOT + ms.getName()获取到方法的全限定名称,这一步主要是获取sql中表的别名,以便后续填充到条件中

2、concatSqlByMh:通过这个方法将入参的list对象转换成sql的查询语句。具体是:利用mybatis-plus的预编译机制,创建一个QueryWrapper来拼接查询条件,通过遍历list中的参数,来构建条件列以及条件

3、将构建好的东西转换为数据库查询语言。

4、xml文件中使用,直接获取sqlSegment这个参数就成

/******************************25-1-23******************************/

后面再补充完整

补充

补充1:

秒懂Java之方法句柄(MethodHandle)_java methodhandles-CSDN博客

句柄

作用与反射类似,可以在运行时访问类型信息,执行效率比反射更高。

相关概念:

Lookup:MethodHandle 的创建工厂,通过它可以创建MethodHandle,值得注意的是检查工作是在创建时处理的,而不是在调用时处理。

MethodType:顾名思义,就是代表方法的签名。一个方法的返回值类型是什么,有几个参数,每个参数的类型什么?

MethodHandle:方法句柄,通过它我们就可以动态访问类型信息了。

如何使用:

  1. 创建Lookup

  2. 创建MethodType

  3. 基于Lookup与MethodType获得MethodHandle

  4. 调用MethodHandle

再上面代码中则是通过操作句柄来调用对应的方法,从而实现将参数和条件进行拼接的操作。

最后将所有的参数和条件都进行拼接之后,再通过getsql这个自定义的方法,将预编译中的?替换成真实的sql,最后获取最终的sql


评论