基于spring-boot 2.x扩展WebSocket,支持细粒度控制

项目地址

项目地址,欢迎提出宝贵意见。

spring security文档

spring security文档片段
spring security只支持在应用启动时加载WebSocket权限信息,修改权限必须要重启应用才能生效,不能按照用户的权限动态授权

spring提供的HandshakeInterceptor接口可以自定义拦截器,但是只能在握手是进行一次拦截,无法细粒度控制权限

StompSubProtocolHandler源码

package org.springframework.web.socket.messaging;

public class StompSubProtocolHandler implements SubProtocolHandler, ApplicationEventPublisherAware {
	/**
	 * 接收客户端消息方法
	 * Handle incoming WebSocket messages from clients.
	 */
	public void handleMessageFromClient(WebSocketSession session,
			WebSocketMessage<?> webSocketMessage, MessageChannel outputChannel) {

		List<Message<byte[]>> messages;
		try {
			ByteBuffer byteBuffer;
			if (webSocketMessage instanceof TextMessage) {
				byteBuffer = ByteBuffer.wrap(((TextMessage) webSocketMessage).asBytes());
			}
			else if (webSocketMessage instanceof BinaryMessage) {
				byteBuffer = ((BinaryMessage) webSocketMessage).getPayload();
			}
			else {
				return;
			}

			BufferingStompDecoder decoder = this.decoders.get(session.getId());
			if (decoder == null) {
				throw new IllegalStateException("No decoder for session id '" + session.getId() + "'");
			}

			messages = decoder.decode(byteBuffer);
			if (messages.isEmpty()) {
				if (logger.isTraceEnabled()) {
					logger.trace("Incomplete STOMP frame content received in session " +
							session + ", bufferSize=" + decoder.getBufferSize() +
							", bufferSizeLimit=" + decoder.getBufferSizeLimit() + ".");
				}
				return;
			}
		}
		catch (Throwable ex) {
			if (logger.isErrorEnabled()) {
				logger.error("Failed to parse " + webSocketMessage +
						" in session " + session.getId() + ". Sending STOMP ERROR to client.", ex);
			}
			handleError(session, ex, null);
			return;
		}

		for (Message<byte[]> message : messages) {
			try {
				StompHeaderAccessor headerAccessor =
						MessageHeaderAccessor.getAccessor(message, StompHeaderAccessor.class);
				Assert.state(headerAccessor != null, "No StompHeaderAccessor");

				headerAccessor.setSessionId(session.getId());
				headerAccessor.setSessionAttributes(session.getAttributes());
				headerAccessor.setUser(getUser(session));
				headerAccessor.setHeader(SimpMessageHeaderAccessor.HEART_BEAT_HEADER, headerAccessor.getHeartbeat());
				if (!detectImmutableMessageInterceptor(outputChannel)) {
					headerAccessor.setImmutable();
				}

				if (logger.isTraceEnabled()) {
					logger.trace("From client: " + headerAccessor.getShortLogMessage(message.getPayload()));
				}

				StompCommand command = headerAccessor.getCommand();
				boolean isConnect = StompCommand.CONNECT.equals(command) || StompCommand.STOMP.equals(command);
				if (isConnect) {
					this.stats.incrementConnectCount();
				}
				else if (StompCommand.DISCONNECT.equals(command)) {
					this.stats.incrementDisconnectCount();
				}

				try {
					SimpAttributesContextHolder.setAttributesFromMessage(message);
					boolean sent = outputChannel.send(message);

					if (sent) {
						if (isConnect) {
							Principal user = headerAccessor.getUser();
							if (user != null && user != session.getPrincipal()) {
								this.stompAuthentications.put(session.getId(), user);
							}
						}
						if (this.eventPublisher != null) {
							Principal user = getUser(session);
							if (isConnect) {
								publishEvent(this.eventPublisher, new SessionConnectEvent(this, message, user));
							}
							else if (StompCommand.SUBSCRIBE.equals(command)) {
								publishEvent(this.eventPublisher, new SessionSubscribeEvent(this, message, user));
							}
							else if (StompCommand.UNSUBSCRIBE.equals(command)) {
								publishEvent(this.eventPublisher, new SessionUnsubscribeEvent(this, message, user));
							}
						}
					}
				}
				finally {
					SimpAttributesContextHolder.resetAttributes();
				}
			}
			catch (Throwable ex) {
				if (logger.isErrorEnabled()) {
					logger.error("Failed to send client message to application via MessageChannel" +
							" in session " + session.getId() + ". Sending STOMP ERROR to client.", ex);
				}
				handleError(session, ex, message);
			}
		}
	}
	/**
	 * 向客户端发送消息的方法
	 * Handle STOMP messages going back out to WebSocket clients.
	 */
	@Override
	@SuppressWarnings("unchecked")
	public void handleMessageToClient(WebSocketSession session, Message<?> message) {
		if (!(message.getPayload() instanceof byte[])) {
			if (logger.isErrorEnabled()) {
				logger.error("Expected byte[] payload. Ignoring " + message + ".");
			}
			return;
		}

		StompHeaderAccessor accessor = getStompHeaderAccessor(message);
		StompCommand command = accessor.getCommand();

		if (StompCommand.MESSAGE.equals(command)) {
			if (accessor.getSubscriptionId() == null && logger.isWarnEnabled()) {
				logger.warn("No STOMP \"subscription\" header in " + message);
			}
			String origDestination = accessor.getFirstNativeHeader(SimpMessageHeaderAccessor.ORIGINAL_DESTINATION);
			if (origDestination != null) {
				accessor = toMutableAccessor(accessor, message);
				accessor.removeNativeHeader(SimpMessageHeaderAccessor.ORIGINAL_DESTINATION);
				accessor.setDestination(origDestination);
			}
		}
		else if (StompCommand.CONNECTED.equals(command)) {
			this.stats.incrementConnectedCount();
			accessor = afterStompSessionConnected(message, accessor, session);
			if (this.eventPublisher != null) {
				try {
					SimpAttributes simpAttributes = new SimpAttributes(session.getId(), session.getAttributes());
					SimpAttributesContextHolder.setAttributes(simpAttributes);
					Principal user = getUser(session);
					publishEvent(this.eventPublisher, new SessionConnectedEvent(this, (Message<byte[]>) message, user));
				}
				finally {
					SimpAttributesContextHolder.resetAttributes();
				}
			}
		}

		byte[] payload = (byte[]) message.getPayload();
		if (StompCommand.ERROR.equals(command) && getErrorHandler() != null) {
			Message<byte[]> errorMessage = getErrorHandler().handleErrorMessageToClient((Message<byte[]>) message);
			if (errorMessage != null) {
				accessor = MessageHeaderAccessor.getAccessor(errorMessage, StompHeaderAccessor.class);
				Assert.state(accessor != null, "No StompHeaderAccessor");
				payload = errorMessage.getPayload();
			}
		}
		sendToClient(session, accessor, payload);
	}
}

阅读源码可以发现spring原生的消息处理类不支持自定义拦截器

interceptable-websocket

做这个项目是为了细粒度的动态控制WebSocket的权限,项目对StompSubProtocolHandler类和其他相关的类做了扩展,增加了对自定义拦截器的支持

具体实现在extension包,拦截器实现在interceptor包

使用方法

直接使用

项目已经发布到maven中央仓库,直接在pom.xml中引用即可

<dependencies>
    <dependency>
        <groupId>com.xzixi</groupId>
        <artifactId>interceptable-websocket</artifactId>
        <version>1.0</version>
    </dependency>
</dependencies>

修改后使用

  1. 下载项目

    打开git bash窗口,执行命令git clone git@gitee.com:xuelingkang/websocket.git
  2. 编译并安装到本地maven仓库

    进入工程目录,打开cmd窗口,执行命令mvn clean install
  3. 在自己的项目中引用
<dependencies>
    <dependency>
        <groupId>com.xzixi</groupId>
        <artifactId>interceptable-websocket</artifactId>
        <version>1.0</version>
    </dependency>
</dependencies>

配置类:

package com.xzixi.websocket.interceptablewebsocketdemo.config;

@Configuration
@EnableInterceptableWebSocketMessageBroker // 增加注解
public class InterceptableSecurityWebSocketConfig extends AbstractInterceptableSecurityWebSocketMessageBrokerConfigurer {
    @Override
    public void registerStompEndpoints(InterceptableWebMvcStompEndpointRegistry registry) {
        // 注册拦截器
        registry.addFromClientInterceptor(accessDecisionFromClientInterceptor()) // 消息授权决策
                .addFromClientInterceptor(sessionIdUnRegistryInterceptor()) // sessionId记录
                .addToClientInterceptor(sessionIdRegistryInterceptor()); // sessionId移除
    }
}

具体使用方法请参考案例工程:interceptable-websocket-demo