Skip to main content
  1. Posts/

Custom Wrapper Function in Golang with Context

·1060 words

Intro #

Now I’am creating a app service named Skrr. Primaryily it was written in Python(Flask) and MongoDB. This time I’m rewriting it in Golang and migrate to PostgreSQL. There were several reasons for it, but as this post is not about it, I may handle this topic in another post.

In this post I will show you how to create a custom wrapper function to inject Database object and User ID from JWT and use it in each API handler. This will help you to reduce the code duplication and make your code more readable.

There are several cool web frameworks in Golang. For example, 1)Gin, 2)Fiber, 3) Echo. And these awesome tools used to provide middleware features you can easily use. However, for me I didn’t want to use frameworks at all for this code. Because, as Golang itself has rich support for network programming, I wanted to use it as much as possible. So I decided to use net/http package and gorilla/mux as minimum dependency. Then I needed to create a custom wrapper function to inject Database object and User ID from JWT and use it in each API handler. This will help you to reduce the code duplication and make your code more readable.

Wrapper Function #

Wrapper functions are used to make writing computer programs easier by abstracting away the details of a subroutine’s underlying implementation. However, the meaning of wrapper function in this blog post, is not the same as the meaning of wrapper function in general. In this post, wrapper function is a function that just wraps another function. In other words, it is a function that takes another function as an argument and returns a function. Which is also called a higher-order function.

Example Code #

This bottom code is a simple example of what I want the function to do. It is a Pyton(FastAPI) code, that I used in my previous project. The function that is passed as an argument is a function that returns a database connection. Also, the function that returns a user id extracted from JWT. These kinds of design makes it much more comfortable to use the database connection and user id in each API handler.


```python
@router.get("/{uuid}")
async def find_one_diary(
    uuid: str, user_uuid=Depends(valid_request), db=Depends(get_db)
):
    diary = await db.execute(select(Diary).where(Diary.uuid == uuid))
    diary = diary.scalar()
    if diary is None:
        return None
    if diary.user_uuid != user_uuid:
        raise HTTPException(status_code=403, detail="Not your diary")

    return diary

Implementation #

Basic Golang Code #

package main

import "net/http"

func main() {
    http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
        w.Write([]byte("Hello World"))
    })
    http.ListenAndServe(":8080", r)
}

This is a basic Golang code making simple web server. The important point that we should focus is the r.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) part. This is the part that we will wrap with our custom wrapper function. So, when we inject our wrapper function, it will look like this.

	db := config.InitDatabase(os.Getenv("POSTGRES_URL"))
	defer db.Close()

    // ...

    schoolRouter := router.PathPrefix("/school").Subrouter()
	schoolRouter.Handle("/{id}", D(db, api.FindOneSchool)).Methods(http.MethodGet)
	schoolRouter.Handle("/search/name", D(db, api.SearchSchoolByName)).Methods(http.MethodGet)

	friendRouter := router.PathPrefix("/friend").Subrouter()
	friendRouter.Handle("", ValidateRequest(D(db, api.FindAllFriend))).Methods(http.MethodGet)

    // ...

Injecting Database Object #

The function D is the wrapper function that I implemented. The first argument is the database object. The second argument is the function that we want to wrap. The return value is the function that we will use as a handler.

func D(db *sql.DB, fn func(http.ResponseWriter, *http.Request, *sql.DB)) http.Handler {
	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
		fn(w, r, db)
	})
}

As second argument of Handle function is http.Handler, we need to return http.Handler type. So, we need to convert our function to http.Handler type. This is why we use http.HandlerFunc to convert our function to http.Handler type.

func FindOneSchool(w http.ResponseWriter, r *http.Request, db *sql.DB) {
	id := mux.Vars(r)["id"]

	school, _ := storage.FindOneSchool(db, id)
	if school.ID == "" {
		utils.JsonResp(w, "no school found", http.StatusBadRequest)
		return
	}

	utils.JsonResp(w, school, http.StatusOK)
}

Then, we can use the database object easily in our handler function. In this case, we use the database object to find a school with the id that is passed as a parameter.

Middleware extracting User from JWT #

The next step is to inject user id from JWT. This is one step more difficult compare to injecting database. Because, we need to extract user id from JWT and pass it to the handler function. Also, when JWT is not valid, we should return 401_UNAUTHORIZED in the middle of the function. So, we need to create a middleware function that extracts user id from JWT and pass it to the handler function.

func ValidateRequest(next http.Handler) http.Handler {
	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
		tokenString := r.Header.Get("Authorization")
		if tokenString == "" {
			JsonResp(w, errors.New("authorization header missing"), http.StatusUnauthorized)
			return
		}

		// Split the "Bearer " prefix from the token string
		tokenParts := strings.Split(tokenString, " ")
		if len(tokenParts) != 2 || strings.ToLower(tokenParts[0]) != "bearer" {
			JsonResp(w, errors.New("invalid Authorization header format"), http.StatusUnauthorized)
			return
		}
		tokenString = tokenParts[1]

		claims, err := ValidateJwt(tokenString)
		if err != nil {
			JsonResp(w, err, http.StatusUnauthorized)
			return
		}

		ctx := context.WithValue(r.Context(), "user_id", userID)
		r = r.WithContext(ctx)

		next.ServeHTTP(w, r)
	})
}

The key point of this function is ctx := context.WithValue(r.Context(), "user_id", userID) and r = r.WithContext(ctx). This is how we pass the user id to the handler function. The first argument is the context of the request. The second argument is the key of the value. The third argument is the value. After creating new context with the value, it is passing to the http.Request context. This is why we use r = r.WithContext(ctx).

When, it was only for validating JWT, using middleware was enough. However, in this case, we can pass the user id to the handler function like below.

func FindAllFriend(w http.ResponseWriter, r *http.Request, db *sql.DB) {
	userID := r.Context().Value("user_id").(string)

	usecase.FindAllFriend(w, db, userID)
}

This is how we can use the user id in the handler function. The userID := r.Context().Value("user_id").(string) part is how we extract the user id from the context. Now, we can easily use the user id also in the handler function.

Conclusion #

In this post, I showed you how to create a custom wrapper function to inject Database object and user id from JWT and use it in each API handler. This will help you to reduce the code duplication and make your code more readable. Also, it was good opportunity to use goalng’s context feature. Which is one of the most important feature in Golang. I hope this post was helpful for you.