DataPermissionInterceptor.java 11.4 KB
package com.tianbo.analysis.intercept;

import com.alibaba.fastjson.JSONArray;
import com.alibaba.fastjson.JSONObject;
import com.tianbo.analysis.thread.SessionUserContext;
import lombok.extern.slf4j.Slf4j;
import net.sf.jsqlparser.expression.Expression;
import net.sf.jsqlparser.parser.CCJSqlParserUtil;
import net.sf.jsqlparser.schema.Table;
import net.sf.jsqlparser.statement.Statement;
import net.sf.jsqlparser.statement.select.Join;
import net.sf.jsqlparser.statement.select.PlainSelect;
import net.sf.jsqlparser.statement.select.Select;
import net.sf.jsqlparser.statement.select.SelectItem;
import org.apache.ibatis.cache.CacheKey;
import org.apache.ibatis.executor.Executor;
import org.apache.ibatis.mapping.BoundSql;
import org.apache.ibatis.mapping.MappedStatement;
import org.apache.ibatis.plugin.Interceptor;
import org.apache.ibatis.plugin.Intercepts;
import org.apache.ibatis.plugin.Invocation;
import org.apache.ibatis.plugin.Signature;
import org.apache.ibatis.session.ResultHandler;
import org.apache.ibatis.session.RowBounds;
import org.springframework.stereotype.Component;

import java.text.SimpleDateFormat;
import java.util.Date;
import java.util.List;

@Intercepts({
        @Signature(type= Executor.class,
                method = "query",
                args = {MappedStatement.class,Object.class, RowBounds.class, ResultHandler.class}),
        @Signature(type= Executor.class,
                method = "query",
                args = {MappedStatement.class,
                        Object.class,
                        RowBounds.class,
                        ResultHandler.class,
                        CacheKey.class,
                        BoundSql.class
                })
        })
@Component
@Slf4j
public class DataPermissionInterceptor implements Interceptor {

    @Override
    public Object intercept(Invocation invocation) throws Throwable {
        //从invocation搞到sql
        String processSql = ExecutorPluginUtils.getSqlByInvocation(invocation);
        //复制原sql,用于产生新sql
        String sql2Reset = processSql;
        Statement statement  = CCJSqlParserUtil.parse(processSql);
        //获取mybatis绑定SQL配置信息也就是XML中的具体执行相关
        MappedStatement mappedStatement =  (MappedStatement)invocation.getArgs()[0];

        /** 得到spring上下文,
         如果后端未用分页,则这步可以省略
         在项目启动类下完成该配置
        ConfigurableApplicationContext run = SpringApplication.run(Application.class, args);
        Interceptor permissionInterceptor = (Interceptor) run.getBean("dataPermissionInterceptor");
        这种方式添加mybatis拦截器保证在pageHelper前执行
        run.getBean(SqlSessionFactory.class).getConfiguration().addInterceptor(permissionInterceptor);
         **/

        if (ExecutorPluginUtils.isAreaTag(mappedStatement)) {
            /**
             * 用户数据权限范围判定,即行数据权限判定
             * 根据判定结果进行dataScope的值set
             */
            // 存储用户行权限的LIST
            /**
             *
             *  判定
             *  场景一: 用户未配置数据权限,rowPermList 为空
             *  场景二: 用户配置了数据权限,rowPermList 有值,包含 对应数据字段列 条件信息
             *  场景三: 用户配置了数据权限,rowPermList 没值, 权限判定具有关键字标识的 如: * / ALL等
             *  todo: 条件值怎么获取?
             */
            String dataScope = "els";
            JSONArray rowPermList = SessionUserContext.getSessionUser();
            //取出来后 清理
            SessionUserContext.clearSessionUser();
            if (rowPermList!=null && !rowPermList.isEmpty()) {
                // 用户配置了数据权限
                //获取该用户所具有的角色的数据权限dataScope
                dataScope = "usr";
            }else {
                //用户未配置数据权限
                dataScope = "*";
            }

            String deptsUser= "admin";
            //因数据敏感省略
            //获取该用户的所在公司或部门下的所有人 in 条件 包含()封装
            //例如 StringBuffer orgBuffer = new StringBuffer();
            // orgBuffer.append("(");
            //String collect = allUserByOrgs.stream().map(String::valueOf).collect(Collectors.joining(","));
            //orgBuffer.append(collect).append(")");
            //String orgsUser = orgBuffer.toString();
            String orgsUser = "45,46,47";
            try {
                if (statement instanceof Select) {
                    Select selectStatement = (Select) statement;
                    //其中的PlainSelect 可以拿到sql语句的全部节点信息,具体各位可以看源码
                    PlainSelect plain = (PlainSelect) selectStatement.getSelectBody();
                    //获取所有外连接
                    List<Join> joins = plain.getJoins();
                    //获取到原始sql语句
                    String sql = processSql;
                    StringBuffer whereSql = new StringBuffer();
                    switch (dataScope) {
                        /**
                         * 这里dataScope  范围 *: 所有数据权限  ,usr: 本人  ,dep(department):部门及分部门(递归)  ,com(company):公司及分公司(递归)
                         * els : 未配置
                         * 所有数据权限作用在人上,因此sql用 in
                         */
                        case "*":
                            whereSql.append("1=1");
                            break;
                        case "usr":
                            if(joins==null || joins.isEmpty()){
                                String and = " and ";
                                for (int i = 0; i < rowPermList.size(); i++) {
                                    JSONObject o = (JSONObject)rowPermList.get(i);
                                    if (i==0){
                                        whereSql
                                                //条件字段
                                                .append(o.get("colName"))
                                                .append(" = ")
                                                // 条件值
                                                .append(getSqlValue(o.get("colValue")));
                                    }else {
                                        whereSql.append(and)
                                                //条件字段
                                                .append(o.get("colName"))
                                                .append(" = ")
                                                // 条件值
                                                .append(getSqlValue(o.get("colValue")));
                                    }

                                }
                            }else{
                                for (Join join : joins) {
                                    Table rightItem = (Table) join.getRightItem();
                                    //匹配表名
                                    if(rightItem.getName().equals("sec_user")){
                                        //获取别名
                                        if(rightItem.getAlias()!=null){
                                            //适配用户ID 样例
                                            whereSql.append(rightItem.getAlias().getName()).append(".USERNAME = ").append(45);
                                        }else {
                                            whereSql.append("id = ").append(deptsUser);
                                        }

                                    }
                                }
                            }
                            break;
                        case "dep":
                            for (Join join : joins) {
                                Table rightItem = (Table) join.getRightItem();
                                if(rightItem.getName().equals("sec_user")){
                                    if(rightItem.getAlias()!=null){
                                        whereSql.append(rightItem.getAlias().getName()).append(".id in ").append(deptsUser);
                                    }else {
                                        whereSql.append("id in ").append(deptsUser);
                                    }

                                }
                            }
                            break;
                        case "com":
                            for (Join join : joins) {
                                Table rightItem = (Table) join.getRightItem();
                                if(rightItem.getName().equals("sec_user")){
                                    if(rightItem.getAlias()!=null){
                                        whereSql.append(rightItem.getAlias().getName()).append(".id in ").append(orgsUser);
                                    }else {
                                        whereSql.append("id in ").append(deptsUser);
                                    }

                                }
                            }
                            break;
                        default:
                            whereSql.append("1=2");
                            break;
                    }

                    /**
                     * 获取where节点
                     * 行访问权限相关
                     * 重新拼接where 语句 需要注意不要破坏where 查询条件的顺序,为了配合索引提高查询效率
                     */
                    Expression where = plain.getWhere();
                    if (where == null) {
                        if (whereSql.length() > 0) {
                            Expression expression = CCJSqlParserUtil
                                    .parseCondExpression(whereSql.toString());
                            Expression whereExpression = (Expression) expression;
                            plain.setWhere(whereExpression);
                        }
                    } else {
                        if (whereSql.length() > 0) {
                            //where条件之前存在,需要重新进行拼接
                            whereSql.append(" and ( " + where.toString() + " )");
                        } else {
                            //新增片段不存在,使用之前的sql
                            whereSql.append(where.toString());
                        }
                        Expression expression = CCJSqlParserUtil
                                .parseCondExpression(whereSql.toString());
                        plain.setWhere(expression);
                    }
                    sql2Reset = selectStatement.toString();
                }

            } catch (Exception e) {
                log.error("[SQl-Interceptor-ERR]-",e);
                e.printStackTrace();
            }
        }
        // 替换sql
        ExecutorPluginUtils.resetSql2Invocation(invocation, sql2Reset);
        //放行
        Object proceed = invocation.proceed();
        return proceed;
    }

    private String getSqlValue(Object value) {
        if (value instanceof String) {
            return "'" + value + "'";
        }else if (value instanceof Date) {
            SimpleDateFormat sdf = new SimpleDateFormat("yyyy-MM-dd");
            return "'" + sdf.format((Date) value) + "'";
        } else {
            return value.toString();
        }
    }
}