wake-up-neo.com

Die Legende zeigt beim Plotten mit Pandas nur ein Etikett

Ich habe zwei Pandas-Datenrahmen, von denen ich hoffe, sie in einer einzelnen Abbildung zu zeichnen. Ich verwende ein IPython-Notebook.

Ich möchte, dass die Legende das Label für beide DataFrames anzeigt, aber bisher konnte ich nur den letzteren zeigen. Wir freuen uns auch über Vorschläge, wie Sie den Code sinnvoller schreiben können. Ich bin neu bei all dem und verstehe objektorientiertes Plotten nicht wirklich.

%pylab inline
import pandas as pd

#creating data

prng = pd.period_range('1/1/2011', '1/1/2012', freq='M')
var=pd.DataFrame(randn(len(prng)),index=prng,columns=['total'])
shares=pd.DataFrame(randn(len(prng)),index=index,columns=['average'])

#plotting

ax=var.total.plot(label='Variance')
ax=shares.average.plot(secondary_y=True,label='Average Age')
ax.left_ax.set_ylabel('Variance of log wages')
ax.right_ax.set_ylabel('Average age')
plt.legend(loc='upper center')
plt.title('Wage Variance and Mean Age')
plt.show()

Legend is missing one of the labels

17
Artturi Björk

Das ist in der Tat etwas verwirrend. Ich denke, es läuft darauf hinaus, wie Matplotlib mit den Sekundärachsen umgeht. Pandas ruft wahrscheinlich ax.twinx() irgendwo auf, was der ersten eine sekundäre Achse überlagert, aber dies ist eigentlich eine separate Achse. Daher auch mit separaten Zeilen & Beschriftungen und separater Legende. Der Aufruf von plt.legend() gilt nur für eine der Achsen (die aktive Achse), die in Ihrem Beispiel die zweite Achse ist.

Pandas speichert zum Glück beide Achsen. Sie können also alle Linienobjekte von beiden Achsen übernehmen und selbst an den Befehl .legend() übergeben. Angenommen Ihre Beispieldaten:

Sie können genau wie Sie zeichnen:

ax = var.total.plot(label='Variance')
ax = shares.average.plot(secondary_y=True, label='Average Age')

ax.set_ylabel('Variance of log wages')
ax.right_ax.set_ylabel('Average age')

Beide Achsenobjekte stehen mit ax (linke Axt) und ax.right_ax zur Verfügung, so dass Sie die Linienobjekte daraus ziehen können. Matplotlibs .get_lines() gibt eine Liste zurück, damit Sie sie durch einfaches Hinzufügen zusammenführen können.

lines = ax.get_lines() + ax.right_ax.get_lines()

Die Linienobjekte verfügen über eine Label-Eigenschaft, mit der das Label gelesen und an den Befehl .legend() übergeben werden kann.

ax.legend(lines, [l.get_label() for l in lines], loc='upper center')

Und der Rest der Verschwörung:

ax.set_title('Wage Variance and Mean Age')
plt.show()

enter image description here

bearbeiten:

Es ist weniger verwirrend, wenn Sie die Teile Pandas (Daten) und Matplotlib (Plotten) strenger voneinander trennen. Vermeiden Sie daher die Verwendung des eingebauten Pandas-Plottings (das Matplotlib sowieso nur umschließt):

fig, ax = plt.subplots()

ax.plot(var.index.to_datetime(), var.total, 'b', label='Variance')
ax.set_ylabel('Variance of log wages')

ax2 = ax.twinx()
ax2.plot(shares.index.to_datetime(), shares.average, 'g' , label='Average Age')
ax2.set_ylabel('Average age')

lines = ax.get_lines() + ax2.get_lines()
ax.legend(lines, [line.get_label() for line in lines], loc='upper center')

ax.set_title('Wage Variance and Mean Age')
plt.show()
32
Rutger Kassies

Wenn mehrere Serien gezeichnet werden, wird die Legende standardmäßig nicht angezeigt.
Benutzerdefinierte Legenden lassen sich einfach anzeigen, indem Sie die Achse der letzten aufgezeichneten Serie/Datenframes verwenden (mein Code aus IPython Notebook ):

%matplotlib inline  # Embed the plot
import matplotlib.pyplot as plt

...
rates[rates.MovieID <= 25].groupby('MovieID').Rating.count().plot()  # blue
(rates[rates.MovieID <= 25].groupby('MovieID').Rating.median() * 1000).plot()  # green
(rates[rates.MovieID <= 25][rates.RateDelta <= 10].groupby('MovieID').Rating.count() * 2000).plot()  # red
ax = (rates[rates.MovieID <= 25][rates.RateDelta <= 10].groupby('MovieID').Rating.median() * 1000).plot()  # cyan

ax.legend(['Popularity', 'RateMedian', 'FirstPpl', 'FirstRM'])

The plot with custom legends

3
luart

Sie können pd.concat verwenden, um die beiden Datenrahmen zusammenzuführen, und dann verwendet der Plot eine sekundäre Y-Achse:

import numpy as np  # For generating random data.
import pandas as pd

# Creating data.
np.random.seed(0)
prng = pd.period_range('1/1/2011', '1/1/2012', freq='M')
var = pd.DataFrame(np.random.randn(len(prng)), index=prng, columns=['total'])
shares = pd.DataFrame(np.random.randn(len(prng)), index=prng, columns=['average'])

# Plotting.
ax = (
    pd.concat([var, shares], axis=1)
    .rename(columns={
        'total': 'Variance of Low Wages',
        'average': 'Average Age'
    })
    .plot(
        title='Wage Variance and Mean Age',
        secondary_y='Average Age')
)
ax.set_ylabel('Variance of Low Wages')
ax.right_ax.set_ylabel('Average Age', rotation=-90)

 chart

0
Alexander