How to handle logs and tracing in Spring WebFlux and microservices

How to handle logs and tracing in Spring WebFlux and microservices

It was a bumpy journey when I tried to customize and setup logs tracing in WebFlux in one of our projects. WebFlux is really cool to implement the idea of non-blocking servers out of the box in Java. This is built on Reactive Streams and supports servers like Netty. The main idea is to build the whole functionality starting from the controller to service to data and other layers to be in one stream or chain. I hope you already have the idea, and you are trying to work with the logs, so reading this one :3

First of all, I started with the request body. For each API request, we need to log the request body to keep track of what was actually sent from the client side. Here is the easiest way to do this -

@PostMapping(value = "api/create/user", consumes = MediaType.APPLICATION_JSON_UTF8_VALUE)
@ResponseStatus(value = HttpStatus.OK)
public Mono<Boolean> createUser(@Valid @RequestBody UserInfo userInfo) {
    log.info("userInfo request body : {}", userInfo);
    return userInfoService.createUserInfo(userInfo);
}

This is easy, right? We are using the @Slf4j in the Controller, so we can just log.info anything we want. And surprisingly this works with WebFlux! It was supposed to be in the chain. But the concept here is, logging is not a thread blocking operation. We are just printing in a format in the stdout or in a file. But if there’s 10 endpoints or more than that, every time we need to write like this in the controller. It should be better if we move this request body log to a common area.

Here comes the concept of Aspect Oriented Programming (AOP). Basically, it is a way of adding behavior to existing code without modifying that code. For a detailed introduction to AOP, there are articles on AOP pointcuts and advice. So I can use AOP annotations like Around or Before to log our request body. Then we decided we should also log the response body. In that case, Around annotation is our answer. I made a custom annotation Loggable to log what is passed into the controller (the request body) and what is the return (the response body).


@Aspect
@Slf4j
public class LoggerAspect {

    @Around("@annotation(Loggable)")
    public Object logAround(ProceedingJoinPoint joinPoint) throws Throwable {

        long start = System.currentTimeMillis();
        var result = joinPoint.proceed();
        if (result instanceof Mono) {
            var monoResult = (Mono) result;
            AtomicReference<String> traceId = new AtomicReference<>("");

            return monoResult
                    .doOnSuccess(o -> {
                        var response = "";
                        if (Objects.nonNull(o)) {
                            response = o.toString();
                        }
                        log.info("Enter: {}.{}() with argument[s] = {}",
                                joinPoint.getSignature().getDeclaringTypeName(), joinPoint.getSignature().getName(),
                                joinPoint.getArgs());
                        log.info("Exit: {}.{}() had arguments = {}, with result = {}, Execution time = {} ms",
                                joinPoint.getSignature().getDeclaringTypeName(), joinPoint.getSignature().getName(),
                                joinPoint.getArgs()[0],
                                response, (System.currentTimeMillis() - start));
                    });
        }
    }
}
import java.lang.annotation.ElementType;  
import java.lang.annotation.Retention;  
import java.lang.annotation.RetentionPolicy;  
import java.lang.annotation.Target;  

/**  
 * @author Azizul Haque Ananto  
 */  

@Retention(RetentionPolicy.RUNTIME)  
@Target(ElementType.METHOD)  
public @interface Loggable {  
}

The code is pretty straight forward. Loggable class for the annotation, and the LoggerAspect for the AOP pointcut. Using the @Loggableannotation in our controller, we can get two logs with Enter and Exit . In the jointPoint.proceed() we will get the response of the function (in our case the function is the controller, so we will basically get the response body), and it will be a Mono. So a subscriber is needed here to get the response body, doOnSuccess is our choice, as we need the response of the successful API call, in case of errors we have an error handler for that. We will come to that later. And for the request body, we can easily use the jointPoint.getArgs() . This looks good, right?

Here comes the fun part. I was using Sleuth to generate tracing id and log4j for logging (the log4j MDC is a thread-local who keeps the tracing). When a request comes from another microservice we can track those using a tracing id. In Spring Boot, if we call one service/module to another using Feign client we can see an X-B3-TraceId in the header which basically tracks the call traversing in a single request-response cycle. In case of Webflux, we need to handle this thing. It was beyond noticing until we face trace id mismatch. Like, module A was calling module B with the trace id 123, but in module B we couldn’t find any trace id 123!

We were using the reactive web client for inter-service/module REST calls. And there was no magic like the Feign client. What I need to do is pass the trace id while calling other modules. And also read the trace id from header while receiving a REST call.

public abstract class AbstractWebClient {
    private static final String MIME_TYPE = "application/json";
    private final WebClient webClient;

    public AbstractWebClient(String clientUrl) {

        this.webClient = WebClient.builder()
                .baseUrl(clientUrl)
                .defaultHeader(HttpHeaders.CONTENT_TYPE, MIME_TYPE)
                .build();
    }

    public <T> Mono<T> get(String uri, Class<T> tClass) {

        return webClient.get()
                .uri(uri)
                .header("X-B3-TRACEID", MDC.get("X-B3-TraceId"))
                .retrieve()
                .onStatus(HttpStatus::is4xxClientError, this::get4xxError)
                .onStatus(HttpStatus::is5xxServerError, this::get5xxError)
                .bodyToMono(tClass);
    }

    public <T> Mono<T> post(String uri, Object body, Class<T> tClass) {

        return webClient.post()
                .uri(uri)
                .header("X-B3-TRACEID", MDC.get("X-B3-TraceId"))
                .syncBody(body)
                .retrieve()
                .onStatus(HttpStatus::is4xxClientError, this::get4xxError)
                .onStatus(HttpStatus::is5xxServerError, this::get5xxError)
                .bodyToMono(tClass);
    }

    private <T> Mono<T> get4xxError(ClientResponse clientResponse) {

        return Mono.error(new Client4xxException("Client Error"));
    }

    private <T> Mono<T> get5xxError(ClientResponse clientResponse) {

        return Mono.error(new Client5xxException("Server Error"));
    }

}

This is theAbstractWebClient , and notice the X-B3-TRACEID in the header. We get the trace id from the log4j MDC which holds the current logging context of the thread, generated by Sleuth. Now we need to catch the trace id.

@Component
@Slf4j
public class TraceIdFilter implements WebFilter {

    @Override
    public Mono<Void> filter(ServerWebExchange exchange, WebFilterChain chain) {
        Map<String, String> headers = exchange.getRequest().getHeaders().toSingleValueMap();

        var traceId = "";
        if (headers.containsKey("X-B3-TRACEID")) {
            traceId = headers.get("X-B3-TRACEID");
            MDC.put("X-B3-TraceId", traceId);
        } else if (!exchange.getRequest().getURI().getPath().contains("/actuator")) {
            log.warn("TRACE_ID not present in header: {}", exchange.getRequest().getURI());
        }

        return chain.filter(exchange);

    }


}

I implemented a webfilter, from where we can get the chain, to include the trace id in it’s MDC. Very simple code, I have read the header, extracted the trace id and put it in the MDC, cool! Now we should get the trace id as expected. But BOOM!! No, in case of errors when we need the trace id most, it was not matched with the header’s one :(

Can you guess the reason? Webflux is the non-blocking thing, thus a single request can go through multiple threads that are not blocked for a single request, rather serving several requests concurrently. The trace id Sleuth generates is basically held by log4j’s MDC context in a single thread. So when the future travels in the threads the MDC context get changed, thus the trace id changes. So this is my solution to the problem -

@Component
@Slf4j
public class TraceIdFilter implements WebFilter {

    @Override
    public Mono<Void> filter(ServerWebExchange exchange, WebFilterChain chain) {
        Map<String, String> headers = exchange.getRequest().getHeaders().toSingleValueMap();

        return chain.filter(exchange)
                .subscriberContext(context -> {
                    var traceId = "";
                    if (headers.containsKey("X-B3-TRACEID")) {
                        traceId = headers.get("X-B3-TRACEID");
                        MDC.put("X-B3-TraceId", traceId);
                    } else if (!exchange.getRequest().getURI().getPath().contains("/actuator")) {
                        log.warn("TRACE_ID not present in header: {}", exchange.getRequest().getURI());
                    }

                    // simple hack to provide the context with the exchange, so the whole chain can get the same trace id
                    Context contextTmp = context.put("X-B3-TraceId", traceId);
                    exchange.getAttributes().put("X-B3-TraceId", traceId);

                    return contextTmp;
                });

    }


}

Focus on the subscriberContext , this holds the context of the chain, from anywhere in the chain we can get the context. Now we can put the trace id in this context, from anywhere in the chain we can read the context using subscriberContext and read back the trace id, set to it the MDC again and the log will have the same trace id! TADA!

This journey taught me a lot of things about fancy functional programming and the trade-offs of reactive java. I hope this blog will help someone, someday. Lastly, I made a library for this Loggable annotation where one can just put this annotation to log whatever arguments the function takes, and whatever it returns, and also logs the execution time. Here is the GitHub repo. A clap will be much appreciated.