Golang中的高阶函数

2020/12/16    Go Golang 高阶函数

维基百科:在数学和计算机科学中,高阶函数(Higher-order function)是至少满足下列一个条件的函数:

  • 接受一个或多个函数作为输入
  • 输出一个函数

符合上述条件的函数我们经常能在编程中碰到,像Python中的map(function, iterable, ...)函数,它的第一个参数接收一个函数名,意思是将function应用于iterable中的每一个元素,这就是一个典型的Higher-order function。

在Golang中我们经常听到这样一句话:“函数是一等公民”,与此对应的有个专业术语叫“First class function”。如果某个编程语言拥有“First class function”特性,就可以把函数作为变量对待。也就说,函数与变量没有差别,它们是一样的,函数可以当成变量使用,也可以赋值给变量。也正是这一特性,可以让我们轻松地在Golang中实现高阶函数。下面看一个Golang版的Map函数:

package main

import "fmt"

type MapCallback func(int, int) int

func main() {
	newItems := Map(square, []int{1, 2, 3, 4})

	fmt.Print(newItems)
}

func square(index, item int) int {
	return item * item
}

func Map(fn MapCallback, items []int) []int {
	result := make([]int, len(items))
	for i, item := range items {
		result[i] = fn(i, item)
	}

	return result
}

那么什么场景可以使用高阶函数呢?先来看一个案例,下述代码描述了一个简单的登录场景,由于各个服务采用了不同的密码加密规则,所以需要在Login函数中写一段判断逻辑:

type Credentials struct {
	Service   string
	Username  string
	Pwd       string
	HashedPwd string
}

func Login(c *Credentials) {

	// do something...

	switch c.Service {
	case "admin":
		c.HashedPwd = fmt.Sprintf("%x", md5.Sum([]byte(c.Pwd)))
	case "op":
		c.HashedPwd = fmt.Sprintf("%x", sha1.Sum([]byte(c.Pwd)))
	default:
		c.HashedPwd = c.Pwd
	}

	// do something...

}

我们在写代码的过程中经常遇到写一大堆代码然后发现有一部分逻辑重复了,于是就提取一个函数以便复用(比如上述的Login函数)。有时候这个函数里的一部分逻辑需要在不同的情况下执行不同的逻辑,一般我们会使用if或者switch来处理。可是这种做法不具备可拓展性,随着if或者case的增加,函数体的大小也会随之线性增长,最终导致代码可维护行越来越差。如果我们使用高阶函数来处理呢:

type HashPwd func(string) string

func Login(c *Credentials, hash HashPwd) {

	// do something...

	c.HashedPwd = hash(c.Pwd)

	// do something...

}

func plainPwd(pwd string) string {
	return pwd
}

func md5Pwd(pwd string) string {
	return fmt.Sprintf("%x", md5.Sum([]byte(pwd)))
}

func sha1Pwd(pwd string) string {
	return fmt.Sprintf("%x", sha1.Sum([]byte(pwd)))
}

func main() {
	Login(c1, plainPwd)
	Login(c2, md5Pwd)
	Login(c3, sha1Pwd)
}

可以看到使用高阶函数封装后代码变得更加优雅了,当我们需要拓展新的hash算法时,只需新增一个HashPwd类型的函数即可,以免Login函数变得臃肿不堪。高阶函数不仅能使程序变得易于拓展与维护,还大大提升了程序的可读性,因为函数名就能清晰地描述它的功能,让开发人员阅读起来更加轻松。

装饰器

动态(组合)地给一个对象增加一些额外的职责。就增加功能而言,Decorator模式比生成子类(继承)更为灵活(消除重复代码 & 减少子类个数)。 ——《设计模式》GoF

在Golang里面我们可以利用高阶函数来实现装饰器,直接看一个案例,统计函数的运行时间:

package main

import (
	"fmt"
	"time"
)

type DoALotOfThings func([]int) int

func doALotOfThings(data []int) int {
	// do a lot of things...
	return 0
}

func timeSpent(fn DoALotOfThings) DoALotOfThings {
	return func(data []int) int {
		start := time.Now()

		result := fn(data)

		fmt.Printf("Time spent: %dms\n", time.Since(start).Milliseconds())

		return result
	}
}

func main() {
	doALotOfThingsWithTimeSpent := timeSpent(doALotOfThings)
	doALotOfThingsWithTimeSpent([]int{1, 2, 3, 4, 5})
}

timeSpent函数就是一个利用高阶函数实现的装饰器,它可以对doALotOfThings函数进行包装,在不改变原函数的情况下便可以统计其运行时间。当然这个案例并不实用,我们只是借助其学习一下装饰器的实现。

下面再看一个实用点的案例,利用装饰器实现HTTP Server的中间件:

package main

import (
	"log"
	"net/http"
)

type Middleware func(http.HandlerFunc) http.HandlerFunc

func main() {
	http.HandleFunc("/users", Pipeline(GetUsers, LogRequest, AllowCORS))

	log.Print("HTTP Server started: http://127.0.0.1:8080")
	log.Fatal(http.ListenAndServe(":8080", nil))
}

func GetUsers(w http.ResponseWriter, r *http.Request) {
	log.Print("GetUsers")
	w.Write([]byte("[1, 2, 3, 4, 5]"))
}

func Pipeline(handler http.HandlerFunc, middleware ...Middleware) http.HandlerFunc {
	for i := len(middleware) - 1; i >= 0; i-- {
		handler = middleware[i](handler)
	}
	return handler
}

// LogRequest 记录请求中间件
func LogRequest(handler http.HandlerFunc) http.HandlerFunc {
	return func(w http.ResponseWriter, r *http.Request) {
		log.Print("LogRequest", r)

		handler(w, r)
	}
}

// AllowCORS 允许跨域中间件
func AllowCORS(handler http.HandlerFunc) http.HandlerFunc {
	return func(w http.ResponseWriter, r *http.Request) {
		log.Print("AllowCORS")
		w.Header().Set("Access-Control-Allow-Origin", "*")
		w.Header().Set("Access-Control-Allow-Methods", "GET,POST,PUT,DELETE")
		w.Header().Set("Access-Control-Allow-Credentials", "true")
		w.Header().Set("Access-Control-Allow-Headers", "X-Requested-With,Content-Type")

		handler(w, r)
	}
}

在上述代码中,LogRequestAllowCORS都是GetUsers的装饰器,然后通过Pipeline来串行地执行这些装饰器,从而达到中间件的效果。这个特性对于HTTP服务程序非常有利,我们可以将一些认证、校验等公用的逻辑提取做为中间件,而每个Handler则只需要关注自己的业务逻辑,这样我们的代码就会变得十分优雅。

依赖注入

这里并不是来讨论如何使用facebookgo/inject之类的工具来编写依赖注入代码,因为我们的主题是高阶函数,我们还是以一个案例来讲解如何利用高阶函数实现依赖注入。假设现在leader给了你一个需求,让你实现一个拼接欢迎用户注册的消息方法,这个方法接收一个userID,你需要从数据库查询出用户信息,然后返回拼接好的消息内容,先来看一下非依赖注入的版本实现:

type User struct {
	ID       uint64
	nickname string
}

func GetWelcomeMsg(userID uint64) string {
	user := RetrieveUser(userID)
	return "欢迎" + user.nickname + "加入社区,致以最诚挚的问候!"
}

// RetrieveUser 获取用户信息(DAO层)
func RetrieveUser(userID uint64) *User {
	user := &User{}
	db.First(user, userID)
	return user
}

功能已经实现了,但是我们还需要编写单元测试来保证代码的可靠性,思考一下如何为GetWelcomeMsg编写单元测试呢?可以看到GetWelcomeMsg里面调用了RetrieveUser方法来获取用户信息,而RetrieveUser是属于DAO层的服务,相当于一个外部依赖,所以说上述代码是将RetrieveUser耦合到了GetWelcomeMsg中,导致我们没法去mock数据来编写单元测试。现在我们需要将依赖解耦,以便更好地编写单元测试:

type UserRetriever func(uint64) *User
type WelcomeMsgGetter func(uint64) string

func NewWelcomeMsgGetter(retrieverUser UserRetriever) WelcomeMsgGetter {
	return func(uerID uint64) string {
		user := retrieverUser(uerID)
		return "欢迎" + user.nickname + "加入社区,致以最诚挚的问候!"
	}
}

func TestGetWelcomeMsg(t *testing.T) {
	retrieveUser := func(userID uint64) *User {
		return &User{
			ID: userID,
			nickname: "小明",
		}
	}

	getWelcomeMsg := NewWelcomeMsgGetter(retrieveUser)
	msg := getWelcomeMsg(1)

	if msg != "欢迎小明加入社区,致以最诚挚的问候!" {
		t.Error("获取欢迎消息出错")
	}
}

可以看到,UserRetriever可以通过NewWelcomeMsgGetter注入到服务内部,这样就实现了解耦,这样我们很容易就能够mock数据进行单元测试了。

总结

此次分享就到这儿了,最后做一个总结(当然高阶函数远不止这4种用法):

  1. 高阶函数并不是什么高深的名词,在平时编程中十分常见
  2. 高阶函数可以在提升复用代码性的同时,保证代码的可维护性
  3. 高阶函数可以实现装饰器,拓展原有函数的功能
  4. 高阶函数可以实现依赖注入,帮助我们编写可测试的代码