diff --git a/temporalio/lib/temporalio/client.rb b/temporalio/lib/temporalio/client.rb index 9f9a2f14..2b73185c 100644 --- a/temporalio/lib/temporalio/client.rb +++ b/temporalio/lib/temporalio/client.rb @@ -6,6 +6,8 @@ require 'temporalio/client/async_activity_handle' require 'temporalio/client/connection' require 'temporalio/client/interceptor' +require 'temporalio/client/metadata_provider' +require 'temporalio/client/metadata_injection_interceptor' require 'temporalio/client/plugin' require 'temporalio/client/schedule' require 'temporalio/client/schedule_handle' diff --git a/temporalio/lib/temporalio/client/metadata_injection_interceptor.rb b/temporalio/lib/temporalio/client/metadata_injection_interceptor.rb new file mode 100644 index 00000000..afa9ffb4 --- /dev/null +++ b/temporalio/lib/temporalio/client/metadata_injection_interceptor.rb @@ -0,0 +1,107 @@ +# frozen_string_literal: true + +require 'temporalio/client/metadata_provider' + +module Temporalio + class Client + # Interceptor that injects metadata from a provider into all requests + class MetadataInjectionInterceptor + include Interceptor + + def initialize(metadata_provider) + @metadata_provider = metadata_provider + end + + def intercept_client(next_interceptor) + MetadataInjectingOutbound.new(next_interceptor, @metadata_provider) + end + + # Outbound that injects metadata + class MetadataInjectingOutbound + def initialize(next_outbound, metadata_provider) + @next_outbound = next_outbound + @metadata_provider = metadata_provider + end + + def start_workflow(input) + inject_metadata(input) + @next_outbound.start_workflow(input) + end + + def signal_with_start_workflow(input) + inject_metadata(input) + @next_outbound.signal_with_start_workflow(input) + end + + def start_update_with_start_workflow(input) + inject_metadata(input) + @next_outbound.start_update_with_start_workflow(input) + end + + def list_workflow_page(input) + inject_metadata(input) + @next_outbound.list_workflow_page(input) + end + + def count_workflows(input) + inject_metadata(input) + @next_outbound.count_workflows(input) + end + + def describe_workflow(input) + inject_metadata(input) + @next_outbound.describe_workflow(input) + end + + def fetch_workflow_history_events(input) + inject_metadata(input) + @next_outbound.fetch_workflow_history_events(input) + end + + def query_workflow(input) + inject_metadata(input) + @next_outbound.query_workflow(input) + end + + def signal_workflow(input) + inject_metadata(input) + @next_outbound.signal_workflow(input) + end + + def signal_workflow_with_start(input) + inject_metadata(input) + @next_outbound.signal_workflow_with_start(input) + end + + def request_cancel_workflow(input) + inject_metadata(input) + @next_outbound.request_cancel_workflow(input) + end + + def terminate_workflow(input) + inject_metadata(input) + @next_outbound.terminate_workflow(input) + end + + def update_workflow(input) + inject_metadata(input) + @next_outbound.update_workflow(input) + end + + private + + def inject_metadata(input) + provider_metadata = @metadata_provider.metadata + return if provider_metadata.empty? + + if input.respond_to?(:rpc_options=) + current_rpc_options = input.rpc_options || {} + current_metadata = current_rpc_options.is_a?(Hash) ? current_rpc_options : {} + merged_metadata = current_metadata.merge(provider_metadata) + input.rpc_options = merged_metadata + end + end + end + end + end +end diff --git a/temporalio/lib/temporalio/client/metadata_provider.rb b/temporalio/lib/temporalio/client/metadata_provider.rb new file mode 100644 index 00000000..9a576ce4 --- /dev/null +++ b/temporalio/lib/temporalio/client/metadata_provider.rb @@ -0,0 +1,95 @@ +# frozen_string_literal: true + +module Temporalio + class Client + # Base class for providing metadata (headers) to gRPC requests + class MetadataProvider + def metadata + raise NotImplementedError, 'Subclasses must implement metadata method' + end + end + + # Simple JWT token provider + class JwtTokenProvider < MetadataProvider + def initialize(token) + @token = token + end + + def metadata + { 'authorization' => "Bearer #{@token}" } + end + end + + # JWT supplier provider (calls lambda for token) + class JwtTokenSupplierProvider < MetadataProvider + def initialize(supplier) + @supplier = supplier + end + + def metadata + token = @supplier.call + token ? { 'authorization' => "Bearer #{token}" } : {} + end + end + + # Keycloak integration + class KeycloakJwtProvider < MetadataProvider + require 'net/http' + require 'json' + + def initialize(keycloak_url, realm, client_id, client_secret, cache_duration: 3600) + @keycloak_url = keycloak_url + @realm = realm + @client_id = client_id + @client_secret = client_secret + @cache_duration = cache_duration + @cached_token = nil + @token_expiry = nil + @mutex = Mutex.new + end + + def metadata + token = ensure_token + token ? { 'authorization' => "Bearer #{token}" } : {} + end + + private + + def ensure_token + @mutex.synchronize do + fetch_new_token if @cached_token.nil? || token_expired? + @cached_token + end + end + + def token_expired? + @token_expiry.nil? || Time.now > @token_expiry + end + + def fetch_new_token + token_url = "#{@keycloak_url}/realms/#{@realm}/protocol/openid-connect/token" + uri = URI(token_url) + + request = Net::HTTP::Post.new(uri.path) + request['Content-Type'] = 'application/x-www-form-urlencoded' + request.set_form_data( + grant_type: 'client_credentials', + client_id: @client_id, + client_secret: @client_secret + ) + + http = Net::HTTP.new(uri.host, uri.port) + http.use_ssl = uri.scheme == 'https' + response = http.request(request) + + if response.is_a?(Net::HTTPSuccess) + body = JSON.parse(response.body) + @cached_token = body['access_token'] + @token_expiry = Time.now + body['expires_in'] - 60 + else + raise "Keycloak token fetch failed: #{response.code}" + end + end + end + end +end