From 6a7a4b37e3321eba6537faea6df2690e05535731 Mon Sep 17 00:00:00 2001
From: Manish R Jain <manishrjain@gmail.com>
Date: Mon, 2 May 2016 11:36:58 +1000
Subject: [PATCH] Add logic for query timeout

---
 query/query.go      | 40 ++++++++++++++++++++++++++++------------
 server/main.go      |  4 ++--
 server/main_test.go |  5 +++--
 3 files changed, 33 insertions(+), 16 deletions(-)

diff --git a/query/query.go b/query/query.go
index 35f88601..e06a5876 100644
--- a/query/query.go
+++ b/query/query.go
@@ -473,7 +473,9 @@ func sortedUniqueUids(r *task.Result) (sorted []uint64, rerr error) {
 	return sorted, nil
 }
 
-func ProcessGraph(sg *SubGraph, rch chan error) {
+func ProcessGraph(sg *SubGraph, rch chan error, td time.Duration) {
+	timeout := time.Now().Add(td)
+
 	var err error
 	if len(sg.query) > 0 && sg.Attr != "_root_" {
 		sg.result, err = worker.ProcessTaskOverNetwork(sg.query)
@@ -515,6 +517,13 @@ func ProcessGraph(sg *SubGraph, rch chan error) {
 		return
 	}
 
+	timeleft := timeout.Sub(time.Now())
+	if timeleft < 0 {
+		glog.WithField("attr", sg.Attr).Error("Query timeout before children")
+		rch <- fmt.Errorf("Query timeout before children")
+		return
+	}
+
 	// Let's execute it in a tree fashion. Each SubGraph would break off
 	// as many goroutines as it's children; which would then recursively
 	// do the same thing.
@@ -523,21 +532,28 @@ func ProcessGraph(sg *SubGraph, rch chan error) {
 	for i := 0; i < len(sg.Children); i++ {
 		child := sg.Children[i]
 		child.query = createTaskQuery(child.Attr, sorted)
-		go ProcessGraph(child, childchan)
+		go ProcessGraph(child, childchan, timeleft)
 	}
 
+	tchan := time.After(timeleft)
 	// Now get all the results back.
 	for i := 0; i < len(sg.Children); i++ {
-		err = <-childchan
-		glog.WithFields(logrus.Fields{
-			"num_children": len(sg.Children),
-			"index":        i,
-			"attr":         sg.Children[i].Attr,
-			"err":          err,
-		}).Debug("Reply from child")
-		if err != nil {
-			x.Err(glog, err).Error("While processing child task.")
-			rch <- err
+		select {
+		case err = <-childchan:
+			glog.WithFields(logrus.Fields{
+				"num_children": len(sg.Children),
+				"index":        i,
+				"attr":         sg.Children[i].Attr,
+				"err":          err,
+			}).Debug("Reply from child")
+			if err != nil {
+				x.Err(glog, err).Error("While processing child task.")
+				rch <- err
+				return
+			}
+		case <-tchan:
+			glog.WithField("attr", sg.Attr).Error("Query timeout after children")
+			rch <- fmt.Errorf("Query timeout after children")
 			return
 		}
 	}
diff --git a/server/main.go b/server/main.go
index 887677d1..6ab16363 100644
--- a/server/main.go
+++ b/server/main.go
@@ -177,7 +177,7 @@ func queryHandler(w http.ResponseWriter, r *http.Request) {
 	glog.WithField("q", string(q)).Debug("Query parsed.")
 
 	rch := make(chan error)
-	go query.ProcessGraph(sg, rch)
+	go query.ProcessGraph(sg, rch, time.Second)
 	err = <-rch
 	if err != nil {
 		x.Err(glog, err).Error("While executing query")
@@ -233,7 +233,7 @@ func (s *server) Query(ctx context.Context,
 	glog.WithField("q", req.Query).Debug("Query parsed.")
 
 	rch := make(chan error)
-	go query.ProcessGraph(sg, rch)
+	go query.ProcessGraph(sg, rch, time.Minute)
 	err = <-rch
 	if err != nil {
 		x.Err(glog, err).Error("While executing query")
diff --git a/server/main_test.go b/server/main_test.go
index 1cbaf043..6a9fc921 100644
--- a/server/main_test.go
+++ b/server/main_test.go
@@ -21,6 +21,7 @@ import (
 	"io/ioutil"
 	"os"
 	"testing"
+	"time"
 
 	"github.com/dgraph-io/dgraph/commit"
 	"github.com/dgraph-io/dgraph/gql"
@@ -158,7 +159,7 @@ func TestQuery(t *testing.T) {
 	}
 
 	ch := make(chan error)
-	go query.ProcessGraph(g, ch)
+	go query.ProcessGraph(g, ch, time.Minute)
 	if err := <-ch; err != nil {
 		t.Error(err)
 		return
@@ -217,7 +218,7 @@ func BenchmarkQuery(b *testing.B) {
 		}
 
 		ch := make(chan error)
-		go query.ProcessGraph(g, ch)
+		go query.ProcessGraph(g, ch, time.Minute)
 		if err := <-ch; err != nil {
 			b.Error(err)
 			return
-- 
GitLab