diff --git a/clair/ancestry.go b/clair/ancestry.go index c67cff93..d9d6c526 100644 --- a/clair/ancestry.go +++ b/clair/ancestry.go @@ -13,7 +13,7 @@ var ( ) // GetAncestry displays an ancestry and all of its features and vulnerabilities. -func (c *Clair) GetAncestry(name string) (*clairpb.GetAncestryResponse_Ancestry, error) { +func (c *Clair) GetAncestry(ctx context.Context, name string) (*clairpb.GetAncestryResponse_Ancestry, error) { c.Logf("clair.ancestry.get name=%s", name) if c.grpcConn == nil { @@ -22,7 +22,7 @@ func (c *Clair) GetAncestry(name string) (*clairpb.GetAncestryResponse_Ancestry, client := clairpb.NewAncestryServiceClient(c.grpcConn) - resp, err := client.GetAncestry(context.Background(), &clairpb.GetAncestryRequest{ + resp, err := client.GetAncestry(ctx, &clairpb.GetAncestryRequest{ AncestryName: name, }) if err != nil { @@ -41,7 +41,7 @@ func (c *Clair) GetAncestry(name string) (*clairpb.GetAncestryResponse_Ancestry, } // PostAncestry performs the analysis of all layers from the provided path. -func (c *Clair) PostAncestry(name string, layers []*clairpb.PostAncestryRequest_PostLayer) error { +func (c *Clair) PostAncestry(ctx context.Context, name string, layers []*clairpb.PostAncestryRequest_PostLayer) error { c.Logf("clair.ancestry.post name=%s", name) if c.grpcConn == nil { @@ -50,7 +50,7 @@ func (c *Clair) PostAncestry(name string, layers []*clairpb.PostAncestryRequest_ client := clairpb.NewAncestryServiceClient(c.grpcConn) - resp, err := client.PostAncestry(context.Background(), &clairpb.PostAncestryRequest{ + resp, err := client.PostAncestry(ctx, &clairpb.PostAncestryRequest{ AncestryName: name, Layers: layers, Format: "Docker", diff --git a/clair/clair.go b/clair/clair.go index 9d3242a5..49dc15a3 100644 --- a/clair/clair.go +++ b/clair/clair.go @@ -1,6 +1,7 @@ package clair import ( + "context" "crypto/tls" "encoding/json" "fmt" @@ -93,8 +94,12 @@ func (c *Clair) url(pathTemplate string, args ...interface{}) string { return url } -func (c *Clair) getJSON(url string, response interface{}) (http.Header, error) { - resp, err := c.Client.Get(url) +func (c *Clair) getJSON(ctx context.Context, url string, response interface{}) (http.Header, error) { + req, err := http.NewRequest("GET", url, nil) + if err != nil { + return nil, err + } + resp, err := c.Client.Do(req.WithContext(ctx)) if err != nil { return nil, err } diff --git a/clair/layer.go b/clair/layer.go index 9cfc65de..e1ccba0a 100644 --- a/clair/layer.go +++ b/clair/layer.go @@ -2,18 +2,19 @@ package clair import ( "bytes" + "context" "encoding/json" "fmt" "net/http" ) // GetLayer displays a Layer and optionally all of its features and vulnerabilities. -func (c *Clair) GetLayer(name string, features, vulnerabilities bool) (*Layer, error) { +func (c *Clair) GetLayer(ctx context.Context, name string, features, vulnerabilities bool) (*Layer, error) { url := c.url("/v1/layers/%s?features=%t&vulnerabilities=%t", name, features, vulnerabilities) c.Logf("clair.layers.get url=%s name=%s", url, name) var respLayer layerEnvelope - if _, err := c.getJSON(url, &respLayer); err != nil { + if _, err := c.getJSON(ctx, url, &respLayer); err != nil { return nil, err } @@ -25,7 +26,7 @@ func (c *Clair) GetLayer(name string, features, vulnerabilities bool) (*Layer, e } // PostLayer performs the analysis of a Layer from the provided path. -func (c *Clair) PostLayer(layer *Layer) (*Layer, error) { +func (c *Clair) PostLayer(ctx context.Context, layer *Layer) (*Layer, error) { url := c.url("/v1/layers") c.Logf("clair.layers.post url=%s name=%s", url, layer.Name) @@ -36,7 +37,14 @@ func (c *Clair) PostLayer(layer *Layer) (*Layer, error) { c.Logf("clair.layers.post req.Body=%s", string(b)) - resp, err := c.Client.Post(url, "application/json", bytes.NewReader(b)) + req, err := http.NewRequest("POST", url, bytes.NewReader(b)) + if err != nil { + return nil, err + } + + req.Header.Set("Content-Type", "application/json") + + resp, err := c.Client.Do(req.WithContext(ctx)) if err != nil { return nil, err } @@ -57,7 +65,7 @@ func (c *Clair) PostLayer(layer *Layer) (*Layer, error) { } // DeleteLayer removes a layer reference from clair. -func (c *Clair) DeleteLayer(name string) error { +func (c *Clair) DeleteLayer(ctx context.Context, name string) error { url := c.url("/v1/layers/%s", name) c.Logf("clair.layers.delete url=%s name=%s", url, name) @@ -66,7 +74,7 @@ func (c *Clair) DeleteLayer(name string) error { return err } - resp, err := c.Client.Do(req) + resp, err := c.Client.Do(req.WithContext(ctx)) if err != nil { return err } diff --git a/clair/layerutil.go b/clair/layerutil.go index e31afa93..7b9923ae 100644 --- a/clair/layerutil.go +++ b/clair/layerutil.go @@ -1,6 +1,7 @@ package clair import ( + "context" "fmt" "strings" @@ -10,7 +11,7 @@ import ( ) // NewClairLayer will form a layer struct required for a clair scan. -func (c *Clair) NewClairLayer(r *registry.Registry, image string, fsLayers map[int]distribution.Descriptor, index int) (*Layer, error) { +func (c *Clair) NewClairLayer(ctx context.Context, r *registry.Registry, image string, fsLayers map[int]distribution.Descriptor, index int) (*Layer, error) { var parentName string if index < len(fsLayers)-1 { parentName = fsLayers[index+1].Digest.String() @@ -20,7 +21,7 @@ func (c *Clair) NewClairLayer(r *registry.Registry, image string, fsLayers map[i p := strings.Join([]string{r.URL, "v2", image, "blobs", fsLayers[index].Digest.String()}, "/") // Get the headers. - h, err := r.Headers(p) + h, err := r.Headers(ctx, p) if err != nil { return nil, err } @@ -35,12 +36,12 @@ func (c *Clair) NewClairLayer(r *registry.Registry, image string, fsLayers map[i } // NewClairV3Layer will form a layer struct required for a clair scan. -func (c *Clair) NewClairV3Layer(r *registry.Registry, image string, fsLayer distribution.Descriptor) (*clairpb.PostAncestryRequest_PostLayer, error) { +func (c *Clair) NewClairV3Layer(ctx context.Context, r *registry.Registry, image string, fsLayer distribution.Descriptor) (*clairpb.PostAncestryRequest_PostLayer, error) { // Form the path. p := strings.Join([]string{r.URL, "v2", image, "blobs", fsLayer.Digest.String()}, "/") // Get the headers. - h, err := r.Headers(p) + h, err := r.Headers(ctx, p) if err != nil { return nil, err } @@ -52,10 +53,10 @@ func (c *Clair) NewClairV3Layer(r *registry.Registry, image string, fsLayer dist }, nil } -func (c *Clair) getLayers(r *registry.Registry, repo, tag string, filterEmpty bool) (map[int]distribution.Descriptor, string, error) { +func (c *Clair) getLayers(ctx context.Context, r *registry.Registry, repo, tag string, filterEmpty bool) (map[int]distribution.Descriptor, string, error) { ok := true // Get the manifest to pass to clair. - mf, err := r.ManifestV2(repo, tag) + mf, err := r.ManifestV2(ctx, repo, tag) if err != nil { ok = false c.Logf("couldn't retrieve manifest v2, falling back to v1") @@ -76,7 +77,7 @@ func (c *Clair) getLayers(r *registry.Registry, repo, tag string, filterEmpty bo return filteredLayers, mf.Config.Digest.String(), nil } - m, err := r.ManifestV1(repo, tag) + m, err := r.ManifestV1(ctx, repo, tag) if err != nil { return nil, "", fmt.Errorf("getting the v1 manifest for %s:%s failed: %v", repo, tag, err) } diff --git a/clair/vulns.go b/clair/vulns.go index cc06035b..2ee4a057 100644 --- a/clair/vulns.go +++ b/clair/vulns.go @@ -1,6 +1,7 @@ package clair import ( + "context" "errors" "fmt" "time" @@ -10,7 +11,7 @@ import ( ) // Vulnerabilities scans the given repo and tag. -func (c *Clair) Vulnerabilities(r *registry.Registry, repo, tag string) (VulnerabilityReport, error) { +func (c *Clair) Vulnerabilities(ctx context.Context, r *registry.Registry, repo, tag string) (VulnerabilityReport, error) { report := VulnerabilityReport{ RegistryURL: r.Domain, Repo: repo, @@ -19,7 +20,7 @@ func (c *Clair) Vulnerabilities(r *registry.Registry, repo, tag string) (Vulnera VulnsBySeverity: make(map[string][]Vulnerability), } - filteredLayers, _, err := c.getLayers(r, repo, tag, true) + filteredLayers, _, err := c.getLayers(ctx, r, repo, tag, true) if err != nil { return report, fmt.Errorf("getting filtered layers failed: %v", err) } @@ -31,20 +32,20 @@ func (c *Clair) Vulnerabilities(r *registry.Registry, repo, tag string) (Vulnera for i := len(filteredLayers) - 1; i >= 0; i-- { // Form the clair layer. - l, err := c.NewClairLayer(r, repo, filteredLayers, i) + l, err := c.NewClairLayer(ctx, r, repo, filteredLayers, i) if err != nil { return report, err } // Post the layer. - if _, err := c.PostLayer(l); err != nil { + if _, err := c.PostLayer(ctx, l); err != nil { return report, err } } report.Name = filteredLayers[0].Digest.String() - vl, err := c.GetLayer(filteredLayers[0].Digest.String(), true, true) + vl, err := c.GetLayer(ctx, filteredLayers[0].Digest.String(), true, true) if err != nil { return report, err } @@ -76,7 +77,7 @@ func (c *Clair) Vulnerabilities(r *registry.Registry, repo, tag string) (Vulnera } // VulnerabilitiesV3 scans the given repo and tag using the clair v3 API. -func (c *Clair) VulnerabilitiesV3(r *registry.Registry, repo, tag string) (VulnerabilityReport, error) { +func (c *Clair) VulnerabilitiesV3(ctx context.Context, r *registry.Registry, repo, tag string) (VulnerabilityReport, error) { report := VulnerabilityReport{ RegistryURL: r.Domain, Repo: repo, @@ -85,7 +86,7 @@ func (c *Clair) VulnerabilitiesV3(r *registry.Registry, repo, tag string) (Vulne VulnsBySeverity: make(map[string][]Vulnerability), } - layers, reportName, err := c.getLayers(r, repo, tag, false) + layers, reportName, err := c.getLayers(ctx, r, repo, tag, false) if err != nil { return report, fmt.Errorf("getting filtered layers failed: %v", err) } @@ -100,7 +101,7 @@ func (c *Clair) VulnerabilitiesV3(r *registry.Registry, repo, tag string) (Vulne clairLayers := []*clairpb.PostAncestryRequest_PostLayer{} for i := len(layers) - 1; i >= 0; i-- { // Form the clair layer. - l, err := c.NewClairV3Layer(r, repo, layers[i]) + l, err := c.NewClairV3Layer(ctx, r, repo, layers[i]) if err != nil { return report, err } @@ -110,12 +111,12 @@ func (c *Clair) VulnerabilitiesV3(r *registry.Registry, repo, tag string) (Vulne } // Post the ancestry. - if err := c.PostAncestry(reportName, clairLayers); err != nil { + if err := c.PostAncestry(ctx, reportName, clairLayers); err != nil { return report, fmt.Errorf("posting ancestry failed: %v", err) } // Get the ancestry. - vl, err := c.GetAncestry(reportName) + vl, err := c.GetAncestry(ctx, reportName) if err != nil { return report, err } diff --git a/digest.go b/digest.go index 13080ef0..d86a8fff 100644 --- a/digest.go +++ b/digest.go @@ -31,13 +31,13 @@ func (cmd *digestCommand) Run(ctx context.Context, args []string) error { } // Create the registry client. - r, err := createRegistryClient(image.Domain) + r, err := createRegistryClient(ctx, image.Domain) if err != nil { return err } // Get the digest. - digest, err := r.Digest(image) + digest, err := r.Digest(ctx, image) if err != nil { return err } diff --git a/handlers.go b/handlers.go index 221d30be..36204168 100644 --- a/handlers.go +++ b/handlers.go @@ -2,6 +2,7 @@ package main import ( "bytes" + "context" "encoding/json" "fmt" "html/template" @@ -53,7 +54,7 @@ type AnalysisResult struct { UpdateInterval time.Duration } -func (rc *registryController) repositories(staticDir string) error { +func (rc *registryController) repositories(ctx context.Context, staticDir string) error { rc.l.Lock() defer rc.l.Unlock() @@ -65,7 +66,7 @@ func (rc *registryController) repositories(staticDir string) error { UpdateInterval: rc.interval, } - repoList, err := rc.reg.Catalog("") + repoList, err := rc.reg.Catalog(ctx, "") if err != nil { return fmt.Errorf("getting catalog for %s failed: %v", rc.reg.Domain, err) } @@ -94,7 +95,7 @@ func (rc *registryController) repositories(staticDir string) error { // Parse and execute the tags templates. // If we are generating the tags files, disable vulnerability links in the // templates since they won't go anywhere without a server side component. - b, err := rc.generateTagsTemplate(repo, false) + b, err := rc.generateTagsTemplate(ctx, repo, false) if err != nil { logrus.Warnf("generating tags template for repo %q failed: %v", repo, err) } @@ -156,7 +157,7 @@ func (rc *registryController) tagsHandler(w http.ResponseWriter, r *http.Request } // Generate the tags template. - b, err := rc.generateTagsTemplate(repo, rc.cl != nil) + b, err := rc.generateTagsTemplate(context.TODO(), repo, rc.cl != nil) if err != nil { logrus.WithFields(logrus.Fields{ "func": "tags", @@ -173,9 +174,9 @@ func (rc *registryController) tagsHandler(w http.ResponseWriter, r *http.Request fmt.Fprint(w, string(b)) } -func (rc *registryController) generateTagsTemplate(repo string, hasVulns bool) ([]byte, error) { +func (rc *registryController) generateTagsTemplate(ctx context.Context, repo string, hasVulns bool) ([]byte, error) { // Get the tags from the server. - tags, err := rc.reg.Tags(repo) + tags, err := rc.reg.Tags(ctx, repo) if err != nil { return nil, fmt.Errorf("getting tags for %s failed: %v", repo, err) } @@ -196,7 +197,7 @@ func (rc *registryController) generateTagsTemplate(repo string, hasVulns bool) ( for _, tag := range tags { // get the manifest - m1, err := rc.reg.ManifestV1(repo, tag) + m1, err := rc.reg.ManifestV1(ctx, repo, tag) if err != nil { return nil, fmt.Errorf("getting v1 manifest for %s:%s failed: %v", repo, tag, err) } @@ -272,10 +273,10 @@ func (rc *registryController) vulnerabilitiesHandler(w http.ResponseWriter, r *h } // Get the vulnerability report. - result, err := rc.cl.VulnerabilitiesV3(rc.reg, image.Path, image.Reference()) + result, err := rc.cl.VulnerabilitiesV3(context.TODO(), rc.reg, image.Path, image.Reference()) if err != nil { // Fallback to Clair v2 API. - result, err = rc.cl.Vulnerabilities(rc.reg, image.Path, image.Reference()) + result, err = rc.cl.Vulnerabilities(context.TODO(), rc.reg, image.Path, image.Reference()) if err != nil { logrus.WithFields(logrus.Fields{ "func": "vulnerabilities", diff --git a/layer.go b/layer.go index f12864fc..b897ae5f 100644 --- a/layer.go +++ b/layer.go @@ -38,19 +38,19 @@ func (cmd *layerCommand) Run(ctx context.Context, args []string) error { } // Create the registry client. - r, err := createRegistryClient(image.Domain) + r, err := createRegistryClient(ctx, image.Domain) if err != nil { return err } // Get the digest. - digest, err := r.Digest(image) + digest, err := r.Digest(ctx, image) if err != nil { return err } // Download the layer. - layer, err := r.DownloadLayer(image.Path, digest) + layer, err := r.DownloadLayer(ctx, image.Path, digest) if err != nil { return err } diff --git a/list.go b/list.go index b829bc55..404c87e8 100644 --- a/list.go +++ b/list.go @@ -30,12 +30,12 @@ func (cmd *listCommand) Run(ctx context.Context, args []string) error { } // Create the registry client. - r, err := createRegistryClient(args[0]) + r, err := createRegistryClient(ctx, args[0]) if err != nil { return err } // Get the repositories via catalog. - repos, err := r.Catalog("") + repos, err := r.Catalog(ctx, "") if err != nil { if _, ok := err.(*json.SyntaxError); ok { return fmt.Errorf("Domain %s is not a valid registry", r.Domain) @@ -56,7 +56,7 @@ func (cmd *listCommand) Run(ctx context.Context, args []string) error { for _, repo := range repos { go func(repo string) { // Get the tags. - tags, err := r.Tags(repo) + tags, err := r.Tags(ctx, repo) if err != nil { fmt.Printf("Get tags of [%s] error: %s", repo, err) } diff --git a/main.go b/main.go index ca295acd..12968c92 100644 --- a/main.go +++ b/main.go @@ -102,7 +102,7 @@ func main() { p.Run() } -func createRegistryClient(domain string) (*registry.Registry, error) { +func createRegistryClient(ctx context.Context, domain string) (*registry.Registry, error) { // Use the auth-url domain if provided. authDomain := authURL if authDomain == "" { @@ -119,7 +119,7 @@ func createRegistryClient(domain string) (*registry.Registry, error) { } // Create the registry client. - return registry.New(auth, registry.Opt{ + return registry.New(ctx, auth, registry.Opt{ Domain: domain, Insecure: insecure, Debug: debug, diff --git a/manifest.go b/manifest.go index 4da8bb7c..8a5b830a 100644 --- a/manifest.go +++ b/manifest.go @@ -36,7 +36,7 @@ func (cmd *manifestCommand) Run(ctx context.Context, args []string) error { } // Create the registry client. - r, err := createRegistryClient(image.Domain) + r, err := createRegistryClient(ctx, image.Domain) if err != nil { return err } @@ -44,13 +44,13 @@ func (cmd *manifestCommand) Run(ctx context.Context, args []string) error { var manifest interface{} if cmd.v1 { // Get the v1 manifest if it was explicitly asked for. - manifest, err = r.ManifestV1(image.Path, image.Reference()) + manifest, err = r.ManifestV1(ctx, image.Path, image.Reference()) if err != nil { return err } } else { // Get the v2 manifest. - manifest, err = r.Manifest(image.Path, image.Reference()) + manifest, err = r.Manifest(ctx, image.Path, image.Reference()) if err != nil { return err } diff --git a/registry/catalog.go b/registry/catalog.go index 4c4494c4..b9e25448 100644 --- a/registry/catalog.go +++ b/registry/catalog.go @@ -1,6 +1,7 @@ package registry import ( + "context" "net/url" "github.com/peterhellberg/link" @@ -11,7 +12,7 @@ type catalogResponse struct { } // Catalog returns the repositories in a registry. -func (r *Registry) Catalog(u string) ([]string, error) { +func (r *Registry) Catalog(ctx context.Context, u string) ([]string, error) { if u == "" { u = "/v2/_catalog" } @@ -19,7 +20,7 @@ func (r *Registry) Catalog(u string) ([]string, error) { r.Logf("registry.catalog url=%s", uri) var response catalogResponse - h, err := r.getJSON(uri, &response) + h, err := r.getJSON(ctx, uri, &response) if err != nil { return nil, err } @@ -27,7 +28,7 @@ func (r *Registry) Catalog(u string) ([]string, error) { for _, l := range link.ParseHeader(h) { if l.Rel == "next" { unescaped, _ := url.QueryUnescape(l.URI) - repos, err := r.Catalog(unescaped) + repos, err := r.Catalog(ctx, unescaped) if err != nil { return nil, err } diff --git a/registry/delete.go b/registry/delete.go index a3d06814..6bd6a9bd 100644 --- a/registry/delete.go +++ b/registry/delete.go @@ -1,6 +1,7 @@ package registry import ( + "context" "fmt" "net/http" @@ -10,7 +11,7 @@ import ( // Delete removes a repository digest from the registry. // https://docs.docker.com/registry/spec/api/#deleting-an-image -func (r *Registry) Delete(repository string, digest digest.Digest) (err error) { +func (r *Registry) Delete(ctx context.Context, repository string, digest digest.Digest) (err error) { url := r.url("/v2/%s/manifests/%s", repository, digest) r.Logf("registry.manifests.delete url=%s repository=%s digest=%s", url, repository, digest) @@ -21,7 +22,7 @@ func (r *Registry) Delete(repository string, digest digest.Digest) (err error) { } req.Header.Add("Accept", fmt.Sprintf("%s;q=0.9", schema2.MediaTypeManifest)) - resp, err := r.Client.Do(req) + resp, err := r.Client.Do(req.WithContext(ctx)) if err != nil { return err } diff --git a/registry/digest.go b/registry/digest.go index 80d94b65..a1346f31 100644 --- a/registry/digest.go +++ b/registry/digest.go @@ -1,6 +1,7 @@ package registry import ( + "context" "fmt" "net/http" @@ -9,7 +10,7 @@ import ( ) // Digest returns the digest for an image. -func (r *Registry) Digest(image Image) (digest.Digest, error) { +func (r *Registry) Digest(ctx context.Context, image Image) (digest.Digest, error) { if len(image.Digest) > 1 { // return early if we already have an image digest. return image.Digest, nil @@ -25,7 +26,7 @@ func (r *Registry) Digest(image Image) (digest.Digest, error) { } req.Header.Add("Accept", schema2.MediaTypeManifest) - resp, err := r.Client.Do(req) + resp, err := r.Client.Do(req.WithContext(ctx)) if err != nil { return "", err } diff --git a/registry/digest_test.go b/registry/digest_test.go index af3fea75..c6bdccf6 100644 --- a/registry/digest_test.go +++ b/registry/digest_test.go @@ -1,23 +1,25 @@ package registry import ( + "context" "testing" "github.com/genuinetools/reg/repoutils" ) func TestDigestFromDockerHub(t *testing.T) { + ctx := context.Background() auth, err := repoutils.GetAuthConfig("", "", "docker.io") if err != nil { t.Fatalf("Could not get auth config: %s", err) } - r, err := New(auth, Opt{}) + r, err := New(ctx, auth, Opt{}) if err != nil { t.Fatalf("Could not create registry instance: %s", err) } - d, err := r.Digest(Image{Domain: "docker.io", Path: "library/alpine", Tag: "latest"}) + d, err := r.Digest(ctx, Image{Domain: "docker.io", Path: "library/alpine", Tag: "latest"}) if err != nil { t.Fatalf("Could not get digest: %s", err) } @@ -28,17 +30,18 @@ func TestDigestFromDockerHub(t *testing.T) { } func TestDigestFromGCR(t *testing.T) { + ctx := context.Background() auth, err := repoutils.GetAuthConfig("", "", "gcr.io") if err != nil { t.Fatalf("Could not get auth config: %s", err) } - r, err := New(auth, Opt{}) + r, err := New(ctx, auth, Opt{}) if err != nil { t.Fatalf("Could not create registry instance: %s", err) } - d, err := r.Digest(Image{Domain: "gcr.io", Path: "google_containers/hyperkube", Tag: "v1.9.9"}) + d, err := r.Digest(ctx, Image{Domain: "gcr.io", Path: "google_containers/hyperkube", Tag: "v1.9.9"}) if err != nil { t.Fatalf("Could not get digest: %s", err) } diff --git a/registry/layer.go b/registry/layer.go index e382337d..a932dde6 100644 --- a/registry/layer.go +++ b/registry/layer.go @@ -1,21 +1,27 @@ package registry import ( + "context" "io" "net/http" "net/url" "fmt" + "github.com/docker/distribution/reference" "github.com/opencontainers/go-digest" ) // DownloadLayer downloads a specific layer by digest for a repository. -func (r *Registry) DownloadLayer(repository string, digest digest.Digest) (io.ReadCloser, error) { +func (r *Registry) DownloadLayer(ctx context.Context, repository string, digest digest.Digest) (io.ReadCloser, error) { url := r.url("/v2/%s/blobs/%s", repository, digest) r.Logf("registry.layer.download url=%s repository=%s digest=%s", url, repository, digest) - resp, err := r.Client.Get(url) + req, err := http.NewRequest("GET", url, nil) + if err != nil { + return nil, err + } + resp, err := r.Client.Do(req.WithContext(ctx)) if err != nil { return nil, err } @@ -24,8 +30,8 @@ func (r *Registry) DownloadLayer(repository string, digest digest.Digest) (io.Re } // UploadLayer uploads a specific layer by digest for a repository. -func (r *Registry) UploadLayer(repository string, digest reference.Reference, content io.Reader) error { - uploadURL, token, err := r.initiateUpload(repository) +func (r *Registry) UploadLayer(ctx context.Context, repository string, digest reference.Reference, content io.Reader) error { + uploadURL, token, err := r.initiateUpload(ctx, repository) if err != nil { return err } @@ -42,16 +48,20 @@ func (r *Registry) UploadLayer(repository string, digest reference.Reference, co upload.Header.Set("Content-Type", "application/octet-stream") upload.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token)) - _, err = r.Client.Do(upload) + _, err = r.Client.Do(upload.WithContext(ctx)) return err } // HasLayer returns if the registry contains the specific digest for a repository. -func (r *Registry) HasLayer(repository string, digest digest.Digest) (bool, error) { +func (r *Registry) HasLayer(ctx context.Context, repository string, digest digest.Digest) (bool, error) { checkURL := r.url("/v2/%s/blobs/%s", repository, digest) r.Logf("registry.layer.check url=%s repository=%s digest=%s", checkURL, repository, digest) - resp, err := r.Client.Head(checkURL) + req, err := http.NewRequest("HEAD", checkURL, nil) + if err != nil { + return false, err + } + resp, err := r.Client.Do(req.WithContext(ctx)) if err == nil { defer resp.Body.Close() return resp.StatusCode == http.StatusOK, nil @@ -72,11 +82,16 @@ func (r *Registry) HasLayer(repository string, digest digest.Digest) (bool, erro return false, err } -func (r *Registry) initiateUpload(repository string) (*url.URL, string, error) { +func (r *Registry) initiateUpload(ctx context.Context, repository string) (*url.URL, string, error) { initiateURL := r.url("/v2/%s/blobs/uploads/", repository) r.Logf("registry.layer.initiate-upload url=%s repository=%s", initiateURL, repository) - resp, err := r.Client.Post(initiateURL, "application/octet-stream", nil) + req, err := http.NewRequest("POST", initiateURL, nil) + if err != nil { + return nil, "", err + } + req.Header.Set("Content-Type", "application/octet-stream") + resp, err := r.Client.Do(req.WithContext(ctx)) if err != nil { return nil, "", err } diff --git a/registry/manifest.go b/registry/manifest.go index 138f6ab7..10e31417 100644 --- a/registry/manifest.go +++ b/registry/manifest.go @@ -2,6 +2,7 @@ package registry import ( "bytes" + "context" "encoding/json" "errors" "fmt" @@ -20,7 +21,7 @@ var ( ) // Manifest returns the manifest for a specific repository:tag. -func (r *Registry) Manifest(repository, ref string) (distribution.Manifest, error) { +func (r *Registry) Manifest(ctx context.Context, repository, ref string) (distribution.Manifest, error) { uri := r.url("/v2/%s/manifests/%s", repository, ref) r.Logf("registry.manifests uri=%s repository=%s ref=%s", uri, repository, ref) @@ -31,7 +32,7 @@ func (r *Registry) Manifest(repository, ref string) (distribution.Manifest, erro req.Header.Add("Accept", fmt.Sprintf("%s;q=0.9", schema2.MediaTypeManifest)) - resp, err := r.Client.Do(req) + resp, err := r.Client.Do(req.WithContext(ctx)) if err != nil { return nil, err } @@ -52,12 +53,12 @@ func (r *Registry) Manifest(repository, ref string) (distribution.Manifest, erro } // ManifestList gets the registry v2 manifest list. -func (r *Registry) ManifestList(repository, ref string) (manifestlist.ManifestList, error) { +func (r *Registry) ManifestList(ctx context.Context, repository, ref string) (manifestlist.ManifestList, error) { uri := r.url("/v2/%s/manifests/%s", repository, ref) r.Logf("registry.manifests uri=%s repository=%s ref=%s", uri, repository, ref) var m manifestlist.ManifestList - if _, err := r.getJSON(uri, &m); err != nil { + if _, err := r.getJSON(ctx, uri, &m); err != nil { r.Logf("registry.manifests response=%v", m) return m, err } @@ -66,12 +67,12 @@ func (r *Registry) ManifestList(repository, ref string) (manifestlist.ManifestLi } // ManifestV2 gets the registry v2 manifest. -func (r *Registry) ManifestV2(repository, ref string) (schema2.Manifest, error) { +func (r *Registry) ManifestV2(ctx context.Context, repository, ref string) (schema2.Manifest, error) { uri := r.url("/v2/%s/manifests/%s", repository, ref) r.Logf("registry.manifests uri=%s repository=%s ref=%s", uri, repository, ref) var m schema2.Manifest - if _, err := r.getJSON(uri, &m); err != nil { + if _, err := r.getJSON(ctx, uri, &m); err != nil { r.Logf("registry.manifests response=%v", m) return m, err } @@ -84,12 +85,12 @@ func (r *Registry) ManifestV2(repository, ref string) (schema2.Manifest, error) } // ManifestV1 gets the registry v1 manifest. -func (r *Registry) ManifestV1(repository, ref string) (schema1.SignedManifest, error) { +func (r *Registry) ManifestV1(ctx context.Context, repository, ref string) (schema1.SignedManifest, error) { uri := r.url("/v2/%s/manifests/%s", repository, ref) r.Logf("registry.manifests uri=%s repository=%s ref=%s", uri, repository, ref) var m schema1.SignedManifest - if _, err := r.getJSON(uri, &m); err != nil { + if _, err := r.getJSON(ctx, uri, &m); err != nil { r.Logf("registry.manifests response=%v", m) return m, err } @@ -102,7 +103,7 @@ func (r *Registry) ManifestV1(repository, ref string) (schema1.SignedManifest, e } // PutManifest calls a PUT for the specific manifest for an image. -func (r *Registry) PutManifest(repository, ref string, manifest distribution.Manifest) error { +func (r *Registry) PutManifest(ctx context.Context, repository, ref string, manifest distribution.Manifest) error { url := r.url("/v2/%s/manifests/%s", repository, ref) r.Logf("registry.manifest.put url=%s repository=%s reference=%s", url, repository, ref) @@ -117,7 +118,7 @@ func (r *Registry) PutManifest(repository, ref string, manifest distribution.Man } req.Header.Set("Content-Type", schema2.MediaTypeManifest) - resp, err := r.Client.Do(req) + resp, err := r.Client.Do(req.WithContext(ctx)) if resp != nil { defer resp.Body.Close() } diff --git a/registry/ping.go b/registry/ping.go index 2e984ace..d2628bbf 100644 --- a/registry/ping.go +++ b/registry/ping.go @@ -1,10 +1,19 @@ package registry +import ( + "context" + "net/http" +) + // Ping tries to contact a registry URL to make sure it is up and accessible. -func (r *Registry) Ping() error { +func (r *Registry) Ping(ctx context.Context) error { url := r.url("/v2/") r.Logf("registry.ping url=%s", url) - resp, err := r.Client.Get(url) + req, err := http.NewRequest("GET", url, nil) + if err != nil { + return err + } + resp, err := r.Client.Do(req.WithContext(ctx)) if resp != nil { defer resp.Body.Close() } diff --git a/registry/registry.go b/registry/registry.go index ba63727e..3e49fb9a 100644 --- a/registry/registry.go +++ b/registry/registry.go @@ -1,6 +1,7 @@ package registry import ( + "context" "crypto/tls" "encoding/json" "fmt" @@ -51,7 +52,7 @@ type Opt struct { } // New creates a new Registry struct with the given URL and credentials. -func New(auth types.AuthConfig, opt Opt) (*Registry, error) { +func New(ctx context.Context, auth types.AuthConfig, opt Opt) (*Registry, error) { transport := http.DefaultTransport if opt.Insecure { @@ -62,10 +63,10 @@ func New(auth types.AuthConfig, opt Opt) (*Registry, error) { } } - return newFromTransport(auth, transport, opt) + return newFromTransport(ctx, auth, transport, opt) } -func newFromTransport(auth types.AuthConfig, transport http.RoundTripper, opt Opt) (*Registry, error) { +func newFromTransport(ctx context.Context, auth types.AuthConfig, transport http.RoundTripper, opt Opt) (*Registry, error) { if len(opt.Domain) < 1 { opt.Domain = auth.ServerAddress } @@ -127,7 +128,7 @@ func newFromTransport(auth types.AuthConfig, transport http.RoundTripper, opt Op } if !opt.SkipPing { - if err := registry.Ping(); err != nil { + if err := registry.Ping(ctx); err != nil { return nil, err } } @@ -142,7 +143,7 @@ func (r *Registry) url(pathTemplate string, args ...interface{}) string { return url } -func (r *Registry) getJSON(url string, response interface{}) (http.Header, error) { +func (r *Registry) getJSON(ctx context.Context, url string, response interface{}) (http.Header, error) { req, err := http.NewRequest("GET", url, nil) if err != nil { return nil, err @@ -155,7 +156,7 @@ func (r *Registry) getJSON(url string, response interface{}) (http.Header, error req.Header.Add("Accept", fmt.Sprintf("%s;q=0.9", manifestlist.MediaTypeManifestList)) } - resp, err := r.Client.Do(req) + resp, err := r.Client.Do(req.WithContext(ctx)) if err != nil { return nil, err } diff --git a/registry/tags.go b/registry/tags.go index 8c320134..8f2493eb 100644 --- a/registry/tags.go +++ b/registry/tags.go @@ -1,16 +1,18 @@ package registry +import "context" + type tagsResponse struct { Tags []string `json:"tags"` } // Tags returns the tags for a specific repository. -func (r *Registry) Tags(repository string) ([]string, error) { +func (r *Registry) Tags(ctx context.Context, repository string) ([]string, error) { url := r.url("/v2/%s/tags/list", repository) r.Logf("registry.tags url=%s repository=%s", url, repository) var response tagsResponse - if _, err := r.getJSON(url, &response); err != nil { + if _, err := r.getJSON(ctx, url, &response); err != nil { return nil, err } diff --git a/registry/tokentransport.go b/registry/tokentransport.go index 3b9b901d..0dea8d92 100644 --- a/registry/tokentransport.go +++ b/registry/tokentransport.go @@ -1,6 +1,7 @@ package registry import ( + "context" "crypto/tls" "encoding/base64" "encoding/json" @@ -58,7 +59,7 @@ func (t authToken) String() (string, error) { } func (t *TokenTransport) authAndRetry(authService *authService, req *http.Request) (*http.Response, error) { - token, authResp, err := t.auth(authService) + token, authResp, err := t.auth(req.Context(), authService) if err != nil { return authResp, err } @@ -70,7 +71,7 @@ func (t *TokenTransport) authAndRetry(authService *authService, req *http.Reques return response, err } -func (t *TokenTransport) auth(authService *authService) (string, *http.Response, error) { +func (t *TokenTransport) auth(ctx context.Context, authService *authService) (string, *http.Response, error) { authReq, err := authService.Request(t.Username, t.Password) if err != nil { return "", nil, err @@ -80,7 +81,7 @@ func (t *TokenTransport) auth(authService *authService) (string, *http.Response, Transport: t.Transport, } - resp, err := c.Do(authReq) + resp, err := c.Do(authReq.WithContext(ctx)) if err != nil { return "", nil, err } @@ -140,7 +141,7 @@ func isTokenDemand(resp *http.Response) (*authService, error) { // Token returns the required token for the specific resource url. If the registry requires basic authentication, this // function returns ErrBasicAuth. -func (r *Registry) Token(url string) (string, error) { +func (r *Registry) Token(ctx context.Context, url string) (string, error) { r.Logf("registry.token url=%s", url) req, err := http.NewRequest("GET", url, nil) @@ -160,7 +161,7 @@ func (r *Registry) Token(url string) (string, error) { } } - resp, err := client.Do(req) + resp, err := client.Do(req.WithContext(ctx)) if err != nil { return "", err } @@ -187,7 +188,7 @@ func (r *Registry) Token(url string) (string, error) { if err != nil { return "", err } - resp, err = http.DefaultClient.Do(authReq) + resp, err = http.DefaultClient.Do(authReq.WithContext(ctx)) if err != nil { return "", err } @@ -206,9 +207,9 @@ func (r *Registry) Token(url string) (string, error) { } // Headers returns the authorization headers for a specific uri. -func (r *Registry) Headers(uri string) (map[string]string, error) { +func (r *Registry) Headers(ctx context.Context, uri string) (map[string]string, error) { // Get the token. - token, err := r.Token(uri) + token, err := r.Token(ctx, uri) if err != nil { if err == ErrBasicAuth { // If we couldn't get a token because the server requires basic auth, just return basic auth headers. diff --git a/registry/tokentransport_test.go b/registry/tokentransport_test.go index 9e138090..a824a2dd 100644 --- a/registry/tokentransport_test.go +++ b/registry/tokentransport_test.go @@ -1,6 +1,7 @@ package registry import ( + "context" "errors" "net/http" "net/http/httptest" @@ -11,6 +12,7 @@ import ( ) func TestErrBasicAuth(t *testing.T) { + ctx := context.Background() ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.URL.Path == "/" { w.Header().Set("www-authenticate", `Basic realm="Registry Realm",service="Docker registry"`) @@ -26,11 +28,11 @@ func TestErrBasicAuth(t *testing.T) { Password: "ss3j", ServerAddress: ts.URL, } - r, err := New(authConfig, Opt{Insecure: true, Debug: true}) + r, err := New(ctx, authConfig, Opt{Insecure: true, Debug: true}) if err != nil { t.Fatalf("expected no error creating client, got %v", err) } - token, err := r.Token(ts.URL) + token, err := r.Token(ctx, ts.URL) if err != ErrBasicAuth { t.Fatalf("expected ErrBasicAuth getting token, got %v", err) } @@ -72,6 +74,7 @@ func TestBothTokenAndAccessTokenWork(t *testing.T) { defer ts.Close() for _, which := range []string{"token", "accesstoken"} { + ctx := context.Background() authURI = ts.URL + "/oauth2/" + which + "?service=my.endpoint.here" authConfig := types.AuthConfig{ Username: "abc", @@ -79,11 +82,11 @@ func TestBothTokenAndAccessTokenWork(t *testing.T) { ServerAddress: ts.URL, } authConfig.Email = "me@email.com" - r, err := New(authConfig, Opt{Insecure: true, Debug: true}) + r, err := New(ctx, authConfig, Opt{Insecure: true, Debug: true}) if err != nil { t.Fatalf("expected no error creating client, got %v", err) } - token, err := r.Token(ts.URL) + token, err := r.Token(ctx, ts.URL) if err != nil { t.Fatalf("err getting token from url: %v err: %v", ts.URL, err) } diff --git a/remove.go b/remove.go index 4b6deca8..572a4da3 100644 --- a/remove.go +++ b/remove.go @@ -31,13 +31,13 @@ func (cmd *removeCommand) Run(ctx context.Context, args []string) error { } // Create the registry client. - r, err := createRegistryClient(image.Domain) + r, err := createRegistryClient(ctx, image.Domain) if err != nil { return err } // Get the digest. - digest, err := r.Digest(image) + digest, err := r.Digest(ctx, image) if err != nil { return err } @@ -47,7 +47,7 @@ func (cmd *removeCommand) Run(ctx context.Context, args []string) error { } // Delete the reference. - if err := r.Delete(image.Path, digest); err != nil { + if err := r.Delete(ctx, image.Path, digest); err != nil { return err } fmt.Printf("Deleted %s\n", image.String()) diff --git a/server.go b/server.go index 8bc21a59..fcd4cd71 100644 --- a/server.go +++ b/server.go @@ -61,7 +61,7 @@ type serverCommand struct { func (cmd *serverCommand) Run(ctx context.Context, args []string) error { // Create the registry client. - r, err := createRegistryClient(cmd.registryServer) + r, err := createRegistryClient(ctx, cmd.registryServer) if err != nil { return err } @@ -127,7 +127,7 @@ func (cmd *serverCommand) Run(ctx context.Context, args []string) error { // Create the initial index. logrus.Info("creating initial static index") - if err := rc.repositories(staticDir); err != nil { + if err := rc.repositories(ctx, staticDir); err != nil { return fmt.Errorf("creating index failed: %v", err) } @@ -142,7 +142,7 @@ func (cmd *serverCommand) Run(ctx context.Context, args []string) error { // Create more indexes every X minutes based off interval. for range ticker.C { logrus.Info("creating timer based static index") - if err := rc.repositories(staticDir); err != nil { + if err := rc.repositories(ctx, staticDir); err != nil { logrus.Warnf("creating static index failed: %v", err) } } diff --git a/tags.go b/tags.go index fd0fb96b..78201297 100644 --- a/tags.go +++ b/tags.go @@ -33,12 +33,12 @@ func (cmd *tagsCommand) Run(ctx context.Context, args []string) error { } // Create the registry client. - r, err := createRegistryClient(image.Domain) + r, err := createRegistryClient(ctx, image.Domain) if err != nil { return err } - tags, err := r.Tags(image.Path) + tags, err := r.Tags(ctx, image.Path) if err != nil { return err } diff --git a/vulns.go b/vulns.go index b9e89e4b..e01594b3 100644 --- a/vulns.go +++ b/vulns.go @@ -49,7 +49,7 @@ func (cmd *vulnsCommand) Run(ctx context.Context, args []string) error { } // Create the registry client. - r, err := createRegistryClient(image.Domain) + r, err := createRegistryClient(ctx, image.Domain) if err != nil { return err } @@ -65,10 +65,10 @@ func (cmd *vulnsCommand) Run(ctx context.Context, args []string) error { } // Get the vulnerability report. - report, err := cr.VulnerabilitiesV3(r, image.Path, image.Reference()) + report, err := cr.VulnerabilitiesV3(ctx, r, image.Path, image.Reference()) if err != nil { // Fallback to Clair v2 API. - report, err = cr.Vulnerabilities(r, image.Path, image.Reference()) + report, err = cr.Vulnerabilities(ctx, r, image.Path, image.Reference()) if err != nil { return err }