Shrio 自定义算法登录认证

1.实现shrio SimpleCredentialsMatcher的doCredentialsMatch算法
package cn.steven.manager.security;

import cn.sh.ideal.manager.util.AESUtils;
import org.apache.shiro.authc.AuthenticationInfo;
import org.apache.shiro.authc.AuthenticationToken;
import org.apache.shiro.authc.credential.SimpleCredentialsMatcher;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/**
 * @FileName: CustomCredentialsMatcher
 * @Author Steven
 * @Date: 2021/11/17
 */
public class CustomCredentialsMatcher extends SimpleCredentialsMatcher{

    private Logger logger = LoggerFactory.getLogger(getClass());

    @Override
    public boolean doCredentialsMatch(AuthenticationToken authenticationToken, AuthenticationInfo info) {
        //获得前台传过来的密码
        SystemUsernamePasswordToken token = (SystemUsernamePasswordToken) authenticationToken;
        //这是数据库里查出来的密码
        String sqlOriginalPassword=(String)info.getCredentials();
        boolean flag = AESUtils.aesEncrypt(String.valueOf(token.getPassword())).equals(sqlOriginalPassword);
        return flag;
    }
}
2.自定义算法
package cn.sh.ideal.manager.util;

import org.apache.commons.codec.binary.Base64;
import sun.misc.BASE64Decoder;

import javax.crypto.Cipher;
import javax.crypto.KeyGenerator;
import javax.crypto.spec.SecretKeySpec;
import java.math.BigInteger;

/**
 * AES的加密和解密
 * @author libo
 */
public class AESUtils {
    /**
     * 密钥 (需要前端和后端保持一致)
     */
    private static final String KEY = "abcdefgabcdefg12";
    /**
     *  算法
     */
    private static final String ALGORITHMSTR = "AES/ECB/PKCS5Padding";
    
    /** 
     * aes解密 
     * @param encrypt   内容 
     * @return 
     * @throws Exception 
     */  
    public static String aesDecrypt(String encrypt) {  
        try {
            return aesDecrypt(encrypt, KEY);
        } catch (Exception e) {
            e.printStackTrace();
            return "";
        }  
    }  
      
    /** 
     * aes加密 
     * @param content 
     * @return 
     * @throws Exception 
     */  
    public static String aesEncrypt(String content) {  
        try {
            return aesEncrypt(content, KEY);
        } catch (Exception e) {
            e.printStackTrace();
            return "";
        }  
    }  
  
    /** 
     * 将byte[]转为各种进制的字符串 
     * @param bytes byte[] 
     * @param radix 可以转换进制的范围,从Character.MIN_RADIX到Character.MAX_RADIX,超出范围后变为10进制 
     * @return 转换后的字符串 
     */  
    public static String binary(byte[] bytes, int radix){  
        return new BigInteger(1, bytes).toString(radix);// 这里的1代表正数
    }  
  
    /** 
     * base 64 encode 
     * @param bytes 待编码的byte[] 
     * @return 编码后的base 64 code 
     */  
    public static String base64Encode(byte[] bytes){  
        return Base64.encodeBase64String(bytes);
    }  
  
    /** 
     * base 64 decode 
     * @param base64Code 待解码的base 64 code 
     * @return 解码后的byte[] 
     * @throws Exception 
     */  
    public static byte[] base64Decode(String base64Code) throws Exception{  
        return StringUtils.isEmpty(base64Code) ? null : new BASE64Decoder().decodeBuffer(base64Code);
    }  
  
      
    /** 
     * AES加密 
     * @param content 待加密的内容 
     * @param encryptKey 加密密钥 
     * @return 加密后的byte[] 
     * @throws Exception 
     */  
    public static byte[] aesEncryptToBytes(String content, String encryptKey) throws Exception {  
        KeyGenerator kgen = KeyGenerator.getInstance("AES");
        kgen.init(128);  
        Cipher cipher = Cipher.getInstance(ALGORITHMSTR);
        cipher.init(Cipher.ENCRYPT_MODE, new SecretKeySpec(encryptKey.getBytes(), "AES"));
  
        return cipher.doFinal(content.getBytes("utf-8"));  
    }  
  
  
    /** 
     * AES加密为base 64 code 
     * @param content 待加密的内容 
     * @param encryptKey 加密密钥 
     * @return 加密后的base 64 code 
     * @throws Exception 
     */  
    public static String aesEncrypt(String content, String encryptKey) throws Exception {  
        return base64Encode(aesEncryptToBytes(content, encryptKey));  
    }  
  
    /** 
     * AES解密 
     * @param encryptBytes 待解密的byte[] 
     * @param decryptKey 解密密钥 
     * @return 解密后的String 
     * @throws Exception 
     */  
    public static String aesDecryptByBytes(byte[] encryptBytes, String decryptKey) throws Exception {  
        KeyGenerator kgen = KeyGenerator.getInstance("AES");  
        kgen.init(128);  
  
        Cipher cipher = Cipher.getInstance(ALGORITHMSTR);  
        cipher.init(Cipher.DECRYPT_MODE, new SecretKeySpec(decryptKey.getBytes(), "AES"));  
        byte[] decryptBytes = cipher.doFinal(encryptBytes);  
        return new String(decryptBytes);  
    }  
  
  
    /** 
     * 将base 64 code AES解密 
     * @param encryptStr 待解密的base 64 code 
     * @param decryptKey 解密密钥 
     * @return 解密后的string 
     * @throws Exception 
     */  
    public static String aesDecrypt(String encryptStr, String decryptKey) throws Exception {  
        return StringUtils.isEmpty(encryptStr) ? null : aesDecryptByBytes(base64Decode(encryptStr), decryptKey);  
    }  
    
    /**
     * 测试
     */
    public static void main(String[] args) throws Exception {  
        String encrypt = "SiDMeIRxC/HVG149ftRayg==";
        System.out.println("加密后:" + encrypt);
        String decrypt = aesDecrypt(encrypt, KEY);
        System.out.println("解密后:" + decrypt);  
    } 
}
3.设置定义算法
package cn.sh.ideal.manager.security;

import java.io.Serializable;
import java.util.Collection;
import java.util.List;
import java.util.concurrent.TimeUnit;

import javax.annotation.PostConstruct;

import cn.sh.ideal.manager.util.AESUtils;
import com.alibaba.fastjson.JSONObject;
import org.apache.commons.lang3.StringUtils;
import org.apache.shiro.authc.AuthenticationException;
import org.apache.shiro.authc.AuthenticationInfo;
import org.apache.shiro.authc.AuthenticationToken;
import org.apache.shiro.authc.SimpleAuthenticationInfo;
import org.apache.shiro.authc.credential.HashedCredentialsMatcher;
import org.apache.shiro.authz.AuthorizationInfo;
import org.apache.shiro.authz.SimpleAuthorizationInfo;
import org.apache.shiro.realm.AuthorizingRealm;
import org.apache.shiro.session.Session;
import org.apache.shiro.subject.PrincipalCollection;
import org.apache.shiro.util.ByteSource;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Service;

import cn.sh.ideal.dao.RedisDao;
import cn.sh.ideal.manager.model.ResourceMenu;
import cn.sh.ideal.manager.model.Role;
import cn.sh.ideal.manager.model.User;
import cn.sh.ideal.manager.service.SystemService;
import cn.sh.ideal.manager.util.Constant;
import cn.sh.ideal.manager.util.Encodes;
import cn.sh.ideal.manager.util.UserUtil;

/**
 * SystemAuthorizingRealm
 * 系统安全认证实现类
 *
 * @author Genghc
 * @date 2015/7/8
 */
@Service("systemAuthorizingRealm")
public class SystemAuthorizingRealm extends AuthorizingRealm {
    private Logger logger = LoggerFactory.getLogger(getClass());
    /*加密方式*/
    public static final String HASH_ALGORITHM = "SHA-1";
    public static final int HASH_INTERATIONS = 1024;
    @Value("#{config['multiAccountLogin']}")
    private String multiAccountLogin;
    @Autowired
    private SystemService systemService;
    @Autowired
    private RedisDao redisDao;

    /**
     * 认证回调,登录时调用
     *
     * @param authenticationToken
     * @return
     * @throws AuthenticationException
     */
    @Override
    protected AuthenticationInfo doGetAuthenticationInfo(AuthenticationToken authenticationToken) throws AuthenticationException {
        SystemUsernamePasswordToken token = (SystemUsernamePasswordToken) authenticationToken;
        //获取所有活动的会话数量
        int activeSessionSize = systemService.getActiveSessions(false).size();
        logger.info("login submit, active session size: {}, username: {}", activeSessionSize, token.getUsername());


        String username = token.getUsername();
        
        if(redisDao.readValue("PWD_ERROR_LOCK" + username) != null){
			throw new AuthenticationException("密码连续输入错误超过3次,帐号锁定5分钟,请和租户管理员联系");
		}
        //根据用户名获取用户信息
        User user = systemService.getUserByLoginName(username);
        //校验用户名
        if (user != null) {
            if (Constant.NO.equals(user.getStatus())) {
                throw new AuthenticationException("msg:该已帐号禁止登录.");
            }
            if(!AESUtils.aesEncrypt(String.valueOf(token.getPassword())).equals(user.getPassword())){
            	AuthenticationException ae = null;
            	String errorMsg = "密码错误!";
	   			 /** 密码输错三次锁定开始 **/
	   			 if(redisDao.readValue("PWD_ERROR_" + username) == null){
	   				 redisDao.saveValue("PWD_ERROR_" + username, 1,2,TimeUnit.MINUTES);
	   				ae = new AuthenticationException("密码错误,还能输入2次");
	   			 }
	   			 else{
	   				 Integer errorNum = (Integer) redisDao.readValue("PWD_ERROR_" + username);
	   				 if(errorNum.intValue()>=2){
	   					 redisDao.saveValue("PWD_ERROR_LOCK" + username, "PWD_ERROR_LOCK", 3, TimeUnit.MINUTES);
	   					 redisDao.deleteValue("PWD_ERROR_" + username);
	   					ae = new AuthenticationException("密码连续输入错误超过3次,帐号锁定3分钟,请和租户管理员联系");
	   				 }
	   				 else{
	   					 redisDao.saveValue("PWD_ERROR_" + username, errorNum.intValue() + 1,2,TimeUnit.MINUTES);
	   					ae = new AuthenticationException("密码错误,还能输入1次");
	   				 }
	   			 }
	   			 /** 密码输错三次锁定结束 **/
	   			throw ae;
            }
            else{
            	if(redisDao.exist("PWD_ERROR_" + username))
            		redisDao.deleteValue("PWD_ERROR_" + username);
            }
            
            
            UserUtil.getSession().setAttribute("tenantCode",user.getTenantCode());
            UserUtil.getSession().setAttribute("username",user.getUserAccount());
            SimpleAuthenticationInfo simpleAuthenticationInfo = new SimpleAuthenticationInfo(new Principal(user), user.getPassword(),null, getName());
            logger.info("simpleAuthenticationInfo:{}",simpleAuthenticationInfo.getCredentials());
            return simpleAuthenticationInfo;
        } else {
        	throw new AuthenticationException("帐号不存在,请和租户管理员联系");
//            return null;
        }


    }

    /**
     * 授权查询回调函数, 进行鉴权但缓存中无用户的授权信息时调用
     *
     * @param principalCollection
     * @return
     */
    @Override
    protected AuthorizationInfo doGetAuthorizationInfo(PrincipalCollection principalCollection) {
        Principal principal = (Principal) getAvailablePrincipal(principalCollection);
        // 是否允许多处登录
        if (!Constant.TRUE.equals(multiAccountLogin)) {
            Collection<Session> sessions = systemService.getActiveSessions(true, principal, UserUtil.getSession());
            if (sessions.size() > 0) {
                // 如果是登录进来的,则踢出已在线用户
                if (UserUtil.getSubject().isAuthenticated()) {
                    for (Session session : sessions) {
                        systemService.deleteSession(session);
                    }
                }
                // 记住我进来的,并且当前用户已登录,则退出当前用户提示信息。
                else {
                    UserUtil.getSubject().logout();
                    throw new AuthenticationException("msg:账号已在其它地方登录,请重新登录。");
                }
            }
        }
        User user = systemService.getUserByLoginName(principal.getLoginName());
        if (user != null) {
            SimpleAuthorizationInfo info = new SimpleAuthorizationInfo();
            List<ResourceMenu> list = systemService.getMenuList(user.getId());
            logger.info("验证权限:"  + JSONObject.toJSONString(list));
            for (ResourceMenu menu : list) {
                if (StringUtils.isNotBlank(menu.getPermission())) {
                    // 添加基于Permission的权限信息
                    for (String permission : StringUtils.split(menu.getPermission(), ",")) {
                        info.addStringPermission(permission);
                    }
                }
            }
            UserUtil.getSession().setAttribute("menuList",list);
            // 添加用户权限
           // info.addStringPermission("user");
            List<Role> roleList = systemService.getRoleList(user.getId());
            if (roleList != null) {
                // 添加用户角色信息
                for (Role role : roleList) {
                    info.addRole(role.getRoleName());
                }
            }

            // 更新登录IP和时间
            //   getSystemService().updateUserLoginInfo(user);
            // 记录登录日志
            //    LogUtils.saveLog(Servlets.getRequest(), "系统登录");
            return info;
        } else {
            return null;
        }
    }

    /**
     * 设定密码校验的Hash算法与迭代次数
     */
    @PostConstruct
    public void initCredentialsMatcher() {

      /*  HashedCredentialsMatcher matcher = new HashedCredentialsMatcher(HASH_ALGORITHM);
        matcher.setHashIterations(HASH_INTERATIONS);
        matcher.setStoredCredentialsHexEncoded(true);*/
        setAuthenticationTokenClass(AuthenticationToken.class);
        setCredentialsMatcher(new CustomCredentialsMatcher());
    }

    /**
     * 授权用户信息
     */
    public static class Principal implements Serializable {

        private static final long serialVersionUID = 1L;

        private String id; // 编号
        private String loginName; // 登录名
        private String name; // 姓名
        private String tenantCode;

//		private Map<String, Object> cacheMap;

        public Principal(User user) {
            this.id = user.getId();
            this.loginName = user.getUserAccount();
            this.name = user.getName();
            this.tenantCode = user.getTenantCode();
        }

        public String getId() {
            return id;
        }

        public String getLoginName() {
            return loginName;
        }

        public String getName() {
            return name;
        }
        public String getTenantCode() {return tenantCode;}

        /**
         * 获取SESSIONID
         */
        public String getSessionid() {
            try {
                return "";
            } catch (Exception e) {
                return "";
            }
        }

        @Override
        public String toString() {
            return id;
        }

    }
}