Netty客户端同步获取结果

上次服务间通信是异步的,现在想实现客户端同步拿到服务端响应结果。实现如下:
在NettyClientHandler类中增加一个结果缓存器

 Map<Long,Protocol<ResponseMsg>> resultMap = new ConcurrentHashMap<>();

修改方法

@Override
protected void channelRead0(ChannelHandlerContext channelHandlerContext,Protocol<ResponseMsg> o) throws Exception {
    logger.info("channelRead0--------------"+Thread.currentThread().getName());
    logger.info("消费者接收到的消息为{}", JSONObject.toJSONString(o));
    resultMap.put(o.getId(),o);
}
public Protocol<ResponseMsg> sendMsg(Protocol<RequestMsg> message){
        channel.writeAndFlush(message);
        while (true){
            Protocol<ResponseMsg> remove = resultMap.remove(message.getId());
            if(remove!=null){
                return remove;
            }
        }
    }

测试类

public class NettyTest {

    public static void main(String[] args) {

        new Thread(()->{
            NettyServer.startNettyServer();
        }).start();

        new Thread(()->{
            NettyClient instance = NettyClient.getInstance();
            try {
                while (true){
                    Thread.sleep(2000);
                    Protocol<RequestMsg> protocol = new Protocol<>();
                    protocol.setMsgType((short)1);
                    RequestMsg requestMsg = new RequestMsg();
                    requestMsg.setMsg("hello:"+System.currentTimeMillis());
                    requestMsg.setOther("你好啊");
                    protocol.setBody(requestMsg);
                    Protocol<ResponseMsg> responseMsgProtocol = instance.sendMsg(protocol);
                    System.out.println("同步获取到结果:"+ responseMsgProtocol);
                }
            } catch (Exception e) {
                e.printStackTrace();
            }

        }).start();
    }
}

在这里插入图片描述

以上实现了同步获取的功能,但是不太优雅,修改如下,自定义一个Future类

public class MyFuture extends CompletableFuture<Object> {

    private static final Logger logger = LoggerFactory.getLogger(MyFuture.class);

    private Sync sync;
    //请求协议对象
    private Protocol<RequestMsg> requestMsgProtocol;
    //响应协议对象
    private Protocol<ResponseMsg> responseMsgProtocol;
    //记录开始时间
    private Long startTime;

    public MyFuture(Protocol<RequestMsg> requestRpcProtocol){
        this.sync = new Sync();
        this.requestMsgProtocol = requestRpcProtocol;
        this.startTime = System.currentTimeMillis();
    }

    @Override
    public boolean isDone() {
        return sync.isDone();
    }

    @Override
    public Object get() throws InterruptedException, ExecutionException {
        sync.acquire(-1);
        if(this.requestMsgProtocol != null){
            return this.responseMsgProtocol.getBody();
        }else {
            return null;
        }
    }

    @Override
    public Object get(long timeout, TimeUnit unit) throws InterruptedException, ExecutionException, TimeoutException {
        boolean success = sync.tryAcquireNanos(-1, unit.toNanos(timeout));
        if(success){
            if(this.responseMsgProtocol != null){
                return responseMsgProtocol.getBody().getResult();
            }else{
                return null;
            }
        }else {
            throw new RuntimeException("Timeout exception. Request id: " + this.requestMsgProtocol.getId());
        }
    }

    @Override
    public boolean isCancelled() {
        throw new UnsupportedOperationException();
    }

    @Override
    public boolean cancel(boolean mayInterruptIfRunning) {
        throw new UnsupportedOperationException();
    }

    public void done(Protocol<ResponseMsg> responseRpcProtocol){
        this.responseMsgProtocol = responseRpcProtocol;
        sync.release(1);
        long responseTime = System.currentTimeMillis() - startTime;
        logger.info("{},responseTime:{}",responseRpcProtocol.getId(),responseTime);
    }

    /**
     * 自定义方法
     */
    static class Sync extends AbstractQueuedSynchronizer{
        private static final long serialVersionUID = 1L;
        private final int done = 1;
        private final int pending = 0;

        @Override
        protected boolean tryAcquire(int arg) {
            return getState() == done;
        }

        @Override
        protected boolean tryRelease(int arg) {
            if(getState() == pending){
                if(compareAndSetState(pending,done)){
                    return true;
                }
            }
            return false;
        }

        public boolean isDone(){
            getState();
            return getState() == done;
        }
    }
}

NettyClientHandler 类增加一个变量

 Map<Long,MyFuture> myFutureMap = new ConcurrentHashMap<>();

修改channelRead0 方法

@Override
    protected void channelRead0(ChannelHandlerContext channelHandlerContext,Protocol<ResponseMsg> responseMsgProtocol) throws Exception {
        logger.info("channelRead0--------------"+Thread.currentThread().getName());
        logger.info("消费者接收到的消息为{}", JSONObject.toJSONString(responseMsgProtocol));
        MyFuture remove = myFutureMap.remove(responseMsgProtocol.getId());
        if(remove!=null){
            remove.done(responseMsgProtocol);
        }
    }

修改sendMsg方法

public MyFuture sendMsg(Protocol<RequestMsg> requestMsgProtocol){
        MyFuture myFuture = new MyFuture(requestMsgProtocol);
        myFutureMap.put(requestMsgProtocol.getId(), myFuture);
        channel.writeAndFlush(requestMsgProtocol);
        return myFuture;
    }

测试类的方法就如下:

public class NettyTest {

    public static void main(String[] args) {

        new Thread(()->{
            NettyServer.startNettyServer();
        }).start();

        new Thread(()->{
            NettyClient instance = NettyClient.getInstance();
            try {
                while (true){
                    Thread.sleep(2000);
                    Protocol<RequestMsg> protocol = new Protocol<>();
                    protocol.setMsgType((short)1);
                    RequestMsg requestMsg = new RequestMsg();
                    requestMsg.setMsg("hello:"+System.currentTimeMillis());
                    requestMsg.setOther("你好啊");
                    protocol.setBody(requestMsg);
                    MyFuture myFuture = instance.sendMsg(protocol);
                    System.out.println("同步获取到结果:"+ myFuture.get());
                }
            } catch (Exception e) {
                e.printStackTrace();
            }

        }).start();
    }
}

在这里插入图片描述