手写一个RPC框架,理解更透彻(附源码)

 一、手写前言

前段时间看到一篇不错的个R更透文章《看了这篇你就会手写RPC框架了》,于是框架便来了兴趣对着实现了一遍,后面觉得还有很多优化的理解地方便对其进行了改进。

主要改动点如下:

 除了Java序列化协议,彻附增加了protobuf和kryo序列化协议,源码配置即用。手写  增加多种负载均衡算法(随机、个R更透轮询、框架加权轮询、理解平滑加权轮询),彻附配置即用。源码  客户端增加本地服务列表缓存,手写提高性能。个R更透  修复高并发情况下,框架netty导致的内存泄漏问题  由原来的每个请求建立一次连接,改为建立TCP长连接,并多次复用。  服务端增加线程池提高消息处理能力

二、介绍

RPC,即 Remote Procedure Call(远程过程调用),调用远程计算机上的服务,就像调用本地服务一样。RPC可以很好的解耦系统,如WebService就是一种基于Http协议的RPC。亿华云

总的来说,就如下几个步骤:

 客户端(ServerA)执行远程方法时就调用client stub传递类名、方法名和参数等信息。  client stub会将参数等信息序列化为二进制流的形式,然后通过Sockect发送给服务端(ServerB)  服务端收到数据包后,server stub 需要进行解析反序列化为类名、方法名和参数等信息。  server stub调用对应的本地方法,并把执行结果返回给客户端

所以一个RPC框架有如下角色:

 服务消费者远程方法的调用方,即客户端。一个服务既可以是消费者也可以是提供者。  服务提供者远程服务的提供方,即服务端。一个服务既可以是消费者也可以是提供者。  注册中心保存服务提供者的服务地址等信息,一般由zookeeper、redis等实现。  监控运维(可选)监控接口的响应时间、统计请求数量等,及时发现系统问题并发出告警通知。

三、实现

本RPC框架rpc-spring-boot-starter涉及技术栈如下:

 使用zookeeper作为注册中心  使用netty作为通信框架  消息编解码:protostuff、亿华云计算kryo、java  spring  使用SPI来根据配置动态选择负载均衡算法等

由于代码过多,这里只讲几处改动点。

3.1动态负载均衡算法

1.编写LoadBalance的实现类

2.自定义注解 @LoadBalanceAno 

/**   * 负载均衡注解   */  @Target(ElementType.TYPE)  @Retention(RetentionPolicy.RUNTIME)  @Documented  public @interface LoadBalanceAno {       String value() default "";  }  /**   * 轮询算法   */  @LoadBalanceAno(RpcConstant.BALANCE_ROUND)  public class FullRoundBalance implements LoadBalance {       private static Logger logger = LoggerFactory.getLogger(FullRoundBalance.class);      private volatile int index;      @Override      public synchronized Service chooseOne(List<Service> services) {           // 加锁防止多线程情况下,index超出services.size()          if (index == services.size()) {               index = 0;         }          return services.get(index++);      }  } 

3.新建在resource目录下META-INF/servers文件夹并创建文件

4.RpcConfig增加配置项loadBalance 

/**   * @author 2YSP   * @date 2020/7/26 15:13   */  @ConfigurationProperties(prefix = "sp.rpc")  public class RpcConfig {       /**       * 服务注册中心地址       */      private String registerAddress = "127.0.0.1:2181";      /**       * 服务暴露端口       */      private Integer serverPort = 9999;      /**       * 服务协议       */      private String protocol = "java";      /**       * 负载均衡算法       */      private String loadBalance = "random";      /**       * 权重,默认为1       */      private Integer weight = 1;     // 省略getter setter  } 

5.在自动配置类RpcAutoConfiguration根据配置选择对应的算法实现类 

/**       * 使用spi匹配符合配置的负载均衡算法      *       * @param name       * @return       */      private LoadBalance getLoadBalance(String name) {           ServiceLoader<LoadBalance> loader = ServiceLoader.load(LoadBalance.class);          Iterator<LoadBalance> iterator = loader.iterator();          while (iterator.hasNext()) {               LoadBalance loadBalance = iterator.next();              LoadBalanceAno ano = loadBalance.getClass().getAnnotation(LoadBalanceAno.class);              Assert.notNull(ano, "load balance name can not be empty!");             if (name.equals(ano.value())) {                   return loadBalance;              }          }          throw new RpcException("invalid load balance config");      }   @Bean      public ClientProxyFactory proxyFactory(@Autowired RpcConfig rpcConfig) {           ClientProxyFactory clientProxyFactory = new ClientProxyFactory();          // 设置服务发现着          clientProxyFactory.setServerDiscovery(new           ZookeeperServerDiscovery(rpcConfig.getRegisterAddress()));          // 设置支持的协议          Map<String, MessageProtocol> supportMessageProtocols = buildSupportMessageProtocols();          clientProxyFactory.setSupportMessageProtocols(supportMessageProtocols);          // 设置负载均衡算法          LoadBalance loadBalance = getLoadBalance(rpcConfig.getLoadBalance());          clientProxyFactory.setLoadBalance(loadBalance);          // 设置网络层实现          clientProxyFactory.setNetClient(new NettyNetClient());          return clientProxyFactory;      } 

3.2本地服务列表缓存

使用Map来缓存数据 

/**   * 服务发现本地缓存   */  public class ServerDiscoveryCache {       /**       * key: serviceName       */      private static final Map<String, List<Service>> SERVER_MAP = new ConcurrentHashMap<>();      /**       * 客户端注入的远程服务service class       */      public static final List<String> SERVICE_CLASS_NAMES = new ArrayList<>();      public static void put(String serviceName, List<Service> serviceList) {           SERVER_MAP.put(serviceName, serviceList);      }      /**       * 去除指定的值       * @param serviceName       * @param service       */      public static void remove(String serviceName, Service service) {           SERVER_MAP.computeIfPresent(serviceName, (key, value) ->                  value.stream().filter(o -> !o.toString().equals(service.toString())).collect(Collectors.toList())          );      }      public static void removeAll(String serviceName) {           SERVER_MAP.remove(serviceName);      }      public static boolean isEmpty(String serviceName) {           return SERVER_MAP.get(serviceName) == null || SERVER_MAP.get(serviceName).size() == 0;      }      public static List<Service> get(String serviceName) {           return SERVER_MAP.get(serviceName);      }  } 

ClientProxyFactory,先查本地缓存,缓存没有再查询zookeeper。 

/**       * 根据服务名获取可用的服务地址列表       * @param serviceName       * @return       */      private List<Service> getServiceList(String serviceName) {           List<Service> services;          synchronized (serviceName){               if (ServerDiscoveryCache.isEmpty(serviceName)) {                   services = serverDiscovery.findServiceList(serviceName);                  if (services == null || services.size() == 0) {                       throw new RpcException("No provider available!");                  }                  ServerDiscoveryCache.put(serviceName, services);              } else {                   services = ServerDiscoveryCache.get(serviceName);              }          }          return services;      } 

问题:如果服务端因为宕机或网络问题下线了,缓存却还在就会导致客户端请求已经不可用的服务端,增加请求失败率。解决方案:由于服务端注册的是临时节点,所以如果服务端下线节点会被移除。只要监听zookeeper的子节点,如果新增或删除子节点就直接清空本地缓存即可。

推荐:100道Java中高级面试题汇总+详细拆解 

DefaultRpcProcessor  /**   * Rpc处理者,支持服务启动暴露,自动注入Service   * @author 2YSP   * @date 2020/7/26 14:46   */  public class DefaultRpcProcessor implements ApplicationListener<ContextRefreshedEvent> {      @Override      public void onApplicationEvent(ContextRefreshedEvent event) {           // Spring启动完毕过后会收到一个事件通知          if (Objects.isNull(event.getApplicationContext().getParent())){               ApplicationContext context = event.getApplicationContext();              // 开启服务              startServer(context);              // 注入Service              injectService(context);          }      }      private void injectService(ApplicationContext context) {           String[] names = context.getBeanDefinitionNames();          for(String name : names){               Class<?> clazz = context.getType(name);              if (Objects.isNull(clazz)){                   continue;              }              Field[] declaredFields = clazz.getDeclaredFields();              for(Field field : declaredFields){                   // 找出标记了InjectService注解的属性                  InjectService injectService = field.getAnnotation(InjectService.class);                  if (injectService == null){                       continue;                  }                    Class<?> fieldfieldClass = field.getType();                  Object object = context.getBean(name);                  field.setAccessible(true);                  try {                       field.set(object,clientProxyFactory.getProxy(fieldClass));                  } catch (IllegalAccessException e) {                       e.printStackTrace();                  }      // 添加本地服务缓存                  ServerDiscoveryCache.SERVICE_CLASS_NAMES.add(fieldClass.getName());              }          }          // 注册子节点监听          if (clientProxyFactory.getServerDiscovery() instanceof ZookeeperServerDiscovery){               ZookeeperServerDiscovery serverDiscovery = (ZookeeperServerDiscovery) clientProxyFactory.getServerDiscovery();              ZkClient zkClient = serverDiscovery.getZkClient();              ServerDiscoveryCache.SERVICE_CLASS_NAMES.forEach(name ->{                   String servicePath = RpcConstant.ZK_SERVICE_PATH + RpcConstant.PATH_DELIMITER + name + "/service";                  zkClient.subscribeChildChanges(servicePath, new ZkChildListenerImpl());              });              logger.info("subscribe service zk node successfully");          }      }      private void startServer(ApplicationContext context) {          ...      }  } 

ZkChildListenerImpl 

/**   * 子节点事件监听处理类   */  public class ZkChildListenerImpl implements IZkChildListener {       private static Logger logger = LoggerFactory.getLogger(ZkChildListenerImpl.class);      /**       * 监听子节点的删除和新增事件       * @param parentPath /rpc/serviceName/service       * @param childList       * @throws Exception       */      @Override      public void handleChildChange(String parentPath, List<String> childList) throws Exception {           logger.debug("Child change parentPath:[{ }] -- childList:[{ }]", parentPath, childList);          // 只要子节点有改动就清空缓存         String[] arr = parentPath.split("/");          ServerDiscoveryCache.removeAll(arr[2]);      }  } 

3.3nettyClient支持TCP长连接

这部分的改动最多,先增加新的源码库sendRequest接口。

添加接口

实现类NettyNetClient 

/**   * @author 2YSP   * @date 2020/7/25 20:12   */  public class NettyNetClient implements NetClient {      private static Logger logger = LoggerFactory.getLogger(NettyNetClient.class);      private static ExecutorService threadPool = new ThreadPoolExecutor(4, 10, 200,              TimeUnit.SECONDS, new LinkedBlockingQueue<>(1000), new ThreadFactoryBuilder()              .setNameFormat("rpcClient-%d")              .build());      private EventLoopGroup loopGroup = new NioEventLoopGroup(4);      /**       * 已连接的服务缓存       * key: 服务地址,格式:ip:port       */      public static Map<String, SendHandlerV2> connectedServerNodes = new ConcurrentHashMap<>();      @Override      public byte[] sendRequest(byte[] data, Service service) throws InterruptedException {     ....          return respData;      }      @Override      public RpcResponse sendRequest(RpcRequest rpcRequest, Service service, MessageProtocol messageProtocol) {          String address = service.getAddress();          synchronized (address) {               if (connectedServerNodes.containsKey(address)) {                   SendHandlerV2 handler = connectedServerNodes.get(address);                  logger.info("使用现有的连接");                  return handler.sendRequest(rpcRequest);              }              String[] addrInfo = address.split(":");              final String serverAddress = addrInfo[0];              final String serverPort = addrInfo[1];              final SendHandlerV2 handler = new SendHandlerV2(messageProtocol, address);              threadPool.submit(() -> {                           // 配置客户端                          Bootstrap b = new Bootstrap();                          b.group(loopGroup).channel(NioSocketChannel.class)                                  .option(ChannelOption.TCP_NODELAY, true)                                  .handler(new ChannelInitializer<SocketChannel>() {                                       @Override                                      protected void initChannel(SocketChannel socketChannel) throws Exception {                                           ChannelPipeline pipeline = socketChannel.pipeline();                                          pipeline                                                  .addLast(handler);                                      }                                  });                          // 启用客户端连接                          ChannelFuture channelFuture = b.connect(serverAddress, Integer.parseInt(serverPort));                          channelFuture.addListener(new ChannelFutureListener() {                               @Override                              public void operationComplete(ChannelFuture channelFuture) throws Exception {                                   connectedServerNodes.put(address, handler);                              }                          });                      }              );              logger.info("使用新的连接。。。");              return handler.sendRequest(rpcRequest);          }      }  } 

每次请求都会调用sendRequest()方法,用线程池异步和服务端创建TCP长连接,连接成功后将SendHandlerV2缓存到ConcurrentHashMap中方便复用,后续请求的请求地址(ip+port)如果在connectedServerNodes中存在则使用connectedServerNodes中的handler处理不再重新建立连接。

SendHandlerV2 

/**   * @author 2YSP   * @date 2020/8/19 20:06   */  public class SendHandlerV2 extends ChannelInboundHandlerAdapter {      private static Logger logger = LoggerFactory.getLogger(SendHandlerV2.class);      /**       * 等待通道建立最大时间       */      static final int CHANNEL_WAIT_TIME = 4;      /**       * 等待响应最大时间       */      static final int RESPONSE_WAIT_TIME = 8;      private volatile Channel channel;      private String remoteAddress;      private static Map<String, RpcFuture<RpcResponse>> requestMap = new ConcurrentHashMap<>();      private MessageProtocol messageProtocol;     private CountDownLatch latch = new CountDownLatch(1);      public SendHandlerV2(MessageProtocol messageProtocol,String remoteAddress) {           this.messageProtocol = messageProtocol;          this.remoteAddress = remoteAddress;      }      @Override      public void channelRegistered(ChannelHandlerContext ctx) throws Exception {           this.channel = ctx.channel();          latch.countDown();      }      @Override      public void channelActive(ChannelHandlerContext ctx) throws Exception {           logger.debug("Connect to server successfully:{ }", ctx);      }      @Override      public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {           logger.debug("Client reads message:{ }", msg);          ByteBuf byteBuf = (ByteBuf) msg;          byte[] resp = new byte[byteBuf.readableBytes()];          byteBuf.readBytes(resp);          // 手动回收          ReferenceCountUtil.release(byteBuf);          RpcResponse response = messageProtocol.unmarshallingResponse(resp);         RpcFuture<RpcResponse> future = requestMap.get(response.getRequestId());          future.setResponse(response);      }      @Override      public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {           cause.printStackTrace();          logger.error("Exception occurred:{ }", cause.getMessage());          ctx.close();      }      @Override      public void channelReadComplete(ChannelHandlerContext ctx) throws Exception {           ctx.flush();      }      @Override      public void channelInactive(ChannelHandlerContext ctx) throws Exception {           super.channelInactive(ctx);          logger.error("channel inactive with remoteAddress:[{ }]",remoteAddress);          NettyNetClient.connectedServerNodes.remove(remoteAddress);      }      @Override      public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception {           super.userEventTriggered(ctx, evt);      }      public RpcResponse sendRequest(RpcRequest request) {           RpcResponse response;          RpcFuture<RpcResponse> future = new RpcFuture<>();          requestMap.put(request.getRequestId(), future);          try {               byte[] data = messageProtocol.marshallingRequest(request);              ByteBuf reqBuf = Unpooled.buffer(data.length);              reqBuf.writeBytes(data);              if (latch.await(CHANNEL_WAIT_TIME,TimeUnit.SECONDS)){                   channel.writeAndFlush(reqBuf);                  // 等待响应                  response = future.get(RESPONSE_WAIT_TIME, TimeUnit.SECONDS);              }else {                   throw new RpcException("establish channel time out");              }          } catch (Exception e) {               throw new RpcException(e.getMessage());          } finally {               requestMap.remove(request.getRequestId());          }          return response;      }  } 

RpcFuture 

package cn.sp.rpc.client.net;  import java.util.concurrent.*;  /**   * @author 2YSP   * @date 2020/8/19 22:31   */  public class RpcFuture<T> implements Future<T> {       private T response;      /**       * 因为请求和响应是一一对应的,所以这里是1       */      private CountDownLatch countDownLatch = new CountDownLatch(1);      /**       * Future的请求时间,用于计算Future是否超时       */      private long beginTime = System.currentTimeMillis();      @Override      public boolean cancel(boolean mayInterruptIfRunning) {           return false;      }     @Override      public boolean isCancelled() {           return false;      }      @Override      public boolean isDone() {           if (response != null) {               return true;          }          return false;      }      /**       * 获取响应,直到有结果才返回       * @return       * @throws InterruptedException       * @throws ExecutionException      */      @Override      public T get() throws InterruptedException, ExecutionException {           countDownLatch.await();          return response;      }      @Override      public T get(long timeout, TimeUnit unit) throws InterruptedException, ExecutionException, TimeoutException {           if (countDownLatch.await(timeout,unit)){               return response;          }          return null;      }      public void setResponse(T response) {           this.response = response;          countDownLatch.countDown();      }      public long getBeginTime() {           return beginTime;      }  } 

此处逻辑,第一次执行 SendHandlerV2#sendRequest() 时channel需要等待通道建立好之后才能发送请求,所以用CountDownLatch来控制,等待通道建立。

自定义Future+requestMap缓存来实现netty的请求和阻塞等待响应,RpcRequest对象在创建时会生成一个请求的唯一标识requestId,发送请求前先将RpcFuture缓存到requestMap中,key为requestId,读取到服务端的响应信息后(channelRead方法),将响应结果放入对应的RpcFuture中。

SendHandlerV2#channelInactive() 方法中,如果连接的服务端异常断开连接了,则及时清理缓存中对应的serverNode。

四、压力测试

测试环境:

 (英特尔)Intel(R) Core(TM) i5-6300HQ CPU @ 2.30GHz 4核  windows10家庭版(64位)  16G内存

1.本地启动zookeeper

2.本地启动一个消费者,两个服务端,轮询算法

3.使用ab进行压力测试,4个线程发送10000个请求

ab -c 4 -n 10000 http://localhost:8080/test/user?id=1 

测试结果:

从图片可以看出,10000个请求只用了11s,比之前的130+秒耗时减少了10倍以上。

代码地址:

 https://github.com/2YSP/rpc-spring-boot-starter   https://github.com/2YSP/rpc-example

参考

 https://www.cnblogs.com/itoak/p/13370031.html 
滇ICP备2023000592号-31