DataPermissionInterceptor.java 14.2 KB
package com.tianbo.analysis.intercept;

import com.alibaba.fastjson.JSON;
import com.alibaba.fastjson.JSONArray;
import com.alibaba.fastjson.JSONObject;
import com.tianbo.analysis.annotation.DataPermission;
import com.tianbo.analysis.thread.SessionUserContext;
import lombok.extern.slf4j.Slf4j;
import net.sf.jsqlparser.expression.Expression;
import net.sf.jsqlparser.parser.CCJSqlParserUtil;
import net.sf.jsqlparser.schema.Column;
import net.sf.jsqlparser.schema.Table;
import net.sf.jsqlparser.statement.Statement;
import net.sf.jsqlparser.statement.select.*;
import org.apache.commons.lang.StringUtils;
import org.apache.ibatis.cache.CacheKey;
import org.apache.ibatis.executor.Executor;
import org.apache.ibatis.mapping.BoundSql;
import org.apache.ibatis.mapping.MappedStatement;
import org.apache.ibatis.plugin.Interceptor;
import org.apache.ibatis.plugin.Intercepts;
import org.apache.ibatis.plugin.Invocation;
import org.apache.ibatis.plugin.Signature;
import org.apache.ibatis.session.ResultHandler;
import org.apache.ibatis.session.RowBounds;
import org.springframework.stereotype.Component;

import java.lang.reflect.Method;
import java.text.SimpleDateFormat;
import java.util.ArrayList;
import java.util.Date;
import java.util.HashSet;
import java.util.List;


@Intercepts({
        @Signature(type= Executor.class,
                method = "query",
                args = {MappedStatement.class,Object.class, RowBounds.class, ResultHandler.class}),
        @Signature(type= Executor.class,
                method = "query",
                args = {MappedStatement.class,
                        Object.class,
                        RowBounds.class,
                        ResultHandler.class,
                        CacheKey.class,
                        BoundSql.class
                })
        })
@Component
@Slf4j
public class DataPermissionInterceptor implements Interceptor {

    @Override
    public Object intercept(Invocation invocation) throws Throwable {
        //判定要拦截的MAPPER 注解,使用了@DataPermission注解的才会被拦截
        Method method = invocation.getMethod();
        // 获取Mapper接口的Class对象
        Class<?> mapperClass = method.getDeclaringClass();
        // 检查类或方法是否标记了@DataPermission注解
        boolean isAnnotated = method.isAnnotationPresent(DataPermission.class)
        || mapperClass.isAnnotationPresent(DataPermission.class);

        // 如果没有标记注解,直接调用原方法
        if (!isAnnotated){
            return invocation.proceed();
        }

        //从invocation搞到sql
        String processSql = ExecutorPluginUtils.getSqlByInvocation(invocation);
        //复制原sql,用于产生新sql
        String sql2Reset = processSql;
        Statement statement  = CCJSqlParserUtil.parse(processSql);
        //获取mybatis绑定SQL配置信息也就是XML中的具体执行相关
        MappedStatement mappedStatement =  (MappedStatement)invocation.getArgs()[0];

        /** 得到spring上下文,
         如果后端未用分页,则这步可以省略
         在项目启动类下完成该配置
        ConfigurableApplicationContext run = SpringApplication.run(Application.class, args);
        Interceptor permissionInterceptor = (Interceptor) run.getBean("dataPermissionInterceptor");
        这种方式添加mybatis拦截器保证在pageHelper前执行
        run.getBean(SqlSessionFactory.class).getConfiguration().addInterceptor(permissionInterceptor);
         **/

        if (ExecutorPluginUtils.isAreaTag(mappedStatement)) {
            /**
             * 用户数据权限范围判定,即行数据权限判定
             * 根据判定结果进行dataScope的值set
             */
            // 存储用户行权限的LIST
            /**
             *
             *  判定
             *  场景一: 用户未配置数据权限,rowPermList 为空
             *  场景二: 用户配置了数据权限,rowPermList 有值,包含 对应数据字段列 条件信息
             *  场景三: 用户配置了数据权限,rowPermList 没值, 权限判定具有关键字标识的 如: * / ALL等
             *  todo: 条件值怎么获取?
             */
            String dataScope = "els";
            JSONObject user = SessionUserContext.getSessionUser();
            String username = user.getString("username");
            Integer userId = user.getInteger("userId");
            JSONArray dataPermissions = user.getJSONArray("dataPermissions");
            //存储行条件数据权限
            ArrayList<JSONObject> rowConditions = new ArrayList<>();
            //使用hashset存储防止 字段列名重复
            HashSet<String> colConditions = new HashSet();

            if (dataPermissions!=null && !dataPermissions.isEmpty()){
                for (Object item : dataPermissions) {
                    JSONObject datapermission = (JSONObject) JSON.toJSON(item);
                    /**
                     * 一个组织绑定了同一个接口的多个数据权限,循环到这里 会出现 * 条件和 usr条件混乱,
                     * 目前 以循环 最后取到的row_condition为准,所以需要保持 数据权限配置数据正确
                     */
                    dataScope = datapermission.getString("perm_type");
                    String colListStr = datapermission.getString("cols_list");
                    if (StringUtils.isNotEmpty(colListStr)){
                        if ("*".equals(colListStr)) {
                            //查询全部列,不改
                        }else{
                            List cloNames = JSONArray.parseArray(colListStr);
                            colConditions.addAll(cloNames);
                            colConditions.add("CREATTIME");
                        }
                    }
                    rowConditions.add(datapermission);
                }
            }

            String deptsUser= "admin";
            //因数据敏感省略
            //获取该用户的所在公司或部门下的所有人 in 条件 包含()封装
            //例如 StringBuffer orgBuffer = new StringBuffer();
            // orgBuffer.append("(");
            //String collect = allUserByOrgs.stream().map(String::valueOf).collect(Collectors.joining(","));
            //orgBuffer.append(collect).append(")");
            //String orgsUser = orgBuffer.toString();
            String orgsUser = "45,46,47";
            try {
                if (statement instanceof Select) {
                    Select selectStatement = (Select) statement;
                    //其中的PlainSelect 可以拿到sql语句的全部节点信息,具体各位可以看源码
                    PlainSelect plain = (PlainSelect) selectStatement.getSelectBody();
                    //获取所有外连接
                    List<Join> joins = plain.getJoins();

                    //获取到原始sql语句
                    String sql = processSql;
                    StringBuffer whereSql = new StringBuffer();
                    switch (dataScope) {
                        /**
                         * 这里dataScope  范围 *: 所有数据权限  ,usr: 本人  ,dep(department):部门及分部门(递归)  ,com(company):公司及分公司(递归)
                         * els : 未配置
                         * 所有数据权限作用在人上,因此sql用 in
                         */
                        case "*":
                            whereSql.append("1=1");
                            break;
                        case "usr":
                            if(joins==null || joins.isEmpty()){
                                String and = " and ";
                                for (int i = 0; i < rowConditions.size(); i++) {
                                    JSONObject dataPermission = rowConditions.get(i);
                                    if (i==0){
                                        whereSql
                                                //条件字段
                                                .append(dataPermission.get("row_condition"))
                                                .append(" = ")
                                                // 条件值
                                                .append(getSqlValue(user.get(dataPermission.getString("row_condition_property"))));
                                    }else {
                                        whereSql.append(and)
                                                //条件字段
                                                .append(dataPermission.get("row_condition"))
                                                .append(" = ")
                                                // 条件值
                                                .append(getSqlValue(user.get(dataPermission.getString("row_condition_property"))));
                                    }

                                }
                            }else{
                                for (Join join : joins) {
                                    Table rightItem = (Table) join.getRightItem();
                                    //匹配表名
                                    if(rightItem.getName().equals("sec_user")){
                                        //获取别名
                                        if(rightItem.getAlias()!=null){
                                            //适配用户ID 样例
                                            whereSql.append(rightItem.getAlias().getName()).append(".USERNAME = ").append(45);
                                        }else {
                                            whereSql.append("id = ").append(deptsUser);
                                        }

                                    }
                                }
                            }
                            break;
                        case "dep":
                            for (Join join : joins) {
                                Table rightItem = (Table) join.getRightItem();
                                if(rightItem.getName().equals("sec_user")){
                                    if(rightItem.getAlias()!=null){
                                        whereSql.append(rightItem.getAlias().getName()).append(".id in ").append(deptsUser);
                                    }else {
                                        whereSql.append("id in ").append(deptsUser);
                                    }

                                }
                            }
                            break;
                        case "com":
                            for (Join join : joins) {
                                Table rightItem = (Table) join.getRightItem();
                                if(rightItem.getName().equals("sec_user")){
                                    if(rightItem.getAlias()!=null){
                                        whereSql.append(rightItem.getAlias().getName()).append(".id in ").append(orgsUser);
                                    }else {
                                        whereSql.append("id in ").append(deptsUser);
                                    }

                                }
                            }
                            break;
                        default:
                            whereSql.append("1=2");
                            break;
                    }
                    /**
                     * 替换select 节点
                     */
                    if (!colConditions.isEmpty()){
                        List<SelectItem> selectExpressionItems = resetColumn(colConditions);
                        plain.setSelectItems(selectExpressionItems);
                    }

                    /**
                     * 获取where节点
                     * 行访问权限相关
                     * 重新拼接where 语句 需要注意不要破坏where 查询条件的顺序,为了配合索引提高查询效率
                     */
                    Expression where = plain.getWhere();
                    if (where == null) {
                        if (whereSql.length() > 0) {
                            Expression expression = CCJSqlParserUtil
                                    .parseCondExpression(whereSql.toString());
                            Expression whereExpression = (Expression) expression;
                            plain.setWhere(whereExpression);
                        }
                    } else {
                        if (whereSql.length() > 0) {
                            //where条件之前存在,需要重新进行拼接
                            whereSql.append(" and ( " + where.toString() + " )");
                        } else {
                            //新增片段不存在,使用之前的sql
                            whereSql.append(where.toString());
                        }
                        Expression expression = CCJSqlParserUtil
                                .parseCondExpression(whereSql.toString());
                        plain.setWhere(expression);
                    }
                    sql2Reset = selectStatement.toString();
                }

            } catch (Exception e) {
                log.error("[SQl-Interceptor-ERR]-",e);
                e.printStackTrace();
            }
        }
        // 替换sql
        ExecutorPluginUtils.resetSql2Invocation(invocation, sql2Reset);
        //放行
        Object proceed = invocation.proceed();
        return proceed;
    }

    private String getSqlValue(Object value) {
        if (value instanceof String) {
            return "'" + value + "'";
        }else if (value instanceof Date) {
            SimpleDateFormat sdf = new SimpleDateFormat("yyyy-MM-dd");
            return "'" + sdf.format((Date) value) + "'";
        } else {
            return value.toString();
        }
    }

    /**
     * 重新设置select的字段
     * @param newSelectItems 要重设置的字段列表
     * @return 重新设置后的字段列表
     */
    private List<SelectItem> resetColumn(HashSet newSelectItems){
        List<SelectItem> newSelectExpressionItems = new ArrayList<>();
        for (Object newSelectItem : newSelectItems) {
            SelectItem selectExpressionItem = new SelectExpressionItem(new Column(newSelectItem.toString()));
            newSelectExpressionItems.add(selectExpressionItem);
        }
        return newSelectExpressionItems;
    }
}