diff --git a/pkg/media/output.go b/pkg/media/output.go index edef1e15..543b5099 100644 --- a/pkg/media/output.go +++ b/pkg/media/output.go @@ -24,6 +24,7 @@ import ( "github.com/go-gst/go-gst/gst" "github.com/go-gst/go-gst/gst/app" "github.com/pion/webrtc/v3/pkg/media" + "google.golang.org/protobuf/proto" "github.com/livekit/ingress/pkg/errors" "github.com/livekit/ingress/pkg/stats" @@ -76,7 +77,7 @@ type AudioOutput struct { codec livekit.AudioCodec } -func NewVideoOutput(codec livekit.VideoCodec, layer *livekit.VideoLayer, outputSync *utils.TrackOutputSynchronizer, statsGatherer *stats.LocalMediaStatsGatherer) (*VideoOutput, error) { +func NewVideoOutput(codec livekit.VideoCodec, layer *ScaledVideoLayer, outputSync *utils.TrackOutputSynchronizer, statsGatherer *stats.LocalMediaStatsGatherer) (*VideoOutput, error) { e, err := newVideoOutput(codec, outputSync) if err != nil { return nil, err @@ -86,7 +87,7 @@ func NewVideoOutput(codec livekit.VideoCodec, layer *livekit.VideoLayer, outputS e.trackStatsGatherer = statsGatherer.RegisterTrackStats(fmt.Sprintf("%s.%s", stats.OutputVideo, layer.Quality.String())) - threadCount := getVideoEncoderThreadCount(layer) + threadCount := getVideoEncoderThreadCount(layer.VideoLayer) e.logger.Infow("video layer", "width", layer.Width, "height", layer.Height, "threads", threadCount) @@ -132,6 +133,33 @@ func NewVideoOutput(codec livekit.VideoCodec, layer *livekit.VideoLayer, outputS return nil, err } + queueIn.GetStaticPad("sink").AddProbe(gst.PadProbeTypeEventDownstream, func(self *gst.Pad, info *gst.PadProbeInfo) gst.PadProbeReturn { + ev := info.GetEvent() + if ev.Type() != gst.EventTypeCaps { + return gst.PadProbeOK + } + + caps := ev.ParseCaps() + w, h, err := getResolution(caps) + if err != nil { + logger.Errorw("input queue caps failed to get resolution", err) + } + + l := proto.Clone(layer.VideoLayer).(*livekit.VideoLayer) + l.Width = layer.MaxW + l.Height = layer.MaxH + applyResolutionToLayer(l, w, h) + + err = inputCaps.SetProperty("caps", gst.NewCapsFromString( + fmt.Sprintf("video/x-raw,width=%d,height=%d", l.Width, l.Height), + )) + if err != nil { + logger.Errorw("failed to set input capsfilter caps", err) + } + + return gst.PadProbeOK + }) + queueEnc, err := gst.NewElementWithName("queue", fmt.Sprintf("video_%s_enc", layer.Quality.String())) if err != nil { return nil, err diff --git a/pkg/media/pipeline.go b/pkg/media/pipeline.go index 3119842e..bc94e0f5 100644 --- a/pkg/media/pipeline.go +++ b/pkg/media/pipeline.go @@ -48,8 +48,10 @@ type Pipeline struct { sink *WebRTCSink input *Input - closed core.Fuse - cancel atomic.Pointer[context.CancelFunc] + audioSinkCreated atomic.Bool + videoSinkCreated atomic.Bool + closed core.Fuse + cancel atomic.Pointer[context.CancelFunc] pipelineErr chan error } @@ -119,6 +121,19 @@ func (p *Pipeline) onOutputReady(pad *gst.Pad, kind types.StreamKind) { } func (p *Pipeline) onParamsReady(kind types.StreamKind, gPad *gst.GhostPad, param *glib.ParamSpec) { + // Keep track of whether we've already created audio/video outputs, and skip + // adding more if so + switch kind { + case types.Audio: + if !p.audioSinkCreated.CompareAndSwap(false, true) { + return + } + case types.Video: + if !p.videoSinkCreated.CompareAndSwap(false, true) { + return + } + } + var err error // TODO fix go-gst to not create non nil gst.Caps for a NULL native caps pointer? @@ -139,7 +154,7 @@ func (p *Pipeline) onParamsReady(kind types.StreamKind, gPad *gst.GhostPad, para p.SendStateUpdate(context.Background()) }() - bin, err := p.sink.AddTrack(kind, caps.(*gst.Caps)) + bin, err := p.sink.AddTrack(kind, caps.(*gst.Caps), p.Params) if err != nil { return } diff --git a/pkg/media/webrtc_sink.go b/pkg/media/webrtc_sink.go index 30e79e4d..3948911f 100644 --- a/pkg/media/webrtc_sink.go +++ b/pkg/media/webrtc_sink.go @@ -180,7 +180,12 @@ func (s *WebRTCSink) addVideoTrack(w, h int) ([]*Output, error) { var tracks []*lksdk.LocalTrack var pliHandlers []*lksdk_output.RTCPHandler - tracks, pliHandlers, err = sdkOut.AddVideoTrack(sortedLayers, putils.GetMimeTypeForVideoCodec(s.params.VideoEncodingOptions.VideoCodec)) + layers := make([]*livekit.VideoLayer, 0, len(sortedLayers)) + for _, l := range sortedLayers { + layers = append(layers, l.VideoLayer) + } + + tracks, pliHandlers, err = sdkOut.AddVideoTrack(layers, putils.GetMimeTypeForVideoCodec(s.params.VideoEncodingOptions.VideoCodec)) if err != nil { return } @@ -198,7 +203,7 @@ func (s *WebRTCSink) addVideoTrack(w, h int) ([]*Output, error) { return outputs, nil } -func (s *WebRTCSink) AddTrack(kind types.StreamKind, caps *gst.Caps) (*gst.Bin, error) { +func (s *WebRTCSink) AddTrack(kind types.StreamKind, caps *gst.Caps, p *params.Params) (*gst.Bin, error) { var bin *gst.Bin switch kind { @@ -279,29 +284,30 @@ func getResolution(caps *gst.Caps) (w int, h int, err error) { return wObj.(int), hObj.(int), nil } -func filterAndSortLayersByQuality(layers []*livekit.VideoLayer, sourceW, sourceH int) []*livekit.VideoLayer { - layersByQuality := make(map[livekit.VideoQuality]*livekit.VideoLayer) +type ScaledVideoLayer struct { + *livekit.VideoLayer + + MaxW uint32 + MaxH uint32 +} + +func filterAndSortLayersByQuality(layers []*livekit.VideoLayer, sourceW, sourceH int) []*ScaledVideoLayer { + layersByQuality := make(map[livekit.VideoQuality]*ScaledVideoLayer) for _, layer := range layers { - layersByQuality[layer.Quality] = layer + layersByQuality[layer.Quality] = &ScaledVideoLayer{VideoLayer: layer, MaxW: layer.Width, MaxH: layer.Height} } - var ret []*livekit.VideoLayer + var ret []*ScaledVideoLayer for q := livekit.VideoQuality_LOW; q <= livekit.VideoQuality_HIGH; q++ { layer, ok := layersByQuality[q] if !ok { continue } - applyResolutionToLayer(layer, sourceW, sourceH) + applyResolutionToLayer(layer.VideoLayer, sourceW, sourceH) ret = append(ret, layer) - - if layer.Width >= uint32(sourceW) && layer.Height >= uint32(sourceH) { - // Next quality layer would be duplicate of current one - break - } - } return ret } @@ -320,7 +326,7 @@ func applyResolutionToLayer(layer *livekit.VideoLayer, sourceW, sourceH int) { w = uint32((int64(h) * int64(sourceW)) / int64(sourceH)) } - // Roubd up to the next even dimension + // Round up to the next even dimension w = ((w + 1) >> 1) << 1 h = ((h + 1) >> 1) << 1