summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorCole Robinson <crobinso@redhat.com>2013-09-06 23:36:09 (GMT)
committerCole Robinson <crobinso@redhat.com>2013-09-06 23:36:09 (GMT)
commit5bf63759b6c18f4b814b52d8ed671fbb428af1f8 (patch)
tree3410d7e2580fd1f4c3b4d72479e34c6de4833e8d
parent67cc81f6b13311269e87200d89cd1c60729afd3d (diff)
downloadvirt-manager-5bf63759b6c18f4b814b52d8ed671fbb428af1f8.zip
virt-manager-5bf63759b6c18f4b814b52d8ed671fbb428af1f8.tar.gz
virt-manager-5bf63759b6c18f4b814b52d8ed671fbb428af1f8.tar.xz
console: Fix issues with spice and askpass (bz 811346)
Spice opens many FDs to handle different channels (display, usb, sound, etc.). For remote SSH URIs, this means we launch multiple SSH proceses. We do so by forking off the process, and when SSH has successfully authenticated, the data starts flowing. If using spice + remote SSH w/o SSH keys, you need to put your data into ssh askpass. askpass wants to own the display for security reasons. When all the channel requests start coming in, we were launching multiple ssh processes one after another. This upset askpass and generally caused havoc in the app. Add some infrastructure to serialize launching ssh processes. We only launch the next ssh process if spice/vnc have conclusively connected or errored out the connection. This makes connection a bit slower for the non-askpass ssh case (about 1.5 seconds), but will ignore avoid this oft reported problem.
-rw-r--r--virtManager/baseclass.py25
-rw-r--r--virtManager/console.py261
2 files changed, 185 insertions, 101 deletions
diff --git a/virtManager/baseclass.py b/virtManager/baseclass.py
index f769650..84c9c79 100644
--- a/virtManager/baseclass.py
+++ b/virtManager/baseclass.py
@@ -36,6 +36,19 @@ from gi.repository import Gtk
class vmmGObject(GObject.GObject):
_leak_check = True
+ @staticmethod
+ def idle_add(func, *args, **kwargs):
+ """
+ Make sure idle functions are run thread safe
+ """
+ def cb():
+ try:
+ return func(*args, **kwargs)
+ except:
+ print traceback.format_exc()
+ return False
+ return GLib.idle_add(cb)
+
def __init__(self):
GObject.GObject.__init__(self)
self.config = config.running_config
@@ -141,18 +154,6 @@ class vmmGObject(GObject.GObject):
self.idle_add(emitwrap, signal, *args)
- def idle_add(self, func, *args, **kwargs):
- """
- Make sure idle functions are run thread safe
- """
- def cb():
- try:
- return func(*args, **kwargs)
- except:
- print traceback.format_exc()
- return False
- return GLib.idle_add(cb)
-
def timeout_add(self, timeout, func, *args):
"""
Make sure timeout functions are run thread safe
diff --git a/virtManager/console.py b/virtManager/console.py
index 5a043b5..c31d35b 100644
--- a/virtManager/console.py
+++ b/virtManager/console.py
@@ -31,10 +31,12 @@ from gi.repository import SpiceClientGLib
import libvirt
+import logging
import os
+import Queue
import signal
import socket
-import logging
+import threading
import virtManager.uihelpers as uihelpers
from virtManager.autodrawer import AutoDrawer
@@ -114,14 +116,105 @@ class ConnectionInfo(object):
return int(self.gport) == -1
-class Tunnel(object):
+class _TunnelScheduler(object):
+ """
+ If the user is using Spice + SSH URI + no SSH keys, we need to
+ serialize connection opening otherwise ssh-askpass gets all angry.
+ This handles the locking and scheduling.
+
+ It's only instantiated once for the whole app, because we serialize
+ independent of connection, vm, etc.
+ """
+ def __init__(self):
+ self._thread = threading.Thread(name="Tunnel thread",
+ target=self._handle_queue,
+ args=())
+ self._thread.daemon = True
+ self._queue = Queue.Queue()
+ self._lock = threading.Lock()
+
+ def _handle_queue(self):
+ while True:
+ cb, args, = self._queue.get()
+ self.lock()
+ vmmGObject.idle_add(cb, *args)
+
+ def schedule(self, cb, *args):
+ if not self._thread.is_alive():
+ self._thread.start()
+ self._queue.put((cb, args))
+
+ def lock(self):
+ self._lock.acquire()
+ def unlock(self):
+ self._lock.release()
+
+_tunnel_sched = _TunnelScheduler()
+
+
+class _Tunnel(object):
def __init__(self):
self.outfd = None
self.errfd = None
self.pid = None
+ self._outfds = None
+ self._errfds = None
+ self.closed = False
def open(self, ginfo):
- if self.outfd is not None:
+ self._outfds = socket.socketpair()
+ self._errfds = socket.socketpair()
+
+ return self._outfds[0].fileno(), self._launch_tunnel, ginfo
+
+ def close(self):
+ if self.closed:
+ return
+ self.closed = True
+
+ logging.debug("Close tunnel PID=%s OUTFD=%s ERRFD=%s",
+ self.pid,
+ self.outfd and self.outfd.fileno() or self._outfds,
+ self.errfd and self.errfd.fileno() or self._errfds)
+
+ if self.outfd:
+ self.outfd.close()
+ elif self._outfds:
+ self._outfds[0].close()
+ self._outfds[1].close()
+ self.outfd = None
+ self._outfds = None
+
+ if self.errfd:
+ self.errfd.close()
+ elif self._errfds:
+ self._errfds[0].close()
+ self._errfds[1].close()
+ self.errfd = None
+ self._errfds = None
+
+ if self.pid:
+ os.kill(self.pid, signal.SIGKILL)
+ os.waitpid(self.pid, 0)
+ self.pid = None
+
+ def get_err_output(self):
+ errout = ""
+ while True:
+ try:
+ new = self.errfd.recv(1024)
+ except:
+ break
+
+ if not new:
+ break
+
+ errout += new
+
+ return errout
+
+ def _launch_tunnel(self, ginfo):
+ if self.closed:
return -1
host, port, ignore = ginfo.get_conn_host()
@@ -168,70 +261,33 @@ class Tunnel(object):
argv_str = reduce(lambda x, y: x + " " + y, argv[1:])
logging.debug("Creating SSH tunnel: %s", argv_str)
- fds = socket.socketpair()
- errorfds = socket.socketpair()
-
pid = os.fork()
if pid == 0:
- fds[0].close()
- errorfds[0].close()
+ self._outfds[0].close()
+ self._errfds[0].close()
os.close(0)
os.close(1)
os.close(2)
- os.dup(fds[1].fileno())
- os.dup(fds[1].fileno())
- os.dup(errorfds[1].fileno())
+ os.dup(self._outfds[1].fileno())
+ os.dup(self._outfds[1].fileno())
+ os.dup(self._errfds[1].fileno())
os.execlp(*argv)
os._exit(1) # pylint: disable=W0212
else:
- fds[1].close()
- errorfds[1].close()
+ self._outfds[1].close()
+ self._errfds[1].close()
- logging.debug("Tunnel PID=%d OUTFD=%d ERRFD=%d",
- pid, fds[0].fileno(), errorfds[0].fileno())
- errorfds[0].setblocking(0)
+ logging.debug("Open tunnel PID=%d OUTFD=%d ERRFD=%d",
+ pid, self._outfds[0].fileno(), self._errfds[0].fileno())
+ self._errfds[0].setblocking(0)
- self.outfd = fds[0]
- self.errfd = errorfds[0]
+ self.outfd = self._outfds[0]
+ self.errfd = self._errfds[0]
+ self._outfds = None
+ self._errfds = None
self.pid = pid
- fd = fds[0].fileno()
- if fd < 0:
- raise SystemError("can't open a new tunnel: fd=%d" % fd)
- return fd
-
- def close(self):
- if self.outfd is None:
- return
-
- logging.debug("Shutting down tunnel PID=%d OUTFD=%d ERRFD=%d",
- self.pid, self.outfd.fileno(),
- self.errfd.fileno())
- self.outfd.close()
- self.outfd = None
- self.errfd.close()
- self.errfd = None
-
- os.kill(self.pid, signal.SIGKILL)
- os.waitpid(self.pid, 0)
- self.pid = None
-
- def get_err_output(self):
- errout = ""
- while True:
- try:
- new = self.errfd.recv(1024)
- except:
- break
-
- if not new:
- break
-
- errout += new
-
- return errout
-
class Tunnels(object):
def __init__(self, ginfo):
@@ -239,9 +295,11 @@ class Tunnels(object):
self._tunnels = []
def open_new(self):
- t = Tunnel()
- fd = t.open(self.ginfo)
+ t = _Tunnel()
+ fd, cb, args = t.open(self.ginfo)
self._tunnels.append(t)
+ _tunnel_sched.schedule(cb, args)
+
return fd
def close_all(self):
@@ -254,6 +312,9 @@ class Tunnels(object):
errout += l.get_err_output()
return errout
+ lock = _tunnel_sched.lock
+ unlock = _tunnel_sched.unlock
+
class Viewer(vmmGObject):
def __init__(self, console):
@@ -275,6 +336,12 @@ class Viewer(vmmGObject):
def get_pixbuf(self):
return self.display.get_pixbuf()
+ def open_ginfo(self, ginfo):
+ if ginfo.need_tunnel():
+ self.open_fd(self.console.tunnels.open_new())
+ else:
+ self.open_host(ginfo)
+
def get_grab_keys(self):
raise NotImplementedError()
@@ -284,10 +351,10 @@ class Viewer(vmmGObject):
def send_keys(self, keys):
raise NotImplementedError()
- def open_host(self, ginfo, password=None):
+ def open_host(self, ginfo):
raise NotImplementedError()
- def open_fd(self, fd, password=None):
+ def open_fd(self, fd):
raise NotImplementedError()
def get_desktop_resolution(self):
@@ -306,6 +373,8 @@ class VNCViewer(Viewer):
# Last noticed desktop resolution
self.desktop_resolution = None
+ self._tunnel_unlocked = False
+
def init_widget(self):
self.set_grab_keys()
@@ -320,18 +389,32 @@ class VNCViewer(Viewer):
self.display.set_pointer_grab(True)
self.display.connect("vnc-pointer-grab", self.console.pointer_grabbed)
- self.display.connect("vnc-pointer-ungrab", self.console.pointer_ungrabbed)
+ self.display.connect("vnc-pointer-ungrab",
+ self.console.pointer_ungrabbed)
self.display.connect("vnc-auth-credential", self._auth_credential)
- self.display.connect("vnc-initialized",
- lambda src: self.console.connected())
- self.display.connect("vnc-disconnected",
- lambda src: self.console.disconnected())
+ self.display.connect("vnc-initialized", self._connected_cb)
+ self.display.connect("vnc-disconnected", self._disconnected_cb)
self.display.connect("vnc-desktop-resize", self._desktop_resize)
- self.display.connect("focus-in-event", self.console.viewer_focus_changed)
- self.display.connect("focus-out-event", self.console.viewer_focus_changed)
+ self.display.connect("focus-in-event",
+ self.console.viewer_focus_changed)
+ self.display.connect("focus-out-event",
+ self.console.viewer_focus_changed)
self.display.show()
+ def _unlock_tunnel(self):
+ if self.console.tunnels and not self._tunnel_unlocked:
+ self.console.tunnels.unlock()
+ self._tunnel_unlocked = True
+
+ def _connected_cb(self, ignore):
+ self._unlock_tunnel()
+ self.console.connected()
+
+ def _disconnected_cb(self, ignore):
+ self._unlock_tunnel()
+ self.console.disconnected()
+
def get_grab_keys(self):
return self.display.get_grab_keys().as_string()
@@ -421,7 +504,7 @@ class VNCViewer(Viewer):
def is_open(self):
return self.display.is_open()
- def open_host(self, ginfo, password=None):
+ def open_host(self, ginfo):
host, port, ignore = ginfo.get_conn_host()
if not ginfo.gsocket:
@@ -444,8 +527,7 @@ class VNCViewer(Viewer):
ginfo.gsocket) + " fd=%s" % fd)
self.open_fd(fd)
- def open_fd(self, fd, password=None):
- ignore = password
+ def open_fd(self, fd):
self.display.open_fd(fd)
def set_credential_username(self, cred):
@@ -469,8 +551,10 @@ class SpiceViewer(Viewer):
self.console.refresh_scaling()
self.display.realize()
- self.display.connect("mouse-grab", lambda src, g: g and self.console.pointer_grabbed(src))
- self.display.connect("mouse-grab", lambda src, g: g or self.console.pointer_ungrabbed(src))
+ self.display.connect("mouse-grab",
+ lambda src, g: g and self.console.pointer_grabbed(src))
+ self.display.connect("mouse-grab",
+ lambda src, g: g or self.console.pointer_ungrabbed(src))
self.display.connect("focus-in-event",
self.console.viewer_focus_changed)
@@ -534,11 +618,19 @@ class SpiceViewer(Viewer):
logging.debug("Spice channel event error: %s", event)
self.console.disconnected()
+ def _fd_channel_event_cb(self, channel, event):
+ # When we see any event from the channel, release the
+ # associated tunnel lock
+ channel.disconnect_by_func(self._fd_channel_event_cb)
+ self.console.tunnels.unlock()
+
def _channel_open_fd_request(self, channel, tls_ignore):
if not self.console.tunnels:
raise SystemError("Got fd request with no configured tunnel!")
logging.debug("Opening tunnel for channel: %s", channel)
+ channel.connect_after("channel-event", self._fd_channel_event_cb)
+
fd = self.console.tunnels.open_new()
channel.open_fd(fd)
@@ -547,6 +639,8 @@ class SpiceViewer(Viewer):
self._channel_open_fd_request)
if type(channel) == SpiceClientGLib.MainChannel:
+ if self.console.tunnels:
+ self.console.tunnels.unlock()
channel.connect_after("channel-event", self._main_channel_event_cb)
return
@@ -584,6 +678,9 @@ class SpiceViewer(Viewer):
gtk_session = SpiceClientGtk.GtkSession.get(self.spice_session)
gtk_session.set_property("auto-clipboard", True)
+ GObject.GObject.connect(self.spice_session, "channel-new",
+ self._channel_new_cb)
+
self.usbdev_manager = SpiceClientGLib.UsbDeviceManager.get(
self.spice_session)
self.usbdev_manager.connect("auto-connect-failed",
@@ -595,26 +692,19 @@ class SpiceViewer(Viewer):
if autoredir:
gtk_session.set_property("auto-usbredir", True)
- def open_host(self, ginfo, password=None):
+ def open_host(self, ginfo):
host, port, tlsport = ginfo.get_conn_host()
-
self._create_spice_session()
+
self.spice_session.set_property("host", str(host))
self.spice_session.set_property("port", str(port))
if tlsport:
self.spice_session.set_property("tls-port", str(tlsport))
- if password:
- self.spice_session.set_property("password", password)
- GObject.GObject.connect(self.spice_session, "channel-new",
- self._channel_new_cb)
+
self.spice_session.connect()
- def open_fd(self, fd, password=None):
+ def open_fd(self, fd):
self._create_spice_session()
- if password:
- self.spice_session.set_property("password", password)
- GObject.GObject.connect(self.spice_session, "channel-new",
- self._channel_new_cb)
self.spice_session.open_fd(fd)
def set_credential_password(self, cred):
@@ -1254,15 +1344,8 @@ class vmmConsolePages(vmmGObjectUI):
self.set_enable_accel()
if ginfo.need_tunnel():
- if self.tunnels:
- # Tunnel already open, no need to continue
- return
-
self.tunnels = Tunnels(ginfo)
- self.viewer.open_fd(self.tunnels.open_new())
- else:
- self.viewer.open_host(ginfo)
-
+ self.viewer.open_ginfo(ginfo)
except Exception, e:
logging.exception("Error connection to graphical console")
self.activate_unavailable_page(