The Best Way to Authenticate WebSockets in NestJS

Posted onby

0 views

WebSockets, at least at this point, are not first-class citizens in NestJS. A few things don't work out as well.

This month, I had to build a real-time chat with WebSockets. This was the first time we were introducing a new transport layer in our backend. The first thing that made sense to build was authentication.

To kick things off, I went through the official documentation. The documentation was pretty short. It talked about how there's no difference from how the authentication works in HTTP.

However, post-implementation, the behaviour didn't make sense. For starters, the Guard wouldn't validate the Socket connection on the initial connection request, allowing anybody to connect and hold on to the connection. Plus, NestJS would only invoke the Guard when there was an event from the client. And, even then, NestJS would not disconnect the Socket connection when the authentication failed.

Here's why that's problematic:

  • Any unauthenticated user could connect to our server and unnecessarily drain our resources
  • Since the Guard never disconnected the Socket connection, a broadcast could potentially end up with an unauthenticated user (unless handled)

Following some Google searches, I was happy to see that this was a common problem. I stumbled upon a related, long outstanding issue on GitHub where people had the same concern.

After going through all the comments, there were two options to prevent the connection from being established altogether when the authentication failed:

  • Write a custom IO Adapter
  • Write a middleware

Both work. I wrote a middleware to drop any invalid connection immediately.

Code

Middleware

type SocketMiddleware = (
  socket: Socket,
  next: (err?: Error) => void,
) => void;

export const AuthWsMiddleware = (
  jwtService: JwtService,
  configService: ConfigService,
  userService: UserService,
): SocketMiddleware => {
  return async (socket: Socket, next) => {
    try {
      const token = socket.handshake?.auth?.token;

      if (!token) {
        throw new Error("Authorization token is missing");
      }

      let payload: JwtTokenPayload | null = null;

      try {
        payload = await jwtService.verifyAsync<JwtTokenPayload>(token);
      } catch (error) {
        throw new Error("Authorization token is invalid");
      }

      const strategy = new JwtStrategy(configService, userService);
      const user = await strategy.validate(payload);

      if (!user) {
        throw new Error("User does not exist");
      }

      socket = Object.assign(socket, {
        user: user!,
      });
      next();
    } catch (error) {
      next(new Error("Unauthorized"));
    }
  };
};

Gateway

export class Gateway implements OnGatewayInit {
  constructor(
    private readonly configService: ConfigService,
    private readonly jwtService: JwtService,
    private readonly userService: UserService,
  ) {}

  async afterInit(@ConnectedSocket() socket: Socket) {
    socket.use(
      AuthWsMiddleware(
        this.jwtService,
        this.configService,
        this.userService,
      ),
    );
  }
}

Resources