reg/vendor/github.com/grpc-ecosystem/grpc-gateway/runtime/handler_test.go
Jess Frazelle 67bc3ef6c3
add v3 api
Signed-off-by: Jess Frazelle <acidburn@microsoft.com>
2018-06-11 12:48:47 -04:00

229 lines
6.4 KiB
Go

package runtime_test
import (
"io"
"io/ioutil"
"net/http"
"net/http/httptest"
"testing"
"context"
"github.com/golang/protobuf/proto"
pb "github.com/grpc-ecosystem/grpc-gateway/examples/proto/examplepb"
"github.com/grpc-ecosystem/grpc-gateway/runtime"
"github.com/grpc-ecosystem/grpc-gateway/runtime/internal"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
)
func TestForwardResponseStream(t *testing.T) {
type msg struct {
pb proto.Message
err error
}
tests := []struct {
name string
msgs []msg
statusCode int
}{{
name: "encoding",
msgs: []msg{
{&pb.SimpleMessage{Id: "One"}, nil},
{&pb.SimpleMessage{Id: "Two"}, nil},
},
statusCode: http.StatusOK,
}, {
name: "empty",
statusCode: http.StatusOK,
}, {
name: "error",
msgs: []msg{{nil, grpc.Errorf(codes.OutOfRange, "400")}},
statusCode: http.StatusBadRequest,
}, {
name: "stream_error",
msgs: []msg{
{&pb.SimpleMessage{Id: "One"}, nil},
{nil, grpc.Errorf(codes.OutOfRange, "400")},
},
statusCode: http.StatusOK,
}}
newTestRecv := func(t *testing.T, msgs []msg) func() (proto.Message, error) {
var count int
return func() (proto.Message, error) {
if count == len(msgs) {
return nil, io.EOF
} else if count > len(msgs) {
t.Errorf("recv() called %d times for %d messages", count, len(msgs))
}
count++
msg := msgs[count-1]
return msg.pb, msg.err
}
}
ctx := runtime.NewServerMetadataContext(context.Background(), runtime.ServerMetadata{})
marshaler := &runtime.JSONPb{}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
recv := newTestRecv(t, tt.msgs)
req := httptest.NewRequest("GET", "http://example.com/foo", nil)
resp := httptest.NewRecorder()
runtime.ForwardResponseStream(ctx, runtime.NewServeMux(), marshaler, resp, req, recv)
w := resp.Result()
if w.StatusCode != tt.statusCode {
t.Errorf("StatusCode %d want %d", w.StatusCode, tt.statusCode)
}
if h := w.Header.Get("Transfer-Encoding"); h != "chunked" {
t.Errorf("ForwardResponseStream missing header chunked")
}
body, err := ioutil.ReadAll(w.Body)
if err != nil {
t.Errorf("Failed to read response body with %v", err)
}
w.Body.Close()
var want []byte
for i, msg := range tt.msgs {
if msg.err != nil {
if i == 0 {
// Skip non-stream errors
t.Skip("checking error encodings")
}
st, _ := status.FromError(msg.err)
httpCode := runtime.HTTPStatusFromCode(st.Code())
b, err := marshaler.Marshal(map[string]proto.Message{
"error": &internal.StreamError{
GrpcCode: int32(st.Code()),
HttpCode: int32(httpCode),
Message: st.Message(),
HttpStatus: http.StatusText(httpCode),
Details: st.Proto().GetDetails(),
},
})
if err != nil {
t.Errorf("marshaler.Marshal() failed %v", err)
}
errBytes := body[len(want):]
if string(errBytes) != string(b) {
t.Errorf("ForwardResponseStream() = \"%s\" want \"%s\"", errBytes, b)
}
return
}
b, err := marshaler.Marshal(map[string]proto.Message{"result": msg.pb})
if err != nil {
t.Errorf("marshaler.Marshal() failed %v", err)
}
want = append(want, b...)
want = append(want, marshaler.Delimiter()...)
}
if string(body) != string(want) {
t.Errorf("ForwardResponseStream() = \"%s\" want \"%s\"", body, want)
}
})
}
}
// A custom marshaler implementation, that doesn't implement the delimited interface
type CustomMarshaler struct {
m *runtime.JSONPb
}
func (c *CustomMarshaler) Marshal(v interface{}) ([]byte, error) { return c.m.Marshal(v) }
func (c *CustomMarshaler) Unmarshal(data []byte, v interface{}) error { return c.m.Unmarshal(data, v) }
func (c *CustomMarshaler) NewDecoder(r io.Reader) runtime.Decoder { return c.m.NewDecoder(r) }
func (c *CustomMarshaler) NewEncoder(w io.Writer) runtime.Encoder { return c.m.NewEncoder(w) }
func (c *CustomMarshaler) ContentType() string { return c.m.ContentType() }
func TestForwardResponseStreamCustomMarshaler(t *testing.T) {
type msg struct {
pb proto.Message
err error
}
tests := []struct {
name string
msgs []msg
statusCode int
}{{
name: "encoding",
msgs: []msg{
{&pb.SimpleMessage{Id: "One"}, nil},
{&pb.SimpleMessage{Id: "Two"}, nil},
},
statusCode: http.StatusOK,
}, {
name: "empty",
statusCode: http.StatusOK,
}, {
name: "error",
msgs: []msg{{nil, grpc.Errorf(codes.OutOfRange, "400")}},
statusCode: http.StatusBadRequest,
}, {
name: "stream_error",
msgs: []msg{
{&pb.SimpleMessage{Id: "One"}, nil},
{nil, grpc.Errorf(codes.OutOfRange, "400")},
},
statusCode: http.StatusOK,
}}
newTestRecv := func(t *testing.T, msgs []msg) func() (proto.Message, error) {
var count int
return func() (proto.Message, error) {
if count == len(msgs) {
return nil, io.EOF
} else if count > len(msgs) {
t.Errorf("recv() called %d times for %d messages", count, len(msgs))
}
count++
msg := msgs[count-1]
return msg.pb, msg.err
}
}
ctx := runtime.NewServerMetadataContext(context.Background(), runtime.ServerMetadata{})
marshaler := &CustomMarshaler{&runtime.JSONPb{}}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
recv := newTestRecv(t, tt.msgs)
req := httptest.NewRequest("GET", "http://example.com/foo", nil)
resp := httptest.NewRecorder()
runtime.ForwardResponseStream(ctx, runtime.NewServeMux(), marshaler, resp, req, recv)
w := resp.Result()
if w.StatusCode != tt.statusCode {
t.Errorf("StatusCode %d want %d", w.StatusCode, tt.statusCode)
}
if h := w.Header.Get("Transfer-Encoding"); h != "chunked" {
t.Errorf("ForwardResponseStream missing header chunked")
}
body, err := ioutil.ReadAll(w.Body)
if err != nil {
t.Errorf("Failed to read response body with %v", err)
}
w.Body.Close()
var want []byte
for _, msg := range tt.msgs {
if msg.err != nil {
t.Skip("checking erorr encodings")
}
b, err := marshaler.Marshal(map[string]proto.Message{"result": msg.pb})
if err != nil {
t.Errorf("marshaler.Marshal() failed %v", err)
}
want = append(want, b...)
want = append(want, "\n"...)
}
if string(body) != string(want) {
t.Errorf("ForwardResponseStream() = \"%s\" want \"%s\"", body, want)
}
})
}
}