Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 12 additions & 11 deletions pipe/pipeline.go
Original file line number Diff line number Diff line change
Expand Up @@ -286,9 +286,6 @@ func (p *Pipeline) Start(ctx context.Context) error {
// Store the stages in the joiners, and verify that the stages'
// requirements are well-formed:
for i, s := range p.stages {
stageJoiners[i].nextStage = s
stageJoiners[i+1].prevStage = s

// Make sure that the stage's requirements are well-formed:
requirements := s.Requirements()
if err := requirements.Stdin.Validate(); err != nil {
Expand All @@ -297,21 +294,25 @@ func (p *Pipeline) Start(ctx context.Context) error {
if err := requirements.Stdout.Validate(); err != nil {
return fmt.Errorf("stdout: %w", err)
}

stageJoiners[i].nextStage = s
stageJoiners[i].nextStageReq = requirements
stageJoiners[i+1].prevStage = s
stageJoiners[i+1].prevStageReq = requirements
}

// Create the "inner" pipes (i.e, all but the first and last
// `stageJoiners`):
for i := 1; i < len(stageJoiners)-1; i++ {
if err := stageJoiners[i].createPipe(); err != nil {
// Check that each of the stages' requirements are satisfiable:
for i := range stageJoiners {
if err := stageJoiners[i].validate(); err != nil {
closePipes()
return err
}
}

// Check that each of the stages' requirements are compatible with
// the pipes that we have created for them:
for i := range stageJoiners {
if err := stageJoiners[i].validate(); err != nil {
// Create the "inner" pipes (i.e, all but the first and last
// `stageJoiners`):
for i := 1; i < len(stageJoiners)-1; i++ {
if err := stageJoiners[i].createPipe(); err != nil {
closePipes()
return err
}
Expand Down
55 changes: 31 additions & 24 deletions pipe/stage_joiner.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,11 @@ type stageJoiner struct {
// prevStage holds the stage that needs to write to the pipe.
prevStage Stage

// prevStageReq caches `prevStage.Requirements()` so that it
// doesn't have to be recomputed. It is the zero value if
// `prevStage` is nil.
prevStageReq StageRequirements

// prevStdout will be used as the stdout of `prevStage`. It is
// usually the "write" end of the `(nextStdin, prevStdout)` pipe
// pair, with the connected pipe ends in the same `stageJoiner`
Expand All @@ -52,6 +57,11 @@ type stageJoiner struct {
// nextStage holds the stage that needs to read from the pipe.
nextStage Stage

// nextStageReq caches `nextStage.Requirements()` so that it
// doesn't have to be recomputed. It is the zero value if
// `nextStage` is nil.
nextStageReq StageRequirements

// nextStdin will be used as the stdin of `nextStage`. It is
// usually the "read" end of the `(nextStdin, prevStdout)` pipe
// pair.
Expand All @@ -61,13 +71,8 @@ type stageJoiner struct {
// needFilePipe returns `true` if the pipe that joins the two adjacent
// stages should be an `os.Pipe()` rather than an `io.Pipe()`.
func (sj *stageJoiner) needFilePipe() bool {
if sj.prevStage.Requirements().Stdout == StreamPreferFile {
return true
}
if sj.nextStage.Requirements().Stdin == StreamPreferFile {
return true
}
return false
return sj.prevStageReq.Stdout == StreamPreferFile ||
sj.nextStageReq.Stdin == StreamPreferFile
}

func (sj *stageJoiner) createPipe() error {
Expand Down Expand Up @@ -100,26 +105,28 @@ func (sj *stageJoiner) closePipe() error {
)
}

// validate verifies that `sj.prevStdout` and `sj.nextStdin` are
// suitable for the adjacent stages, in particular that no pipe is
// created if the stage requirements are `StreamForbidden`.
// validate verifies that the adjacent stages' stream requirements are
// satisfiable, in particular that a stage that forbids its stdin or
// stdout is not connected to anything.
func (sj *stageJoiner) validate() error {
if sj.prevStage != nil {
stdoutRequirements := sj.prevStage.Requirements().Stdout
if stdoutRequirements == StreamForbidden && sj.prevStdout != nil {
return fmt.Errorf(
"stage %q forbids stdout, but stdout is connected", sj.prevStage.Name(),
)
}
// `prevStage`'s stdout is connected if there is a `nextStage` to
// consume it (in which case an inner pipe will be created) or if
// a stream (`p.stdout`) has already been stored in `prevStdout`.
if sj.prevStage != nil && sj.prevStageReq.Stdout == StreamForbidden &&
(sj.nextStage != nil || sj.prevStdout != nil) {
return fmt.Errorf(
"stage %q forbids stdout, but stdout is connected", sj.prevStage.Name(),
)
}

if sj.nextStage != nil {
stdinRequirements := sj.nextStage.Requirements().Stdin
if stdinRequirements == StreamForbidden && sj.nextStdin != nil {
return fmt.Errorf(
"stage %q forbids stdin, but stdin is connected", sj.nextStage.Name(),
)
}
// `nextStage`'s stdin is connected if there is a `prevStage` to
// produce it (in which case an inner pipe will be created) or if
// a stream (`p.stdin`) has already been stored in `nextStdin`.
if sj.nextStage != nil && sj.nextStageReq.Stdin == StreamForbidden &&
(sj.prevStage != nil || sj.nextStdin != nil) {
return fmt.Errorf(
"stage %q forbids stdin, but stdin is connected", sj.nextStage.Name(),
)
}

return nil
Expand Down
Loading