diff --git a/android/app/build.gradle b/android/app/build.gradle index 7b7214c5..ca1d73c5 100644 --- a/android/app/build.gradle +++ b/android/app/build.gradle @@ -78,6 +78,7 @@ flutter { dependencies { def workVersion = "2.9.1" + implementation "androidx.browser:browser:1.9.0" implementation "androidx.security:security-crypto:1.0.0" implementation "androidx.work:work-runtime-ktx:$workVersion" implementation 'com.google.code.gson:gson:2.11.0' diff --git a/android/app/src/main/kotlin/net/defined/mobile_nebula/APIClient.kt b/android/app/src/main/kotlin/net/defined/mobile_nebula/APIClient.kt index 93e4beb5..33a4849c 100644 --- a/android/app/src/main/kotlin/net/defined/mobile_nebula/APIClient.kt +++ b/android/app/src/main/kotlin/net/defined/mobile_nebula/APIClient.kt @@ -19,10 +19,18 @@ class APIClient(context: Context) { return decodeIncomingSite(res.site) } - fun tryUpdate(siteName: String, hostID: String, privateKey: String, counter: Long, trustedKeys: String): IncomingSite? { - val res: mobileNebula.TryUpdateResult + fun preauth(): mobileNebula.PreAuthResult { + return client.endpointPreAuth() + } + + fun authPoll(pollToken: String): mobileNebula.PollDataResult { + return client.endpointAuthPoll(pollToken) + } + + fun longPollWait(siteName: String, hostID: String, privateKey: String, counter: Long, trustedKeys: String): IncomingSite? { + val res: mobileNebula.LongPollWaitResult try { - res = client.tryUpdate(siteName, hostID, privateKey, counter, trustedKeys) + res = client.longPollWait(siteName, hostID, privateKey, counter, trustedKeys) } catch (e: Exception) { // type information from Go is not available, use string matching instead if (e.message == "invalid credentials") { @@ -42,4 +50,17 @@ class APIClient(context: Context) { private fun decodeIncomingSite(jsonSite: String): IncomingSite { return gson.fromJson(jsonSite, IncomingSite::class.java) } + + fun reauthenticate(creds: DNCredentials): String { + try { + return client.reauthenticate(creds.hostID, creds.privateKey, creds.counter.toLong(), creds.trustedKeys); + } catch (e: Exception) { + // type information from Go is not available, use string matching instead + if (e.message == "invalid credentials") { + throw InvalidCredentialsException() + } + + throw e + } + } } \ No newline at end of file diff --git a/android/app/src/main/kotlin/net/defined/mobile_nebula/DNUpdateWorker.kt b/android/app/src/main/kotlin/net/defined/mobile_nebula/DNUpdateWorker.kt index d640d439..456b06ea 100644 --- a/android/app/src/main/kotlin/net/defined/mobile_nebula/DNUpdateWorker.kt +++ b/android/app/src/main/kotlin/net/defined/mobile_nebula/DNUpdateWorker.kt @@ -39,6 +39,7 @@ class DNUpdateWorker(ctx: Context, params: WorkerParameters) : Worker(ctx, param private fun updateSite(site: Site) { try { + Log.i(TAG, "updateSite for ${site.name}") DNUpdateLock(site).use { val res = updater.updateSite(site) @@ -98,7 +99,7 @@ class DNSiteUpdater( val newSite: IncomingSite? try { - newSite = apiClient.tryUpdate( + newSite = apiClient.longPollWait( site.name, credentials.hostID, credentials.privateKey, @@ -108,11 +109,12 @@ class DNSiteUpdater( } catch (e: InvalidCredentialsException) { if (!credentials.invalid) { site.invalidateDNCredentials(context) - Log.d(TAG, "Invalidated credentials in site ${site.name}") + Log.e(TAG, "Invalidated credentials in site ${site.name}") return Result.CREDENTIALS_UPDATED } return Result.NOOP } + Log.d(TAG, "Updated site ${site.id}: ${site.name}. Update? ${newSite != null}") if (newSite != null) { newSite.save(context) diff --git a/android/app/src/main/kotlin/net/defined/mobile_nebula/MainActivity.kt b/android/app/src/main/kotlin/net/defined/mobile_nebula/MainActivity.kt index 3826999b..b13ae2ee 100644 --- a/android/app/src/main/kotlin/net/defined/mobile_nebula/MainActivity.kt +++ b/android/app/src/main/kotlin/net/defined/mobile_nebula/MainActivity.kt @@ -11,26 +11,31 @@ import android.content.pm.PackageManager import android.net.VpnService import android.os.* import android.util.Log +import androidx.browser.customtabs.CustomTabsIntent import androidx.core.content.ContextCompat +import androidx.core.net.toUri import androidx.work.* import com.google.gson.Gson import io.flutter.embedding.android.FlutterActivity import io.flutter.embedding.engine.FlutterEngine import io.flutter.plugin.common.MethodCall import io.flutter.plugin.common.MethodChannel -import io.flutter.plugins.GeneratedPluginRegistrant +import io.flutter.plugin.common.StandardMethodCodec import java.io.File import java.util.concurrent.TimeUnit + const val TAG = "nebula" const val VPN_START_CODE = 0x10 const val CHANNEL = "net.defined.mobileNebula/NebulaVpnService" +const val BGCHANNEL = "net.defined.mobileNebula/NebulaVpnService/background" const val UPDATE_WORKER = "dnUpdater" class MainActivity: FlutterActivity() { private var ui: MethodChannel? = null + private var bg: MethodChannel? = null - private var inMessenger: Messenger? = Messenger(IncomingHandler()) + private var inMessenger: Messenger = Messenger(IncomingHandler()) private var outMessenger: Messenger? = null private var apiClient: APIClient? = null @@ -48,7 +53,7 @@ class MainActivity: FlutterActivity() { private var activeSiteId: String? = null private val workManager = WorkManager.getInstance(application) - private val refreshReceiver: BroadcastReceiver = RefreshReceiver() + private var refreshReceiver: BroadcastReceiver? = null companion object { const val ACTION_REFRESH_SITES = "net.defined.mobileNebula.REFRESH_SITES" @@ -75,6 +80,11 @@ class MainActivity: FlutterActivity() { "nebula.verifyCertAndKey" -> nebulaVerifyCertAndKey(call, result) "dn.enroll" -> dnEnroll(call, result) + "dn.getPollToken" -> dnGetPollToken(call, result) + "dn.usePollToken" -> dnUsePollToken(call, result) + "dn.popBrowser" -> dnPopBrowser(call, result) + "dn.reauthenticate" -> dnReauthenticate(call, result) + "dn.doUpdate" -> dnDoUpdate(result) "listSites" -> listSites(result) "deleteSite" -> deleteSite(call, result) @@ -96,21 +106,36 @@ class MainActivity: FlutterActivity() { else -> result.notImplemented() } } + + val taskQueue = flutterEngine.dartExecutor.binaryMessenger.makeBackgroundTaskQueue() + bg = MethodChannel(flutterEngine.dartExecutor.binaryMessenger, + BGCHANNEL, + StandardMethodCodec.INSTANCE, + taskQueue) + + bg!!.setMethodCallHandler { call, result -> + when(call.method) { + "dn.enroll" -> dnEnroll(call, result) + "dn.reauthenticate" -> dnReauthenticate(call, result) + "dn.getPollToken" -> dnGetPollToken(call, result) + "dn.usePollToken" -> dnUsePollToken(call, result) + + else -> result.notImplemented() + } + } } override fun onCreate(savedInstanceState: Bundle?) { super.onCreate(savedInstanceState) apiClient = APIClient(context) - + refreshReceiver = RefreshReceiver() ContextCompat.registerReceiver(context, refreshReceiver, IntentFilter(ACTION_REFRESH_SITES), ContextCompat.RECEIVER_NOT_EXPORTED) - enqueueDNUpdater() } override fun onDestroy() { super.onDestroy() - unregisterReceiver(refreshReceiver) } @@ -179,26 +204,137 @@ class MainActivity: FlutterActivity() { } } + private fun doEnroll(enrollCode: String): Result { + val site: IncomingSite + val siteDir: File + try { + site = apiClient!!.enroll(enrollCode) + siteDir = site.save(context) + } catch (err: Exception) { + return Result.failure(err) + } + + val ok = validateOrDeleteSite(siteDir) + Log.w(TAG,"got site: OK? $ok") + if (!ok) { + return Result.failure(Exception("Enrollment failed due to invalid config")) + } + Handler(Looper.getMainLooper()).post { + doRefresh() + } + return Result.success(true) + } + private fun dnEnroll(call: MethodCall, result: MethodChannel.Result) { val code = call.arguments as String if (code == "") { return result.error("required_argument", "code is a required argument", null) } + val out = doEnroll(code) + return when { + out.isSuccess-> result.success(null) + out.isFailure-> result.error("enroll_failed", out.exceptionOrNull()?.message, null) + else-> result.error("enroll_failed", "unknown", null) + } + } - val site: IncomingSite - val siteDir: File + private fun dnPopBrowser(call: MethodCall, result: MethodChannel.Result) { + val urlToPop = call.arguments as String + if (urlToPop == "") { + return result.error("required_argument", "url is a required argument", null) + } + val customTabsIntent = CustomTabsIntent.Builder() + .setShowTitle(true) + .build() try { - site = apiClient!!.enroll(code) - siteDir = site.save(context) + customTabsIntent.launchUrl(this, urlToPop.toUri()) } catch (err: Exception) { return result.error("unhandled_error", err.message, null) } + return result.success(null) + } - if (!validateOrDeleteSite(siteDir)) { - return result.error("failure", "Enrollment failed due to invalid config", null) + private fun dnGetPollToken(call: MethodCall, result: MethodChannel.Result) { + try { + val resp = apiClient!!.preauth() + val out = mapOf("pollToken" to resp.pollToken, "url" to resp.loginURL) + return result.success(out) + } catch (err: Exception) { + return result.error("unhandled_error", err.message, null) } + } - result.success(null) + private fun dnReauthenticate(call: MethodCall, result: MethodChannel.Result) { + val id = call.argument("id") + if (id == "") { + return result.error("required_argument", "id is a required argument", null) + } + + val site = sites!!.getSite(id!!) ?: return result.error("unknown_site", "No site with that id exists", null) + val creds = site.site.getDNCredentials(context) + try { + val resp = apiClient!!.reauthenticate(creds) + return result.success(resp) + } catch (err: Exception) { + return result.error("unhandled_error", err.message, null) + } + } + + private fun dnDoUpdate(result: MethodChannel.Result) { + val workRequest = OneTimeWorkRequestBuilder().build() + workManager.enqueue(workRequest) + return result.success(null) + } + + private fun usePollToken(pt: String): Result { + if (pt == "") { + return Result.failure(Exception("invalid sequence: pollToken is blank")) + } + return try { + val response = apiClient!!.authPoll(pt) + when (response.status) { + "COMPLETED" -> { + if (response.enrollmentCode == "") { + Result.failure(Exception("auth complete, enroll code empty!")) + } else { + doEnroll(response.enrollmentCode) + } + } + "STARTED" -> Result.failure(Exception( "auth incomplete")) + "WAITING" -> Result.failure(Exception( "auth incomplete")) + else -> { + Result.failure(Exception( "auth incomplete, invalid status")) + } + } + } catch (e: Exception) { + Log.e(TAG, "usePollToken threw an exception $e") + return Result.failure(e) + } + } + + private fun dnUsePollToken(call: MethodCall, result: MethodChannel.Result) { + val pollToken = call.arguments as String + if (pollToken == "") { + return result.error("required_argument", "pollToken is a required argument", null) + } + val out = usePollToken(pollToken) + return when { + out.isSuccess-> result.success(null) + out.isFailure-> { + val msg = out.exceptionOrNull()?.message + if(msg != null) { + if (msg.contains("resource not found")) { + result.error("oidc_enroll_failed", msg, null) + } else { + result.error("oidc_enroll_incomplete", msg, null) + } + } else { + result.error("oidc_enroll_failed", "unknown", null) + } + } + + else-> result.error("oidc_enroll_failed", "unknown", null) + } } private fun listSites(result: MethodChannel.Result) { @@ -453,6 +589,10 @@ class MainActivity: FlutterActivity() { bindService(intent, connection, 0) } + //trigger a doupdate + val workRequest = OneTimeWorkRequestBuilder().build() + workManager.enqueue(workRequest) + return result.success(null) } @@ -549,15 +689,19 @@ class MainActivity: FlutterActivity() { outMessenger = null } + private fun doRefresh() { + if (sites == null) return + + Log.d(TAG, "Refreshing sites in MainActivity") + + sites?.refreshSites(activeSiteId) + ui?.invokeMethod("refreshSites", null) + } + inner class RefreshReceiver : BroadcastReceiver() { override fun onReceive(context: Context, intent: Intent?) { if (intent?.action != ACTION_REFRESH_SITES) return - if (sites == null) return - - Log.d(TAG, "Refreshing sites in MainActivity") - - sites?.refreshSites(activeSiteId) - ui?.invokeMethod("refreshSites", null) + doRefresh() } } diff --git a/android/app/src/main/kotlin/net/defined/mobile_nebula/NebulaVpnService.kt b/android/app/src/main/kotlin/net/defined/mobile_nebula/NebulaVpnService.kt index 4457a9ab..8ff1b007 100644 --- a/android/app/src/main/kotlin/net/defined/mobile_nebula/NebulaVpnService.kt +++ b/android/app/src/main/kotlin/net/defined/mobile_nebula/NebulaVpnService.kt @@ -14,6 +14,7 @@ import androidx.core.content.ContextCompat import androidx.work.* import mobileNebula.CIDR import java.io.File +import java.lang.ref.WeakReference class NebulaVpnService : VpnService() { @@ -42,7 +43,7 @@ class NebulaVpnService : VpnService() { private lateinit var messenger: Messenger private val mClients = ArrayList() - private val reloadReceiver: BroadcastReceiver = ReloadReceiver() + private var reloadReceiver: BroadcastReceiver? = null private var workManager: WorkManager? = null private var path: String? = null @@ -51,14 +52,17 @@ class NebulaVpnService : VpnService() { private var nebula: mobileNebula.Nebula? = null private var vpnInterface: ParcelFileDescriptor? = null private var didSleep = false - private var networkCallback: NetworkCallback = NetworkCallback() + private lateinit var networkCallback: NetworkCallback override fun onCreate() { workManager = WorkManager.getInstance(this) + reloadReceiver = ReloadReceiver() + networkCallback = NetworkCallback() super.onCreate() } override fun onStartCommand(intent: Intent?, flags: Int, startId: Int): Int { + Log.d(TAG, "onStartCommand") if (intent?.action == ACTION_STOP) { stopVpn() return Service.START_NOT_STICKY @@ -284,33 +288,36 @@ class NebulaVpnService : VpnService() { /** * Handler of incoming messages from clients. */ - inner class IncomingHandler : Handler(Looper.getMainLooper()) { + private class IncomingHandler(service: NebulaVpnService) : Handler(Looper.getMainLooper()) { + private val serviceRef: WeakReference = WeakReference(service) + override fun handleMessage(msg: Message) { + val m = serviceRef.get() ?: return //bail if the service is dead //TODO: how do we limit what can talk to us? //TODO: Make sure replyTo is actually a messenger when (msg.what) { - MSG_REGISTER_CLIENT -> register(msg) - MSG_UNREGISTER_CLIENT -> mClients.remove(msg.replyTo) - MSG_IS_RUNNING -> isRunning() - MSG_LIST_HOSTMAP -> listHostmap(msg) - MSG_LIST_INDEXES -> listIndexes(msg) - MSG_LIST_PENDING_HOSTMAP -> listHostmap(msg) - MSG_GET_HOSTINFO -> getHostInfo(msg) - MSG_CLOSE_TUNNEL -> closeTunnel(msg) - MSG_SET_REMOTE_FOR_TUNNEL -> setRemoteForTunnel(msg) + MSG_REGISTER_CLIENT -> register(m, msg) + MSG_UNREGISTER_CLIENT -> m.mClients.remove(msg.replyTo) + MSG_IS_RUNNING -> isRunning(m) + MSG_LIST_HOSTMAP -> listHostmap(m, msg) + MSG_LIST_INDEXES -> listIndexes(m, msg) + MSG_LIST_PENDING_HOSTMAP -> listHostmap(m, msg) + MSG_GET_HOSTINFO -> getHostInfo(m, msg) + MSG_CLOSE_TUNNEL -> closeTunnel(m, msg) + MSG_SET_REMOTE_FOR_TUNNEL -> setRemoteForTunnel(m, msg) else -> super.handleMessage(msg) } } - private fun register(msg: Message) { - mClients.add(msg.replyTo) - if (!running) { - startVpn() + private fun register(m: NebulaVpnService, msg: Message) { + m.mClients.add(msg.replyTo) + if (!m.running) { + m.startVpn() } } - private fun protect(msg: Message): Boolean { - if (nebula != null) { + private fun protect(m: NebulaVpnService, msg: Message): Boolean { + if (m.nebula != null) { return false } @@ -318,50 +325,50 @@ class NebulaVpnService : VpnService() { return true } - private fun isRunning() { - sendSimple(MSG_IS_RUNNING, if (running) 1 else 0) + private fun isRunning(m: NebulaVpnService) { + m.sendSimple(MSG_IS_RUNNING, if (m.running) 1 else 0) } - private fun listHostmap(msg: Message) { - if (protect(msg)) { return } + private fun listHostmap(m: NebulaVpnService, msg: Message) { + if (protect(m, msg)) { return } - val res = nebula!!.listHostmap(msg.what == MSG_LIST_PENDING_HOSTMAP) + val res = m.nebula!!.listHostmap(msg.what == MSG_LIST_PENDING_HOSTMAP) val m = Message.obtain(null, msg.what) m.data.putString("data", res) msg.replyTo.send(m) } - private fun listIndexes(msg: Message) { - if (protect(msg)) { return } + private fun listIndexes(m: NebulaVpnService, msg: Message) { + if (protect(m, msg)) { return } - val res = nebula!!.listIndexes(false) + val res = m.nebula!!.listIndexes(false) val m = Message.obtain(null, msg.what) m.data.putString("data", res) msg.replyTo.send(m) } - private fun getHostInfo(msg: Message) { - if (protect(msg)) { return } + private fun getHostInfo(m: NebulaVpnService, msg: Message) { + if (protect(m, msg)) { return } - val res = nebula!!.getHostInfoByVpnIp(msg.data.getString("vpnIp"), msg.data.getBoolean("pending")) + val res = m.nebula!!.getHostInfoByVpnIp(msg.data.getString("vpnIp"), msg.data.getBoolean("pending")) val m = Message.obtain(null, msg.what) m.data.putString("data", res) msg.replyTo.send(m) } - private fun setRemoteForTunnel(msg: Message) { - if (protect(msg)) { return } + private fun setRemoteForTunnel(m: NebulaVpnService, msg: Message) { + if (protect(m, msg)) { return } - val res = nebula!!.setRemoteForTunnel(msg.data.getString("vpnIp"), msg.data.getString("addr")) + val res = m.nebula!!.setRemoteForTunnel(msg.data.getString("vpnIp"), msg.data.getString("addr")) val m = Message.obtain(null, msg.what) m.data.putString("data", res) msg.replyTo.send(m) } - private fun closeTunnel(msg: Message) { - if (protect(msg)) { return } + private fun closeTunnel(m: NebulaVpnService, msg: Message) { + if (protect(m, msg)) { return } - val res = nebula!!.closeTunnel(msg.data.getString("vpnIp")) + val res = m.nebula!!.closeTunnel(msg.data.getString("vpnIp")) val m = Message.obtain(null, msg.what) m.data.putBoolean("data", res) msg.replyTo.send(m) @@ -396,7 +403,7 @@ class NebulaVpnService : VpnService() { return super.onBind(intent) } - messenger = Messenger(IncomingHandler()) + messenger = Messenger(IncomingHandler(this)) return messenger.binder } } diff --git a/android/app/src/main/kotlin/net/defined/mobile_nebula/Sites.kt b/android/app/src/main/kotlin/net/defined/mobile_nebula/Sites.kt index 65bb9510..92866979 100644 --- a/android/app/src/main/kotlin/net/defined/mobile_nebula/Sites.kt +++ b/android/app/src/main/kotlin/net/defined/mobile_nebula/Sites.kt @@ -17,10 +17,6 @@ data class SiteContainer( class Sites(private var engine: FlutterEngine) { private var containers: HashMap = HashMap() - init { - refreshSites() - } - fun refreshSites(activeSite: String? = null) { val context = MainActivity.getContext()!! @@ -37,6 +33,9 @@ class Sites(private var engine: FlutterEngine) { if (site.id == activeSite) { updater.setState(true, "Connected") + } else if (site.managed) { + // still show new info for managed sites, since the backend may have changed stuff! + updater.setState(null, null) } containers[site.id] = SiteContainer(site, updater) @@ -122,9 +121,14 @@ class SiteUpdater(private var site: Site, engine: FlutterEngine): EventChannel.S this.site = site } - fun setState(connected: Boolean, status: String, err: String? = null) { - site.connected = connected - site.status = status + fun setState(connected: Boolean?, status: String?, err: String? = null) { + if (connected != null) { + site.connected = connected + } + if (status != null) { + site.status = status + } + if (err != null) { eventSink?.error("", err, gson.toJson(site)) } else { @@ -212,6 +216,8 @@ class Site(context: Context, siteDir: File) { // The following fields are present when managed = true val rawConfig: String? val lastManagedUpdate: String? + val managedOIDCEmail: String? + val managedOIDCExpiry: String? // Path to this site on disk @Transient @@ -241,6 +247,8 @@ class Site(context: Context, siteDir: File) { rawConfig = incomingSite.rawConfig managed = incomingSite.managed ?: false lastManagedUpdate = incomingSite.lastManagedUpdate + managedOIDCEmail = incomingSite.managedOIDCEmail + managedOIDCExpiry = incomingSite.managedOIDCExpiry connected = false status = "Disconnected" @@ -347,6 +355,8 @@ class IncomingSite( val managed: Boolean?, // The following fields are present when managed = true val lastManagedUpdate: String?, + val managedOIDCEmail: String?, + val managedOIDCExpiry: String?, val rawConfig: String?, var dnCredentials: DNCredentials?, ) { diff --git a/android/gradle.properties b/android/gradle.properties index b69a3d14..8da88752 100644 --- a/android/gradle.properties +++ b/android/gradle.properties @@ -1,4 +1,4 @@ -org.gradle.jvmargs=-Xmx1536M +org.gradle.jvmargs=-Xmx4096M android.useAndroidX=true android.enableJetifier=true android.nonTransitiveRClass=false diff --git a/android/settings.gradle b/android/settings.gradle index d6e6aff6..92f2d074 100644 --- a/android/settings.gradle +++ b/android/settings.gradle @@ -19,7 +19,7 @@ pluginManagement { plugins { id "org.gradle.toolchains.foojay-resolver-convention" version "0.8.0" id "dev.flutter.flutter-plugin-loader" version "1.0.0" - id "com.android.application" version '8.13.0' apply false + id "com.android.application" version '8.13.1' apply false id "org.jetbrains.kotlin.android" version "2.0.20" apply false } diff --git a/ios/Runner/APIClient.swift b/ios/Runner/APIClient.swift index 4fe01d35..2a47ba85 100644 --- a/ios/Runner/APIClient.swift +++ b/ios/Runner/APIClient.swift @@ -19,12 +19,12 @@ class APIClient { return try decodeIncomingSite(jsonSite: res.site) } - func tryUpdate( + func longPollWait( siteName: String, hostID: String, privateKey: String, counter: Int, trustedKeys: String ) throws -> IncomingSite? { - let res: MobileNebulaTryUpdateResult + let res: MobileNebulaLongPollWaitResult do { - res = try apiClient.tryUpdate( + res = try apiClient.longPollWait( siteName, hostID: hostID, privateKey: privateKey, diff --git a/ios/Runner/DNUpdate.swift b/ios/Runner/DNUpdate.swift index c59652fa..c3b3675e 100644 --- a/ios/Runner/DNUpdate.swift +++ b/ios/Runner/DNUpdate.swift @@ -47,7 +47,7 @@ class DNUpdater { let newSite: IncomingSite? do { - newSite = try apiClient.tryUpdate( + newSite = try apiClient.longPollWait( siteName: site.name, hostID: credentials.hostID, privateKey: credentials.privateKey, diff --git a/lib/models/Site.dart b/lib/models/Site.dart index 11523657..0840229b 100644 --- a/lib/models/Site.dart +++ b/lib/models/Site.dart @@ -13,6 +13,7 @@ var uuid = Uuid(); class Site { static const platform = MethodChannel('net.defined.mobileNebula/NebulaVpnService'); + static const bgplatform = MethodChannel('net.defined.mobileNebula/NebulaVpnService/background'); late EventChannel _updates; /// Signals that something about this site has changed. onError is called with an error string if there was an error @@ -49,6 +50,8 @@ class Site { // The following fields are present when managed = true late String? rawConfig; late DateTime? lastManagedUpdate; + late String? managedOIDCEmail; + late DateTime? managedOIDCExpiry; // A list of errors encountered while loading the site late List errors; @@ -73,6 +76,8 @@ class Site { this.managed = false, this.rawConfig, this.lastManagedUpdate, + this.managedOIDCEmail, + this.managedOIDCExpiry, }) { this.id = id ?? uuid.v4(); this.staticHostmap = staticHostmap ?? {}; @@ -84,7 +89,7 @@ class Site { _updates.receiveBroadcastStream().listen( (d) { try { - _updateFromJson(d); + updateFromJson(d); _change.add(null); } catch (err) { //TODO: handle the error @@ -92,7 +97,7 @@ class Site { } }, onError: (err) { - _updateFromJson(err.details); + updateFromJson(err.details); var error = err as PlatformException; _change.addError(error.message ?? 'An unexpected error occurred'); }, @@ -121,13 +126,15 @@ class Site { managed: decoded['managed'], rawConfig: decoded['rawConfig'], lastManagedUpdate: decoded['lastManagedUpdate'], + managedOIDCEmail: decoded['managedOIDCEmail'], + managedOIDCExpiry: decoded['managedOIDCExpiry'], ); } - _updateFromJson(String json) { - var decoded = Site._fromJson(jsonDecode(json)); + updateFromMap(Map j) { + final decoded = Site._fromJson(j); name = decoded["name"]; - id = decoded['id']; // TODO update EventChannel + id = decoded['id']; // TODO update EventChannel, or consider this an error staticHostmap = decoded['staticHostmap']; ca = decoded['ca']; certInfo = decoded['certInfo']; @@ -136,8 +143,12 @@ class Site { cipher = decoded['cipher']; sortKey = decoded['sortKey']; mtu = decoded['mtu']; - connected = decoded['connected']; - status = decoded['status']; + if (decoded['connected'] != null) { + connected = decoded['connected']; + } + if (decoded['status'] != null) { + status = decoded['status']; + } logFile = decoded['logFile']; logVerbosity = decoded['logVerbosity']; errors = decoded['errors']; @@ -145,6 +156,13 @@ class Site { managed = decoded['managed']; rawConfig = decoded['rawConfig']; lastManagedUpdate = decoded['lastManagedUpdate']; + managedOIDCEmail = decoded['managedOIDCEmail']; + managedOIDCExpiry = decoded['managedOIDCExpiry']; + } + + updateFromJson(String json) { + var decoded = jsonDecode(json); + updateFromMap(decoded); } static _fromJson(Map json) { @@ -188,8 +206,8 @@ class Site { "cipher": json['cipher'], "sortKey": json['sortKey'], "mtu": json['mtu'], - "connected": json['connected'] ?? false, - "status": json['status'] ?? "", + "connected": json['connected'], + "status": json['status'], "logFile": json['logFile'], "logVerbosity": json['logVerbosity'], "errors": errors, @@ -197,6 +215,8 @@ class Site { "managed": json['managed'] ?? false, "rawConfig": json['rawConfig'], "lastManagedUpdate": json["lastManagedUpdate"] == null ? null : DateTime.parse(json["lastManagedUpdate"]), + "managedOIDCEmail": json["managedOIDCEmail"], + "managedOIDCExpiry": json["managedOIDCExpiry"] == null ? null : DateTime.parse(json["managedOIDCExpiry"]), }; } @@ -351,6 +371,13 @@ class Site { } } + bool isSwitchOnAllowed() { + if (managed) { + return true; + } + return errors.isNotEmpty && !connected; + } + Future setRemoteForTunnel(String vpnIp, String addr) async { try { var ret = await platform.invokeMethod("active.setRemoteForTunnel", { @@ -380,4 +407,15 @@ class Site { throw err.toString(); } } + + // returns loginurl + Future reauthenticate() async { + try { + return await bgplatform.invokeMethod("dn.reauthenticate", {"id": id}); + } on PlatformException catch (err) { + throw err.details ?? err.message ?? err.toString(); + } catch (err) { + throw err.toString(); + } + } } diff --git a/lib/screens/MainScreen.dart b/lib/screens/MainScreen.dart index 52a07369..a829fd9b 100644 --- a/lib/screens/MainScreen.dart +++ b/lib/screens/MainScreen.dart @@ -17,6 +17,8 @@ import 'package:mobile_nebula/models/UnsafeRoute.dart'; import 'package:mobile_nebula/screens/SettingsScreen.dart'; import 'package:mobile_nebula/screens/SiteDetailScreen.dart'; import 'package:mobile_nebula/screens/siteConfig/SiteConfigScreen.dart'; +import 'package:mobile_nebula/services/oidc.dart'; +import 'package:mobile_nebula/services/settings.dart'; import 'package:mobile_nebula/services/utils.dart'; import 'package:pull_to_refresh/pull_to_refresh.dart'; import 'package:uuid/uuid.dart'; @@ -92,6 +94,7 @@ class MainScreen extends StatefulWidget { } class _MainScreenState extends State { + final settings = Settings(); List? sites; // A set of widgets to display in a column that represents an error blocking us from moving forward entirely List? error; @@ -99,6 +102,8 @@ class _MainScreenState extends State { bool supportsQRScanning = false; static const platform = MethodChannel('net.defined.mobileNebula/NebulaVpnService'); + static const bgplatform = MethodChannel('net.defined.mobileNebula/NebulaVpnService/background'); + late final OIDCPoller _authService = OIDCPoller(settings, platform, bgplatform); RefreshController refreshController = RefreshController(); ScrollController scrollController = ScrollController(); @@ -112,6 +117,12 @@ class _MainScreenState extends State { platform.setMethodCallHandler(handleMethodCall); + if (settings.pollCode != "") { + _keepPolling(); + } else { + print("no enroll token I guess"); + } + super.initState(); } @@ -331,34 +342,73 @@ class _MainScreenState extends State { ); } + _keepPolling() async { + try { + print("wow we still need to poll"); + final status = await _authService.pollLoop(); + if (!mounted) return; + if (status) { + _loadSites(); + } else { + print("login failed"); + } + } catch (e) { + print("login failed with exception: $e"); + } + } + _loadSites() async { //TODO: This can throw, we need to show an error dialog Map rawSites = jsonDecode(await platform.invokeMethod('listSites')); + Map foundSites = {}; + Map oldSitesById = {}; + sites?.forEach((s) { + oldSitesById[s.id] = s; + foundSites[s.id] = false; + }); + sites = []; rawSites.forEach((id, rawSite) { - try { - var site = Site.fromJson(rawSite); - - //TODO: we need to cancel change listeners when we rebuild - site.onChange().listen( - (_) { - setState(() {}); - }, - onError: (err) { - setState(() {}); - if (ModalRoute.of(context)!.isCurrent) { - Utils.popError(context, "${site.name} Error", err); - } - }, - ); + final s = oldSitesById[id]; + if (s != null) { + if (s.id == id) { + foundSites[s.id] = true; + try { + s.updateFromMap(rawSite); + sites!.add(s); + } catch (err) { + print("$err site config: $rawSite"); //TODO: handle error + } + } + } else { + try { + var site = Site.fromJson(rawSite); + site.onChange().listen( + (_) { + setState(() {}); + }, + onError: (err) { + setState(() {}); + if (ModalRoute.of(context)!.isCurrent) { + Utils.popError(context, "${site.name} Error", err); + } + }, + ); + + sites!.add(site); + } catch (err) { + print("$err site config: $rawSite"); //TODO: handle error + // Sometimes it is helpful to just nuke these is dev + // platform.invokeMethod('deleteSite', id); + } + } + }); - sites!.add(site); - } catch (err) { - //TODO: handle error - print("$err site config: $rawSite"); - // Sometimes it is helpful to just nuke these is dev - // platform.invokeMethod('deleteSite', id); + //tear down old sites + oldSitesById.forEach((id, oldSite) { + if (foundSites[id] == false) { + oldSite.dispose(); } }); diff --git a/lib/screens/SettingsScreen.dart b/lib/screens/SettingsScreen.dart index 3c4a1a3a..4addcae2 100644 --- a/lib/screens/SettingsScreen.dart +++ b/lib/screens/SettingsScreen.dart @@ -1,11 +1,14 @@ import 'dart:async'; +import 'dart:isolate'; import 'package:flutter/material.dart'; +import 'package:flutter/services.dart'; import 'package:mobile_nebula/components/SimplePage.dart'; import 'package:mobile_nebula/components/config/ConfigItem.dart'; import 'package:mobile_nebula/components/config/ConfigPageItem.dart'; import 'package:mobile_nebula/components/config/ConfigSection.dart'; import 'package:mobile_nebula/screens/EnrollmentScreen.dart'; +import 'package:mobile_nebula/services/oidc.dart'; import 'package:mobile_nebula/services/settings.dart'; import 'package:mobile_nebula/services/utils.dart'; @@ -22,6 +25,9 @@ class SettingsScreen extends StatefulWidget { class _SettingsScreenState extends State { var settings = Settings(); + static const platform = MethodChannel('net.defined.mobileNebula/NebulaVpnService'); + static const bgplatform = MethodChannel('net.defined.mobileNebula/NebulaVpnService/background'); + late final OIDCPoller _authService = OIDCPoller(settings, platform, bgplatform); @override void initState() { @@ -127,6 +133,11 @@ class _SettingsScreenState extends State { () => Utils.openPage(context, (context) => EnrollmentScreen(stream: widget.stream, allowCodeEntry: true)), ), + ConfigPageItem( + label: Text('Enroll with Managed Nebula (SSO)'), + labelWidth: 250, + onPressed: () => onEnrollSSO(), + ), ], ), ); @@ -141,4 +152,27 @@ class _SettingsScreenState extends State { return SimplePage(title: Text('Settings'), child: Column(children: items)); } + + Future onEnrollSSO() async { + try { + final success = await _authService.beginLogin(); + + if (!success) { + print("Failed to open login page"); + return; + } + + print("Waiting for login..."); + final status = await _authService.pollLoop(); + if (!mounted) return; + if (status == true) { + // Login successful, go home + Navigator.of(context).pop(); + } else { + print("login failed"); + } + } catch (e) { + print("login failed with exception: $e"); + } + } } diff --git a/lib/screens/SiteDetailScreen.dart b/lib/screens/SiteDetailScreen.dart index 601b4852..ee7361c7 100644 --- a/lib/screens/SiteDetailScreen.dart +++ b/lib/screens/SiteDetailScreen.dart @@ -15,6 +15,7 @@ import 'package:mobile_nebula/screens/SiteTunnelsScreen.dart'; import 'package:mobile_nebula/screens/siteConfig/SiteConfigScreen.dart'; import 'package:mobile_nebula/services/utils.dart'; import 'package:pull_to_refresh/pull_to_refresh.dart'; +import 'package:duration/duration.dart'; import '../components/DangerButton.dart'; import '../components/SiteTitle.dart'; @@ -38,8 +39,10 @@ class _SiteDetailScreenState extends State { late StreamSubscription onChange; static const platform = MethodChannel('net.defined.mobileNebula/NebulaVpnService'); bool changed = false; + bool reauthSpin = false; List? activeHosts; List? pendingHosts; + String expiresIn = "Unknown"; RefreshController refreshController = RefreshController(initialRefresh: false); @override @@ -60,7 +63,9 @@ class _SiteDetailScreenState extends State { pendingHosts = null; } - setState(() {}); + setState(() { + expiresIn = calcExpiresIn(site.managedOIDCExpiry); + }); }, onError: (err) { setState(() {}); @@ -79,7 +84,7 @@ class _SiteDetailScreenState extends State { @override Widget build(BuildContext context) { - final title = SiteTitle(site: widget.site); + final title = SiteTitle(site: site); return SimplePage( title: title, @@ -94,15 +99,20 @@ class _SiteDetailScreenState extends State { ), refreshController: refreshController, onRefresh: () async { + //await Site.platform.invokeMethod('dn.doUpdate'); //todo? if (site.connected && site.status == "Connected") { await _listHostmap(); } + setState(() { + expiresIn = calcExpiresIn(site.managedOIDCExpiry); + }); refreshController.refreshCompleted(); }, child: Column( children: [ _buildErrors(), _buildConfig(), + site.managed ? _buildManaged() : Container(), site.connected ? _buildHosts() : Container(), _buildSiteDetails(), _buildDelete(), @@ -125,7 +135,7 @@ class _SiteDetailScreenState extends State { ), ); } - + //todo if expired, add reauth button return ConfigSection( label: 'ERRORS', borderColor: CupertinoColors.systemRed.resolveFrom(context), @@ -138,9 +148,9 @@ class _SiteDetailScreenState extends State { void handleChange(v) async { try { if (v) { - await widget.site.start(); + await site.start(); } else { - await widget.site.stop(); + await site.stop(); } } catch (error) { var action = v ? 'start' : 'stop'; @@ -158,30 +168,89 @@ class _SiteDetailScreenState extends State { Padding( padding: EdgeInsets.only(right: 5), child: Text( - widget.site.status, + site.status, style: TextStyle(color: CupertinoColors.secondaryLabel.resolveFrom(context)), ), ), Switch.adaptive( - value: widget.site.connected, + value: site.connected, materialTapTargetSize: MaterialTapTargetSize.shrinkWrap, - onChanged: widget.site.errors.isNotEmpty && !widget.site.connected ? null : handleChange, + onChanged: site.isSwitchOnAllowed() ? handleChange: null, ), ], ), ), - ConfigPageItem( - label: Text('Logs'), - onPressed: () { - Utils.openPage(context, (context) { - return SiteLogsScreen(site: widget.site); - }); - }, - ), ], ); } + String calcExpiresIn(DateTime? expiresAt) { + if (expiresAt == null) { + return "Never"; + } + + final exp = expiresAt.toLocal(); + if (exp.isBefore(DateTime.now())) { + return "NOW"; + } else { + final expAt = exp.difference(DateTime.now()); + return "in ${expAt.pretty(tersity: DurationTersity.second)}"; //todo minute? + } + } + + Widget _buildManaged() { + if (site.managedOIDCEmail == null) { + return Container(); + } + + var out = ConfigSection( + label: "MANAGED CONFIG", + children: [], + ); + + expiresIn = calcExpiresIn(site.managedOIDCExpiry); + + Widget? reauthText = null; + if (reauthSpin) { + reauthText = SizedBox(height: 20, width: 20, child: PlatformCircularProgressIndicator()); + } else { + reauthText = Text(expiresIn); + } + + out.children.add(ConfigItem( + label: Text("Username"), + content: Wrap( + alignment: WrapAlignment.end, + crossAxisAlignment: WrapCrossAlignment.center, + children: [Text(site.managedOIDCEmail!)], + ))); + out.children.add(ConfigPageItem( + label: Text("Reauthenticate"), + onPressed: _reauth, + content: Wrap( + alignment: WrapAlignment.end, + crossAxisAlignment: WrapCrossAlignment.center, + children: [reauthText], + ))); + + return out; + } + + Future _reauth() async { + setState(() { + reauthSpin = true; + }); + try { + final loginUrl = await site.reauthenticate(); + await platform.invokeMethod("dn.popBrowser", loginUrl); + } on PlatformException catch (err) { + print(err); + } + setState(() { + reauthSpin = false; + }); + } + Widget _buildHosts() { Widget active, pending; @@ -257,7 +326,7 @@ class _SiteDetailScreenState extends State { onPressed: () { Utils.openPage(context, (context) { return SiteConfigScreen( - site: widget.site, + site: site, onSave: (site) async { changed = true; setState(() {}); @@ -267,6 +336,14 @@ class _SiteDetailScreenState extends State { }); }, ), + ConfigPageItem( + label: Text('Logs'), + onPressed: () { + Utils.openPage(context, (context) { + return SiteLogsScreen(site: site); + }); + }, + ), ], ); } @@ -302,7 +379,7 @@ class _SiteDetailScreenState extends State { Future _deleteSite() async { try { - var err = await platform.invokeMethod("deleteSite", widget.site.id); + var err = await platform.invokeMethod("deleteSite", site.id); if (err != null) { Utils.popError(context, 'Failed to delete the site', err); return false; diff --git a/lib/screens/siteConfig/SiteConfigScreen.dart b/lib/screens/siteConfig/SiteConfigScreen.dart index 148a9461..382537b0 100644 --- a/lib/screens/siteConfig/SiteConfigScreen.dart +++ b/lib/screens/siteConfig/SiteConfigScreen.dart @@ -134,27 +134,47 @@ class _SiteConfigScreenState extends State { } Widget _managed() { + if (!site.managed) { + return Container(); + } final formatter = DateFormat.yMMMMd('en_US').add_jm(); var lastUpdate = "Unknown"; + var oidcExpiry = "Unknown"; if (site.lastManagedUpdate != null) { lastUpdate = formatter.format(site.lastManagedUpdate!.toLocal()); } - return site.managed - ? ConfigSection( - label: "MANAGED CONFIG", - children: [ - ConfigItem( - label: Text("Last Update"), - content: Wrap( - alignment: WrapAlignment.end, - crossAxisAlignment: WrapCrossAlignment.center, - children: [Text(lastUpdate)], - ), - ), - ], - ) - : Container(); + var out = ConfigSection( + label: "MANAGED CONFIG", + children: [ + ConfigItem( + label: Text("Last Updated"), + content: Wrap( + alignment: WrapAlignment.end, + crossAxisAlignment: WrapCrossAlignment.center, + children: [Text(lastUpdate)], + ), + ), + ], + ); + + if (site.managedOIDCEmail != null) { + if (site.managedOIDCExpiry != null) { + oidcExpiry = formatter.format(site.managedOIDCExpiry!.toLocal()); + } else { + oidcExpiry = "Never"; + } + + out.children.add(ConfigItem( + label: Text("Username"), + content: Wrap( + alignment: WrapAlignment.end, + crossAxisAlignment: WrapCrossAlignment.center, + children: [Text(site.managedOIDCEmail!)], + ))); + } + + return out; } Widget _keys() { diff --git a/lib/services/oidc.dart b/lib/services/oidc.dart new file mode 100644 index 00000000..482b84e2 --- /dev/null +++ b/lib/services/oidc.dart @@ -0,0 +1,94 @@ +import 'package:flutter/services.dart'; +import 'package:mobile_nebula/services/settings.dart'; +import 'dart:async'; + +class _PollTokenResponse { + final String token; + final String url; + _PollTokenResponse(this.token, this.url); +} + +class OIDCPoller { + final Settings settings; //todo thread safety? + final MethodChannel platform; + final MethodChannel bgplatform; + + OIDCPoller(this.settings, this.platform, this.bgplatform); + + Future<_PollTokenResponse?> _getPollToken() async { + try { + //todo put a lil spinny somewhere? + var out = await platform.invokeMethod("dn.getPollToken"); + if (out == null) { + print("getPollToken was null"); + return null; + } + settings.pollCode = out["pollToken"]; + return _PollTokenResponse(out["pollToken"], out["url"]); + } on PlatformException catch (err) { + print(err); + return null; + } + } + + Future beginLogin() async { + final resp = await _getPollToken(); + if (resp == null) { + print('Could not obtain poll token'); + return false; + } + + try { + await platform.invokeMethod("dn.popBrowser", resp.url); + return true; + } on PlatformException catch (err) { + print(err); + return false; + } + } + + Future pollLoginStatus() async { + final pollToken = settings.pollCode; + + if (pollToken == "") { + print('No poll token found'); + return false; + } + + try { + await bgplatform.invokeMethod("dn.usePollToken", pollToken); + print("probably enrolled"); + settings.pollCode = ""; + return true; + } on PlatformException catch (err) { + if (err.code == "oidc_enroll_incomplete") { + print("still thinking! $err"); + return null; //retry I suppose? + } + print(err); + return false; + } + } + + Future pollLoop({ + Duration interval = const Duration(seconds: 2), + Duration timeout = const Duration(minutes: 5), + }) async { + final startTime = DateTime.now(); + + while (DateTime.now().difference(startTime) < timeout) { + final status = await pollLoginStatus(); + if (status != null) { + // Login completed (success or failure) + settings.pollCode = ""; + return status; + } else { + await Future.delayed(interval); + } + } + + // Timeout reached + settings.pollCode = ""; + return false; + } +} diff --git a/lib/services/settings.dart b/lib/services/settings.dart index 5144c3ec..b25f6b2a 100644 --- a/lib/services/settings.dart +++ b/lib/services/settings.dart @@ -54,6 +54,16 @@ class Settings { } } + set pollCode(String code) { + //todo set an expiration time? + _set('pollCode', code); + } + + String get pollCode { + //todo expire + return _getString('pollCode', ''); + } + String _getString(String key, String defaultValue) { final val = _settings[key]; if (val is String) { diff --git a/nebula/api.go b/nebula/api.go index 702d4aa8..73e9e4fc 100644 --- a/nebula/api.go +++ b/nebula/api.go @@ -13,6 +13,7 @@ import ( "github.com/DefinedNet/dnapi" "github.com/DefinedNet/dnapi/keys" + "github.com/DefinedNet/dnapi/message" "github.com/sirupsen/logrus" "github.com/slackhq/nebula/cert" ) @@ -26,7 +27,17 @@ type EnrollResult struct { Site string } -type TryUpdateResult struct { +type PreAuthResult struct { + PollToken string + LoginURL string +} + +type PollDataResult struct { + Status string `json:"state"` + EnrollmentCode string `json:"enrollmentCode"` +} + +type LongPollWaitResult struct { FetchedUpdate bool Site string } @@ -65,7 +76,7 @@ func (c *APIClient) Enroll(code string) (*EnrollResult, error) { return nil, fmt.Errorf("unexpected failure: %s", err) } - site, err := newDNSite(meta.Org.Name, cfg, string(pkey), *creds) + site, err := newDNSite(meta.Org.Name, cfg, string(pkey), *creds, meta) if err != nil { return nil, fmt.Errorf("failure generating site: %s", err) } @@ -78,7 +89,47 @@ func (c *APIClient) Enroll(code string) (*EnrollResult, error) { return &EnrollResult{Site: string(jsonSite)}, nil } -func (c *APIClient) TryUpdate(siteName string, hostID string, privateKey string, counter int, trustedKeys string) (*TryUpdateResult, error) { +func (c *APIClient) EndpointPreAuth() (*PreAuthResult, error) { + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + msg, err := c.c.EndpointPreAuth(ctx) + var apiError *dnapi.APIError + switch { + case errors.As(err, &apiError): + return nil, fmt.Errorf("%s (request ID: %s)", apiError, apiError.ReqID) + case errors.Is(err, context.DeadlineExceeded): + return nil, fmt.Errorf("request timed out - try again") + case err != nil: + return nil, fmt.Errorf("unexpected failure: %s", err) + } + + return &PreAuthResult{ + PollToken: msg.PollToken, + LoginURL: msg.LoginURL, + }, nil +} + +func (c *APIClient) EndpointAuthPoll(pollCode string) (*PollDataResult, error) { + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + msg, err := c.c.EndpointAuthPoll(ctx, pollCode) + var apiError *dnapi.APIError + switch { + case errors.As(err, &apiError): + return nil, fmt.Errorf("%s (request ID: %s)", apiError, apiError.ReqID) + case errors.Is(err, context.DeadlineExceeded): + return nil, fmt.Errorf("request timed out - try again") + case err != nil: + return nil, fmt.Errorf("unexpected failure: %s", err) + } + + return &PollDataResult{ + Status: string(msg.Status), + EnrollmentCode: msg.EnrollmentCode, + }, nil +} + +func (c *APIClient) keysToCreds(hostID string, privateKey string, counter int, trustedKeys string) (*keys.Credentials, error) { // Build dnapi.Credentials struct from inputs if counter < 0 { return nil, fmt.Errorf("invalid counter value: must be unsigned") @@ -103,26 +154,44 @@ func (c *APIClient) TryUpdate(siteName string, hostID string, privateKey string, Counter: uint(counter), TrustedKeys: tk, } + return &creds, nil +} +func (c *APIClient) LongPollWait(siteName string, hostID string, privateKey string, counter int, trustedKeys string) (*LongPollWaitResult, error) { + creds, err := c.keysToCreds(hostID, privateKey, counter, trustedKeys) + if err != nil { + return nil, err + } // Check for update - ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + ctx, cancel := context.WithTimeout(context.Background(), time.Minute) //todo should this have a small retry loop to deal with mobile-related pain? defer cancel() - updateAvailable, err := c.c.CheckForUpdate(ctx, creds) + msg, err := c.c.LongPollWait(ctx, *creds, []string{message.DoUpdate}) switch { + case errors.Is(ctx.Err(), context.DeadlineExceeded): + return &LongPollWaitResult{FetchedUpdate: false}, nil case errors.Is(err, dnapi.ErrInvalidCredentials): return nil, InvalidCredentialsError{} case err != nil: - return nil, fmt.Errorf("CheckForUpdate error: %s", err) + return nil, fmt.Errorf("LongPollWait error: %s", err) } - - if !updateAvailable { - return &TryUpdateResult{FetchedUpdate: false}, nil + var msgType struct{ Command string } + err = json.Unmarshal(msg.Action, &msgType) + if err != nil { + return nil, fmt.Errorf("failed to parse LongPollWait response: %s", err) } + switch msgType.Command { + case message.DoUpdate: + return c.doUpdate(siteName, *creds) + default: + return &LongPollWaitResult{FetchedUpdate: false}, nil + } +} +func (c *APIClient) doUpdate(siteName string, creds keys.Credentials) (*LongPollWaitResult, error) { // Perform the update and return the new site object updateCtx, updateCancel := context.WithTimeout(context.Background(), 30*time.Second) defer updateCancel() - cfg, pkey, newCreds, _, err := c.c.DoUpdate(updateCtx, creds) + cfg, pkey, newCreds, configMeta, err := c.c.DoUpdate(updateCtx, creds) switch { case errors.Is(err, dnapi.ErrInvalidCredentials): return nil, InvalidCredentialsError{} @@ -130,7 +199,7 @@ func (c *APIClient) TryUpdate(siteName string, hostID string, privateKey string, return nil, fmt.Errorf("DoUpdate error: %s", err) } - site, err := newDNSite(siteName, cfg, string(pkey), *newCreds) + site, err := newDNSite(siteName, cfg, string(pkey), *newCreds, configMeta) if err != nil { return nil, fmt.Errorf("failure generating site: %s", err) } @@ -140,7 +209,25 @@ func (c *APIClient) TryUpdate(siteName string, hostID string, privateKey string, return nil, fmt.Errorf("failed to marshal site: %s", err) } - return &TryUpdateResult{Site: string(jsonSite), FetchedUpdate: true}, nil + return &LongPollWaitResult{Site: string(jsonSite), FetchedUpdate: true}, nil +} + +func (c *APIClient) Reauthenticate(hostID string, privateKey string, counter int, trustedKeys string) (string, error) { + creds, err := c.keysToCreds(hostID, privateKey, counter, trustedKeys) + if err != nil { + return "", err + } + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + resp, err := c.c.Reauthenticate(ctx, *creds) + switch { + case errors.As(err, &dnapi.ErrInvalidCredentials): + return "", InvalidCredentialsError{} + case err != nil: + return "", fmt.Errorf("reauthenticate error: %s", err) + } + + return resp.LoginURL, nil } func unmarshalHostPrivateKey(b []byte) (keys.PrivateKey, []byte, error) { diff --git a/nebula/go.mod b/nebula/go.mod index b0364d67..671c2f17 100644 --- a/nebula/go.mod +++ b/nebula/go.mod @@ -37,12 +37,13 @@ require ( github.com/vishvananda/netns v0.0.5 // indirect go.yaml.in/yaml/v2 v2.4.3 // indirect go.yaml.in/yaml/v3 v3.0.4 // indirect - golang.org/x/mod v0.30.0 // indirect - golang.org/x/net v0.47.0 // indirect + golang.org/x/mobile v0.0.0-20251209145715-2553ed8ce294 // indirect + golang.org/x/mod v0.31.0 // indirect + golang.org/x/net v0.48.0 // indirect golang.org/x/sync v0.19.0 // indirect golang.org/x/sys v0.39.0 // indirect golang.org/x/term v0.38.0 // indirect - golang.org/x/tools v0.39.0 // indirect + golang.org/x/tools v0.40.0 // indirect golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect golang.zx2c4.com/wireguard v0.0.0-20250521234502-f333402bd9cb // indirect golang.zx2c4.com/wireguard/windows v0.5.3 // indirect diff --git a/nebula/go.sum b/nebula/go.sum index bd1bbc6a..baaafec8 100644 --- a/nebula/go.sum +++ b/nebula/go.sum @@ -1,8 +1,6 @@ cloud.google.com/go v0.34.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= dario.cat/mergo v1.0.2 h1:85+piFYR1tMbRrLcDwR18y4UKJ3aH1Tbzi24VRW1TK8= dario.cat/mergo v1.0.2/go.mod h1:E/hbnu0NxMFBjpMIE34DRGLWqDy0g5FuKDhCb31ngxA= -github.com/DefinedNet/dnapi v0.0.0-20251117181834-1112b1c2813b h1:qqhoqcZnsLhCsZNFCgAYehtXwX4eyfO/moyMGeD9fpg= -github.com/DefinedNet/dnapi v0.0.0-20251117181834-1112b1c2813b/go.mod h1:N6BTss8f8BEoNdO+rQZJZjIOu3lIbwMgm8B2D2o3fUk= github.com/DefinedNet/dnapi v0.0.0-20251210211559-8ae1e6743199 h1:sYdbeQcXjUyFrlR3KE7rbhPLrSrq0tVAxaHQnfUxaMs= github.com/DefinedNet/dnapi v0.0.0-20251210211559-8ae1e6743199/go.mod h1:vmEciuymyw9SGuI2c7FjNtrp9zSMjux9eFiF8tYPjdc= github.com/alecthomas/template v0.0.0-20160405071501-a0175ee3bccc/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc= @@ -134,8 +132,6 @@ github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6Mwd github.com/sirupsen/logrus v1.6.0/go.mod h1:7uNnSEd1DgxDLC74fIahvMZmmYsHGZGEOFrfsX/uA88= github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ= github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ= -github.com/slackhq/nebula v1.10.0 h1:uhu4Cpzw3pXyDJ8G1fMSppsvG7aE9XCt4UaauggHax0= -github.com/slackhq/nebula v1.10.0/go.mod h1:PmYcyoGhAX4X8lCzJjGv7aLTBbFbPy7QeWbpwWvJf+Y= github.com/slackhq/nebula v1.10.1-0.20251210163936-3ec527e42cec h1:F251X4hgG3Fen49ouS7yUVcwYkvvCjb5bmRFAbMnm+c= github.com/slackhq/nebula v1.10.1-0.20251210163936-3ec527e42cec/go.mod h1:mqXWEQjg+I1r5KeCqji83gA0rZPCY9yvP25USUBFGxc= github.com/stefanberger/go-pkcs11uri v0.0.0-20230803200340-78284954bff6 h1:pnnLyeX7o/5aX8qUQ69P/mLojDqwda8hFOCBTmP/6hw= @@ -165,16 +161,18 @@ golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACk golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.0.0-20210322153248-0c34fe9e7dc2/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4= -golang.org/x/crypto v0.45.0 h1:jMBrvKuj23MTlT0bQEOBcAE0mjg8mK9RXFhRH6nyF3Q= -golang.org/x/crypto v0.45.0/go.mod h1:XTGrrkGJve7CYK7J8PEww4aY7gM3qMCElcJQ8n8JdX4= golang.org/x/crypto v0.46.0 h1:cKRW/pmt1pKAfetfu+RCEvjvZkA9RimPbh7bhFjGVBU= golang.org/x/crypto v0.46.0/go.mod h1:Evb/oLKmMraqjZ2iQTwDwvCtJkczlDuTmdJXoZVzqU0= golang.org/x/lint v0.0.0-20200302205851-738671d3881b/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY= +golang.org/x/mobile v0.0.0-20251209145715-2553ed8ce294 h1:Cr6kbEvA6nqvdHynE4CtVKlqpZB9dS1Jva/6IsHA19g= +golang.org/x/mobile v0.0.0-20251209145715-2553ed8ce294/go.mod h1:RdZ+3sb4CVgpCFnzv+I4haEpwqFfsfzlLHs3L7ok+e0= golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg= golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.30.0 h1:fDEXFVZ/fmCKProc/yAXXUijritrDzahmwwefnjoPFk= golang.org/x/mod v0.30.0/go.mod h1:lAsf5O2EvJeSFMiBxXDki7sCgAxEUcZHXoXMKT4GJKc= +golang.org/x/mod v0.31.0 h1:HaW9xtz0+kOcWKwli0ZXy79Ix+UW/vOfmWI5QVd2tgI= +golang.org/x/mod v0.31.0/go.mod h1:43JraMp9cGx1Rx3AqioxrbrhNsLl2l/iNAvuBkrezpg= golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20181114220301-adae6a3d119a/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20190108225652-1e06a53dbb7e/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= @@ -187,6 +185,7 @@ golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwY golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= golang.org/x/net v0.47.0 h1:Mx+4dIFzqraBXUugkia1OOvlD6LemFo1ALMHjrXDOhY= golang.org/x/net v0.47.0/go.mod h1:/jNxtkgq5yWUGYkaZGqo27cfGZ1c5Nen03aYrrKpVRU= +golang.org/x/net v0.48.0/go.mod h1:+ndRgGjkh8FGtu1w1FGbEC31if4VrNVMuKTgcAAnQRY= golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= @@ -194,8 +193,6 @@ golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJ golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20201207232520-09787c993a3a/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.18.0 h1:kr88TuHDroi+UVf+0hZnirlk8o8T+4MrK6mr60WkH/I= -golang.org/x/sync v0.18.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI= golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4= golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI= golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= @@ -214,13 +211,9 @@ golang.org/x/sys v0.0.0-20210603081109-ebe580a85c40/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.2.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.38.0 h1:3yZWxaJjBmCWXqhN1qh02AkOnCQ1poK6oF+a7xWL6Gc= -golang.org/x/sys v0.38.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= golang.org/x/sys v0.39.0 h1:CvCKL8MeisomCi6qNZ+wbb0DN9E5AATixKsvNtMoMFk= golang.org/x/sys v0.39.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= -golang.org/x/term v0.37.0 h1:8EGAD0qCmHYZg6J17DvsMy9/wJ7/D/4pV/wfnld5lTU= -golang.org/x/term v0.37.0/go.mod h1:5pB4lxRNYYVZuTLmy8oR2BH8dflOR+IbTYFD8fi3254= golang.org/x/term v0.38.0 h1:PQ5pkm/rLO6HnxFR7N2lJHOZX6Kez5Y1gDSJla6jo7Q= golang.org/x/term v0.38.0/go.mod h1:bSEAKrOT1W+VSu9TSCMtoGEOUcKxOKgl3LE5QEF/xVg= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= @@ -233,6 +226,8 @@ golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roY golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA= golang.org/x/tools v0.39.0 h1:ik4ho21kwuQln40uelmciQPp9SipgNDdrafrYA4TmQQ= golang.org/x/tools v0.39.0/go.mod h1:JnefbkDPyD8UU2kI5fuf8ZX4/yUeh9W877ZeBONxUqQ= +golang.org/x/tools v0.40.0 h1:yLkxfA+Qnul4cs9QA3KnlFu0lVmd8JJfoq+E41uSutA= +golang.org/x/tools v0.40.0/go.mod h1:Ik/tzLRlbscWpqqMRjyWYDisX8bG13FrdXp3o4Sr9lc= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= diff --git a/nebula/site.go b/nebula/site.go index 76591848..eb30e42d 100644 --- a/nebula/site.go +++ b/nebula/site.go @@ -4,6 +4,7 @@ import ( "encoding/json" "time" + "github.com/DefinedNet/dnapi" "github.com/DefinedNet/dnapi/keys" "gopkg.in/yaml.v2" ) @@ -25,6 +26,8 @@ type site struct { Key *string `json:"key"` Managed jsonTrue `json:"managed"` LastManagedUpdate *time.Time `json:"lastManagedUpdate"` + ManagedOIDCEmail *string `json:"managedOIDCEmail"` + ManagedOIDCExpiry *time.Time `json:"managedOIDCExpiry"` RawConfig *string `json:"rawConfig"` DNCredentials *dnCredentials `json:"dnCredentials"` } @@ -54,7 +57,7 @@ func (f jsonTrue) MarshalJSON() ([]byte, error) { return json.Marshal(true) } -func newDNSite(name string, rawCfg []byte, key string, creds keys.Credentials) (*site, error) { +func newDNSite(name string, rawCfg []byte, key string, creds keys.Credentials, configMeta *dnapi.ConfigMeta) (*site, error) { // Convert YAML Nebula config to a JSON Site var cfg config if err := yaml.Unmarshal(rawCfg, &cfg); err != nil { @@ -113,7 +116,7 @@ func newDNSite(name string, rawCfg []byte, key string, creds keys.Credentials) ( return nil, err } - return &site{ + s := &site{ Name: name, ID: creds.HostID, StaticHostmap: shm, @@ -136,5 +139,12 @@ func newDNSite(name string, rawCfg []byte, key string, creds keys.Credentials) ( Counter: int(creds.Counter), TrustedKeys: string(tkm), }, - }, nil + } + + if configMeta != nil && configMeta.EndpointOIDC != nil { + s.ManagedOIDCEmail = &configMeta.EndpointOIDC.Email + s.ManagedOIDCExpiry = configMeta.EndpointOIDC.ExpiresAt + } + + return s, nil } diff --git a/pubspec.lock b/pubspec.lock index a1357440..d1629a81 100644 --- a/pubspec.lock +++ b/pubspec.lock @@ -97,14 +97,22 @@ packages: url: "https://pub.dev" source: hosted version: "3.0.4" + duration: + dependency: "direct main" + description: + name: duration + sha256: "13e5d20723c9c1dde8fb318cf86716d10ce294734e81e44ae1a817f3ae714501" + url: "https://pub.dev" + source: hosted + version: "4.0.3" fake_async: dependency: transitive description: name: fake_async - sha256: "6a95e56b2449df2273fd8c45a662d6947ce1ebb7aafe80e550a3f68297f3cacc" + sha256: "5368f224a74523e8d2e7399ea1638b37aecfca824a3cc4dfdf77bf1fa905ac44" url: "https://pub.dev" source: hosted - version: "1.3.2" + version: "1.3.3" ffi: dependency: transitive description: @@ -252,26 +260,26 @@ packages: dependency: transitive description: name: leak_tracker - sha256: c35baad643ba394b40aac41080300150a4f08fd0fd6a10378f8f7c6bc161acec + sha256: "33e2e26bdd85a0112ec15400c8cbffea70d0f9c3407491f672a2fad47915e2de" url: "https://pub.dev" source: hosted - version: "10.0.8" + version: "11.0.2" leak_tracker_flutter_testing: dependency: transitive description: name: leak_tracker_flutter_testing - sha256: f8b613e7e6a13ec79cfdc0e97638fddb3ab848452eff057653abd3edba760573 + sha256: "1dbc140bb5a23c75ea9c4811222756104fbcd1a27173f0c34ca01e16bea473c1" url: "https://pub.dev" source: hosted - version: "3.0.9" + version: "3.0.10" leak_tracker_testing: dependency: transitive description: name: leak_tracker_testing - sha256: "6ba465d5d76e67ddf503e1161d1f4a6bc42306f9d66ca1e8f079a47290fb06d3" + sha256: "8d5a2d49f4a66b49744b23b018848400d23e54caf9463f4eb20df3eb8acb2eb1" url: "https://pub.dev" source: hosted - version: "3.0.1" + version: "3.0.2" lints: dependency: transitive description: @@ -300,10 +308,10 @@ packages: dependency: transitive description: name: meta - sha256: e3641ec5d63ebf0d9b41bd43201a66e3fc79a65db5f61fc181f04cd27aab950c + sha256: "23f08335362185a5ea2ad3a4e597f1375e78bce8a040df5c600c8d3552ef2394" url: "https://pub.dev" source: hosted - version: "1.16.0" + version: "1.17.0" mime: dependency: transitive description: @@ -553,10 +561,10 @@ packages: dependency: transitive description: name: test_api - sha256: fb31f383e2ee25fbbfe06b40fe21e1e458d14080e3c67e7ba0acfde4df4e0bbd + sha256: ab2726c1a94d3176a45960b6234466ec367179b87dd74f1611adb1f3b5fb9d55 url: "https://pub.dev" source: hosted - version: "0.7.4" + version: "0.7.7" typed_data: dependency: transitive description: @@ -665,10 +673,10 @@ packages: dependency: transitive description: name: vector_math - sha256: "80b3257d1492ce4d091729e3a67a60407d227c27241d6927be0130c98e741803" + sha256: d530bd74fea330e6e364cda7a85019c434070188383e1cd8d9777ee586914c5b url: "https://pub.dev" source: hosted - version: "2.1.4" + version: "2.2.0" vm_service: dependency: transitive description: @@ -718,5 +726,5 @@ packages: source: hosted version: "3.1.3" sdks: - dart: ">=3.7.0 <4.0.0" + dart: ">=3.8.0-0 <4.0.0" flutter: ">=3.29.0" diff --git a/pubspec.yaml b/pubspec.yaml index 07ef00a4..17b410cb 100644 --- a/pubspec.yaml +++ b/pubspec.yaml @@ -40,6 +40,7 @@ dependencies: sentry_dart_plugin: ^2.4.1 mobile_scanner: ^7.0.1 path: ^1.9.1 + duration: ^4.0.3 dev_dependencies: flutter_test: