diff --git a/grr/server/grr_response_server/gui/api_call_router_with_approval_checks.py b/grr/server/grr_response_server/gui/api_call_router_with_approval_checks.py index ae1c340ad5..dd8c8281ff 100644 --- a/grr/server/grr_response_server/gui/api_call_router_with_approval_checks.py +++ b/grr/server/grr_response_server/gui/api_call_router_with_approval_checks.py @@ -709,6 +709,7 @@ def ListScheduledFlows( args: api_flow_pb2.ApiListScheduledFlowsArgs, context: Optional[api_call_context.ApiCallContext] = None, ) -> api_flow.ApiListScheduledFlowsHandler: + self.approval_checker.CheckClientAccess(context, args.client_id) return self.delegate.ListScheduledFlows(args, context=context) def UnscheduleFlow( diff --git a/grr/server/grr_response_server/gui/api_call_router_with_approval_checks_test.py b/grr/server/grr_response_server/gui/api_call_router_with_approval_checks_test.py index 184a7ad650..3ef240081e 100644 --- a/grr/server/grr_response_server/gui/api_call_router_with_approval_checks_test.py +++ b/grr/server/grr_response_server/gui/api_call_router_with_approval_checks_test.py @@ -382,6 +382,7 @@ def testGetOsqueryResultsChecksClientAccessIfNotPartOfHunt(self): "ListAllFlowOutputPluginLogs", "ListFlowLogs", "ScheduleFlow", + "ListScheduledFlows", ]) def testClientFlowsMethodsAreAccessChecked(self): @@ -415,6 +416,13 @@ def testClientFlowsMethodsAreAccessChecked(self): args=args, ) + args = api_flow_pb2.ApiListScheduledFlowsArgs( + client_id=self.client_id, creator="test" + ) + self.CheckMethodIsAccessChecked( + self.router.ListScheduledFlows, "CheckClientAccess", args=args + ) + args = api_flow_pb2.ApiCancelFlowArgs(client_id=self.client_id) self.CheckMethodIsAccessChecked( self.router.CancelFlow, "CheckClientAccess", args=args diff --git a/grr/server/grr_response_server/gui/api_plugins/flow.py b/grr/server/grr_response_server/gui/api_plugins/flow.py index 265c6c60cd..84f0b333a5 100644 --- a/grr/server/grr_response_server/gui/api_plugins/flow.py +++ b/grr/server/grr_response_server/gui/api_plugins/flow.py @@ -1411,8 +1411,10 @@ def Handle( args: flow_pb2.ApiListScheduledFlowsArgs, context: Optional[api_call_context.ApiCallContext] = None, ) -> flow_pb2.ApiListScheduledFlowsResult: + assert context is not None + results = flow.ListScheduledFlows( - client_id=args.client_id, creator=args.creator + client_id=args.client_id, creator=context.username ) results = sorted(results, key=lambda sf: sf.create_time) results = [InitApiScheduledFlowFromScheduledFlow(sf) for sf in results] diff --git a/grr/server/grr_response_server/gui/api_plugins/flow_test.py b/grr/server/grr_response_server/gui/api_plugins/flow_test.py index 4a02c79c8c..8eb5d3f545 100644 --- a/grr/server/grr_response_server/gui/api_plugins/flow_test.py +++ b/grr/server/grr_response_server/gui/api_plugins/flow_test.py @@ -866,6 +866,34 @@ def testListScheduledFlows(self, db: abstract_db.Database): self.assertCountEqual(results.scheduled_flows, [sf1, sf2]) + @db_test_lib.WithDatabase + def testListScheduledFlowsUsesContextUsername( + self, db: abstract_db.Database + ): + """Verify that handler uses context.username, not args.creator.""" + context = _CreateContext(db) + client_id = db_test_utils.InitializeClient(db) + + # Schedule a flow as the authenticated user. + handler = flow_plugin.ApiScheduleFlowHandler() + args = flow_pb2.ApiCreateFlowArgs() + args.client_id = client_id + args.flow.name = file.CollectFilesByKnownPath.__name__ + args.flow.args.Pack(flows_pb2.CollectFilesByKnownPathArgs(paths=["/foo"])) + args.flow.runner_args.CopyFrom(flows_pb2.FlowRunnerArgs(cpu_limit=60)) + sf = handler.Handle(args, context=context) + + # List with args.creator set to a different user -- the handler should + # ignore it and use context.username instead. + handler = flow_plugin.ApiListScheduledFlowsHandler() + args = flow_pb2.ApiListScheduledFlowsArgs( + client_id=client_id, creator="someotheruser" + ) + results = handler.Handle(args, context=context) + + # Should still return the authenticated user's flows, not "someotheruser". + self.assertCountEqual(results.scheduled_flows, [sf]) + @db_test_lib.WithDatabase def testUnscheduleFlowRemovesScheduledFlow(self, db: abstract_db.Database): context = _CreateContext(db)