Passing context (#163)

* passing context in layer calls

* more contexting

* clair folder and context in handlers

* fixed token transport to reuse request context

* tests

* taking out context pass in server handlers
This commit is contained in:
Jessica Tracy 2018-12-29 12:09:10 -05:00 committed by Jess Frazelle
parent 5635f17ffc
commit 32589e90be
26 changed files with 168 additions and 114 deletions

View file

@ -13,7 +13,7 @@ var (
)
// GetAncestry displays an ancestry and all of its features and vulnerabilities.
func (c *Clair) GetAncestry(name string) (*clairpb.GetAncestryResponse_Ancestry, error) {
func (c *Clair) GetAncestry(ctx context.Context, name string) (*clairpb.GetAncestryResponse_Ancestry, error) {
c.Logf("clair.ancestry.get name=%s", name)
if c.grpcConn == nil {
@ -22,7 +22,7 @@ func (c *Clair) GetAncestry(name string) (*clairpb.GetAncestryResponse_Ancestry,
client := clairpb.NewAncestryServiceClient(c.grpcConn)
resp, err := client.GetAncestry(context.Background(), &clairpb.GetAncestryRequest{
resp, err := client.GetAncestry(ctx, &clairpb.GetAncestryRequest{
AncestryName: name,
})
if err != nil {
@ -41,7 +41,7 @@ func (c *Clair) GetAncestry(name string) (*clairpb.GetAncestryResponse_Ancestry,
}
// PostAncestry performs the analysis of all layers from the provided path.
func (c *Clair) PostAncestry(name string, layers []*clairpb.PostAncestryRequest_PostLayer) error {
func (c *Clair) PostAncestry(ctx context.Context, name string, layers []*clairpb.PostAncestryRequest_PostLayer) error {
c.Logf("clair.ancestry.post name=%s", name)
if c.grpcConn == nil {
@ -50,7 +50,7 @@ func (c *Clair) PostAncestry(name string, layers []*clairpb.PostAncestryRequest_
client := clairpb.NewAncestryServiceClient(c.grpcConn)
resp, err := client.PostAncestry(context.Background(), &clairpb.PostAncestryRequest{
resp, err := client.PostAncestry(ctx, &clairpb.PostAncestryRequest{
AncestryName: name,
Layers: layers,
Format: "Docker",

View file

@ -1,6 +1,7 @@
package clair
import (
"context"
"crypto/tls"
"encoding/json"
"fmt"
@ -93,8 +94,12 @@ func (c *Clair) url(pathTemplate string, args ...interface{}) string {
return url
}
func (c *Clair) getJSON(url string, response interface{}) (http.Header, error) {
resp, err := c.Client.Get(url)
func (c *Clair) getJSON(ctx context.Context, url string, response interface{}) (http.Header, error) {
req, err := http.NewRequest("GET", url, nil)
if err != nil {
return nil, err
}
resp, err := c.Client.Do(req.WithContext(ctx))
if err != nil {
return nil, err
}

View file

@ -2,18 +2,19 @@ package clair
import (
"bytes"
"context"
"encoding/json"
"fmt"
"net/http"
)
// GetLayer displays a Layer and optionally all of its features and vulnerabilities.
func (c *Clair) GetLayer(name string, features, vulnerabilities bool) (*Layer, error) {
func (c *Clair) GetLayer(ctx context.Context, name string, features, vulnerabilities bool) (*Layer, error) {
url := c.url("/v1/layers/%s?features=%t&vulnerabilities=%t", name, features, vulnerabilities)
c.Logf("clair.layers.get url=%s name=%s", url, name)
var respLayer layerEnvelope
if _, err := c.getJSON(url, &respLayer); err != nil {
if _, err := c.getJSON(ctx, url, &respLayer); err != nil {
return nil, err
}
@ -25,7 +26,7 @@ func (c *Clair) GetLayer(name string, features, vulnerabilities bool) (*Layer, e
}
// PostLayer performs the analysis of a Layer from the provided path.
func (c *Clair) PostLayer(layer *Layer) (*Layer, error) {
func (c *Clair) PostLayer(ctx context.Context, layer *Layer) (*Layer, error) {
url := c.url("/v1/layers")
c.Logf("clair.layers.post url=%s name=%s", url, layer.Name)
@ -36,7 +37,14 @@ func (c *Clair) PostLayer(layer *Layer) (*Layer, error) {
c.Logf("clair.layers.post req.Body=%s", string(b))
resp, err := c.Client.Post(url, "application/json", bytes.NewReader(b))
req, err := http.NewRequest("POST", url, bytes.NewReader(b))
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", "application/json")
resp, err := c.Client.Do(req.WithContext(ctx))
if err != nil {
return nil, err
}
@ -57,7 +65,7 @@ func (c *Clair) PostLayer(layer *Layer) (*Layer, error) {
}
// DeleteLayer removes a layer reference from clair.
func (c *Clair) DeleteLayer(name string) error {
func (c *Clair) DeleteLayer(ctx context.Context, name string) error {
url := c.url("/v1/layers/%s", name)
c.Logf("clair.layers.delete url=%s name=%s", url, name)
@ -66,7 +74,7 @@ func (c *Clair) DeleteLayer(name string) error {
return err
}
resp, err := c.Client.Do(req)
resp, err := c.Client.Do(req.WithContext(ctx))
if err != nil {
return err
}

View file

@ -1,6 +1,7 @@
package clair
import (
"context"
"fmt"
"strings"
@ -10,7 +11,7 @@ import (
)
// NewClairLayer will form a layer struct required for a clair scan.
func (c *Clair) NewClairLayer(r *registry.Registry, image string, fsLayers map[int]distribution.Descriptor, index int) (*Layer, error) {
func (c *Clair) NewClairLayer(ctx context.Context, r *registry.Registry, image string, fsLayers map[int]distribution.Descriptor, index int) (*Layer, error) {
var parentName string
if index < len(fsLayers)-1 {
parentName = fsLayers[index+1].Digest.String()
@ -20,7 +21,7 @@ func (c *Clair) NewClairLayer(r *registry.Registry, image string, fsLayers map[i
p := strings.Join([]string{r.URL, "v2", image, "blobs", fsLayers[index].Digest.String()}, "/")
// Get the headers.
h, err := r.Headers(p)
h, err := r.Headers(ctx, p)
if err != nil {
return nil, err
}
@ -35,12 +36,12 @@ func (c *Clair) NewClairLayer(r *registry.Registry, image string, fsLayers map[i
}
// NewClairV3Layer will form a layer struct required for a clair scan.
func (c *Clair) NewClairV3Layer(r *registry.Registry, image string, fsLayer distribution.Descriptor) (*clairpb.PostAncestryRequest_PostLayer, error) {
func (c *Clair) NewClairV3Layer(ctx context.Context, r *registry.Registry, image string, fsLayer distribution.Descriptor) (*clairpb.PostAncestryRequest_PostLayer, error) {
// Form the path.
p := strings.Join([]string{r.URL, "v2", image, "blobs", fsLayer.Digest.String()}, "/")
// Get the headers.
h, err := r.Headers(p)
h, err := r.Headers(ctx, p)
if err != nil {
return nil, err
}
@ -52,10 +53,10 @@ func (c *Clair) NewClairV3Layer(r *registry.Registry, image string, fsLayer dist
}, nil
}
func (c *Clair) getLayers(r *registry.Registry, repo, tag string, filterEmpty bool) (map[int]distribution.Descriptor, string, error) {
func (c *Clair) getLayers(ctx context.Context, r *registry.Registry, repo, tag string, filterEmpty bool) (map[int]distribution.Descriptor, string, error) {
ok := true
// Get the manifest to pass to clair.
mf, err := r.ManifestV2(repo, tag)
mf, err := r.ManifestV2(ctx, repo, tag)
if err != nil {
ok = false
c.Logf("couldn't retrieve manifest v2, falling back to v1")
@ -76,7 +77,7 @@ func (c *Clair) getLayers(r *registry.Registry, repo, tag string, filterEmpty bo
return filteredLayers, mf.Config.Digest.String(), nil
}
m, err := r.ManifestV1(repo, tag)
m, err := r.ManifestV1(ctx, repo, tag)
if err != nil {
return nil, "", fmt.Errorf("getting the v1 manifest for %s:%s failed: %v", repo, tag, err)
}

View file

@ -1,6 +1,7 @@
package clair
import (
"context"
"errors"
"fmt"
"time"
@ -10,7 +11,7 @@ import (
)
// Vulnerabilities scans the given repo and tag.
func (c *Clair) Vulnerabilities(r *registry.Registry, repo, tag string) (VulnerabilityReport, error) {
func (c *Clair) Vulnerabilities(ctx context.Context, r *registry.Registry, repo, tag string) (VulnerabilityReport, error) {
report := VulnerabilityReport{
RegistryURL: r.Domain,
Repo: repo,
@ -19,7 +20,7 @@ func (c *Clair) Vulnerabilities(r *registry.Registry, repo, tag string) (Vulnera
VulnsBySeverity: make(map[string][]Vulnerability),
}
filteredLayers, _, err := c.getLayers(r, repo, tag, true)
filteredLayers, _, err := c.getLayers(ctx, r, repo, tag, true)
if err != nil {
return report, fmt.Errorf("getting filtered layers failed: %v", err)
}
@ -31,20 +32,20 @@ func (c *Clair) Vulnerabilities(r *registry.Registry, repo, tag string) (Vulnera
for i := len(filteredLayers) - 1; i >= 0; i-- {
// Form the clair layer.
l, err := c.NewClairLayer(r, repo, filteredLayers, i)
l, err := c.NewClairLayer(ctx, r, repo, filteredLayers, i)
if err != nil {
return report, err
}
// Post the layer.
if _, err := c.PostLayer(l); err != nil {
if _, err := c.PostLayer(ctx, l); err != nil {
return report, err
}
}
report.Name = filteredLayers[0].Digest.String()
vl, err := c.GetLayer(filteredLayers[0].Digest.String(), true, true)
vl, err := c.GetLayer(ctx, filteredLayers[0].Digest.String(), true, true)
if err != nil {
return report, err
}
@ -76,7 +77,7 @@ func (c *Clair) Vulnerabilities(r *registry.Registry, repo, tag string) (Vulnera
}
// VulnerabilitiesV3 scans the given repo and tag using the clair v3 API.
func (c *Clair) VulnerabilitiesV3(r *registry.Registry, repo, tag string) (VulnerabilityReport, error) {
func (c *Clair) VulnerabilitiesV3(ctx context.Context, r *registry.Registry, repo, tag string) (VulnerabilityReport, error) {
report := VulnerabilityReport{
RegistryURL: r.Domain,
Repo: repo,
@ -85,7 +86,7 @@ func (c *Clair) VulnerabilitiesV3(r *registry.Registry, repo, tag string) (Vulne
VulnsBySeverity: make(map[string][]Vulnerability),
}
layers, reportName, err := c.getLayers(r, repo, tag, false)
layers, reportName, err := c.getLayers(ctx, r, repo, tag, false)
if err != nil {
return report, fmt.Errorf("getting filtered layers failed: %v", err)
}
@ -100,7 +101,7 @@ func (c *Clair) VulnerabilitiesV3(r *registry.Registry, repo, tag string) (Vulne
clairLayers := []*clairpb.PostAncestryRequest_PostLayer{}
for i := len(layers) - 1; i >= 0; i-- {
// Form the clair layer.
l, err := c.NewClairV3Layer(r, repo, layers[i])
l, err := c.NewClairV3Layer(ctx, r, repo, layers[i])
if err != nil {
return report, err
}
@ -110,12 +111,12 @@ func (c *Clair) VulnerabilitiesV3(r *registry.Registry, repo, tag string) (Vulne
}
// Post the ancestry.
if err := c.PostAncestry(reportName, clairLayers); err != nil {
if err := c.PostAncestry(ctx, reportName, clairLayers); err != nil {
return report, fmt.Errorf("posting ancestry failed: %v", err)
}
// Get the ancestry.
vl, err := c.GetAncestry(reportName)
vl, err := c.GetAncestry(ctx, reportName)
if err != nil {
return report, err
}

View file

@ -31,13 +31,13 @@ func (cmd *digestCommand) Run(ctx context.Context, args []string) error {
}
// Create the registry client.
r, err := createRegistryClient(image.Domain)
r, err := createRegistryClient(ctx, image.Domain)
if err != nil {
return err
}
// Get the digest.
digest, err := r.Digest(image)
digest, err := r.Digest(ctx, image)
if err != nil {
return err
}

View file

@ -2,6 +2,7 @@ package main
import (
"bytes"
"context"
"encoding/json"
"fmt"
"html/template"
@ -53,7 +54,7 @@ type AnalysisResult struct {
UpdateInterval time.Duration
}
func (rc *registryController) repositories(staticDir string) error {
func (rc *registryController) repositories(ctx context.Context, staticDir string) error {
rc.l.Lock()
defer rc.l.Unlock()
@ -65,7 +66,7 @@ func (rc *registryController) repositories(staticDir string) error {
UpdateInterval: rc.interval,
}
repoList, err := rc.reg.Catalog("")
repoList, err := rc.reg.Catalog(ctx, "")
if err != nil {
return fmt.Errorf("getting catalog for %s failed: %v", rc.reg.Domain, err)
}
@ -94,7 +95,7 @@ func (rc *registryController) repositories(staticDir string) error {
// Parse and execute the tags templates.
// If we are generating the tags files, disable vulnerability links in the
// templates since they won't go anywhere without a server side component.
b, err := rc.generateTagsTemplate(repo, false)
b, err := rc.generateTagsTemplate(ctx, repo, false)
if err != nil {
logrus.Warnf("generating tags template for repo %q failed: %v", repo, err)
}
@ -156,7 +157,7 @@ func (rc *registryController) tagsHandler(w http.ResponseWriter, r *http.Request
}
// Generate the tags template.
b, err := rc.generateTagsTemplate(repo, rc.cl != nil)
b, err := rc.generateTagsTemplate(context.TODO(), repo, rc.cl != nil)
if err != nil {
logrus.WithFields(logrus.Fields{
"func": "tags",
@ -173,9 +174,9 @@ func (rc *registryController) tagsHandler(w http.ResponseWriter, r *http.Request
fmt.Fprint(w, string(b))
}
func (rc *registryController) generateTagsTemplate(repo string, hasVulns bool) ([]byte, error) {
func (rc *registryController) generateTagsTemplate(ctx context.Context, repo string, hasVulns bool) ([]byte, error) {
// Get the tags from the server.
tags, err := rc.reg.Tags(repo)
tags, err := rc.reg.Tags(ctx, repo)
if err != nil {
return nil, fmt.Errorf("getting tags for %s failed: %v", repo, err)
}
@ -196,7 +197,7 @@ func (rc *registryController) generateTagsTemplate(repo string, hasVulns bool) (
for _, tag := range tags {
// get the manifest
m1, err := rc.reg.ManifestV1(repo, tag)
m1, err := rc.reg.ManifestV1(ctx, repo, tag)
if err != nil {
return nil, fmt.Errorf("getting v1 manifest for %s:%s failed: %v", repo, tag, err)
}
@ -272,10 +273,10 @@ func (rc *registryController) vulnerabilitiesHandler(w http.ResponseWriter, r *h
}
// Get the vulnerability report.
result, err := rc.cl.VulnerabilitiesV3(rc.reg, image.Path, image.Reference())
result, err := rc.cl.VulnerabilitiesV3(context.TODO(), rc.reg, image.Path, image.Reference())
if err != nil {
// Fallback to Clair v2 API.
result, err = rc.cl.Vulnerabilities(rc.reg, image.Path, image.Reference())
result, err = rc.cl.Vulnerabilities(context.TODO(), rc.reg, image.Path, image.Reference())
if err != nil {
logrus.WithFields(logrus.Fields{
"func": "vulnerabilities",

View file

@ -38,19 +38,19 @@ func (cmd *layerCommand) Run(ctx context.Context, args []string) error {
}
// Create the registry client.
r, err := createRegistryClient(image.Domain)
r, err := createRegistryClient(ctx, image.Domain)
if err != nil {
return err
}
// Get the digest.
digest, err := r.Digest(image)
digest, err := r.Digest(ctx, image)
if err != nil {
return err
}
// Download the layer.
layer, err := r.DownloadLayer(image.Path, digest)
layer, err := r.DownloadLayer(ctx, image.Path, digest)
if err != nil {
return err
}

View file

@ -30,12 +30,12 @@ func (cmd *listCommand) Run(ctx context.Context, args []string) error {
}
// Create the registry client.
r, err := createRegistryClient(args[0])
r, err := createRegistryClient(ctx, args[0])
if err != nil {
return err
}
// Get the repositories via catalog.
repos, err := r.Catalog("")
repos, err := r.Catalog(ctx, "")
if err != nil {
if _, ok := err.(*json.SyntaxError); ok {
return fmt.Errorf("Domain %s is not a valid registry", r.Domain)
@ -56,7 +56,7 @@ func (cmd *listCommand) Run(ctx context.Context, args []string) error {
for _, repo := range repos {
go func(repo string) {
// Get the tags.
tags, err := r.Tags(repo)
tags, err := r.Tags(ctx, repo)
if err != nil {
fmt.Printf("Get tags of [%s] error: %s", repo, err)
}

View file

@ -102,7 +102,7 @@ func main() {
p.Run()
}
func createRegistryClient(domain string) (*registry.Registry, error) {
func createRegistryClient(ctx context.Context, domain string) (*registry.Registry, error) {
// Use the auth-url domain if provided.
authDomain := authURL
if authDomain == "" {
@ -119,7 +119,7 @@ func createRegistryClient(domain string) (*registry.Registry, error) {
}
// Create the registry client.
return registry.New(auth, registry.Opt{
return registry.New(ctx, auth, registry.Opt{
Domain: domain,
Insecure: insecure,
Debug: debug,

View file

@ -36,7 +36,7 @@ func (cmd *manifestCommand) Run(ctx context.Context, args []string) error {
}
// Create the registry client.
r, err := createRegistryClient(image.Domain)
r, err := createRegistryClient(ctx, image.Domain)
if err != nil {
return err
}
@ -44,13 +44,13 @@ func (cmd *manifestCommand) Run(ctx context.Context, args []string) error {
var manifest interface{}
if cmd.v1 {
// Get the v1 manifest if it was explicitly asked for.
manifest, err = r.ManifestV1(image.Path, image.Reference())
manifest, err = r.ManifestV1(ctx, image.Path, image.Reference())
if err != nil {
return err
}
} else {
// Get the v2 manifest.
manifest, err = r.Manifest(image.Path, image.Reference())
manifest, err = r.Manifest(ctx, image.Path, image.Reference())
if err != nil {
return err
}

View file

@ -1,6 +1,7 @@
package registry
import (
"context"
"net/url"
"github.com/peterhellberg/link"
@ -11,7 +12,7 @@ type catalogResponse struct {
}
// Catalog returns the repositories in a registry.
func (r *Registry) Catalog(u string) ([]string, error) {
func (r *Registry) Catalog(ctx context.Context, u string) ([]string, error) {
if u == "" {
u = "/v2/_catalog"
}
@ -19,7 +20,7 @@ func (r *Registry) Catalog(u string) ([]string, error) {
r.Logf("registry.catalog url=%s", uri)
var response catalogResponse
h, err := r.getJSON(uri, &response)
h, err := r.getJSON(ctx, uri, &response)
if err != nil {
return nil, err
}
@ -27,7 +28,7 @@ func (r *Registry) Catalog(u string) ([]string, error) {
for _, l := range link.ParseHeader(h) {
if l.Rel == "next" {
unescaped, _ := url.QueryUnescape(l.URI)
repos, err := r.Catalog(unescaped)
repos, err := r.Catalog(ctx, unescaped)
if err != nil {
return nil, err
}

View file

@ -1,6 +1,7 @@
package registry
import (
"context"
"fmt"
"net/http"
@ -10,7 +11,7 @@ import (
// Delete removes a repository digest from the registry.
// https://docs.docker.com/registry/spec/api/#deleting-an-image
func (r *Registry) Delete(repository string, digest digest.Digest) (err error) {
func (r *Registry) Delete(ctx context.Context, repository string, digest digest.Digest) (err error) {
url := r.url("/v2/%s/manifests/%s", repository, digest)
r.Logf("registry.manifests.delete url=%s repository=%s digest=%s",
url, repository, digest)
@ -21,7 +22,7 @@ func (r *Registry) Delete(repository string, digest digest.Digest) (err error) {
}
req.Header.Add("Accept", fmt.Sprintf("%s;q=0.9", schema2.MediaTypeManifest))
resp, err := r.Client.Do(req)
resp, err := r.Client.Do(req.WithContext(ctx))
if err != nil {
return err
}

View file

@ -1,6 +1,7 @@
package registry
import (
"context"
"fmt"
"net/http"
@ -9,7 +10,7 @@ import (
)
// Digest returns the digest for an image.
func (r *Registry) Digest(image Image) (digest.Digest, error) {
func (r *Registry) Digest(ctx context.Context, image Image) (digest.Digest, error) {
if len(image.Digest) > 1 {
// return early if we already have an image digest.
return image.Digest, nil
@ -25,7 +26,7 @@ func (r *Registry) Digest(image Image) (digest.Digest, error) {
}
req.Header.Add("Accept", schema2.MediaTypeManifest)
resp, err := r.Client.Do(req)
resp, err := r.Client.Do(req.WithContext(ctx))
if err != nil {
return "", err
}

View file

@ -1,23 +1,25 @@
package registry
import (
"context"
"testing"
"github.com/genuinetools/reg/repoutils"
)
func TestDigestFromDockerHub(t *testing.T) {
ctx := context.Background()
auth, err := repoutils.GetAuthConfig("", "", "docker.io")
if err != nil {
t.Fatalf("Could not get auth config: %s", err)
}
r, err := New(auth, Opt{})
r, err := New(ctx, auth, Opt{})
if err != nil {
t.Fatalf("Could not create registry instance: %s", err)
}
d, err := r.Digest(Image{Domain: "docker.io", Path: "library/alpine", Tag: "latest"})
d, err := r.Digest(ctx, Image{Domain: "docker.io", Path: "library/alpine", Tag: "latest"})
if err != nil {
t.Fatalf("Could not get digest: %s", err)
}
@ -28,17 +30,18 @@ func TestDigestFromDockerHub(t *testing.T) {
}
func TestDigestFromGCR(t *testing.T) {
ctx := context.Background()
auth, err := repoutils.GetAuthConfig("", "", "gcr.io")
if err != nil {
t.Fatalf("Could not get auth config: %s", err)
}
r, err := New(auth, Opt{})
r, err := New(ctx, auth, Opt{})
if err != nil {
t.Fatalf("Could not create registry instance: %s", err)
}
d, err := r.Digest(Image{Domain: "gcr.io", Path: "google_containers/hyperkube", Tag: "v1.9.9"})
d, err := r.Digest(ctx, Image{Domain: "gcr.io", Path: "google_containers/hyperkube", Tag: "v1.9.9"})
if err != nil {
t.Fatalf("Could not get digest: %s", err)
}

View file

@ -1,21 +1,27 @@
package registry
import (
"context"
"io"
"net/http"
"net/url"
"fmt"
"github.com/docker/distribution/reference"
"github.com/opencontainers/go-digest"
)
// DownloadLayer downloads a specific layer by digest for a repository.
func (r *Registry) DownloadLayer(repository string, digest digest.Digest) (io.ReadCloser, error) {
func (r *Registry) DownloadLayer(ctx context.Context, repository string, digest digest.Digest) (io.ReadCloser, error) {
url := r.url("/v2/%s/blobs/%s", repository, digest)
r.Logf("registry.layer.download url=%s repository=%s digest=%s", url, repository, digest)
resp, err := r.Client.Get(url)
req, err := http.NewRequest("GET", url, nil)
if err != nil {
return nil, err
}
resp, err := r.Client.Do(req.WithContext(ctx))
if err != nil {
return nil, err
}
@ -24,8 +30,8 @@ func (r *Registry) DownloadLayer(repository string, digest digest.Digest) (io.Re
}
// UploadLayer uploads a specific layer by digest for a repository.
func (r *Registry) UploadLayer(repository string, digest reference.Reference, content io.Reader) error {
uploadURL, token, err := r.initiateUpload(repository)
func (r *Registry) UploadLayer(ctx context.Context, repository string, digest reference.Reference, content io.Reader) error {
uploadURL, token, err := r.initiateUpload(ctx, repository)
if err != nil {
return err
}
@ -42,16 +48,20 @@ func (r *Registry) UploadLayer(repository string, digest reference.Reference, co
upload.Header.Set("Content-Type", "application/octet-stream")
upload.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token))
_, err = r.Client.Do(upload)
_, err = r.Client.Do(upload.WithContext(ctx))
return err
}
// HasLayer returns if the registry contains the specific digest for a repository.
func (r *Registry) HasLayer(repository string, digest digest.Digest) (bool, error) {
func (r *Registry) HasLayer(ctx context.Context, repository string, digest digest.Digest) (bool, error) {
checkURL := r.url("/v2/%s/blobs/%s", repository, digest)
r.Logf("registry.layer.check url=%s repository=%s digest=%s", checkURL, repository, digest)
resp, err := r.Client.Head(checkURL)
req, err := http.NewRequest("HEAD", checkURL, nil)
if err != nil {
return false, err
}
resp, err := r.Client.Do(req.WithContext(ctx))
if err == nil {
defer resp.Body.Close()
return resp.StatusCode == http.StatusOK, nil
@ -72,11 +82,16 @@ func (r *Registry) HasLayer(repository string, digest digest.Digest) (bool, erro
return false, err
}
func (r *Registry) initiateUpload(repository string) (*url.URL, string, error) {
func (r *Registry) initiateUpload(ctx context.Context, repository string) (*url.URL, string, error) {
initiateURL := r.url("/v2/%s/blobs/uploads/", repository)
r.Logf("registry.layer.initiate-upload url=%s repository=%s", initiateURL, repository)
resp, err := r.Client.Post(initiateURL, "application/octet-stream", nil)
req, err := http.NewRequest("POST", initiateURL, nil)
if err != nil {
return nil, "", err
}
req.Header.Set("Content-Type", "application/octet-stream")
resp, err := r.Client.Do(req.WithContext(ctx))
if err != nil {
return nil, "", err
}

View file

@ -2,6 +2,7 @@ package registry
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
@ -20,7 +21,7 @@ var (
)
// Manifest returns the manifest for a specific repository:tag.
func (r *Registry) Manifest(repository, ref string) (distribution.Manifest, error) {
func (r *Registry) Manifest(ctx context.Context, repository, ref string) (distribution.Manifest, error) {
uri := r.url("/v2/%s/manifests/%s", repository, ref)
r.Logf("registry.manifests uri=%s repository=%s ref=%s", uri, repository, ref)
@ -31,7 +32,7 @@ func (r *Registry) Manifest(repository, ref string) (distribution.Manifest, erro
req.Header.Add("Accept", fmt.Sprintf("%s;q=0.9", schema2.MediaTypeManifest))
resp, err := r.Client.Do(req)
resp, err := r.Client.Do(req.WithContext(ctx))
if err != nil {
return nil, err
}
@ -52,12 +53,12 @@ func (r *Registry) Manifest(repository, ref string) (distribution.Manifest, erro
}
// ManifestList gets the registry v2 manifest list.
func (r *Registry) ManifestList(repository, ref string) (manifestlist.ManifestList, error) {
func (r *Registry) ManifestList(ctx context.Context, repository, ref string) (manifestlist.ManifestList, error) {
uri := r.url("/v2/%s/manifests/%s", repository, ref)
r.Logf("registry.manifests uri=%s repository=%s ref=%s", uri, repository, ref)
var m manifestlist.ManifestList
if _, err := r.getJSON(uri, &m); err != nil {
if _, err := r.getJSON(ctx, uri, &m); err != nil {
r.Logf("registry.manifests response=%v", m)
return m, err
}
@ -66,12 +67,12 @@ func (r *Registry) ManifestList(repository, ref string) (manifestlist.ManifestLi
}
// ManifestV2 gets the registry v2 manifest.
func (r *Registry) ManifestV2(repository, ref string) (schema2.Manifest, error) {
func (r *Registry) ManifestV2(ctx context.Context, repository, ref string) (schema2.Manifest, error) {
uri := r.url("/v2/%s/manifests/%s", repository, ref)
r.Logf("registry.manifests uri=%s repository=%s ref=%s", uri, repository, ref)
var m schema2.Manifest
if _, err := r.getJSON(uri, &m); err != nil {
if _, err := r.getJSON(ctx, uri, &m); err != nil {
r.Logf("registry.manifests response=%v", m)
return m, err
}
@ -84,12 +85,12 @@ func (r *Registry) ManifestV2(repository, ref string) (schema2.Manifest, error)
}
// ManifestV1 gets the registry v1 manifest.
func (r *Registry) ManifestV1(repository, ref string) (schema1.SignedManifest, error) {
func (r *Registry) ManifestV1(ctx context.Context, repository, ref string) (schema1.SignedManifest, error) {
uri := r.url("/v2/%s/manifests/%s", repository, ref)
r.Logf("registry.manifests uri=%s repository=%s ref=%s", uri, repository, ref)
var m schema1.SignedManifest
if _, err := r.getJSON(uri, &m); err != nil {
if _, err := r.getJSON(ctx, uri, &m); err != nil {
r.Logf("registry.manifests response=%v", m)
return m, err
}
@ -102,7 +103,7 @@ func (r *Registry) ManifestV1(repository, ref string) (schema1.SignedManifest, e
}
// PutManifest calls a PUT for the specific manifest for an image.
func (r *Registry) PutManifest(repository, ref string, manifest distribution.Manifest) error {
func (r *Registry) PutManifest(ctx context.Context, repository, ref string, manifest distribution.Manifest) error {
url := r.url("/v2/%s/manifests/%s", repository, ref)
r.Logf("registry.manifest.put url=%s repository=%s reference=%s", url, repository, ref)
@ -117,7 +118,7 @@ func (r *Registry) PutManifest(repository, ref string, manifest distribution.Man
}
req.Header.Set("Content-Type", schema2.MediaTypeManifest)
resp, err := r.Client.Do(req)
resp, err := r.Client.Do(req.WithContext(ctx))
if resp != nil {
defer resp.Body.Close()
}

View file

@ -1,10 +1,19 @@
package registry
import (
"context"
"net/http"
)
// Ping tries to contact a registry URL to make sure it is up and accessible.
func (r *Registry) Ping() error {
func (r *Registry) Ping(ctx context.Context) error {
url := r.url("/v2/")
r.Logf("registry.ping url=%s", url)
resp, err := r.Client.Get(url)
req, err := http.NewRequest("GET", url, nil)
if err != nil {
return err
}
resp, err := r.Client.Do(req.WithContext(ctx))
if resp != nil {
defer resp.Body.Close()
}

View file

@ -1,6 +1,7 @@
package registry
import (
"context"
"crypto/tls"
"encoding/json"
"fmt"
@ -51,7 +52,7 @@ type Opt struct {
}
// New creates a new Registry struct with the given URL and credentials.
func New(auth types.AuthConfig, opt Opt) (*Registry, error) {
func New(ctx context.Context, auth types.AuthConfig, opt Opt) (*Registry, error) {
transport := http.DefaultTransport
if opt.Insecure {
@ -62,10 +63,10 @@ func New(auth types.AuthConfig, opt Opt) (*Registry, error) {
}
}
return newFromTransport(auth, transport, opt)
return newFromTransport(ctx, auth, transport, opt)
}
func newFromTransport(auth types.AuthConfig, transport http.RoundTripper, opt Opt) (*Registry, error) {
func newFromTransport(ctx context.Context, auth types.AuthConfig, transport http.RoundTripper, opt Opt) (*Registry, error) {
if len(opt.Domain) < 1 {
opt.Domain = auth.ServerAddress
}
@ -127,7 +128,7 @@ func newFromTransport(auth types.AuthConfig, transport http.RoundTripper, opt Op
}
if !opt.SkipPing {
if err := registry.Ping(); err != nil {
if err := registry.Ping(ctx); err != nil {
return nil, err
}
}
@ -142,7 +143,7 @@ func (r *Registry) url(pathTemplate string, args ...interface{}) string {
return url
}
func (r *Registry) getJSON(url string, response interface{}) (http.Header, error) {
func (r *Registry) getJSON(ctx context.Context, url string, response interface{}) (http.Header, error) {
req, err := http.NewRequest("GET", url, nil)
if err != nil {
return nil, err
@ -155,7 +156,7 @@ func (r *Registry) getJSON(url string, response interface{}) (http.Header, error
req.Header.Add("Accept", fmt.Sprintf("%s;q=0.9", manifestlist.MediaTypeManifestList))
}
resp, err := r.Client.Do(req)
resp, err := r.Client.Do(req.WithContext(ctx))
if err != nil {
return nil, err
}

View file

@ -1,16 +1,18 @@
package registry
import "context"
type tagsResponse struct {
Tags []string `json:"tags"`
}
// Tags returns the tags for a specific repository.
func (r *Registry) Tags(repository string) ([]string, error) {
func (r *Registry) Tags(ctx context.Context, repository string) ([]string, error) {
url := r.url("/v2/%s/tags/list", repository)
r.Logf("registry.tags url=%s repository=%s", url, repository)
var response tagsResponse
if _, err := r.getJSON(url, &response); err != nil {
if _, err := r.getJSON(ctx, url, &response); err != nil {
return nil, err
}

View file

@ -1,6 +1,7 @@
package registry
import (
"context"
"crypto/tls"
"encoding/base64"
"encoding/json"
@ -58,7 +59,7 @@ func (t authToken) String() (string, error) {
}
func (t *TokenTransport) authAndRetry(authService *authService, req *http.Request) (*http.Response, error) {
token, authResp, err := t.auth(authService)
token, authResp, err := t.auth(req.Context(), authService)
if err != nil {
return authResp, err
}
@ -70,7 +71,7 @@ func (t *TokenTransport) authAndRetry(authService *authService, req *http.Reques
return response, err
}
func (t *TokenTransport) auth(authService *authService) (string, *http.Response, error) {
func (t *TokenTransport) auth(ctx context.Context, authService *authService) (string, *http.Response, error) {
authReq, err := authService.Request(t.Username, t.Password)
if err != nil {
return "", nil, err
@ -80,7 +81,7 @@ func (t *TokenTransport) auth(authService *authService) (string, *http.Response,
Transport: t.Transport,
}
resp, err := c.Do(authReq)
resp, err := c.Do(authReq.WithContext(ctx))
if err != nil {
return "", nil, err
}
@ -140,7 +141,7 @@ func isTokenDemand(resp *http.Response) (*authService, error) {
// Token returns the required token for the specific resource url. If the registry requires basic authentication, this
// function returns ErrBasicAuth.
func (r *Registry) Token(url string) (string, error) {
func (r *Registry) Token(ctx context.Context, url string) (string, error) {
r.Logf("registry.token url=%s", url)
req, err := http.NewRequest("GET", url, nil)
@ -160,7 +161,7 @@ func (r *Registry) Token(url string) (string, error) {
}
}
resp, err := client.Do(req)
resp, err := client.Do(req.WithContext(ctx))
if err != nil {
return "", err
}
@ -187,7 +188,7 @@ func (r *Registry) Token(url string) (string, error) {
if err != nil {
return "", err
}
resp, err = http.DefaultClient.Do(authReq)
resp, err = http.DefaultClient.Do(authReq.WithContext(ctx))
if err != nil {
return "", err
}
@ -206,9 +207,9 @@ func (r *Registry) Token(url string) (string, error) {
}
// Headers returns the authorization headers for a specific uri.
func (r *Registry) Headers(uri string) (map[string]string, error) {
func (r *Registry) Headers(ctx context.Context, uri string) (map[string]string, error) {
// Get the token.
token, err := r.Token(uri)
token, err := r.Token(ctx, uri)
if err != nil {
if err == ErrBasicAuth {
// If we couldn't get a token because the server requires basic auth, just return basic auth headers.

View file

@ -1,6 +1,7 @@
package registry
import (
"context"
"errors"
"net/http"
"net/http/httptest"
@ -11,6 +12,7 @@ import (
)
func TestErrBasicAuth(t *testing.T) {
ctx := context.Background()
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/" {
w.Header().Set("www-authenticate", `Basic realm="Registry Realm",service="Docker registry"`)
@ -26,11 +28,11 @@ func TestErrBasicAuth(t *testing.T) {
Password: "ss3j",
ServerAddress: ts.URL,
}
r, err := New(authConfig, Opt{Insecure: true, Debug: true})
r, err := New(ctx, authConfig, Opt{Insecure: true, Debug: true})
if err != nil {
t.Fatalf("expected no error creating client, got %v", err)
}
token, err := r.Token(ts.URL)
token, err := r.Token(ctx, ts.URL)
if err != ErrBasicAuth {
t.Fatalf("expected ErrBasicAuth getting token, got %v", err)
}
@ -72,6 +74,7 @@ func TestBothTokenAndAccessTokenWork(t *testing.T) {
defer ts.Close()
for _, which := range []string{"token", "accesstoken"} {
ctx := context.Background()
authURI = ts.URL + "/oauth2/" + which + "?service=my.endpoint.here"
authConfig := types.AuthConfig{
Username: "abc",
@ -79,11 +82,11 @@ func TestBothTokenAndAccessTokenWork(t *testing.T) {
ServerAddress: ts.URL,
}
authConfig.Email = "me@email.com"
r, err := New(authConfig, Opt{Insecure: true, Debug: true})
r, err := New(ctx, authConfig, Opt{Insecure: true, Debug: true})
if err != nil {
t.Fatalf("expected no error creating client, got %v", err)
}
token, err := r.Token(ts.URL)
token, err := r.Token(ctx, ts.URL)
if err != nil {
t.Fatalf("err getting token from url: %v err: %v", ts.URL, err)
}

View file

@ -31,13 +31,13 @@ func (cmd *removeCommand) Run(ctx context.Context, args []string) error {
}
// Create the registry client.
r, err := createRegistryClient(image.Domain)
r, err := createRegistryClient(ctx, image.Domain)
if err != nil {
return err
}
// Get the digest.
digest, err := r.Digest(image)
digest, err := r.Digest(ctx, image)
if err != nil {
return err
}
@ -47,7 +47,7 @@ func (cmd *removeCommand) Run(ctx context.Context, args []string) error {
}
// Delete the reference.
if err := r.Delete(image.Path, digest); err != nil {
if err := r.Delete(ctx, image.Path, digest); err != nil {
return err
}
fmt.Printf("Deleted %s\n", image.String())

View file

@ -61,7 +61,7 @@ type serverCommand struct {
func (cmd *serverCommand) Run(ctx context.Context, args []string) error {
// Create the registry client.
r, err := createRegistryClient(cmd.registryServer)
r, err := createRegistryClient(ctx, cmd.registryServer)
if err != nil {
return err
}
@ -127,7 +127,7 @@ func (cmd *serverCommand) Run(ctx context.Context, args []string) error {
// Create the initial index.
logrus.Info("creating initial static index")
if err := rc.repositories(staticDir); err != nil {
if err := rc.repositories(ctx, staticDir); err != nil {
return fmt.Errorf("creating index failed: %v", err)
}
@ -142,7 +142,7 @@ func (cmd *serverCommand) Run(ctx context.Context, args []string) error {
// Create more indexes every X minutes based off interval.
for range ticker.C {
logrus.Info("creating timer based static index")
if err := rc.repositories(staticDir); err != nil {
if err := rc.repositories(ctx, staticDir); err != nil {
logrus.Warnf("creating static index failed: %v", err)
}
}

View file

@ -33,12 +33,12 @@ func (cmd *tagsCommand) Run(ctx context.Context, args []string) error {
}
// Create the registry client.
r, err := createRegistryClient(image.Domain)
r, err := createRegistryClient(ctx, image.Domain)
if err != nil {
return err
}
tags, err := r.Tags(image.Path)
tags, err := r.Tags(ctx, image.Path)
if err != nil {
return err
}

View file

@ -49,7 +49,7 @@ func (cmd *vulnsCommand) Run(ctx context.Context, args []string) error {
}
// Create the registry client.
r, err := createRegistryClient(image.Domain)
r, err := createRegistryClient(ctx, image.Domain)
if err != nil {
return err
}
@ -65,10 +65,10 @@ func (cmd *vulnsCommand) Run(ctx context.Context, args []string) error {
}
// Get the vulnerability report.
report, err := cr.VulnerabilitiesV3(r, image.Path, image.Reference())
report, err := cr.VulnerabilitiesV3(ctx, r, image.Path, image.Reference())
if err != nil {
// Fallback to Clair v2 API.
report, err = cr.Vulnerabilities(r, image.Path, image.Reference())
report, err = cr.Vulnerabilities(ctx, r, image.Path, image.Reference())
if err != nil {
return err
}