Browse Source

[livy] Strip SparkR out the ANSI control characters before checking marker

This was breaking when running the "should report an error
if accessing an unknown variable" test when the tests were run
on the commandline.
Erick Tryzelaar 10 years ago
parent
commit
2e44a33

+ 48 - 13
apps/spark/java/livy-repl/src/main/scala/com/cloudera/hue/livy/repl/sparkr/SparkRInterpreter.scala

@@ -18,7 +18,6 @@
 
 
 package com.cloudera.hue.livy.repl.sparkr
 package com.cloudera.hue.livy.repl.sparkr
 
 
-import java.io.File
 import java.nio.file.Files
 import java.nio.file.Files
 import java.util.concurrent.locks.ReentrantLock
 import java.util.concurrent.locks.ReentrantLock
 
 
@@ -151,20 +150,56 @@ private class SparkRInterpreter(process: Process)
 
 
   @tailrec
   @tailrec
   private def readTo(marker: String, output: StringBuilder = StringBuilder.newBuilder): String = {
   private def readTo(marker: String, output: StringBuilder = StringBuilder.newBuilder): String = {
-    val char = stdout.read()
-    if (char == -1) {
+    var char = readChar(output)
+
+    // Remove any ANSI color codes which match the pattern "\u001b\\[[0-9;]*[mG]".
+    // It would be easier to do this with a regex, but unfortunately I don't see an easy way to do
+    // without copying the StringBuilder into a string for each character.
+    if (char == '\u001b') {
+      if (readChar(output) == '[') {
+        char = readDigits(output)
+
+        if (char == 'm' || char == 'G') {
+          output.delete(output.lastIndexOf('\u001b'), output.length)
+        }
+      }
+    }
+
+    if (output.endsWith(marker)) {
+      val result = output.toString()
+      result.substring(0, result.length - marker.length)
+        .stripPrefix("\n")
+        .stripSuffix("\n")
+    } else {
+      readTo(marker, output)
+    }
+  }
+
+  private def readChar(output: StringBuilder): Char = {
+    val byte = stdout.read()
+    if (byte == -1) {
       throw new Exited(output.toString())
       throw new Exited(output.toString())
     } else {
     } else {
-      output.append(char.toChar)
-      if (output.endsWith(marker)) {
-        val result = output.toString()
-        result.substring(0, result.length - marker.length)
-          .replaceAll("\033\\[[0-9;]*[mG]", "") // Remove any ANSI color codes
-          .stripPrefix("\n")
-          .stripSuffix("\n")
-      } else {
-        readTo(marker, output)
-      }
+      val char = byte.toChar
+      output.append(char)
+      char
+    }
+  }
+
+  @tailrec
+  private def readDigits(output: StringBuilder): Char = {
+    val byte = stdout.read()
+    if (byte == -1) {
+      throw new Exited(output.toString())
+    }
+
+    val char = byte.toChar
+
+    if (('0' to '9').contains(char)) {
+      output.append(char)
+      readDigits(output)
+    } else {
+      char
     }
     }
   }
   }