Skip to content

Commit

Permalink
Implement Series.plot.box
Browse files Browse the repository at this point in the history
  • Loading branch information
HyukjinKwon committed Jan 14, 2021
1 parent 053d1eb commit d872bb2
Show file tree
Hide file tree
Showing 2 changed files with 120 additions and 6 deletions.
88 changes: 82 additions & 6 deletions databricks/koalas/plot/plotly.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,26 +13,37 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
from typing import TYPE_CHECKING, Union

import pandas as pd

from databricks.koalas.plot import HistogramPlotBase, name_like_string, KoalasPlotAccessor
from databricks.koalas.plot import (
HistogramPlotBase,
name_like_string,
KoalasPlotAccessor,
BoxPlotBase,
)

if TYPE_CHECKING:
import databricks.koalas as ks


def plot_koalas(data, kind, **kwargs):
def plot_koalas(data: Union["ks.DataFrame", "ks.Series"], kind: str, **kwargs):
import plotly

# Koalas specific plots
if kind == "pie":
return plot_pie(data, **kwargs)
if kind == "hist":
# Note that here data is a Koalas DataFrame or Series unlike other type of plots.
return plot_histogram(data, **kwargs)
if kind == "box":
return plot_box(data, **kwargs)

# Other plots.
return plotly.plot(KoalasPlotAccessor.pandas_plot_data_map[kind](data), kind, **kwargs)


def plot_pie(data, **kwargs):
def plot_pie(data: Union["ks.DataFrame", "ks.Series"], **kwargs):
from plotly import express

data = KoalasPlotAccessor.pandas_plot_data_map["pie"](data)
Expand All @@ -50,13 +61,13 @@ def plot_pie(data, **kwargs):
data,
values=kwargs.pop("values", values),
names=kwargs.pop("names", default_names),
**kwargs
**kwargs,
)
else:
raise RuntimeError("Unexpected type: [%s]" % type(data))


def plot_histogram(data, **kwargs):
def plot_histogram(data: Union["ks.DataFrame", "ks.Series"], **kwargs):
import plotly.graph_objs as go

bins = kwargs.get("bins", 10)
Expand Down Expand Up @@ -92,3 +103,68 @@ def plot_histogram(data, **kwargs):
fig["layout"]["xaxis"]["title"] = "value"
fig["layout"]["yaxis"]["title"] = "count"
return fig


def plot_box(data: Union["ks.DataFrame", "ks.Series"], **kwargs):
import plotly.graph_objs as go

if isinstance(data, pd.DataFrame):
raise RuntimeError(
"plotly does not support a box plot with Koalas DataFrame. Use Series instead."
)

# this isn't actually an argument in plotly. But seems like plotly doesn't expose this
# parameter?
whis = kwargs.pop("whis", 1.5)
# This one is Koalas specific to control precision for approx_percentile
precision = kwargs.pop("precision", 0.01)

# Plotly options
boxpoints = kwargs.pop("boxpoints", "suspectedoutliers")
notched = kwargs.pop("notched", False)
if boxpoints not in ["suspectedoutliers", False]:
raise ValueError(
"plotly plotting backend does not support 'boxpoints' set to '%s'. "
"Set to 'suspectedoutliers' or False." % boxpoints
)
if notched:
raise ValueError(
"plotly plotting backend does not support 'notched' set to '%s'. "
"Set to False." % notched
)

colname = name_like_string(data.name)
spark_column_name = data._internal.spark_column_name_for(data._column_label)

# Computes mean, median, Q1 and Q3 with approx_percentile and precision
col_stats, col_fences = BoxPlotBase.compute_stats(data, spark_column_name, whis, precision)

# Creates a column to flag rows as outliers or not
outliers = BoxPlotBase.outliers(data, spark_column_name, *col_fences)

# Computes min and max values of non-outliers - the whiskers
whiskers = BoxPlotBase.calc_whiskers(spark_column_name, outliers)

fliers = None
if boxpoints:
fliers = BoxPlotBase.get_fliers(spark_column_name, outliers, whiskers[0])

fig = go.Figure()
fig.add_trace(
go.Box(
name=colname,
q1=[col_stats["q1"]],
median=[col_stats["med"]],
q3=[col_stats["q3"]],
mean=[col_stats["mean"]],
lowerfence=[whiskers[0]],
upperfence=[whiskers[1]],
y=[fliers],
boxpoints=boxpoints,
notched=notched,
**kwargs, # this is for workarounds. Box takes different options from express.box.
)
)
fig["layout"]["xaxis"]["title"] = colname
fig["layout"]["yaxis"]["title"] = "value"
return fig
38 changes: 38 additions & 0 deletions databricks/koalas/tests/plot/test_series_plot_plotly.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,3 +168,41 @@ def check_hist_plot(kser):
columns = pd.MultiIndex.from_tuples([("x", "y")])
kdf1.columns = columns
check_hist_plot(kdf1[("x", "y")])

def test_pox_plot(self):
def check_pox_plot(kser):
fig = go.Figure()
fig.add_trace(
go.Box(
name=name_like_string(kser.name),
q1=[3],
median=[6],
q3=[9],
mean=[10.0],
lowerfence=[1],
upperfence=[15],
y=[[50]],
boxpoints="suspectedoutliers",
notched=False,
)
)
fig["layout"]["xaxis"]["title"] = name_like_string(kser.name)
fig["layout"]["yaxis"]["title"] = "value"

self.assertEqual(
pprint.pformat(kser.plot(kind="box").to_dict()), pprint.pformat(fig.to_dict())
)

kdf1 = self.kdf1
check_pox_plot(kdf1["a"])

columns = pd.MultiIndex.from_tuples([("x", "y")])
kdf1.columns = columns
check_pox_plot(kdf1[("x", "y")])

def test_pox_plot_arguments(self):
with self.assertRaisesRegex(ValueError, "does not support"):
self.kdf1.a.plot.box(boxpoints="all")
with self.assertRaisesRegex(ValueError, "does not support"):
self.kdf1.a.plot.box(notched=True)
self.kdf1.a.plot.box(hovertext="abc") # other arguments should not throw an exception

0 comments on commit d872bb2

Please sign in to comment.