package authorization import ( "net/http" "net/http/httptest" "strings" "testing" "github.com/docker/docker/pkg/plugingetter" "github.com/stretchr/testify/require" ) func TestMiddleware(t *testing.T) { pluginNames := []string{"testPlugin1", "testPlugin2"} var pluginGetter plugingetter.PluginGetter m := NewMiddleware(pluginNames, pluginGetter) authPlugins := m.getAuthzPlugins() require.Equal(t, 2, len(authPlugins)) require.EqualValues(t, pluginNames[0], authPlugins[0].Name()) require.EqualValues(t, pluginNames[1], authPlugins[1].Name()) } func TestNewResponseModifier(t *testing.T) { recorder := httptest.NewRecorder() modifier := NewResponseModifier(recorder) modifier.Header().Set("H1", "V1") modifier.Write([]byte("body")) require.False(t, modifier.Hijacked()) modifier.WriteHeader(http.StatusInternalServerError) require.NotNil(t, modifier.RawBody()) raw, err := modifier.RawHeaders() require.NotNil(t, raw) require.Nil(t, err) headerData := strings.Split(strings.TrimSpace(string(raw)), ":") require.EqualValues(t, "H1", strings.TrimSpace(headerData[0])) require.EqualValues(t, "V1", strings.TrimSpace(headerData[1])) modifier.Flush() modifier.FlushAll() if recorder.Header().Get("H1") != "V1" { t.Fatalf("Header value must exists %s", recorder.Header().Get("H1")) } } func setAuthzPlugins(m *Middleware, plugins []Plugin) { m.mu.Lock() m.plugins = plugins m.mu.Unlock() }